Skip to content

Commit ed7ef59

Browse files
committed
feat: enhance source connection handling with target reference and host settings resolution
1 parent a1fb355 commit ed7ef59

2 files changed

Lines changed: 165 additions & 13 deletions

File tree

src/iop/production/source_inference.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class SourceConnection:
1515
source: str
1616
detail: str = ""
1717
interaction: str = "request"
18+
target_reference: str = ""
1819

1920

2021
_PYTHON_REQUEST_METHODS = {
@@ -93,7 +94,7 @@ def infer_source_connections(
9394

9495
def _infer_python_connections(
9596
class_name: str,
96-
_host_settings: dict[str, Any],
97+
host_settings: dict[str, Any],
9798
roots: tuple[Path, ...],
9899
) -> list[SourceConnection]:
99100
candidates = []
@@ -114,7 +115,12 @@ def _infer_python_connections(
114115

115116
connections: list[SourceConnection] = []
116117
for _root_index, candidate in candidates[:1]:
117-
connections.extend(candidate.connections)
118+
connections.extend(
119+
_python_connections_for_host_settings(
120+
candidate.connections,
121+
host_settings,
122+
)
123+
)
118124
return _unique_connections(connections)
119125

120126

@@ -281,7 +287,7 @@ def _python_class_connections(class_node: ast.ClassDef) -> list[SourceConnection
281287
continue
282288
call_name = _python_call_name(node.func)
283289
if call_name in _PYTHON_REQUEST_METHODS:
284-
for target, detail, interaction in _python_request_targets(
290+
for target, detail, interaction, target_reference in _python_request_targets(
285291
node,
286292
string_values,
287293
call_name,
@@ -292,10 +298,11 @@ def _python_class_connections(class_node: ast.ClassDef) -> list[SourceConnection
292298
source="Python source",
293299
detail=detail,
294300
interaction=interaction,
301+
target_reference=target_reference,
295302
)
296303
)
297304
elif call_name in _PYTHON_MULTI_REQUEST_METHODS:
298-
for target, detail, interaction in _python_multi_request_targets(
305+
for target, detail, interaction, target_reference in _python_multi_request_targets(
299306
node,
300307
string_values,
301308
call_name,
@@ -306,6 +313,7 @@ def _python_class_connections(class_node: ast.ClassDef) -> list[SourceConnection
306313
source="Python source",
307314
detail=detail,
308315
interaction=interaction,
316+
target_reference=target_reference,
309317
)
310318
)
311319
return _unique_connections(connections)
@@ -367,7 +375,7 @@ def _python_request_targets(
367375
call: ast.Call,
368376
string_values: dict[str, list[str]],
369377
call_name: str,
370-
) -> list[tuple[str, str, str]]:
378+
) -> list[tuple[str, str, str, str]]:
371379
target_node = _python_call_target_node(call)
372380
if target_node is None:
373381
return []
@@ -383,14 +391,14 @@ def _python_multi_request_targets(
383391
call: ast.Call,
384392
string_values: dict[str, list[str]],
385393
call_name: str,
386-
) -> list[tuple[str, str, str]]:
394+
) -> list[tuple[str, str, str, str]]:
387395
if not call.args:
388396
return []
389397
collection = call.args[0]
390398
if not isinstance(collection, (ast.List, ast.Tuple)):
391399
return []
392400

393-
targets: list[tuple[str, str, str]] = []
401+
targets: list[tuple[str, str, str, str]] = []
394402
for item in collection.elts:
395403
if not isinstance(item, ast.Tuple) or not item.elts:
396404
continue
@@ -410,20 +418,59 @@ def _resolve_python_targets(
410418
string_values: dict[str, list[str]],
411419
call_name: str,
412420
interaction: str,
413-
) -> list[tuple[str, str, str]]:
421+
) -> list[tuple[str, str, str, str]]:
414422
literal = _python_literal_string(node)
415423
if literal is not None:
416-
return [(literal, f"{call_name} literal", interaction)]
424+
return [(literal, f"{call_name} literal", interaction, "")]
417425

418426
name = _python_reference_name(node)
419427
if name is None:
420428
return []
429+
detail = f"{call_name} {name}"
430+
values = string_values.get(name, [])
431+
if not values:
432+
return [("", detail, interaction, name)]
421433
return [
422-
(value, f"{call_name} {name}", interaction)
423-
for value in string_values.get(name, [])
434+
(value, detail, interaction, name)
435+
for value in values
424436
]
425437

426438

439+
def _python_connections_for_host_settings(
440+
connections: tuple[SourceConnection, ...],
441+
host_settings: dict[str, Any],
442+
) -> list[SourceConnection]:
443+
resolved: list[SourceConnection] = []
444+
for connection in connections:
445+
setting_name = _python_reference_setting_name(connection.target_reference)
446+
setting_targets = (
447+
_split_targets(host_settings.get(setting_name, ""))
448+
if setting_name
449+
else []
450+
)
451+
if setting_targets:
452+
for target in setting_targets:
453+
resolved.append(
454+
SourceConnection(
455+
target=target,
456+
source=connection.source,
457+
detail=connection.detail,
458+
interaction=connection.interaction,
459+
target_reference=connection.target_reference,
460+
)
461+
)
462+
continue
463+
if connection.target:
464+
resolved.append(connection)
465+
return resolved
466+
467+
468+
def _python_reference_setting_name(reference: str) -> str:
469+
if not reference.startswith("self."):
470+
return ""
471+
return reference.removeprefix("self.")
472+
473+
427474
def _python_call_target_node(call: ast.Call) -> ast.AST | None:
428475
for keyword in call.keywords:
429476
if keyword.arg == "target":
@@ -664,15 +711,16 @@ def _root_key(root: Path) -> str:
664711

665712
def _unique_connections(connections: list[SourceConnection]) -> list[SourceConnection]:
666713
unique: list[SourceConnection] = []
667-
seen: set[tuple[str, str, str, str]] = set()
714+
seen: set[tuple[str, str, str, str, str]] = set()
668715
for connection in connections:
669-
if not connection.target:
716+
if not connection.target and not connection.target_reference:
670717
continue
671718
key = (
672719
connection.target,
673720
connection.source,
674721
connection.detail,
675722
connection.interaction,
723+
connection.target_reference,
676724
)
677725
if key in seen:
678726
continue

src/tests/unit/test_production.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,110 @@ def on_message(self, request):
467467
]["interaction"] == "sync"
468468

469469

470+
def test_production_from_dict_prefers_python_host_setting_over_source_default(
471+
tmp_path,
472+
monkeypatch,
473+
):
474+
(tmp_path / "bench_bp.py").write_text(
475+
"""
476+
class BenchIoPProcess:
477+
target = target()
478+
479+
def on_init(self):
480+
if not hasattr(self, "target"):
481+
self.target = "Python.BenchIoPOperation"
482+
483+
def on_message(self, request):
484+
self.send_request_sync(self.target, request)
485+
""",
486+
encoding="utf-8",
487+
)
488+
monkeypatch.chdir(tmp_path)
489+
classpaths = str(tmp_path)
490+
491+
prod = Production.from_dict(
492+
{
493+
"Bench.Production": {
494+
"Item": [
495+
{
496+
"@Name": "Python.BenchIoPProcess",
497+
"@ClassName": "Python.BenchIoPProcess",
498+
"Setting": [
499+
{
500+
"@Target": "Host",
501+
"@Name": "%classpaths",
502+
"#text": classpaths,
503+
},
504+
{
505+
"@Target": "Host",
506+
"@Name": "target",
507+
"#text": "Python.BenchIoPOperation",
508+
},
509+
],
510+
},
511+
{
512+
"@Name": "Python.BenchIoPProcess.To.Cls",
513+
"@ClassName": "Python.BenchIoPProcess",
514+
"Setting": [
515+
{
516+
"@Target": "Host",
517+
"@Name": "%classpaths",
518+
"#text": classpaths,
519+
},
520+
{
521+
"@Target": "Host",
522+
"@Name": "target",
523+
"#text": "Bench.Operation",
524+
},
525+
],
526+
},
527+
{
528+
"@Name": "Python.BenchIoPOperation",
529+
"@ClassName": "Python.BenchIoPOperation",
530+
},
531+
{"@Name": "Bench.Operation", "@ClassName": "Bench.Operation"},
532+
]
533+
}
534+
},
535+
connections={
536+
"items": [
537+
{
538+
"item": "Python.BenchIoPProcess",
539+
"iop": True,
540+
"module": "bench_bp",
541+
"classname": "BenchIoPProcess",
542+
"classpaths": classpaths,
543+
"connections": [],
544+
},
545+
{
546+
"item": "Python.BenchIoPProcess.To.Cls",
547+
"iop": True,
548+
"module": "bench_bp",
549+
"classname": "BenchIoPProcess",
550+
"classpaths": classpaths,
551+
"connections": [],
552+
},
553+
]
554+
},
555+
)
556+
557+
edges = {
558+
(edge["source_item"], edge["target"]): edge
559+
for edge in prod.graph().to_dict()["edges"]
560+
}
561+
562+
assert (
563+
"Python.BenchIoPProcess.To.Cls",
564+
"Python.BenchIoPOperation",
565+
) not in edges
566+
assert edges[
567+
("Python.BenchIoPProcess", "Python.BenchIoPOperation")
568+
]["interaction"] == "sync"
569+
assert edges[
570+
("Python.BenchIoPProcess.To.Cls", "Bench.Operation")
571+
]["interaction"] == "sync"
572+
573+
470574
def test_production_from_dict_does_not_treat_python_prefix_as_python_source(
471575
tmp_path,
472576
monkeypatch,

0 commit comments

Comments
 (0)