diff --git a/pyproject.toml b/pyproject.toml index bc63ec921..9d38e5fbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "pyjwt[crypto]", "tomlkit", "graypy>=2.1.0", + "jinja2>=3.1.6", ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 2ae736cff..523732565 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -39,6 +39,7 @@ from blueapi.service.model import DeviceResponse, PlanResponse, SourceInfo, TaskRequest from blueapi.worker import ProgressEvent, WorkerEvent +from . import stubgen from .scratch import setup_scratch from .updates import CliEventRenderer @@ -152,6 +153,23 @@ def start_application(obj: dict): start(config) +@main.command() +@click.pass_obj +@click.argument("target", type=click.Path(file_okay=False)) +def generate_stubs(obj: dict, target: Path): + """ + Generate a type-stubs project for blueapi for the currently running server. + This enables users using blueapi as a library to benefit from type checking + and linting when writing scripts against the BlueapiClient. + """ + click.echo(f"Writing stubs to {target}") + + config: ApplicationConfig = obj["config"] + bc = BlueapiClient.from_config(config) + + stubgen.generate_stubs(Path(target), list(bc.plans), list(bc.devices)) + + @main.group() @click.option( "-o", diff --git a/src/blueapi/cli/stubgen.py b/src/blueapi/cli/stubgen.py new file mode 100644 index 000000000..6f6fbf4bd --- /dev/null +++ b/src/blueapi/cli/stubgen.py @@ -0,0 +1,117 @@ +import logging +from dataclasses import dataclass +from inspect import cleandoc +from pathlib import Path +from textwrap import dedent +from typing import Self, TextIO + +from jinja2 import Environment, PackageLoader + +from blueapi.client.cache import DeviceRef, Plan +from blueapi.core import context +from blueapi.core.bluesky_types import BLUESKY_PROTOCOLS + +log = logging.getLogger(__name__) + + +@dataclass +class ArgSpec: + name: str + type: str + optional: bool + + +@dataclass +class PlanSpec: + name: str + docs: str + args: list[ArgSpec] + + @classmethod + def from_plan(cls, plan: Plan) -> Self: + req = set(plan.required) + args = [ + ArgSpec(arg, _type_string(spec), arg not in req) + for arg, spec in plan.model.parameter_schema.get("properties", {}).items() + ] + return cls(plan.name, plan.help_text, args) + + +BLUESKY_PROTOCOL_NAMES = {context.qualified_name(proto) for proto in BLUESKY_PROTOCOLS} + + +def _type_string(spec) -> str: + """Best effort attempt at making useful type hints for plans""" + match spec.get("type"): + case "array": + return f"list[{_type_string(spec.get('items'))}]" + case "integer": + return "int" + case "number": + return "float" + case proto if proto in BLUESKY_PROTOCOL_NAMES: + return "DeviceRef" + case "object": + return "dict[str, Any]" + case "string": + return "str" + case "boolean": + return "bool" + case None if opts := spec.get("anyOf"): + return " | ".join(_type_string(opt) for opt in opts) + case _: + return "Any" + + +def generate_stubs(target: Path, plans: list[Plan], devices: list[DeviceRef]): + log.info("Generating stubs for %d plans and %d devices", len(plans), len(devices)) + target.mkdir(parents=True, exist_ok=True) + client_dir = target / "src" / "blueapi-stubs" / "client" + + log.debug("Making project structure: %s", client_dir) + client_dir.mkdir(parents=True, exist_ok=True) + + stub_file = client_dir / "cache.pyi" + project_file = target / "pyproject.toml" + py_typed = target / "src" / "blueapi-stubs" / "py.typed" + + log.debug("Writing pyproject.toml to %s", project_file) + with open(project_file, "w") as out: + out.write( + dedent(""" + [project] + name = "blueapi-stubs" + version = "0.1.0" + description = "Generated client stubs for a running server" + readme = "README.md" + requires-python = ">=3.11" + + dependencies = [ + "blueapi" + ] + """) + ) + + log.debug("Writing py.typed file to %s", py_typed) + with open(py_typed, "w") as out: + out.write("partial\n") + + log.debug("Writing stub file to %s", stub_file) + with open(stub_file, "w") as out: + render_stub_file(out, plans, devices) + + +def _docstring(text: str) -> str: + # """Convert a docstring to a format that can be inserted into the template""" + return cleandoc(text).replace('"""', '\\"""') + + +def render_stub_file( + stub_file: TextIO, plan_models: list[Plan], devices: list[DeviceRef] +): + plans = [PlanSpec.from_plan(p) for p in plan_models] + + env = Environment(loader=PackageLoader("blueapi", package_path="stubs/templates")) + env.filters["docstring"] = _docstring + tmpl = env.get_template("cache_template.pyi.jinja") + stub_file.write(tmpl.render(plans=plans, devices=devices)) diff --git a/src/blueapi/client/cache.py b/src/blueapi/client/cache.py new file mode 100644 index 000000000..0ec8c4c87 --- /dev/null +++ b/src/blueapi/client/cache.py @@ -0,0 +1,177 @@ +import logging +from collections.abc import Callable +from itertools import chain +from typing import Any + +from blueapi.client.rest import BlueapiRestClient +from blueapi.service.model import DeviceModel, PlanModel +from blueapi.worker.event import WorkerEvent + +log = logging.getLogger(__name__) + + +# This file should be kept in sync with the type stub template in stubs/templates + + +PlanRunner = Callable[[str, dict[str, Any]], WorkerEvent] + + +class PlanCache: + """ + Cache of plans available on the server + """ + + def __init__(self, runner: PlanRunner, plans: list[PlanModel]): + self._cache = {model.name: Plan(model=model, runner=runner) for model in plans} + for name, plan in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, plan) + + def __getitem__(self, name: str) -> "Plan": + return self._cache[name] + + def __getattr__(self, name: str) -> "Plan": + raise AttributeError(f"No plan named '{name}' available") + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"PlanCache({len(self._cache)} plans)" + + +class Plan: + """ + An interface to a plan on the blueapi server + + This allows remote plans to be called (mostly) as if they were local + methods when writing user scripts. + + If you are seeing this help while using blueapi as a library, generating + type stubs may be helpful for type checking and plan discovery, eg + + blueapi generate-stubs /tmp/blueapi-stubs + uv add --editable /tmp/blueapi-stubs + + """ + + model: PlanModel + + def __init__(self, model: PlanModel, runner: PlanRunner): + self.model = model + self._runner = runner + self.__doc__ = model.description + + def __call__(self, *args, **kwargs) -> WorkerEvent: + """ + Run the plan on the server mapping the given args into the required parameters + """ + return self._runner(self.name, self._build_args(*args, **kwargs)) + + @property + def name(self) -> str: + return self.model.name + + @property + def help_text(self) -> str: + return self.model.description or f"Plan {self!r}" + + @property + def properties(self) -> set[str]: + return self.model.parameter_schema.get("properties", {}).keys() + + @property + def required(self) -> list[str]: + return self.model.parameter_schema.get("required", []) + + def _build_args(self, *args, **kwargs): + log.info( + "Building args for %s, using %s and %s", + "[" + ",".join(self.properties) + "]", + args, + kwargs, + ) + + if len(args) > len(self.properties): + raise TypeError(f"{self.name} got too many arguments") + if extra := {k for k in kwargs if k not in self.properties}: + raise TypeError(f"{self.name} got unexpected arguments: {extra}") + + params = {} + # Initially fill parameters using positional args assuming the order + # from the parameter_schema + for req, arg in zip(self.properties, args, strict=False): + params[req] = arg + + # Then append any values given via kwargs + for key, value in kwargs.items(): + # If we've already assumed a positional arg was this value, bail out + if key in params: + raise TypeError(f"{self.name} got multiple values for {key}") + params[key] = value + + if missing := {k for k in self.required if k not in params}: + raise TypeError(f"Missing argument(s) for {missing}") + return params + + def __repr__(self): + opts = [p for p in self.properties if p not in self.required] + params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) + return f"{self.name}({params})" + + +class DeviceCache: + def __init__(self, rest: BlueapiRestClient): + self._rest = rest + self._cache = { + model.name: DeviceRef(name=model.name, cache=self, model=model) + for model in rest.get_devices().devices + } + for name, device in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, device) + + def __getitem__(self, name: str) -> "DeviceRef": + if dev := self._cache.get(name): + return dev + try: + model = self._rest.get_device(name) + device = DeviceRef(name=name, cache=self, model=model) + self._cache[name] = device + setattr(self, model.name, device) + return device + except KeyError: + pass + raise AttributeError(f"No device named '{name}' available") + + def __getattr__(self, name: str) -> "DeviceRef": + if name.startswith("_"): + return super().__getattribute__(name) + return self[name] + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"DeviceCache({len(self._cache)} devices)" + + +class DeviceRef(str): + model: DeviceModel + _cache: DeviceCache + + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): + instance = super().__new__(cls, name) + instance.model = model + instance._cache = cache + return instance + + def __getattr__(self, name) -> "DeviceRef": + if name.startswith("_"): + raise AttributeError(f"No child device named {name}") + return self._cache[f"{self}.{name}"] + + def __repr__(self): + return f"Device({self})" diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index e6f1e83e3..39fbf5054 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -4,9 +4,8 @@ from collections.abc import Iterable from concurrent.futures import Future from functools import cached_property -from itertools import chain from pathlib import Path -from typing import Self +from typing import Any, Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -23,10 +22,8 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import SessionManager from blueapi.service.model import ( - DeviceModel, EnvironmentResponse, OIDCConfig, - PlanModel, PythonEnvironmentResponse, SourceInfo, TaskRequest, @@ -36,6 +33,7 @@ from blueapi.worker import WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus +from .cache import DeviceCache, PlanCache from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent from .rest import BlueapiRestClient, BlueskyRemoteControlError @@ -49,149 +47,6 @@ class MissingInstrumentSessionError(Exception): pass -class PlanCache: - def __init__(self, client: "BlueapiClient", plans: list[PlanModel]): - self._cache = { - model.name: Plan(name=model.name, model=model, client=client) - for model in plans - } - for name, plan in self._cache.items(): - if name.startswith("_"): - continue - setattr(self, name, plan) - - def __getitem__(self, name: str) -> "Plan": - return self._cache[name] - - def __getattr__(self, name: str) -> "Plan": - raise AttributeError(f"No plan named '{name}' available") - - def __iter__(self): - return iter(self._cache.values()) - - def __repr__(self) -> str: - return f"PlanCache({len(self._cache)} plans)" - - -class DeviceCache: - def __init__(self, rest: BlueapiRestClient): - self._rest = rest - self._cache = { - model.name: DeviceRef(name=model.name, cache=self, model=model) - for model in rest.get_devices().devices - } - for name, device in self._cache.items(): - if name.startswith("_"): - continue - setattr(self, name, device) - - def __getitem__(self, name: str) -> "DeviceRef": - if dev := self._cache.get(name): - return dev - try: - model = self._rest.get_device(name) - device = DeviceRef(name=name, cache=self, model=model) - self._cache[name] = device - setattr(self, model.name, device) - return device - except KeyError: - pass - raise AttributeError(f"No device named '{name}' available") - - def __getattr__(self, name: str) -> "DeviceRef": - if name.startswith("_"): - return super().__getattribute__(name) - return self[name] - - def __iter__(self): - return iter(self._cache.values()) - - def __repr__(self) -> str: - return f"DeviceCache({len(self._cache)} devices)" - - -class DeviceRef(str): - model: DeviceModel - _cache: DeviceCache - - def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): - instance = super().__new__(cls, name) - instance.model = model - instance._cache = cache - return instance - - def __getattr__(self, name) -> "DeviceRef": - if name.startswith("_"): - raise AttributeError(f"No child device named {name}") - return self._cache[f"{self}.{name}"] - - def __repr__(self): - return f"Device({self})" - - -class Plan: - def __init__(self, name, model: PlanModel, client: "BlueapiClient"): - self.name = name - self.model = model - self._client = client - self.__doc__ = model.description - - def __call__(self, *args, **kwargs): - req = TaskRequest( - name=self.name, - params=self._build_args(*args, **kwargs), - instrument_session=self._client.instrument_session, - ) - self._client.run_task(req) - - @property - def help_text(self) -> str: - return self.model.description or f"Plan {self!r}" - - @property - def properties(self) -> set[str]: - return self.model.parameter_schema.get("properties", {}).keys() - - @property - def required(self) -> list[str]: - return self.model.parameter_schema.get("required", []) - - def _build_args(self, *args, **kwargs): - log.info( - "Building args for %s, using %s and %s", - "[" + ",".join(self.properties) + "]", - args, - kwargs, - ) - - if len(args) > len(self.properties): - raise TypeError(f"{self.name} got too many arguments") - if extra := {k for k in kwargs if k not in self.properties}: - raise TypeError(f"{self.name} got unexpected arguments: {extra}") - - params = {} - # Initially fill parameters using positional args assuming the order - # from the parameter_schema - for req, arg in zip(self.properties, args, strict=False): - params[req] = arg - - # Then append any values given via kwargs - for key, value in kwargs.items(): - # If we've already assumed a positional arg was this value, bail out - if key in params: - raise TypeError(f"{self.name} got multiple values for {key}") - params[key] = value - - if missing := {k for k in self.required if k not in params}: - raise TypeError(f"Missing argument(s) for {missing}") - return params - - def __repr__(self): - opts = [p for p in self.properties if p not in self.required] - params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) - return f"{self.name}({params})" - - class BlueapiClient: """Unified client for controlling blueapi""" @@ -214,7 +69,7 @@ def __init__( @cached_property @start_as_current_span(TRACER) def plans(self) -> PlanCache: - return PlanCache(self, self._rest.get_plans().plans) + return PlanCache(self.run_plan, self._rest.get_plans().plans) @cached_property @start_as_current_span(TRACER) @@ -333,6 +188,15 @@ def active_task(self) -> WorkerTask: return self._rest.get_active_task() + @start_as_current_span(TRACER, "name", "params") + def run_plan(self, name: str, params: dict[str, Any]) -> WorkerEvent: + req = TaskRequest( + name=name, + params=params, + instrument_session=self.instrument_session, + ) + return self.run_task(req) + @start_as_current_span(TRACER, "task", "timeout") def run_task( self, diff --git a/src/blueapi/stubs/templates/cache_template.pyi.jinja b/src/blueapi/stubs/templates/cache_template.pyi.jinja new file mode 100644 index 000000000..b06ef3638 --- /dev/null +++ b/src/blueapi/stubs/templates/cache_template.pyi.jinja @@ -0,0 +1,72 @@ +from collections.abc import Callable +from typing import Any +from blueapi.client.rest import BlueapiRestClient +from blueapi.service.model import DeviceModel, PlanModel +from blueapi.worker.event import WorkerEvent + +{#- + This file is based on the cache.py file in blueapi/client/cache.py and should + be kept in sync with changes there. +#} + +# This file is auto-generated for a live server and should not be modified directly + +PlanRunner = Callable[[str, dict[str, Any]], WorkerEvent] + +class PlanCache: + def __init__(self, runner: PlanRunner, plans: list[PlanModel]) -> None: ... + def __getitem__(self, name: str) -> Plan: ... + def __iter__(self): # -> Iterator[Plan]: + ... + def __repr__(self) -> str: ... + +### Generated plans +{%- for item in plans %} + def {{ item.name }}(self,{% for arg in item.args %} + {{ arg.name }}: {{ arg.type }}{% if arg.optional %} | None = None{% endif %}, + {%- endfor %} + ) -> WorkerEvent: + """ + {{ item.docs | docstring | indent(8) }} + """ + ... +{%- endfor %} +### End + + +class Plan: + model: PlanModel + def __init__(self, model: PlanModel, runner: PlanRunner) -> None: ... + def __call__(self, *args, **kwargs): # -> None: + ... + + @property + def name(self) -> str: ... + @property + def help_text(self) -> str: ... + @property + def properties(self) -> set[str]: ... + @property + def required(self) -> list[str]: ... + def __repr__(self) -> str: ... + + +class DeviceRef(str): + model: DeviceModel + _cache: DeviceCache + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): ... + def __getattr__(self, name) -> DeviceRef: ... + def __repr__(self) -> str: ... + +class DeviceCache: + def __init__(self, rest: BlueapiRestClient) -> None: ... + def __getitem__(self, name: str) -> DeviceRef: ... + def __iter__(self): # -> Iterator[DeviceRef]: + ... + def __repr__(self) -> str: ... + +### Generated devices + {%- for item in devices %} + {{ item }}: DeviceRef + {%- endfor %} +### End diff --git a/tests/unit_tests/cli/test_stubgen.py b/tests/unit_tests/cli/test_stubgen.py new file mode 100644 index 000000000..766f3e2f0 --- /dev/null +++ b/tests/unit_tests/cli/test_stubgen.py @@ -0,0 +1,214 @@ +from io import StringIO +from textwrap import dedent +from types import FunctionType +from unittest.mock import Mock + +import pytest + +from blueapi.cli.stubgen import ( + _docstring, + _type_string, + generate_stubs, + render_stub_file, +) +from blueapi.client.cache import DeviceRef, Plan +from blueapi.service.model import DeviceModel, PlanModel + + +def single_line(): + """Single line docstring""" + + +def single_line_new_line(): + """ + Single line docstring + """ + + +def multi_line_inline(): + """First line + Second line""" + + +def multi_line_new_line(): + """ + First line + Second line + """ + + +def indented_multi_line(): + """ + First line + indented + """ + + +@pytest.mark.parametrize( + "input,expected", + [ + (single_line, "Single line docstring"), + (single_line_new_line, "Single line docstring"), + (multi_line_inline, "First line\nSecond line"), + (multi_line_new_line, "First line\nSecond line"), + (indented_multi_line, "First line\n indented"), + ], +) +def test_docstring_filter(input: FunctionType, expected: str): + assert input.__doc__ + assert _docstring(input.__doc__) == expected + + +@pytest.mark.parametrize( + "typ,expected", + [ + ({"type": "string"}, "str"), + ({"type": "number"}, "float"), + ({"type": "integer"}, "int"), + ({"type": "object"}, "dict[str, Any]"), + ({"type": "boolean"}, "bool"), + ({"type": "array", "items": {"type": "integer"}}, "list[int]"), + ({"type": "array", "items": {"type": "object"}}, "list[dict[str, Any]]"), + ( + { + "type": "array", + "items": {"anyOf": [{"type": "integer"}, {"type": "boolean"}]}, + }, + "list[int | bool]", + ), + ({"anyOf": [{"type": "object"}, {"type": "string"}]}, "dict[str, Any] | str"), + ({"type": "unknown.other.Type"}, "Any"), + # Special case the bluesky protocols to require device references + ({"type": "bluesky.protocols.Readable"}, "DeviceRef"), + ({}, "Any"), + ], + ids=lambda param: param.get("type") if isinstance(param, dict) else param, +) +def test_type_string(typ: dict, expected: str): + assert _type_string(typ) == expected + + +def test_render_empty(): + output = StringIO() + + render_stub_file(output, [], []) + plan_text, device_text = _extract_rendered(output) + + assert plan_text == "" + assert device_text == "" + + +FOO = PlanModel(name="empty", description="Doc string for empty", schema={}) + +BAR = PlanModel( + name="two_args", + description="Doc string for two_args", + schema={ + "properties": { + "one": {"type": "integer"}, + "two": {"type": "string"}, + }, + "required": ["one"], + }, +) + + +def test_render_empty_plan_function(): + output = StringIO() + plans = [Plan(model=FOO, runner=Mock())] + render_stub_file(output, plans, []) + plan_text, device_text = _extract_rendered(output) + + assert device_text == "" + + assert ( + plan_text + == """\ + def empty(self, + ) -> WorkerEvent: + \""" + Doc string for empty + \""" + ...\n""" + ) + + +def test_render_multiple_plan_functions(): + output = StringIO() + runner = Mock() + plans = [Plan(FOO, runner), Plan(BAR, runner)] + render_stub_file(output, plans, []) + plan_text, device_text = _extract_rendered(output) + assert device_text == "" + + assert ( + plan_text + == """\ + def empty(self, + ) -> WorkerEvent: + \""" + Doc string for empty + \""" + ... + def two_args(self, + one: int, + two: str | None = None, + ) -> WorkerEvent: + \""" + Doc string for two_args + \""" + ...\n""" + ) + + +def test_device_fields(): + output = StringIO() + cache = Mock() + devices = [ + DeviceRef("one", cache, DeviceModel(name="one", protocols=[])), + DeviceRef("two", cache, DeviceModel(name="two", protocols=[])), + ] + render_stub_file(output, [], devices) + + plan_text, device_text = _extract_rendered(output) + assert plan_text == "" + assert device_text == " one: DeviceRef\n two: DeviceRef\n" + + +def test_package_creation(tmp_path): + generate_stubs(tmp_path / "blueapi-stubs", [], []) + with open(tmp_path / "blueapi-stubs" / "pyproject.toml") as pyproj: + assert pyproj.read().startswith( + dedent(""" + [project] + name = "blueapi-stubs" + version = "0.1.0" + """) + ) + with open( + tmp_path / "blueapi-stubs" / "src" / "blueapi-stubs" / "py.typed" + ) as typed: + assert typed.read() == "partial\n" + + assert ( + tmp_path / "blueapi-stubs" / "src" / "blueapi-stubs" / "client" / "cache.pyi" + ).exists() + + +def _extract_rendered(src: StringIO) -> tuple[str, str]: + src.seek(0) + _read_until_line(src, "### Generated plans") + plan_text = _read_until_line(src, "### End") + _read_until_line(src, "### Generated devices") + device_text = _read_until_line(src, "### End") + return plan_text, device_text + + +def _read_until_line(src: StringIO, match: str) -> str: + text = "" + for line in src: + if line.startswith(match): + break + text += line + + return text diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 98aad7871..8bbc54578 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -10,13 +10,10 @@ ) from pydantic import HttpUrl +from blueapi.client.cache import DeviceCache, DeviceRef, Plan, PlanCache from blueapi.client.client import ( BlueapiClient, - DeviceCache, - DeviceRef, MissingInstrumentSessionError, - Plan, - PlanCache, ) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError @@ -409,6 +406,15 @@ def test_run_task_fails_on_failing_event( on_event.assert_called_with(FAILED_EVENT) +@patch("blueapi.client.client.BlueapiClient.run_task") +def test_run_plan(run_task, client, mock_rest): + client.instrument_session = "cm12345-2" + client.run_plan("foo", {"foo": "bar"}) + run_task.assert_called_once_with( + TaskRequest(name="foo", params={"foo": "bar"}, instrument_session="cm12345-2") + ) + + @pytest.mark.parametrize( "test_event", [ @@ -677,40 +683,38 @@ def test_device_ignores_underscores(): cache.__getitem__.assert_not_called() -def test_plan_help_text(client): - plan = Plan("foo", PlanModel(name="foo", description="help for foo"), client) +def test_plan_help_text(): + plan = Plan(PlanModel(name="foo", description="help for foo"), Mock()) assert plan.help_text == "help for foo" -def test_plan_fallback_help_text(client): +def test_plan_fallback_help_text(): plan = Plan( - "foo", PlanModel( name="foo", schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, ), - client, + Mock(), ) assert plan.help_text == "Plan foo(one, two=None)" -def test_plan_properties(client): +def test_plan_properties(): plan = Plan( - "foo", PlanModel( name="foo", schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, ), - client, + Mock(), ) assert plan.properties == {"one", "two"} assert plan.required == ["one"] -def test_plan_empty_fallback_help_text(client): +def test_plan_empty_fallback_help_text(): plan = Plan( - "foo", PlanModel(name="foo", schema={"properties": {}, "required": []}), client + PlanModel(name="foo", schema={"properties": {}, "required": []}), Mock() ) assert plan.help_text == "Plan foo()" @@ -729,18 +733,11 @@ def test_plan_empty_fallback_help_text(client): ], ) def test_plan_param_mapping(args, kwargs, params): - client = Mock() - client.instrument_session = "cm12345-1" - plan = Plan( - FULL_PLAN.name, - FULL_PLAN, - client, - ) + runner = Mock() + plan = Plan(FULL_PLAN, runner) plan(*args, **kwargs) - client.run_task.assert_called_once_with( - TaskRequest(name="foobar", instrument_session="cm12345-1", params=params) - ) + runner.assert_called_once_with("foobar", params) @pytest.mark.parametrize( @@ -759,17 +756,15 @@ def test_plan_param_mapping(args, kwargs, params): ], ) def test_plan_invalid_param_mapping(args, kwargs, msg): - client = Mock() - client.instrument_session = "cm12345-1" + runner = Mock(spec=Callable) plan = Plan( - FULL_PLAN.name, FULL_PLAN, - client, + runner, ) with pytest.raises(TypeError, match=msg): plan(*args, **kwargs) - client.run_task.assert_not_called() + runner.assert_not_called() def test_adding_removing_callback(client): diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 7210ec2bb..a789d21eb 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -1329,3 +1329,17 @@ def test_config_schema( stream.write.assert_called() else: assert json.loads(result.output) == expected + pass + + +@patch("blueapi.client.client.BlueapiClient.from_config") +@patch("blueapi.cli.cli.stubgen") +def test_genstubs( + stubgen, + client, + runner: CliRunner, +): + runner.invoke(main, ["generate-stubs", "/path/to/stub_dir"]) + stubgen.generate_stubs.assert_called_once_with( + Path("/path/to/stub_dir"), list(client().plans), list(client().devices) + ) diff --git a/uv.lock b/uv.lock index efbb556aa..d44bac795 100644 --- a/uv.lock +++ b/uv.lock @@ -438,6 +438,7 @@ dependencies = [ { name = "fastapi" }, { name = "gitpython" }, { name = "graypy" }, + { name = "jinja2" }, { name = "observability-utils" }, { name = "opentelemetry-distro" }, { name = "opentelemetry-instrumentation-fastapi" }, @@ -495,6 +496,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.112.0" }, { name = "gitpython" }, { name = "graypy", specifier = ">=2.1.0" }, + { name = "jinja2", specifier = ">=3.1.6" }, { name = "observability-utils", specifier = ">=0.1.4" }, { name = "opentelemetry-distro", specifier = ">=0.48b0" }, { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.48b0" },