-- See: L. Allison. Types and classes of machine learning and data mining.
--      26th Australasian Computer Science Conference (ACSC) pp.207-215,
--      Adelaide, February 2003
--      L. Allison. Models for machine learning and data mining in
--      functional programming.      doi:10.1017/S0956796804005301
--      J. Functional Programming, 15(1), pp.15-32, Jan. 2005
-- Author: Lloyd ALLISON           lloyd at bruce cs monash edu au
--         http://www.csse.monash.edu.au/~lloyd/tildeFP/II/200309/
-- This program is free software; you can redistribute it and/or modify it
-- under the terms of the GNU General Public License (GPL) as published by
-- the Free Software Foundation; either version 2 of the License, or (at
-- your option) any later version.  This program is distributed in the hope
-- that it will be useful, but without any warranty, without even the implied
-- warranty of merchantability or fitness for a particular purpose.  See the
-- GNU General Public License for more details.  You should have received a
-- copy of the GNU General Public License with this program; if not, write to:
-- Free Software Foundation, Inc., Boston, MA 02111, USA.

module SM_Classes (module SM_Classes) where                       -- export all
import SM_Utilities

--  SuperModel ---.--- Model dataSpace
--                |
--                |--- TimeSeries dataSpace
--                |
--                |--- FunctionModel inSpace opSpace
--                |
--                .--- ?other?

topline =
  "Haskell-98: Model types & classes, L.A., CSSE, Monash, .au, 9/2002, 6/2003"

type Probability   = Double
type MessageLength = Double


-- References:
--
-- L. Allison.  Types and classes of machine learning and data mining.
-- 26th Australasian Computer Science Conference (ACSC), pp207-215, Feb. 2003.
--
-- L. Allison.  The types of models.
-- 2nd Hawaii Int. Conf. on Statistics and Related Fields (HICS), June 2003.
-- Life is tough.


class (Show sMdl) => SuperModel sMdl where                        -- SuperModel
  prior    :: sMdl -> Probability    -- prior of blah-Model
  msg1     :: sMdl -> MessageLength  -- 1st part of any 2-part message
  msg1Base :: Double -> sMdl -> MessageLength
  mixture  :: (Mixture mx, SuperModel (mx sMdl)) =>
              mx sMdl -> sMdl        -- Mixture of blah_Model -> blah_Model

  prior sm = exp (-msg1 sm)
  msg1  sm = - log (prior sm)
  msg1Base b sm = (msg1 sm) / (log b)

  mixture mx = error "mixture not defined for SuperModel"
  -- NB. Any instance of SuperModel must define mixture and
  --     either prior or msg1 or both.
  -- In principle there might be a SuperModel not in class Show, but
  -- in practice easier to require them all to be, even if only trivially.



class Model mdl where                                                  -- Model
  pr   :: (mdl dataSpace) -> dataSpace -> Probability
  nlPr :: (mdl dataSpace) -> dataSpace -> MessageLength         -- neg log prob
  nlPrBase :: Double -> (mdl dataSpace) -> dataSpace -> MessageLength

  msg  :: SuperModel (mdl dataSpace) =>
          (mdl dataSpace) -> [dataSpace] -> MessageLength
  msgBase  :: SuperModel (mdl dataSpace) =>
              Double -> (mdl dataSpace) -> [dataSpace] -> MessageLength
  msg2 :: (mdl dataSpace) -> [dataSpace] -> MessageLength
  msg2Base :: Double -> (mdl dataSpace) -> [dataSpace] -> MessageLength

  pr   m datum = exp (- nlPr m datum) -- NB instances define pr or nlPr or both
  nlPr m datum = - log (pr m datum)
  nlPrBase b m datum = (nlPr m datum) / (log b)

  msg        m ds = (msg1 m) + (msg2 m ds)     -- complete 2-part message
  msgBase  b m ds = (msg  m ds) / (log b)
  msg2       m ds = foldl (+) 0 (map (nlPr m) ds)  -- part2  dataset|m  (list|m)
  msg2Base b m ds = (msg2 m ds) / (log b)
-- Some other properties (methods) that are possible, even desirable:
-- print (Show) yourself, generate sample data, ...


class TimeSeries tsm where                                        -- TimeSeries
-- given a TimeSeries, ts, of a dataSeries,  `predictors ts dataSeries' gives
-- a list of models for each position plus one (!) in the dataSeries,
-- each model conditional on the preceding data elements.
  predictors :: (tsm dataSpace) -> [dataSpace] -> [ModelType dataSpace]
  prs        :: (tsm dataSpace) -> [dataSpace] -> [Probability]
  nlPrs      :: (tsm dataSpace) -> [dataSpace] -> [MessageLength]
  nlPrsBase  :: Double -> (tsm dataSpace) -> [dataSpace] -> [MessageLength]
  -- any instance of TimeSeries must at least define  predictors
  -- and note that the output must be one longer than the dataSeries !

  prs tsm dataSeries = zipWith pr (predictors tsm dataSeries) dataSeries

  nlPrs tsm dataSeries = zipWith nlPr (predictors tsm dataSeries) dataSeries

  nlPrsBase b tsm dataSeries =
    zipWith (nlPrBase b) (predictors tsm dataSeries) dataSeries



class FunctionModel fm where                                   -- FunctionModel
  condModel :: (fm inSpace opSpace) -> inSpace -> ModelType opSpace
  condPr    :: (fm inSpace opSpace) -> inSpace -> opSpace -> Probability
  condNlPr  :: (fm inSpace opSpace) -> inSpace -> opSpace -> MessageLength
  condNlPrBase :: Double ->
               (fm inSpace opSpace) -> inSpace -> opSpace -> MessageLength
  -- any instance of FunctionModel must at least define  condModel

  condPr   fm i o = pr   (condModel fm i) o  -- cond prob, pr(o|i,fm)
  condNlPr fm i o = nlPr (condModel fm i) o
  condNlPrBase b fm i o = (condNlPr fm i o) / (log b)


-- -----------------------------------------------------------------L.Allison--

class Mixture mx where                   -- for miscellaneous weighted mixtures
  mixer      :: (SuperModel t) => mx t -> ModelType Int
  components :: (SuperModel t) => mx t -> [t]

data (SuperModel elt) =>
  MixtureType elt = Mix (ModelType Int) [elt] -- weighted averages; MixtureType

instance Mixture MixtureType where
  mixer      (Mix m _ ) = m
  components (Mix _ es) = es

instance (SuperModel elt) => SuperModel (MixtureType elt) where
  msg1 (Mix m es) =
    foldl (+) (nlPr wallaceIntModel (length es - 1) + msg1 m) (map msg1 es)
  -- NB. This assumes that the receiver must be told the number of components.

                                                                   -- Printable
instance (SuperModel ct) => Show (MixtureType ct) where
  showsPrec p (Mix m es) s =
    "{Mix " ++ showsPrec p m (" : " ++ showsPrec p es (s++"}"))
  -- ct is in Show because SuperModels are in Show

-- Could have a MixtureModel, a sub-class of Mixture and of Model,
-- ditto MixtureTimeSeries, MixtureFunctionModel, but cannot(?)
-- have MixtureType as an instance of Model & TimeSeries & FunctionModel.
-- In any case,  mixture (Mix m cs)  is in the class of its components.

-- ----------------------------------------------------------------------------
                                                        -- conversion functions

model2timeSeries m =        -- :: Model of dataSpace -> TimeSeries of dataSpace
  TSM (msg1 m) (\context -> m) (\()->"TSM "++show m)   -- trivial but useful

                 -- :: Model of dataSpace -> FunctionModel of inSpace dataSpace
model2functionModel m = FM (msg1 m) (\ip -> m) (\()->"FM "++show m)


             -- Model of Int -> TimeSeries of dataSpace -> Model of [dataSpace]
             -- the 1st param is for "lengths"
timeSeries2model1 lenMdl tsm =
  MnlPr (msg1 tsm + msg1 lenMdl) (\dataSeries ->
            foldl (+) (nlPr lenMdl (length dataSeries)) (nlPrs tsm dataSeries))
            (\() -> "[" ++ (showsPrec 0 lenMdl (":" ++ (show tsm) ++ "]") ) )

timeSeries2model tsm =       -- TimeSeries of dataSpace -> Model of [dataSpace]
  timeSeries2model1 wallaceIntModel tsm


timeSeries2functionModel tsm =  -- TimeSeries of ds -> FunctionModel of [ds] ds
  FM (msg1 tsm) (\dataSeries -> last(predictors tsm dataSeries))
     (\()->"FM "++show tsm)


             -- FunctionModel of inSpace opSpace -> Model of (inSpace, opSpace)
functionModel2model fm = MnlPr (msg1 fm) (\(i, o) -> condNlPr fm i o)
                               (\() -> "(_,o):"++show fm)
                                -- i.e. pr(o|i,fm), i assumed common knowledge!

           -- FunctionModel of [dataSpace] dataSpace -> TimeSeries of dataSpace
functionModel2timeSeries fm =
  TSM (msg1 fm) (condModel fm) (\()->"TSM "++show fm)


  -- Model inSpace -> FunctionModel inSpace opSpace -> Model (inSpace, opSpace)
condition m fm =                          -- ? should the param order be v.v. ?
  let nlp (i,o) = (nlPr m i) + (nlPr (condModel fm i) o)
  in MnlPr (msg1 m + msg1 fm) nlp
           (\() -> "(" ++ showsPrec 0 fm ("|" ++ show m ++ ")"))

-- ----------------------------------------------------------------------------
                                                                   -- ModelType
data ModelType dataSpace =
  MPr   MessageLength (dataSpace -> Probability)   (() -> String) |
     -- msg len,      prob fn,                      description
  MnlPr MessageLength (dataSpace -> MessageLength) (() -> String)
     -- msg len,      neg log prob fn,              description

instance SuperModel (ModelType dataSpace)  where
  msg1 (MPr   mdlLen p _) = mdlLen
  msg1 (MnlPr mdlLen m _) = mdlLen
  mixture mx =
    let nlp datum =
          let (m0:ms) = components mx
              mxr     = mixer mx
              w n  []    total = total    -- done
              w n (m:ms) total =          -- next component, m
                w (n+1) ms (logPlus total ((nlPr mxr n)+(nlPr m datum)))
                                       --    component + datum|component
          in w 1 ms ((nlPr mxr 0) + (nlPr m0 datum))   -- do weighted average
    in MnlPr (msg1 mx) nlp (\()->show mx)

instance TimeSeries ModelType where   -- Model of ds is a trivial TimeSeries ds
  predictors m dataSeries = map (\_ -> m) ((error "") : dataSeries)
  -- Note that the ouput list is one longer than the input  dataSeries !

instance Model ModelType where
  nlPr (MPr   _ p _) datum = - log (p datum)
  nlPr (MnlPr _ n _) datum = n datum

instance Show (ModelType dataSpace)  where
  show (MPr   _ _ desc) = desc()
  show (MnlPr _ _ desc) = desc()

-- of course there can be other types that are instances of Model

-- ----------------------------------------------------------------------------
-- Logically, wallaceIntModel should be in module Models, but I get an
-- intermittent compile error if it is -- so there, or rather, here.

wallaceIntModel =                                -- Model of non-neg Int,  >= 0
-- A ``universal'' code, from the Wallace tree code (Wallace & Patrick 1993).
-- The number of full binary trees of n leaves is the (n-1)th Catalan number;
-- the nth catalan number is {[2n]C[n]}/(n+1)
-- C[0]=1, C[1]=1, C[2]=2, C[3]=5, C[4]=14, ...     -- n  code     cat' cum'
  let catalans =              -- NB. "cached"       -- 0  0        1    1
        let cats last n =                           -- 1  100      1    2
              let twoN = 2*n                        -- 2  10100    2    4
                  n1   = n+1                        -- 3  11000
                  nSq  = n*n                        -- 4  1010100  5    9
                  -- next = last * twoN * (twoN-1) `div` nSq  -- obvious
                  next = last * 2 * (twoN-1) `div` n          -- better
              in (next `div` n1) : (cats next n1)
        in 1 : (cats 1 1)
      cumulativeCatalans = scanl1 (+) catalans
      find n posn (e:es) = if n < e then posn else find n (posn+1) es
  in MnlPr 0 (\n -> ((find (assert (0<=) n) 0 cumulativeCatalans)*2+1)*log2)
           (\()->"wallaceIntModel")
-- Note that it is a non-parametric model.

-- ----------------------------------------------------------------------------
                                                              -- TimeSeriesType
data TimeSeriesType dataSpace =
  TSM MessageLength ([dataSpace] -> ModelType dataSpace) (() -> String)
               -- i.e. context   -> Model of next elt     description
               -- NB. context ordered last..first

instance SuperModel (TimeSeriesType dataSpace) where
  msg1 (TSM mdlLen m _) = mdlLen
  mixture mx =
    let f context = mixture(Mix (mixer mx)
                      (map (\(TSM _ ftsm _) -> ftsm context) (components mx)))
    in TSM (msg1 mx) f (\()->show mx)

instance TimeSeries TimeSeriesType where
  predictors (TSM mdlLen f _) dataSeries =
    let scan  []    context = [f context] -- NB. context ordered last..first
        scan (d:ds) context = (f context) : (scan ds (d:context))
    in scan dataSeries []
  -- This method of definition is efficient if f examines the last few datas.
  -- Note that the ouput list is one longer than the input  dataSeries !

instance Show (TimeSeriesType dataSpace)  where
  show (TSM _ _ desc) = desc()

-- of course there can be other types that are instances of TimeSeries
-- ----------------------------------------------------------------------------
                                                           -- FunctionModelType
data FunctionModelType inSpace opSpace =
  FM MessageLength (inSpace -> ModelType opSpace) (() -> String)

instance SuperModel (FunctionModelType inSpace opSpace) where
  msg1 (FM mdlLen m _) = mdlLen
  mixture mx =
    let condM inp = mixture (Mix (mixer mx)
                                 (map (\f -> condModel f inp) (components mx)))
    in FM (msg1 mx) condM (\()->show mx)

instance FunctionModel FunctionModelType where
  condModel (FM mdlLen f _) = f

instance Show (FunctionModelType inSpace opSpace) where
  show (FM _ _ desc) = desc()

-- of course there can be other types that are instances of FunctionModel
-- ------------------------------9/2002--6/2003--L.Allison--CSSE--Monash--.au--

test02 =                             -- some very simple (non-exhaustive) tests
  let { p 0 = 1/2; p 1 = 1/4; p 2 = 1/8; p 3 = 1/8;
        dice4 = MPr 0 p (\()->"skewed4");  -- a very(!) simple Model
        values = [0, 1, 2, 3] }
  in print "-- test02 --"
      >> print( "pr dice4 0, 1, 2, 3   = " ++ show( map (pr dice4) values ))
      >> print( "nlPr dice4 0, 1, 2, 3 = " ++ show( map (nlPr dice4) values ))
      >> print( "nlPrBase 2 dice4 0, 1, 2, 3 = "
          ++ show( map (nlPrBase 2 dice4) values ) ++ " bits")
      >> print( "msgBase  2 dice4 [0,1,2,3]  = "
          ++ show( msgBase 2 dice4 values ) ++ " bits")
      >> print( "nlPrBase 2 (timeSeries2model dice4) [0,1,2,3] = "
          ++ show( nlPrBase 2 (timeSeries2model dice4) values ) ++ " = "
          ++ show( msg2Base 2 (timeSeries2model dice4) [values] ))

-- ----------------------------------------------------------------------------
