diff --git a/Control/Concurrent/Stream.hs b/Control/Concurrent/Stream.hs new file mode 100644 index 0000000..3c8d666 --- /dev/null +++ b/Control/Concurrent/Stream.hs @@ -0,0 +1,138 @@ +{- + Copyright (c) Meta Platforms, Inc. and affiliates. + All rights reserved. + + This source code is licensed under the BSD-style license found in the + LICENSE file in the root directory of this source tree. +-} + +-- | Processing streams with a fixed number of worker threads +module Control.Concurrent.Stream + ( stream + , streamBound + , streamWithInput + , streamWithOutput + , streamWithInputOutput + , mapConcurrentlyBounded + , forConcurrentlyBounded + ) where + +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Exception +import Control.Monad +import Data.Maybe +import Data.IORef + +data ShouldBindThreads = BoundThreads | UnboundThreads + +-- | Maps a fixed number of workers concurrently over a stream of values +-- produced by a producer function. The producer is passed a function to +-- call for each work item. If a worker throws a synchronous exception, it +-- will be propagated to the caller. +stream + :: Int -- ^ Maximum Concurrency + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> (a -> IO ()) -- ^ Worker + -> IO () +stream maxConcurrency producer worker = + streamWithInput producer (replicate maxConcurrency ()) $ const worker + +-- | Like stream, but uses bound threads for the workers. See +-- 'Control.Concurrent.forkOS' for details on bound threads. +streamBound + :: Int -- ^ Maximum Concurrency + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> (a -> IO ()) -- ^ Worker + -> IO () +streamBound maxConcurrency producer worker = + stream_ BoundThreads producer (replicate maxConcurrency ()) $ const worker + +-- | Like stream, but each worker is passed an element of an input list. +streamWithInput + :: ((a -> IO ()) -> IO ()) -- ^ Producer + -> [b] -- ^ Worker state + -> (b -> a -> IO ()) -- ^ Worker + -> IO () +streamWithInput = stream_ UnboundThreads + +-- | Like 'stream', but collects the results of each worker +streamWithOutput + :: Int + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> (a -> IO c) -- ^ Worker + -> IO [c] +streamWithOutput maxConcurrency producer worker = + streamWithInputOutput producer (replicate maxConcurrency ()) $ + const worker + +-- | Like 'streamWithInput', but collects the results of each worker +streamWithInputOutput + :: ((a -> IO ()) -> IO ()) -- ^ Producer + -> [b] -- ^ Worker input + -> (b -> a -> IO c) -- ^ Worker + -> IO [c] +streamWithInputOutput producer workerInput worker = do + results <- newIORef [] + let prod write = producer $ \a -> do + res <- newIORef Nothing + modifyIORef results (res :) + write (a, res) + stream_ UnboundThreads prod workerInput $ \s (a,ref) -> do + worker s a >>= writeIORef ref . Just + readIORef results >>= mapM readIORef >>= return . catMaybes . reverse + +stream_ + :: ShouldBindThreads -- use bound threads? + -> ((a -> IO ()) -> IO ()) -- ^ Producer + -> [b] -- Worker input + -> (b -> a -> IO ()) -- ^ Worker + -> IO () +stream_ useBoundThreads producer workerInput worker = do + let maxConcurrency = length workerInput + q <- atomically $ newTBQueue (fromIntegral maxConcurrency) + let write x = atomically $ writeTBQueue q (Just x) + mask $ \unmask -> + concurrently_ (runWorkers unmask q) $ unmask $ do + -- run the producer + producer write + -- write end-markers for all workers + replicateM_ maxConcurrency $ + atomically $ writeTBQueue q Nothing + where + runWorkers unmask q = case useBoundThreads of + BoundThreads -> + foldr1 concurrentlyBound $ + map (runWorker unmask q) workerInput + UnboundThreads -> + mapConcurrently_ (runWorker unmask q) workerInput + + concurrentlyBound l r = + withAsyncBound l $ \a -> + withAsyncBound r $ \b -> + void $ waitBoth a b + + runWorker unmask q s = do + v <- atomically $ readTBQueue q + case v of + Nothing -> return () + Just t -> do + unmask (worker s t) + runWorker unmask q s + +-- | Concurrent map over a list of values, using a bounded number of threads. +mapConcurrentlyBounded + :: Int -- ^ Maximum concurrency + -> (a -> IO b) -- ^ Function to map over the input values + -> [a] -- ^ List of input values + -> IO [b] -- ^ List of output values +mapConcurrentlyBounded maxConcurrency f input = + streamWithOutput maxConcurrency (forM_ input) f + +-- | 'mapConcurrentlyBounded' but with its arguments reversed +forConcurrentlyBounded + :: Int -- ^ Maximum concurrency + -> [a] -- ^ List of input values + -> (a -> IO b) -- ^ Function to map over the input values + -> IO [b] -- ^ List of output values +forConcurrentlyBounded = flip . mapConcurrentlyBounded diff --git a/async.cabal b/async.cabal index a0f5838..da76dd3 100644 --- a/async.cabal +++ b/async.cabal @@ -1,5 +1,5 @@ name: async -version: 2.2.5 +version: 2.2.6 -- don't forget to update ./changelog.md! synopsis: Run IO operations asynchronously and wait for their results @@ -81,6 +81,7 @@ library other-extensions: Trustworthy exposed-modules: Control.Concurrent.Async Control.Concurrent.Async.Internal + Control.Concurrent.Stream build-depends: base >= 4.3 && < 4.22, hashable >= 1.1.2.0 && < 1.6, stm >= 2.2 && < 2.6 @@ -89,6 +90,7 @@ library test-suite test-async default-language: Haskell2010 + ghc-options: -threaded type: exitcode-stdio-1.0 hs-source-dirs: test main-is: test-async.hs diff --git a/changelog.md b/changelog.md index 0544b82..50baea8 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,9 @@ +## Changes in 2.2.6 + + - Added Control.Concurrent.Stream for processing streams with a fixed + number of workers. Includes a bounded version of mapConcurrently: + mapConcurrentlyBounded. + ## Changes in 2.2.5 - #117: Document that empty for Concurrently waits forever diff --git a/test/test-async.hs b/test/test-async.hs index f781b15..58801fa 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -8,6 +8,7 @@ import Test.HUnit import Control.Concurrent.STM import Control.Concurrent.Async +import Control.Concurrent.Stream import Control.Exception import Data.IORef import Data.Typeable @@ -65,6 +66,15 @@ tests = [ , testCase "concurrentlyE_Monoid" concurrentlyE_Monoid , testCase "concurrentlyE_Monoid_fail" concurrentlyE_Monoid_fail #endif + , testCase "stream" $ case_stream False + , testCase "streamBound" $ case_stream True + , testCase "stream_exception" $ case_stream_exception False + , testCase "streamBound_exception" $ case_stream_exception True + , testCase "streamWithInput" case_streamInput + , testCase "streamWithInput_exception" case_streamInput_exception + , testCase "mapConcurrentlyBounded" case_mapConcurrentlyBounded + , testCase "mapConcurrentlyBounded_exception" + case_mapConcurrentlyBounded_exception ] ] @@ -459,3 +469,53 @@ concurrentlyE_Monoid_fail = do r :: Either Char [Char] <- runConcurrentlyE $ foldMap ConcurrentlyE $ current assertEqual "The earliest failure" (Left 'u') r #endif + +case_stream :: Bool -> Assertion +case_stream bound = do + ref <- newIORef [] + let inp = [1..100] + let producer write = forM_ inp $ \x -> write (show x) + (if bound then streamBound else stream) 4 producer $ \s -> atomicModifyIORef ref (\l -> (s:l, ())) + res <- readIORef ref + sort res @?= sort (map show inp) + +case_stream_exception :: Bool -> Assertion +case_stream_exception bound = do + let inp = [1..100] + let producer write = forM_ inp $ \x -> write (show x) + r <- try $ (if bound then streamBound else stream) 4 producer $ \s -> + when (s == "3") $ throwIO (ErrorCall s) + r @?= Left (ErrorCall "3" :: ErrorCall) + +case_streamInput :: Assertion +case_streamInput = do + ref <- newIORef [] + let inp = [1..100]; workers = [1..4] :: [Int] + let producer write = forM_ inp $ \x -> write (show x) + streamWithInput producer workers $ \s t -> atomicModifyIORef ref (\l -> ((s,t):l, ())) + res <- readIORef ref + sort (map snd res) @?= sort (map show inp) + all ((`elem` workers) . fst) res @?= True + +case_streamInput_exception :: Assertion +case_streamInput_exception = do + let inp = [1..100]; workers = [1..4] :: [Int] + let producer write = forM_ inp $ \x -> write (show x) + r <- try $ streamWithInput producer workers $ \s t -> + when (t == "3") $ throwIO (ErrorCall t) + r @?= Left (ErrorCall "3" :: ErrorCall) + +case_mapConcurrentlyBounded :: Assertion +case_mapConcurrentlyBounded = do + let inp = [1..100] + let f x = threadDelay 1000 >> return (x * 2) + res <- mapConcurrentlyBounded 4 f inp + res @?= map (*2) inp + +case_mapConcurrentlyBounded_exception :: Assertion +case_mapConcurrentlyBounded_exception = do + let inp = [1..100] + let f x | x == 3 = throwIO $ ErrorCall "3" + | otherwise = threadDelay 1000 >> return (x * 2) + res <- try $ mapConcurrentlyBounded 4 f inp + res @?= Left (ErrorCall "3" :: ErrorCall)