Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions Control/Concurrent/Stream.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
{-
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
-}

-- | Processing streams with a fixed number of worker threads
module Control.Concurrent.Stream
( stream
, streamBound
, streamWithInput
, streamWithOutput
, streamWithInputOutput
, mapConcurrentlyBounded
, forConcurrentlyBounded
) where

import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Data.Maybe
import Data.IORef

data ShouldBindThreads = BoundThreads | UnboundThreads

-- | Maps a fixed number of workers concurrently over a stream of values
-- produced by a producer function. The producer is passed a function to
-- call for each work item. If a worker throws a synchronous exception, it
-- will be propagated to the caller.
stream
:: Int -- ^ Maximum Concurrency
-> ((a -> IO ()) -> IO ()) -- ^ Producer
-> (a -> IO ()) -- ^ Worker
-> IO ()
stream maxConcurrency producer worker =
streamWithInput producer (replicate maxConcurrency ()) $ const worker

-- | Like stream, but uses bound threads for the workers. See
-- 'Control.Concurrent.forkOS' for details on bound threads.
streamBound
:: Int -- ^ Maximum Concurrency
-> ((a -> IO ()) -> IO ()) -- ^ Producer
-> (a -> IO ()) -- ^ Worker
-> IO ()
streamBound maxConcurrency producer worker =
stream_ BoundThreads producer (replicate maxConcurrency ()) $ const worker

-- | Like stream, but each worker is passed an element of an input list.
streamWithInput
:: ((a -> IO ()) -> IO ()) -- ^ Producer
-> [b] -- ^ Worker state
-> (b -> a -> IO ()) -- ^ Worker
-> IO ()
streamWithInput = stream_ UnboundThreads

-- | Like 'stream', but collects the results of each worker
streamWithOutput
:: Int
-> ((a -> IO ()) -> IO ()) -- ^ Producer
-> (a -> IO c) -- ^ Worker
-> IO [c]
streamWithOutput maxConcurrency producer worker =
streamWithInputOutput producer (replicate maxConcurrency ()) $
const worker

-- | Like 'streamWithInput', but collects the results of each worker
streamWithInputOutput
:: ((a -> IO ()) -> IO ()) -- ^ Producer
-> [b] -- ^ Worker input
-> (b -> a -> IO c) -- ^ Worker
-> IO [c]
streamWithInputOutput producer workerInput worker = do
results <- newIORef []
let prod write = producer $ \a -> do
res <- newIORef Nothing
modifyIORef results (res :)
write (a, res)
stream_ UnboundThreads prod workerInput $ \s (a,ref) -> do
worker s a >>= writeIORef ref . Just
readIORef results >>= mapM readIORef >>= return . catMaybes . reverse

stream_
:: ShouldBindThreads -- use bound threads?
-> ((a -> IO ()) -> IO ()) -- ^ Producer
-> [b] -- Worker input
-> (b -> a -> IO ()) -- ^ Worker
-> IO ()
stream_ useBoundThreads producer workerInput worker = do
let maxConcurrency = length workerInput
q <- atomically $ newTBQueue (fromIntegral maxConcurrency)
let write x = atomically $ writeTBQueue q (Just x)
mask $ \unmask ->
concurrently_ (runWorkers unmask q) $ unmask $ do
-- run the producer
producer write
-- write end-markers for all workers
replicateM_ maxConcurrency $
atomically $ writeTBQueue q Nothing
where
runWorkers unmask q = case useBoundThreads of
BoundThreads ->
foldr1 concurrentlyBound $
map (runWorker unmask q) workerInput
UnboundThreads ->
mapConcurrently_ (runWorker unmask q) workerInput

concurrentlyBound l r =
withAsyncBound l $ \a ->
withAsyncBound r $ \b ->
void $ waitBoth a b

runWorker unmask q s = do
v <- atomically $ readTBQueue q
case v of
Nothing -> return ()
Just t -> do
unmask (worker s t)
runWorker unmask q s

-- | Concurrent map over a list of values, using a bounded number of threads.
mapConcurrentlyBounded
:: Int -- ^ Maximum concurrency
-> (a -> IO b) -- ^ Function to map over the input values
-> [a] -- ^ List of input values
-> IO [b] -- ^ List of output values
mapConcurrentlyBounded maxConcurrency f input =
streamWithOutput maxConcurrency (forM_ input) f

-- | 'mapConcurrentlyBounded' but with its arguments reversed
forConcurrentlyBounded
:: Int -- ^ Maximum concurrency
-> [a] -- ^ List of input values
-> (a -> IO b) -- ^ Function to map over the input values
-> IO [b] -- ^ List of output values
forConcurrentlyBounded = flip . mapConcurrentlyBounded
4 changes: 3 additions & 1 deletion async.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: async
version: 2.2.5
version: 2.2.6
-- don't forget to update ./changelog.md!
synopsis: Run IO operations asynchronously and wait for their results

Expand Down Expand Up @@ -81,6 +81,7 @@ library
other-extensions: Trustworthy
exposed-modules: Control.Concurrent.Async
Control.Concurrent.Async.Internal
Control.Concurrent.Stream
build-depends: base >= 4.3 && < 4.22,
hashable >= 1.1.2.0 && < 1.6,
stm >= 2.2 && < 2.6
Expand All @@ -89,6 +90,7 @@ library

test-suite test-async
default-language: Haskell2010
ghc-options: -threaded
type: exitcode-stdio-1.0
hs-source-dirs: test
main-is: test-async.hs
Expand Down
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## Changes in 2.2.6

- Added Control.Concurrent.Stream for processing streams with a fixed
number of workers. Includes a bounded version of mapConcurrently:
mapConcurrentlyBounded.

## Changes in 2.2.5

- #117: Document that empty for Concurrently waits forever
Expand Down
60 changes: 60 additions & 0 deletions test/test-async.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Test.HUnit

import Control.Concurrent.STM
import Control.Concurrent.Async
import Control.Concurrent.Stream
import Control.Exception
import Data.IORef
import Data.Typeable
Expand Down Expand Up @@ -65,6 +66,15 @@ tests = [
, testCase "concurrentlyE_Monoid" concurrentlyE_Monoid
, testCase "concurrentlyE_Monoid_fail" concurrentlyE_Monoid_fail
#endif
, testCase "stream" $ case_stream False
, testCase "streamBound" $ case_stream True
, testCase "stream_exception" $ case_stream_exception False
, testCase "streamBound_exception" $ case_stream_exception True
, testCase "streamWithInput" case_streamInput
, testCase "streamWithInput_exception" case_streamInput_exception
, testCase "mapConcurrentlyBounded" case_mapConcurrentlyBounded
, testCase "mapConcurrentlyBounded_exception"
case_mapConcurrentlyBounded_exception
]
]

Expand Down Expand Up @@ -459,3 +469,53 @@ concurrentlyE_Monoid_fail = do
r :: Either Char [Char] <- runConcurrentlyE $ foldMap ConcurrentlyE $ current
assertEqual "The earliest failure" (Left 'u') r
#endif

case_stream :: Bool -> Assertion
case_stream bound = do
ref <- newIORef []
let inp = [1..100]
let producer write = forM_ inp $ \x -> write (show x)
(if bound then streamBound else stream) 4 producer $ \s -> atomicModifyIORef ref (\l -> (s:l, ()))
res <- readIORef ref
sort res @?= sort (map show inp)

case_stream_exception :: Bool -> Assertion
case_stream_exception bound = do
let inp = [1..100]
let producer write = forM_ inp $ \x -> write (show x)
r <- try $ (if bound then streamBound else stream) 4 producer $ \s ->
when (s == "3") $ throwIO (ErrorCall s)
r @?= Left (ErrorCall "3" :: ErrorCall)

case_streamInput :: Assertion
case_streamInput = do
ref <- newIORef []
let inp = [1..100]; workers = [1..4] :: [Int]
let producer write = forM_ inp $ \x -> write (show x)
streamWithInput producer workers $ \s t -> atomicModifyIORef ref (\l -> ((s,t):l, ()))
res <- readIORef ref
sort (map snd res) @?= sort (map show inp)
all ((`elem` workers) . fst) res @?= True

case_streamInput_exception :: Assertion
case_streamInput_exception = do
let inp = [1..100]; workers = [1..4] :: [Int]
let producer write = forM_ inp $ \x -> write (show x)
r <- try $ streamWithInput producer workers $ \s t ->
when (t == "3") $ throwIO (ErrorCall t)
r @?= Left (ErrorCall "3" :: ErrorCall)

case_mapConcurrentlyBounded :: Assertion
case_mapConcurrentlyBounded = do
let inp = [1..100]
let f x = threadDelay 1000 >> return (x * 2)
res <- mapConcurrentlyBounded 4 f inp
res @?= map (*2) inp

case_mapConcurrentlyBounded_exception :: Assertion
case_mapConcurrentlyBounded_exception = do
let inp = [1..100]
let f x | x == 3 = throwIO $ ErrorCall "3"
| otherwise = threadDelay 1000 >> return (x * 2)
res <- try $ mapConcurrentlyBounded 4 f inp
res @?= Left (ErrorCall "3" :: ErrorCall)