Skip to content

Commit c0a8a01

Browse files
authored
Update storage driver store context metadata (#1399)
1 parent 6c1bc40 commit c0a8a01

File tree

17 files changed

+1196
-309
lines changed

17 files changed

+1196
-309
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,11 @@ def feature_flag_is_on(workflow_id: str | None) -> bool:
533533
def feature_flag_selector(
534534
context: temporalio.converter.StorageDriverStoreContext, _payload: Payload
535535
) -> temporalio.converter.StorageDriver | None:
536-
workflow_id = None
537-
if isinstance(context.serialization_context, temporalio.converter.WorkflowSerializationContext):
538-
workflow_id = context.serialization_context.workflow_id
539-
elif isinstance(context.serialization_context, temporalio.converter.ActivitySerializationContext):
540-
workflow_id = context.serialization_context.workflow_id
536+
workflow_id = (
537+
context.target.id
538+
if isinstance(context.target, temporalio.converter.StorageDriverWorkflowInfo)
539+
else None
540+
)
541541
return my_driver if feature_flag_is_on(workflow_id) else None
542542

543543
options = ExternalStorage(

temporalio/client.py

Lines changed: 111 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666
ActivitySerializationContext,
6767
DataConverter,
6868
SerializationContext,
69+
StorageDriverActivityInfo,
70+
StorageDriverStoreContext,
71+
StorageDriverWorkflowInfo,
6972
WithSerializationContext,
7073
WorkflowSerializationContext,
7174
)
@@ -6161,11 +6164,16 @@ async def _to_proto(
61616164
priority: temporalio.api.common.v1.Priority | None = None
61626165
if self.priority:
61636166
priority = self.priority._to_proto()
6164-
data_converter = client.data_converter.with_context(
6167+
data_converter = client.data_converter._with_contexts(
61656168
WorkflowSerializationContext(
61666169
namespace=client.namespace,
61676170
workflow_id=self.id,
6168-
)
6171+
),
6172+
StorageDriverStoreContext(
6173+
target=StorageDriverWorkflowInfo(
6174+
id=self.id, type=self.workflow, namespace=client.namespace
6175+
),
6176+
),
61696177
)
61706178
action = temporalio.api.schedule.v1.ScheduleAction(
61716179
start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo(
@@ -6210,7 +6218,8 @@ async def _to_proto(
62106218
# TODO (dan): confirm whether this be `is not None`
62116219
if self.typed_search_attributes:
62126220
temporalio.converter.encode_search_attributes(
6213-
self.typed_search_attributes, action.start_workflow.search_attributes
6221+
self.typed_search_attributes,
6222+
action.start_workflow.search_attributes,
62146223
)
62156224
if self.headers:
62166225
await _apply_headers(
@@ -8077,11 +8086,16 @@ async def _build_signal_with_start_workflow_execution_request(
80778086
self, input: StartWorkflowInput
80788087
) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest:
80798088
assert input.start_signal
8080-
data_converter = self._client.data_converter.with_context(
8089+
data_converter = self._client.data_converter._with_contexts(
80818090
WorkflowSerializationContext(
80828091
namespace=self._client.namespace,
80838092
workflow_id=input.id,
8084-
)
8093+
),
8094+
StorageDriverStoreContext(
8095+
target=StorageDriverWorkflowInfo(
8096+
id=input.id, type=input.workflow, namespace=self._client.namespace
8097+
),
8098+
),
80858099
)
80868100
req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest(
80878101
signal_name=input.start_signal
@@ -8108,11 +8122,16 @@ async def _populate_start_workflow_execution_request(
81088122
),
81098123
input: StartWorkflowInput | UpdateWithStartStartWorkflowInput,
81108124
) -> None:
8111-
data_converter = self._client.data_converter.with_context(
8125+
data_converter = self._client.data_converter._with_contexts(
81128126
WorkflowSerializationContext(
81138127
namespace=self._client.namespace,
81148128
workflow_id=input.id,
8115-
)
8129+
),
8130+
StorageDriverStoreContext(
8131+
target=StorageDriverWorkflowInfo(
8132+
id=input.id, type=input.workflow, namespace=self._client.namespace
8133+
),
8134+
),
81168135
)
81178136
req.namespace = self._client.namespace
81188137
req.workflow_id = input.id
@@ -8228,11 +8247,18 @@ async def count_workflows(
82288247
)
82298248

82308249
async def query_workflow(self, input: QueryWorkflowInput) -> Any:
8231-
data_converter = self._client.data_converter.with_context(
8250+
data_converter = self._client.data_converter._with_contexts(
82328251
WorkflowSerializationContext(
82338252
namespace=self._client.namespace,
82348253
workflow_id=input.id,
8235-
)
8254+
),
8255+
StorageDriverStoreContext(
8256+
target=StorageDriverWorkflowInfo(
8257+
id=input.id,
8258+
run_id=input.run_id or None,
8259+
namespace=self._client.namespace,
8260+
),
8261+
),
82368262
)
82378263
req = temporalio.api.workflowservice.v1.QueryWorkflowRequest(
82388264
namespace=self._client.namespace,
@@ -8255,7 +8281,10 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
82558281
await self._apply_headers(input.headers, req.query.header.fields)
82568282
try:
82578283
resp = await self._client.workflow_service.query_workflow(
8258-
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
8284+
req,
8285+
retry=True,
8286+
metadata=input.rpc_metadata,
8287+
timeout=input.rpc_timeout,
82598288
)
82608289
except RPCError as err:
82618290
# If the status is INVALID_ARGUMENT, we can assume it's a query
@@ -8281,11 +8310,18 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
82818310
return results[0]
82828311

82838312
async def signal_workflow(self, input: SignalWorkflowInput) -> None:
8284-
data_converter = self._client.data_converter.with_context(
8313+
data_converter = self._client.data_converter._with_contexts(
82858314
WorkflowSerializationContext(
82868315
namespace=self._client.namespace,
82878316
workflow_id=input.id,
8288-
)
8317+
),
8318+
StorageDriverStoreContext(
8319+
target=StorageDriverWorkflowInfo(
8320+
id=input.id,
8321+
run_id=input.run_id or None,
8322+
namespace=self._client.namespace,
8323+
),
8324+
),
82898325
)
82908326
req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest(
82918327
namespace=self._client.namespace,
@@ -8306,11 +8342,18 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None:
83068342
)
83078343

83088344
async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
8309-
data_converter = self._client.data_converter.with_context(
8345+
data_converter = self._client.data_converter._with_contexts(
83108346
WorkflowSerializationContext(
83118347
namespace=self._client.namespace,
83128348
workflow_id=input.id,
8313-
)
8349+
),
8350+
StorageDriverStoreContext(
8351+
target=StorageDriverWorkflowInfo(
8352+
id=input.id,
8353+
run_id=input.run_id or None,
8354+
namespace=self._client.namespace,
8355+
),
8356+
),
83148357
)
83158358
req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest(
83168359
namespace=self._client.namespace,
@@ -8365,7 +8408,7 @@ async def _build_start_activity_execution_request(
83658408
self, input: StartActivityInput
83668409
) -> temporalio.api.workflowservice.v1.StartActivityExecutionRequest:
83678410
"""Build StartActivityExecutionRequest from input."""
8368-
data_converter = self._client.data_converter.with_context(
8411+
data_converter = self._client.data_converter._with_contexts(
83698412
ActivitySerializationContext(
83708413
namespace=self._client.namespace,
83718414
activity_id=input.id,
@@ -8374,7 +8417,14 @@ async def _build_start_activity_execution_request(
83748417
is_local=False,
83758418
workflow_id=None,
83768419
workflow_type=None,
8377-
)
8420+
),
8421+
StorageDriverStoreContext(
8422+
target=StorageDriverActivityInfo(
8423+
id=input.id,
8424+
type=input.activity_type,
8425+
namespace=self._client.namespace,
8426+
),
8427+
),
83788428
)
83798429

83808430
req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest(
@@ -8560,11 +8610,20 @@ async def _build_update_workflow_execution_request(
85608610
input: StartWorkflowUpdateInput | UpdateWithStartUpdateWorkflowInput,
85618611
workflow_id: str,
85628612
) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest:
8563-
data_converter = self._client.data_converter.with_context(
8613+
data_converter = self._client.data_converter._with_contexts(
85648614
WorkflowSerializationContext(
85658615
namespace=self._client.namespace,
85668616
workflow_id=workflow_id,
8567-
)
8617+
),
8618+
StorageDriverStoreContext(
8619+
target=StorageDriverWorkflowInfo(
8620+
id=workflow_id,
8621+
run_id=(input.run_id or None)
8622+
if isinstance(input, StartWorkflowUpdateInput)
8623+
else None,
8624+
namespace=self._client.namespace,
8625+
),
8626+
),
85688627
)
85698628
run_id, first_execution_run_id = (
85708629
(
@@ -8739,10 +8798,34 @@ async def _start_workflow_update_with_start(
87398798

87408799
### Async activity calls
87418800

8801+
def _get_async_activity_store_context(
8802+
self, id_or_token: AsyncActivityIDReference | bytes
8803+
) -> StorageDriverStoreContext:
8804+
if isinstance(id_or_token, AsyncActivityIDReference):
8805+
if id_or_token.workflow_id:
8806+
return StorageDriverStoreContext(
8807+
target=StorageDriverWorkflowInfo(
8808+
id=id_or_token.workflow_id or None,
8809+
run_id=id_or_token.run_id or None,
8810+
namespace=self._client.namespace,
8811+
),
8812+
)
8813+
return StorageDriverStoreContext(
8814+
target=StorageDriverActivityInfo(
8815+
id=id_or_token.activity_id,
8816+
run_id=id_or_token.run_id or None,
8817+
namespace=self._client.namespace,
8818+
),
8819+
)
8820+
else:
8821+
return StorageDriverStoreContext(target=None)
8822+
87428823
async def heartbeat_async_activity(
87438824
self, input: HeartbeatAsyncActivityInput
87448825
) -> None:
8745-
data_converter = input.data_converter_override or self._client.data_converter
8826+
data_converter = (
8827+
input.data_converter_override or self._client.data_converter
8828+
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))
87468829
details = (
87478830
None
87488831
if not input.details
@@ -8797,7 +8880,9 @@ async def heartbeat_async_activity(
87978880
)
87988881

87998882
async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
8800-
data_converter = input.data_converter_override or self._client.data_converter
8883+
data_converter = (
8884+
input.data_converter_override or self._client.data_converter
8885+
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))
88018886
result = (
88028887
None
88038888
if input.result is temporalio.common._arg_unset
@@ -8831,7 +8916,9 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No
88318916
)
88328917

88338918
async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
8834-
data_converter = input.data_converter_override or self._client.data_converter
8919+
data_converter = (
8920+
input.data_converter_override or self._client.data_converter
8921+
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))
88358922

88368923
failure = temporalio.api.failure.v1.Failure()
88378924
await data_converter.encode_failure(input.error, failure)
@@ -8872,7 +8959,9 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
88728959
async def report_cancellation_async_activity(
88738960
self, input: ReportCancellationAsyncActivityInput
88748961
) -> None:
8875-
data_converter = input.data_converter_override or self._client.data_converter
8962+
data_converter = (
8963+
input.data_converter_override or self._client.data_converter
8964+
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))
88768965
details = (
88778966
None
88788967
if not input.details

temporalio/contrib/aws/s3driver/_driver.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from temporalio.api.common.v1 import Payload
1616
from temporalio.contrib.aws.s3driver._client import S3StorageDriverClient
1717
from temporalio.converter import (
18-
ActivitySerializationContext,
1918
StorageDriver,
19+
StorageDriverActivityInfo,
2020
StorageDriverClaim,
2121
StorageDriverRetrieveContext,
2222
StorageDriverStoreContext,
23-
WorkflowSerializationContext,
23+
StorageDriverWorkflowInfo,
2424
)
2525

2626
_T = TypeVar("_T")
@@ -113,40 +113,25 @@ async def store(
113113
(e.g. proto binary). The returned list is the same length as
114114
``payloads``.
115115
"""
116-
workflow_id: str | None = None
117-
activity_id: str | None = None
118-
namespace: str | None = None
119-
if isinstance(context.serialization_context, WorkflowSerializationContext):
120-
workflow_id = context.serialization_context.workflow_id
121-
namespace = context.serialization_context.namespace
122-
if isinstance(context.serialization_context, ActivitySerializationContext):
123-
# Prioritize workflow over activity so that the same payload that
124-
# may be stored across workflow and activity boundaries are deduplicated.
125-
if context.serialization_context.workflow_id:
126-
workflow_id = context.serialization_context.workflow_id
127-
elif context.serialization_context.activity_id:
128-
activity_id = context.serialization_context.activity_id
129-
namespace = context.serialization_context.namespace
130-
131-
# URL encode values to avoid characters that break the key format
132-
# e.g. spaces, forward-slashes, etc.
133-
if namespace:
134-
namespace = urllib.parse.quote(namespace, safe="")
135-
if workflow_id:
136-
workflow_id = urllib.parse.quote(workflow_id, safe="")
137-
if activity_id:
138-
activity_id = urllib.parse.quote(activity_id, safe="")
139-
140-
namespace_segments = f"/ns/{namespace}" if namespace else ""
141116

117+
def _quote(val: str | None) -> str | None:
118+
return urllib.parse.quote(val, safe="") if val else None
119+
120+
# Build context segments from the target identity.
142121
context_segments = ""
143-
# Prioritize workflow over activity so that the same payload that
144-
# may be stored across workflow and activity boundaries are deduplicated.
145-
# Workflow and Activity IDs are case sensitive.
146-
if workflow_id:
147-
context_segments += f"/wfi/{workflow_id}"
148-
elif activity_id:
149-
context_segments += f"/aci/{activity_id}"
122+
target = context.target
123+
namespace = _quote(target.namespace) if target is not None else None
124+
namespace_segment = f"/ns/{namespace}" if namespace else ""
125+
if isinstance(target, StorageDriverWorkflowInfo):
126+
wf_type = _quote(target.type) or "null"
127+
wf_id = _quote(target.id) or "null"
128+
wf_run_id = _quote(target.run_id) or "null"
129+
context_segments = f"/wt/{wf_type}/wi/{wf_id}/ri/{wf_run_id}"
130+
elif isinstance(target, StorageDriverActivityInfo):
131+
act_type = _quote(target.type) or "null"
132+
act_id = _quote(target.id) or "null"
133+
act_run_id = _quote(target.run_id) or "null"
134+
context_segments = f"/at/{act_type}/ai/{act_id}/ri/{act_run_id}"
150135

151136
async def _upload(payload: Payload) -> StorageDriverClaim:
152137
bucket = self._get_bucket(context, payload)
@@ -162,7 +147,7 @@ async def _upload(payload: Payload) -> StorageDriverClaim:
162147

163148
digest_segments = f"/d/sha256/{hash_digest}"
164149

165-
key = f"v0{namespace_segments}{context_segments}{digest_segments}"
150+
key = f"v0{namespace_segment}{context_segments}{digest_segments}"
166151

167152
try:
168153
if not await self._client.object_exists(bucket=bucket, key=key):

temporalio/converter/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from temporalio.converter._extstore import (
88
ExternalStorage,
99
StorageDriver,
10+
StorageDriverActivityInfo,
1011
StorageDriverClaim,
1112
StorageDriverRetrieveContext,
1213
StorageDriverStoreContext,
14+
StorageDriverWorkflowInfo,
1315
StorageWarning,
1416
)
1517
from temporalio.converter._failure_converter import (
@@ -54,9 +56,11 @@
5456
"ActivitySerializationContext",
5557
"ExternalStorage",
5658
"StorageDriver",
59+
"StorageDriverActivityInfo",
5760
"StorageDriverClaim",
5861
"StorageDriverRetrieveContext",
5962
"StorageDriverStoreContext",
63+
"StorageDriverWorkflowInfo",
6064
"StorageWarning",
6165
"AdvancedJSONEncoder",
6266
"BinaryNullPayloadConverter",

0 commit comments

Comments
 (0)