-- See: L. Allison. Types and classes of machine learning and data mining.
--      26th Australasian Computer Science Conference (ACSC) pp.207-215,
--      Adelaide, February 2003
--      L. Allison. Models for machine learning and data mining in
--      functional programming.      doi:10.1017/S0956796804005301
--      J. Functional Programming, 15(1), pp.15-32, Jan. 2005
-- Author: Lloyd ALLISON           lloyd at bruce cs monash edu au
--         http://www.csse.monash.edu.au/~lloyd/tildeFP/II/200309/
-- This program is free software; you can redistribute it and/or modify it
-- under the terms of the GNU General Public License (GPL) as published by
-- the Free Software Foundation; either version 2 of the License, or (at
-- your option) any later version.  This program is distributed in the hope
-- that it will be useful, but without any warranty, without even the implied
-- warranty of merchantability or fitness for a particular purpose.  See the
-- GNU General Public License for more details.  You should have received a
-- copy of the GNU General Public License with this program; if not, write to:
-- Free Software Foundation, Inc., Boston, MA 02111, USA.

module Models (module Models) where
import SM_Utilities
import SM_Classes


bivariate (m1, m2) =         -- (Model of d1, Model of d2) -> Model of (d1, d2)
  let nlp (d1, d2) = (nlPr m1 d1) + (nlPr m2 d2)  -- d1, d2 assumed independent
  in MnlPr ((msg1 m1) + (msg1 m2)) nlp
           (\() -> "(" ++ showsPrec 0 m1 ("," ++ show m2 ++ ")") )

estBivariate (est1, est2) dataSet =
  let (ds1, ds2) = unzip dataSet
  in bivariate (est1 ds1, est2 ds2)

estBivariateWeighted (est1, est2) dataSet weights =    -- for mixture modelling
  let (ds1, ds2) = unzip dataSet
  in bivariate (est1 ds1 weights, est2 ds2 weights)

-- ----------------------------------------------------------------------------

-- 2003, I moved     wallaceIntModel, -- Model of non-neg Int,  >= 0
-- to module Classes for "pragmatic" reasons.


quadraticIntModel =        -- a simple parameterless Model of non-neg Int, >= 0
  let nlPr n = if n < 0 then error "n < 0"
               else log (fromIntegral ((n+1)*(n+2)))  -- n >= 0
                    -- i.e.  pr n = 1/((n+1)*(n+2))
  in MnlPr 0 nlPr (\() -> "P(n)=1/((n+1)*(n+2), n>=0")
  -- NB. Sum[n=0..] 1/((n+1)*(n+2)) = Sum[n=0..]{ 1/(n+1) - 1/(n+2) }
  --     = 1 - 1/2 + 1/2 - 1/3 + 1/3 ... = 1,  i.e. it is a prob' dist'n.


fifty50 = uniform 0 1                        -- 50:50 on {0,1}


uniform lo hi =                              -- Model of Enum Bounded dataSpace
  let (lwb, upb) = if (fromEnum hi) >= (fromEnum lo)
                   then (lo, hi) else (hi, lo)  -- ?switch?
  in MPr 0 (\_ -> 1 / fromIntegral((fromEnum upb) - (fromEnum lwb) + 1))
         (\() -> "uniform[" ++ show lo ++ ".." ++ show hi ++ "]")


probs2model ps =                           --  :: [Probability] -> Model of Int
  MPr 0 (\n -> ps !! n ) (\() -> "mState"++show ps)  -- NB. MPr 0 ...


freqs2model fs =     -- :: frequencies, [Num] -> Model of Int (i.e. MultiState)
  -- Taking a short-cut here, see Wallace and Boulton (1968), and
  -- Wallace and Freeman (1987) for more information.
  let total = foldl (+) 0 fs
      -- noModel estimates cost of data using the "uninformative" method
      noModel  []    ans = ans
      noModel (f:fs) ans = noModel fs (ans - (stirling f))   -- NB base e
      -- MML:  part1 + part2 = noModel... + delta, so...
      part1 = (noModel fs (stirling (total+1))) + delta - part2
      part2 = - (foldl (+) 0 ((zipWith (*) fs (map log ps))))
            -- part2 = msg Len of data given the MML model
      delta = nParams * ((log(pi/6))+1) / 2 -- about 0.176 nits per parameter
      nParams = fromIntegral(length fs - 1)
      ps = normalise (map ((+) 0.5) fs)  -- probabilities, NB +0.5  MML est'r
      p n = ps !! n                      -- the prob' function
  in MPr (if part1 > 0 then part1 else 0) p (\()->"mState"++show ps)


       -- list of frequencies from a series of (weights and) Enum, Bounded type
countWeighted dataSeries weights =
  let c  []     _     counters = counters
      c (d:ds) (w:ws) counters = c ds ws (inc counters (fromEnum d) w)
      inc  []    n _ = error ("count, in inc, ran off end of list "++(show n))
      inc (c:cs) 0 w = (c+w) : cs            -- counter, c += w
      inc (c:cs) n w = c : (inc cs (n-1) w)  -- w is a weight
      zeroes 0 = []
      zeroes n = 0 : (zeroes (n-1))  -- e.g. zeroes 2 = [0, 0]
      upb = maxBound `asTypeOf` (dataSeries !! 0)
      lwb = minBound `asTypeOf` (dataSeries !! 0)
      typeSize = (fromEnum upb) - (fromEnum lwb) + 1
  in c dataSeries weights (zeroes typeSize)

count dataSeries =   -- list of frequencies from a series of Enum, Bounded type
  countWeighted dataSeries (repeat 1)  -- totalAssignment



                      -- (Bounded t, Enum t) => t -> Model of Int -> Model of t
modelInt2model egValue intModel =       -- example value gives the Model's type
  let fromE x = fromEnum( x `asTypeOf` egValue )
      toInt x = (fromE x) - (fromE minBound)
      p datum = pr intModel (toInt datum)
  in MPr (msg1 intModel) p (\()->show intModel)



estMultiStateWeighted dataSet weights =    -- weights and dataSet -> Model of t
  modelInt2model (dataSet !! 0) (freqs2model (countWeighted dataSet weights))

estMultiState dataSet =                    -- :: [Enum Bounded t] -> Model of t
  modelInt2model (dataSet !! 0) (freqs2model (count dataSet))



coinByPrH prH = modelInt2model H (probs2model [prH, 1-prH])

-- ------------------------------9/2002--6/2003--L.Allison--CSSE--Monash--.au--

normal eps m s = normalModel 0 eps m s                 -- i.e. a *fixed* Normal

normalModel m1 eps m s =                         -- i.e. Gaussian a.k.a. Normal
  -- eps is the data measurement accuracy,  +/-(eps/2).
  if s < eps then  -- need better integration, much earlier really!!!
       error "normal: std-dev too small v. data measurement accuracy"
  else MnlPr m1 (normalNlPr eps m s)                          -- Model of Float
       (\() -> "N(" ++ show m ++ "," ++ show s ++ ")(+-" ++ show eps ++ ")")

normalNlPr    eps m s x = normalNlPrDensity m s x - log eps    -- neg log prob'

normalNlPrDensity m s x =                                      -- neg log dens'
  let constPart = (log(2 * pi)) / 2 + (log s)
  in ( constPart + (((x-m)/s)**2)/2 )


estNormal minMean maxMean minSigma maxSigma eps xs =
-- Dealing with the normal (Gaussian) distribution is a good deal trickier
-- than sometimes realised.  The programmer cannot have any prior idea of the
-- ranges of the mean and standard deviation, but the guy with the data does!
-- So we require information about the ranges to be passed in.
-- This version of an estimator uses a uniform prior for the mean over
-- the given range, and 1/sigma over the given range.
--   (In some applications it is not unknown to stretch correctness and
--    to use priors based on the population (data) statistics.)
-- Priors, mean:  uniform over [minMean..maxMean]
--         sigma: 1/sigma over [minSigma..maxSigma]
-- Data is measured to +/-eps/2

-- f(x|m,s) = (1/(sqrt(2 pi) s)) exp(-sqr(x-m)/(2 sqr(s)))        pr(x) density
-- L = -ln(f( ))                                         neg log likelihood(xi)
--   = N.{ln(2 pi)/2 + ln(s)} + {SUMi (xi-m)^2}/(2 s^2)
-- 1st derivatives:
--   d L/d m = {N.m - SUMi xi}/s^2         (hence max LH est', m ~ sum/N) 
--   d L/d s = -N/s - {SUMi (xi-m)^2}/s^3  (hence max LH est', s ~.../N)
-- 2nd derivatives:
--   d2 L/d m d s = zero in expectation
--   d2 L/d m2    = N/s^2
--   d2 L/d s2    = N/s^2 + 3.{SUM (xi-m)^2}/s^4   = - 2.N/s^2 in expectation
-- Fisher Info = 2.N^2 / s^4      (e.g. see Ch6 s6.2 CSW's book-draft 6/2003)
-- NB. This code is deliberately straightforward, not optimised.

 let sqr z = z*z
     n     = fromIntegral(length xs)
     xMean = (foldl (+) 0 xs) / n                            -- naive algorithm
     xVariance
       = if n <= 1 then exp((log minSigma + log maxSigma) / 2) -- geom' mean
         else (foldl (\a -> \b -> a+sqr(b-mean)) 0 xs) / (n-1) -- NB n-1

     mean  = min maxMean (max minMean xMean)  -- or error if outside limits ???
     sigma = min maxSigma (max (max eps minSigma) (sqrt xVariance))

-- 1. It does not make sense to infer a standard deviation less
--    than something-like(!) the data measurement accuracy.
-- 2. A better(!) method of integrating the prior and the
--    likelihood should be used if n is small.
-- 3. Sometimes (small) data-sets are "odd", with just the one value
--    repeated in which case should you really be using the normal?

     nlPriorMean   = log(maxMean - minMean)               -- i.e. uniform prior
-- NB. INTEGRAL[a..b] 1/x = log b - log a
     nlPriorSigma  = log(log maxSigma - log minSigma)+log sigma  -- i.e 1/sigma
     nlPrior       = nlPriorMean + nlPriorSigma
-- konstant: the *2, for 2 parameters, cancels with /2 in the general msg form.
     konstant      = (1 + log(latticeK2))
     halfLogFisher = (log 2)/2 + log n - 2 * log sigma
     logFisher     = 2 * halfLogFisher
-- msg1, 1st part of message, e.g. Farr 1999, or
-- e.g. http://www.csse.monash.edu.au/~lloyd/tildeMML/Notes/Fisher.html
     msg1 = nlPrior + halfLogFisher + konstant

 in normalModel msg1 eps mean sigma

-- ----------------------------------------------------------------------------

test03 =
 let multi01 = freqs2model [1.1, 1.1, 2.2]     -- test fractional freq's
     coin01  = coinByPrH 0.5                   -- 50:50
     coin02  = estMultiState [H,H,T]           --  5:3  !
  -- NB. tuples of Enum Bounded types made instances of class Enum in Utilities
     twoCoin = estMultiState [(H,H),(T,T), (H,T),(H,T), (T,H), (T,H) ]
     bitsMdl = freqs2model [100,100]
     n01     = normal  0.1  0  1               -- NB. data measured to  +/- 0.1
     flts1   = [1.0, 2.0, 0.0, 1.0, -1.0, 3.0]
     flts2   = take 60 (cycle flts1) -- just more of the same
     norm1   = estNormal (-10.0) 10.0  0.1 10.0  0.1 flts1
     norm2   = estNormal (-10.0) 10.0  0.1 10.0  0.1 flts2
 in print "-- test03 --"
 >> print("multi01 = " ++ show multi01
    ++ " = " ++ show(map (pr multi01) [0, 1, 2] ) )
 >> print("pr coin01 H and T = " ++ show(map (pr coin01) [H, T] ) )
 >> print("pr coin02 H and T = " ++ show(map (pr coin02) [H, T] ) )
 >> print("pr twoCoin same   = " ++ show(map (pr twoCoin) [(H,H),(T,T)] ) )
 >> print("msgBase  2 bitsMdl [0,1]= "
    ++ show( msgBase 2 bitsMdl [0,1] ) ++ " bits" )
 >> print("msg1Base 2 bitsMdl      = "
    ++ show( msg1Base 2 bitsMdl ) ++ " bits" )
 >> print("msg2Base 2 bitsMdl [0,1]= "
    ++ show( msg2Base 2 bitsMdl [0,1] ) ++ " bits" )
 >> print("msg1Base 2 (mixture bitsMdl ...) = "
    ++ show( msg1Base 2 (mixture (Mix bitsMdl [bitsMdl, bitsMdl])) )++" bits")
 >> print("wallaceIntModel 0, 1, ... = "
    ++ show(map (nlPrBase 2 wallaceIntModel) [0,1,2,3,4,5,6]) ++ " bits")
 >> print(show n01 ++ " 0.0 1.0 = " ++ show( map (pr n01) [0.0, 1.0] ) )
 >> print(show norm1 ++" flts1 = " ++show(msg1 norm1) ++"+"
                 ++show(msg2 norm1 flts1) ++" nits")
 >> print(show norm2 ++" flts2 = " ++show(msg1 norm2) ++"+"
                 ++show(msg2 norm2 flts2) ++" nits")
-- ----------------------------------------------------------------------------

