{- VAM 2P, done the lazy way -}
module Interpreter where

import Code
--import Data.Function
import qualified Data.Map as M
import Env (PrlgEnv(..))
import IR (Id(..))
import qualified Control.Monad.Trans.State.Lazy as St

prove :: Code -> PrlgEnv (Either String Bool)
prove g = do
  St.modify $ \i ->
    i
      { cur =
          Cho
            { hed = g
            , hvar = emptyScope
            , gol = [LastCall]
            , gvar = emptyScope
            , heap = emptyHeap
            , stk = []
            , cut = []
            }
      , cho = []
      }
  loop
  where
    loop = do
      i <- St.get
      proveStep cont finish i
    cont i = St.put i >> loop
    finish i res = St.put i >> return res

data Dereferenced
  = FreeRef Int
  | BoundRef Int Datum
  | NoRef

proveStep :: (Interp -> a) -> (Interp -> Either String Bool -> a) -> Interp -> a
proveStep c f i = go i
  where
    ifail msg = f i $ Left msg
    tailcut [LastCall] chos _ = Just chos
    tailcut [LastCall, Cut] _ cut = Just cut
    tailcut _ _ _ = Nothing
    withDef fn
      | Just d <- defs i M.!? fn = ($ d)
      | otherwise = const $ ifail $ "no definition: " ++ show fn
    {- Backtracking -}
    backtrack i@Interp {cho = chos}
      {- if available, restore the easiest choicepoint -}
      | (cho:chos) <- chos = c i {cur = cho, cho = chos}
      {- if there's no other choice, answer no -}
      | otherwise = f i $ Right False
    {- Unification -}
    go i@Interp {cur = cur@Cho { hed = U h:hs
                               , gol = U g:gs
                               , heap = heap@(Heap _ hmap)
                               }} = unify h g
        {- termination tools -}
      where
        uok = c i {cur = cur {hed = hs, gol = gs}}
        setHeap r x =
          c i {cur = cur {hed = hs, gol = gs, heap = writeHeap r x heap}}
        {- heap tools -}
        deref x =
          case hmap M.!? x of
            Just (HeapRef x' _) ->
              if x == x'
                then FreeRef x'
                else deref x'
            Just x' -> BoundRef x x'
            _ -> NoRef
        writeHeap addr x (Heap nxt m) = Heap nxt (M.adjust (const x) addr m)
        newHeapVar h = head <$> newHeapVars 1 h
        newHeapVars n (Heap nxt m) =
          let addrs = [nxt + i - 1 | i <- [1 .. n]]
           in ( Heap (nxt + n) $
                foldr (uncurry M.insert) m [(a, HeapRef a Nothing) | a <- addrs]
              , addrs)
        allocLocal reg scope cont
          | Just addr <- scope M.!? reg = cont scope heap addr
          | (heap', addr) <- newHeapVar heap =
            cont (M.insert reg addr scope) heap' addr
        newHeapStruct addr s@(Struct Id {arity = arity}) cont =
          let (Heap nxt' m', addrs) = newHeapVars (arity + 1) heap
              m'' =
                M.insert addr (HeapRef (head addrs) Nothing) .
                M.insert (head addrs) s $
                m'
           in cont [HeapRef a Nothing | a <- tail addrs] (Heap nxt' m'')
        {- simple cases first -}
        unify (VoidRef _) (VoidRef _) = uok
        unify (Atom a) (Atom b)
          | a == b = uok
        unify (VoidRef _) (Atom _) = uok
        unify (Atom _) (VoidRef _) = uok
        unify (Struct a) (Struct b)
          | a == b = uok
        {- unifying a struct with void must cause us to skip the void -}
        unify (VoidRef _) (Struct Id {arity = a}) =
          c
            i
              { cur =
                  cur {hed = replicate a (U $ VoidRef Nothing) ++ hs, gol = gs}
              }
        unify (Struct Id {arity = a}) (VoidRef _) =
          c
            i
              { cur =
                  cur {hed = hs, gol = replicate a (U $ VoidRef Nothing) ++ gs}
              }
        {- handle local refs; first ignore their combination with voids to save memory -}
        unify (LocalRef _ _) (VoidRef _) = uok
        unify (VoidRef _) (LocalRef _ _) = uok
        {- allocate heap for LocalRefs and retry with HeapRefs -}
        unify (LocalRef hv ident) _ =
          allocLocal hv (hvar cur) $ \hvar' heap' addr ->
            c
              i
                { cur =
                    cur
                      { hed = U (HeapRef addr ident) : hs
                      , hvar = hvar'
                      , heap = heap'
                      }
                }
        unify _ (LocalRef gv ident) =
          allocLocal gv (gvar cur) $ \gvar' heap' addr ->
            c
              i
                { cur =
                    cur
                      { gol = U (HeapRef addr ident) : gs
                      , gvar = gvar'
                      , heap = heap'
                      }
                }
        {- handle heap refs; first ignore their combination with voids again -}
        unify (HeapRef _ _) (VoidRef _) = uok
        unify (VoidRef _) (HeapRef _ _) = uok
        {- actual HeapRefs, these are dereferenced and then unified (sometimes with copying) -}
        unify (HeapRef hr' hident) g =
          case deref hr' of
            FreeRef hr ->
              case g of
                atom@(Atom _) -> setHeap hr atom
                s@(Struct _) ->
                  newHeapStruct
                    hr
                    s
                    (\nhs nheap ->
                       c
                         i
                           { cur =
                               cur
                                 {hed = map U nhs ++ hs, gol = gs, heap = nheap}
                           })
                HeapRef gr' _ ->
                  case deref gr' of
                    FreeRef gr -> setHeap hr (HeapRef gr hident)
                    BoundRef addr _ -> setHeap hr (HeapRef addr hident)
                    _ -> ifail "dangling goal ref (from head ref)"
            BoundRef _ atom@(Atom a) -> unify atom g
            BoundRef addr struct@(Struct Id {arity = arity}) ->
              c
                i
                  { cur =
                      cur
                        { hed =
                            U struct :
                            [U (HeapRef (addr + i) Nothing) | i <- [1 .. arity]] ++
                            hs
                        , gol = U g : gs
                        }
                  }
            _ -> ifail "dangling head ref"
        unify h (HeapRef gr' gident) =
          case deref gr' of
            FreeRef gr ->
              case h of
                atom@(Atom _) -> setHeap gr atom
                s@(Struct _) ->
                  newHeapStruct
                    gr
                    s
                    (\ngs nheap ->
                       c
                         i
                           { cur =
                               cur
                                 {hed = hs, gol = map U ngs ++ gs, heap = nheap}
                           })
            BoundRef _ atom@(Atom b) -> unify h atom
            BoundRef addr struct@(Struct Id {arity = arity}) ->
              c
                i
                  { cur =
                      cur
                        { hed = U h : hs
                        , gol =
                            U struct :
                            [U (HeapRef (addr + i) Nothing) | i <- [1 .. arity]] ++
                            gs
                        }
                  }
            _ -> ifail "dangling goal ref"
        unify _ _ = backtrack i
    {- Resolution -}
    go i@Interp { cur = cur@Cho { hed = hed
                                , hvar = hvar
                                , gol = gol
                                , gvar = gvar
                                , heap = heap
                                , stk = stk
                                , cut = cut
                                }
                , cho = chos
                }
      {- invoke a built-in (this gets replaced by NoGoal by default but the
       - builtin can actually do whatever it wants with the code -}
      | [Builtin (BuiltinFunc bf)] <- hed =
        c (bf i {cur = cur {hed = [NoGoal]}})
      {- top-level success -}
      | [NoGoal] <- hed
      , Just nchos <- tailcut gol chos cut
      , [] <- stk =
        f i {cur = cur {hed = [], gol = []}, cho = nchos} $ Right True
      {- cut before the first goal (this solves all cuts in head) -}
      | Cut:hs <- hed = c i {cur = cur {hed = hs}, cho = cut}
      {- succeed and return to caller -}
      | [NoGoal] <- hed
      , Just nchos <- tailcut gol chos cut
      , (Goal:U (Struct fn):gs, ngvar, _):ss <- stk =
        withDef fn $ \(hs:ohs) ->
          c
            i
              { cur =
                  cur
                    { hed = hs
                    , hvar = emptyScope
                    , gol = gs
                    , gvar = ngvar
                    , stk = ss
                    }
              , cho =
                  [Cho oh emptyScope gs ngvar heap ss nchos | oh <- ohs] ++
                  nchos
              }
      {- succeed and return to caller, and the caller wants a cut -}
      | [NoGoal] <- hed
      , Just _ <- tailcut gol chos cut
      , (Cut:Goal:U (Struct fn):gs, ngvar, rchos):ss <- stk =
        withDef fn $ \(hs:ohs) ->
          c
            i
              { cur =
                  cur
                    { hed = hs
                    , hvar = emptyScope
                    , gol = gs
                    , gvar = ngvar
                    , stk = ss
                    }
              , cho =
                  [Cho oh emptyScope gs ngvar heap ss rchos | oh <- ohs] ++
                  rchos
              }
      {- start matching next goal -}
      | [NoGoal] <- hed
      , (Call:Goal:U (Struct fn):gs) <- gol =
        withDef fn $ \(hs:ohs) ->
          c
            i
              { cur = cur {hed = hs, hvar = emptyScope, gol = gs}
              , cho =
                  [Cho oh emptyScope gs gvar heap stk chos | oh <- ohs] ++ chos
              }
      {- start matching next goal after a cut -}
      | [NoGoal] <- hed
      , (Call:Cut:Goal:U (Struct fn):gs) <- gol =
        withDef fn $ \(hs:ohs) ->
          c
            i
              { cur = cur {hed = hs, hvar = emptyScope, gol = gs}
              , cho =
                  [Cho oh emptyScope gs gvar heap stk cut | oh <- ohs] ++ cut
              }
      {- goal head matching succeeded, make a normal call -}
      | (Goal:U (Struct fn):ngs) <- hed
      , (Call:gs) <- gol =
        withDef fn $ \(hs:ohs) ->
          let nstk = (gs, gvar, chos) : stk
           in c i
                  { cur =
                      cur
                        { hed = hs
                        , hvar = emptyScope
                        , gol = ngs
                        , gvar = hvar
                        , stk = nstk
                        }
                  , cho =
                      [Cho oh emptyScope ngs hvar heap nstk chos | oh <- ohs] ++
                      chos
                  }
      {- successful match continued by tail call -}
      | (Goal:U (Struct fn):ngs) <- hed
      , Just nchos <- tailcut gol chos cut =
        withDef fn $ \(hs:ohs) ->
          c
            i
              { cur = cur {hed = hs, hvar = emptyScope, gol = ngs, gvar = hvar}
              , cho =
                  [Cho oh emptyScope ngs hvar heap stk nchos | oh <- ohs] ++
                  nchos
              }
    {- The End -}
    go _ = ifail "code broken: impossible instruction combo"