-- See: L. Allison. Types and classes of machine learning and data mining.
--      26th Australasian Computer Science Conference (ACSC) pp.207-215,
--      Adelaide, February 2003
--      L. Allison. Models for machine learning and data mining in
--      functional programming.      doi:10.1017/S0956796804005301
--      J. Functional Programming, 15(1), pp.15-32, Jan. 2005
-- Author: Lloyd ALLISON           lloyd at bruce cs monash edu au
--         http://www.csse.monash.edu.au/~lloyd/tildeFP/II/200309/
-- This program is free software; you can redistribute it and/or modify it
-- under the terms of the GNU General Public License (GPL) as published by
-- the Free Software Foundation; either version 2 of the License, or (at
-- your option) any later version.  This program is distributed in the hope
-- that it will be useful, but without any warranty, without even the implied
-- warranty of merchantability or fitness for a particular purpose.  See the
-- GNU General Public License for more details.  You should have received a
-- copy of the GNU General Public License with this program; if not, write to:
-- Free Software Foundation, Inc., Boston, MA 02111, USA.

module CTrees (module CTrees) where
import SM_Utilities
import SM_Classes
import Models


-- A classification tree, a.k.a. decision tree, is
-- an instance of SuperModel and FunctionModel


data CTreeType inSpace opSpace =
  CTleaf   (ModelType opSpace)  |
  CTfork   MessageLength (Splitter inSpace) [CTreeType inSpace opSpace]  |
  CTforkFM (FunctionModelType inSpace Int)  [CTreeType inSpace opSpace]
  -- The last option, CTFork, is rather speculative as of 2002, 2003.


instance SuperModel (CTreeType inSpace opSpace) where
 -- NB. For simplicity, this costs the *structure* at 1-bit per node
 -- which is only optimal for binary trees. See Wallace & Patrick 1993.
 -- It also assumes that a tree uses only Trad or only non-Trad (FM) forks.
 msg1 (CTleaf leafModel)       = log2 + msg1 leafModel
 msg1 (CTfork fnLen f dts) = log2 + fnLen + (foldl (+) 0 (map msg1 dts))
 -- pity no built-in class Function in Haskell
 msg1 (CTforkFM fnMixer dts)     =
    log2 + (msg1 fnMixer)+(foldl (+) 0 (map msg1 dts))


instance FunctionModel CTreeType where
  condModel (CTleaf leafModel)       i = leafModel
  condModel (CTfork fnLen f dts) i = condModel (dts !! (applySplitter f i)) i

  -- only include the following to check the types
  condModel (CTforkFM fnMixer dts)     i =   -- below is == the 1-branch method
    let choice = biggest( map (\n -> condPr fnMixer i n)
                              [0 .. ((length dts)-1)]   ) 0 (-1) 0
        biggest  []    posn bigVal bigPosn = bigPosn
        biggest (p:ps) posn bigVal bigPosn =
          if p > bigVal then biggest ps (posn+1) p      posn
                        else biggest ps (posn+1) bigVal bigPosn
    in condModel (dts !! choice) i


instance Show (CTreeType inSpace opSpace) where                -- printing etc.
  showsPrec d (CTleaf mdl) rest = "{CTleaf " ++ (showsPrec d mdl ("}" ++ rest))
  showsPrec d (CTfork _ f subTrees) rest =
    "{CTfork " ++ (show f) ++ (showsPrec d subTrees ("}" ++ rest))
  showsPrec d (CTforkFM f subTrees) rest =
    "{CTforkFM " ++ (show f) ++ (showsPrec d subTrees ("}" ++ rest))


                                     -- estimate a classification tree, a CTree
estCTree  estLeafMdl splits  ipSet opSet =
 let  -- NB. estLeafMdl must be able to handle an empty data set
  search ipSet opSet =
   let
    leaf    = CTleaf leafMdl        -- the simplest 1-level tree is just a leaf
    leafMdl = estLeafMdl opSet
    leafMsg = msg (functionModel2model leaf) (zip ipSet opSet)
            -- this is lazy programming, not lazy evaluation!-)

                                    -- search for the best (1- or) 2-level tree
    alternatives [] bestML bestCTree bestIpParts bestOpParts =
      (bestCTree, bestIpParts, bestOpParts)  -- done

                             -- 2-level tree. NB simplest 0-lookahead algorithm
    alternatives (sp:sps) bestML bestCTree bestIpParts bestOpParts =
      let -- NB. valid, but probably better to bias towards "earlier" splts
          theTree  = CTfork (log (fromIntegral (length splts))) sp leaves
          leaves   = map CTleaf leafMdls     -- one leaf ...
          leafMdls = map estLeafMdl opParts  -- ... per part
          partNums = map (applySplitter sp) ipSet
          ipParts  = partition (aritySplitter sp) partNums ipSet
          opParts  = partition (aritySplitter sp) partNums opSet
          msgLen   = msg (functionModel2model theTree) (zip ipSet opSet)
      in
         if msgLen < bestML                                   -- an improvement
         then alternatives sps msgLen theTree   ipParts     opParts   -- new 1
         else alternatives sps bestML bestCTree bestIpParts bestOpParts -- old

    splts = splits ipSet  -- beware if very long (slow)

   in case  alternatives splts leafMsg leaf [ipSet] [opSet]  of
        ((CTfork msgLen pf leaves), ipParts, opParts) ->
          CTfork msgLen pf (zipWith search ipParts opParts) ;      -- subtrees?
        (t, _, _) -> t                                  -- the single leaf wins

 in search ipSet opSet

-- ------------------------------9/2002--6/2003--L.Allison--CSSE--Monash--.au--

test07 = let                                                            -- test
 { ct01_leaf1 = CTleaf (freqs2model [2,2]) ;  -- 1:1
   ct01_leaf2 = CTleaf (freqs2model [1,3]) ;  -- 3:7
   ct01 = CTforkFM
            (FM 0 (\ip->MPr 0 (\op->if ip==op then 1.0 else 0.0) (\()->"?"))
                  (\()->"?"))
            [ct01_leaf1, ct01_leaf2] ;
   -- NB. ct01 (trivially) has a FunctionModel in the fork nodes

   ct02_leaf0 = CTleaf (coinByPrH 0.5);
   ct02_leaf1 = CTleaf (coinByPrH 0.2);
   ct02 = CTfork 0 (Splitter 2 (\ip->ip) (\()->"id|{0,1}"))
                 [ct02_leaf0, ct02_leaf1];

   ct03 = estCTree estMultiState splits ips ops ;              -- infer a CTree
   -- NB. above uses default splits for ips; see Utilities.
   ipCases = [ (H, H,  1.0), (H, H,  2.0), (H, T,  1.0), (H, T,  2.0 :: Float),
               (T, H,  1.0), (T, H,  2.0), (T, T,  1.0), (T, T,  2.0),
               (H, H, -1.0), (H, H, -2.0), (H, T, -1.0), (H, T, -2.0),
               (T, H, -1.0), (T, H, -2.0), (T, T, -1.0), (T, T, -2.0) ] ;
   opCases = [ H, H, H, H,   T, T, T, T,
               H, H, H, H,   H, H, H, H ] ;   -- H iff @0 == H or @2 < 1.0
   ips = (take 100 (cycle ipCases)) ++ ipCases ;        --  ++ noise
   ops = (take 100 (cycle opCases)) ++ [T,T,T,T, H,H,H,H, T,T,T,T, T,T,T,T] ;

                                                    -- Normal (Gaussian) leaves
   ct04 = estCTree (estNormal (-10.0) 10.0 0.1 10.0 0.1) splits
                   (take 100 (cycle b2)) (take 100 (cycle cts)) ;
                                                  -- play with sep'n and with n
   -- @0 => N(0,1) v. N(sep', 1);  @1 irrelevant
   b2 = [(True, True), (True, False), (True, True), (True, False),
         (False,True), (False,False), (False,True), (False,False)] ;
   cts = let c = [1.0, 1.0, -1.0, -1.0]
             cplus separation = map (+separation) c
         in c ++ cplus 0.9
 }

 in print "-- test07 --"
 >> print("condPr ct01 0 0 and 1 0 = "
     ++ show( zipWith (condPr ct01) [0,1] [0,0]) )
 >> print("condPr ct02 0 H and 1 H = "
     ++ show( zipWith (condPr ct02) [0,1] [H,H] ) )
 >> print("splits ips = " ++ show( splits ips ) )
 >> print("ct03 = " ++ show(ct03) ++ " = "
     ++ show(msg1 ct03) ++ " nits, data|ct03 = "
     ++ show( msg2 (functionModel2model ct03) (zip ips ops) ) ++ " nits")
 >> print("condPr ct03 [(H,H,1.0) ... ] H = "
     ++ show( zipWith (condPr ct03) ipCases (take 16 (repeat H)) ) )
 >> print("ct04 = " ++ show ct04 )
-- ----------------------------------------------------------------------------
