diff --git a/Microsoft.DurableTask.sln b/Microsoft.DurableTask.sln index 0b8ef935..051dd2c1 100644 --- a/Microsoft.DurableTask.sln +++ b/Microsoft.DurableTask.sln @@ -1,4 +1,4 @@ - + Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio Version 17 VisualStudioVersion = 17.3.32901.215 @@ -115,6 +115,24 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NamespaceGenerationSample", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ReplaySafeLoggerFactorySample", "samples\ReplaySafeLoggerFactorySample\ReplaySafeLoggerFactorySample.csproj", "{8E7BECBC-7226-4778-B8F2-8EBDFF0D3BA4}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Extensions", "Extensions", "{21303FBF-2A2B-17C2-D2DF-3E924022E940}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AzureManagedServerless", "src\Extensions\AzureManagedServerless\AzureManagedServerless.csproj", "{C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "AzureManaged", "AzureManaged", "{D4587EC0-1B16-8420-7502-A967139249D4}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "AzureManaged", "AzureManaged", "{53193780-CD18-2643-6953-C26F59EAEDF5}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Extensions", "Extensions", "{00205C88-F000-28F2-A910-C6FA00E065EE}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AzureManagedServerless.Tests", "test\Extensions\AzureManagedServerless.Tests\AzureManagedServerless.Tests.csproj", "{4D50F5B2-4782-486F-A9AA-073D798CC60D}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "serverless", "serverless", "{5BD6F026-413E-9AC5-D159-8E8D9F26EF1B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "main-app", "samples\serverless\main-app\main-app.csproj", "{4535F88F-EA1C-4C6F-84D5-93535EE1568C}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "remote-worker", "samples\serverless\remote-worker\remote-worker.csproj", "{562E5DB9-761B-4DE9-98CB-C364F6DE558E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -701,7 +719,54 @@ Global {8E7BECBC-7226-4778-B8F2-8EBDFF0D3BA4}.Release|x64.Build.0 = Release|Any CPU {8E7BECBC-7226-4778-B8F2-8EBDFF0D3BA4}.Release|x86.ActiveCfg = Release|Any CPU {8E7BECBC-7226-4778-B8F2-8EBDFF0D3BA4}.Release|x86.Build.0 = Release|Any CPU - + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Debug|x64.ActiveCfg = Debug|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Debug|x64.Build.0 = Debug|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Debug|x86.ActiveCfg = Debug|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Debug|x86.Build.0 = Debug|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Release|Any CPU.Build.0 = Release|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Release|x64.ActiveCfg = Release|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Release|x64.Build.0 = Release|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Release|x86.ActiveCfg = Release|Any CPU + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2}.Release|x86.Build.0 = Release|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Debug|x64.ActiveCfg = Debug|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Debug|x64.Build.0 = Debug|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Debug|x86.ActiveCfg = Debug|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Debug|x86.Build.0 = Debug|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Release|Any CPU.Build.0 = Release|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Release|x64.ActiveCfg = Release|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Release|x64.Build.0 = Release|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Release|x86.ActiveCfg = Release|Any CPU + {4D50F5B2-4782-486F-A9AA-073D798CC60D}.Release|x86.Build.0 = Release|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Debug|x64.ActiveCfg = Debug|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Debug|x64.Build.0 = Debug|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Debug|x86.ActiveCfg = Debug|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Debug|x86.Build.0 = Debug|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Release|Any CPU.Build.0 = Release|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Release|x64.ActiveCfg = Release|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Release|x64.Build.0 = Release|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Release|x86.ActiveCfg = Release|Any CPU + {4535F88F-EA1C-4C6F-84D5-93535EE1568C}.Release|x86.Build.0 = Release|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Debug|x64.ActiveCfg = Debug|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Debug|x64.Build.0 = Debug|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Debug|x86.ActiveCfg = Debug|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Debug|x86.Build.0 = Debug|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Release|Any CPU.Build.0 = Release|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Release|x64.ActiveCfg = Release|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Release|x64.Build.0 = Release|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Release|x86.ActiveCfg = Release|Any CPU + {562E5DB9-761B-4DE9-98CB-C364F6DE558E}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -759,7 +824,15 @@ Global {4A7305AE-AAAE-43AE-AAB2-DA58DACC6FA8} = {EFF7632B-821E-4CFC-B4A0-ED4B24296B17} {5A69FD28-D814-490E-A76B-B0A5F88C25B2} = {EFF7632B-821E-4CFC-B4A0-ED4B24296B17} {8E7BECBC-7226-4778-B8F2-8EBDFF0D3BA4} = {EFF7632B-821E-4CFC-B4A0-ED4B24296B17} - + {21303FBF-2A2B-17C2-D2DF-3E924022E940} = {8AFC9781-F6F1-4696-BB4A-9ED7CA9D612B} + {C6DC28DC-95CE-42DA-B02C-FFB2BA1CB1A2} = {21303FBF-2A2B-17C2-D2DF-3E924022E940} + {D4587EC0-1B16-8420-7502-A967139249D4} = {1C217BB2-CE16-41CC-9D47-0FC0DB60BDB3} + {53193780-CD18-2643-6953-C26F59EAEDF5} = {5B448FF6-EC42-491D-A22E-1DC8B618E6D5} + {00205C88-F000-28F2-A910-C6FA00E065EE} = {E5637F81-2FB9-4CD7-900D-455363B142A7} + {4D50F5B2-4782-486F-A9AA-073D798CC60D} = {00205C88-F000-28F2-A910-C6FA00E065EE} + {5BD6F026-413E-9AC5-D159-8E8D9F26EF1B} = {EFF7632B-821E-4CFC-B4A0-ED4B24296B17} + {4535F88F-EA1C-4C6F-84D5-93535EE1568C} = {5BD6F026-413E-9AC5-D159-8E8D9F26EF1B} + {562E5DB9-761B-4DE9-98CB-C364F6DE558E} = {5BD6F026-413E-9AC5-D159-8E8D9F26EF1B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {AB41CB55-35EA-4986-A522-387AB3402E71} diff --git a/README.md b/README.md index 7226f201..2094e890 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,8 @@ The Durable Task Scheduler for Azure Functions is a managed backend that is curr This SDK can also be used with the Durable Task Scheduler directly, without any Durable Functions dependency. To get started, sign up for the [Durable Task Scheduler private preview](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) and follow the instructions to create a new Durable Task Scheduler instance. Once granted access to the private preview GitHub repository, you can find samples and documentation for getting started [here](https://github.com/Azure/Azure-Functions-Durable-Task-Scheduler-Private-Preview/tree/main/samples/portable-sdk/dotnet/AspNetWebApp#readme). +The [serverless activities sample](samples/serverless/README.md) shows how to declare selected activities for DTS-managed serverless execution and build the remote worker container image separately from the declarer app. + ## Obtaining the Protobuf definitions This project utilizes protobuf definitions from [durabletask-protobuf](https://github.com/microsoft/durabletask-protobuf), which are copied (vendored) into this repository under the `src/Grpc` directory. See the corresponding [README.md](./src/Grpc/README.md) for more information about how to update the protobuf definitions. diff --git a/samples/serverless/README.md b/samples/serverless/README.md new file mode 100644 index 00000000..d6ce29b1 --- /dev/null +++ b/samples/serverless/README.md @@ -0,0 +1,57 @@ +# Serverless Activities Sample + +This sample shows how to run selected Durable Task activities in DTS-managed serverless sandboxes. + +The sample is intentionally split into two projects: + +| Path | Purpose | +| --- | --- | +| `main-app/` | Runs locally or in a normal app host. It declares the serverless activity and starts one hello orchestration. | +| `remote-worker/` | Builds the container image that DTS starts inside a serverless sandbox. It contains the remote hello activity. | + +## Build + +```powershell +dotnet build .\samples\serverless\main-app\main-app.csproj +dotnet build .\samples\serverless\remote-worker\remote-worker.csproj +``` + +## Build the remote worker image + +Run from the repository root: + +```powershell +$image = ".azurecr.io/dts-serverless-sample:" +docker build -f .\samples\serverless\remote-worker\Containerfile -t $image . +docker push $image +``` + +## Run a hello orchestration + +The main app uses `DefaultAzureCredential`; sign in with Azure CLI or configure another supported Azure identity before running it. + +```powershell +$env:DTS_ENDPOINT = "https://" +$env:DTS_TASK_HUB = "" +$env:DTS_SERVERLESS_ACTIVITY_IMAGE = ".azurecr.io/dts-serverless-sample:" +$env:DTS_SERVERLESS_CPU = "1000m" +$env:DTS_SERVERLESS_MEMORY = "2048Mi" +$env:DTS_SERVERLESS_MAX_ACTIVITIES = "1" +$env:DTS_SAMPLE_HELLO_INPUT = "serverless-sample" + +dotnet run --project .\samples\serverless\main-app\main-app.csproj +``` + +Expected output includes the serverless activity result: + +```text +Runtime status: Completed +Output: "hello from pid=: serverless-sample" +``` + +Use the Durable Task Scheduler dashboard's Serverless Activities preview tab to inspect serverless activity runtimes and stream runtime logs. + +The remote worker image does not need customer-provided DTS runtime settings. +DTS injects the scheduler endpoint, task hub, worker profile, capacity, substrate, +and sandbox identifier when it starts the sandbox. The worker reports the +activities registered in the image when it connects. diff --git a/samples/serverless/main-app/Activities.cs b/samples/serverless/main-app/Activities.cs new file mode 100644 index 00000000..73c5caee --- /dev/null +++ b/samples/serverless/main-app/Activities.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Samples.Serverless.MainApp; + +internal static class ServerlessTaskNames +{ + public const string RemoteHello = "RemoteHello"; + public const string HelloOrchestrator = nameof(HelloOrchestrator); +} diff --git a/samples/serverless/main-app/Orchestrators.cs b/samples/serverless/main-app/Orchestrators.cs new file mode 100644 index 00000000..b5fedd88 --- /dev/null +++ b/samples/serverless/main-app/Orchestrators.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DurableTask; + +namespace Microsoft.DurableTask.Samples.Serverless.MainApp; + +[DurableTask(nameof(HelloOrchestrator))] +internal sealed class HelloOrchestrator : TaskOrchestrator +{ + public override async Task RunAsync(TaskOrchestrationContext context, string input) + { + string remoteResult = await context.CallActivityAsync(ServerlessTaskNames.RemoteHello, input); + return remoteResult; + } +} diff --git a/samples/serverless/main-app/Program.cs b/samples/serverless/main-app/Program.cs new file mode 100644 index 00000000..b589e6cb --- /dev/null +++ b/samples/serverless/main-app/Program.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Core; +using Azure.Identity; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.AzureManaged; +using Microsoft.DurableTask.Samples.Serverless.MainApp; +using Microsoft.DurableTask.Worker; +using Microsoft.DurableTask.Worker.AzureManaged; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +string endpoint = GetRequiredEnvironmentVariable("DTS_ENDPOINT"); +string taskHub = Environment.GetEnvironmentVariable("DTS_TASK_HUB") ?? "ServerlessPocHub"; +string workerProfileId = Environment.GetEnvironmentVariable("DTS_WORKER_PROFILE_ID") ?? "default"; +string serverlessActivityImage = Environment.GetEnvironmentVariable("DTS_SERVERLESS_ACTIVITY_IMAGE") + ?? "serverless-remote-worker:local"; +string helloInput = ParseHelloInput( + args, + Environment.GetEnvironmentVariable("DTS_SAMPLE_HELLO_INPUT") ?? "serverless-sample"); +TokenCredential credential = new DefaultAzureCredential(); + +HostApplicationBuilder builder = Host.CreateApplicationBuilder(args); +builder.Logging.AddSimpleConsole(options => +{ + options.SingleLine = true; + options.UseUtcTimestamp = true; + options.TimestampFormat = "yyyy-MM-ddTHH:mm:ss.fffZ "; +}); + +builder.Services.AddDurableTaskWorker(workerBuilder => +{ + workerBuilder.AddTasks(tasks => tasks.AddAllGeneratedTasks()); + workerBuilder.UseDurableTaskScheduler(options => + { + options.EndpointAddress = endpoint; + options.TaskHubName = taskHub; + options.Credential = credential; + }); + + workerBuilder.DeclareServerlessActivities(options => + { + options.TaskHub = taskHub; + options.WorkerProfileId = workerProfileId; + options.ContainerImage = serverlessActivityImage; + options.Cpu = Environment.GetEnvironmentVariable("DTS_SERVERLESS_CPU") ?? "1000m"; + options.Memory = Environment.GetEnvironmentVariable("DTS_SERVERLESS_MEMORY") ?? "2048Mi"; + options.MaxConcurrentActivities = GetIntEnv("DTS_SERVERLESS_MAX_ACTIVITIES", 1); + options.AddActivity(ServerlessTaskNames.RemoteHello); + }); +}); + +builder.Services.AddDurableTaskClient(clientBuilder => +{ + clientBuilder.UseDurableTaskScheduler(options => + { + options.EndpointAddress = endpoint; + options.TaskHubName = taskHub; + options.Credential = credential; + }); +}); + +using IHost host = builder.Build(); + +await host.StartAsync(); + +DurableTaskClient client = host.Services.GetRequiredService(); +string instanceId = await client.ScheduleNewOrchestrationInstanceAsync( + ServerlessTaskNames.HelloOrchestrator, + input: helloInput); +OrchestrationMetadata? result = await client.WaitForInstanceCompletionAsync( + instanceId, + getInputsAndOutputs: true); + +Console.WriteLine($"Started orchestration: {instanceId}"); +Console.WriteLine($"Runtime status: {result?.RuntimeStatus}"); +Console.WriteLine($"Output: {result?.SerializedOutput ?? ""}"); + +await host.StopAsync(); + +static string GetRequiredEnvironmentVariable(string name) + => Environment.GetEnvironmentVariable(name) + ?? throw new InvalidOperationException($"An environment variable named '{name}' is required."); + +static int GetIntEnv(string name, int defaultValue) +{ + string? value = Environment.GetEnvironmentVariable(name); + if (string.IsNullOrWhiteSpace(value)) + { + return defaultValue; + } + + return int.TryParse(value, out int parsed) && parsed > 0 + ? parsed + : throw new InvalidOperationException($"Environment variable '{name}' must be a positive integer."); +} + +static string ParseHelloInput(string[] args, string defaultHelloInput) +{ + if (args.Length == 0) + { + return defaultHelloInput; + } + + string verb = args[0].ToLowerInvariant(); + return verb switch + { + "hello" => args.Length > 1 ? args[1] : defaultHelloInput, + _ => throw new InvalidOperationException("Supported commands: hello [name]."), + }; +} diff --git a/samples/serverless/main-app/main-app.csproj b/samples/serverless/main-app/main-app.csproj new file mode 100644 index 00000000..f3987d2a --- /dev/null +++ b/samples/serverless/main-app/main-app.csproj @@ -0,0 +1,24 @@ + + + + Exe + net10.0 + enable + ServerlessMainApp + Microsoft.DurableTask.Samples.Serverless.MainApp + + + + + + + + + + + + + + + + diff --git a/samples/serverless/remote-worker/Activities.cs b/samples/serverless/remote-worker/Activities.cs new file mode 100644 index 00000000..51a576c1 --- /dev/null +++ b/samples/serverless/remote-worker/Activities.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DurableTask; + +namespace Microsoft.DurableTask.Samples.Serverless.RemoteWorker; + +[DurableTask("RemoteHello")] +internal sealed class RemoteHelloActivity : TaskActivity +{ + public override Task RunAsync(TaskActivityContext context, string input) + => Task.FromResult($"hello from {Environment.MachineName} pid={Environment.ProcessId}: {input}"); +} diff --git a/samples/serverless/remote-worker/Containerfile b/samples/serverless/remote-worker/Containerfile new file mode 100644 index 00000000..18f24b23 --- /dev/null +++ b/samples/serverless/remote-worker/Containerfile @@ -0,0 +1,34 @@ +# syntax=docker/dockerfile:1.7 + +FROM --platform=$TARGETPLATFORM mcr.microsoft.com/dotnet/sdk:10.0 AS build +WORKDIR /src +ARG TARGETARCH + +COPY . /src/durabletask-dotnet + +WORKDIR /src/durabletask-dotnet/samples/serverless/remote-worker +RUN case "$TARGETARCH" in \ + amd64) runtime_identifier=linux-x64 ;; \ + arm64) runtime_identifier=linux-arm64 ;; \ + *) echo "Unsupported target architecture: $TARGETARCH" >&2; exit 1 ;; \ + esac \ + && dotnet publish remote-worker.csproj \ + -c Release \ + -r "$runtime_identifier" \ + --self-contained false \ + -o /app/publish \ + --configfile /src/durabletask-dotnet/nuget.config \ + /p:DebugSymbols=false \ + /p:DebugType=None \ + && find /app/publish -type f \( -name '*.xml' -o -name '*.pdb' \) -delete + +FROM mcr.microsoft.com/dotnet/aspnet:10.0 AS runtime +WORKDIR /app + +ENV ASPNETCORE_URLS=http://+:8080 + +EXPOSE 8080 + +COPY --from=build /app/publish ./ + +ENTRYPOINT ["dotnet", "ServerlessRemoteWorker.dll"] diff --git a/samples/serverless/remote-worker/Containerfile.dockerignore b/samples/serverless/remote-worker/Containerfile.dockerignore new file mode 100644 index 00000000..9b8836cf --- /dev/null +++ b/samples/serverless/remote-worker/Containerfile.dockerignore @@ -0,0 +1,20 @@ +** +!Directory.Build.props +!Directory.Build.targets +!Directory.Packages.props +!global.json +!nuget.config +!stylecop.json +!eng/ +!eng/** +!src/ +!src/** +!samples/ +!samples/Directory.Build.props +!samples/Directory.Packages.props +!samples/serverless/ +!samples/serverless/** +**/bin/ +**/obj/ +**/.git/ +**/.tunnel-url \ No newline at end of file diff --git a/samples/serverless/remote-worker/Program.cs b/samples/serverless/remote-worker/Program.cs new file mode 100644 index 00000000..b705fba5 --- /dev/null +++ b/samples/serverless/remote-worker/Program.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DurableTask.Samples.Serverless.RemoteWorker; +using Microsoft.DurableTask.Worker; +using Microsoft.DurableTask.Worker.AzureManaged; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +HostApplicationBuilder builder = Host.CreateApplicationBuilder(args); +builder.Logging.AddSimpleConsole(options => +{ + options.SingleLine = true; + options.UseUtcTimestamp = true; + options.TimestampFormat = "yyyy-MM-ddTHH:mm:ss.fffZ "; +}); + +builder.Services.AddDurableTaskWorker(workerBuilder => +{ + workerBuilder.AddTasks(tasks => + { + tasks.AddActivity(); + }); + workerBuilder.UseServerlessWorker(); +}); + +await builder.Build().RunAsync(); diff --git a/samples/serverless/remote-worker/remote-worker.csproj b/samples/serverless/remote-worker/remote-worker.csproj new file mode 100644 index 00000000..c358c3cb --- /dev/null +++ b/samples/serverless/remote-worker/remote-worker.csproj @@ -0,0 +1,21 @@ + + + + Exe + net10.0 + enable + ServerlessRemoteWorker + Microsoft.DurableTask.Samples.Serverless.RemoteWorker + + + + + + + + + + + + + diff --git a/src/Extensions/AzureManagedServerless/AzureManagedServerless.csproj b/src/Extensions/AzureManagedServerless/AzureManagedServerless.csproj new file mode 100644 index 00000000..413b65ea --- /dev/null +++ b/src/Extensions/AzureManagedServerless/AzureManagedServerless.csproj @@ -0,0 +1,25 @@ + + + + net6.0;net8.0;net10.0 + Azure Managed serverless activities support for Durable Task. + Microsoft.DurableTask.AzureManaged.Serverless + true + + + + + + + + + + + + + + + + + + diff --git a/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClient.cs b/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClient.cs new file mode 100644 index 00000000..01f32ae7 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClient.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Client.AzureManaged; + +/// +/// Client for DTS serverless activity management operations. +/// +public sealed class ServerlessActivitiesClient +{ + readonly Proto.ServerlessActivities.ServerlessActivitiesClient client; + + /// + /// Initializes a new instance of the class. + /// + /// The generated gRPC client used to call DTS serverless management operations. + internal ServerlessActivitiesClient(Proto.ServerlessActivities.ServerlessActivitiesClient client) + { + this.client = client; + } + + /// + /// Removes a serverless activity declaration for a worker profile. + /// + /// The worker profile ID whose declaration should be removed. + /// The cancellation token used to cancel the request. + /// A task that completes when DTS removes the declaration. + public Task RemoveServerlessActivityDeclarationAsync( + string workerProfileId, + CancellationToken cancellation = default) + => this.client.RemoveServerlessActivityDeclarationAsync(workerProfileId, cancellation); +} diff --git a/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClientExtensions.cs b/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClientExtensions.cs new file mode 100644 index 00000000..05de3530 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClientExtensions.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Grpc.Core; +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Client.AzureManaged; + +/// +/// Extension methods for the generated serverless activities gRPC client. +/// +public static class ServerlessActivitiesClientExtensions +{ + /// + /// Removes a serverless activity declaration for a worker profile using task hub metadata already configured on the gRPC channel. + /// + /// The generated serverless activities gRPC client. + /// The worker profile ID whose declaration should be removed. + /// The cancellation token used to cancel the request. + /// A task that completes when DTS removes the declaration. + public static Task RemoveServerlessActivityDeclarationAsync( + this Proto.ServerlessActivities.ServerlessActivitiesClient client, + string workerProfileId, + CancellationToken cancellation = default) + { + return RemoveServerlessActivityDeclarationCoreAsync( + client, + workerProfileId, + cancellation); + } + + static async Task RemoveServerlessActivityDeclarationCoreAsync( + Proto.ServerlessActivities.ServerlessActivitiesClient client, + string workerProfileId, + CancellationToken cancellation) + { + ArgumentNullException.ThrowIfNull(client); + ValidateRequired(workerProfileId, nameof(workerProfileId), "Worker profile ID is required."); + + Proto.RemoveServerlessActivityDeclarationRequest request = new() + { + WorkerProfileId = workerProfileId, + }; + + using AsyncUnaryCall call = client.RemoveServerlessActivityDeclarationAsync( + request, + headers: null, + cancellationToken: cancellation); + await call.ResponseAsync.ConfigureAwait(false); + } + + static void ValidateRequired(string value, string parameterName, string message) + { + if (string.IsNullOrWhiteSpace(value)) + { + throw new ArgumentException(message, parameterName); + } + } +} diff --git a/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClientServiceCollectionExtensions.cs b/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClientServiceCollectionExtensions.cs new file mode 100644 index 00000000..ff0ed57a --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Client/ServerlessActivitiesClientServiceCollectionExtensions.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Grpc.Net.Client; +using Microsoft.DurableTask.Client.Grpc; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Client.AzureManaged; + +/// +/// Extension methods for registering DTS serverless activity management clients. +/// +public static class ServerlessActivitiesClientServiceCollectionExtensions +{ + /// + /// Adds a DTS serverless activity management client using the default Durable Task client configuration. + /// + /// The service collection to configure. + /// The original service collection, for call chaining. + public static IServiceCollection AddDurableTaskSchedulerServerlessActivitiesClient(this IServiceCollection services) + => AddDurableTaskSchedulerServerlessActivitiesClient(services, Options.DefaultName); + + /// + /// Adds a DTS serverless activity management client using a named Durable Task client configuration. + /// + /// The service collection to configure. + /// The Durable Task client name whose scheduler channel should be reused. + /// The original service collection, for call chaining. + public static IServiceCollection AddDurableTaskSchedulerServerlessActivitiesClient( + this IServiceCollection services, + string clientName) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(clientName); + + services.AddSingleton(provider => + { + GrpcDurableTaskClientOptions options = provider + .GetRequiredService>() + .Get(clientName); + + if (options.CallInvoker is { } callInvoker) + { + return new ServerlessActivitiesClient(new Proto.ServerlessActivities.ServerlessActivitiesClient(callInvoker)); + } + + if (options.Channel is GrpcChannel channel) + { + return new ServerlessActivitiesClient(new Proto.ServerlessActivities.ServerlessActivitiesClient(channel.CreateCallInvoker())); + } + + throw new InvalidOperationException("DTS serverless activity management requires a configured Durable Task Scheduler client."); + }); + return services; + } +} diff --git a/src/Extensions/AzureManagedServerless/DurableTaskSchedulerServerlessWorkerExtensions.cs b/src/Extensions/AzureManagedServerless/DurableTaskSchedulerServerlessWorkerExtensions.cs new file mode 100644 index 00000000..695408f1 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/DurableTaskSchedulerServerlessWorkerExtensions.cs @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Linq; +using Azure.Identity; +using Grpc.Net.Client; +using Microsoft.DurableTask.Protobuf.Serverless; +using Microsoft.DurableTask.Worker.AzureManaged.Serverless; +using Microsoft.DurableTask.Worker.Grpc; +using Microsoft.DurableTask.Worker.Grpc.Internal; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Microsoft.DurableTask.Worker.AzureManaged; + +/// +/// Extension methods for configuring Azure Managed Durable Task workers with serverless activity support. +/// +public static class DurableTaskSchedulerServerlessWorkerExtensions +{ + /// + /// Declares serverless activities with DTS and excludes them from local execution. + /// Call this on the local coordinator worker — not on the sandbox worker binary. + /// + /// The Durable Task worker builder to configure. + /// Callback to configure serverless activity behavior. + /// The original builder, for call chaining. + public static IDurableTaskWorkerBuilder DeclareServerlessActivities( + this IDurableTaskWorkerBuilder builder, + Action configure) + { + Check.NotNull(builder); + Check.NotNull(configure); + + builder.Services.AddOptions(builder.Name) + .Configure(configure) + .PostConfigure>((options, schedulerOptions) => + ApplyTaskHubDefault(options, schedulerOptions.Get(builder.Name).TaskHubName)); + + builder.Services.AddOptions(builder.Name) + .PostConfigure>( + (filters, serverlessOptions) => ExcludeServerlessActivitiesFromLocalExecution(filters, serverlessOptions.Get(builder.Name))); + + builder.Services.AddSingleton(sp => CreateServerlessActivityDeclarationHostedService(sp, builder.Name)); + return builder; + } + + /// + /// Configures this worker as a serverless activity worker that connects to DTS to receive and execute + /// serverless activities. Use this on a dedicated worker binary that runs inside serverless infrastructure. + /// Runtime configuration is read from environment variables injected by DTS. + /// + /// + /// + /// This method is for separate worker binaries only. The coordinator uses + /// to declare and provision the serverless activity configuration. + /// + /// + /// Required environment variables injected automatically by DTS: + /// + /// DTS_ENDPOINT — canonical scheduler endpoint + /// DTS_TASK_HUB — task hub name from the declaration + /// DTS_SUBSTRATE — identifies the sandbox substrate + /// + /// + /// + /// The Durable Task worker builder to configure. + /// The original builder, for call chaining. + public static IDurableTaskWorkerBuilder UseServerlessWorker(this IDurableTaskWorkerBuilder builder) + { + Check.NotNull(builder); + + ConfigureDurableTaskSchedulerFromEnvironment(builder); + builder.UseWorkItemFilters(); + + builder.Services.AddOptions(builder.Name) + .PostConfigure>((options, schedulerOptions) => + { + ApplyRuntimeTaskHubDefault(options, schedulerOptions.Get(builder.Name).TaskHubName); + ApplyWorkerEnvironmentOverrides(options); + }); + + builder.Services.AddOptions(builder.Name) + .PostConfigure(IncludeOnlyRegisteredActivities); + + builder.Services.AddSingleton(); + builder.Services.AddOptions(builder.Name) + .Configure((options, activityTracker) => + options.ConfigureActivityNotification(phase => + { + if (phase == ActivityNotificationPhase.Started) + { + activityTracker.NotifyActivityStarted(); + } + else if (phase == ActivityNotificationPhase.Completed) + { + activityTracker.NotifyActivityCompleted(); + } + })); + + builder.Services.AddSingleton(sp => CreateServerlessActivityWorkerRegistrationHostedService(sp, builder.Name)); + return builder; + } + + static void ExcludeServerlessActivitiesFromLocalExecution(DurableTaskWorkerWorkItemFilters filters, ServerlessOptions options) + { + string[] activityNames = ServerlessActivityConfiguration.ResolveActivityNames(options.ActivityNames); + if (activityNames.Length == 0) + { + return; + } + + filters.ExcludedActivities = MergeActivityFilters(filters.ExcludedActivities, activityNames); + } + + static void IncludeOnlyRegisteredActivities(DurableTaskWorkerWorkItemFilters filters) + { + filters.Orchestrations = []; + filters.ExcludedActivities = []; + filters.Entities = []; + } + + static ServerlessActivityDeclarationHostedService CreateServerlessActivityDeclarationHostedService( + IServiceProvider services, + string builderName) + { + ServerlessOptions options = services.GetRequiredService>().Get(builderName); + ILoggerFactory loggerFactory = services.GetRequiredService(); + ServerlessWorkerRuntimeOptions runtimeOptions = services.GetRequiredService>().Get(builderName); + + return new ServerlessActivityDeclarationHostedService( + CreateServerlessActivitiesClient(services, builderName), + options, + runtimeOptions, + loggerFactory.CreateLogger()); + } + + static ServerlessActivityWorkerRegistrationHostedService CreateServerlessActivityWorkerRegistrationHostedService( + IServiceProvider services, + string builderName) + { + ServerlessWorkerRuntimeOptions options = services.GetRequiredService>().Get(builderName); + ILoggerFactory loggerFactory = services.GetRequiredService(); + IHostApplicationLifetime? lifetime = services.GetService(); + ServerlessActivityTracker activityTracker = services.GetRequiredService(); + DurableTaskWorkerWorkItemFilters filters = services.GetRequiredService>().Get(builderName); + + return new ServerlessActivityWorkerRegistrationHostedService( + CreateServerlessActivitiesClient(services, builderName), + options, + ResolveActivityFilterNames(filters.Activities), + loggerFactory.CreateLogger(), + lifetime, + activityTracker); + } + + static ServerlessActivitiesClientAdapter CreateServerlessActivitiesClient(IServiceProvider services, string builderName) + { + GrpcDurableTaskWorkerOptions options = services.GetRequiredService>().Get(builderName); + if (options.CallInvoker is { } callInvoker) + { + return new ServerlessActivitiesClientAdapter(new ServerlessActivities.ServerlessActivitiesClient(callInvoker)); + } + + if (options.Channel is { } channel) + { + return new ServerlessActivitiesClientAdapter( + new ServerlessActivities.ServerlessActivitiesClient(channel.CreateCallInvoker()), + attachTaskHubMetadata: false); + } + + throw new InvalidOperationException("Azure Managed serverless activities require a configured gRPC channel or call invoker."); + } + + static void ApplyTaskHubDefault(ServerlessOptions options, string taskHubName) + { + if (string.IsNullOrWhiteSpace(options.TaskHub) && !string.IsNullOrWhiteSpace(taskHubName)) + { + options.TaskHub = taskHubName; + } + } + + static void ApplyRuntimeTaskHubDefault(ServerlessWorkerRuntimeOptions options, string taskHubName) + { + if (string.IsNullOrWhiteSpace(options.TaskHub) && !string.IsNullOrWhiteSpace(taskHubName)) + { + options.TaskHub = taskHubName; + } + } + + static void ConfigureDurableTaskSchedulerFromEnvironment(IDurableTaskWorkerBuilder builder) + { + string endpoint = GetRequiredEnvironmentVariable("DTS_ENDPOINT"); + string taskHub = GetRequiredEnvironmentVariable("DTS_TASK_HUB"); + + // Private preview: DTS-owned sandbox workers authenticate with the injected + // managed identity via DefaultAzureCredential. Revisit this if customer-owned + // worker identities or non-default auth modes are introduced. + builder.UseDurableTaskScheduler(endpoint, taskHub, new DefaultAzureCredential()); + } + + static string GetRequiredEnvironmentVariable(string name) + { + string? value = Environment.GetEnvironmentVariable(name); + return string.IsNullOrWhiteSpace(value) + ? throw new InvalidOperationException($"{name} must be injected by DTS for serverless workers.") + : value.Trim(); + } + + static void ApplyWorkerEnvironmentOverrides(ServerlessWorkerRuntimeOptions options) + { + // Auto-detect worker mode from DTS_SUBSTRATE, which the backend injects when + // launching a sandbox. This is the authoritative signal that this process is a sandbox worker. + string? substrate = Environment.GetEnvironmentVariable("DTS_SUBSTRATE"); + if (string.Equals(substrate, "Sandbox", StringComparison.OrdinalIgnoreCase) + || string.Equals(substrate, "AcaSessionPool", StringComparison.OrdinalIgnoreCase)) + { + options.Mode = ServerlessMode.ServerlessInclude; + } + + ApplyWorkerProfileEnvironmentOverride(profile => options.WorkerProfileId = profile); + + if (int.TryParse(Environment.GetEnvironmentVariable("DTS_SERVERLESS_MAX_ACTIVITIES"), out int maxActivities) && maxActivities > 0) + { + options.MaxConcurrentActivities = maxActivities; + } + } + + static void ApplyWorkerProfileEnvironmentOverride(Action setWorkerProfileId) + { + string? workerProfileId = Environment.GetEnvironmentVariable("DTS_WORKER_PROFILE_ID"); + if (!string.IsNullOrWhiteSpace(workerProfileId)) + { + setWorkerProfileId(workerProfileId.Trim()); + } + } + + static DurableTaskWorkerWorkItemFilters.ActivityFilter[] MergeActivityFilters( + IReadOnlyList existingFilters, + IEnumerable activityNames) + { + Dictionary merged = new(StringComparer.OrdinalIgnoreCase); + foreach (DurableTaskWorkerWorkItemFilters.ActivityFilter filter in existingFilters) + { + if (!string.IsNullOrWhiteSpace(filter.Name)) + { + merged[filter.Name] = filter; + } + } + + foreach (string activityName in activityNames) + { + merged[activityName] = new DurableTaskWorkerWorkItemFilters.ActivityFilter { Name = activityName }; + } + + return merged.Values.ToArray(); + } + + static string[] ResolveActivityFilterNames(IReadOnlyList activityFilters) + { + return activityFilters + .Select(static filter => filter.Name) + .Where(static name => !string.IsNullOrWhiteSpace(name)) + .Select(static name => name.Trim()) + .Distinct(StringComparer.Ordinal) + .ToArray(); + } +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/Logs.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/Logs.cs new file mode 100644 index 00000000..3f1bbfa0 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/Logs.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Logging; +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Log messages for serverless activity services. +/// +static partial class Logs +{ + [LoggerMessage( + EventId = 1, + Level = LogLevel.Information, + Message = "No serverless activities discovered for hub={Hub}; skipping declaration")] + public static partial void NoServerlessActivitiesForDeclaration(ILogger logger, string hub); + + [LoggerMessage( + EventId = 2, + Level = LogLevel.Information, + Message = "Serverless activities declared hub={Hub} workerProfile={WorkerProfile} count={Count} image={Image}")] + public static partial void ServerlessActivitiesDeclared(ILogger logger, string hub, string workerProfile, int count, string image); + + [LoggerMessage( + EventId = 4, + Level = LogLevel.Error, + Message = "Serverless activity declaration failed hub={Hub}")] + public static partial void ServerlessActivityDeclarationFailed(ILogger logger, Exception exception, string hub); + + [LoggerMessage( + EventId = 5, + Level = LogLevel.Information, + Message = "No serverless activities discovered for worker hub={Hub}; skipping live registration")] + public static partial void NoServerlessActivitiesForWorkerRegistration(ILogger logger, string hub); + + [LoggerMessage( + EventId = 6, + Level = LogLevel.Information, + Message = "Serverless activity worker registered hub={Hub} count={Count} substrate={Substrate} sandboxId={SandboxId}")] + public static partial void ServerlessActivityWorkerRegistered( + ILogger logger, string hub, int count, Proto.SubstrateKind substrate, string sandboxId); + + [LoggerMessage( + EventId = 7, + Level = LogLevel.Error, + Message = "Serverless activity worker registration stream failed hub={Hub}")] + public static partial void ServerlessActivityWorkerRegistrationFailed(ILogger logger, Exception exception, string hub); +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivitiesClientAdapter.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivitiesClientAdapter.cs new file mode 100644 index 00000000..2ad14739 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivitiesClientAdapter.cs @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Grpc.Core; +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Client abstraction for the serverless activities gRPC service. +/// +interface IServerlessActivitiesClient +{ + /// + /// Declares serverless activities to DTS. + /// + /// The declaration message. + /// The task hub that owns the declaration. + /// The cancellation token. + /// The declaration result. + Task DeclareServerlessActivitiesAsync( + Proto.ServerlessActivityDeclaration declaration, + string taskHub, + CancellationToken cancellationToken); + + /// + /// Opens a serverless activity worker registration session. + /// + /// The task hub that owns the worker session. + /// The cancellation token. + /// The worker registration session. + IServerlessActivityWorkerSession OpenServerlessActivityWorkerSession(string taskHub, CancellationToken cancellationToken); +} + +/// +/// Client-streaming session used by a serverless activity worker registration. +/// +interface IServerlessActivityWorkerSession : IAsyncDisposable +{ + /// + /// Writes a worker registration message to the stream. + /// + /// The message to write. + /// A task that completes when the message is written. + Task WriteMessageAsync(Proto.ServerlessActivityWorkerMessage message); + + /// + /// Waits for the server to complete the worker registration session. + /// + /// The worker session result. + Task WaitForCompletionAsync(); + + /// + /// Completes the request stream and waits for the server response. + /// + /// A task that completes when the server response is observed. + Task CompleteAsync(); +} + +/// +/// gRPC-backed implementation of . +/// +sealed class ServerlessActivitiesClientAdapter : IServerlessActivitiesClient +{ + readonly Proto.ServerlessActivities.ServerlessActivitiesClient client; + readonly bool attachTaskHubMetadata; + + /// + /// Initializes a new instance of the class. + /// + /// The generated serverless activities gRPC client. + /// True to add per-call task hub metadata when the underlying channel does not already do so. + public ServerlessActivitiesClientAdapter( + Proto.ServerlessActivities.ServerlessActivitiesClient client, + bool attachTaskHubMetadata = true) + { + this.client = Check.NotNull(client); + this.attachTaskHubMetadata = attachTaskHubMetadata; + } + + /// + public async Task DeclareServerlessActivitiesAsync( + Proto.ServerlessActivityDeclaration declaration, + string taskHub, + CancellationToken cancellationToken) + { + return await this.client.DeclareServerlessActivitiesAsync( + declaration, + headers: this.CreateTaskHubHeaders(taskHub), + cancellationToken: cancellationToken) + .ResponseAsync.ConfigureAwait(false); + } + + /// + public IServerlessActivityWorkerSession OpenServerlessActivityWorkerSession(string taskHub, CancellationToken cancellationToken) + { + AsyncClientStreamingCall call = + this.client.ConnectServerlessActivityWorker( + headers: this.CreateTaskHubHeaders(taskHub), + cancellationToken: cancellationToken); + return new GrpcServerlessActivityWorkerSession(call); + } + + Metadata? CreateTaskHubHeaders(string taskHub) => this.attachTaskHubMetadata + ? new Metadata { { "taskhub", taskHub }, } + : null; + + /// + /// gRPC-backed serverless activity worker registration session. + /// + sealed class GrpcServerlessActivityWorkerSession : IServerlessActivityWorkerSession + { + readonly AsyncClientStreamingCall call; + + /// + /// Initializes a new instance of the class. + /// + /// The active gRPC client-streaming call. + public GrpcServerlessActivityWorkerSession(AsyncClientStreamingCall call) + { + this.call = call; + } + + /// + public Task WriteMessageAsync(Proto.ServerlessActivityWorkerMessage message) => + this.call.RequestStream.WriteAsync(message); + + /// + public async Task WaitForCompletionAsync() => + await this.call.ResponseAsync.ConfigureAwait(false); + + /// + public async Task CompleteAsync() + { + await this.call.RequestStream.CompleteAsync().ConfigureAwait(false); + await this.WaitForCompletionAsync().ConfigureAwait(false); + } + + /// + public ValueTask DisposeAsync() + { + this.call.Dispose(); + return default; + } + } +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityConfiguration.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityConfiguration.cs new file mode 100644 index 00000000..9e8bc2f6 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityConfiguration.cs @@ -0,0 +1,234 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Builds and normalizes serverless activity protocol messages. +/// +static class ServerlessActivityConfiguration +{ + /// + /// Resolves configured activity names for serverless activity execution. + /// + /// The configured activity names. + /// The normalized activity names. + public static string[] ResolveActivityNames(IEnumerable configuredNames) + { + return configuredNames + .Where(static name => !string.IsNullOrWhiteSpace(name)) + .Select(static name => name.Trim()) + .Distinct(StringComparer.Ordinal) + .ToArray(); + } + + /// + /// Builds a serverless activity declaration protocol message. + /// + /// The serverless options. + /// The activity names included in the declaration. + /// The declaration protocol message. + public static Proto.ServerlessActivityDeclaration BuildDeclaration(ServerlessOptions options, IReadOnlyCollection activityNames) + { + Check.NotNull(options); + Check.NotNull(activityNames); + + ValidateTaskHub(options.TaskHub, "Serverless activity declaration requires a task hub name."); + + if (activityNames.Count == 0) + { + throw new InvalidOperationException("Serverless activity declaration requires at least one activity name."); + } + + string workerProfileId = NormalizeWorkerProfileId(options.WorkerProfileId, "Serverless activity declaration requires a worker profile ID."); + + if (options.MaxConcurrentActivities <= 0) + { + throw new InvalidOperationException("Serverless activity max concurrent activities must be greater than zero."); + } + + Proto.ServerlessActivityDeclaration declaration = new() + { + WorkerProfileId = workerProfileId, + Image = BuildImage(options), + Resources = BuildResources(options), + MaxConcurrentActivities = options.MaxConcurrentActivities, + }; + + declaration.ActivityNames.AddRange(activityNames); + declaration.EnvironmentVariables.Add(options.EnvironmentVariables); + declaration.Entrypoint.AddRange(NormalizeOptionalStrings(options.Entrypoint)); + declaration.Cmd.AddRange(NormalizeOptionalStrings(options.Cmd)); + return declaration; + } + + /// + /// Builds the initial serverless activity worker registration message. + /// + /// The serverless options. + /// The activity handlers registered by the worker process. + /// The worker start protocol message. + public static Proto.ServerlessActivityWorkerMessage BuildWorkerStart( + ServerlessWorkerRuntimeOptions options, + IReadOnlyCollection registeredActivityNames) + { + Check.NotNull(options); + Check.NotNull(registeredActivityNames); + + ValidateTaskHub(options.TaskHub, "Serverless activity worker registration requires a task hub name."); + string[] activityNames = ResolveActivityNames(registeredActivityNames); + if (activityNames.Length == 0) + { + throw new InvalidOperationException("Serverless activity worker registration requires at least one registered activity."); + } + + if (options.MaxConcurrentActivities <= 0) + { + throw new InvalidOperationException("Serverless activity worker max concurrent activities must be greater than zero."); + } + + string workerProfileId = NormalizeWorkerProfileId(options.WorkerProfileId, "Serverless activity worker registration requires a worker profile ID."); + + Proto.ServerlessActivityWorkerStart start = new() + { + TaskHub = options.TaskHub, + WorkerProfileId = workerProfileId, + MaxActivitiesCount = options.MaxConcurrentActivities, + Substrate = GetSubstrateFromEnvironment(), + DtsSandboxIdentifier = Environment.GetEnvironmentVariable("DTS_SANDBOX_ID") ?? string.Empty, + }; + start.ActivityNames.AddRange(activityNames); + + return new Proto.ServerlessActivityWorkerMessage { Start = start }; + } + + /// + /// Builds a serverless activity worker heartbeat message. + /// + /// The number of activities currently executing. + /// The heartbeat protocol message. + public static Proto.ServerlessActivityWorkerMessage BuildWorkerHeartbeat(int activeActivitiesCount) + { + if (activeActivitiesCount < 0) + { + throw new InvalidOperationException("Serverless activity worker active activity count cannot be negative."); + } + + return new Proto.ServerlessActivityWorkerMessage + { + Heartbeat = new Proto.ServerlessActivityWorkerHeartbeat + { + ActiveActivitiesCount = activeActivitiesCount, + }, + }; + } + + static Proto.ServerlessActivityImage BuildImage(ServerlessOptions options) + { + string? imageRef = Coalesce( + options.ContainerImage, + BuildImageRef(options.RegistryServer, options.Repository, options.Tag, options.ImageDigest)); + + if (string.IsNullOrWhiteSpace(imageRef)) + { + throw new InvalidOperationException("Serverless activity image metadata requires a container image reference."); + } + + return new Proto.ServerlessActivityImage + { + ImageRef = imageRef, + }; + } + + static Proto.ServerlessActivityResources BuildResources(ServerlessOptions options) + { + string cpu = NormalizeRequired(options.Cpu, "Serverless activity declaration requires CPU resources."); + string memory = NormalizeRequired(options.Memory, "Serverless activity declaration requires memory resources."); + + return new Proto.ServerlessActivityResources + { + Cpu = cpu, + Memory = memory, + }; + } + + static Proto.SubstrateKind GetSubstrateFromEnvironment() + { + string? substrate = Environment.GetEnvironmentVariable("DTS_SUBSTRATE"); + if (substrate is null) + { + return Proto.SubstrateKind.Unspecified; + } + + if (substrate.Equals("Sandbox", StringComparison.OrdinalIgnoreCase)) + { + return Proto.SubstrateKind.Sandbox; + } + + if (substrate.Equals("AcaSessionPool", StringComparison.OrdinalIgnoreCase)) + { + return Proto.SubstrateKind.AcaSessionPool; + } + + return Proto.SubstrateKind.Unspecified; + } + + static void ValidateTaskHub(string value, string errorMessage) + { + _ = NormalizeRequired(value, errorMessage); + } + + static string NormalizeWorkerProfileId(string value, string errorMessage) + { + return NormalizeRequired(value, errorMessage); + } + + static string NormalizeRequired(string value, string errorMessage) + { + if (string.IsNullOrWhiteSpace(value)) + { + throw new InvalidOperationException(errorMessage); + } + + return value.Trim(); + } + + static string[] NormalizeOptionalStrings(IEnumerable values) + { + return values + .Where(static value => !string.IsNullOrWhiteSpace(value)) + .Select(static value => value.Trim()) + .ToArray(); + } + + static string? BuildImageRef(string? registryServer, string? repository, string? tag, string? digest) + { + if (string.IsNullOrWhiteSpace(repository)) + { + return null; + } + + string image = string.IsNullOrWhiteSpace(registryServer) ? repository : $"{registryServer}/{repository}"; + if (!string.IsNullOrWhiteSpace(digest)) + { + return $"{image}@{digest}"; + } + + return string.IsNullOrWhiteSpace(tag) ? image : $"{image}:{tag}"; + } + + static string? Coalesce(params string?[] values) + { + foreach (string? value in values) + { + if (!string.IsNullOrWhiteSpace(value)) + { + return value.Trim(); + } + } + + return null; + } +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityDeclarationHostedService.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityDeclarationHostedService.cs new file mode 100644 index 00000000..4c3d8590 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityDeclarationHostedService.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Hosted service that declares serverless activities with DTS when the local worker starts. +/// +sealed class ServerlessActivityDeclarationHostedService : IHostedService +{ + readonly IServerlessActivitiesClient client; + readonly ServerlessOptions options; + readonly ServerlessWorkerRuntimeOptions? runtimeOptions; + readonly ILogger logger; + + /// + /// Initializes a new instance of the class. + /// + /// The serverless activities client. + /// The serverless options. + /// The optional serverless worker runtime options. + /// The logger. + public ServerlessActivityDeclarationHostedService( + IServerlessActivitiesClient client, + ServerlessOptions options, + ServerlessWorkerRuntimeOptions? runtimeOptions, + ILogger logger) + { + this.client = Check.NotNull(client); + this.options = Check.NotNull(options); + this.runtimeOptions = runtimeOptions; + this.logger = Check.NotNull(logger); + } + + /// + public async Task StartAsync(CancellationToken cancellationToken) + { + if (this.runtimeOptions?.Mode == ServerlessMode.ServerlessInclude) + { + return; + } + + string[] activityNames = ServerlessActivityConfiguration.ResolveActivityNames(this.options.ActivityNames); + if (activityNames.Length == 0) + { + Logs.NoServerlessActivitiesForDeclaration(this.logger, this.options.TaskHub); + return; + } + + Proto.ServerlessActivityDeclaration declaration = ServerlessActivityConfiguration.BuildDeclaration( + this.options, + activityNames); + try + { + await this.client.DeclareServerlessActivitiesAsync( + declaration, + this.options.TaskHub, + cancellationToken).ConfigureAwait(false); + Logs.ServerlessActivitiesDeclared( + this.logger, + this.options.TaskHub, + declaration.WorkerProfileId, + declaration.ActivityNames.Count, + declaration.Image?.ImageRef ?? string.Empty); + } + catch (Exception ex) + { + Logs.ServerlessActivityDeclarationFailed(this.logger, ex, this.options.TaskHub); + throw; + } + } + + /// + public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityTracker.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityTracker.cs new file mode 100644 index 00000000..36237ce2 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityTracker.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Tracks activity execution state for a serverless worker process. +/// +sealed class ServerlessActivityTracker +{ + int activeActivityCount; + + /// + /// Gets the number of activities currently in flight on this worker. + /// + public int InFlightCount => Volatile.Read(ref this.activeActivityCount); + + /// + /// Records the start of an in-flight activity. + /// + internal void NotifyActivityStarted() => Interlocked.Increment(ref this.activeActivityCount); + + /// + /// Records the completion of an activity. + /// + internal void NotifyActivityCompleted() + { + while (true) + { + int currentCount = Volatile.Read(ref this.activeActivityCount); + if (currentCount == 0) + { + return; + } + + if (Interlocked.CompareExchange(ref this.activeActivityCount, currentCount - 1, currentCount) == currentCount) + { + return; + } + } + } +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityWorkerRegistrationHostedService.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityWorkerRegistrationHostedService.cs new file mode 100644 index 00000000..fa00e396 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessActivityWorkerRegistrationHostedService.cs @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.IO; +using Grpc.Core; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Proto = Microsoft.DurableTask.Protobuf.Serverless; + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Hosted service that registers a running process as a serverless activity worker with DTS. +/// +sealed class ServerlessActivityWorkerRegistrationHostedService : IHostedService, IAsyncDisposable +{ + readonly object sync = new(); + readonly IServerlessActivitiesClient client; + readonly ServerlessWorkerRuntimeOptions options; + readonly IReadOnlyCollection registeredActivityNames; + readonly ILogger logger; + readonly IHostApplicationLifetime? lifetime; + readonly ServerlessActivityTracker? activityTracker; + readonly Random reconnectJitter; + readonly SemaphoreSlim streamSync = new(1, 1); + CancellationTokenSource? cts; + IServerlessActivityWorkerSession? session; + Task? pump; + + /// + /// Initializes a new instance of the class. + /// + /// The serverless activities client. + /// The serverless worker runtime options. + /// The activity handlers registered by this worker process. + /// The logger. + /// The optional application lifetime used to stop the host when a non-retriable registration stream failure occurs. + /// The optional activity tracker used to report live in-flight activity count. + /// The optional random source used to jitter reconnect delays. + public ServerlessActivityWorkerRegistrationHostedService( + IServerlessActivitiesClient client, + ServerlessWorkerRuntimeOptions options, + IReadOnlyCollection registeredActivityNames, + ILogger logger, + IHostApplicationLifetime? lifetime = null, + ServerlessActivityTracker? activityTracker = null, + Random? reconnectJitter = null) + { + this.client = Check.NotNull(client); + this.options = Check.NotNull(options); + this.registeredActivityNames = Check.NotNull(registeredActivityNames); + this.logger = Check.NotNull(logger); + this.lifetime = lifetime; + this.activityTracker = activityTracker; + this.reconnectJitter = reconnectJitter ?? Random.Shared; + } + + /// + public Task StartAsync(CancellationToken cancellationToken) + { + if (this.options.Mode != ServerlessMode.ServerlessInclude) + { + this.pump = Task.CompletedTask; + return Task.CompletedTask; + } + + string[] activityNames = ServerlessActivityConfiguration.ResolveActivityNames(this.registeredActivityNames); + if (activityNames.Length == 0) + { + Logs.NoServerlessActivitiesForWorkerRegistration(this.logger, this.options.TaskHub); + this.pump = Task.CompletedTask; + return Task.CompletedTask; + } + + CancellationTokenSource registrationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + Task registrationPump = Task.Run( + () => this.RunRegistrationLoopAsync(activityNames.Length, registrationCts.Token), + CancellationToken.None); + lock (this.sync) + { + this.cts = registrationCts; + this.pump = registrationPump; + } + + return Task.CompletedTask; + } + + /// + public async Task StopAsync(CancellationToken cancellationToken) + { + CancellationTokenSource? localCts; + IServerlessActivityWorkerSession? localSession; + Task? localPump; + lock (this.sync) + { + localCts = this.cts; + localSession = this.session; + localPump = this.pump; + } + + localCts?.Cancel(); + + if (localSession is not null) + { + try + { + await this.CompleteSessionAsync(localSession, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (ex is OperationCanceledException or ObjectDisposedException or RpcException) + { + } + } + + if (localPump is not null) + { + try + { + await localPump.WaitAsync(cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (Exception ex) when (ex is OperationCanceledException or ObjectDisposedException or RpcException) + { + } + } + + lock (this.sync) + { + if (ReferenceEquals(this.cts, localCts)) + { + this.cts = null; + } + + if (ReferenceEquals(this.session, localSession)) + { + this.session = null; + } + + if (ReferenceEquals(this.pump, localPump)) + { + this.pump = Task.CompletedTask; + } + } + + localCts?.Dispose(); + } + + /// + public ValueTask DisposeAsync() => new(this.StopAsync(CancellationToken.None)); + + /// + /// Computes a full-jitter reconnect delay in the range [0, retryDelay). + /// + /// The current exponential retry delay. + /// The random source used for jitter. + /// The jittered reconnect delay. + internal static TimeSpan ComputeJitteredReconnectDelay(TimeSpan retryDelay, Random random) + { + Check.NotNull(random); + if (retryDelay <= TimeSpan.Zero) + { + return TimeSpan.Zero; + } + + long jitteredTicks = (long)(random.NextDouble() * retryDelay.Ticks); + return TimeSpan.FromTicks(jitteredTicks); + } + + static async ValueTask DisposeSessionAsync(IServerlessActivityWorkerSession registrationSession) + { + try + { + await registrationSession.DisposeAsync().ConfigureAwait(false); + } + catch (Exception ex) when (ex is OperationCanceledException or ObjectDisposedException or RpcException) + { + } + } + + static bool IsRetriableRegistrationFailure(Exception exception) => + (exception is OperationCanceledException or ObjectDisposedException or IOException) + || (exception is RpcException rpcException + && rpcException.StatusCode is StatusCode.Cancelled + or StatusCode.DeadlineExceeded + or StatusCode.Internal + or StatusCode.ResourceExhausted + or StatusCode.Unavailable + or StatusCode.Unknown); + + async Task RunRegistrationLoopAsync(int activityCount, CancellationToken cancellationToken) + { + TimeSpan retryDelay = this.GetInitialRetryDelay(); + while (!cancellationToken.IsCancellationRequested) + { + IServerlessActivityWorkerSession? registrationSession = null; + try + { + registrationSession = this.client.OpenServerlessActivityWorkerSession(this.options.TaskHub, cancellationToken); + this.SetCurrentSession(registrationSession); + + Proto.ServerlessActivityWorkerMessage startMessage = ServerlessActivityConfiguration.BuildWorkerStart(this.options, this.registeredActivityNames); + await this.WriteSessionMessageAsync(registrationSession, startMessage, cancellationToken).ConfigureAwait(false); + Logs.ServerlessActivityWorkerRegistered( + this.logger, + startMessage.Start.TaskHub, + activityCount, + startMessage.Start.Substrate, + startMessage.Start.DtsSandboxIdentifier); + + retryDelay = this.GetInitialRetryDelay(); + await this.RunRegistrationSessionAsync(registrationSession, cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + break; + } + catch (Exception ex) when (!IsRetriableRegistrationFailure(ex)) + { + Logs.ServerlessActivityWorkerRegistrationFailed(this.logger, ex, this.options.TaskHub); + this.lifetime?.StopApplication(); + break; + } + catch (Exception ex) + { + Logs.ServerlessActivityWorkerRegistrationFailed(this.logger, ex, this.options.TaskHub); + await this.DelayBeforeReconnectAsync(retryDelay, cancellationToken).ConfigureAwait(false); + retryDelay = this.GetNextRetryDelay(retryDelay); + } + finally + { + if (registrationSession is not null) + { + this.ClearCurrentSession(registrationSession); + await DisposeSessionAsync(registrationSession).ConfigureAwait(false); + } + } + } + } + + async Task RunRegistrationSessionAsync( + IServerlessActivityWorkerSession registrationSession, + CancellationToken cancellationToken) + { + using CancellationTokenSource heartbeatCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + Task heartbeatTask = this.PumpHeartbeatsAsync(registrationSession, heartbeatCts.Token); + Task completionTask = registrationSession.WaitForCompletionAsync(); + Task completedTask = await Task.WhenAny(heartbeatTask, completionTask).ConfigureAwait(false); + + if (ReferenceEquals(completedTask, completionTask)) + { + heartbeatCts.Cancel(); + try + { + await heartbeatTask.ConfigureAwait(false); + } + catch (OperationCanceledException) when (heartbeatCts.IsCancellationRequested) + { + } + catch (Exception) + { + // The server response is authoritative once the response task wins the race. + } + + await completionTask.ConfigureAwait(false); + return; + } + + await heartbeatTask.ConfigureAwait(false); + } + + async Task PumpHeartbeatsAsync( + IServerlessActivityWorkerSession registrationSession, + CancellationToken cancellationToken) + { + using PeriodicTimer timer = new(this.options.HeartbeatInterval); + while (await timer.WaitForNextTickAsync(cancellationToken).ConfigureAwait(false)) + { + int activeActivitiesCount = this.activityTracker?.InFlightCount ?? 0; + await this.WriteSessionMessageAsync( + registrationSession, + ServerlessActivityConfiguration.BuildWorkerHeartbeat(activeActivitiesCount), + cancellationToken).ConfigureAwait(false); + } + } + + async Task WriteSessionMessageAsync( + IServerlessActivityWorkerSession registrationSession, + Proto.ServerlessActivityWorkerMessage message, + CancellationToken cancellationToken) + { + await this.streamSync.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + cancellationToken.ThrowIfCancellationRequested(); + await registrationSession.WriteMessageAsync(message).ConfigureAwait(false); + } + finally + { + this.streamSync.Release(); + } + } + + async Task CompleteSessionAsync( + IServerlessActivityWorkerSession registrationSession, + CancellationToken cancellationToken) + { + await this.streamSync.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + await registrationSession.CompleteAsync().WaitAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + this.streamSync.Release(); + } + } + + void SetCurrentSession(IServerlessActivityWorkerSession registrationSession) + { + lock (this.sync) + { + this.session = registrationSession; + } + } + + void ClearCurrentSession(IServerlessActivityWorkerSession registrationSession) + { + lock (this.sync) + { + if (ReferenceEquals(this.session, registrationSession)) + { + this.session = null; + } + } + } + + TimeSpan GetInitialRetryDelay() => + this.options.WorkerRegistrationRetryInitialDelay <= this.options.WorkerRegistrationRetryMaxDelay + ? this.options.WorkerRegistrationRetryInitialDelay + : this.options.WorkerRegistrationRetryMaxDelay; + + TimeSpan GetNextRetryDelay(TimeSpan retryDelay) + { + if (retryDelay <= TimeSpan.Zero) + { + return retryDelay; + } + + long nextTicks = Math.Min(retryDelay.Ticks * 2, this.options.WorkerRegistrationRetryMaxDelay.Ticks); + return TimeSpan.FromTicks(nextTicks); + } + + async Task DelayBeforeReconnectAsync(TimeSpan retryDelay, CancellationToken cancellationToken) + { + TimeSpan jitteredDelay = ComputeJitteredReconnectDelay(retryDelay, this.reconnectJitter); + if (jitteredDelay > TimeSpan.Zero) + { + await Task.Delay(jitteredDelay, cancellationToken).ConfigureAwait(false); + } + } +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessOptions.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessOptions.cs new file mode 100644 index 00000000..e3d07e11 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessOptions.cs @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Options for declaring serverless activities and the worker image DTS should start for them. +/// +public sealed class ServerlessOptions +{ + /// + /// Default worker profile ID used when no profile is specified. + /// + internal const string DefaultWorkerProfileId = "default"; + + /// + /// Gets the serverless activity names to declare. Remote workers report their registered + /// activities separately when they connect. + /// + public IList ActivityNames { get; } = new List(); + + /// + /// Gets or sets the task hub where the serverless activity declaration is stored. + /// + public string TaskHub { get; set; } = string.Empty; + + /// + /// Gets or sets the worker profile ID used for the serverless activity pool. + /// + public string WorkerProfileId { get; set; } = DefaultWorkerProfileId; + + /// + /// Gets or sets the full container image reference for serverless workers. + /// + public string? ContainerImage { get; set; } + + /// + /// Gets or sets the registry server for the serverless worker image. + /// + public string? RegistryServer { get; set; } + + /// + /// Gets or sets the repository for the serverless worker image. + /// + public string? Repository { get; set; } + + /// + /// Gets or sets the tag for the serverless worker image. + /// + public string? Tag { get; set; } + + /// + /// Gets or sets the digest for the serverless worker image. + /// + public string? ImageDigest { get; set; } + + /// + /// Gets or sets the CPU quantity declared for each serverless sandbox. + /// + public string Cpu { get; set; } = "1000m"; + + /// + /// Gets or sets the memory quantity declared for each serverless sandbox. + /// + public string Memory { get; set; } = "2048Mi"; + + /// + /// Gets custom environment variables DTS should provide to serverless workers created from this declaration. + /// DTS-owned runtime variables such as DTS_ENDPOINT, DTS_TASK_HUB, and + /// DTS_SANDBOX_ID are injected by the backend and should not be supplied here. + /// + public IDictionary EnvironmentVariables { get; } = new Dictionary(StringComparer.Ordinal); + + /// + /// Gets the sandbox entrypoint declared for serverless workers. + /// + public IList Entrypoint { get; } = new List(); + + /// + /// Gets the sandbox command declared for serverless workers. + /// + public IList Cmd { get; } = new List(); + + /// + /// Gets or sets the maximum number of concurrent activities expected from each serverless worker. + /// + public int MaxConcurrentActivities { get; set; } = 100; + + /// + /// Adds an activity name to the serverless declaration. + /// + /// The activity name to execute serverlessly. + /// The current options instance. + public ServerlessOptions AddActivity(string activityName) + { + if (string.IsNullOrWhiteSpace(activityName)) + { + throw new ArgumentException("Serverless activity name cannot be empty.", nameof(activityName)); + } + + this.ActivityNames.Add(activityName.Trim()); + return this; + } + + /// + /// Adds an activity type to the serverless declaration. + /// + /// The activity type to execute serverlessly. + /// The current options instance. + public ServerlessOptions AddActivity() + where TActivity : class, ITaskActivity + { + return this.AddActivity(GetTaskName(typeof(TActivity))); + } + + static string GetTaskName(Type type) + { + Check.NotNull(type); + return Attribute.GetCustomAttribute(type, typeof(DurableTaskAttribute)) is DurableTaskAttribute { Name.Name: not null and not "" } attr + ? attr.Name.Name + : type.Name; + } +} diff --git a/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessWorkerRuntimeOptions.cs b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessWorkerRuntimeOptions.cs new file mode 100644 index 00000000..16852746 --- /dev/null +++ b/src/Extensions/AzureManagedServerless/Worker/Serverless/ServerlessWorkerRuntimeOptions.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Worker.AzureManaged.Serverless; + +/// +/// Defines how a worker participates in serverless activity execution. +/// +internal enum ServerlessMode +{ + /// + /// The worker is not running inside serverless infrastructure. + /// + LocalExclude, + + /// + /// The worker runs inside serverless infrastructure and executes only serverless activities. + /// + ServerlessInclude, +} + +/// +/// Internal runtime settings for a sandbox serverless worker process. +/// +internal sealed class ServerlessWorkerRuntimeOptions +{ + /// + /// Gets or sets the task hub used by serverless worker registration. + /// + public string TaskHub { get; set; } = string.Empty; + + /// + /// Gets or sets the worker profile ID used by serverless worker registration. + /// + public string WorkerProfileId { get; set; } = ServerlessOptions.DefaultWorkerProfileId; + + /// + /// Gets or sets the maximum number of concurrent activities expected from this serverless worker. + /// + public int MaxConcurrentActivities { get; set; } = 100; + + /// + /// Gets or sets the interval used to refresh live worker capacity while the registration stream is open. + /// + public TimeSpan HeartbeatInterval { get; set; } = TimeSpan.FromSeconds(2); + + /// + /// Gets or sets the initial delay before retrying a failed worker registration stream. + /// + public TimeSpan WorkerRegistrationRetryInitialDelay { get; set; } = TimeSpan.FromSeconds(1); + + /// + /// Gets or sets the maximum delay before retrying a failed worker registration stream. + /// + public TimeSpan WorkerRegistrationRetryMaxDelay { get; set; } = TimeSpan.FromSeconds(30); + + /// + /// Gets or sets the worker mode for serverless activity execution. Set automatically from the runtime environment. + /// + public ServerlessMode Mode { get; set; } = ServerlessMode.LocalExclude; +} diff --git a/src/Grpc/orchestrator_service.proto b/src/Grpc/orchestrator_service.proto index 3d7c8eb4..f782a5fe 100644 --- a/src/Grpc/orchestrator_service.proto +++ b/src/Grpc/orchestrator_service.proto @@ -856,6 +856,11 @@ message WorkItemFilters { repeated OrchestrationFilter orchestrations = 1; repeated ActivityFilter activities = 2; repeated EntityFilter entities = 3; + // Activities the worker explicitly does NOT want to process. When set, + // matching activity work items are skipped for this connection even if + // they would otherwise match `activities`. Mutually exclusive with + // `activities` for the same name. + repeated ActivityFilter exclude_activities = 4; } message OrchestrationFilter { diff --git a/src/Grpc/serverless_activities_service.proto b/src/Grpc/serverless_activities_service.proto new file mode 100644 index 00000000..153d62db --- /dev/null +++ b/src/Grpc/serverless_activities_service.proto @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +syntax = "proto3"; + +package microsoft.durabletask.serverless; + +option csharp_namespace = "Microsoft.DurableTask.Protobuf.Serverless"; + +service ServerlessActivities { + // Opens a live serverless activity worker session. The first message must be a + // start message with static worker metadata. Heartbeats carry dynamic state + // only. Closing the stream deregisters the worker. + rpc ConnectServerlessActivityWorker(stream ServerlessActivityWorkerMessage) returns (ServerlessActivityWorkerSessionResult); + + // Declares serverless activities before any live worker stream exists. This is a + // configuration contract and does not advertise active worker capacity. + rpc DeclareServerlessActivities(ServerlessActivityDeclaration) returns (ServerlessActivityDeclarationResult); + + // Removes a serverless activity declaration so the backend stops waking new workers + // for the specified worker profile. Existing workers are not terminated by this RPC. + rpc RemoveServerlessActivityDeclaration(RemoveServerlessActivityDeclarationRequest) returns (RemoveServerlessActivityDeclarationResult); +} + +message ServerlessActivityWorkerMessage { + oneof message { + ServerlessActivityWorkerStart start = 1; + ServerlessActivityWorkerHeartbeat heartbeat = 2; + } +} + +message ServerlessActivityWorkerStart { + reserved 2; + reserved "worker_instance_id"; + + string task_hub = 1; + int32 max_activities_count = 3; + // Substrate the worker is running in. UNSPECIFIED = legacy (pre-substrate-aware) workers. + SubstrateKind substrate = 4; + // DTS-generated sandbox identifier injected as DTS_SANDBOX_ID. This is not + // the ADC provider sandbox resource id. + string dts_sandbox_identifier = 5; + string worker_profile_id = 6; + // Activity handlers registered by the worker process. DTS validates this + // matches the declaration before advertising worker capacity. + repeated string activity_names = 7; +} + +message ServerlessActivityWorkerHeartbeat { + int32 active_activities_count = 1; +} + +message ServerlessActivityWorkerSessionResult { + bool accepted = 1; + string message = 2; +} + +message ServerlessActivityDeclaration { + string worker_profile_id = 2; + repeated string activity_names = 3; + ServerlessActivityImage image = 4; + map environment_variables = 5; + int32 max_concurrent_activities = 6; + ServerlessActivityResources resources = 7; + repeated string entrypoint = 8; + repeated string cmd = 9; +} + +message ServerlessActivityImage { + string image_ref = 1; +} + +message ServerlessActivityResources { + string cpu = 1; + string memory = 2; +} + +message ServerlessActivityDeclarationResult { +} + +message RemoveServerlessActivityDeclarationRequest { + string worker_profile_id = 1; +} + +message RemoveServerlessActivityDeclarationResult { +} + +// Compute substrate executing the activity worker. +enum SubstrateKind { + SUBSTRATE_KIND_UNSPECIFIED = 0; + SUBSTRATE_KIND_ACA_SESSION_POOL = 1; + SUBSTRATE_KIND_SANDBOX = 2; +} diff --git a/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs b/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs index 3b9d4e55..2b832a54 100644 --- a/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs +++ b/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs @@ -300,6 +300,7 @@ and not AccessViolationException } } } + GC.SuppressFinalize(this); } diff --git a/src/Worker/Core/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs b/src/Worker/Core/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs index a9274078..50e967ae 100644 --- a/src/Worker/Core/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs +++ b/src/Worker/Core/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs @@ -171,12 +171,14 @@ public static IDurableTaskWorkerBuilder UseWorkItemFilters(this IDurableTaskWork { opts.Orchestrations = []; opts.Activities = []; + opts.ExcludedActivities = []; opts.Entities = []; } else { opts.Orchestrations = workItemFilters.Orchestrations; opts.Activities = workItemFilters.Activities; + opts.ExcludedActivities = workItemFilters.ExcludedActivities; opts.Entities = workItemFilters.Entities; } }); @@ -194,6 +196,7 @@ public static IDurableTaskWorkerBuilder UseWorkItemFilters(this IDurableTaskWork if (workItemFilters is not null && (workItemFilters.Orchestrations.Count > 0 || workItemFilters.Activities.Count > 0 + || workItemFilters.ExcludedActivities.Count > 0 || workItemFilters.Entities.Count > 0)) { builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton< diff --git a/src/Worker/Core/DependencyInjection/DurableTaskWorkerWorkItemFiltersValidator.cs b/src/Worker/Core/DependencyInjection/DurableTaskWorkerWorkItemFiltersValidator.cs index c8eefb12..bcb45212 100644 --- a/src/Worker/Core/DependencyInjection/DurableTaskWorkerWorkItemFiltersValidator.cs +++ b/src/Worker/Core/DependencyInjection/DurableTaskWorkerWorkItemFiltersValidator.cs @@ -42,6 +42,7 @@ public ValidateOptionsResult Validate(string? name, DurableTaskWorkerWorkItemFil // reports a verdict for workers that actually configured filters. if (options.Orchestrations.Count == 0 && options.Activities.Count == 0 + && options.ExcludedActivities.Count == 0 && options.Entities.Count == 0) { return ValidateOptionsResult.Skip; @@ -53,11 +54,14 @@ public ValidateOptionsResult Validate(string? name, DurableTaskWorkerWorkItemFil options.Orchestrations.Select(o => o.Name), n => registry.Orchestrators.ContainsKey(n)); List unknownActivities = FindUnknown( options.Activities.Select(a => a.Name), n => registry.Activities.ContainsKey(n)); + List unknownExcludedActivities = FindUnknown( + options.ExcludedActivities.Select(a => a.Name), n => registry.Activities.ContainsKey(n)); List unknownEntities = FindUnknown( options.Entities.Select(e => e.Name), n => registry.Entities.ContainsKey(n)); if (unknownOrchestrations.Count == 0 && unknownActivities.Count == 0 + && unknownExcludedActivities.Count == 0 && unknownEntities.Count == 0) { return ValidateOptionsResult.Success; @@ -71,6 +75,7 @@ public ValidateOptionsResult Validate(string? name, DurableTaskWorkerWorkItemFil .Append("or remove them from the filters."); AppendCategory(sb, "Orchestrations", unknownOrchestrations); AppendCategory(sb, "Activities", unknownActivities); + AppendCategory(sb, "ExcludedActivities", unknownExcludedActivities); AppendCategory(sb, "Entities", unknownEntities); return ValidateOptionsResult.Fail(sb.ToString()); diff --git a/src/Worker/Core/DurableTaskWorkerWorkItemFilters.cs b/src/Worker/Core/DurableTaskWorkerWorkItemFilters.cs index 8a5df2f1..ec24fd2e 100644 --- a/src/Worker/Core/DurableTaskWorkerWorkItemFilters.cs +++ b/src/Worker/Core/DurableTaskWorkerWorkItemFilters.cs @@ -22,6 +22,11 @@ public class DurableTaskWorkerWorkItemFilters /// public IReadOnlyList Activities { get; set; } = []; + /// + /// Gets or sets the activity filters that should be excluded from this worker connection. + /// + public IReadOnlyList ExcludedActivities { get; set; } = []; + /// /// Gets or sets the entity filters. /// @@ -55,6 +60,7 @@ internal static DurableTaskWorkerWorkItemFilters FromDurableTaskRegistry(Durable Name = activity.Key, Versions = versions, }).ToList(), + ExcludedActivities = [], Entities = registry.Entities.Select(entity => new EntityFilter { // Entity names are normalized to lowercase in the backend. diff --git a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs index 4c5a18b2..40a2240b 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs @@ -156,7 +156,7 @@ await this.ProcessWorkItemsAsync( this.internalOptions.ReconnectBackoffBase, this.internalOptions.ReconnectBackoffCap, backoffRandom, - fullJitter: true); + fullJitter: true); this.Logger.ReconnectBackoff(reconnectAttempt, (int)delay.TotalMilliseconds); reconnectAttempt++; await Task.Delay(delay, cancellation); @@ -405,12 +405,23 @@ void DispatchWorkItem(P.WorkItem workItem, CancellationToken cancellation) } else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.ActivityRequest) { + this.internalOptions.NotifyActivity?.Invoke(ActivityNotificationPhase.Started); this.RunBackgroundTask( workItem, - () => this.OnRunActivityAsync( - workItem.ActivityRequest, - workItem.CompletionToken, - cancellation), + async () => + { + try + { + await this.OnRunActivityAsync( + workItem.ActivityRequest, + workItem.CompletionToken, + cancellation).ConfigureAwait(false); + } + finally + { + this.internalOptions.NotifyActivity?.Invoke(ActivityNotificationPhase.Completed); + } + }, cancellation); } else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequest) diff --git a/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs b/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs index 49bd6350..59c21a00 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Microsoft.DurableTask.Worker.Grpc.Internal; using P = Microsoft.DurableTask.Protobuf; namespace Microsoft.DurableTask.Worker.Grpc; @@ -165,5 +166,10 @@ internal class InternalOptions /// deferring disposal of the old channel so in-flight RPCs already using it are not interrupted. /// public Func>? ChannelRecreator { get; set; } + + /// + /// Gets or sets a callback that is invoked when activity work items are received or finished. + /// + public Action? NotifyActivity { get; set; } } } diff --git a/src/Worker/Grpc/Internal/DurableTaskWorkerWorkItemFiltersExtension.cs b/src/Worker/Grpc/Internal/DurableTaskWorkerWorkItemFiltersExtension.cs index 176d376c..63c2b052 100644 --- a/src/Worker/Grpc/Internal/DurableTaskWorkerWorkItemFiltersExtension.cs +++ b/src/Worker/Grpc/Internal/DurableTaskWorkerWorkItemFiltersExtension.cs @@ -39,6 +39,16 @@ public static P.WorkItemFilters ToGrpcWorkItemFilters(this DurableTaskWorkerWork grpcWorkItemFilters.Activities.Add(grpcActivityFilter); } + foreach (var activityFilter in workItemFilter.ExcludedActivities) + { + var grpcActivityFilter = new P.ActivityFilter + { + Name = activityFilter.Name, + }; + grpcActivityFilter.Versions.AddRange(activityFilter.Versions); + grpcWorkItemFilters.ExcludeActivities.Add(grpcActivityFilter); + } + foreach (var entityFilter in workItemFilter.Entities) { var grpcEntityFilter = new P.EntityFilter diff --git a/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs b/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs index b26b36cc..764db9f7 100644 --- a/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs +++ b/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs @@ -7,6 +7,22 @@ namespace Microsoft.DurableTask.Worker.Grpc.Internal; +/// +/// Identifies the phase of activity execution being reported to internal worker hooks. +/// +public enum ActivityNotificationPhase +{ + /// + /// The worker has received and started processing an activity work item. + /// + Started, + + /// + /// The worker has finished processing an activity work item. + /// + Completed, +} + /// /// Provides access to configuring internal options for the gRPC worker. /// @@ -28,6 +44,24 @@ public static void ConfigureForAzureManaged(this GrpcDurableTaskWorkerOptions op options.Internal.InsertEntityUnlocksOnCompletion = true; } + /// + /// Registers a callback invoked when activity work items start and finish execution. + /// + /// The gRPC worker options. + /// The activity notification callback. + /// + /// This is an internal API that supports the DurableTask infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new DurableTask release. + /// + public static void ConfigureActivityNotification( + this GrpcDurableTaskWorkerOptions options, + Action notification) + { + options.Internal.NotifyActivity += notification ?? throw new ArgumentNullException(nameof(notification)); + } + /// /// Sets a callback that the worker invokes when the underlying gRPC channel needs to be recreated /// after repeated connect failures (e.g., because the backend was replaced and the existing channel diff --git a/test/Extensions/AzureManagedServerless.Tests/AzureManagedServerless.Tests.csproj b/test/Extensions/AzureManagedServerless.Tests/AzureManagedServerless.Tests.csproj new file mode 100644 index 00000000..a03e480e --- /dev/null +++ b/test/Extensions/AzureManagedServerless.Tests/AzureManagedServerless.Tests.csproj @@ -0,0 +1,16 @@ + + + + net10.0 + + + + + + + + + + + + diff --git a/test/Extensions/AzureManagedServerless.Tests/ServerlessActivitiesClientExtensionsTests.cs b/test/Extensions/AzureManagedServerless.Tests/ServerlessActivitiesClientExtensionsTests.cs new file mode 100644 index 00000000..9db1c9ad --- /dev/null +++ b/test/Extensions/AzureManagedServerless.Tests/ServerlessActivitiesClientExtensionsTests.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using FluentAssertions; +using Grpc.Core; +using Microsoft.DurableTask.Client.Grpc; +using Microsoft.DurableTask.Protobuf.Serverless; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.DurableTask.Client.AzureManaged.Tests; + +public class ServerlessActivitiesClientExtensionsTests +{ + [Fact] + public async Task AddDurableTaskSchedulerServerlessActivitiesClient_UsesConfiguredDurableTaskClientInvoker() + { + // Arrange + RecordingServerlessLogCallInvoker callInvoker = new(); + ServiceCollection services = new(); + services.AddOptions(Options.DefaultName) + .Configure(options => options.CallInvoker = callInvoker); + services.AddDurableTaskSchedulerServerlessActivitiesClient(); + + using ServiceProvider provider = services.BuildServiceProvider(); + ServerlessActivitiesClient client = provider.GetRequiredService(); + + // Act + await client.RemoveServerlessActivityDeclarationAsync("default"); + + // Assert + callInvoker.RemoveRequest.Should().NotBeNull(); + callInvoker.RemoveRequest!.WorkerProfileId.Should().Be("default"); + } + + [Fact] + public async Task RemoveServerlessActivityDeclarationAsync_SendsRequest() + { + // Arrange + RecordingServerlessLogCallInvoker callInvoker = new(); + ServerlessActivities.ServerlessActivitiesClient client = new(callInvoker); + + // Act + await client.RemoveServerlessActivityDeclarationAsync("default"); + + // Assert + callInvoker.RemoveRequest.Should().NotBeNull(); + callInvoker.RemoveRequest!.WorkerProfileId.Should().Be("default"); + callInvoker.RemoveHeaders.Should().NotContain(header => header.Key == "taskhub"); + callInvoker.UnaryDisposeCount.Should().Be(1); + } + + sealed class RecordingServerlessLogCallInvoker : CallInvoker + { + public RemoveServerlessActivityDeclarationRequest? RemoveRequest { get; private set; } + + public Metadata RemoveHeaders { get; private set; } = []; + + public int UnaryDisposeCount { get; private set; } + + public override TResponse BlockingUnaryCall( + Method method, + string? host, + CallOptions options, + TRequest request) + { + throw new NotSupportedException(); + } + + public override AsyncUnaryCall AsyncUnaryCall( + Method method, + string? host, + CallOptions options, + TRequest request) + { + method.FullName.Should().EndWith("/RemoveServerlessActivityDeclaration"); + this.RemoveRequest = (RemoveServerlessActivityDeclarationRequest)(object)request; + this.RemoveHeaders = options.Headers ?? []; + + return new AsyncUnaryCall( + Task.FromResult((TResponse)(object)new RemoveServerlessActivityDeclarationResult()), + Task.FromResult(new Metadata()), + () => new Status(StatusCode.OK, string.Empty), + () => new Metadata(), + () => this.UnaryDisposeCount++); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall( + Method method, + string? host, + CallOptions options, + TRequest request) + { + throw new NotSupportedException(); + } + + public override AsyncClientStreamingCall AsyncClientStreamingCall( + Method method, + string? host, + CallOptions options) + { + throw new NotSupportedException(); + } + + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall( + Method method, + string? host, + CallOptions options) + { + throw new NotSupportedException(); + } + } +} diff --git a/test/Extensions/AzureManagedServerless.Tests/ServerlessActivitiesTests.cs b/test/Extensions/AzureManagedServerless.Tests/ServerlessActivitiesTests.cs new file mode 100644 index 00000000..f2d8b462 --- /dev/null +++ b/test/Extensions/AzureManagedServerless.Tests/ServerlessActivitiesTests.cs @@ -0,0 +1,957 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; +using FluentAssertions; +using Grpc.Core; +using Microsoft.DurableTask.Protobuf.Serverless; +using Microsoft.DurableTask.Worker.AzureManaged; +using Microsoft.DurableTask.Worker.AzureManaged.Serverless; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Microsoft.DurableTask.Worker.AzureManaged.Tests; + +public class ServerlessActivitiesTests +{ + const string TaskHub = "testhub"; + + [Fact] + public void ServerlessDeclarationContract_DoesNotExposeRemovedOptions() + { + typeof(ServerlessOptions).GetProperty("LaunchCommand").Should().BeNull(); + typeof(ServerlessOptions).GetProperty("DeclarationRetryMaxAttempts").Should().BeNull(); + typeof(ServerlessOptions).GetProperty("DeclarationRetryDelay").Should().BeNull(); + typeof(ServerlessOptions).GetProperty( + "HeartbeatInterval", + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic).Should().BeNull(); + typeof(ServerlessOptions).GetProperty("WakeupPort").Should().BeNull(); + typeof(ServerlessOptions).GetProperty( + "WorkerRegistrationRetryInitialDelay", + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic).Should().BeNull(); + typeof(ServerlessOptions).GetProperty( + "WorkerRegistrationRetryMaxDelay", + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic).Should().BeNull(); + typeof(ServerlessOptions).GetProperty( + "Mode", + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic).Should().BeNull(); + typeof(ServerlessActivityDeclaration).GetProperty("LaunchCommand").Should().BeNull(); + } + + [Fact] + public void ServerlessOptions_AddActivity_AddsStringAndTypedActivityNames() + { + // Arrange + ServerlessOptions options = new(); + + // Act + options + .AddActivity(" RemoteHello ") + .AddActivity(); + + // Assert + options.ActivityNames.Should().Equal("RemoteHello", "TypedRemoteHello"); + } + + [Fact] + public async Task ServerlessActivityDeclarationHostedService_SendsDeclarationPayload() + { + // Arrange + ServerlessOptions options = new() + { + TaskHub = TaskHub, + WorkerProfileId = "profile-a", + ContainerImage = "mcr.microsoft.com/durabletask/demo-worker:1.0", + Cpu = "500m", + Memory = "1024Mi", + MaxConcurrentActivities = 7, + }; + options.AddActivity("RemoteHello"); + options.EnvironmentVariables.Add("CUSTOM_SETTING", "enabled"); + options.Entrypoint.Add("/usr/bin/tini"); + options.Entrypoint.Add("--"); + options.Cmd.Add("dotnet"); + options.Cmd.Add("/app/DemoWorker.dll"); + FakeServerlessActivitiesClient client = new(); + ServerlessActivityDeclarationHostedService service = new( + client, + options, + runtimeOptions: null, + NullLogger.Instance); + + // Act + await service.StartAsync(CancellationToken.None); + + // Assert + ServerlessActivityDeclaration declaration = client.Declarations.Should().ContainSingle().Subject; + client.DeclarationTaskHubs.Should().Equal(TaskHub); + declaration.WorkerProfileId.Should().Be("profile-a"); + declaration.ActivityNames.Should().Equal("RemoteHello"); + declaration.Image.ImageRef.Should().Be("mcr.microsoft.com/durabletask/demo-worker:1.0"); + declaration.Resources.Cpu.Should().Be("500m"); + declaration.Resources.Memory.Should().Be("1024Mi"); + declaration.EnvironmentVariables.Should().ContainKey("CUSTOM_SETTING").WhoseValue.Should().Be("enabled"); + declaration.Entrypoint.Should().Equal("/usr/bin/tini", "--"); + declaration.Cmd.Should().Equal("dotnet", "/app/DemoWorker.dll"); + declaration.MaxConcurrentActivities.Should().Be(7); + } + + [Fact] + public async Task ServerlessActivitiesClientAdapter_SendsTaskHubMetadata() + { + // Arrange + RecordingServerlessActivitiesCallInvoker callInvoker = new(); + ServerlessActivitiesClientAdapter adapter = new(new ServerlessActivities.ServerlessActivitiesClient(callInvoker)); + ServerlessActivityDeclaration declaration = new() + { + WorkerProfileId = "profile-a", + Image = new ServerlessActivityImage + { + ImageRef = "example.com/repo/worker:latest", + }, + Resources = new ServerlessActivityResources + { + Cpu = "500m", + Memory = "1024Mi", + }, + MaxConcurrentActivities = 7, + }; + declaration.ActivityNames.Add("RemoteHello"); + + // Act + await adapter.DeclareServerlessActivitiesAsync(declaration, TaskHub, CancellationToken.None); + await using IServerlessActivityWorkerSession session = adapter.OpenServerlessActivityWorkerSession( + TaskHub, + CancellationToken.None); + + // Assert + callInvoker.DeclarationHeaders.Should().Contain(header => header.Key == "taskhub" && header.Value == TaskHub); + callInvoker.WorkerSessionHeaders.Should().Contain(header => header.Key == "taskhub" && header.Value == TaskHub); + } + + [Fact] + public async Task ServerlessActivitiesClientAdapter_CanRelyOnChannelTaskHubMetadata() + { + // Arrange + RecordingServerlessActivitiesCallInvoker callInvoker = new(); + ServerlessActivitiesClientAdapter adapter = new( + new ServerlessActivities.ServerlessActivitiesClient(callInvoker), + attachTaskHubMetadata: false); + ServerlessActivityDeclaration declaration = new() + { + WorkerProfileId = "profile-a", + Image = new ServerlessActivityImage + { + ImageRef = "example.com/repo/worker:latest", + }, + Resources = new ServerlessActivityResources + { + Cpu = "500m", + Memory = "1024Mi", + }, + MaxConcurrentActivities = 7, + }; + declaration.ActivityNames.Add("RemoteHello"); + + // Act + await adapter.DeclareServerlessActivitiesAsync(declaration, TaskHub, CancellationToken.None); + await using IServerlessActivityWorkerSession session = adapter.OpenServerlessActivityWorkerSession( + TaskHub, + CancellationToken.None); + + // Assert + callInvoker.DeclarationHeaders.Should().NotContain(header => header.Key == "taskhub"); + callInvoker.WorkerSessionHeaders.Should().NotContain(header => header.Key == "taskhub"); + } + + [Fact] + public async Task ServerlessActivityDeclarationHostedService_OmitsEntrypointAndCmdByDefault() + { + // Arrange + ServerlessOptions options = new() + { + TaskHub = TaskHub, + ContainerImage = "mcr.microsoft.com/durabletask/demo-worker:1.0", + }; + options.ActivityNames.Add("RemoteHello"); + FakeServerlessActivitiesClient client = new(); + ServerlessActivityDeclarationHostedService service = new( + client, + options, + runtimeOptions: null, + NullLogger.Instance); + + // Act + await service.StartAsync(CancellationToken.None); + + // Assert + ServerlessActivityDeclaration declaration = client.Declarations.Should().ContainSingle().Subject; + declaration.Entrypoint.Should().BeEmpty(); + declaration.Cmd.Should().BeEmpty(); + } + + [Fact] + public async Task ServerlessActivityDeclarationHostedService_SkipsDeclarationWhenNamesAreEmpty() + { + // Arrange + ServerlessOptions options = new() + { + TaskHub = TaskHub, + ContainerImage = "example.com/repo/worker:latest", + }; + FakeServerlessActivitiesClient client = new(); + ServerlessActivityDeclarationHostedService service = new( + client, + options, + runtimeOptions: null, + NullLogger.Instance); + + // Act + await service.StartAsync(CancellationToken.None); + + // Assert + client.Declarations.Should().BeEmpty(); + } + + [Fact] + public async Task ServerlessActivityDeclarationHostedService_DoesNotRetryTransientFailures() + { + // Arrange + ServerlessOptions options = new() + { + TaskHub = TaskHub, + ContainerImage = "example.com/repo/worker@sha256:abc", + }; + options.ActivityNames.Add("RemoteHello"); + FakeServerlessActivitiesClient client = new() { TransientDeclarationFailures = 1 }; + ServerlessActivityDeclarationHostedService service = new( + client, + options, + runtimeOptions: null, + NullLogger.Instance); + + // Act + Func action = () => service.StartAsync(CancellationToken.None); + + // Assert + await action.Should().ThrowAsync() + .Where(exception => exception.StatusCode == StatusCode.Unavailable); + client.DeclarationAttempts.Should().Be(1); + client.Declarations.Should().BeEmpty(); + } + + [Fact] + public async Task ServerlessActivityWorkerRegistrationHostedService_SendsLiveWorkerMetadataWithRegisteredActivities() + { + // Arrange + string? originalSubstrate = Environment.GetEnvironmentVariable("DTS_SUBSTRATE"); + string? originalSandboxId = Environment.GetEnvironmentVariable("DTS_SANDBOX_ID"); + Environment.SetEnvironmentVariable("DTS_SUBSTRATE", "Sandbox"); + Environment.SetEnvironmentVariable("DTS_SANDBOX_ID", "sandbox-1"); + + try + { + ServerlessWorkerRuntimeOptions options = new() + { + Mode = ServerlessMode.ServerlessInclude, + TaskHub = TaskHub, + WorkerProfileId = "profile-a", + MaxConcurrentActivities = 3, + HeartbeatInterval = TimeSpan.FromDays(1), + }; + FakeServerlessActivitiesClient client = new(); + ServerlessActivityWorkerRegistrationHostedService service = new( + client, + options, + ["RemoteHello"], + NullLogger.Instance); + + // Act + await service.StartAsync(CancellationToken.None); + await client.Session.WaitForMessageAsync(message => message.Start != null); + await service.StopAsync(CancellationToken.None); + + // Assert + client.SessionTaskHubs.Should().Equal(TaskHub); + ServerlessActivityWorkerMessage message = client.Session.Messages.Should().ContainSingle().Subject; + ServerlessActivityWorkerStart start = message.Start; + start.TaskHub.Should().Be(TaskHub); + start.WorkerProfileId.Should().Be("profile-a"); + start.MaxActivitiesCount.Should().Be(3); + start.Substrate.Should().Be(SubstrateKind.Sandbox); + start.DtsSandboxIdentifier.Should().Be("sandbox-1"); + start.ActivityNames.Should().Equal("RemoteHello"); + } + finally + { + Environment.SetEnvironmentVariable("DTS_SUBSTRATE", originalSubstrate); + Environment.SetEnvironmentVariable("DTS_SANDBOX_ID", originalSandboxId); + } + } + + [Fact] + public void ServerlessActivityTracker_TracksInFlightActivityCount() + { + // Arrange + ServerlessActivityTracker activityTracker = new(); + + // Act + activityTracker.NotifyActivityStarted(); + activityTracker.NotifyActivityStarted(); + + // Assert + activityTracker.InFlightCount.Should().Be(2); + + // Act + activityTracker.NotifyActivityCompleted(); + + // Assert + activityTracker.InFlightCount.Should().Be(1); + + // Act + activityTracker.NotifyActivityCompleted(); + activityTracker.NotifyActivityCompleted(); + + // Assert + activityTracker.InFlightCount.Should().Be(0); + } + + [Fact] + public async Task ServerlessActivityWorkerRegistrationHostedService_SendsHeartbeatWithCurrentInFlightCount() + { + // Arrange + ServerlessWorkerRuntimeOptions options = new() + { + Mode = ServerlessMode.ServerlessInclude, + TaskHub = TaskHub, + WorkerProfileId = "profile-a", + MaxConcurrentActivities = 3, + HeartbeatInterval = TimeSpan.FromMilliseconds(10), + }; + + FakeServerlessActivitiesClient client = new(); + ServerlessActivityTracker activityTracker = new(); + activityTracker.NotifyActivityStarted(); + activityTracker.NotifyActivityStarted(); + + ServerlessActivityWorkerRegistrationHostedService service = new( + client, + options, + ["RemoteHello"], + NullLogger.Instance, + activityTracker: activityTracker); + + // Act + await service.StartAsync(CancellationToken.None); + await client.Session.WaitForMessageAsync(message => message.Heartbeat?.ActiveActivitiesCount == 2); + activityTracker.NotifyActivityCompleted(); + await client.Session.WaitForMessageAsync(message => message.Heartbeat?.ActiveActivitiesCount == 1); + await service.StopAsync(CancellationToken.None); + + // Assert + client.Session.Messages.Should().Contain(message => message.Heartbeat != null && message.Heartbeat.ActiveActivitiesCount == 2); + client.Session.Messages.Should().Contain(message => message.Heartbeat != null && message.Heartbeat.ActiveActivitiesCount == 1); + } + + [Fact] + public async Task ServerlessActivityWorkerRegistrationHostedService_ReopensSessionAfterTransientStreamFailure() + { + // Arrange + ServerlessWorkerRuntimeOptions options = new() + { + Mode = ServerlessMode.ServerlessInclude, + TaskHub = TaskHub, + WorkerProfileId = "profile-a", + MaxConcurrentActivities = 3, + HeartbeatInterval = TimeSpan.FromMilliseconds(10), + WorkerRegistrationRetryInitialDelay = TimeSpan.FromMilliseconds(10), + WorkerRegistrationRetryMaxDelay = TimeSpan.FromMilliseconds(10), + }; + + FakeServerlessActivityWorkerSession failedSession = new() { ThrowOnWriteAttempt = 2 }; + FakeServerlessActivityWorkerSession recoveredSession = new(); + FakeServerlessActivitiesClient client = new(); + client.QueueSession(failedSession); + client.QueueSession(recoveredSession); + + ServerlessActivityWorkerRegistrationHostedService service = new( + client, + options, + ["RemoteHello"], + NullLogger.Instance); + + // Act + await service.StartAsync(CancellationToken.None); + await failedSession.WaitForMessageAsync(message => message.Start != null); + await recoveredSession.WaitForMessageAsync(message => message.Start != null); + await service.StopAsync(CancellationToken.None); + + // Assert + client.SessionTaskHubs.Should().Equal(TaskHub, TaskHub); + failedSession.Messages.Should().ContainSingle(message => message.Start != null); + recoveredSession.Messages.Should().ContainSingle(message => message.Start != null); + } + + [Fact] + public async Task ServerlessActivityWorkerRegistrationHostedService_ReopensSessionAfterTerminalServerFailure() + { + // Arrange + ServerlessWorkerRuntimeOptions options = new() + { + Mode = ServerlessMode.ServerlessInclude, + TaskHub = TaskHub, + WorkerProfileId = "profile-a", + MaxConcurrentActivities = 3, + HeartbeatInterval = TimeSpan.FromDays(1), + WorkerRegistrationRetryInitialDelay = TimeSpan.FromMilliseconds(10), + WorkerRegistrationRetryMaxDelay = TimeSpan.FromMilliseconds(10), + }; + + FakeServerlessActivityWorkerSession failedSession = new(); + FakeServerlessActivityWorkerSession recoveredSession = new(); + FakeServerlessActivitiesClient client = new(); + client.QueueSession(failedSession); + client.QueueSession(recoveredSession); + + ServerlessActivityWorkerRegistrationHostedService service = new( + client, + options, + ["RemoteHello"], + NullLogger.Instance); + + // Act + await service.StartAsync(CancellationToken.None); + await failedSession.WaitForMessageAsync(message => message.Start != null); + failedSession.FailCompletion(new RpcException(new Status(StatusCode.Unavailable, "terminal"))); + await recoveredSession.WaitForMessageAsync(message => message.Start != null); + await service.StopAsync(CancellationToken.None); + + // Assert + client.SessionTaskHubs.Should().Equal(TaskHub, TaskHub); + failedSession.Messages.Should().ContainSingle(message => message.Start != null); + recoveredSession.Messages.Should().ContainSingle(message => message.Start != null); + } + + [Fact] + public void ServerlessActivityWorkerRegistrationHostedService_ComputeJitteredReconnectDelay_UsesFullJitterWindow() + { + // Arrange + TimeSpan retryDelay = TimeSpan.FromSeconds(10); + + // Act + TimeSpan zero = ServerlessActivityWorkerRegistrationHostedService.ComputeJitteredReconnectDelay( + TimeSpan.Zero, + new DeterministicRandom(0.5)); + TimeSpan low = ServerlessActivityWorkerRegistrationHostedService.ComputeJitteredReconnectDelay( + retryDelay, + new DeterministicRandom(0.0)); + TimeSpan mid = ServerlessActivityWorkerRegistrationHostedService.ComputeJitteredReconnectDelay( + retryDelay, + new DeterministicRandom(0.5)); + TimeSpan high = ServerlessActivityWorkerRegistrationHostedService.ComputeJitteredReconnectDelay( + retryDelay, + new DeterministicRandom(0.999999)); + + // Assert + zero.Should().Be(TimeSpan.Zero); + low.Should().Be(TimeSpan.Zero); + mid.Should().Be(TimeSpan.FromSeconds(5)); + high.Should().BeGreaterThan(TimeSpan.FromSeconds(9)); + high.Should().BeLessThan(retryDelay); + } + + [Fact] + public async Task ServerlessActivityWorkerRegistrationHostedService_AppliesJitterToReconnectDelay() + { + // Arrange + ServerlessWorkerRuntimeOptions options = new() + { + Mode = ServerlessMode.ServerlessInclude, + TaskHub = TaskHub, + WorkerProfileId = "profile-a", + MaxConcurrentActivities = 3, + HeartbeatInterval = TimeSpan.FromMilliseconds(10), + WorkerRegistrationRetryInitialDelay = TimeSpan.FromDays(1), + WorkerRegistrationRetryMaxDelay = TimeSpan.FromDays(1), + }; + + FakeServerlessActivityWorkerSession failedSession = new() { ThrowOnWriteAttempt = 2 }; + FakeServerlessActivityWorkerSession recoveredSession = new(); + FakeServerlessActivitiesClient client = new(); + client.QueueSession(failedSession); + client.QueueSession(recoveredSession); + + ServerlessActivityWorkerRegistrationHostedService service = new( + client, + options, + ["RemoteHello"], + NullLogger.Instance, + reconnectJitter: new DeterministicRandom(0.0)); + + // Act + await service.StartAsync(CancellationToken.None); + await failedSession.WaitForMessageAsync(message => message.Start != null); + await recoveredSession.WaitForMessageAsync(message => message.Start != null); + await service.StopAsync(CancellationToken.None); + + // Assert + client.SessionTaskHubs.Should().Equal(TaskHub, TaskHub); + } + + [Fact] + public async Task ServerlessActivityWorkerRegistrationHostedService_StopAsync_DoesNotCompleteStreamWhileWriteIsInFlight() + { + // Arrange + ServerlessWorkerRuntimeOptions options = new() + { + Mode = ServerlessMode.ServerlessInclude, + TaskHub = TaskHub, + WorkerProfileId = "profile-a", + MaxConcurrentActivities = 3, + HeartbeatInterval = TimeSpan.FromMilliseconds(10), + }; + + FakeServerlessActivityWorkerSession session = new() { BlockWriteAttempt = 2 }; + FakeServerlessActivitiesClient client = new(); + client.QueueSession(session); + + ServerlessActivityWorkerRegistrationHostedService service = new( + client, + options, + ["RemoteHello"], + NullLogger.Instance); + + // Act + await service.StartAsync(CancellationToken.None); + await session.WaitForBlockedWriteAsync(); + Task stopTask = service.StopAsync(CancellationToken.None); + Task completeAttempt = session.WaitForCompleteAsync(); + Task completeBeforeWriteReleased = await Task.WhenAny( + completeAttempt, + Task.Delay(TimeSpan.FromMilliseconds(100))); + session.ReleaseBlockedWrite(); + await stopTask.WaitAsync(TimeSpan.FromSeconds(5)); + + // Assert + completeBeforeWriteReleased.Should().NotBe(completeAttempt); + session.CompleteCalled.Should().BeTrue(); + session.CompleteCalledWhileWriteActive.Should().BeFalse(); + } + + [Fact] + public async Task DeclareServerlessActivities_ConfiguresLocalWorkerExclusionFilter() + { + // Arrange + using EnvironmentVariableScope endpoint = new("DTS_ENDPOINT", "https://example.scheduler"); + using EnvironmentVariableScope taskHub = new("DTS_TASK_HUB", TaskHub); + ServiceCollection services = new(); + Mock mockBuilder = new(); + mockBuilder.Setup(builder => builder.Services).Returns(services); + mockBuilder.Setup(builder => builder.Name).Returns(Options.DefaultName); + + // Act + mockBuilder.Object.DeclareServerlessActivities(options => + { + options.TaskHub = TaskHub; + options.ContainerImage = "example.com/repo/worker:latest"; + options.ActivityNames.Add("RemoteHello"); + }); + + await using ServiceProvider provider = services.BuildServiceProvider(); + DurableTaskWorkerWorkItemFilters filters = provider.GetRequiredService>().Get(Options.DefaultName); + + // Assert + filters.ExcludedActivities.Select(filter => filter.Name).Should().Equal("RemoteHello"); + filters.Activities.Should().BeEmpty(); + } + + [Fact] + public async Task DeclareServerlessActivities_DoesNotConfigureFilterWhenActivityNamesAreEmpty() + { + // Arrange + using EnvironmentVariableScope endpoint = new("DTS_ENDPOINT", "https://example.scheduler"); + using EnvironmentVariableScope taskHub = new("DTS_TASK_HUB", TaskHub); + ServiceCollection services = new(); + Mock mockBuilder = new(); + mockBuilder.Setup(builder => builder.Services).Returns(services); + mockBuilder.Setup(builder => builder.Name).Returns(Options.DefaultName); + + // Act + mockBuilder.Object.DeclareServerlessActivities(options => + { + options.TaskHub = TaskHub; + options.ContainerImage = "example.com/repo/worker:latest"; + }); + + await using ServiceProvider provider = services.BuildServiceProvider(); + DurableTaskWorkerWorkItemFilters filters = provider.GetRequiredService>().Get(Options.DefaultName); + + // Assert + filters.ExcludedActivities.Should().BeEmpty(); + filters.Activities.Should().BeEmpty(); + } + + [Fact] + public async Task UseServerlessWorker_ConfiguresRegisteredActivityWorkerFilter() + { + // Arrange + using EnvironmentVariableScope endpoint = new("DTS_ENDPOINT", "https://example.scheduler"); + using EnvironmentVariableScope taskHub = new("DTS_TASK_HUB", TaskHub); + ServiceCollection services = new(); + services.Configure( + Options.DefaultName, + registry => registry.AddActivityFunc(new TaskName("RemoteHello"), (_, input) => input)); + Mock mockBuilder = new(); + mockBuilder.Setup(builder => builder.Services).Returns(services); + mockBuilder.Setup(builder => builder.Name).Returns(Options.DefaultName); + + // Act + mockBuilder.Object.UseServerlessWorker(); + + await using ServiceProvider provider = services.BuildServiceProvider(); + DurableTaskWorkerWorkItemFilters filters = provider.GetRequiredService>().Get(Options.DefaultName); + + // Assert + filters.Activities.Select(filter => filter.Name).Should().Equal("RemoteHello"); + filters.ExcludedActivities.Should().BeEmpty(); + filters.Orchestrations.Should().BeEmpty(); + filters.Entities.Should().BeEmpty(); + } + + [Fact] + public void UseServerlessWorker_DoesNotRegisterWakeupServerHostedService() + { + // Arrange + using EnvironmentVariableScope endpoint = new("DTS_ENDPOINT", "https://example.scheduler"); + using EnvironmentVariableScope taskHub = new("DTS_TASK_HUB", TaskHub); + ServiceCollection services = new(); + Mock mockBuilder = new(); + mockBuilder.Setup(builder => builder.Services).Returns(services); + mockBuilder.Setup(builder => builder.Name).Returns(Options.DefaultName); + + // Act + mockBuilder.Object.UseServerlessWorker(); + + // Assert + services.Count(descriptor => descriptor.ServiceType == typeof(IHostedService)).Should().Be(1); + } + + [Fact] + public void UseServerlessWorker_MissingInjectedEndpoint_Throws() + { + // Arrange + using EnvironmentVariableScope endpoint = new("DTS_ENDPOINT", null); + using EnvironmentVariableScope taskHub = new("DTS_TASK_HUB", TaskHub); + ServiceCollection services = new(); + Mock mockBuilder = new(); + mockBuilder.Setup(builder => builder.Services).Returns(services); + mockBuilder.Setup(builder => builder.Name).Returns(Options.DefaultName); + + // Act + Action action = () => mockBuilder.Object.UseServerlessWorker(); + + // Assert + action.Should().Throw() + .WithMessage("DTS_ENDPOINT must be injected by DTS for serverless workers."); + } + + [Fact] + public void UseServerlessWorker_MissingInjectedTaskHub_Throws() + { + // Arrange + using EnvironmentVariableScope endpoint = new("DTS_ENDPOINT", "https://example.scheduler"); + using EnvironmentVariableScope taskHub = new("DTS_TASK_HUB", null); + ServiceCollection services = new(); + Mock mockBuilder = new(); + mockBuilder.Setup(builder => builder.Services).Returns(services); + mockBuilder.Setup(builder => builder.Name).Returns(Options.DefaultName); + + // Act + Action action = () => mockBuilder.Object.UseServerlessWorker(); + + // Assert + action.Should().Throw() + .WithMessage("DTS_TASK_HUB must be injected by DTS for serverless workers."); + } + + [DurableTask("TypedRemoteHello")] + sealed class TypedRemoteHelloActivity : TaskActivity + { + public override Task RunAsync(TaskActivityContext context, string input) + { + return Task.FromResult(input); + } + } + + sealed class FakeServerlessActivitiesClient : IServerlessActivitiesClient + { + readonly Queue queuedSessions = new(); + + public int TransientDeclarationFailures { get; init; } + + public int DeclarationAttempts { get; private set; } + + public List Declarations { get; } = []; + + public List DeclarationTaskHubs { get; } = []; + + public List SessionTaskHubs { get; } = []; + + public List Sessions { get; } = []; + + public FakeServerlessActivityWorkerSession Session { get; } = new(); + + public void QueueSession(FakeServerlessActivityWorkerSession session) => this.queuedSessions.Enqueue(session); + + public Task DeclareServerlessActivitiesAsync( + ServerlessActivityDeclaration declaration, + string taskHub, + CancellationToken cancellationToken) + { + this.DeclarationAttempts++; + if (this.DeclarationAttempts <= this.TransientDeclarationFailures) + { + throw new RpcException(new Status(StatusCode.Unavailable, "transient")); + } + + this.DeclarationTaskHubs.Add(taskHub); + this.Declarations.Add(declaration.Clone()); + return Task.FromResult(new ServerlessActivityDeclarationResult()); + } + + public IServerlessActivityWorkerSession OpenServerlessActivityWorkerSession(string taskHub, CancellationToken cancellationToken) + { + this.SessionTaskHubs.Add(taskHub); + FakeServerlessActivityWorkerSession session = this.queuedSessions.Count > 0 + ? this.queuedSessions.Dequeue() + : this.Session; + this.Sessions.Add(session); + return session; + } + } + + sealed class RecordingServerlessActivitiesCallInvoker : CallInvoker + { + public Metadata DeclarationHeaders { get; private set; } = []; + + public Metadata WorkerSessionHeaders { get; private set; } = []; + + public override TResponse BlockingUnaryCall( + Method method, + string? host, + CallOptions options, + TRequest request) + { + throw new NotSupportedException(); + } + + public override AsyncUnaryCall AsyncUnaryCall( + Method method, + string? host, + CallOptions options, + TRequest request) + { + method.FullName.Should().EndWith("/DeclareServerlessActivities"); + this.DeclarationHeaders = options.Headers ?? []; + + return new AsyncUnaryCall( + Task.FromResult((TResponse)(object)new ServerlessActivityDeclarationResult()), + Task.FromResult(new Metadata()), + () => new Status(StatusCode.OK, string.Empty), + () => [], + () => { }); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall( + Method method, + string? host, + CallOptions options, + TRequest request) + { + throw new NotSupportedException(); + } + + public override AsyncClientStreamingCall AsyncClientStreamingCall( + Method method, + string? host, + CallOptions options) + { + method.FullName.Should().EndWith("/ConnectServerlessActivityWorker"); + this.WorkerSessionHeaders = options.Headers ?? []; + + return new AsyncClientStreamingCall( + new RecordingClientStreamWriter(), + Task.FromResult((TResponse)(object)new ServerlessActivityWorkerSessionResult()), + Task.FromResult(new Metadata()), + () => new Status(StatusCode.OK, string.Empty), + () => [], + () => { }); + } + + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall( + Method method, + string? host, + CallOptions options) + { + throw new NotSupportedException(); + } + } + + sealed class RecordingClientStreamWriter : IClientStreamWriter + { + public WriteOptions? WriteOptions { get; set; } + + public Task WriteAsync(T message) => Task.CompletedTask; + + public Task CompleteAsync() => Task.CompletedTask; + } + + sealed class FakeServerlessActivityWorkerSession : IServerlessActivityWorkerSession + { + readonly object sync = new(); + readonly TaskCompletionSource completion = + new(TaskCreationOptions.RunContinuationsAsynchronously); + readonly TaskCompletionSource blockedWriteStarted = + new(TaskCreationOptions.RunContinuationsAsynchronously); + readonly TaskCompletionSource releaseBlockedWrite = + new(TaskCreationOptions.RunContinuationsAsynchronously); + int writeAttempts; + int activeWrites; + + public List Messages { get; } = []; + + public int? ThrowOnWriteAttempt { get; init; } + + public int? BlockWriteAttempt { get; init; } + + public bool CompleteCalled { get; private set; } + + public bool CompleteCalledWhileWriteActive { get; private set; } + + public void FailCompletion(Exception exception) => this.completion.TrySetException(exception); + + public Task WaitForBlockedWriteAsync() => this.blockedWriteStarted.Task.WaitAsync(TimeSpan.FromSeconds(5)); + + public Task WaitForCompleteAsync() + { + lock (this.sync) + { + return this.CompleteCalled ? Task.CompletedTask : this.completion.Task; + } + } + + public void ReleaseBlockedWrite() => this.releaseBlockedWrite.TrySetResult(); + + public async Task WaitForMessageAsync(Func predicate) + { + using CancellationTokenSource timeout = new(TimeSpan.FromSeconds(5)); + while (!timeout.IsCancellationRequested) + { + lock (this.sync) + { + if (this.Messages.Any(predicate)) + { + return; + } + } + + await Task.Delay(TimeSpan.FromMilliseconds(10), timeout.Token); + } + + throw new TimeoutException("Timed out waiting for serverless worker message."); + } + + public Task WriteMessageAsync(ServerlessActivityWorkerMessage message) + { + int attempt; + bool blockWrite; + lock (this.sync) + { + attempt = ++this.writeAttempts; + if (this.ThrowOnWriteAttempt == attempt) + { + throw new RpcException(new Status(StatusCode.Unavailable, "transient")); + } + + this.activeWrites++; + blockWrite = this.BlockWriteAttempt == attempt; + if (blockWrite) + { + this.blockedWriteStarted.TrySetResult(); + } + } + + return this.WriteMessageCoreAsync(message, blockWrite); + } + + public Task WaitForCompletionAsync() => this.completion.Task; + + public async Task CompleteAsync() + { + lock (this.sync) + { + this.CompleteCalled = true; + this.CompleteCalledWhileWriteActive = this.activeWrites > 0; + } + + this.completion.TrySetResult(new ServerlessActivityWorkerSessionResult { Accepted = true }); + await this.completion.Task.ConfigureAwait(false); + } + + public ValueTask DisposeAsync() => default; + + async Task WriteMessageCoreAsync(ServerlessActivityWorkerMessage message, bool blockWrite) + { + try + { + if (blockWrite) + { + await this.releaseBlockedWrite.Task.ConfigureAwait(false); + } + + lock (this.sync) + { + this.Messages.Add(message.Clone()); + } + } + finally + { + lock (this.sync) + { + this.activeWrites--; + } + } + } + } + + sealed class DeterministicRandom : Random + { + readonly double value; + + public DeterministicRandom(double value) + { + this.value = value; + } + + protected override double Sample() => this.value; + } + + sealed class EnvironmentVariableScope : IDisposable + { + readonly string name; + readonly string? originalValue; + + public EnvironmentVariableScope(string name, string? value) + { + this.name = name; + this.originalValue = Environment.GetEnvironmentVariable(name); + Environment.SetEnvironmentVariable(name, value); + } + + public void Dispose() => Environment.SetEnvironmentVariable(this.name, this.originalValue); + } +} diff --git a/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs index db6c98da..0fe26c55 100644 --- a/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs +++ b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Collections.Concurrent; using System.IO; using System.Reflection; using Google.Protobuf.WellKnownTypes; @@ -28,6 +29,9 @@ public class GrpcDurableTaskWorkerTests static readonly MethodInfo ProcessorConnectAsyncMethod = typeof(GrpcDurableTaskWorker) .GetNestedType("Processor", BindingFlags.NonPublic)! .GetMethod("ConnectAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + static readonly MethodInfo DispatchWorkItemMethod = typeof(GrpcDurableTaskWorker) + .GetNestedType("Processor", BindingFlags.NonPublic)! + .GetMethod("DispatchWorkItem", BindingFlags.Instance | BindingFlags.NonPublic)!; static readonly MethodInfo TryRecreateChannelAsyncMethod = typeof(GrpcDurableTaskWorker) .GetMethod("TryRecreateChannelAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; @@ -250,6 +254,58 @@ public async Task ProcessorExecuteAsync_GracefulDrainAfterFirstMessage_Reconnect logs.Should().NotContain(log => log.Message.Contains("Recreating gRPC channel to backend")); } + [Fact] + public async Task DispatchWorkItem_ActivityRequest_NotifiesActivityStartAndCompletion() + { + // Arrange + ConcurrentQueue notifications = new(); + TaskCompletionSource completed = new(TaskCreationOptions.RunContinuationsAsynchronously); + GrpcDurableTaskWorkerOptions grpcOptions = new(); + grpcOptions.ConfigureActivityNotification(phase => + { + notifications.Enqueue(phase); + if (phase == ActivityNotificationPhase.Completed) + { + completed.TrySetResult(); + } + }); + + P.WorkItem activityWorkItem = new() + { + ActivityRequest = new P.ActivityRequest + { + Name = "MyActivity", + TaskId = 42, + OrchestrationInstance = new P.OrchestrationInstance + { + InstanceId = "instance1", + ExecutionId = "execution1", + }, + }, + CompletionToken = "completion1", + }; + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions); + Mock clientMock = new( + MockBehavior.Strict, + new object[] { Mock.Of() }); + clientMock + .Setup(client => client.CompleteActivityTaskAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns(CreateUnaryCall(Task.FromResult(new P.CompleteTaskResponse()))); + object processor = CreateProcessor(worker, clientMock.Object); + + // Act + InvokeDispatchWorkItem(processor, activityWorkItem, CancellationToken.None); + await completed.Task.WaitAsync(TimeSpan.FromSeconds(5)); + + // Assert + notifications.Should().Equal(ActivityNotificationPhase.Started, ActivityNotificationPhase.Completed); + } + [Fact] public async Task ProcessorExecuteAsync_HelloDeadlineExceeded_ReturnsChannelRecreateRequested() { @@ -542,6 +598,11 @@ static async Task InvokeProcessorExecuteAsync(object proces return (ProcessorExitReason)task.GetType().GetProperty("Result")!.GetValue(task)!; } + static void InvokeDispatchWorkItem(object processor, P.WorkItem workItem, CancellationToken cancellationToken) + { + DispatchWorkItemMethod.Invoke(processor, new object?[] { workItem, cancellationToken }); + } + static void InvokeApplySuccessfulRecreate( GrpcDurableTaskWorker worker, object result,