
I was experimenting with Prompt today and found that you can get a "restricted monad" style of behavior out of a regular monad using Prompt:
{-# LANGUAGE GADTs #-} module SetTest where import qualified Data.Set as S
Prompt is available from http://hackage.haskell.org/cgi-bin/hackage-scripts/package/MonadPrompt-1.0.0...
import Control.Monad.Prompt
"OrdP" is a prompt that implements MonadPlus for orderable types:
data OrdP m a where PZero :: OrdP m a PRestrict :: Ord a => m a -> OrdP m a PPlus :: Ord a => m a -> m a -> OrdP m a
type SetM = RecPrompt OrdP
We can't make this an instance of MonadPlus; mplus would need an Ord constraint. But as long as we don't import it, we can overload the name.
mzero :: SetM a mzero = prompt PZero mplus :: Ord a => SetM a -> SetM a -> SetM a mplus x y = prompt (PPlus x y)
"mrestrict" can be inserted at various points in a computation to optimize it; it forces the passed in computation to complete and uses a Set to eliminate duplicate outputs. We could also implement mrestrict without an additional element in our prompt datatype, at the cost of some performance: mrestrict m = mplus mzero m
mrestrict :: Ord a => SetM a -> SetM a mrestrict x = prompt (PRestrict x)
Finally we need an interpretation function to run the monad and extract a set from it:
runSetM :: Ord r => SetM r -> S.Set r runSetM = runPromptC ret prm . unRecPrompt where -- ret :: r -> S.Set r ret = S.singleton -- prm :: forall a. OrdP SetM a -> (a -> S.Set r) -> S.Set r prm PZero _ = S.empty prm (PRestrict m) k = unionMap k (runSetM m) prm (PPlus m1 m2) k = unionMap k (runSetM m1 `S.union` runSetM m2)
unionMap is the equivalent of concatMap for lists.
unionMap :: Ord b => (a -> S.Set b) -> S.Set a -> S.Set b unionMap f = S.fold (\a r -> f a `S.union` r) S.empty
Oleg's test now works without modification:
test1s_do () = do x <- return "a" return $ "b" ++ x
olegtest :: S.Set String olegtest = runSetM $ test1s_do () -- fromList ["ba"]
settest :: S.Set Int settest = runSetM $ do x <- mplus (mplus mzero (return 2)) (mplus (return 2) (return 3)) return (x+3) -- fromList [5,6]
What this does under the hood is treat the computation on each element of the set separately, except at programmer-specified synchronization points where the computation result is required to be a member of the Ord typeclass. Synchronization points happen at every "mplus" & "mrestrict"; these correspond to a gathering of the computation results up to that point into a Set and then dispatching the remainder of the computation from that Set. -- ryan