Skip to content

Conversation

@jakelorocco
Copy link
Contributor

@jakelorocco jakelorocco commented Jan 9, 2026

Addresses: #236 and #291

For quick understanding, I created a temporary example file in another branch that highlights most use cases for the typing / most scenarios: https://github.com/generative-computing/mellea/blob/jal/typing-example/typing-examples.py

Changes:
Makes Component, ModelOutputThunk, and SamplingResult into Generics. These class can now take a type and our code will understand the relationship between this Type T and all these classes when calling functions.

Components also now specify a parse function that takes a ModelOutputThunk and parses a value of type T from it. This value gets set as the parsed_repr for that ModelOutputThunk.

The general flow is:

  • Component[T] -> generate -> ModelOutputThunk[T] -> type(ModelOutputThunk.parsed_repr) == T
  • Componet[T] -> generate(return_sampling_results=True) -> SamplingResult[T].result == ModelOutputThunk[T]

Things show as Component[Any] or ModelOutputThunk[Any] when used for parameters, meaning that by default, functions can take an type of Component / ModelOutputThunk without needing the Type [T] type hints to be added.

Had to make changes to existing code:

  • default component parsing methods / return types to strings
  • modify generative slots to utilize the new parsing function
  • remove the current formatter's parsing method

There's some open questions about tool calling that I'm leaving for when we revamp that section of our code base. We may eventually need some generic parsing code that always runs for those cases.

In order to support typed Components, there's a couple of tradeoffs that must be made. Bolded options are ones I opted for.

  • default component type:
    • background
      • by default, our components will return strings for the parsed_repr value of the model output thunk
      • as a result, we have to decide where to define this default value
    • tradeoff
      • have Component (default str) and ComponentType (no default) where ComponentType is used for parameterization and Component is used for class instantiation, OR
      • -> force users to explicitly declare Component[str] during class instantiation
  • CBlock typing
    • background
      • our backends can generate from CBlocks
      • CBlocks exist as simple contianers for strings and basic data
    • tradeoff
      • have CBlocks be typed; this requires making them generics and requires some sort of subclassing system to enable parsing functions for each type
      • -> have CBlocks always create ModelOutputThunks with str parsed_repr values which requires overloading our backend generate from raw functions

Testing:
Added new tests and all current tests pass.

Examples:

Details
import asyncio
from typing import Any, get_args
from mellea import start_session
from mellea.backends.ollama import OllamaModelBackend
from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk, Component
from mellea.stdlib.chat import Message
from mellea.stdlib.genslot import generative
from mellea.stdlib.instruction import Instruction
from mellea.stdlib.requirement import Requirement, ValidationResult,
from mellea import  generative
from mellea.stdlib.sampling import RejectionSamplingStrategy
import mellea.stdlib.functional as mfuncs
from mellea.stdlib.sampling.base import BaseSamplingStrategy

# 1. Works by default with model output thunks.
mot = ModelOutputThunk[float](value="1")
assert hasattr(mot, "__orig_class__"), f"mots are generics and should have this field"
assert get_args(mot.__orig_class__)[0] == float, f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" # type: ignore

unknown_mot = ModelOutputThunk(value="2")
assert not hasattr(unknown_mot, "__orig_class__"), f"unknown mots / mots with no type defined at instantiate don't have this attribute"

# 2. The output parse type works.
class FloatComp(Component[float]):
    def __init__(self, value: str) -> None:
        self.value = value

    def parts(self) -> list[Component | CBlock]:
        return []

    def format_for_llm(self) -> str:
        return self.value

    def parse(self, computed: ModelOutputThunk) -> float:
        if computed.value is None:
            return -1
        return float(computed.value)

fc = FloatComp(value="generate a float")
assert fc.parse(mot) == 1

# 3. We can subclass them too as long as the types are covariant.
class IntComp(FloatComp, Component[int]):
    def parse(self, computed: ModelOutputThunk) -> int:
        if computed.value is None:
            return -1
        try:
            return int(computed.value)
        except:
            return -2

ic = IntComp("generate an int")
assert ic.parse(mot) == 1

# 4. We can't override the generic type for component subclasses outside of the
#    class definition.
try:
    instruction = Instruction[int](description="this is an instruction") # type: ignore
    assert False, "previous line should have raised a TypeError exception"
except TypeError as e:
    # This is expected. Just confirming it's not a generic class.
    assert "not a generic class" in str(e)
except Exception:
    assert False, "code in the try block should raise a TypeError exception"

# 5. Test in context of generation
m = start_session(ctx=ChatContext().add(CBlock("goodbye")))
out, _ = mfuncs.act(ic, context=ChatContext(), backend=m.backend)
print(out.parsed_repr)

# `out` typed as ModelOutputThunk[str]
out = m.backend.generate_from_context(CBlock(""), ctx=ChatContext())

# `out` typed as ModelOutputThunk[float]
out = m.backend.generate_from_context(ic, ctx=ChatContext())

# `out` typed as ModelOutputThunk[float | str]
out = m.backend.generate_from_raw([ic, CBlock("")], ctx=ChatContext())
# `out` typed as ModelOutputThunk[float]
out = m.backend.generate_from_raw([ic, ic], ctx=ChatContext())
# `out` typed as ModelOutputThunk[str]
out = m.backend.generate_from_raw([CBlock("")], ctx=ChatContext())

# 6. Components that return Components work correctly.
class CompWithComp(Component[Instruction]):
    def __init__(self) -> None:
        super().__init__()

    def parse(self, computed: ModelOutputThunk) -> Instruction:
        return Instruction()

    def format_for_llm(self) -> str:
        return ""

    def parts(self) -> list[Component[str] | CBlock]:
        return []

# typed as ModelOutputThunk[Instruction]
mot_with_comp_type, _ = mfuncs.act(action=CompWithComp(), context=ChatContext(), backend=m.backend)
assert mot_with_comp_type.parsed_repr is not None

# typed as ModelOutputThunk[str]
mot_using_previous_parsed, _ = mfuncs.act(action=mot_with_comp_type.parsed_repr, context=ChatContext(), backend=m.backend)


# 7. Individual backends are get typed correctly.
ob = OllamaModelBackend()
mot = ob.generate_from_context(CBlock(""), ctx=m.ctx)
mot = ob.generate_from_raw([ic, CBlock("")], ctx=m.ctx)

# 8. Example with messages.
user_message = Message("user", "Hello!")
response = m.act(user_message)
assert response.parsed_repr is not None
second_response = m.act(response.parsed_repr)

# 9. Sampling strategies are correctly typed
strat = RejectionSamplingStrategy()

async def sampling_test():
    # typed as SamplingResult[float]
    sampling_result = await strat.sample(ic, context=m.ctx, backend=m.backend, requirements=None)

    # typed as ModelOutputThunk[float]
    sampling_result.result
    # typed as list[ModelOutputThunk[float]]
    sampling_result.sample_generations

    # typed as list[Component[Any]]
    # NOTE: We can't make any guarantees about what action ended up being used to generate a result.
    sampling_result.sample_actions

# 10. Works when returning sampling results from a session / functional level.
# typed as SamplingResult[float]
results = m.act(ic, return_sampling_results=True)
results = mfuncs.act(ic, context=m.ctx, backend=m.backend, return_sampling_results=True)

# typed as SamplingResult[str]
results = m.instruct("hello", return_sampling_results=True)
results = mfuncs.instruct("Hello", context=m.ctx, backend=m.backend, return_sampling_results=True)

# 11. Works with Genslots as well.
@generative
def test(val1: int) -> bool:
    ...

# typed as bool
out = test(m=m, val1=1)


# 12. Test that sampling strategies with repair strats return the correct parsed_repr.
async def sampling_return_type():
    m = start_session()
    class CustomSamplingStrat(BaseSamplingStrategy):
        @staticmethod
        def select_from_failure(sampled_actions: list[Component[Any]], sampled_results: list[ModelOutputThunk[Any]], sampled_val: list[list[tuple[Requirement, ValidationResult]]]) -> int:
            return len(sampled_actions) - 1
        
        @staticmethod
        def repair(old_ctx: Context, new_ctx: Context, past_actions: list[Component[Any]], past_results: list[ModelOutputThunk[Any]], past_val: list[list[tuple[Requirement, ValidationResult]]]) -> tuple[Component[Any], Context]:
            return Instruction("print another number 100 greater"), old_ctx

    css = CustomSamplingStrat(loop_budget=3)
    out = await css.sample(
        action=IntComp("2000"),
        context=ChatContext(),
        backend=m.backend,
        requirements=[
            Requirement(None, validation_fn=lambda x: ValidationResult(False), check_only=True)
        ]
    )

    # Even though the intermediate actions are Instructions, the parsed_reprs at each stage
    # are ints.
    for result in out.sample_generations:
        assert isinstance(result.parsed_repr, int), "model output thunks should have the correct parsed_repr type"

    for action in out.sample_actions[1:]:
        assert isinstance(action, Instruction), "repair strategy should force repaired actions to be Instructions"

asyncio.run(sampling_return_type())

# 13. Random functions that have components / model output thunks in them.
def rand_comp(action: Component):
    ...
rand_comp(ic)

def rand_mot(mot: ModelOutputThunk):
    ...
rand_mot(ModelOutputThunk[int](""))

@mergify
Copy link

mergify bot commented Jan 9, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🟢 Enforce conventional commit

Wonderful, this rule succeeded.

Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/

  • title ~= ^(fix|feat|docs|style|refactor|perf|test|build|ci|chore|revert|release)(?:\(.+\))?:

@jakelorocco jakelorocco force-pushed the jal/typed-components-changes branch from 9ab2e7a to 6154c65 Compare January 9, 2026 22:28
@jakelorocco jakelorocco marked this pull request as ready for review January 9, 2026 22:29
@jakelorocco jakelorocco merged commit 2eb689d into main Jan 12, 2026
1 of 4 checks passed
@jakelorocco jakelorocco deleted the jal/typed-components-changes branch January 12, 2026 15:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Extend generative slot args and output types to support any subtype of Component

2 participants