Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ Mellea's core abstraction is called a `Component`. A `Component` is a structured

Components are composite data structures; that is, a `Component` can be made up of many other parts. Each of those parts is either a `CBlock` or another `Component`. `CBlock`s, or "content blocks", are an atomic unit of text or data. CBlocks hold raw text (or sometimes parsed representations) and can be used as leaves in the Component DAG.

Components can also specify an expected output type along with a parse function to extract that type from the LLM's output. By default, this type is a string; but by defining a Component's expected type, you can get type hinting for outputs in the standard library.

Backends are the engine that actually run the LLM. Backends consume Components, format the Component, pass the formatted input to an LLM, and return model outputs, which are then parsed back into CBlocks or Components.

During the course of an interaction with an LLM, several Components and CBlocks may be created. Logic for handling this trace of interactions is provided by a `Context` object. Some book-keeping needs to be done in order for Contexts to approporiately handle a trace of Components and CBlocks. The `MelleaSession` class, which is created by `mellea.start_session()`, does this book-keeping a simple wrapper around Contexts and Backends.
Expand Down
45 changes: 37 additions & 8 deletions mellea/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
import abc
import asyncio
import itertools
from typing import TypeVar
from collections.abc import Sequence
from typing import Any, overload

import pydantic
import typing_extensions

from mellea.backends.model_ids import ModelIdentifier
from mellea.backends.types import ModelOption
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk

BaseModelSubclass = TypeVar(
"BaseModelSubclass", bound=pydantic.BaseModel
from mellea.stdlib.base import C, CBlock, Component, Context, ModelOutputThunk

# Necessary to define a type that supports `None` so that the BaseModelSubclass
# can have a default value. Otherwise, Python complains about typed-components
# since types with default values must come after those without default values in
# function signatures (which is incompatible with our function parameter formatting).
pydantic_model_or_none = pydantic.BaseModel | None
BaseModelSubclass = typing_extensions.TypeVar(
"BaseModelSubclass", bound=pydantic_model_or_none, default=None
) # must be a subclass of BaseModel


Expand All @@ -39,13 +46,13 @@ def __init__(
@abc.abstractmethod
async def generate_from_context(
self,
action: Component | CBlock,
action: Component[C] | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> tuple[ModelOutputThunk, Context]:
) -> tuple[ModelOutputThunk[C], Context]:
"""Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request.

Args:
Expand All @@ -60,10 +67,32 @@ async def generate_from_context(
"""
...

@overload
async def generate_from_raw(
self,
actions: list[Component[C]],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk[C]]: ...

@overload
async def generate_from_raw(
self,
actions: list[Component[C] | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk[C | str]]: ...

@abc.abstractmethod
async def generate_from_raw(
self,
actions: list[Component | CBlock],
actions: Sequence[Component[C] | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
Expand Down
6 changes: 3 additions & 3 deletions mellea/backends/dummy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module holds shim backends used for smoke tests."""

from mellea.backends import Backend, BaseModelSubclass
from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk
from mellea.stdlib.base import C, CBlock, Component, Context, ModelOutputThunk


class DummyBackend(Backend):
Expand All @@ -18,13 +18,13 @@ def __init__(self, responses: list[str] | None):

async def generate_from_context(
self,
action: Component | CBlock,
action: Component[C] | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> tuple[ModelOutputThunk, Context]:
) -> tuple[ModelOutputThunk[C], Context]:
"""See constructor for an exmplanation of how DummyBackends work."""
assert format is None, "The DummyBackend does not support constrained decoding."
if self.responses is None:
Expand Down
77 changes: 7 additions & 70 deletions mellea/backends/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ModelOutputThunk,
TemplateRepresentation,
)
from mellea.stdlib.chat import Message, ToolMessage
from mellea.stdlib.chat import Message


class Formatter(abc.ABC):
Expand All @@ -31,16 +31,6 @@ def print(self, c: Component | CBlock) -> str:
"""Renders a component for input to a model."""
...

@abc.abstractmethod
def parse(
self, source_component: Component | CBlock, result: ModelOutputThunk
) -> ModelOutputThunk:
"""Parses the output from a model and sets the parsed_repr of the result ModelOutputThunk.

Returns the ModelOutputThunk that was passed in.
"""
...

def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]:
"""Helper method that converts a linearized chat history into a list of messages. The purpose of this helper is to prepare a sequence of Messages for input to a chat endpoint."""

Expand All @@ -58,7 +48,12 @@ def _to_msg(c: Component | CBlock) -> Message:
) # This is already entailed by c.is_computed(); the line is included here to satisfy the type-checker.

if c.parsed_repr is not None:
c = c.parsed_repr # This might be a message.
if isinstance(c.parsed_repr, Component):
# Only use the parsed_repr if it's something that we know how to print.
c = c.parsed_repr # This might be a message.
else:
# Otherwise, explicitly stringify it.
c = Message(role=role, content=str(c.parsed_repr))
else:
c = Message(role=role, content=c.value) # type: ignore

Expand Down Expand Up @@ -104,64 +99,6 @@ def __init__(
# Key: obj.__class__.__name___ -> Value: jinja2.Template
self._template_cache = SimpleLRUCache(10) if self._use_template_cache else None

def parse(
self, source_component: Component | CBlock, result: ModelOutputThunk
) -> ModelOutputThunk:
"""Parses the output and updates the result's parsed_repr."""
parsed = self._parse(source_component=source_component, result=result)
result.parsed_repr = parsed
return result

def _parse(
self, source_component: Component | CBlock, result: ModelOutputThunk
) -> CBlock | Component:
"""Parses the output from a model."""
if result.tool_calls is not None:
# A tool was successfully requested.
# Assistant responses for tool calling differ by backend. For the default formatter,
# we put all of the function data into the content field in the same format we received it.

# Chat backends should provide an openai-like object in the _meta chat response, which we can use to properly format this output.
if "chat_response" in result._meta:
# Ollama.
return Message(
role=result._meta["chat_response"].message.role,
content=str(result._meta["chat_response"].message.tool_calls),
)
elif "oai_chat_response" in result._meta:
# OpenAI and Watsonx.
return Message(
role=result._meta["oai_chat_response"]["message"]["role"],
content=str(
result._meta["oai_chat_response"]["message"].get(
"tool_calls", []
)
),
)
else:
# HuggingFace (or others). There are no guarantees on how the model represented the function calls.
# Output it in the same format we received the tool call request.
assert result.value is not None
return Message(role="assistant", content=result.value)

if type(source_component) is Message:
if "chat_response" in result._meta:
# chat backends should provide an openai-like object in the _meta chat response, which we can use to properly format this output.
return Message(
role=result._meta["chat_response"].message.role,
content=result._meta["chat_response"].message.content,
)
elif "oai_chat_response" in result._meta:
return Message(
role=result._meta["oai_chat_response"]["message"]["role"],
content=result._meta["oai_chat_response"]["message"]["content"],
)
else:
assert result.value is not None
return Message(role="assistant", content=result.value)
else:
return result

def _stringify(
self,
c: (
Expand Down
57 changes: 39 additions & 18 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import inspect
import json
import threading
from collections.abc import Callable, Coroutine
from collections.abc import Callable, Coroutine, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload

import granite_common
import outlines
Expand All @@ -42,7 +42,6 @@
LocalHFAdapter,
get_adapter_for_intrinsic,
)
from mellea.backends.adapters.catalog import fetch_intrinsic_metadata
from mellea.backends.cache import Cache, SimpleLRUCache
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
from mellea.backends.model_ids import ModelIdentifier
Expand All @@ -57,13 +56,13 @@
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
C,
CBlock,
Component,
Context,
GenerateLog,
GenerateType,
ModelOutputThunk,
ModelToolCall,
)
from mellea.stdlib.chat import Message
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
Expand Down Expand Up @@ -201,13 +200,13 @@ def _make_dc_cache(self, toks, **model_options):

async def generate_from_context(
self,
action: Component | CBlock,
action: Component[C] | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
):
) -> tuple[ModelOutputThunk[C], Context]:
"""Generate using the huggingface model."""
await self.do_generate_walk(action)

Expand Down Expand Up @@ -558,13 +557,13 @@ def _make_merged_kv_cache(

async def _generate_from_context_with_kv_cache(
self,
action: Component | CBlock,
action: Component[C] | CBlock,
ctx: Context,
*,
_format: type[BaseModelSubclass] | None = None,
model_options: dict[str, Any],
tool_calls: bool = False,
) -> ModelOutputThunk:
) -> ModelOutputThunk[C]:
# Construct input.
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
# Otherwise, we will linearize the context and treat it as a raw input.
Expand Down Expand Up @@ -598,7 +597,7 @@ async def _generate_from_context_with_kv_cache(
if _format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = _format.model_json_schema()
schema: dict[str, Any] = _format.model_json_schema() # type: ignore
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
Expand Down Expand Up @@ -765,7 +764,7 @@ async def _generate_from_context_standard(
if _format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = _format.model_json_schema()
schema: dict[str, Any] = _format.model_json_schema() # type: ignore
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
Expand Down Expand Up @@ -931,8 +930,6 @@ async def post_processing(
"ModelOutputThunks should have their model_opts assigned during generation"
)

self.formatter.parse(mot._action, mot)

# Generate the log for this ModelOutputThunk.
generate_log = GenerateLog()
generate_log.prompt = conversation
Expand All @@ -951,17 +948,39 @@ async def post_processing(

mot._generate_log = generate_log

@overload
async def generate_from_raw(
self,
actions: list[Component | CBlock],
actions: list[Component[C]],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk[C]]: ...

@overload
async def generate_from_raw(
self,
actions: list[Component[C] | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk[C | str]]: ...

async def generate_from_raw(
self,
actions: Sequence[Component[C] | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generate using the completions api. Gives the input provided to the model without templating."""
await self.do_generate_walks(actions)
await self.do_generate_walks(list(actions))

if tool_calls:
FancyLogger.get_logger().warning(
Expand Down Expand Up @@ -992,7 +1011,7 @@ async def generate_from_raw(
if format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = format.model_json_schema()
schema: dict[str, Any] = format.model_json_schema() # type: ignore
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
Expand Down Expand Up @@ -1046,8 +1065,10 @@ async def generate_from_raw(
}
},
)

self.formatter.parse(actions[i], result)
action = actions[i]
result.parsed_repr = (
action.parse(result) if isinstance(action, Component) else result.value
)

generate_log = GenerateLog()
generate_log.prompt = self.formatter.print(actions[i])
Expand All @@ -1056,7 +1077,7 @@ async def generate_from_raw(
generate_log.date = datetime.datetime.now()
generate_log.model_output = decoded_result
generate_log.extra = {"format": format, "seed": seed}
generate_log.action = actions[i]
generate_log.action = action

result._generate_log = generate_log
results.append(result)
Expand Down
Loading
Loading