{-# OPTIONS_GHC -Wno-unused-foralls #-}

module Plutarch.Internal.TermCont (
  hashOpenTerm,
  TermCont (TermCont),
  runTermCont,
  unTermCont,
  tcont,
  pfindPlaceholder,
  pfindAllPlaceholders,
) where

import Data.Hashable (Hashed, hashed)
import Data.Kind (Type)
import Data.List (nub)
import Data.String (fromString)
import Plutarch.Internal.Term (
  Config (Tracing),
  HoistedTerm (..),
  RawTerm (..),
  S,
  Term (Term),
  TracingMode (DetTracing),
  asRawTerm,
  getTerm,
  perror,
  pgetConfig,
 )
import Plutarch.Internal.Trace (ptraceInfo)

newtype TermCont :: forall (r :: S -> Type). S -> Type -> Type where
  TermCont :: forall r s a. {forall (r :: S -> Type) (s :: S) a.
TermCont @r s a -> (a -> Term s r) -> Term s r
runTermCont :: (a -> Term s r) -> Term s r} -> TermCont @r s a

unTermCont :: TermCont @a s (Term s a) -> Term s a
unTermCont :: forall (a :: S -> Type) (s :: S).
TermCont @a s (Term s a) -> Term s a
unTermCont TermCont @a s (Term s a)
t = TermCont @a s (Term s a) -> (Term s a -> Term s a) -> Term s a
forall (r :: S -> Type) (s :: S) a.
TermCont @r s a -> (a -> Term s r) -> Term s r
runTermCont TermCont @a s (Term s a)
t Term s a -> Term s a
forall a. a -> a
id

instance Functor (TermCont s) where
  fmap :: forall a b. (a -> b) -> TermCont @r s a -> TermCont @r s b
fmap a -> b
f (TermCont (a -> Term s r) -> Term s r
g) = ((b -> Term s r) -> Term s r) -> TermCont @r s b
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont (((b -> Term s r) -> Term s r) -> TermCont @r s b)
-> ((b -> Term s r) -> Term s r) -> TermCont @r s b
forall a b. (a -> b) -> a -> b
$ \b -> Term s r
h -> (a -> Term s r) -> Term s r
g (b -> Term s r
h (b -> Term s r) -> (a -> b) -> a -> Term s r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)

instance Applicative (TermCont s) where
  pure :: forall a. a -> TermCont @r s a
pure a
x = ((a -> Term s r) -> Term s r) -> TermCont @r s a
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont (((a -> Term s r) -> Term s r) -> TermCont @r s a)
-> ((a -> Term s r) -> Term s r) -> TermCont @r s a
forall a b. (a -> b) -> a -> b
$ \a -> Term s r
f -> a -> Term s r
f a
x
  TermCont @r s (a -> b)
x <*> :: forall a b.
TermCont @r s (a -> b) -> TermCont @r s a -> TermCont @r s b
<*> TermCont @r s a
y = do
    a -> b
x <- TermCont @r s (a -> b)
x
    a -> b
x (a -> b) -> TermCont @r s a -> TermCont @r s b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> TermCont @r s a
y

instance Monad (TermCont s) where
  (TermCont (a -> Term s r) -> Term s r
f) >>= :: forall a b.
TermCont @r s a -> (a -> TermCont @r s b) -> TermCont @r s b
>>= a -> TermCont @r s b
g = ((b -> Term s r) -> Term s r) -> TermCont @r s b
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont (((b -> Term s r) -> Term s r) -> TermCont @r s b)
-> ((b -> Term s r) -> Term s r) -> TermCont @r s b
forall a b. (a -> b) -> a -> b
$ \b -> Term s r
h ->
    (a -> Term s r) -> Term s r
f
      ( \a
x ->
          TermCont @r s b -> (b -> Term s r) -> Term s r
forall (r :: S -> Type) (s :: S) a.
TermCont @r s a -> (a -> Term s r) -> Term s r
runTermCont (a -> TermCont @r s b
g a
x) b -> Term s r
h
      )

instance MonadFail (TermCont s) where
  fail :: forall a. String -> TermCont @r s a
fail String
s = ((a -> Term s r) -> Term s r) -> TermCont @r s a
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont (((a -> Term s r) -> Term s r) -> TermCont @r s a)
-> ((a -> Term s r) -> Term s r) -> TermCont @r s a
forall a b. (a -> b) -> a -> b
$ \a -> Term s r
_ ->
    (Config -> Term s r) -> Term s r
forall (s :: S) (a :: S -> Type). (Config -> Term s a) -> Term s a
pgetConfig ((Config -> Term s r) -> Term s r)
-> (Config -> Term s r) -> Term s r
forall a b. (a -> b) -> a -> b
$ \case
      -- Note: This currently works because DetTracing is the most specific
      -- tracing mode.
      Tracing LogLevel
_ TracingMode
DetTracing -> Term s PString -> Term s r -> Term s r
forall (a :: S -> Type) (s :: S).
Term s PString -> Term s a -> Term s a
ptraceInfo Term s PString
"Pattern matching failure in TermCont" Term s r
forall (s :: S) (a :: S -> Type). Term s a
perror
      Config
_ -> Term s PString -> Term s r -> Term s r
forall (a :: S -> Type) (s :: S).
Term s PString -> Term s a -> Term s a
ptraceInfo (String -> Term s PString
forall a. IsString a => String -> a
fromString String
s) Term s r
forall (s :: S) (a :: S -> Type). Term s a
perror

tcont :: ((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont :: forall a (s :: S) (r :: S -> Type).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont = ((a -> Term s r) -> Term s r) -> TermCont @r s a
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont

hashOpenTerm :: Term s a -> TermCont s (Hashed RawTerm)
hashOpenTerm :: forall {r :: S -> Type} (s :: S) (a :: S -> Type).
Term s a -> TermCont @r s (Hashed RawTerm)
hashOpenTerm Term s a
x = ((Hashed RawTerm -> Term s r) -> Term s r)
-> TermCont @r s (Hashed RawTerm)
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont (((Hashed RawTerm -> Term s r) -> Term s r)
 -> TermCont @r s (Hashed RawTerm))
-> ((Hashed RawTerm -> Term s r) -> Term s r)
-> TermCont @r s (Hashed RawTerm)
forall a b. (a -> b) -> a -> b
$ \Hashed RawTerm -> Term s r
f -> (Word64 -> TermMonad TermResult) -> Term s r
forall (s :: S) (a :: S -> Type).
(Word64 -> TermMonad TermResult) -> Term s a
Term ((Word64 -> TermMonad TermResult) -> Term s r)
-> (Word64 -> TermMonad TermResult) -> Term s r
forall a b. (a -> b) -> a -> b
$ \Word64
i -> do
  TermResult
y <- Term s a -> Word64 -> TermMonad TermResult
forall (s :: S) (a :: S -> Type).
Term s a -> Word64 -> TermMonad TermResult
asRawTerm Term s a
x Word64
i
  let h :: Hashed RawTerm
h = RawTerm -> Hashed RawTerm
forall a. Hashable a => a -> Hashed a
hashed (RawTerm -> Hashed RawTerm) -> RawTerm -> Hashed RawTerm
forall a b. (a -> b) -> a -> b
$ TermResult -> RawTerm
getTerm TermResult
y
  Term s r -> Word64 -> TermMonad TermResult
forall (s :: S) (a :: S -> Type).
Term s a -> Word64 -> TermMonad TermResult
asRawTerm (Hashed RawTerm -> Term s r
f Hashed RawTerm
h) Word64
i

-- This can technically be done outside of TermCont.
-- Need to pay close attention when killing branch with this.
-- If term is pre-evaluated (via `evalTerm`), RawTerm will no longer hold
-- tagged RPlaceholder.

{- | Given a term, and an integer tag, this function checks if the term holds and
@PPlaceholder@ with the given integer tag.
-}
pfindPlaceholder :: Integer -> Term s a -> TermCont s Bool
pfindPlaceholder :: forall {r :: S -> Type} (s :: S) (a :: S -> Type).
Integer -> Term s a -> TermCont @r s Bool
pfindPlaceholder Integer
idx Term s a
x = ((Bool -> Term s r) -> Term s r) -> TermCont @r s Bool
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont (((Bool -> Term s r) -> Term s r) -> TermCont @r s Bool)
-> ((Bool -> Term s r) -> Term s r) -> TermCont @r s Bool
forall a b. (a -> b) -> a -> b
$ \Bool -> Term s r
f -> (Word64 -> TermMonad TermResult) -> Term s r
forall (s :: S) (a :: S -> Type).
(Word64 -> TermMonad TermResult) -> Term s a
Term ((Word64 -> TermMonad TermResult) -> Term s r)
-> (Word64 -> TermMonad TermResult) -> Term s r
forall a b. (a -> b) -> a -> b
$ \Word64
i -> do
  TermResult
y <- Term s a -> Word64 -> TermMonad TermResult
forall (s :: S) (a :: S -> Type).
Term s a -> Word64 -> TermMonad TermResult
asRawTerm Term s a
x Word64
i
  Term s r -> Word64 -> TermMonad TermResult
forall (s :: S) (a :: S -> Type).
Term s a -> Word64 -> TermMonad TermResult
asRawTerm (Bool -> Term s r
f (Bool -> Term s r)
-> (TermResult -> Bool) -> TermResult -> Term s r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RawTerm -> Bool
findPlaceholder (RawTerm -> Bool) -> (TermResult -> RawTerm) -> TermResult -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermResult -> RawTerm
getTerm (TermResult -> Term s r) -> TermResult -> Term s r
forall a b. (a -> b) -> a -> b
$ TermResult
y) Word64
i
  where
    findPlaceholder :: RawTerm -> Bool
findPlaceholder = \case
      RLamAbs Word64
_ RawTerm
x -> RawTerm -> Bool
findPlaceholder RawTerm
x
      RApply RawTerm
x [RawTerm]
xs -> (RawTerm -> Bool) -> [RawTerm] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any RawTerm -> Bool
findPlaceholder (RawTerm
x RawTerm -> [RawTerm] -> [RawTerm]
forall a. a -> [a] -> [a]
: [RawTerm]
xs)
      RForce RawTerm
x -> RawTerm -> Bool
findPlaceholder RawTerm
x
      RDelay RawTerm
x -> RawTerm -> Bool
findPlaceholder RawTerm
x
      RHoisted (HoistedTerm Int
_ RawTerm
x) -> RawTerm -> Bool
findPlaceholder RawTerm
x
      RPlaceHolder Integer
idx' -> Integer
idx Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
idx'
      RConstr Word64
_ [RawTerm]
xs -> (RawTerm -> Bool) -> [RawTerm] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any RawTerm -> Bool
findPlaceholder [RawTerm]
xs
      RCase RawTerm
x [RawTerm]
xs -> (RawTerm -> Bool) -> [RawTerm] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any RawTerm -> Bool
findPlaceholder (RawTerm
x RawTerm -> [RawTerm] -> [RawTerm]
forall a. a -> [a] -> [a]
: [RawTerm]
xs)
      RVar Word64
_ -> Bool
False
      RConstant Some @Type (ValueOf DefaultUni)
_ -> Bool
False
      RBuiltin DefaultFun
_ -> Bool
False
      RCompiled Term DeBruijn DefaultUni DefaultFun ()
_ -> Bool
False
      RawTerm
RError -> Bool
False

-- | Finds all placeholder ids and returns it
pfindAllPlaceholders :: Term s a -> TermCont s [Integer]
pfindAllPlaceholders :: forall {r :: S -> Type} (s :: S) (a :: S -> Type).
Term s a -> TermCont @r s [Integer]
pfindAllPlaceholders Term s a
x = (([Integer] -> Term s r) -> Term s r) -> TermCont @r s [Integer]
forall (r :: S -> Type) (s :: S) a.
((a -> Term s r) -> Term s r) -> TermCont @r s a
TermCont ((([Integer] -> Term s r) -> Term s r) -> TermCont @r s [Integer])
-> (([Integer] -> Term s r) -> Term s r) -> TermCont @r s [Integer]
forall a b. (a -> b) -> a -> b
$ \[Integer] -> Term s r
f -> (Word64 -> TermMonad TermResult) -> Term s r
forall (s :: S) (a :: S -> Type).
(Word64 -> TermMonad TermResult) -> Term s a
Term ((Word64 -> TermMonad TermResult) -> Term s r)
-> (Word64 -> TermMonad TermResult) -> Term s r
forall a b. (a -> b) -> a -> b
$ \Word64
i -> do
  TermResult
y <- Term s a -> Word64 -> TermMonad TermResult
forall (s :: S) (a :: S -> Type).
Term s a -> Word64 -> TermMonad TermResult
asRawTerm Term s a
x Word64
i
  Term s r -> Word64 -> TermMonad TermResult
forall (s :: S) (a :: S -> Type).
Term s a -> Word64 -> TermMonad TermResult
asRawTerm ([Integer] -> Term s r
f ([Integer] -> Term s r)
-> (TermResult -> [Integer]) -> TermResult -> Term s r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Integer] -> [Integer]
forall a. Eq a => [a] -> [a]
nub ([Integer] -> [Integer])
-> (TermResult -> [Integer]) -> TermResult -> [Integer]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RawTerm -> [Integer]
findPlaceholder (RawTerm -> [Integer])
-> (TermResult -> RawTerm) -> TermResult -> [Integer]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermResult -> RawTerm
getTerm (TermResult -> Term s r) -> TermResult -> Term s r
forall a b. (a -> b) -> a -> b
$ TermResult
y) Word64
i
  where
    findPlaceholder :: RawTerm -> [Integer]
    findPlaceholder :: RawTerm -> [Integer]
findPlaceholder = \case
      RLamAbs Word64
_ RawTerm
x -> RawTerm -> [Integer]
findPlaceholder RawTerm
x
      RApply RawTerm
x [RawTerm]
xs -> RawTerm -> [Integer]
findPlaceholder RawTerm
x [Integer] -> [Integer] -> [Integer]
forall a. Semigroup a => a -> a -> a
<> (RawTerm -> [Integer]) -> [RawTerm] -> [Integer]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: Type -> Type) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap RawTerm -> [Integer]
findPlaceholder [RawTerm]
xs
      RForce RawTerm
x -> RawTerm -> [Integer]
findPlaceholder RawTerm
x
      RDelay RawTerm
x -> RawTerm -> [Integer]
findPlaceholder RawTerm
x
      RHoisted (HoistedTerm Int
_ RawTerm
x) -> RawTerm -> [Integer]
findPlaceholder RawTerm
x
      RPlaceHolder Integer
idx -> [Integer
Item [Integer]
idx]
      RConstr Word64
_ [RawTerm]
xs -> (RawTerm -> [Integer]) -> [RawTerm] -> [Integer]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: Type -> Type) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap RawTerm -> [Integer]
findPlaceholder [RawTerm]
xs
      RCase RawTerm
x [RawTerm]
xs -> RawTerm -> [Integer]
findPlaceholder RawTerm
x [Integer] -> [Integer] -> [Integer]
forall a. Semigroup a => a -> a -> a
<> (RawTerm -> [Integer]) -> [RawTerm] -> [Integer]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: Type -> Type) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap RawTerm -> [Integer]
findPlaceholder [RawTerm]
xs
      RVar Word64
_ -> []
      RConstant Some @Type (ValueOf DefaultUni)
_ -> []
      RBuiltin DefaultFun
_ -> []
      RCompiled Term DeBruijn DefaultUni DefaultFun ()
_ -> []
      RawTerm
RError -> []