diff --git a/Control/Concurrent/Async/Internal.hs b/Control/Concurrent/Async/Internal.hs index cfe47a2..11e3432 100644 --- a/Control/Concurrent/Async/Internal.hs +++ b/Control/Concurrent/Async/Internal.hs @@ -142,7 +142,7 @@ asyncOnWithUnmask cpu actionWith = asyncUsing :: CALLSTACK (IO () -> IO ThreadId) -> IO a -> IO (Async a) -asyncUsing doFork = \action -> do +asyncUsing doFork action = do var <- newEmptyTMVarIO let action_plus = debugLabelMe >> action -- t <- forkFinally action (\r -> atomically $ putTMVar var r) @@ -207,7 +207,7 @@ withAsyncUsing :: (IO () -> IO ThreadId) -> IO a -> (Async a -> IO b) -> IO b -- The bracket version works, but is slow. We can do better by -- hand-coding it: -withAsyncUsing doFork = \action inner -> do +withAsyncUsing doFork action inner = do var <- newEmptyTMVarIO mask $ \restore -> do let action_plus = debugLabelMe >> action @@ -734,7 +734,7 @@ concurrently' left right collect = do -- ensure the children are really dead replicateM_ count' (tryAgain $ takeMVar done) - r <- collect (tryAgain $ takeDone) `onException` stop + r <- collect (tryAgain takeDone) `onException` stop stop return r @@ -801,7 +801,7 @@ forConcurrently_ = flip mapConcurrently_ replicateConcurrently :: CALLSTACK Int -> IO a -> IO [a] -replicateConcurrently cnt = runConcurrently . sequenceA . replicate cnt . Concurrently +replicateConcurrently cnt = runConcurrently . replicateM cnt . Concurrently -- | Same as 'replicateConcurrently', but ignore the results. -- @@ -927,7 +927,7 @@ rawForkIO :: CALLSTACK IO () -> IO ThreadId rawForkIO action = IO $ \ s -> - case (fork# action_plus s) of (# s1, tid #) -> (# s1, ThreadId tid #) + case fork# action_plus s of (# s1, tid #) -> (# s1, ThreadId tid #) where (IO action_plus) = debugLabelMe >> action @@ -936,7 +936,7 @@ rawForkOn :: CALLSTACK Int -> IO () -> IO ThreadId rawForkOn (I# cpu) action = IO $ \ s -> - case (forkOn# cpu action_plus s) of (# s1, tid #) -> (# s1, ThreadId tid #) + case forkOn# cpu action_plus s of (# s1, tid #) -> (# s1, ThreadId tid #) where (IO action_plus) = debugLabelMe >> action diff --git a/Control/Concurrent/Async/Warden.hs b/Control/Concurrent/Async/Warden.hs new file mode 100644 index 0000000..ec259be --- /dev/null +++ b/Control/Concurrent/Async/Warden.hs @@ -0,0 +1,89 @@ +{- + Copyright (c) Meta Platforms, Inc. and affiliates. + All rights reserved. + + This source code is licensed under the BSD-style license found in the + LICENSE file in the root directory of this source tree. +-} + +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | A more flexible way to create 'Async's and have them automatically +-- cancelled when the 'Warden' is shut down. +module Control.Concurrent.Async.Warden + ( Warden + , withWarden + , create + , shutdown + , spawn + , spawn_ + , spawnMask + , WardenException(..) + ) where + +import Control.Concurrent (forkIO) +import Control.Concurrent.Async (Async) +import qualified Control.Concurrent.Async as Async +import Control.Concurrent.MVar +import Control.Exception +import Control.Monad +import Data.HashSet (HashSet) +import qualified Data.HashSet as HashSet +import System.IO (fixIO) + + +-- | A 'Warden' is an owner of 'Async's which cancels them on 'shutdown'. +-- +-- 'Nothing' in the MVar means the 'Warden' has been shut down. +newtype Warden = Warden (MVar (Maybe (HashSet (Async ())))) + +-- | Run the action with a new 'Warden', and call 'shutdown' when the action +-- exits. +withWarden :: (Warden -> IO a) -> IO a +withWarden = bracket create shutdown + +-- | Create a new 'Warden'. +create :: IO Warden +create = Warden <$> newMVar (Just mempty) + +-- | Shutdown a 'Warden', calling 'cancel' on all owned threads. Subsequent +-- calls to 'spawn' and 'shutdown' will be no-ops. +-- +-- Note that any exceptions thrown by the threads will be ignored. If you want +-- exceptions to be propagated, either call `wait` explicitly on the 'Async', +-- or use 'link'. +shutdown :: Warden -> IO () +shutdown (Warden v) = do + r <- swapMVar v Nothing + mapM_ (Async.mapConcurrently_ Async.cancel) r + +forget :: Warden -> Async a -> IO () +forget (Warden v) async = modifyMVar_ v $ \x -> case x of + Just xs -> return $! Just $! HashSet.delete (void async) xs + Nothing -> return Nothing + +-- | Spawn a thread with masked exceptions and pass an unmask function to the +-- action. +spawnMask :: Warden -> ((forall b. IO b -> IO b) -> IO a) -> IO (Async a) +spawnMask (Warden v) action = modifyMVar v $ \r -> case r of + Just asyncs -> do + -- Create a new thread which removes itself from the 'HashSet' when it + -- exits. + this <- fixIO $ \this -> mask_ $ Async.asyncWithUnmask $ \unmask -> + action unmask `finally` forget (Warden v) this + return (Just $ HashSet.insert (void this) asyncs, this) + Nothing -> throwIO $ WardenException "Warden has been shut down" + +newtype WardenException = WardenException String + deriving (Show) + +instance Exception WardenException + +-- | Spawn a new thread owned by the 'Warden'. +spawn :: Warden -> IO a -> IO (Async a) +spawn warden action = spawnMask warden $ \unmask -> unmask action + +-- | Spawn a new thread owned by the 'Warden'. +spawn_ :: Warden -> IO () -> IO () +spawn_ w = void . spawn w diff --git a/Control/Concurrent/Stream.hs b/Control/Concurrent/Stream.hs new file mode 100644 index 0000000..3c8d666 --- /dev/null +++ b/Control/Concurrent/Stream.hs @@ -0,0 +1,138 @@ +{- + Copyright (c) Meta Platforms, Inc. and affiliates. + All rights reserved. + + This source code is licensed under the BSD-style license found in the + LICENSE file in the root directory of this source tree. +-} + +-- | Processing streams with a fixed number of worker threads +module Control.Concurrent.Stream + ( stream + , streamBound + , streamWithInput + , streamWithOutput + , streamWithInputOutput + , mapConcurrentlyBounded + , forConcurrentlyBounded + ) where + +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Exception +import Control.Monad +import Data.Maybe +import Data.IORef + +data ShouldBindThreads = BoundThreads | UnboundThreads + +-- | Maps a fixed number of workers concurrently over a stream of values +-- produced by a producer function. The producer is passed a function to +-- call for each work item. If a worker throws a synchronous exception, it +-- will be propagated to the caller. +stream + :: Int -- ^ Maximum Concurrency + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> (a -> IO ()) -- ^ Worker + -> IO () +stream maxConcurrency producer worker = + streamWithInput producer (replicate maxConcurrency ()) $ const worker + +-- | Like stream, but uses bound threads for the workers. See +-- 'Control.Concurrent.forkOS' for details on bound threads. +streamBound + :: Int -- ^ Maximum Concurrency + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> (a -> IO ()) -- ^ Worker + -> IO () +streamBound maxConcurrency producer worker = + stream_ BoundThreads producer (replicate maxConcurrency ()) $ const worker + +-- | Like stream, but each worker is passed an element of an input list. +streamWithInput + :: ((a -> IO ()) -> IO ()) -- ^ Producer + -> [b] -- ^ Worker state + -> (b -> a -> IO ()) -- ^ Worker + -> IO () +streamWithInput = stream_ UnboundThreads + +-- | Like 'stream', but collects the results of each worker +streamWithOutput + :: Int + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> (a -> IO c) -- ^ Worker + -> IO [c] +streamWithOutput maxConcurrency producer worker = + streamWithInputOutput producer (replicate maxConcurrency ()) $ + const worker + +-- | Like 'streamWithInput', but collects the results of each worker +streamWithInputOutput + :: ((a -> IO ()) -> IO ()) -- ^ Producer + -> [b] -- ^ Worker input + -> (b -> a -> IO c) -- ^ Worker + -> IO [c] +streamWithInputOutput producer workerInput worker = do + results <- newIORef [] + let prod write = producer $ \a -> do + res <- newIORef Nothing + modifyIORef results (res :) + write (a, res) + stream_ UnboundThreads prod workerInput $ \s (a,ref) -> do + worker s a >>= writeIORef ref . Just + readIORef results >>= mapM readIORef >>= return . catMaybes . reverse + +stream_ + :: ShouldBindThreads -- use bound threads? + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> [b] -- Worker input + -> (b -> a -> IO ()) -- ^ Worker + -> IO () +stream_ useBoundThreads producer workerInput worker = do + let maxConcurrency = length workerInput + q <- atomically $ newTBQueue (fromIntegral maxConcurrency) + let write x = atomically $ writeTBQueue q (Just x) + mask $ \unmask -> + concurrently_ (runWorkers unmask q) $ unmask $ do + -- run the producer + producer write + -- write end-markers for all workers + replicateM_ maxConcurrency $ + atomically $ writeTBQueue q Nothing + where + runWorkers unmask q = case useBoundThreads of + BoundThreads -> + foldr1 concurrentlyBound $ + map (runWorker unmask q) workerInput + UnboundThreads -> + mapConcurrently_ (runWorker unmask q) workerInput + + concurrentlyBound l r = + withAsyncBound l $ \a -> + withAsyncBound r $ \b -> + void $ waitBoth a b + + runWorker unmask q s = do + v <- atomically $ readTBQueue q + case v of + Nothing -> return () + Just t -> do + unmask (worker s t) + runWorker unmask q s + +-- | Concurrent map over a list of values, using a bounded number of threads. +mapConcurrentlyBounded + :: Int -- ^ Maximum concurrency + -> (a -> IO b) -- ^ Function to map over the input values + -> [a] -- ^ List of input values + -> IO [b] -- ^ List of output values +mapConcurrentlyBounded maxConcurrency f input = + streamWithOutput maxConcurrency (forM_ input) f + +-- | 'mapConcurrentlyBounded' but with its arguments reversed +forConcurrentlyBounded + :: Int -- ^ Maximum concurrency + -> [a] -- ^ List of input values + -> (a -> IO b) -- ^ Function to map over the input values + -> IO [b] -- ^ List of output values +forConcurrentlyBounded = flip . mapConcurrentlyBounded diff --git a/async.cabal b/async.cabal index a0f5838..a382f1e 100644 --- a/async.cabal +++ b/async.cabal @@ -1,5 +1,5 @@ name: async -version: 2.2.5 +version: 2.2.6 -- don't forget to update ./changelog.md! synopsis: Run IO operations asynchronously and wait for their results @@ -81,14 +81,18 @@ library other-extensions: Trustworthy exposed-modules: Control.Concurrent.Async Control.Concurrent.Async.Internal + Control.Concurrent.Async.Warden + Control.Concurrent.Stream build-depends: base >= 4.3 && < 4.22, hashable >= 1.1.2.0 && < 1.6, - stm >= 2.2 && < 2.6 + stm >= 2.2 && < 2.6, + unordered-containers >= 0.2 && < 0.3 if flag(debug-auto-label) cpp-options: -DDEBUG_AUTO_LABEL test-suite test-async default-language: Haskell2010 + ghc-options: -threaded type: exitcode-stdio-1.0 hs-source-dirs: test main-is: test-async.hs diff --git a/changelog.md b/changelog.md index 0544b82..c674e01 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,12 @@ +## Changes in 2.2.6 + + - Added Control.Concurrent.Stream for processing streams with a fixed + number of workers. Includes a bounded version of mapConcurrently: + mapConcurrentlyBounded. + - Added Control.Concurrent.Async.Warden for a way to create Asyncs that + is more flexible than 'withAsync' but retains the guarantee of cancelling + orphaned threads, unlike 'async'. + ## Changes in 2.2.5 - #117: Document that empty for Concurrently waits forever diff --git a/test/test-async.hs b/test/test-async.hs index f781b15..c0d15e0 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -8,6 +8,8 @@ import Test.HUnit import Control.Concurrent.STM import Control.Concurrent.Async +import Control.Concurrent.Async.Warden +import Control.Concurrent.Stream import Control.Exception import Data.IORef import Data.Typeable @@ -65,6 +67,17 @@ tests = [ , testCase "concurrentlyE_Monoid" concurrentlyE_Monoid , testCase "concurrentlyE_Monoid_fail" concurrentlyE_Monoid_fail #endif + , testCase "stream" $ case_stream False + , testCase "streamBound" $ case_stream True + , testCase "stream_exception" $ case_stream_exception False + , testCase "streamBound_exception" $ case_stream_exception True + , testCase "streamWithInput" case_streamInput + , testCase "streamWithInput_exception" case_streamInput_exception + , testCase "mapConcurrentlyBounded" case_mapConcurrentlyBounded + , testCase "mapConcurrentlyBounded_exception" + case_mapConcurrentlyBounded_exception + , testCase "Warden" case_Warden + , testCase "Warden_spawn_after_shutdown" case_Warden_spawn_after_shutdown ] ] @@ -459,3 +472,79 @@ concurrentlyE_Monoid_fail = do r :: Either Char [Char] <- runConcurrentlyE $ foldMap ConcurrentlyE $ current assertEqual "The earliest failure" (Left 'u') r #endif + +case_stream :: Bool -> Assertion +case_stream bound = do + ref <- newIORef [] + let inp = [1..100] + let producer write = forM_ inp $ \x -> write (show x) + (if bound then streamBound else stream) 4 producer $ \s -> atomicModifyIORef ref (\l -> (s:l, ())) + res <- readIORef ref + sort res @?= sort (map show inp) + +case_stream_exception :: Bool -> Assertion +case_stream_exception bound = do + let inp = [1..100] + let producer write = forM_ inp $ \x -> write (show x) + r <- try $ (if bound then streamBound else stream) 4 producer $ \s -> + when (s == "3") $ throwIO (ErrorCall s) + r @?= Left (ErrorCall "3" :: ErrorCall) + +case_streamInput :: Assertion +case_streamInput = do + ref <- newIORef [] + let inp = [1..100]; workers = [1..4] :: [Int] + let producer write = forM_ inp $ \x -> write (show x) + streamWithInput producer workers $ \s t -> atomicModifyIORef ref (\l -> ((s,t):l, ())) + res <- readIORef ref + sort (map snd res) @?= sort (map show inp) + all ((`elem` workers) . fst) res @?= True + +case_streamInput_exception :: Assertion +case_streamInput_exception = do + let inp = [1..100]; workers = [1..4] :: [Int] + let producer write = forM_ inp $ \x -> write (show x) + r <- try $ streamWithInput producer workers $ \s t -> + when (t == "3") $ throwIO (ErrorCall t) + r @?= Left (ErrorCall "3" :: ErrorCall) + +case_mapConcurrentlyBounded :: Assertion +case_mapConcurrentlyBounded = do + let inp = [1..100] + let f x = threadDelay 1000 >> return (x * 2) + res <- mapConcurrentlyBounded 4 f inp + res @?= map (*2) inp + +case_mapConcurrentlyBounded_exception :: Assertion +case_mapConcurrentlyBounded_exception = do + let inp = [1..100] + let f x | x == 3 = throwIO $ ErrorCall "3" + | otherwise = threadDelay 1000 >> return (x * 2) + res <- try $ mapConcurrentlyBounded 4 f inp + res @?= Left (ErrorCall "3" :: ErrorCall) + +case_Warden :: Assertion +case_Warden = do + a3 <- withWarden $ \warden -> do + a1 <- spawn warden $ return 1 + a2 <- spawnMask warden $ \unmask -> unmask (return 2) + a3 <- spawn warden $ threadDelay 10000000 + spawn_ warden $ throwIO (ErrorCall "a4") -- ignored + r1 <- wait a1 + r1 @?= 1 + r2 <- wait a2 + r2 @?= 2 + return a3 + r3 <- waitCatch a3 + case r3 of + Right _ -> assertFailure "Expected AsyncCancelled" + Left e -> fromException e @?= Just AsyncCancelled + +case_Warden_spawn_after_shutdown :: Assertion +case_Warden_spawn_after_shutdown = do + warden <- create + shutdown warden + r <- try $ spawn warden $ return () + case r of + Left (WardenException{}) -> return () -- expected + Right _ -> assertFailure "Expected WardenException" \ No newline at end of file