Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 74 additions & 40 deletions src/Simplex/Messaging/Notifications/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import Data.Time.Format.ISO8601 (iso8601Show)
import GHC.IORef (atomicSwapIORef)
import GHC.Stats (getRTSStats)
import Network.Socket (ServiceName, Socket, socketToHandle)
import Numeric.Natural (Natural)
import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError, ServerTransmission (..))
import Simplex.Messaging.Client.Agent
import qualified Simplex.Messaging.Crypto as C
Expand Down Expand Up @@ -85,7 +86,7 @@ import System.Exit (exitFailure, exitSuccess)
import System.IO (BufferMode (..), hClose, hPrint, hPutStrLn, hSetBuffering, hSetNewlineMode, universalNewlineMode)
import System.Mem.Weak (deRefWeak)
import System.Timeout (timeout)
import UnliftIO (IOMode (..), UnliftIO, askUnliftIO, race_, unliftIO, withFile)
import UnliftIO (IOMode (..), UnliftIO (..), askUnliftIO, race_, unliftIO, withFile)
import UnliftIO.Concurrent (forkIO, killThread, mkWeakThreadId)
import UnliftIO.Directory (doesFileExist, renameFile)
import UnliftIO.Exception
Expand Down Expand Up @@ -116,7 +117,6 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
void $ forkIO $ resubscribe s
raceAny_
( ntfSubscriber s
: ntfPush ps
: periodicNtfsThread ps
: map runServer transports
<> serverStatsThread_ cfg
Expand Down Expand Up @@ -147,12 +147,17 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
saveServer
NtfSubscriber {smpSubscribers, smpAgent} <- asks subscriber
liftIO $ readTVarIO smpSubscribers >>= mapM_ stopSubscriber
NtfPushServer {pushWorkers} <- asks pushServer
liftIO $ readTVarIO pushWorkers >>= mapM_ stopPushWorker
liftIO $ closeSMPClientAgent smpAgent
logNote "Server stopped"
where
stopSubscriber v =
atomically (tryReadTMVar $ sessionVar v)
>>= mapM (deRefWeak . subThreadId >=> mapM_ killThread)
stopPushWorker v =
atomically (tryReadTMVar $ sessionVar v)
>>= mapM (deRefWeak . workerThreadId >=> mapM_ killThread)

saveServer :: M ()
saveServer = asks store >>= liftIO . closeNtfDbStore >> saveServerStats
Expand Down Expand Up @@ -257,7 +262,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
let threadsCount = 0
#endif
let NtfSubscriber {smpSubscribers, smpAgent = a} = subscriber
NtfPushServer {pushQ} = pushServer
NtfPushServer {pushWorkers} = pushServer
SMPClientAgent {smpClients, smpSessions, smpSubWorkers} = a
srvSubscribers <- getSMPWorkerMetrics a smpSubscribers
srvClients <- getSMPWorkerMetrics a smpClients
Expand All @@ -267,7 +272,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs snd
ntfPendingQueueSubs <- getSMPSubMetrics a pendingQueueSubs
smpSessionCount <- M.size <$> readTVarIO smpSessions
apnsPushQLength <- atomically $ lengthTBQueue pushQ
apnsPushQLength <- pushWorkersQLength pushWorkers
pure
NtfRealTimeMetrics
{ threadsCount,
Expand Down Expand Up @@ -526,35 +531,36 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} =
where
receiveSMP = do
st <- asks store
NtfPushServer {pushQ} <- asks pushServer
ps <- asks pushServer
stats <- asks serverStats
liftIO $ forever $ do
forever $ do
((_, srv@(SMPServer (h :| _) _ _), _), _thVersion, sessionId, ts) <- atomically $ readTBQueue msgQ
forM ts $ \(ntfId, t) -> case t of
forM_ ts $ \(ntfId, t) -> case t of
STUnexpectedError e -> logError $ "SMP client unexpected error: " <> tshow e -- uncorrelated response, should not happen
STResponse {} -> pure () -- it was already reported as timeout error
STEvent msgOrErr -> do
let smpQueue = SMPQueueNtf srv ntfId
case msgOrErr of
Right (SMP.NMSG nmsgNonce encNMsgMeta) -> do
ntfTs <- getSystemTime
updatePeriodStats (activeSubs stats) ntfId
ntfTs <- liftIO getSystemTime
liftIO $ updatePeriodStats (activeSubs stats) ntfId
let newNtf = PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta}
srvHost_ = if isOwnServer ca srv then Just (safeDecodeUtf8 $ strEncode h) else Nothing
addTokenLastNtf st newNtf >>= \case
srvHost = safeDecodeUtf8 $ strEncode h
isOwn = isOwnServer ca srv
liftIO (addTokenLastNtf st newNtf) >>= \case
Right (tkn, lastNtfs) -> do
atomically $ writeTBQueue pushQ (srvHost_, tkn, PNMessage lastNtfs)
incNtfStat_ stats ntfReceived
mapM_ (`incServerStat` ntfReceivedOwn stats) srvHost_
Left AUTH -> do
pushNotification ps (Just srvHost) isOwn tkn $ PNMessage lastNtfs
liftIO $ incNtfStat_ stats ntfReceived
when isOwn $ liftIO $ incServerStat srvHost (ntfReceivedOwn stats)
Left AUTH -> liftIO $ do
incNtfStat_ stats ntfReceivedAuth
mapM_ (`incServerStat` ntfReceivedAuthOwn stats) srvHost_
when isOwn $ incServerStat srvHost (ntfReceivedAuthOwn stats)
Left _ -> pure ()
Right SMP.END ->
whenM (atomically $ activeClientSession' ca sessionId srv) $
void $ updateSrvSubStatus st smpQueue NSEnd
void $ liftIO $ updateSrvSubStatus st smpQueue NSEnd
Right SMP.DELD ->
void $ updateSrvSubStatus st smpQueue NSDeleted
void $ liftIO $ updateSrvSubStatus st smpQueue NSDeleted
Right (SMP.ERR e) -> logError $ "SMP server error: " <> tshow e
Right _ -> logError "SMP server unexpected response"
Left e -> logError $ "SMP client error: " <> tshow e
Expand Down Expand Up @@ -632,9 +638,25 @@ logSubStatus srv event n updated =
showServer' :: SMPServer -> Text
showServer' = decodeLatin1 . strEncode . host

ntfPush :: NtfPushServer -> M ()
ntfPush s@NtfPushServer {pushQ} = forever $ do
(srvHost_, tkn@NtfTknRec {ntfTknId, token = t@(DeviceToken pp _), tknStatus}, ntf) <- atomically (readTBQueue pushQ)
pushNotification :: NtfPushServer -> Maybe T.Text -> OwnServer -> NtfTknRec -> PushNotification -> M ()
pushNotification s srvHost_ isOwn tkn@NtfTknRec {token = DeviceToken pp _} ntf = do
q <- getOrCreatePushWorker s (srvHost_, pp) isOwn
atomically $ writeTBQueue q (tkn, ntf)

getOrCreatePushWorker :: NtfPushServer -> (Maybe T.Text, PushProvider) -> OwnServer -> M (TBQueue (NtfTknRec, PushNotification))
getOrCreatePushWorker s@NtfPushServer {pushWorkers, pushWorkerSeq, pushQSize} key@(srvHost_, _) isOwn = do
ts <- liftIO getCurrentTime
atomically (getSessVar pushWorkerSeq key pushWorkers ts) >>= \case
Left v -> do
q <- liftIO $ newTBQueueIO pushQSize
tId <- mkWeakThreadId =<< forkIO (runPushWorker s srvHost_ isOwn q)
atomically $ putTMVar (sessionVar v) PushWorker {workerQ = q, workerThreadId = tId}
pure q
Right v -> workerQ <$> atomically (readTMVar $ sessionVar v)

runPushWorker :: NtfPushServer -> Maybe T.Text -> OwnServer -> TBQueue (NtfTknRec, PushNotification) -> M ()
runPushWorker s srvHost_ isOwn q = forever $ do
(tkn@NtfTknRec {ntfTknId, token = t@(DeviceToken pp _), tknStatus}, ntf) <- atomically (readTBQueue q)
liftIO $ logDebug $ "sending push notification to " <> T.pack (show pp)
st <- asks store
case ntf of
Expand All @@ -644,7 +666,7 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
void $ liftIO $ setTknStatusConfirmed st tkn
incNtfStatT t ntfVrfDelivered
Left _ -> incNtfStatT t ntfVrfFailed
PNCheckMessages -> do
PNCheckMessages ->
liftIO (deliverNotification st pp tkn ntf) >>= \case
Right _ -> do
void $ liftIO $ updateTokenCronSentAt st ntfTknId . systemSeconds =<< getSystemTime
Expand All @@ -656,35 +678,36 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
liftIO (deliverNotification st pp tkn ntf) >>= \case
Left _ -> do
incNtfStatT t ntfFailed
liftIO $ mapM_ (`incServerStat` ntfFailedOwn stats) srvHost_
when isOwn $ liftIO $ mapM_ (`incServerStat` ntfFailedOwn stats) srvHost_
Right () -> do
incNtfStatT t ntfDelivered
liftIO $ mapM_ (`incServerStat` ntfDeliveredOwn stats) srvHost_

when isOwn $ liftIO $ mapM_ (`incServerStat` ntfDeliveredOwn stats) srvHost_
where
checkActiveTkn :: NtfTknStatus -> M () -> M ()
checkActiveTkn status action
| status == NTActive = action
| otherwise = liftIO $ logError "bad notification token status"
deliverNotification :: NtfPostgresStore -> PushProvider -> NtfTknRec -> PushNotification -> IO (Either PushProviderError ())
deliverNotification st pp tkn@NtfTknRec {ntfTknId} ntf = do
deliver <- getPushClient s pp
runExceptT (deliver tkn ntf) >>= \case
deliverNotification st pp tkn@NtfTknRec {ntfTknId} ntf' = do
(deliver, clientVar) <- getPushClient s pp
runExceptT (deliver tkn ntf') >>= \case
Right _ -> pure $ Right ()
Left e -> case e of
PPConnection _ -> retryDeliver
PPRetryLater -> retryDeliver
PPConnection ce -> retryDeliver clientVar $ "connection " <> tshow ce
PPRetryLater r -> retryDeliver clientVar r
PPCryptoError _ -> err e
PPResponseError {} -> err e
PPTokenInvalid r -> do
void $ updateTknStatus st tkn $ NTInvalid $ Just r
err e
PPPermanentError -> err e
where
retryDeliver :: IO (Either PushProviderError ())
retryDeliver = do
deliver <- newPushClient s pp
runExceptT (deliver tkn ntf) >>= \case
retryDeliver :: PushClientVar -> Text -> IO (Either PushProviderError ())
retryDeliver oldVar reason = do
logWarn $ "retrying push (" <> tshow pp <> ", " <> tshow ntfTknId <> "): " <> reason
atomically $ removeSessVar oldVar pp (pushClients s)
(deliver, _) <- getPushClient s pp
runExceptT (deliver tkn ntf') >>= \case
Right _ -> pure $ Right ()
Left e -> case e of
PPTokenInvalid r -> do
Expand All @@ -693,15 +716,26 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
_ -> err e
err e = logError ("Push provider error (" <> tshow pp <> ", " <> tshow ntfTknId <> "): " <> tshow e) $> Left e

pushWorkersQLength :: TMap (Maybe T.Text, PushProvider) PushWorkerVar -> IO Natural
pushWorkersQLength workers = do
ws <- readTVarIO workers
foldM addQLength 0 ws
where
addQLength acc v =
atomically (tryReadTMVar $ sessionVar v) >>= \case
Just PushWorker {workerQ} -> (acc +) <$> atomically (lengthTBQueue workerQ)
Nothing -> pure acc

periodicNtfsThread :: NtfPushServer -> M ()
periodicNtfsThread NtfPushServer {pushQ} = do
periodicNtfsThread s = do
st <- asks store
ntfsInterval <- asks $ periodicNtfsInterval . config
let interval = 1000000 * ntfsInterval
UnliftIO unlift <- askUnliftIO
liftIO $ forever $ do
threadDelay interval
now <- systemSeconds <$> getSystemTime
cnt <- withPeriodicNtfTokens st now $ \tkn -> atomically $ writeTBQueue pushQ (Nothing, tkn, PNCheckMessages)
cnt <- withPeriodicNtfTokens st now $ \tkn -> unlift $ pushNotification s Nothing False tkn PNCheckMessages
logNote $ "Scheduled periodic notifications: " <> tshow cnt

runNtfClientTransport :: Transport c => THandleNTF c 'TServer -> M ()
Expand Down Expand Up @@ -791,7 +825,7 @@ verifyNtfTransmission st thAuth (tAuth, authorized, (corrId, entId, cmd)) = case
e -> VRFailed e

client :: NtfServerClient -> NtfSubscriber -> NtfPushServer -> M ()
client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServer {pushQ} =
client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} ps =
forever $
atomically (readTBQueue rcvQ)
>>= mapM processCommand
Expand All @@ -808,7 +842,7 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ
ts <- liftIO $ getSystemDate
let tkn = mkNtfTknRec tknId newTkn srvDhPrivKey dhSecret regCode ts
withNtfStore (`addNtfToken` tkn) $ \_ -> do
atomically $ writeTBQueue pushQ (Nothing, tkn, PNVerification regCode)
pushNotification ps Nothing False tkn $ PNVerification regCode
incNtfStatT token ntfVrfQueued
incNtfStatT token tknCreated
pure $ NRTknId tknId srvDhPubKey
Expand All @@ -824,7 +858,7 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ
| otherwise -> withNtfStore (\st -> updateTknStatus st tkn NTRegistered) $ \_ -> sendVerification
where
sendVerification = do
atomically $ writeTBQueue pushQ (Nothing, tkn, PNVerification tknRegCode)
pushNotification ps Nothing False tkn $ PNVerification tknRegCode
incNtfStatT token ntfVrfQueued
pure $ NRTknId ntfTknId $ C.publicKey tknDhPrivKey
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
Expand All @@ -842,7 +876,7 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ
regCode <- getRegCode
let tkn' = tkn {token = token', tknStatus = NTRegistered, tknRegCode = regCode}
withNtfStore (`replaceNtfToken` tkn') $ \_ -> do
atomically $ writeTBQueue pushQ (Nothing, tkn', PNVerification regCode)
pushNotification ps Nothing False tkn' $ PNVerification regCode
incNtfStatT token ntfVrfQueued
incNtfStatT token tknReplaced
pure NROk
Expand Down
72 changes: 56 additions & 16 deletions src/Simplex/Messaging/Notifications/Server/Env.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,27 @@ module Simplex.Messaging.Notifications.Server.Env
SMPSubscriberVar,
SMPSubscriber (..),
NtfPushServer (..),
PushClientVar,
PushWorker (..),
PushWorkerVar,
NtfRequest (..),
NtfServerClient (..),
defaultInactiveClientExpiration,
newNtfServerEnv,
newNtfSubscriber,
newNtfPushServer,
newPushClient,
getPushClient,
newNtfServerClient,
) where

import Control.Concurrent (ThreadId)
import qualified Control.Exception as E
import Control.Logger.Simple
import Control.Monad
import Crypto.Random
import Data.Functor (($>))
import Data.Int (Int64)
import Simplex.Messaging.Agent.RetryInterval
import Data.List.NonEmpty (NonEmpty)
import qualified Data.Text as T
import Data.Time.Clock (getCurrentTime)
Expand Down Expand Up @@ -58,6 +63,7 @@ import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ASrvTransport, SMPServiceRole (..), ServiceCredentials (..), THandleParams, TransportPeer (..))
import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials, TransportServerConfig, loadFingerprint, loadServerCredential)
import Simplex.Messaging.Util (tshow)
import System.Exit (exitFailure)
import System.Mem.Weak (Weak)
import UnliftIO.STM
Expand Down Expand Up @@ -163,28 +169,62 @@ data SMPSubscriber = SMPSubscriber
}

data NtfPushServer = NtfPushServer
{ pushQ :: TBQueue (Maybe T.Text, NtfTknRec, PushNotification), -- Maybe Text is a hostname of "own" server
pushClients :: TMap PushProvider PushProviderClient,
{ pushWorkers :: TMap (Maybe T.Text, PushProvider) PushWorkerVar,
pushWorkerSeq :: TVar Int,
pushQSize :: Natural,
pushClients :: TMap PushProvider PushClientVar,
pushClientSeq :: TVar Int,
apnsConfig :: APNSPushClientConfig
}

data PushWorker = PushWorker
{ workerQ :: TBQueue (NtfTknRec, PushNotification),
workerThreadId :: Weak ThreadId
}

type PushWorkerVar = SessionVar PushWorker

-- The Either communicates client-creation failure from the winner to the waiters.
type PushClientVar = SessionVar (Either E.SomeException PushProviderClient)

newNtfPushServer :: Natural -> APNSPushClientConfig -> IO NtfPushServer
newNtfPushServer qSize apnsConfig = do
pushQ <- newTBQueueIO qSize
newNtfPushServer pushQSize apnsConfig = do
pushWorkers <- TM.emptyIO
pushWorkerSeq <- newTVarIO 0
pushClients <- TM.emptyIO
pure NtfPushServer {pushQ, pushClients, apnsConfig}

newPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient
newPushClient NtfPushServer {apnsConfig, pushClients} pp = do
c <- case apnsProviderHost pp of
pushClientSeq <- newTVarIO 0
pure NtfPushServer {pushWorkers, pushWorkerSeq, pushQSize, pushClients, pushClientSeq, apnsConfig}

-- | Single-flight access to the per-provider push client with bounded retry.
-- The returned PushClientVar is the handle retryDeliver passes to removeSessVar to evict
-- this specific instance before re-fetching.
getPushClient :: NtfPushServer -> PushProvider -> IO (PushProviderClient, PushClientVar)
getPushClient s@NtfPushServer {apnsConfig = APNSPushClientConfig {reconnectInterval}} pp =
withRetryIntervalCount reconnectInterval $ \n _delay loop -> do
ts <- getCurrentTime
E.try (atomically (getSessVar (pushClientSeq s) pp (pushClients s) ts) >>= either (newPushClient s pp) waitForPushClient) >>= \case
Right result -> pure result
Left e
| n < 2 -> do
logError $ "getPushClient error (" <> tshow pp <> "): " <> tshow (e :: E.SomeException)
loop
| otherwise -> E.throwIO e

newPushClient :: NtfPushServer -> PushProvider -> PushClientVar -> IO (PushProviderClient, PushClientVar)
newPushClient NtfPushServer {pushClients, apnsConfig} pp v = do
r <- E.try $ case apnsProviderHost pp of
Nothing -> pure $ \_ _ -> pure ()
Just host -> apnsPushProviderClient <$> createAPNSPushClient host apnsConfig
atomically $ TM.insert pp c pushClients
pure c

getPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient
getPushClient s@NtfPushServer {pushClients} pp =
TM.lookupIO pp pushClients >>= maybe (newPushClient s pp) pure
atomically $ do
putTMVar (sessionVar v) r
case r of
Left _ -> removeSessVar v pp pushClients
Right _ -> pure ()
either E.throwIO (\c -> pure (c, v)) r

waitForPushClient :: PushClientVar -> IO (PushProviderClient, PushClientVar)
waitForPushClient v =
atomically (readTMVar $ sessionVar v) >>= either E.throwIO (\c -> pure (c, v))

data NtfRequest
= NtfReqNew CorrId ANewNtfEntity
Expand Down
Loading
Loading