Skip to content

Commit 70e7f5c

Browse files
committed
Type recipe package
1 parent 39c4450 commit 70e7f5c

4 files changed

Lines changed: 58 additions & 54 deletions

File tree

src/workflows/recipe/__init__.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626

2727
def _wrap_subscription(
2828
transport_layer: CommonTransport,
29-
subscription_call,
30-
channel,
31-
callback,
32-
*args,
29+
subscription_call: Callable[..., int],
30+
channel: str,
31+
callback: Callable[..., Any],
32+
*args: Any,
3333
mangle_for_receiving: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
3434
allow_non_recipe_messages: bool = False,
35-
log_extender=None,
36-
**kwargs,
37-
):
35+
log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None,
36+
**kwargs: Any,
37+
) -> int:
3838
"""Internal method to create an intercepting function for incoming messages
3939
to interpret recipes. This function is then used to subscribe to a channel
4040
on the transport layer.
@@ -61,7 +61,7 @@ def _wrap_subscription(
6161
"""
6262

6363
@functools.wraps(callback)
64-
def unwrap_recipe(header, message: dict[str, Any]):
64+
def unwrap_recipe(header: dict[str, Any], message: dict[str, Any]) -> Any:
6565
"""Unpack incoming messages when they are in a recipe format.
6666
6767
Other messages are passed through unmodified.
@@ -124,11 +124,11 @@ def wrap_subscribe(
124124
transport_layer: CommonTransport,
125125
channel: str,
126126
callback: Callable[[RecipeWrapper, dict, dict], None],
127-
*args,
127+
*args: Any,
128128
allow_non_recipe_messages: Literal[False] = False,
129129
mangle_for_receiving: Callable[[Any], Any] | None = None,
130130
log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None,
131-
**kwargs,
131+
**kwargs: Any,
132132
) -> int: ...
133133

134134

@@ -137,24 +137,24 @@ def wrap_subscribe(
137137
transport_layer: CommonTransport,
138138
channel: str,
139139
callback: Callable[[RecipeWrapper | None, dict, dict | bytes], None],
140-
*args,
140+
*args: Any,
141141
allow_non_recipe_messages: Literal[True],
142142
mangle_for_receiving: Callable[[Any], Any] | None = None,
143143
log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None,
144-
**kwargs,
144+
**kwargs: Any,
145145
) -> int: ...
146146

147147

148148
def wrap_subscribe(
149-
transport_layer,
150-
channel,
151-
callback,
152-
*args,
153-
allow_non_recipe_messages=False,
154-
mangle_for_receiving=None,
155-
log_extender=None,
156-
**kwargs,
157-
):
149+
transport_layer: CommonTransport,
150+
channel: str,
151+
callback: Callable[..., Any],
152+
*args: Any,
153+
allow_non_recipe_messages: bool = False,
154+
mangle_for_receiving: Callable[[Any], Any] | None = None,
155+
log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None,
156+
**kwargs: Any,
157+
) -> int:
158158
"""Listen to a queue on the transport layer, similar to the subscribe call in
159159
transport/common_transport.py. Intercept all incoming messages and parse
160160
for recipe information.
@@ -186,13 +186,13 @@ def wrap_subscribe(
186186

187187

188188
def wrap_subscribe_broadcast(
189-
transport_layer,
190-
channel,
191-
callback,
192-
*args,
189+
transport_layer: CommonTransport,
190+
channel: str,
191+
callback: Callable[..., Any],
192+
*args: Any,
193193
mangle_for_receiving: Callable[[Any], Any] | None = None,
194-
**kwargs,
195-
):
194+
**kwargs: Any,
195+
) -> int:
196196
"""Listen to a topic on the transport layer, similar to the
197197
subscribe_broadcast call in transport/common_transport.py. Intercept all
198198
incoming messages and parse for recipe information.

src/workflows/recipe/recipe.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def validate(self) -> None:
132132
# Detect cycles
133133
touched_nodes = {"start", "error"}
134134

135-
def flatten_links(struct):
135+
def flatten_links(struct: Any) -> list[int]:
136136
"""Take an output/error link object, list or dictionary and return flat list of linked nodes."""
137137
if struct is None:
138138
return []
@@ -153,7 +153,7 @@ def flatten_links(struct):
153153
"Invalid recipe: Invalid link in recipe (%s)" % str(struct)
154154
)
155155

156-
def find_cycles(path):
156+
def find_cycles(path: list[Any]) -> None:
157157
"""Depth-First-Search helper function to identify cycles."""
158158
if path[-1] not in self.recipe:
159159
raise workflows.Error(
@@ -215,23 +215,23 @@ def apply_parameters(self, parameters: dict[str, Any]) -> None:
215215
"""
216216

217217
class SafeString:
218-
def __init__(self, s):
218+
def __init__(self, s: str):
219219
self.string = s
220220

221-
def __repr__(self):
221+
def __repr__(self) -> str:
222222
return "{" + self.string + "}"
223223

224-
def __str__(self):
224+
def __str__(self) -> str:
225225
return "{" + self.string + "}"
226226

227-
def __getitem__(self, item):
227+
def __getitem__(self, item: str) -> SafeString:
228228
return SafeString(self.string + "[" + item + "]")
229229

230230
class SafeDict(dict):
231231
"""A dictionary that returns undefined keys as {keyname}.
232232
This can be used to selectively replace variables in datastructures."""
233233

234-
def __missing__(self, key):
234+
def __missing__(self, key: str) -> SafeString:
235235
return SafeString(key)
236236

237237
# By default the python formatter class is used to resolve {item} references
@@ -242,15 +242,15 @@ def __missing__(self, key):
242242
# string.
243243
ds_formatter = string.Formatter()
244244

245-
def ds_format_field(value, spec):
245+
def ds_format_field(value: Any, spec: str) -> str:
246246
ds_format_field.last = value # type: ignore
247247
return ""
248248

249249
ds_formatter.format_field = ds_format_field # type: ignore
250250

251251
params = SafeDict(parameters)
252252

253-
def _recursive_apply(item):
253+
def _recursive_apply(item: Any) -> Any:
254254
"""Helper function to recursively apply replacements."""
255255
if isinstance(item, str):
256256
if item.startswith("{$REPLACE") and item.endswith("}"):
@@ -323,7 +323,7 @@ def merge(self, other: Recipe | str) -> Recipe:
323323
new_recipe[translation[key]] = value
324324

325325
# Rewrite all copied entries to point to new keys
326-
def translate(x):
326+
def translate(x: Any) -> Any:
327327
if isinstance(x, list):
328328
return list(map(translate, x))
329329
elif isinstance(x, tuple):

src/workflows/recipe/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def validate_recipe(json_filename: str | os.PathLike[str]) -> None:
5454
raise e
5555

5656

57-
def main():
57+
def main() -> None:
5858
"""Run the program from entry point"""
5959
parser = argparse.ArgumentParser()
6060
parser.add_argument(

src/workflows/recipe/wrapper.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def start(
301301
header: dict[str, Any] | None = None,
302302
*,
303303
mangle_for_sending: Callable[[Any], Any] | None = None,
304-
**kwargs,
304+
**kwargs: Any,
305305
) -> None:
306306
"""
307307
Trigger the start of a recipe.
@@ -375,7 +375,7 @@ def checkpoint(
375375
usually "encode to JSON".
376376
**kwargs: Keywords passed on to the transport.
377377
"""
378-
if not self.recipe_step:
378+
if not self.recipe_step or self.recipe_pointer is None:
379379
raise ValueError(
380380
"This RecipeWrapper object does not contain "
381381
"a recipe with a selected step."
@@ -418,7 +418,9 @@ def apply_parameters(self, parameters: dict[str, Any]) -> None:
418418
assert self.recipe_pointer is not None
419419
self.recipe_step = self.recipe[self.recipe_pointer]
420420

421-
def _generate_full_recipe_message(self, destination, message, add_path_step):
421+
def _generate_full_recipe_message(
422+
self, destination: int, message: Any, add_path_step: bool
423+
) -> dict[str, Any]:
422424
"""Factory function to generate independent message objects for
423425
downstream recipients with different destinations."""
424426
if add_path_step and self.recipe_pointer:
@@ -436,17 +438,17 @@ def _generate_full_recipe_message(self, destination, message, add_path_step):
436438

437439
def _send_to_destinations(
438440
self,
439-
destinations,
440-
message,
441-
header=None,
441+
destinations: int | list[int],
442+
message: Any,
443+
header: dict[str, Any] | None = None,
442444
mangle_for_sending: Callable[[Any], Any] | None = None,
443-
**kwargs,
444-
):
445+
**kwargs: Any,
446+
) -> None:
445447
"""Send messages to a list of numbered destinations. This is an internal
446448
helper method used by the public 'send' methods.
447449
"""
448450
if not isinstance(destinations, list):
449-
destinations = (destinations,)
451+
destinations = [destinations]
450452
for destination in destinations:
451453
self._send_to_destination(
452454
destination,
@@ -458,13 +460,13 @@ def _send_to_destinations(
458460

459461
def _send_to_destination(
460462
self,
461-
destination,
462-
header,
463-
payload,
464-
transport_kwargs,
465-
add_path_step=True,
463+
destination: int,
464+
header: dict[str, Any] | None,
465+
payload: Any,
466+
transport_kwargs: dict[str, Any],
467+
add_path_step: bool = True,
466468
mangle_for_sending: Callable[[Any], Any] | None = None,
467-
):
469+
) -> None:
468470
"""Helper function to send a message to a specific recipe destination."""
469471
if header:
470472
header = header.copy()
@@ -516,7 +518,9 @@ def _send_to_destination(
516518
**dest_kwargs,
517519
)
518520

519-
def _retry_transport(self, function, *args, **kwargs):
521+
def _retry_transport(
522+
self, function: Callable[..., Any], *args: Any, **kwargs: Any
523+
) -> Any:
520524
"""Attempt to send a message, and in case the connection has been lost,
521525
attempt to reconnect. Reconnecting only works on the assumption that
522526
the previous connection did not include any subscriptions, which should

0 commit comments

Comments
 (0)