Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions atomicmemory/providers/atomicmemory/async_handle_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
scope_to_fields,
scope_to_query_pairs,
strip_agent_scope,
strip_read_filters,
)

Route = Callable[[str], str]
Expand Down Expand Up @@ -93,7 +94,7 @@ async def expand(self, refs: list[str], scope: MemoryScope) -> list[AtomicMemory
method="POST",
json=body,
)
echoed = strip_agent_scope(scope)
echoed = strip_read_filters(scope)
return [_to_atomic_memory(m, echoed) for m in raw.get("memories", [])]

async def list(
Expand All @@ -103,7 +104,7 @@ async def list(
) -> AtomicMemoryListResultPage:
opts = _coerce_list_options(options)
_assert_list_options_scope_compat(scope, opts)
pairs: list[tuple[str, str]] = scope_to_query_pairs(scope)
pairs: list[tuple[str, str]] = scope_to_query_pairs(scope, include_thread=True)
if opts.limit is not None:
pairs.append(("limit", str(opts.limit)))
if opts.offset is not None:
Expand All @@ -127,14 +128,16 @@ async def list(
)

async def get(self, id: str, scope: MemoryScope) -> AtomicMemoryMemory | None:
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}")
unfiltered_scope = strip_read_filters(scope)
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}")
raw = await afetch_json_or_none(self._client, self._http, path)
if raw is None:
return None
return _to_atomic_memory(raw, strip_agent_scope(scope))
return _to_atomic_memory(raw, unfiltered_scope)

async def delete(self, id: str, scope: MemoryScope) -> None:
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}")
unfiltered_scope = strip_read_filters(scope)
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}")
try:
await afetch_void(self._client, self._http, path, method="DELETE")
except ProviderError as exc:
Expand All @@ -152,7 +155,7 @@ async def _post_ingest(
) -> AtomicMemoryIngestResult:
assert_scope_allows_visibility(scope, input.visibility)
body: dict[str, Any] = {
**scope_to_fields(scope),
**scope_to_fields(scope, include_thread=True),
"conversation": input.conversation,
"source_site": input.source_site,
"source_url": input.source_url or "",
Expand All @@ -173,7 +176,7 @@ async def _post_search(
scope: MemoryScope,
) -> AtomicMemorySearchResultPage:
body: dict[str, Any] = {
**scope_to_fields(scope, include_agent_scope=True),
**scope_to_fields(scope, include_agent_scope=True, include_thread=True),
"query": request.query,
}
if request.limit is not None:
Expand Down
3 changes: 2 additions & 1 deletion atomicmemory/providers/atomicmemory/async_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from atomicmemory.providers.atomicmemory.path import normalize_api_version
from atomicmemory.providers.atomicmemory.provider import (
_build_ingest_body,
_build_list_path,
_build_package_body,
_build_search_body,
_qs,
Expand Down Expand Up @@ -137,7 +138,7 @@ async def do_delete(self, ref: MemoryRef) -> None:
async def do_list(self, request: ListRequest) -> ListResultPage:
offset = int(request.cursor) if request.cursor else 0
limit = request.limit if request.limit is not None else 20
path = self._route(f"/memories/list?user_id={_qs(request.scope.user)}&limit={limit}&offset={offset}")
path = self._route(_build_list_path(request.scope, limit, offset))
raw = await afetch_json(self._require_client(), self._http_options, path)
memories = [to_memory(m, request.scope) for m in raw.get("memories", [])]
next_offset = offset + len(memories)
Expand Down
2 changes: 2 additions & 0 deletions atomicmemory/providers/atomicmemory/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class UserScope(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
kind: Literal["user"] = "user"
user_id: str = Field(alias="userId")
thread: str | None = None


class WorkspaceScope(BaseModel):
Expand All @@ -45,6 +46,7 @@ class WorkspaceScope(BaseModel):
user_id: str = Field(alias="userId")
workspace_id: str = Field(alias="workspaceId")
agent_id: str = Field(alias="agentId")
thread: str | None = None
agent_scope: AgentScope | None = Field(default=None, alias="agentScope")


Expand Down
36 changes: 28 additions & 8 deletions atomicmemory/providers/atomicmemory/handle_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
scope_to_fields,
scope_to_query_pairs,
strip_agent_scope,
strip_read_filters,
)

Route = Callable[[str], str]
Expand Down Expand Up @@ -97,7 +98,7 @@ def expand(self, refs: list[str], scope: MemoryScope) -> list[AtomicMemoryMemory
method="POST",
json=body,
)
echoed = strip_agent_scope(scope)
echoed = strip_read_filters(scope)
return [_to_atomic_memory(m, echoed) for m in raw.get("memories", [])]

def list(
Expand All @@ -107,7 +108,7 @@ def list(
) -> AtomicMemoryListResultPage:
opts = _coerce_list_options(options)
_assert_list_options_scope_compat(scope, opts)
pairs: list[tuple[str, str]] = scope_to_query_pairs(scope)
pairs: list[tuple[str, str]] = scope_to_query_pairs(scope, include_thread=True)
if opts.limit is not None:
pairs.append(("limit", str(opts.limit)))
if opts.offset is not None:
Expand All @@ -131,14 +132,16 @@ def list(
)

def get(self, id: str, scope: MemoryScope) -> AtomicMemoryMemory | None:
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}")
unfiltered_scope = strip_read_filters(scope)
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}")
raw = fetch_json_or_none(self._client, self._http, path)
if raw is None:
return None
return _to_atomic_memory(raw, strip_agent_scope(scope))
return _to_atomic_memory(raw, unfiltered_scope)

def delete(self, id: str, scope: MemoryScope) -> None:
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}")
unfiltered_scope = strip_read_filters(scope)
path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}")
try:
fetch_void(self._client, self._http, path, method="DELETE")
except ProviderError as exc:
Expand All @@ -160,7 +163,7 @@ def _post_ingest(
) -> AtomicMemoryIngestResult:
assert_scope_allows_visibility(scope, input.visibility)
body: dict[str, Any] = {
**scope_to_fields(scope),
**scope_to_fields(scope, include_thread=True),
"conversation": input.conversation,
"source_site": input.source_site,
"source_url": input.source_url or "",
Expand All @@ -181,7 +184,7 @@ def _post_search(
scope: MemoryScope,
) -> AtomicMemorySearchResultPage:
body: dict[str, Any] = {
**scope_to_fields(scope, include_agent_scope=True),
**scope_to_fields(scope, include_agent_scope=True, include_thread=True),
"query": request.query,
}
if request.limit is not None:
Expand Down Expand Up @@ -263,7 +266,7 @@ def _to_atomic_memory(raw: dict[str, Any], scope: MemoryScope) -> AtomicMemoryMe
payload: dict[str, Any] = {
"id": raw["id"],
"content": raw.get("content") or "",
"scope": scope,
"scope": _build_memory_scope(raw, scope),
"created_at": _parse_iso(raw.get("created_at")) or _now_utc(),
}
if raw.get("updated_at"):
Expand All @@ -274,6 +277,23 @@ def _to_atomic_memory(raw: dict[str, Any], scope: MemoryScope) -> AtomicMemoryMe
return AtomicMemoryMemory.model_validate(payload)


def _build_memory_scope(raw: dict[str, Any], requested_scope: MemoryScope) -> MemoryScope:
"""Validate and project Core ``session_id`` back into namespace scope."""
session_id = raw.get("session_id")
if requested_scope.thread is not None:
if not session_id:
raise ValueError(
"atomicmemory provider: backend response missing required `session_id` for thread-scoped request"
)
if session_id != requested_scope.thread:
raise ValueError(
"atomicmemory provider: backend response `session_id` did not match requested thread scope"
)
if not session_id:
return requested_scope
return requested_scope.model_copy(update={"thread": session_id})


def _to_atomic_search_result(raw: dict[str, Any], scope: MemoryScope) -> AtomicMemorySearchResult:
similarity = _coalesce(raw.get("semantic_similarity"), raw.get("similarity"))
ranking_score = _coalesce(raw.get("ranking_score"), raw.get("score"))
Expand Down
26 changes: 25 additions & 1 deletion atomicmemory/providers/atomicmemory/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,37 @@ def to_memory(raw: dict[str, Any], scope: Scope) -> Memory:
return Memory(
id=raw["id"],
content=raw["content"],
scope=scope,
scope=_build_scope(raw, scope),
created_at=created_at,
provenance=_build_provenance(raw),
metadata=_build_metadata(raw),
)


def _build_scope(raw: dict[str, Any], scope: Scope) -> Scope:
"""Merge backend-projected scope fields and validate scoped reads."""
namespace = raw.get("namespace")
session_id = raw.get("session_id")
if scope.namespace is not None and namespace is not None and namespace != scope.namespace:
raise ValueError("atomicmemory provider: backend response `namespace` did not match requested namespace scope")
if scope.thread is not None:
if not session_id:
raise ValueError(
"atomicmemory provider: backend response missing required `session_id` for thread-scoped request"
)
if session_id != scope.thread:
raise ValueError(
"atomicmemory provider: backend response `session_id` did not match requested thread scope"
)

updates: dict[str, Any] = {}
if namespace:
updates["namespace"] = namespace
if session_id:
updates["thread"] = session_id
return scope.model_copy(update=updates)


def _build_provenance(raw: dict[str, Any]) -> Provenance | None:
fields: dict[str, Any] = {}
if "source_site" in raw and raw["source_site"] is not None:
Expand Down
21 changes: 19 additions & 2 deletions atomicmemory/providers/atomicmemory/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from datetime import datetime
from typing import Any
from urllib.parse import quote
from urllib.parse import quote, urlencode

import httpx

Expand All @@ -31,6 +31,7 @@
MemoryVersion,
PackageFormat,
PackageRequest,
Scope,
SearchRequest,
SearchResult,
SearchResultPage,
Expand Down Expand Up @@ -129,7 +130,7 @@ def do_delete(self, ref: MemoryRef) -> None:
def do_list(self, request: ListRequest) -> ListResultPage:
offset = int(request.cursor) if request.cursor else 0
limit = request.limit if request.limit is not None else 20
path = self._route(f"/memories/list?user_id={_qs(request.scope.user)}&limit={limit}&offset={offset}")
path = self._route(_build_list_path(request.scope, limit, offset))
raw = fetch_json(self._require_client(), self._http_options, path)
memories = [to_memory(m, request.scope) for m in raw.get("memories", [])]
next_offset = offset + len(memories)
Expand Down Expand Up @@ -264,6 +265,8 @@ def _build_ingest_body(input: IngestInput) -> dict[str, Any]:
"source_site": input.provenance.source if input.provenance and input.provenance.source else "sdk",
"source_url": input.provenance.source_url if input.provenance and input.provenance.source_url else "",
}
if input.scope.thread is not None:
body["session_id"] = input.scope.thread
if input.mode == "verbatim":
body["skip_extraction"] = True
if input.metadata:
Expand All @@ -282,6 +285,8 @@ def _build_search_body(request: SearchRequest) -> dict[str, Any]:
body["threshold"] = request.threshold
if request.scope.namespace is not None:
body["namespace_scope"] = request.scope.namespace
if request.scope.thread is not None:
body["session_id"] = request.scope.thread
return body


Expand All @@ -298,3 +303,15 @@ def _build_package_body(request: PackageRequest) -> dict[str, Any]:
def _qs(value: str | None) -> str:
"""URL-encode a query-string value; empty string when falsy."""
return quote(value, safe="") if value else ""


def _build_list_path(scope: Scope, limit: int, offset: int) -> str:
"""Build the Core list path, including optional thread scope."""
pairs = [
("user_id", scope.user or ""),
("limit", str(limit)),
("offset", str(offset)),
]
if scope.thread is not None:
pairs.append(("session_id", scope.thread))
return f"/memories/list?{urlencode(pairs)}"
33 changes: 28 additions & 5 deletions atomicmemory/providers/atomicmemory/scope_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from typing import Any

from atomicmemory.core.errors import ValidationError
from atomicmemory.providers.atomicmemory.handle import MemoryScope, WorkspaceScope
from atomicmemory.providers.atomicmemory.handle import MemoryScope, UserScope, WorkspaceScope


def scope_to_fields(
scope: MemoryScope,
*,
include_agent_scope: bool = False,
include_thread: bool = False,
) -> dict[str, Any]:
"""Translate a `MemoryScope` to wire-format request fields.

Expand All @@ -26,28 +27,36 @@ def scope_to_fields(
include_agent_scope: Emit ``agent_scope`` on the wire. Defaults
to ``False``; only the search routes opt in (core ignores
``agent_scope`` on expand/list/get/delete).
include_thread: Emit ``session_id`` on routes Core honors:
ingest, search, and list.

Returns:
A dict with ``user_id`` always set, plus ``workspace_id`` /
``agent_id`` (and optionally ``agent_scope``) for workspace
scopes.
"""
if not isinstance(scope, WorkspaceScope):
return {"user_id": scope.user_id}
fields: dict[str, Any] = {
user_fields: dict[str, Any] = {"user_id": scope.user_id}
if include_thread and scope.thread is not None:
user_fields["session_id"] = scope.thread
return user_fields
workspace_fields: dict[str, Any] = {
"user_id": scope.user_id,
"workspace_id": scope.workspace_id,
"agent_id": scope.agent_id,
}
if include_agent_scope and scope.agent_scope is not None:
fields["agent_scope"] = scope.agent_scope
return fields
workspace_fields["agent_scope"] = scope.agent_scope
if include_thread and scope.thread is not None:
workspace_fields["session_id"] = scope.thread
return workspace_fields


def scope_to_query_pairs(
scope: MemoryScope,
*,
include_agent_scope: bool = False,
include_thread: bool = False,
) -> list[tuple[str, str]]:
"""Translate a scope to ``[(key, value)]`` pairs for query strings.

Expand All @@ -66,6 +75,8 @@ def scope_to_query_pairs(
pairs.extend(("agent_scope", v) for v in value)
else:
pairs.append(("agent_scope", value))
if include_thread and scope.thread is not None:
pairs.append(("session_id", scope.thread))
return pairs


Expand All @@ -91,6 +102,18 @@ def strip_agent_scope(scope: MemoryScope) -> MemoryScope:
"""
if not isinstance(scope, WorkspaceScope):
return scope
return WorkspaceScope(
user_id=scope.user_id,
workspace_id=scope.workspace_id,
agent_id=scope.agent_id,
thread=scope.thread,
)


def strip_read_filters(scope: MemoryScope) -> MemoryScope:
"""Drop filters the target route did not apply before echoing scope."""
if not isinstance(scope, WorkspaceScope):
return UserScope(user_id=scope.user_id)
return WorkspaceScope(
user_id=scope.user_id,
workspace_id=scope.workspace_id,
Expand Down
Loading
Loading