diff --git a/Network/Wai.hs b/Network/Wai.hs index 9807c44..4466b2c 100644 --- a/Network/Wai.hs +++ b/Network/Wai.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE Rank2Types #-} module Network.Wai ( -- * Data types -- ** Request method @@ -23,6 +23,8 @@ module Network.Wai , Status (..) , statusCode , statusMessage + -- * Enumerator + , Enumerator -- * WAI interface , Request (..) , Response (..) @@ -233,6 +235,8 @@ statusMessage Status405 = B8.pack "Method Not Allowed" statusMessage Status500 = B8.pack "Internal Server Error" statusMessage (Status _ m) = m +type Enumerator a = (a -> B.ByteString -> IO (Either a a)) -> a -> IO a + data Request = Request { requestMethod :: Method , httpVersion :: HttpVersion @@ -242,7 +246,7 @@ data Request = Request , serverPort :: Int , httpHeaders :: [(RequestHeader, B.ByteString)] , urlScheme :: UrlScheme - , requestBody :: IO (Maybe B.ByteString) + , requestBody :: forall a. Enumerator a , errorHandler :: String -> IO () , remoteHost :: String } @@ -250,7 +254,7 @@ data Request = Request data Response = Response { status :: Status , headers :: [(ResponseHeader, B.ByteString)] - , body :: Either FilePath ((B.ByteString -> IO ()) -> IO ()) + , body :: forall a. Either FilePath (Enumerator a) } type Application = Request -> IO Response diff --git a/Network/Wai/Handler/SimpleServer.hs b/Network/Wai/Handler/SimpleServer.hs index 72f7280..6e4d3ba 100644 --- a/Network/Wai/Handler/SimpleServer.hs +++ b/Network/Wai/Handler/SimpleServer.hs @@ -21,7 +21,6 @@ import Network.Wai import qualified System.IO import qualified Data.ByteString as BS -import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Char8 as B8 import Network ( listenOn, accept, sClose, PortID(PortNumber), Socket @@ -135,14 +134,16 @@ parseRequest port lines' handle remoteHost' = do , remoteHost = remoteHost' } -requestBodyHandle :: Handle -> MVar Int -> IO (Maybe BS.ByteString) -requestBodyHandle h mlen = modifyMVar mlen helper where - helper :: Int -> IO (Int, Maybe BS.ByteString) - helper 0 = return (0, Nothing) - helper len = do +requestBodyHandle :: Handle -> MVar Int -> Enumerator a +requestBodyHandle h mlen iter accum = modifyMVar mlen (helper accum) where + helper a 0 = return (0, a) + helper a len = do bs <- BS.hGet h len let newLen = len - BS.length bs - return (newLen, Just bs) + ea' <- iter a bs + case ea' of + Left a' -> return (newLen, a') + Right a' -> helper a' newLen parseFirst :: (StringLike s, MonadFailure InvalidRequest m) => s @@ -167,8 +168,9 @@ sendResponse h res = do BS.hPut h $ SL.pack "\r\n" case body res of Left fp -> unsafeSendFile h fp - Right enum -> enum $ BS.hPut h + Right enum -> enum myPut h >> return () where + myPut _ bs = BS.hPut h bs >> return (Right h) putHeader (x, y) = do BS.hPut h $ responseHeaderToBS x BS.hPut h $ SL.pack ": " diff --git a/test.hs b/test.hs index 1e8b5c1..a92d1e5 100644 --- a/test.hs +++ b/test.hs @@ -1,7 +1,7 @@ +{-# LANGUAGE Rank2Types #-} import Network.Wai import Network.Wai.Handler.SimpleServer import qualified Data.ByteString.Char8 as B8 -import qualified Data.ByteString as B main :: IO () main = putStrLn "http://localhost:3000/" >> run 3000 app @@ -18,19 +18,12 @@ indexResponse = return Response , body = index } -postResponse :: IO (Maybe B.ByteString) -> IO Response +postResponse :: (forall a. Enumerator a) -> IO Response postResponse rb = return Response { status = Status200 , headers = [(ContentType, B8.pack "text/plain")] - , body = Right $ postBody rb + , body = Right rb } index :: Either FilePath a index = Left "index.html" - -postBody :: IO (Maybe B.ByteString) -> (B.ByteString -> IO ()) -> IO () -postBody req res = do - mbs <- req - case mbs of - Nothing -> return () - Just bs -> res bs >> postBody req res