Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/python-merge-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI__APIKEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
ANTHROPIC_CHAT_MODEL_ID: ${{ vars.ANTHROPIC_CHAT_MODEL_ID }}
OPENAI_EMBEDDING_MODEL_ID: ${{ vars.OPENAI__EMBEDDINGMODELID }}
AZURE_OPENAI_CHAT_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__CHATDEPLOYMENTNAME }}
AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__RESPONSESDEPLOYMENTNAME }}
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__EMBEDDINGDEPLOYMENTNAME }}
AZURE_OPENAI_ENDPOINT: ${{ vars.AZUREOPENAI__ENDPOINT }}
LOCAL_MCP_URL: ${{ vars.LOCAL_MCP__URL }}
# For Azure Functions integration tests
Expand Down
385 changes: 385 additions & 0 deletions docs/features/vector-stores-and-embeddings/README.md

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions python/packages/core/agent_framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from ._agents import Agent, BaseAgent, RawAgent, SupportsAgentRun
from ._clients import (
BaseChatClient,
BaseEmbeddingClient,
SupportsChatGetResponse,
SupportsCodeInterpreterTool,
SupportsFileSearchTool,
SupportsGetEmbeddings,
SupportsImageGenerationTool,
SupportsMCPTool,
SupportsWebSearchTool,
Expand Down Expand Up @@ -82,9 +84,14 @@
ChatResponseUpdate,
Content,
ContinuationToken,
Embedding,
EmbeddingGenerationOptions,
EmbeddingInputT,
EmbeddingT,
FinalT,
FinishReason,
FinishReasonLiteral,
GeneratedEmbeddings,
Message,
OuterFinalT,
OuterUpdateT,
Expand Down Expand Up @@ -201,6 +208,7 @@
"BaseAgent",
"BaseChatClient",
"BaseContextProvider",
"BaseEmbeddingClient",
"BaseHistoryProvider",
"Case",
"ChatAndFunctionMiddlewareTypes",
Expand All @@ -218,6 +226,10 @@
"Edge",
"EdgeCondition",
"EdgeDuplicationError",
"Embedding",
"EmbeddingGenerationOptions",
"EmbeddingInputT",
"EmbeddingT",
"Executor",
"FanInEdgeGroup",
"FanOutEdgeGroup",
Expand All @@ -232,6 +244,7 @@
"FunctionMiddleware",
"FunctionMiddlewareTypes",
"FunctionTool",
"GeneratedEmbeddings",
"GraphConnectivityError",
"InMemoryCheckpointStorage",
"InMemoryHistoryProvider",
Expand Down Expand Up @@ -261,6 +274,7 @@
"SupportsChatGetResponse",
"SupportsCodeInterpreterTool",
"SupportsFileSearchTool",
"SupportsGetEmbeddings",
"SupportsImageGenerationTool",
"SupportsMCPTool",
"SupportsWebSearchTool",
Expand Down
142 changes: 141 additions & 1 deletion python/packages/core/agent_framework/_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from ._types import (
ChatResponse,
ChatResponseUpdate,
EmbeddingGenerationOptions,
EmbeddingInputT,
EmbeddingT,
GeneratedEmbeddings,
Message,
ResponseStream,
validate_chat_options,
Expand All @@ -56,7 +60,6 @@

InputT = TypeVar("InputT", contravariant=True)

EmbeddingT = TypeVar("EmbeddingT")
BaseChatClientT = TypeVar("BaseChatClientT", bound="BaseChatClient")

logger = logging.getLogger("agent_framework")
Expand Down Expand Up @@ -660,3 +663,140 @@ def get_file_search_tool(**kwargs: Any) -> Any:


# endregion


# region SupportsGetEmbeddings Protocol

# Contravariant/covariant TypeVars for the Protocol
EmbeddingInputContraT = TypeVar(
"EmbeddingInputContraT",
default="str",
contravariant=True,
)
EmbeddingCoT = TypeVar(
"EmbeddingCoT",
default="list[float]",
)
EmbeddingOptionsContraT = TypeVar(
"EmbeddingOptionsContraT",
bound=TypedDict, # type: ignore[valid-type]
default="EmbeddingGenerationOptions",
contravariant=True,
)


@runtime_checkable
class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingCoT, EmbeddingOptionsContraT]):
"""Protocol for an embedding client that can generate embeddings.

This protocol enables duck-typing for embedding generation. Any class that
implements ``get_embeddings`` with a compatible signature satisfies this protocol.

Generic over the input type (defaults to ``str``), output embedding type
(defaults to ``list[float]``), and options type.

Examples:
.. code-block:: python

from agent_framework import SupportsGetEmbeddings


async def use_embeddings(client: SupportsGetEmbeddings) -> None:
result = await client.get_embeddings(["Hello, world!"])
for embedding in result:
print(embedding.vector)
"""

additional_properties: dict[str, Any]

def get_embeddings(
self,
values: Sequence[EmbeddingInputContraT],
*,
options: EmbeddingOptionsContraT | None = None,
) -> Awaitable[GeneratedEmbeddings[EmbeddingCoT]]:
"""Generate embeddings for the given values.

Args:
values: The values to generate embeddings for.
options: Optional embedding generation options.

Returns:
Generated embeddings with metadata.
"""
...


# endregion


# region BaseEmbeddingClient

# Covariant for the BaseEmbeddingClient
EmbeddingOptionsCoT = TypeVar(
"EmbeddingOptionsCoT",
bound=TypedDict, # type: ignore[valid-type]
default="EmbeddingGenerationOptions",
covariant=True,
)


class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsCoT]):
"""Abstract base class for embedding clients.

Subclasses implement ``get_embeddings`` to provide the actual
embedding generation logic.

Generic over the input type (defaults to ``str``), output embedding type
(defaults to ``list[float]``), and options type.

Examples:
.. code-block:: python

from agent_framework import BaseEmbeddingClient, Embedding, GeneratedEmbeddings
from collections.abc import Sequence


class CustomEmbeddingClient(BaseEmbeddingClient):
async def get_embeddings(self, values, *, options=None):
return GeneratedEmbeddings([Embedding(vector=[0.1, 0.2, 0.3]) for _ in values])
"""

OTEL_PROVIDER_NAME: ClassVar[str] = "unknown"
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"}

def __init__(
self,
*,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize a BaseEmbeddingClient instance.

Args:
additional_properties: Additional properties to pass to the client.
**kwargs: Additional keyword arguments passed to parent classes (for MRO).
"""
self.additional_properties = additional_properties or {}
super().__init__(**kwargs)

@abstractmethod
async def get_embeddings(
self,
values: Sequence[EmbeddingInputT],
*,
options: EmbeddingOptionsCoT | None = None,
) -> GeneratedEmbeddings[EmbeddingT]:
"""Generate embeddings for the given values.

Args:
values: The values to generate embeddings for.
options: Optional embedding generation options.

Returns:
Generated embeddings with metadata.
"""
...


# endregion
141 changes: 139 additions & 2 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
import re
import sys
from asyncio import iscoroutine
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableMapping, Sequence
from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Iterable,
Mapping,
MutableMapping,
Sequence,
)
from copy import deepcopy
from datetime import datetime
from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewType, cast, overload

from pydantic import BaseModel
Expand Down Expand Up @@ -272,7 +282,8 @@ def _serialize_value(value: Any, exclude_none: bool) -> Any:

# region Constants and types
_T = TypeVar("_T")
EmbeddingT = TypeVar("EmbeddingT")
EmbeddingT = TypeVar("EmbeddingT", default="list[float]")
EmbeddingInputT = TypeVar("EmbeddingInputT", default="str")
ChatResponseT = TypeVar("ChatResponseT", bound="ChatResponse")
ToolModeT = TypeVar("ToolModeT", bound="ToolMode")
AgentResponseT = TypeVar("AgentResponseT", bound="AgentResponse")
Expand Down Expand Up @@ -3158,3 +3169,129 @@ def merge_chat_options(
result[key] = value

return result


# region Embedding Types


class EmbeddingGenerationOptions(TypedDict, total=False):
"""Common request settings for embedding generation.

All fields are optional (total=False) to allow partial specification.
Provider-specific TypedDicts extend this with additional options.

Examples:
.. code-block:: python

from agent_framework import EmbeddingGenerationOptions

options: EmbeddingGenerationOptions = {
"model_id": "text-embedding-3-small",
"dimensions": 1536,
}
"""

model_id: str
dimensions: int


class Embedding(Generic[EmbeddingT]):
"""A single embedding vector with metadata.

Generic over the embedding vector type, e.g. ``Embedding[list[float]]``,
``Embedding[list[int]]``, or ``Embedding[bytes]``.

Args:
vector: The embedding vector data.
model_id: The model used to generate this embedding.
dimensions: Explicit dimension count (computed from vector length if omitted).
created_at: Timestamp of when the embedding was generated.
additional_properties: Additional metadata.

Examples:
.. code-block:: python

from agent_framework import Embedding

embedding = Embedding(
vector=[0.1, 0.2, 0.3],
model_id="text-embedding-3-small",
)
assert embedding.dimensions == 3
"""

def __init__(
self,
vector: EmbeddingT,
*,
model_id: str | None = None,
dimensions: int | None = None,
created_at: datetime | None = None,
additional_properties: dict[str, Any] | None = None,
) -> None:
self.vector = vector
self._dimensions = dimensions
self.model_id = model_id
self.created_at = created_at
self.additional_properties = additional_properties or {}

@property
def dimensions(self) -> int | None:
"""Return the number of dimensions in the embedding vector.

Uses the explicitly provided value if set, otherwise computes from vector length.
"""
if self._dimensions is not None:
return self._dimensions
if isinstance(self.vector, (list, tuple, bytes)):
return len(self.vector)
return None


EmbeddingOptionsT = TypeVar(
"EmbeddingOptionsT",
bound=TypedDict, # type: ignore[valid-type]
default="EmbeddingGenerationOptions",
)


class GeneratedEmbeddings(list[Embedding[EmbeddingT]], Generic[EmbeddingT, EmbeddingOptionsT]):
"""A list of generated embeddings with usage metadata.

Extends list for direct iteration and indexing.
Generic over both the embedding vector type and the options type used for generation.

Args:
embeddings: Sequence of Embedding objects.
options: The options used to generate these embeddings.
usage: Token usage information (e.g. prompt_tokens, total_tokens).
additional_properties: Additional metadata.

Examples:
.. code-block:: python

from agent_framework import Embedding, GeneratedEmbeddings

embeddings = GeneratedEmbeddings(
[Embedding(vector=[0.1, 0.2]), Embedding(vector=[0.3, 0.4])],
usage={"prompt_tokens": 10, "total_tokens": 10},
)
assert len(embeddings) == 2
assert embeddings.usage["prompt_tokens"] == 10
"""

def __init__(
self,
embeddings: Iterable[Embedding[EmbeddingT]] | None = None,
*,
options: EmbeddingOptionsT | None = None,
usage: dict[str, Any] | None = None,
additional_properties: dict[str, Any] | None = None,
) -> None:
super().__init__(embeddings or [])
self.options = options
self.usage = usage
self.additional_properties = additional_properties or {}


# endregion
Loading