module Builtins where

import Code
  ( Builtin(..)
  , BuiltinFn
  , Cho(..)
  , Datum(..)
  , Instr(..)
  , Interp(..)
  , heapStruct
  )
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Lazy (get, gets, modify)
import Data.Functor.Identity (runIdentity)
import Data.List (intercalate)
import qualified Data.Map as M
import Env (PrlgEnv(..), findStruct)
import qualified IR
import Interpreter (backtrack)
import qualified Operators as O
import System.Console.Haskeline (getInputChar, outputStr, outputStrLn)

bi = Builtin

showTerm itos heap = runIdentity . heapStruct atom struct hrec heap
  where
    atom (Atom a) = pure $ itos M.! a
    atom VoidRef = pure "_"
    struct (Struct (IR.Id h _)) args =
      pure $ itos M.! h ++ "(" ++ intercalate "," args ++ ")"
    hrec (HeapRef hr) ref =
      pure $
      (if hr == ref
         then "_X"
         else "_Rec") ++
      show hr

printLocals :: BuiltinFn
printLocals = do
  scope <- gets (gvar . cur)
  heap <- gets (heap . cur)
  IR.StrTable _ _ itos <- gets strtable
  flip traverse (M.elems scope) $ \(ref, name) ->
    lift . outputStrLn $
    (maybe "_" id $ itos M.!? name) ++ " = " ++ showTerm itos heap ref
  return Nothing

promptRetry :: BuiltinFn
promptRetry = do
  x <- lift $ getInputChar "? "
  case x of
    Just ';' -> backtrack
    _ -> return Nothing

write :: BuiltinFn
write = do
  scope <- gets (hvar . cur)
  heap <- gets (heap . cur)
  IR.StrTable _ _ itos <- gets strtable
  lift . outputStr . showTerm itos heap . fst $ scope M.! 0
  return Nothing

nl :: BuiltinFn
nl = do
  lift $ outputStrLn ""
  return Nothing

writeln :: BuiltinFn
writeln = write >> nl

assertFact :: BuiltinFn
assertFact = do
  scope <- gets (hvar . cur)
  heap <- gets (heap . cur)
  {- TODO this needs to go through PrlgInt because of cuts in assertClause -}
  let atom a = Just [U a]
      struct s args = Just (U s : concat args)
      hrec (HeapRef tgt) src
        | src == tgt = Just [U (LocalRef tgt 0)]
        | otherwise = Nothing
      code = heapStruct atom struct hrec heap . fst $ scope M.! 0
  case code of
    Just (U (Struct s):head) -> do
      addClause s (head ++ [NoGoal])
      return Nothing
    Just [U (Atom a)] -> do
      addClause (IR.Id a 0) [NoGoal]
      return Nothing
    _ -> backtrack

retractall :: BuiltinFn
retractall = do
  return Nothing

{- adding the builtins -}
addOp op = modify $ \s -> s {ops = op : ops s}

addClause struct head =
  modify $ \s ->
    s {defs = M.alter (Just . maybe [head] (\hs -> head : hs)) struct $ defs s}

addProcedure struct heads =
  modify $ \s -> s {defs = M.insert struct heads $ defs s}

addProc n a c = do
  sym <- findStruct n a
  addProcedure sym c

addBi0 n b = addProc n 0 [[Invoke $ bi b]]

addPrelude :: PrlgEnv ()
addPrelude = do
  pure undefined
  {- primitives -}
  addBi0 "true" (pure Nothing)
  addBi0 "fail" backtrack
  addOp $ O.xfx "=" 700
  addProc "=" 2 [[U (LocalRef 0 0), U (LocalRef 0 0), NoGoal]]
  {- clauses -}
  addOp $ O.xfy "," 1000
  addOp $ O.xfx ":-" 1200
  addOp $ O.fx ":-" 1200
  addProc "assert" 1 [[U (LocalRef 0 0), Invoke (bi assertFact)]]
  addProc "retractall" 1 [[U (LocalRef 0 0), Invoke (bi retractall)]]
  {- query tools -}
  addBi0 "print_locals" printLocals
  addBi0 "prompt_retry" promptRetry
  {- IO -}
  addProc "write" 1 [[U (LocalRef 0 0), Invoke (bi write)]]
  addProc "writeln" 1 [[U (LocalRef 0 0), Invoke (bi writeln)]]
  addBi0 "nl" nl
  {- debug -}
  addBi0 "interpreter_trace" (get >>= liftIO . print >> pure Nothing)