module Heap where

import Code
import CodeLens
import Data.Foldable (traverse_)
import qualified Data.Map as M
import IR (Id(..))
import Lens.Micro.Mtl

data Dereferenced
  = FreeRef Int
  | BoundRef Int Datum
  | NoRef
  deriving (Show, Eq)

-- TRICKY: HeapRefs alone must not form a cycle other than the FreeRef case.
deref' :: Heap -> Int -> Dereferenced
deref' h@(Heap _ hmap) x =
  case hmap M.!? x of
    Just (HeapRef x') ->
      if x == x'
        then FreeRef x'
        else deref' h x'
    Just x' -> BoundRef x x'
    _ -> NoRef

derefHeap = deref' --TODO remove

deref :: Int -> PrlgEnv Dereferenced
deref x = flip deref' x <$> use (cur . heap)

writeHeap :: Int -> Datum -> PrlgEnv ()
writeHeap a v = cur . heap %= (\(Heap nxt m) -> Heap nxt $ M.insert a v m)

allocHeap :: Int -> PrlgEnv Int
allocHeap n = do
  Heap nxt m <- use (cur . heap)
  cur . heap .= Heap (nxt + n) m
  pure nxt

makeVar a = writeHeap a (HeapRef a)

newHeapVar = head <$> newHeapVars 1

newHeapVars n = do
  base <- allocHeap n
  let addrs = [base .. base + n - 1]
  traverse_ makeVar addrs
  pure addrs

putHeapStruct addr s@(Struct Id {arity = arity}) = do
  base <- allocHeap (arity + 1)
  let paddrs = map (base +) [1 .. arity]
  traverse_ makeVar paddrs
  writeHeap base s
  writeHeap addr (HeapRef base)
  return $ map HeapRef paddrs

heapStruct ::
     Monad m
  => (Datum -> m a)
  -> (Datum -> [a] -> m a)
  -> (Datum -> Int -> m a)
  -> Heap
  -> Int
  -> m a
heapStruct atom struct rec (Heap _ heap) hr = go [hr] hr
  where
    go visited ref
      | rr@(HeapRef r) <- heap M.! ref =
        if r == ref || r `elem` visited
          then rec rr ref
          else go (r : visited) r
      | s@(Struct (Id _ arity)) <- heap M.! ref =
        sequence [go (ref + i : visited) (ref + i) | i <- [1 .. arity]] >>=
        struct s
      | x <- heap M.! ref = atom x