diff --git a/docs/tutorial.md b/docs/tutorial.md index 768606e8..5d7afe22 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -334,6 +334,8 @@ Mellea's core abstraction is called a `Component`. A `Component` is a structured Components are composite data structures; that is, a `Component` can be made up of many other parts. Each of those parts is either a `CBlock` or another `Component`. `CBlock`s, or "content blocks", are an atomic unit of text or data. CBlocks hold raw text (or sometimes parsed representations) and can be used as leaves in the Component DAG. +Components can also specify an expected output type along with a parse function to extract that type from the LLM's output. By default, this type is a string; but by defining a Component's expected type, you can get type hinting for outputs in the standard library. + Backends are the engine that actually run the LLM. Backends consume Components, format the Component, pass the formatted input to an LLM, and return model outputs, which are then parsed back into CBlocks or Components. During the course of an interaction with an LLM, several Components and CBlocks may be created. Logic for handling this trace of interactions is provided by a `Context` object. Some book-keeping needs to be done in order for Contexts to approporiately handle a trace of Components and CBlocks. The `MelleaSession` class, which is created by `mellea.start_session()`, does this book-keeping a simple wrapper around Contexts and Backends. diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index 097a184a..2ebb0059 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -5,17 +5,24 @@ import abc import asyncio import itertools -from typing import TypeVar +from collections.abc import Sequence +from typing import Any, overload import pydantic +import typing_extensions from mellea.backends.model_ids import ModelIdentifier from mellea.backends.types import ModelOption from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk - -BaseModelSubclass = TypeVar( - "BaseModelSubclass", bound=pydantic.BaseModel +from mellea.stdlib.base import C, CBlock, Component, Context, ModelOutputThunk + +# Necessary to define a type that supports `None` so that the BaseModelSubclass +# can have a default value. Otherwise, Python complains about typed-components +# since types with default values must come after those without default values in +# function signatures (which is incompatible with our function parameter formatting). +pydantic_model_or_none = pydantic.BaseModel | None +BaseModelSubclass = typing_extensions.TypeVar( + "BaseModelSubclass", bound=pydantic_model_or_none, default=None ) # must be a subclass of BaseModel @@ -39,13 +46,13 @@ def __init__( @abc.abstractmethod async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> tuple[ModelOutputThunk, Context]: + ) -> tuple[ModelOutputThunk[C], Context]: """Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request. Args: @@ -60,10 +67,32 @@ async def generate_from_context( """ ... + @overload + async def generate_from_raw( + self, + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload + async def generate_from_raw( + self, + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + @abc.abstractmethod async def generate_from_raw( self, - actions: list[Component | CBlock], + actions: Sequence[Component[C] | CBlock], ctx: Context, *, format: type[BaseModelSubclass] | None = None, diff --git a/mellea/backends/dummy.py b/mellea/backends/dummy.py index e8673cd6..3b45999d 100644 --- a/mellea/backends/dummy.py +++ b/mellea/backends/dummy.py @@ -1,7 +1,7 @@ """This module holds shim backends used for smoke tests.""" from mellea.backends import Backend, BaseModelSubclass -from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk +from mellea.stdlib.base import C, CBlock, Component, Context, ModelOutputThunk class DummyBackend(Backend): @@ -18,13 +18,13 @@ def __init__(self, responses: list[str] | None): async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> tuple[ModelOutputThunk, Context]: + ) -> tuple[ModelOutputThunk[C], Context]: """See constructor for an exmplanation of how DummyBackends work.""" assert format is None, "The DummyBackend does not support constrained decoding." if self.responses is None: diff --git a/mellea/backends/formatter.py b/mellea/backends/formatter.py index 7a060f42..0a3bcb35 100644 --- a/mellea/backends/formatter.py +++ b/mellea/backends/formatter.py @@ -20,7 +20,7 @@ ModelOutputThunk, TemplateRepresentation, ) -from mellea.stdlib.chat import Message, ToolMessage +from mellea.stdlib.chat import Message class Formatter(abc.ABC): @@ -31,16 +31,6 @@ def print(self, c: Component | CBlock) -> str: """Renders a component for input to a model.""" ... - @abc.abstractmethod - def parse( - self, source_component: Component | CBlock, result: ModelOutputThunk - ) -> ModelOutputThunk: - """Parses the output from a model and sets the parsed_repr of the result ModelOutputThunk. - - Returns the ModelOutputThunk that was passed in. - """ - ... - def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]: """Helper method that converts a linearized chat history into a list of messages. The purpose of this helper is to prepare a sequence of Messages for input to a chat endpoint.""" @@ -58,7 +48,12 @@ def _to_msg(c: Component | CBlock) -> Message: ) # This is already entailed by c.is_computed(); the line is included here to satisfy the type-checker. if c.parsed_repr is not None: - c = c.parsed_repr # This might be a message. + if isinstance(c.parsed_repr, Component): + # Only use the parsed_repr if it's something that we know how to print. + c = c.parsed_repr # This might be a message. + else: + # Otherwise, explicitly stringify it. + c = Message(role=role, content=str(c.parsed_repr)) else: c = Message(role=role, content=c.value) # type: ignore @@ -104,64 +99,6 @@ def __init__( # Key: obj.__class__.__name___ -> Value: jinja2.Template self._template_cache = SimpleLRUCache(10) if self._use_template_cache else None - def parse( - self, source_component: Component | CBlock, result: ModelOutputThunk - ) -> ModelOutputThunk: - """Parses the output and updates the result's parsed_repr.""" - parsed = self._parse(source_component=source_component, result=result) - result.parsed_repr = parsed - return result - - def _parse( - self, source_component: Component | CBlock, result: ModelOutputThunk - ) -> CBlock | Component: - """Parses the output from a model.""" - if result.tool_calls is not None: - # A tool was successfully requested. - # Assistant responses for tool calling differ by backend. For the default formatter, - # we put all of the function data into the content field in the same format we received it. - - # Chat backends should provide an openai-like object in the _meta chat response, which we can use to properly format this output. - if "chat_response" in result._meta: - # Ollama. - return Message( - role=result._meta["chat_response"].message.role, - content=str(result._meta["chat_response"].message.tool_calls), - ) - elif "oai_chat_response" in result._meta: - # OpenAI and Watsonx. - return Message( - role=result._meta["oai_chat_response"]["message"]["role"], - content=str( - result._meta["oai_chat_response"]["message"].get( - "tool_calls", [] - ) - ), - ) - else: - # HuggingFace (or others). There are no guarantees on how the model represented the function calls. - # Output it in the same format we received the tool call request. - assert result.value is not None - return Message(role="assistant", content=result.value) - - if type(source_component) is Message: - if "chat_response" in result._meta: - # chat backends should provide an openai-like object in the _meta chat response, which we can use to properly format this output. - return Message( - role=result._meta["chat_response"].message.role, - content=result._meta["chat_response"].message.content, - ) - elif "oai_chat_response" in result._meta: - return Message( - role=result._meta["oai_chat_response"]["message"]["role"], - content=result._meta["oai_chat_response"]["message"]["content"], - ) - else: - assert result.value is not None - return Message(role="assistant", content=result.value) - else: - return result - def _stringify( self, c: ( diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index ff1e92da..7dca2cac 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -13,9 +13,9 @@ import inspect import json import threading -from collections.abc import Callable, Coroutine +from collections.abc import Callable, Coroutine, Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast, overload import granite_common import outlines @@ -42,7 +42,6 @@ LocalHFAdapter, get_adapter_for_intrinsic, ) -from mellea.backends.adapters.catalog import fetch_intrinsic_metadata from mellea.backends.cache import Cache, SimpleLRUCache from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier @@ -57,13 +56,13 @@ from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( + C, CBlock, Component, Context, GenerateLog, GenerateType, ModelOutputThunk, - ModelToolCall, ) from mellea.stdlib.chat import Message from mellea.stdlib.intrinsics.intrinsic import Intrinsic @@ -201,13 +200,13 @@ def _make_dc_cache(self, toks, **model_options): async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ): + ) -> tuple[ModelOutputThunk[C], Context]: """Generate using the huggingface model.""" await self.do_generate_walk(action) @@ -558,13 +557,13 @@ def _make_merged_kv_cache( async def _generate_from_context_with_kv_cache( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, model_options: dict[str, Any], tool_calls: bool = False, - ) -> ModelOutputThunk: + ) -> ModelOutputThunk[C]: # Construct input. # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. # Otherwise, we will linearize the context and treat it as a raw input. @@ -598,7 +597,7 @@ async def _generate_from_context_with_kv_cache( if _format: # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk - schema: dict[str, Any] = _format.model_json_schema() + schema: dict[str, Any] = _format.model_json_schema() # type: ignore schema_json: str = json.dumps(schema) regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json @@ -765,7 +764,7 @@ async def _generate_from_context_standard( if _format: # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk - schema: dict[str, Any] = _format.model_json_schema() + schema: dict[str, Any] = _format.model_json_schema() # type: ignore schema_json: str = json.dumps(schema) regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json @@ -931,8 +930,6 @@ async def post_processing( "ModelOutputThunks should have their model_opts assigned during generation" ) - self.formatter.parse(mot._action, mot) - # Generate the log for this ModelOutputThunk. generate_log = GenerateLog() generate_log.prompt = conversation @@ -951,9 +948,31 @@ async def post_processing( mot._generate_log = generate_log + @overload async def generate_from_raw( self, - actions: list[Component | CBlock], + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload + async def generate_from_raw( + self, + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], ctx: Context, *, format: type[BaseModelSubclass] | None = None, @@ -961,7 +980,7 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" - await self.do_generate_walks(actions) + await self.do_generate_walks(list(actions)) if tool_calls: FancyLogger.get_logger().warning( @@ -992,7 +1011,7 @@ async def generate_from_raw( if format: # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk - schema: dict[str, Any] = format.model_json_schema() + schema: dict[str, Any] = format.model_json_schema() # type: ignore schema_json: str = json.dumps(schema) regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json @@ -1046,8 +1065,10 @@ async def generate_from_raw( } }, ) - - self.formatter.parse(actions[i], result) + action = actions[i] + result.parsed_repr = ( + action.parse(result) if isinstance(action, Component) else result.value + ) generate_log = GenerateLog() generate_log.prompt = self.formatter.print(actions[i]) @@ -1056,7 +1077,7 @@ async def generate_from_raw( generate_log.date = datetime.datetime.now() generate_log.model_output = decoded_result generate_log.extra = {"format": format, "seed": seed} - generate_log.action = actions[i] + generate_log.action = action result._generate_log = generate_log results.append(result) diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 80adbc8b..a1de8951 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -5,8 +5,8 @@ import functools import json import os -from collections.abc import Callable, Coroutine -from typing import Any +from collections.abc import Callable, Coroutine, Sequence +from typing import Any, overload import litellm # type: ignore import litellm.litellm_core_utils # type: ignore @@ -29,6 +29,7 @@ extract_model_tool_requests, ) from mellea.stdlib.base import ( + C, CBlock, Component, Context, @@ -112,13 +113,13 @@ def __init__( async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ): + ) -> tuple[ModelOutputThunk[C], Context]: """See `generate_from_chat_context`.""" assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." @@ -233,14 +234,14 @@ def _make_backend_specific_and_remove( async def _generate_from_chat_context_standard( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: + ) -> ModelOutputThunk[C]: await self.do_generate_walk(action) model_opts = self._simplify_and_merge(model_options) @@ -278,7 +279,7 @@ async def _generate_from_chat_context_standard( "type": "json_schema", "json_schema": { "name": _format.__name__, - "schema": _format.model_json_schema(), + "schema": _format.model_json_schema(), # type: ignore "strict": True, }, } @@ -437,8 +438,6 @@ async def post_processing( for key, val in tool_chunk.items(): mot.tool_calls[key] = val - self.formatter.parse(mot._action, mot) - # Generate the log for this ModelOutputThunk. generate_log = GenerateLog() generate_log.prompt = conversation @@ -476,9 +475,31 @@ def _extract_tools( FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") return tools + @overload + async def generate_from_raw( + self, + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload async def generate_from_raw( self, - actions: list[Component | CBlock], + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], ctx: Context, *, format: type[BaseModelSubclass] | None = None, @@ -486,7 +507,7 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" - await self.do_generate_walks(actions) + await self.do_generate_walks(list(actions)) extra_body = {} if format is not None: FancyLogger.get_logger().warning( @@ -495,7 +516,7 @@ async def generate_from_raw( ) # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. - extra_body["guided_json"] = format.model_json_schema() + extra_body["guided_json"] = format.model_json_schema() # type: ignore if tool_calls: FancyLogger.get_logger().warning( "The completion endpoint does not support tool calling." @@ -539,7 +560,9 @@ async def generate_from_raw( else None, } - self.formatter.parse(action, output) + output.parsed_repr = ( + action.parse(output) if isinstance(action, Component) else output.value + ) generate_log = GenerateLog() generate_log.prompt = prompt diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index b2b8cb39..94bfb2cc 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -3,8 +3,8 @@ import asyncio import datetime import functools -from collections.abc import AsyncIterator, Callable, Coroutine -from typing import Any +from collections.abc import AsyncIterator, Callable, Coroutine, Sequence +from typing import Any, overload import ollama from tqdm import tqdm @@ -26,6 +26,7 @@ from mellea.helpers.event_loop_helper import _run_async_in_thread from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( + C, CBlock, Component, Context, @@ -254,13 +255,13 @@ def _make_backend_specific_and_remove( async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ): + ) -> tuple[ModelOutputThunk[C], Context]: """See `generate_from_chat_context`.""" assert ctx.is_chat_context, ( "The ollama backend only supports chat-like contexts." @@ -277,13 +278,13 @@ async def generate_from_context( async def generate_from_chat_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: + ) -> ModelOutputThunk[C]: """Generates a ModelOutputThunk. The final value for this object can be awaited. The new completion is generated from the provided Context using this backend's `Formatter`. @@ -354,7 +355,7 @@ async def generate_from_chat_context( think=model_opts.get(ModelOption.THINKING, None), stream=model_opts.get(ModelOption.STREAM, False), options=self._make_backend_specific_and_remove(model_opts), - format=_format.model_json_schema() if _format is not None else None, + format=_format.model_json_schema() if _format is not None else None, # type: ignore ) # type: ignore output = ModelOutputThunk(None) @@ -390,9 +391,31 @@ async def generate_from_chat_context( return output + @overload + async def generate_from_raw( + self, + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload async def generate_from_raw( self, - actions: list[Component | CBlock], + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], ctx: Context, *, format: type[BaseModelSubclass] | None = None, @@ -430,7 +453,7 @@ async def generate_from_raw( prompt=prompt, raw=True, think=model_opts.get(ModelOption.THINKING, None), - format=format.model_json_schema() if format is not None else None, + format=format.model_json_schema() if format is not None else None, # type: ignore options=self._make_backend_specific_and_remove(model_opts), ) coroutines.append(co) @@ -462,8 +485,10 @@ async def generate_from_raw( }, }, ) - - self.formatter.parse(actions[i], result) + action = actions[i] + result.parsed_repr = ( + action.parse(result) if isinstance(action, Component) else result.value + ) generate_log = GenerateLog() generate_log.prompt = prompts[i] @@ -476,7 +501,7 @@ async def generate_from_raw( "thinking": model_opts.get(ModelOption.THINKING, None), "seed": model_opts.get(ModelOption.SEED, None), } - generate_log.action = actions[i] + generate_log.action = action if error: generate_log.extra["error"] = error @@ -555,7 +580,6 @@ async def post_processing( assert mot._model_options is not None, ( "ModelOutputThunks should have their model_opts assigned during generation" ) - self.formatter.parse(mot._action, mot) # Generate the log for this ModelOutputThunk. generate_log = GenerateLog() diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 0ae82729..f6c3c21b 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -7,10 +7,10 @@ import inspect import json import os -from collections.abc import Callable, Coroutine +from collections.abc import Callable, Coroutine, Sequence from copy import deepcopy from enum import Enum -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, overload from urllib.parse import urlparse import granite_common @@ -48,6 +48,7 @@ extract_model_tool_requests, ) from mellea.stdlib.base import ( + C, CBlock, Component, Context, @@ -299,13 +300,13 @@ def _make_backend_specific_and_remove( async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ): + ) -> tuple[ModelOutputThunk[C], Context]: """See `generate_from_chat_context`.""" assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." @@ -320,14 +321,14 @@ async def generate_from_context( async def generate_from_chat_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, - ) -> tuple[ModelOutputThunk, Context]: + ) -> tuple[ModelOutputThunk[C], Context]: """Generates a new completion from the provided Context using this backend's `Formatter`.""" await self.do_generate_walk(action) @@ -623,7 +624,7 @@ async def _generate_from_chat_context_standard( # This only addresses the additionalProperties=False constraint. # Other constraints we should be checking/patching are described here: # https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat - monkey_patched_response_schema = _format.model_json_schema() + monkey_patched_response_schema = _format.model_json_schema() # type: ignore monkey_patched_response_schema["additionalProperties"] = False extra_params["response_format"] = { "type": "json_schema", @@ -641,7 +642,7 @@ async def _generate_from_chat_context_standard( "type": "json_schema", "json_schema": { "name": _format.__name__, - "schema": _format.model_json_schema(), + "schema": _format.model_json_schema(), # type: ignore "strict": True, }, } @@ -801,8 +802,6 @@ async def post_processing( for key, val in tool_chunk.items(): mot.tool_calls[key] = val - self.formatter.parse(mot._action, mot) - # Generate the log for this ModelOutputThunk. generate_log = GenerateLog() generate_log.prompt = conversation @@ -821,9 +820,31 @@ async def post_processing( generate_log.result = mot mot._generate_log = generate_log + @overload async def generate_from_raw( self, - actions: list[Component | CBlock], + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload + async def generate_from_raw( + self, + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], ctx: Context, *, format: type[BaseModelSubclass] | None = None, @@ -831,7 +852,7 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" - await self.do_generate_walks(actions) + await self.do_generate_walks(list(actions)) extra_body = {} if format is not None: @@ -841,7 +862,7 @@ async def generate_from_raw( ) # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. - extra_body["guided_json"] = format.model_json_schema() + extra_body["guided_json"] = format.model_json_schema() # type: ignore if tool_calls: FancyLogger.get_logger().warning( "The completion endpoint does not support tool calling at the moment." @@ -888,7 +909,9 @@ async def generate_from_raw( else None, } - self.formatter.parse(action, output) + output.parsed_repr = ( + action.parse(output) if isinstance(action, Component) else output.value + ) generate_log = GenerateLog() generate_log.prompt = prompt diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index 6d3fc228..56d85b50 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -15,8 +15,8 @@ import json import os import shutil -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Optional, overload import msgspec # type:ignore import outlines @@ -39,6 +39,7 @@ from mellea.helpers.event_loop_helper import _run_async_in_thread from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( + C, CBlock, Component, Context, @@ -239,14 +240,14 @@ def _model(self) -> vllm.AsyncLLMEngine: async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, - ) -> tuple[ModelOutputThunk, Context]: + ) -> tuple[ModelOutputThunk[C], Context]: """Generate using the huggingface model.""" await self.do_generate_walk(action) @@ -267,14 +268,14 @@ async def generate_from_context( async def _generate_from_context_standard( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, model_options: dict[str, Any], generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: + ) -> ModelOutputThunk[C]: # Construct input. # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. # Otherwise, we will linearize the context and treat it as a raw input. @@ -322,7 +323,7 @@ async def _generate_from_context_standard( if _format is not None: # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk - schema: dict[str, Any] = _format.model_json_schema() + schema: dict[str, Any] = _format.model_json_schema() # type: ignore schema_json: str = json.dumps(schema) regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json # type: ignore @@ -409,8 +410,6 @@ async def post_processing( "ModelOutputThunks should have their model_opts assigned during generation" ) - self.formatter.parse(mot._action, mot) - # Generate the log for this ModelOutputThunk. generate_log = GenerateLog() generate_log.prompt = conversation @@ -429,9 +428,31 @@ async def post_processing( mot._generate_log = generate_log + @overload + async def generate_from_raw( + self, + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload async def generate_from_raw( self, - actions: list[Component | CBlock], + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], ctx: Context, *, format: type[BaseModelSubclass] | None = None, @@ -439,7 +460,7 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" - await self.do_generate_walks(actions) + await self.do_generate_walks(list(actions)) if tool_calls: FancyLogger.get_logger().warning( @@ -458,7 +479,7 @@ async def generate_from_raw( ) if format is not None: - schema: dict[str, Any] = format.model_json_schema() + schema: dict[str, Any] = format.model_json_schema() # type: ignore schema_json: str = json.dumps(schema) regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json # type: ignore @@ -487,9 +508,13 @@ async def generate(prompt, request_id): results = [ModelOutputThunk(value=text) for text in decoded_results] for i, result in enumerate(results): - self.formatter.parse(actions[i], result) date = datetime.datetime.now() + action = actions[i] + result.parsed_repr = ( + action.parse(result) if isinstance(action, Component) else result.value + ) + generate_log = GenerateLog() generate_log.prompt = prompts[i] generate_log.backend = f"vllm::{self.model_id!s}" @@ -500,7 +525,7 @@ async def generate(prompt, request_id): "format": format, "seed": model_options.get(ModelOption.SEED, None), } - generate_log.action = actions[i] + generate_log.action = action generate_log.result = results[i] result._generate_log = generate_log diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 721e4f05..73f257e5 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -6,9 +6,9 @@ import json import os import warnings -from collections.abc import AsyncGenerator, Callable, Coroutine +from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence from dataclasses import fields -from typing import Any +from typing import Any, overload from ibm_watsonx_ai import APIClient, Credentials from ibm_watsonx_ai.foundation_models import ModelInference @@ -34,6 +34,7 @@ extract_model_tool_requests, ) from mellea.stdlib.base import ( + C, CBlock, Component, Context, @@ -238,13 +239,13 @@ def _make_backend_specific_and_remove( async def generate_from_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ): + ) -> tuple[ModelOutputThunk[C], Context]: """See `generate_from_chat_context`.""" assert ctx.is_chat_context, NotImplementedError( "The watsonx.ai backend only supports chat-like contexts." @@ -260,14 +261,14 @@ async def generate_from_context( async def generate_from_chat_context( self, - action: Component | CBlock, + action: Component[C] | CBlock, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: + ) -> ModelOutputThunk[C]: """Generates a new completion from the provided Context using this backend's `Formatter`.""" await self.do_generate_walk(action) @@ -301,7 +302,7 @@ async def generate_from_chat_context( "type": "json_schema", "json_schema": { "name": _format.__name__, - "schema": _format.model_json_schema(), + "schema": _format.model_json_schema(), # type: ignore "strict": True, }, } @@ -463,8 +464,6 @@ async def post_processing( for key, val in tool_chunk.items(): mot.tool_calls[key] = val - self.formatter.parse(mot._action, mot) - # Generate the log for this ModelOutputThunk. generate_log = GenerateLog() generate_log.prompt = conversation @@ -482,9 +481,31 @@ async def post_processing( generate_log.action = mot._action mot._generate_log = generate_log + @overload + async def generate_from_raw( + self, + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload async def generate_from_raw( self, - actions: list[Component | CBlock], + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], ctx: Context, *, format: type[BaseModelSubclass] | None = None, @@ -492,7 +513,7 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generates a completion text. Gives the input provided to the model without templating.""" - await self.do_generate_walks(actions) + await self.do_generate_walks(list(actions)) if format is not None: FancyLogger.get_logger().warning( @@ -529,7 +550,10 @@ async def generate_from_raw( }, ) - self.formatter.parse(actions[i], result) + action = actions[i] + result.parsed_repr = ( + action.parse(result) if isinstance(action, Component) else result.value + ) generate_log = GenerateLog() generate_log.prompt = prompts[i] @@ -541,7 +565,7 @@ async def generate_from_raw( "format": format, "seed": model_opts.get(ModelOption.SEED, None), } - generate_log.action = actions[i] + generate_log.action = action result._generate_log = generate_log diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index d83b102d..c608179a 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -12,13 +12,24 @@ from copy import copy, deepcopy from dataclasses import dataclass from io import BytesIO -from typing import Any, Protocol, TypeVar, runtime_checkable +from typing import Any, Generic, Protocol, TypeVar, runtime_checkable +import typing_extensions from PIL import Image as PILImage -from mellea.helpers.fancy_logger import FancyLogger +S = typing_extensions.TypeVar("S", default=Any, covariant=True) +"""Used for class definitions for Component and ModelOutputThunk; also used for functions that don't accept CBlocks. Defaults to `Any`.""" +C = typing_extensions.TypeVar("C", default=str) +"""Used for component typing in function parameters where the function takes a Component[C] and/or CBlock and can return a ModelOutputThunk[C]. Defaults to `str`.""" +class ComponentParseError(Exception): + """Raised by `Component.parse()` when the underlying parsing method throws an exception.""" + pass + + +# For ModelOutputThunk return types to be typed correctly, CBlocks must be defined +# using generics and a type var that defaults to str. CBlocks should never be initialized with [type]. class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" @@ -128,7 +139,7 @@ def __repr__(self): @runtime_checkable -class Component(Protocol): +class Component(Protocol, Generic[S]): """A `Component` is a composite data structure that is intended to be represented to an LLM.""" def parts(self) -> list[Component | CBlock]: @@ -142,6 +153,16 @@ def format_for_llm(self) -> TemplateRepresentation | str: """ raise NotImplementedError("format_for_llm isn't implemented by default") + def parse(self, computed: ModelOutputThunk) -> S: + try: + return self._parse(computed) + except Exception as e: + raise ComponentParseError(f"component parsing failed: {e}") + + def _parse(self, computed: ModelOutputThunk) -> S: + """Components can define a return type that is parsed from the text output of an LLM.""" + raise NotImplementedError("parse isn't implemented by default") + def get_images_from_component(c: Component) -> None | list[ImageBlock]: """Gets images from a `Component` if they are present and a non-empty list, otherwise returns None.""" @@ -163,7 +184,7 @@ def get_images_from_component(c: Component) -> None | list[ImageBlock]: # TODO: Add support for passing in docs as model options. -class Document(Component): +class Document(Component[str]): """Documents should typically be used in a Message object.""" def __init__(self, text: str, title: str | None = None, doc_id: str | None = None): @@ -190,6 +211,10 @@ def format_for_llm(self) -> str: return doc + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" + class GenerateType(enum.Enum): """Used to track what functions can be used to extract a value from a ModelOutputThunk.""" @@ -199,19 +224,21 @@ class GenerateType(enum.Enum): SYNC = 2 -class ModelOutputThunk(CBlock): +class ModelOutputThunk(CBlock, Generic[S]): """A `ModelOutputThunk` is a special type of `CBlock` that we know came from a model's output. It is possible to instantiate one without the output being computed yet.""" def __init__( self, value: str | None, meta: dict[str, Any] | None = None, - parsed_repr: CBlock | Component | Any | None = None, + parsed_repr: S | None = None, tool_calls: dict[str, ModelToolCall] | None = None, ): """Initializes as a cblock, optionally also with a parsed representation from an output formatter.""" super().__init__(value, meta) - self.parsed_repr: CBlock | Component | Any | None = parsed_repr + + self.parsed_repr: S | None = parsed_repr + """Will be non-`None` once computed.""" # Set computed to True if a value is passed in. self._computed: bool = True if value is not None else False @@ -266,7 +293,7 @@ async def avalue(self) -> str: RuntimeError: If called when the ModelOutputThunk's generate function is not async compatible. """ if self._computed: - assert self.value # If computed, the value cannot be None. + assert self.value is not None # If computed, the value cannot be None. return self.value if not self._generate_type == GenerateType.ASYNC: @@ -355,6 +382,22 @@ async def astream(self) -> str: assert self._post_process is not None await self._post_process(self) + match self._action: + case Component(): + self.parsed_repr = self._action._parse(self) + case CBlock(): + assert self.value is not None, ( + "value must be non-None since this thunk is computed" + ) + self.parsed_repr = self.value # type: ignore + case _: + raise ValueError( + "attempted to astream from a model output thunk with no ._action set" + ) + assert self.parsed_repr is not None, ( + "enforce constraint that a computed ModelOutputThunk has a non-None parsed_repr" + ) + return self._underlying_value # type: ignore def __repr__(self): @@ -374,7 +417,7 @@ def __copy__(self): # itself if the parsing didn't result in a new representation. It makes sense to update the # parsed_repr to the copied ModelOutputThunk in that case. if self.parsed_repr is self: - copied.parsed_repr = copied + copied.parsed_repr = copied # type: ignore copied._computed = self._computed copied._thinking = self._thinking @@ -671,7 +714,7 @@ def call_func(self) -> Any: return self.func(**self.args) -class SimpleComponent(Component): +class SimpleComponent(Component[str]): """A Component that is make up of named spans.""" def __init__(self, **kwargs): @@ -719,3 +762,7 @@ def make_json_string(kwargs): def format_for_llm(self): """Uses a string rep.""" return SimpleComponent.make_json_string(self._kwargs) + + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" diff --git a/mellea/stdlib/chat.py b/mellea/stdlib/chat.py index ad582be6..f024089e 100644 --- a/mellea/stdlib/chat.py +++ b/mellea/stdlib/chat.py @@ -16,7 +16,7 @@ ) -class Message(Component): +class Message(Component["Message"]): """A single Message in a Chat history. TODO: we may want to deprecate this Component entirely. @@ -96,6 +96,53 @@ def __str__(self): docs = [f"{doc.format_for_llm()[:10]}..." for doc in self._docs] return f'mellea.Message(role="{self.role}", content="{self.content}", images="{images}", documents="{docs}")' + def _parse(self, computed: ModelOutputThunk) -> "Message": + """Parse the model output into a Message.""" + # TODO: There's some specific logic for tool calls. Storing that here for now. + # We may eventually need some generic parsing logic that gets run for all Component types... + if computed.tool_calls is not None: + # A tool was successfully requested. + # Assistant responses for tool calling differ by backend. For the default formatter, + # we put all of the function data into the content field in the same format we received it. + + # Chat backends should provide an openai-like object in the _meta chat response, which we can use to properly format this output. + if "chat_response" in computed._meta: + # Ollama. + return Message( + role=computed._meta["chat_response"].message.role, + content=str(computed._meta["chat_response"].message.tool_calls), + ) + elif "oai_chat_response" in computed._meta: + # OpenAI and Watsonx. + return Message( + role=computed._meta["oai_chat_response"]["message"]["role"], + content=str( + computed._meta["oai_chat_response"]["message"].get( + "tool_calls", [] + ) + ), + ) + else: + # HuggingFace (or others). There are no guarantees on how the model represented the function calls. + # Output it in the same format we received the tool call request. + assert computed.value is not None + return Message(role="assistant", content=computed.value) + + if "chat_response" in computed._meta: + # Chat backends should provide an openai-like object in the _meta chat response, which we can use to properly format this output. + return Message( + role=computed._meta["chat_response"].message.role, + content=computed._meta["chat_response"].message.content, + ) + elif "oai_chat_response" in computed._meta: + return Message( + role=computed._meta["oai_chat_response"]["message"]["role"], + content=computed._meta["oai_chat_response"]["message"]["content"], + ) + else: + assert computed.value is not None + return Message(role="assistant", content=computed.value) + class ToolMessage(Message): """Adds the name field for function name.""" diff --git a/mellea/stdlib/docs/richdocument.py b/mellea/stdlib/docs/richdocument.py index 251254fb..6913d71f 100644 --- a/mellea/stdlib/docs/richdocument.py +++ b/mellea/stdlib/docs/richdocument.py @@ -11,11 +11,16 @@ from docling_core.types.doc.document import DoclingDocument, TableItem from docling_core.types.io import DocumentStream -from mellea.stdlib.base import CBlock, Component, TemplateRepresentation +from mellea.stdlib.base import ( + CBlock, + Component, + ModelOutputThunk, + TemplateRepresentation, +) from mellea.stdlib.mobject import MObject, Query, Transform -class RichDocument(Component): +class RichDocument(Component[str]): """A `RichDocument` is a block of content with an underlying DoclingDocument. It has helper functions for working with the document and extracting parts of it. @@ -41,6 +46,10 @@ def format_for_llm(self) -> TemplateRepresentation | str: """ return self.to_markdown() + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" + def docling(self) -> DoclingDocument: """Get the underlying Docling Document.""" return self._doc diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index e758be04..950e103c 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -4,7 +4,7 @@ import asyncio from collections.abc import Coroutine -from typing import Any, Literal, overload +from typing import Any, Literal, TypeVar, overload from PIL import Image as PILImage @@ -14,12 +14,12 @@ from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( CBlock, - ChatContext, Component, Context, GenerateLog, ImageBlock, ModelOutputThunk, + S, SimpleContext, ) from mellea.stdlib.chat import Message, ToolMessage @@ -36,7 +36,7 @@ @overload def act( - action: Component, + action: Component[S], context: Context, backend: Backend, *, @@ -46,12 +46,12 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> tuple[ModelOutputThunk, Context]: ... +) -> tuple[ModelOutputThunk[S], Context]: ... @overload def act( - action: Component, + action: Component[S], context: Context, backend: Backend, *, @@ -61,11 +61,11 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> SamplingResult: ... +) -> SamplingResult[S]: ... def act( - action: Component, + action: Component[S], context: Context, backend: Backend, *, @@ -75,7 +75,7 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> tuple[ModelOutputThunk, Context] | SamplingResult: +) -> tuple[ModelOutputThunk[S], Context] | SamplingResult[S]: """Runs a generic action, and adds both the action and the result to the context. Args: @@ -129,7 +129,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> tuple[ModelOutputThunk, Context]: ... +) -> tuple[ModelOutputThunk[str], Context]: ... @overload @@ -150,7 +150,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> SamplingResult: ... +) -> SamplingResult[str]: ... def instruct( @@ -170,7 +170,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> tuple[ModelOutputThunk, Context] | SamplingResult: +) -> tuple[ModelOutputThunk[str], Context] | SamplingResult[str]: """Generates from an instruction. Args: @@ -418,7 +418,7 @@ def transform( @overload async def aact( - action: Component, + action: Component[S], context: Context, backend: Backend, *, @@ -429,12 +429,12 @@ async def aact( model_options: dict | None = None, tool_calls: bool = False, silence_context_type_warning: bool = False, -) -> tuple[ModelOutputThunk, Context]: ... +) -> tuple[ModelOutputThunk[S], Context]: ... @overload async def aact( - action: Component, + action: Component[S], context: Context, backend: Backend, *, @@ -445,11 +445,11 @@ async def aact( model_options: dict | None = None, tool_calls: bool = False, silence_context_type_warning: bool = False, -) -> SamplingResult: ... +) -> SamplingResult[S]: ... async def aact( - action: Component, + action: Component[S], context: Context, backend: Backend, *, @@ -460,7 +460,7 @@ async def aact( model_options: dict | None = None, tool_calls: bool = False, silence_context_type_warning: bool = False, -) -> tuple[ModelOutputThunk, Context] | SamplingResult: +) -> tuple[ModelOutputThunk[S], Context] | SamplingResult: """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. Args: @@ -567,7 +567,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> tuple[ModelOutputThunk, Context]: ... +) -> tuple[ModelOutputThunk[str], Context]: ... @overload @@ -588,7 +588,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> SamplingResult: ... +) -> SamplingResult[S]: ... async def ainstruct( @@ -608,7 +608,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, -) -> tuple[ModelOutputThunk, Context] | SamplingResult: +) -> tuple[ModelOutputThunk[str], Context] | SamplingResult: """Generates from an instruction. Args: diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 666263bf..f56379ae 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -139,12 +139,12 @@ def __init__( self.validation = validation_results -class Function: +class Function(Generic[P, R]): """A Function.""" - def __init__(self, func: Callable): + def __init__(self, func: Callable[P, R]): """A Function.""" - self._func: Callable = func + self._func: Callable[P, R] = func self._function_dict: FunctionDict = describe_function(func) @@ -254,7 +254,7 @@ def __init__(self): """A list of parameter names used by Mellea. Cannot use these in functions decorated with @generative.""" -class GenerativeSlot(Component, Generic[P, R]): +class GenerativeSlot(Component[R], Generic[P, R]): """A generative slot component.""" def __init__(self, func: Callable[P, R]): @@ -281,6 +281,8 @@ def __init__(self, func: Callable[P, R]): self._arguments: Arguments | None = None functools.update_wrapper(self, func) + self._response_model = create_response_format(self._function._func) + # Set when calling the decorated func. self.precondition_requirements: list[Requirement] = [] self.requirements: list[Requirement] = [] @@ -407,6 +409,16 @@ def format_for_llm(self) -> TemplateRepresentation: template_order=["*", "GenerativeSlot"], ) + def _parse(self, computed: ModelOutputThunk) -> R: + """Parse the model output. Returns the original function's return type.""" + function_response: FunctionResponse[R] = ( + self._response_model.model_validate_json( + computed.value # type: ignore + ) + ) + + return function_response.result + class SyncGenerativeSlot(GenerativeSlot, Generic[P, R]): @overload @@ -474,8 +486,6 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: slot_args.append(get_argument(slot_copy._function._func, key, val)) slot_copy._arguments = Arguments(slot_args) - response_model = create_response_format(self._function._func) - # Do precondition validation first. if slot_copy._arguments is not None: if extracted.m is not None: @@ -517,7 +527,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: slot_copy, requirements=slot_copy.requirements, strategy=extracted.strategy, - format=response_model, + format=self._response_model, model_options=extracted.model_options, ) else: @@ -530,18 +540,15 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: extracted.backend, requirements=slot_copy.requirements, strategy=extracted.strategy, - format=response_model, + format=self._response_model, model_options=extracted.model_options, ) - function_response: FunctionResponse[R] = response_model.model_validate_json( - response.value # type: ignore - ) - + assert response.parsed_repr is not None if context is None: - return function_response.result + return response.parsed_repr else: - return function_response.result, context + return response.parsed_repr, context class AsyncGenerativeSlot(GenerativeSlot, Generic[P, R]): @@ -612,8 +619,6 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R slot_args.append(get_argument(slot_copy._function._func, key, val)) slot_copy._arguments = Arguments(slot_args) - response_model = create_response_format(self._function._func) - # AsyncGenerativeSlots are used with async functions. In order to support that behavior, # they must return a coroutine object. async def __async_call__() -> tuple[R, Context] | R: @@ -660,7 +665,7 @@ async def __async_call__() -> tuple[R, Context] | R: slot_copy, requirements=slot_copy.requirements, strategy=extracted.strategy, - format=response_model, + format=self._response_model, model_options=extracted.model_options, ) else: @@ -673,18 +678,15 @@ async def __async_call__() -> tuple[R, Context] | R: extracted.backend, requirements=slot_copy.requirements, strategy=extracted.strategy, - format=response_model, + format=self._response_model, model_options=extracted.model_options, ) - function_response: FunctionResponse[R] = response_model.model_validate_json( - response.value # type: ignore - ) - + assert response.parsed_repr is not None if context is None: - return function_response.result + return response.parsed_repr else: - return function_response.result, context + return response.parsed_repr, context return __async_call__() diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/instruction.py index 8d5b3a68..7e9440ef 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/instruction.py @@ -10,13 +10,14 @@ CBlock, Component, ImageBlock, + ModelOutputThunk, TemplateRepresentation, blockify, ) from mellea.stdlib.requirement import Requirement, reqify -class Instruction(Component): +class Instruction(Component[str]): """The Instruction in an instruct/validate/repair loop.""" def __init__( @@ -180,3 +181,7 @@ def copy_and_repair(self, repair_string: str) -> Instruction: res = deepcopy(self) res._repair_string = repair_string return res + + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" diff --git a/mellea/stdlib/intrinsics/intrinsic.py b/mellea/stdlib/intrinsics/intrinsic.py index e21eb571..925dba6b 100644 --- a/mellea/stdlib/intrinsics/intrinsic.py +++ b/mellea/stdlib/intrinsics/intrinsic.py @@ -5,10 +5,15 @@ from typing import cast from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata -from mellea.stdlib.base import CBlock, Component, TemplateRepresentation +from mellea.stdlib.base import ( + CBlock, + Component, + ModelOutputThunk, + TemplateRepresentation, +) -class Intrinsic(Component): +class Intrinsic(Component[str]): """A component representing an intrinsic.""" def __init__( @@ -64,3 +69,7 @@ def format_for_llm(self) -> TemplateRepresentation | str: "`Intrinsic` doesn't implement format_for_llm by default. You should only " "use an `Intrinsic` as the action and not as a part of the context." ) + + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" diff --git a/mellea/stdlib/mify.py b/mellea/stdlib/mify.py index 2059e426..43a3df07 100644 --- a/mellea/stdlib/mify.py +++ b/mellea/stdlib/mify.py @@ -5,7 +5,12 @@ from collections.abc import Callable from typing import Any, Protocol, TypeVar, overload, runtime_checkable -from mellea.stdlib.base import CBlock, Component, TemplateRepresentation +from mellea.stdlib.base import ( + CBlock, + Component, + ModelOutputThunk, + TemplateRepresentation, +) from mellea.stdlib.mobject import MObjectProtocol, Query, Transform @@ -192,6 +197,13 @@ def format_for_llm(self) -> TemplateRepresentation: template_order=template_order, ) + def parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now. + + [no-index] + """ + return computed.value if computed.value is not None else "" + T = TypeVar("T") diff --git a/mellea/stdlib/mobject.py b/mellea/stdlib/mobject.py index 19852264..e9f20f1a 100644 --- a/mellea/stdlib/mobject.py +++ b/mellea/stdlib/mobject.py @@ -6,10 +6,15 @@ from collections.abc import Callable from typing import Protocol, runtime_checkable -from mellea.stdlib.base import CBlock, Component, TemplateRepresentation +from mellea.stdlib.base import ( + CBlock, + Component, + ModelOutputThunk, + TemplateRepresentation, +) -class Query(Component): +class Query(Component[str]): """A Query component.""" def __init__(self, obj: Component, query: str) -> None: @@ -48,8 +53,12 @@ def format_for_llm(self) -> TemplateRepresentation | str: template_order=["Query"], ) + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" -class Transform(Component): + +class Transform(Component[str]): """A Transform component.""" def __init__(self, obj: Component, transformation: str) -> None: @@ -88,6 +97,10 @@ def format_for_llm(self) -> TemplateRepresentation | str: template_order=["Transform"], ) + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" + @runtime_checkable class MObjectProtocol(Protocol): @@ -137,8 +150,12 @@ def format_for_llm(self) -> TemplateRepresentation | str: """ ... + def parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output.""" + ... -class MObject(Component): + +class MObject(Component[str]): """An extension of `Component` for adding query and transform operations.""" def __init__( @@ -218,3 +235,7 @@ def format_for_llm(self) -> TemplateRepresentation | str: fields=[], template_order=["*", "MObject"], ) + + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index f30aa4fc..8ad261e8 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -96,7 +96,7 @@ def __bool__(self) -> bool: return self.as_bool() -class Requirement(Component): +class Requirement(Component[str]): """Requirements are a special type of Component used as input to the Validate step in Instruct/Validate/Repair patterns.""" def __init__( @@ -177,6 +177,10 @@ def format_for_llm(self) -> TemplateRepresentation | str: template_order=["*", "Requirement"], ) + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" + class LLMaJRequirement(Requirement): """A requirement that always uses LLM-as-a-Judge. Any available constraint ALoRA will be ignored.""" diff --git a/mellea/stdlib/sampling/__init__.py b/mellea/stdlib/sampling/__init__.py index 4c933cd9..81bc507c 100644 --- a/mellea/stdlib/sampling/__init__.py +++ b/mellea/stdlib/sampling/__init__.py @@ -6,4 +6,4 @@ RejectionSamplingStrategy, RepairTemplateStrategy, ) -from .types import SamplingResult, SamplingStrategy +from .types import S, SamplingResult, SamplingStrategy diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 06401ec4..6c00a008 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -13,7 +13,7 @@ from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ValidationResult -from .types import SamplingResult, SamplingStrategy +from .types import S, SamplingResult, SamplingStrategy class BaseSamplingStrategy(SamplingStrategy): @@ -61,6 +61,8 @@ def repair( Returns: The next action component and context to be used for the next generation attempt. """ + # TODO: For Component/ModelOutputThunk-typing to work, repair strategies should always return a Component with the same parsing + # as the initial action used for this sampling strategy. ... @staticmethod @@ -84,7 +86,7 @@ def select_from_failure( async def sample( self, - action: Component, + action: Component[S], context: Context, backend: Backend, requirements: list[Requirement] | None, @@ -94,7 +96,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, - ) -> SamplingResult: + ) -> SamplingResult[S]: """This method performs a sampling operation based on the given instruction. Args: @@ -160,6 +162,13 @@ async def sample( ) await result.avalue() + # Sampling strategies may use different components from the original + # action. This might cause discrepancies in the expected parsed_repr + # type / value. Explicitly overwrite that here. + # TODO: See if there's a more elegant way for this so that each sampling + # strategy doesn't have to re-implement it. + result.parsed_repr = action.parse(result) + # validation pass val_scores_co = mfuncs.avalidate( reqs=reqs, diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index be082159..b0b827ef 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -12,6 +12,7 @@ from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult from mellea.stdlib.sampling import BaseSamplingStrategy, SamplingResult +from mellea.stdlib.sampling.types import S class BestofNSamplingStrategy(BaseSamplingStrategy): @@ -19,7 +20,7 @@ class BestofNSamplingStrategy(BaseSamplingStrategy): async def sample( self, - action: Component, + action: Component[S], context: Context, backend: Backend, requirements: list[Requirement] | None, @@ -29,7 +30,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, - ) -> SamplingResult: + ) -> SamplingResult[S]: """This method performs a sampling operation based on the given instruction. Args: @@ -115,6 +116,12 @@ async def sample( model_options=model_options, tool_calls=tool_calls, ) + + # Sampling strategies may use different components from the original + # action. This might cause discrepancies in the expected parsed_repr + # type / value. Explicitly overwrite that here. + result.parsed_repr = action.parse(result) + sampled_results.append(result) sampled_actions.append(next_action) sample_contexts.append(result_ctx) diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index 692fae66..90934938 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -12,6 +12,7 @@ from mellea.stdlib.requirement import Requirement, ValidationResult from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult from mellea.stdlib.sampling.base import Component, Context +from mellea.stdlib.sampling.types import S from mellea.stdlib.sampling_algos.budget_forcing_alg import think_budget_forcing @@ -72,7 +73,7 @@ def __init__( async def sample( self, - action: Component, + action: Component[S], context: Context, backend: Backend, requirements: list[Requirement] | None, @@ -82,7 +83,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, - ) -> SamplingResult: + ) -> SamplingResult[S]: """This method performs a sampling operation based on the given instruction. Args: @@ -164,6 +165,11 @@ async def sample( ) result_ctx = next_context + # Sampling strategies may use different components from the original + # action. This might cause discrepancies in the expected parsed_repr + # type / value. Explicitly overwrite that here. + result.parsed_repr = action.parse(result) + # validation pass val_scores_co = mfuncs.avalidate( reqs=reqs, diff --git a/mellea/stdlib/sampling/majority_voting.py b/mellea/stdlib/sampling/majority_voting.py index 8ba99798..d68d06c0 100644 --- a/mellea/stdlib/sampling/majority_voting.py +++ b/mellea/stdlib/sampling/majority_voting.py @@ -11,6 +11,7 @@ from mellea.stdlib.requirement import Requirement from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult from mellea.stdlib.sampling.base import Component, Context +from mellea.stdlib.sampling.types import S class BaseMBRDSampling(RejectionSamplingStrategy): @@ -64,7 +65,7 @@ def maybe_apply_weighted(self, scr: np.ndarray): async def sample( self, - action: Component, + action: Component[S], context: Context, backend: Backend, requirements: list[Requirement] | None, @@ -74,7 +75,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, - ) -> SamplingResult: + ) -> SamplingResult[S]: """Samples using majority voting. Args: diff --git a/mellea/stdlib/sampling/types.py b/mellea/stdlib/sampling/types.py index 391b2c89..69cd13c2 100644 --- a/mellea/stdlib/sampling/types.py +++ b/mellea/stdlib/sampling/types.py @@ -1,13 +1,14 @@ """Base types for sampling.""" import abc +from typing import Generic, TypeVar from mellea.backends import Backend, BaseModelSubclass -from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk +from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk, S from mellea.stdlib.requirement import Requirement, ValidationResult -class SamplingResult(CBlock): +class SamplingResult(CBlock, Generic[S]): """Stores the results from a sampling operation. This includes successful and failed samplings.""" def __init__( @@ -15,7 +16,7 @@ def __init__( result_index: int, success: bool, *, - sample_generations: list[ModelOutputThunk] | None = None, + sample_generations: list[ModelOutputThunk[S]] | None = None, sample_validations: list[list[tuple[Requirement, ValidationResult]]] | None = None, sample_actions: list[Component] | None = None, @@ -56,7 +57,7 @@ def __init__( self.sample_contexts = sample_contexts @property - def result(self) -> ModelOutputThunk: + def result(self) -> ModelOutputThunk[S]: """The final output or result from applying the sampling strategy.""" return self.sample_generations[self.result_index] @@ -66,7 +67,7 @@ def result_ctx(self) -> Context: return self.sample_contexts[self.result_index] @property - def result_action(self) -> Component: + def result_action(self) -> Component[S]: """The action that generated the final output or result from applying the sampling strategy.""" return self.sample_actions[self.result_index] @@ -86,7 +87,7 @@ class SamplingStrategy(abc.ABC): @abc.abstractmethod async def sample( self, - action: Component, + action: Component[S], context: Context, backend: Backend, requirements: list[Requirement] | None, @@ -95,7 +96,7 @@ async def sample( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> SamplingResult: + ) -> SamplingResult[S]: """This method is the abstract method for sampling a given component. It must be implemented by any concrete subclasses to provide specific sampling logic. diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index db4f47be..ef2b348b 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -5,7 +5,7 @@ import contextvars import inspect from copy import copy -from typing import Any, Literal, overload +from typing import Any, Literal, TypeVar, overload from PIL import Image as PILImage @@ -26,6 +26,7 @@ GenerateLog, ImageBlock, ModelOutputThunk, + S, SimpleContext, ) from mellea.stdlib.chat import Message @@ -231,7 +232,7 @@ def cleanup(self) -> None: @overload def act( self, - action: Component, + action: Component[S], *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -239,12 +240,12 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: ... + ) -> ModelOutputThunk[S]: ... @overload def act( self, - action: Component, + action: Component[S], *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -252,11 +253,11 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> SamplingResult: ... + ) -> SamplingResult[S]: ... def act( self, - action: Component, + action: Component[S], *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -264,7 +265,7 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: + ) -> ModelOutputThunk[S] | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. Args: @@ -316,7 +317,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: ... + ) -> ModelOutputThunk[str]: ... @overload def instruct( @@ -335,7 +336,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> SamplingResult: ... + ) -> SamplingResult[str]: ... def instruct( self, @@ -353,7 +354,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: + ) -> ModelOutputThunk[str] | SamplingResult: """Generates from an instruction. Args: @@ -515,7 +516,7 @@ def transform( @overload async def aact( self, - action: Component, + action: Component[S], *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -523,12 +524,12 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: ... + ) -> ModelOutputThunk[S]: ... @overload async def aact( self, - action: Component, + action: Component[S], *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -536,11 +537,11 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> SamplingResult: ... + ) -> SamplingResult[S]: ... async def aact( self, - action: Component, + action: Component[S], *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -548,7 +549,7 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: + ) -> ModelOutputThunk[S] | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. Args: @@ -600,7 +601,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: ... + ) -> ModelOutputThunk[str]: ... @overload async def ainstruct( @@ -619,7 +620,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> SamplingResult: ... + ) -> SamplingResult[str]: ... async def ainstruct( self, @@ -637,7 +638,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: + ) -> ModelOutputThunk[str] | SamplingResult[str]: """Generates from an instruction. Args: diff --git a/mellea/stdlib/test_based_eval.py b/mellea/stdlib/test_based_eval.py index 1e96ad61..56df2344 100644 --- a/mellea/stdlib/test_based_eval.py +++ b/mellea/stdlib/test_based_eval.py @@ -6,7 +6,12 @@ from pydantic import BaseModel, Field, field_validator -from mellea.stdlib.base import CBlock, Component, TemplateRepresentation +from mellea.stdlib.base import ( + CBlock, + Component, + ModelOutputThunk, + TemplateRepresentation, +) class Message(BaseModel): @@ -42,7 +47,7 @@ def validate_examples(cls, v): return v -class TestBasedEval(Component): +class TestBasedEval(Component[str]): """Each TestBasedEval represents a single unit test.""" def __init__( @@ -76,6 +81,10 @@ def format_for_llm(self) -> TemplateRepresentation: template_order=["*"], ) + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" + def set_judge_context( self, input_text: str, prediction: str, targets_for_input: list[str] ): diff --git a/pyproject.toml b/pyproject.toml index 1f6cdb67..2bad08c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,7 @@ combine-as-imports = true split-on-trailing-comma = false [tool.codespell] -ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge' +ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge,Strat' check-filenames = true check-hidden = false regex = "(? None: super().__init__() @@ -22,6 +28,9 @@ def format_for_llm(self) -> TemplateRepresentation: obj=self, args={"arg": None}, tools={self.tool1.__name__: self.tool1} ) + def _parse(self, computed: ModelOutputThunk) -> str: + return "" + class FakeToolComponentWithExtraTool(FakeToolComponent): def __init__(self) -> None: diff --git a/test/stdlib_basics/test_base.py b/test/stdlib_basics/test_base.py index e19c6adc..c1184e64 100644 --- a/test/stdlib_basics/test_base.py +++ b/test/stdlib_basics/test_base.py @@ -1,5 +1,7 @@ +from typing import Any import pytest -from mellea.stdlib.base import CBlock, Component +from mellea.stdlib.base import CBlock, Component, ModelOutputThunk +from mellea.stdlib.chat import Message def test_cblock(): @@ -16,16 +18,51 @@ def test_cblpock_meta(): def test_component(): - class _ClosuredComponent(Component): + class _ClosuredComponent(Component[str]): def parts(self): return [] def format_for_llm(self) -> str: return "" + def _parse(self, computed: ModelOutputThunk) -> str: + return "" + c = _ClosuredComponent() assert len(c.parts()) == 0 +def test_parse(): + class _ChatResponse: + def __init__(self, msg: Message) -> None: + self.message = msg + + source = Message(role="user", content="source message") + result = ModelOutputThunk( + value="result value", + meta={ + "chat_response": _ChatResponse( + Message(role="assistant", content="assistant reply") + ) + }, + ) + + result.parsed_repr = source.parse(result) + assert isinstance(result.parsed_repr, Message), ( + "result's parsed repr should be a message when meta includes a chat_response" + ) + assert result.parsed_repr.role == "assistant", ( + "result's parsed repr role should be assistant" + ) + assert result.parsed_repr.content == "assistant reply" + + result = ModelOutputThunk(value="result value") + result.parsed_repr = source.parse(result) + assert isinstance(result.parsed_repr, Message), ( + "result's parsed repr should be a message when source component is a message" + ) + assert result.parsed_repr.content == "result value" + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib_basics/test_component_typing.py b/test/stdlib_basics/test_component_typing.py new file mode 100644 index 00000000..188146b7 --- /dev/null +++ b/test/stdlib_basics/test_component_typing.py @@ -0,0 +1,221 @@ +"""Tests for checking the functionality of typed components, model output thunks and sampling results.""" + +import pytest +from typing import get_args +from mellea import start_session +from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B +from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.base import ( + CBlock, + ChatContext, + Component, + ComponentParseError, + Context, + ModelOutputThunk, + SimpleContext, +) +from mellea.stdlib.chat import Message +from mellea.stdlib.instruction import Instruction +from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.sampling.base import BaseSamplingStrategy +from mellea.stdlib.session import MelleaSession + +import mellea.stdlib.functional as mfuncs + + +class FloatComp(Component[float]): + def __init__(self, value: str) -> None: + self.value = value + + def parts(self) -> list[Component | CBlock]: + return [] + + def format_for_llm(self) -> str: + return self.value + + def _parse(self, computed: ModelOutputThunk) -> float: + if computed.value is None: + return -1 + return float(computed.value) + + +class IntComp(FloatComp, Component[int]): + def _parse(self, computed: ModelOutputThunk) -> int: + if computed.value is None: + return -1 + try: + return int(computed.value) + except: + return -2 + +class ExceptionRaisingComp(Component[int]): + def parts(self) -> list[Component | CBlock]: + return [] + + def format_for_llm(self) -> str: + return "" + + def _parse(self, computed: ModelOutputThunk) -> int: + raise ValueError("random error") + + +@pytest.fixture(scope="module") +def backend(gh_run: int): + """Shared backend.""" + if gh_run == 1: + return OllamaModelBackend( + model_id=IBM_GRANITE_4_MICRO_3B.ollama_name # type: ignore + ) + else: + return OllamaModelBackend(model_id="granite3.3:8b") + + +@pytest.fixture(scope="module") +def session(backend) -> MelleaSession: + return MelleaSession(backend=backend, ctx=SimpleContext()) + + +def test_mot_init_typing(): + mot = ModelOutputThunk[float](value="1") + assert hasattr(mot, "__orig_class__"), ( + f"mots are generics and should have this field" + ) + assert get_args(mot.__orig_class__)[0] == float, ( + f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" + ) # type: ignore + + unknown_mot = ModelOutputThunk(value="2") + assert not hasattr(unknown_mot, "__orig_class__"), ( + f"unknown mots / mots with no type defined at instantiate don't have this attribute" + ) + + +def test_simple_component_parsing(): + fc = FloatComp(value="generate a float") + mot = ModelOutputThunk[float](value="1") + assert fc.parse(mot) == 1 + assert isinstance(fc.parse(mot), float) + + +def test_subclassed_component_parsing(): + ic = IntComp("generate an int") + mot = ModelOutputThunk[float](value="1") + assert ic.parse(mot) == 1 + +def test_component_parsing_fails(): + erc = ExceptionRaisingComp() + mot = ModelOutputThunk[float](value="1") + + with pytest.raises(ComponentParseError): + _ = erc.parse(mot) == 1 + +def test_incorrect_type_override(): + with pytest.raises(TypeError): + instruction = Instruction[int](description="this is an instruction") # type: ignore + + +# Marking as qualitative for now since there's so much generation required for this. +@pytest.mark.qualitative +async def test_generating(session): + m = session + ic = IntComp("generate an int") + + out, _ = mfuncs.act(ic, context=ChatContext(), backend=m.backend, strategy=None) + assert isinstance(out.parsed_repr, int) + + # `out` typed as ModelOutputThunk[str] + out, _ = await m.backend.generate_from_context( + CBlock("Say Hello!"), ctx=ChatContext() + ) + await out.avalue() + assert isinstance(out.parsed_repr, str) + + # `out` typed as ModelOutputThunk[float] + out, _ = await m.backend.generate_from_context(ic, ctx=ChatContext()) + await out.avalue() + assert isinstance(out.parsed_repr, int) + + # `out` typed as ModelOutputThunk[float | str] + out = await m.backend.generate_from_raw([ic, CBlock("")], ctx=ChatContext()) + for result in out: + await result.avalue() + assert isinstance(out[0].parsed_repr, int) + assert isinstance(out[1].parsed_repr, str) + + # `out` typed as ModelOutputThunk[float] + out = await m.backend.generate_from_raw([ic, ic], ctx=ChatContext()) + for result in out: + await result.avalue() + assert isinstance(result.parsed_repr, int) + + # `out` typed as ModelOutputThunk[str] + out = await m.backend.generate_from_raw([CBlock("")], ctx=ChatContext()) + for result in out: + await result.avalue() + assert isinstance(result.parsed_repr, str) + + +@pytest.mark.qualitative +def test_message_typing(session): + m = session + user_message = Message("user", "Hello!") + response = m.act(user_message) + assert response.parsed_repr is not None + assert isinstance(response.parsed_repr, Message) + + second_response = m.act(response.parsed_repr) + assert second_response.parsed_repr is not None + assert isinstance(second_response.parsed_repr, Message) + + +@pytest.mark.qualitative +async def test_generating_with_sampling(session): + m = session + m = start_session() + + class CustomSamplingStrat(BaseSamplingStrategy): + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + return len(sampled_actions) - 1 + + @staticmethod + def repair( + old_ctx: Context, + new_ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> tuple[Component, Context]: + return Instruction("print another number 100 greater"), old_ctx + + css = CustomSamplingStrat(loop_budget=3) + out = await css.sample( + action=IntComp("2000"), + context=ChatContext(), + backend=m.backend, + requirements=[ + Requirement( + None, validation_fn=lambda x: ValidationResult(False), check_only=True + ) + ], + ) + + # Even though the intermediate actions are Instructions, the parsed_reprs at each stage + # are ints. + for result in out.sample_generations: + assert isinstance(result.parsed_repr, int), ( + "model output thunks should have the correct parsed_repr type" + ) + + for action in out.sample_actions[1:]: + assert isinstance(action, Instruction), ( + "repair strategy should force repaired actions to be Instructions" + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/stdlib_basics/test_mify.py b/test/stdlib_basics/test_mify.py index 1cae3ed1..00ead004 100644 --- a/test/stdlib_basics/test_mify.py +++ b/test/stdlib_basics/test_mify.py @@ -8,8 +8,10 @@ def test_protocol_adherence(): mobj = MObject() - assert isinstance(mobj, MObjectProtocol) - assert isinstance(mobj, Component) + assert isinstance(mobj, MObjectProtocol), ( + "mobject doesn't conform to mobject protocol" + ) + assert isinstance(mobj, Component), "mobject doesn't conform to component protocol" @mify class _Customer: @@ -17,8 +19,12 @@ def __init__(self, name: str) -> None: self.name = name mified_class = _Customer("Jake") - assert isinstance(mified_class, MObjectProtocol) - assert isinstance(mified_class, Component) + assert isinstance(mified_class, MObjectProtocol), ( + "mified class doesn't conform to mobject protocol" + ) + assert isinstance(mified_class, Component), ( + "mified class doesn't conform to component protocol" + ) class _Customer2: def __init__(self, name: str) -> None: @@ -26,8 +32,12 @@ def __init__(self, name: str) -> None: c = _Customer2("jake") mify(c) - assert isinstance(c, MObjectProtocol) - assert isinstance(c, Component) + assert isinstance(c, MObjectProtocol), ( + "mified object doesn't conform to mobject protocol" + ) + assert isinstance(c, Component), ( + "mified object doesn't conform to component protocol" + ) def test_mify_class(): diff --git a/test/test_formatter_baseclasses.py b/test/test_formatter_baseclasses.py index 9bc35575..2ba232a4 100644 --- a/test/test_formatter_baseclasses.py +++ b/test/test_formatter_baseclasses.py @@ -2,7 +2,7 @@ import os import sys import tempfile -from typing import List, Optional +from typing import Any, List, Optional import pytest @@ -73,44 +73,6 @@ def test_to_chat_messages(tf: TemplateFormatter): ) -def test_parse(tf: TemplateFormatter): - class _ChatResponse: - def __init__(self, msg: Message) -> None: - self.message = msg - - source = Message(role="user", content="source message") - result = ModelOutputThunk( - value="result value", - meta={ - "chat_response": _ChatResponse( - Message(role="assistant", content="assistant reply") - ) - }, - ) - tf.parse(source, result) - assert isinstance(result.parsed_repr, Message), ( - "result's parsed repr should be a message when meta includes a chat_response" - ) - assert result.parsed_repr.role == "assistant", ( - "result's parsed repr role should be assistant" - ) - assert result.parsed_repr.content == "assistant reply" - - result = ModelOutputThunk(value="result value") - tf.parse(source, result) - assert isinstance(result.parsed_repr, Message), ( - "result's parsed repr should be a message when source component is a message" - ) - assert result.parsed_repr.content == "result value" - - cblock_source = CBlock("cblock source") - result = ModelOutputThunk(value="result value from cblock") - tf.parse(cblock_source, result) - assert result.parsed_repr is result, ( - "parse should set the result object to result.parsed_repr if it's not parsing a message" - ) - - def test_custom_template_string(tf: TemplateFormatter): class _TemplInstruction(Instruction): def format_for_llm(self) -> TemplateRepresentation: @@ -203,13 +165,16 @@ def test_no_module(tf: TemplateFormatter): def test_no_template(tf: TemplateFormatter): - class _NoTemplate(Component): + class _NoTemplate(Component[str]): def parts(self) -> List[Component | CBlock]: return [] def format_for_llm(self) -> TemplateRepresentation: return TemplateRepresentation(self, {}) + def _parse(self, computed: ModelOutputThunk) -> str: + return "" + with pytest.raises(Exception): tf._load_template(_NoTemplate().format_for_llm()) @@ -280,8 +245,8 @@ def test_custom_component_external_package(tf: TemplateFormatter): Ensures template loading works for custom components defined in other packages.""" new_component_content = """ -from mellea.stdlib.base import Component, TemplateRepresentation -class NewComponent(Component): +from mellea.stdlib.base import Component, TemplateRepresentation, ModelOutputThunk +class NewComponent(Component[str]): def parts(self): raise NotImplementedError( "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" @@ -292,6 +257,9 @@ def format_for_llm(self) -> dict: self, {"text": "template arg version of new component"} ) + + def _parse(self, computed: ModelOutputThunk) -> str: + return "" """ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: