-- See: L. Allison. Types and classes of machine learning and data mining.
--      26th Australasian Computer Science Conference (ACSC) pp.207-215,
--      Adelaide, February 2003
--      L. Allison. Models for machine learning and data mining in
--      functional programming.      doi:10.1017/S0956796804005301
--      J. Functional Programming, 15(1), pp.15-32, Jan. 2005
-- Author: Lloyd ALLISON           lloyd at bruce cs monash edu au
--         http://www.csse.monash.edu.au/~lloyd/tildeFP/II/200309/
-- This program is free software; you can redistribute it and/or modify it
-- under the terms of the GNU General Public License (GPL) as published by
-- the Free Software Foundation; either version 2 of the License, or (at
-- your option) any later version.  This program is distributed in the hope
-- that it will be useful, but without any warranty, without even the implied
-- warranty of merchantability or fitness for a particular purpose.  See the
-- GNU General Public License for more details.  You should have received a
-- copy of the GNU General Public License with this program; if not, write to:
-- Free Software Foundation, Inc., Boston, MA 02111, USA.

module MixtureModels (module MixtureModels) where
import SM_Utilities
import SM_Classes
import Models

-- estimate a mixture given a list of (component) estimators & a dataset,
-- expectation maximisation loop for a given number of components,
-- guess memberships -> mixture -> loop{ -> m'ship -> mixture },
-- membership refers to a datum fractionally "belonging to"
-- components of the mixture.

estMixture ests dataSet =  -- [estimator] -> [dataSpace] -> model of dataSpace
 let                       -- i.e. [estimator] -> estimator

  memberships  (Mix mixer components)  =                 -- memberships|Mixture
    let doAll (d:ds) = prepend (doOne d) (doAll ds)      -- all data
        doAll  []    = map (\x -> []) components
        doOne  datum = normalise(                        -- one datum
          zipWith (\c -> \m -> (pr mixer c)*(pr m datum)) [0..] components)
                                  -- pr(c) * pr(datum|c)  for class #c = m
    in doAll dataSet

  randomMemberships =
    let doAll seed  []    = map (\_ -> []) ests
        doAll seed (_:ds) =                                         -- all data
          let doOne seed  []      ans = (seed, normalise ans)
              doOne seed (_:ests) ans =                            -- one datum
                doOne (prng seed) ests ((fromIntegral(1+ seed `mod` 10)) : ans)
          in let (seed2, forDatum) = doOne seed ests []
             in prepend forDatum (doAll seed2 ds)
    in doAll 4321 dataSet   -- should use a better prng!

  fit [] [] = []                                          -- Models|memberships
  fit (est:ests) (mem:mems) = (est dataSet mem) : (fit ests mems)

  fitMixture mems = Mix (freqs2model (map (foldl (+) 0) mems))  -- weights
                        (fit ests mems)                         -- components

  cycle    mx = fitMixture (memberships mx)                     -- EM step
  cycles 0 mx = mx
  cycles n mx = cycles (n-1) (cycle mx)                         -- n x cycle

 in mixture( cycles 20 (fitMixture randomMemberships) )

-- ------------------------------9/2002--6/2003--L.Allison--CSSE--Monash--.au--

test06 = let
 { twoUp = (take 100 . cycle)                      -- (imaginary) 2-up data ...
           [(H,H),(H,H),(H,H),(H,H), (T,T),(T,T),(T,T),(T,T), (H,T),(T,H)];
   pairs4 = [(H,H), (H,T), (T,H), (T,T)];                         -- fair coins
   est2coins = estBivariateWeighted ( estMultiStateWeighted,
                                      estMultiStateWeighted );
   mix1  = estMixture [est2coins] twoUp;                       -- 1:1:1:1
   mix2  = estMixture [est2coins, est2coins] twoUp;            -- 4:1:1:4 +/-
   mix3  = estMixture [est2coins, est2coins, est2coins] twoUp;

   test str mm =
     print(str ++ " = " ++ show mm)
     >> print("pr "++str++" [HH HT TH TT] = " ++ show(map (pr mm) pairs4))
     >> print("msg "++str++" twoUp   = " ++ show( msg mm twoUp )
         ++ "=" ++ show(msg1 mm) ++ "+" ++ show(msg2 mm twoUp) ++ " nits")
     >> print("msg "++str++" pairs4* = "
         ++ show(msg mm ((take 100 . cycle) pairs4)) ++ " nits")
 }
 in print "-- test06 --"
 >> test "mix1" mix1      -- too simple
 >> test "mix2" mix2      -- mix2 should be the best model, for twoUp
 >> test "mix3" mix3      -- unnec' complex

-- ----------------------------------------------------------------------------
