module MagnusTrie3 (insert, insertWith, singleton, empty, lookup)
    where

-- TODO:
--  test with lazy parameters and no buckets
--  test with strict constructors
--  test with a simple (unbalanced) binary tree (instead of Map)


import Prelude hiding (lookup)
import qualified Data.Map  as M  

data Trie a v = NothingNothingNode
              | JustNothingNode !v
              | NothingJustNode !a !(Trie a v) !(Trie a v) !(Trie a v)
              | JustJustNode !v !a !(Trie a v) !(Trie a v) !(Trie a v)
              | NothingSimpleNode !a !(Trie a v)
              | JustSimpleNode !v !a !(Trie a v)
              deriving (Show)

empty :: (Ord a)=> Trie a v
empty = NothingNothingNode

singleton :: (Ord a)=> [a] -> v -> Trie a v
singleton [] v = JustNothingNode v
singleton (a:as) v = NothingSimpleNode a (singleton as v)

lookup :: Ord a => [a] -> Trie a v -> Maybe v
lookup [] (JustNothingNode v) = Just v
lookup [] (JustJustNode v _ _ _ _) = Just v
lookup [] (JustSimpleNode v _ _) = Just v
lookup [] (NothingNothingNode) = Nothing
lookup [] (NothingJustNode _ _ _ _) = Nothing
lookup [] (NothingSimpleNode _ _) = Nothing
lookup (a:as) (JustNothingNode _) = Nothing
lookup (a:as) (NothingNothingNode) = Nothing
lookup (a:as) (JustJustNode _ head lesser tail greater) =
    case compare a head of
      LT -> lookup (a:as) lesser
      EQ -> lookup as tail
      GT -> lookup (a:as) greater
lookup (a:as) (NothingJustNode head lesser tail greater) =
    case compare a head of
      LT -> lookup (a:as) lesser
      EQ -> lookup as tail
      GT -> lookup (a:as) greater
lookup (a:as) (JustSimpleNode _ head tail) =
    if a == head then lookup as tail else Nothing
lookup (a:as) (NothingSimpleNode head tail) =
    if a == head then lookup as tail else Nothing

-- merge the key-list/value pair into the Trie
insert :: (Ord a)=> [a] -> v -> Trie a v -> Trie a v
insert = insertWith const


insertWith :: (Ord a)=> (v->v->v) -> [a] -> v -> Trie a v -> Trie a v
insertWith f as v NothingNothingNode = singleton as v
insertWith f [] v (NothingJustNode head lesser tail greater) = JustJustNode v head lesser tail greater
insertWith f [] v (NothingSimpleNode head tail) = JustSimpleNode v head tail
insertWith f [] v (JustNothingNode v1) = JustNothingNode (f v v1)
insertWith f [] v (JustJustNode v1 head lesser tail greater) = JustJustNode (f v v1) head lesser tail greater
insertWith f [] v (JustSimpleNode v1 head tail) = JustSimpleNode (f v v1) head tail
insertWith f (a:as) v (JustNothingNode mv) = JustSimpleNode mv a (singleton as v)
insertWith f (a:as) v (JustJustNode mv head lesser tail greater) =
    case compare a head of
      LT -> JustJustNode mv head (insertWith f (a:as) v lesser) tail greater
      EQ -> JustJustNode mv head lesser (insertWith f as v tail) greater
      GT -> JustJustNode mv head lesser tail (insertWith f (a:as) v greater)
insertWith f (a:as) v (NothingJustNode head lesser tail greater) =
    case compare a head of
      LT -> NothingJustNode head (insertWith f (a:as) v lesser) tail greater
      EQ -> NothingJustNode head lesser (insertWith f as v tail) greater
      GT -> NothingJustNode head lesser tail (insertWith f (a:as) v greater)
insertWith f (a:as) v (JustSimpleNode mv head tail) =
    case compare a head of
      LT -> JustJustNode mv head (singleton (a:as) v) tail empty
      EQ -> JustSimpleNode mv head (insertWith f as v tail)
      GT -> JustJustNode mv head empty tail (singleton (a:as) v)
insertWith f (a:as) v (NothingSimpleNode head tail) =
    case compare a head of
      LT -> NothingJustNode head (singleton (a:as) v) tail empty
      EQ -> NothingSimpleNode head (insertWith f as v tail)
      GT -> NothingJustNode head empty tail (singleton (a:as) v)

