module Juvix.Compiler.Builtins.Effect
  ( module Juvix.Compiler.Builtins.Effect,
  )
where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Builtins.Error
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Pretty
import Juvix.Prelude

data Builtins :: Effect where
  GetBuiltinName' :: Interval -> BuiltinPrim -> Builtins m Name
  RegisterBuiltin' :: BuiltinPrim -> Name -> Builtins m ()

makeSem ''Builtins

registerBuiltin :: (IsBuiltin a, Member Builtins r) => a -> Name -> Sem r ()
registerBuiltin = registerBuiltin' . toBuiltinPrim

getBuiltinName :: (IsBuiltin a, Member Builtins r) => Interval -> a -> Sem r Name
getBuiltinName i = getBuiltinName' i . toBuiltinPrim

newtype BuiltinsState = BuiltinsState
  { _builtinsTable :: HashMap BuiltinPrim Name
  }

makeLenses ''BuiltinsState

iniBuiltins :: BuiltinsState
iniBuiltins = BuiltinsState mempty

runBuiltins :: forall r a. (Member (Error JuvixError) r) => BuiltinsState -> Sem (Builtins ': r) a -> Sem r (BuiltinsState, a)
runBuiltins ini = reinterpret (runState ini) $ \case
  GetBuiltinName' i b -> fromMaybeM notDefined (gets (^. builtinsTable . at b))
    where
      notDefined :: Sem (State BuiltinsState ': r) x
      notDefined =
        throw $
          JuvixError
            NotDefined
              { _notDefinedBuiltin = b,
                _notDefinedLoc = i
              }
  RegisterBuiltin' b n -> do
    s <- gets (^. builtinsTable . at b)
    case s of
      Nothing -> do
        modify (over builtinsTable (set (at b) (Just n)))
      Just {} -> alreadyDefined
    where
      alreadyDefined :: Sem (State BuiltinsState ': r) x
      alreadyDefined =
        throw $
          JuvixError
            AlreadyDefined
              { _alreadyDefinedBuiltin = b,
                _alreadyDefinedLoc = getLoc n
              }

evalBuiltins :: (Member (Error JuvixError) r) => BuiltinsState -> Sem (Builtins ': r) a -> Sem r a
evalBuiltins s = fmap snd . runBuiltins s

data FunInfo = FunInfo
  { _funInfoDef :: FunctionDef,
    _funInfoBuiltin :: BuiltinFunction,
    _funInfoSignature :: Expression,
    _funInfoClauses :: [(Expression, Expression)],
    _funInfoFreeVars :: [VarName],
    _funInfoFreeTypeVars :: [VarName]
  }

makeLenses ''FunInfo

registerFun ::
  (Members '[Builtins, NameIdGen] r) =>
  FunInfo ->
  Sem r ()
registerFun fi = do
  let op = fi ^. funInfoDef . funDefName
      ty = fi ^. funInfoDef . funDefType
      sig = fi ^. funInfoSignature
  unless ((sig ==% ty) (HashSet.fromList (fi ^. funInfoFreeTypeVars))) (error "builtin has the wrong type signature")
  registerBuiltin (fi ^. funInfoBuiltin) op
  let freeVars = HashSet.fromList (fi ^. funInfoFreeVars)
      a =% b = (a ==% b) freeVars
      clauses :: [(Expression, Expression)]
      clauses =
        [ (clauseLhsAsExpression op (toList pats), body)
          | Just cls <- [unfoldLambdaClauses (fi ^. funInfoDef . funDefBody)],
            (pats, body) <- toList cls
        ]
  case zipExactMay (fi ^. funInfoClauses) clauses of
    Nothing -> error "builtin has the wrong number of clauses"
    Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do
      unless
        (exLhs =% lhs)
        ( error
            ( "clause lhs does not match for "
                <> ppTrace op
                <> "\nExpected: "
                <> ppTrace exLhs
                <> "\nActual: "
                <> ppTrace lhs
            )
        )
      unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body)
