diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 367b486cd..e934d3a19 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -45,6 +45,33 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: return metadata +def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]: + """ + Flatten the 'info' field in metadata to the top level. + + If metadata contains an 'info' field that is a dictionary, all its key-value pairs + will be moved to the top level of metadata, and the 'info' field will be removed. + + Args: + metadata: Dictionary that may contain an 'info' field + + Returns: + Dictionary with 'info' fields flattened to top level + + Example: + Input: {"user_id": "xxx", "info": {"A": "value1", "B": "value2"}} + Output: {"user_id": "xxx", "A": "value1", "B": "value2"} + """ + if "info" in metadata and isinstance(metadata["info"], dict): + # Copy info fields to top level + info_dict = metadata.pop("info") + for key, value in info_dict.items(): + # Only add if key doesn't already exist at top level (to avoid overwriting) + if key not in metadata: + metadata[key] = value + return metadata + + class Neo4jGraphDB(BaseGraphDB): """Neo4j-based implementation of a graph memory store.""" @@ -170,6 +197,9 @@ def remove_oldest_memory( def add_node( self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None ) -> None: + logger.info(f"[add_node] metadata: {metadata},info: {metadata.get('info')}") + print(f"[add_node] metadata: {metadata},info: {metadata.get('info')}") + user_name = user_name if user_name else self.config.user_name if not self.config.use_multi_db and (self.config.user_name or user_name): metadata["user_name"] = user_name @@ -177,6 +207,9 @@ def add_node( # Safely process metadata metadata = _prepare_node_metadata(metadata) + # Flatten info fields to top level (for Neo4j flat structure) + metadata = _flatten_info_fields(metadata) + # Merge node and set metadata created_at = metadata.pop("created_at") updated_at = metadata.pop("updated_at") @@ -661,6 +694,8 @@ def search_by_embedding( threshold: float | None = None, search_filter: dict | None = None, user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -695,8 +730,21 @@ def search_by_embedding( where_clauses.append("node.memory_type = $scope") if status: where_clauses.append("node.status = $status") - if not self.config.use_multi_db and (self.config.user_name or user_name): - where_clauses.append("node.user_name = $user_name") + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + node_alias="node", + ) + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") # Add search_filter conditions if search_filter: @@ -704,6 +752,14 @@ def search_by_embedding( param_name = f"filter_{key}" where_clauses.append(f"node.{key} = ${param_name}") + # Build filter conditions using common method + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter=filter, + param_counter_start=0, + node_alias="node", + ) + where_clauses.extend(filter_conditions) + where_clause = "" if where_clauses: where_clause = "WHERE " + " AND ".join(where_clauses) @@ -721,18 +777,25 @@ def search_by_embedding( parameters["scope"] = scope if status: parameters["status"] = status - if not self.config.use_multi_db and (self.config.user_name or user_name): - if kwargs.get("cube_name"): - parameters["user_name"] = kwargs["cube_name"] - else: - parameters["user_name"] = user_name - # Add search_filter parameters + # Add user_name and knowledgebase_ids parameters using common method + parameters.update(user_name_params) + + # Handle cube_name override for user_name + if kwargs.get("cube_name"): + parameters["user_name"] = kwargs["cube_name"] + if search_filter: for key, value in search_filter.items(): param_name = f"filter_{key}" parameters[param_name] = value + # Add filter parameters + if filter_params: + parameters.update(filter_params) + + logger.info(f"[search_by_embedding] query: {query},parameters: {parameters}") + print(f"[search_by_embedding] query: {query},parameters: {parameters}") with self.driver.session(database=self.db_name) as session: result = session.run(query, parameters) records = [{"id": record["id"], "score": record["score"]} for record in result] @@ -744,7 +807,11 @@ def search_by_embedding( return records def get_by_metadata( - self, filters: list[dict[str, Any]], user_name: str | None = None + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, ) -> list[str]: """ TODO: @@ -770,6 +837,12 @@ def get_by_metadata( - Supports structured querying such as tag/category/importance/time filtering. - Can be used for faceted recall or prefiltering before embedding rerank. """ + logger.info( + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + ) + print( + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + ) user_name = user_name if user_name else self.config.user_name where_clauses = [] params = {} @@ -802,12 +875,43 @@ def get_by_metadata( else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and (self.config.user_name or user_name): - where_clauses.append("n.user_name = $user_name") - params["user_name"] = user_name + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + node_alias="n", + ) + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Build filter conditions using common method + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter=filter, + param_counter_start=len(filters), # Start from len(filters) to avoid conflicts + node_alias="n", + ) + where_clauses.extend(filter_conditions) + + where_str = " AND ".join(where_clauses) if where_clauses else "" + if where_str: + query = f"MATCH (n:Memory) WHERE {where_str} RETURN n.id AS id" + else: + query = "MATCH (n:Memory) RETURN n.id AS id" + + # Add user_name and knowledgebase_ids parameters using common method + params.update(user_name_params) - where_str = " AND ".join(where_clauses) - query = f"MATCH (n:Memory) WHERE {where_str} RETURN n.id AS id" + # Merge filter parameters + if filter_params: + params.update(filter_params) + logger.info(f"[get_by_metadata] query: {query},params: {params}") + print(f"[get_by_metadata] query: {query},params: {params}") with self.driver.session(database=self.db_name) as session: result = session.run(query, params) @@ -999,33 +1103,78 @@ def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> No target_id=edge["target"], ) - def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]: + def get_all_memory_items( + self, + scope: str, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: """ Retrieve all memory items of a specific memory_type. Args: scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. + filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. + Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]} Returns: Returns: list[dict]: Full list of memory items under this scope. """ + logger.info( + f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + ) + print( + f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + ) + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") - where_clause = "WHERE n.memory_type = $scope" + where_clauses = ["n.memory_type = $scope"] params = {"scope": scope} - if not self.config.use_multi_db and (self.config.user_name or user_name): - where_clause += " AND n.user_name = $user_name" - params["user_name"] = user_name + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + node_alias="n", + ) + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Build filter conditions using common method + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter=filter, + param_counter_start=0, + node_alias="n", + ) + where_clauses.extend(filter_conditions) + + where_clause = "WHERE " + " AND ".join(where_clauses) + + # Add user_name and knowledgebase_ids parameters using common method + params.update(user_name_params) + + # Add filter parameters + if filter_params: + params.update(filter_params) query = f""" MATCH (n:Memory) {where_clause} RETURN n """ + logger.info(f"[get_all_memory_items] query: {query},params: {params}") + print(f"[get_all_memory_items] query: {query},params: {params}") with self.driver.session(database=self.db_name) as session: results = session.run(query, params) @@ -1183,6 +1332,159 @@ def _index_exists(self, index_name: str) -> bool: return True return False + def _build_user_name_and_kb_ids_conditions_cypher( + self, + user_name: str | None, + knowledgebase_ids: list[str] | None, + default_user_name: str | None = None, + node_alias: str = "node", + ) -> tuple[list[str], dict[str, Any]]: + """ + Build user_name and knowledgebase_ids conditions for Cypher queries. + + Args: + user_name: User name for filtering + knowledgebase_ids: List of knowledgebase IDs + default_user_name: Default user name from config if user_name is None + node_alias: Node alias in Cypher query (default: "node" or "n") + + Returns: + Tuple of (condition_strings_list, parameters_dict) + """ + user_name_conditions = [] + params = {} + effective_user_name = user_name if user_name else default_user_name + + # Only add user_name condition if not using multi-db mode + if not self.config.use_multi_db and (self.config.user_name or effective_user_name): + user_name_conditions.append(f"{node_alias}.user_name = $user_name") + params["user_name"] = effective_user_name + + # Add knowledgebase_ids conditions (checking user_name field in the data) + if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: + for idx, kb_id in enumerate(knowledgebase_ids): + if isinstance(kb_id, str): + param_name = f"kb_id_{idx}" + user_name_conditions.append(f"{node_alias}.user_name = ${param_name}") + params[param_name] = kb_id + + return user_name_conditions, params + + def _build_filter_conditions_cypher( + self, + filter: dict | None, + param_counter_start: int = 0, + node_alias: str = "node", + ) -> tuple[list[str], dict[str, Any]]: + """ + Build filter conditions for Cypher queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + param_counter_start: Starting value for parameter counter (to avoid conflicts) + node_alias: Node alias in Cypher query (default: "node" or "n") + + Returns: + Tuple of (condition_strings_list, parameters_dict) + """ + filter_conditions = [] + filter_params = {} + + if not filter: + return filter_conditions, filter_params + + def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[str, dict]: + """Build a WHERE condition for a single filter item. + + Args: + condition_dict: A dict like {"id": "xxx"} or {"A": "xxx"} or {"created_at": {"gt": "2025-11-01"}} + param_counter: List to track parameter counter for unique param names + + Returns: + Tuple of (condition_string, parameters_dict) + """ + condition_parts = [] + params = {} + + for key, value in condition_dict.items(): + # Check if value is a dict with comparison operators (gt, lt, gte, lte) + if isinstance(value, dict): + # Handle comparison operators: gt (greater than), lt (less than), gte (greater than or equal), lte (less than or equal) + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte"): + # Map operator to Cypher operator + cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} + cypher_op = cypher_op_map[op] + + # All fields are stored as flat properties in Neo4j + param_name = f"filter_{key}_{op}_{param_counter[0]}" + param_counter[0] += 1 + params[param_name] = op_value + + # Check if field is a date field (created_at, updated_at, etc.) + # Use datetime() function for date comparisons + if key in ("created_at", "updated_at") or key.endswith("_at"): + condition_parts.append( + f"{node_alias}.{key} {cypher_op} datetime(${param_name})" + ) + else: + condition_parts.append( + f"{node_alias}.{key} {cypher_op} ${param_name}" + ) + elif op == "contains": + # Handle contains operator (for array fields like tags, sources) + param_name = f"filter_{key}_{op}_{param_counter[0]}" + param_counter[0] += 1 + params[param_name] = op_value + + # For array fields, check if element is in array + if key in ("tags", "sources"): + condition_parts.append(f"${param_name} IN {node_alias}.{key}") + else: + # For non-array fields, contains might not be applicable, but we'll treat it as IN for consistency + condition_parts.append(f"${param_name} IN {node_alias}.{key}") + elif op == "like": + # Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%') + # Neo4j uses CONTAINS for string matching + param_name = f"filter_{key}_{op}_{param_counter[0]}" + param_counter[0] += 1 + params[param_name] = op_value + condition_parts.append(f"{node_alias}.{key} CONTAINS ${param_name}") + else: + # All fields are stored as flat properties in Neo4j (simple equality) + param_name = f"filter_{key}_{param_counter[0]}" + param_counter[0] += 1 + params[param_name] = value + condition_parts.append(f"{node_alias}.{key} = ${param_name}") + + return " AND ".join(condition_parts), params + + param_counter = [param_counter_start] + + if isinstance(filter, dict): + if "or" in filter: + # OR logic: at least one condition must match + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str, params = build_filter_condition(condition, param_counter) + if condition_str: + or_conditions.append(f"({condition_str})") + filter_params.update(params) + if or_conditions: + filter_conditions.append(f"({' OR '.join(or_conditions)})") + + elif "and" in filter: + # AND logic: all conditions must match + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str, params = build_filter_condition(condition, param_counter) + if condition_str: + filter_conditions.append(f"({condition_str})") + filter_params.update(params) + + return filter_conditions, filter_params + def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: node = node_data.copy() diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 6f7786834..ff7d5f50b 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1,5 +1,7 @@ import json +import re +from datetime import datetime from typing import Any from memos.configs.graph_db import Neo4jGraphDBConfig @@ -143,6 +145,8 @@ def search_by_embedding( threshold: float | None = None, search_filter: dict | None = None, user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -155,6 +159,9 @@ def search_by_embedding( status (str, optional): Node status filter (e.g., 'activated', 'archived'). threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters to apply. + filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. + Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]} + knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by. Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. @@ -165,10 +172,12 @@ def search_by_embedding( - If 'status' is provided, it further filters nodes by status. - If 'threshold' is provided, only results with score >= threshold will be returned. - If 'search_filter' is provided, it applies additional metadata-based filtering. + - If 'filter' is provided, it applies complex filter conditions with AND/OR logic. - The returned IDs can be used to fetch full node data from Neo4j if needed. """ user_name = user_name if user_name else self.config.user_name - # Build VecDB filter + + # First, perform vector search in external vector DB vec_filter = {} if scope: vec_filter["memory_type"] = scope @@ -185,45 +194,518 @@ def search_by_embedding( vec_filter.update(search_filter) # Perform vector search - results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter) + vec_results = [] + if self.vec_db: + try: + vec_results = self.vec_db.search( + query_vector=vector, top_k=top_k, filter=vec_filter + ) + except Exception as e: + logger.warning(f"[VecDB] search failed: {e}") # Filter by threshold if threshold is not None: - results = [r for r in results if r.score is None or r.score >= threshold] + vec_results = [r for r in vec_results if r.score is None or r.score >= threshold] + + # If no filter or knowledgebase_ids provided, return vector search results directly + if not filter and not knowledgebase_ids: + return [{"id": r.id, "score": r.score} for r in vec_results] + + # Extract IDs from vector search results + vec_ids = [r.id for r in vec_results] + if not vec_ids: + return [] + + # Build WHERE clause for Neo4j filtering + where_clauses = ["n.id IN $vec_ids"] + params = {"vec_ids": vec_ids} + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + node_alias="n", + ) + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Build filter conditions using common method + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter=filter, + param_counter_start=0, + node_alias="n", + ) + where_clauses.extend(filter_conditions) + + where_clause = "WHERE " + " AND ".join(where_clauses) - # Return consistent format - return [{"id": r.id, "score": r.score} for r in results] + # Add user_name and knowledgebase_ids parameters using common method + params.update(user_name_params) + + # Add filter parameters + if filter_params: + params.update(filter_params) + + # Query Neo4j to filter results + query = f""" + MATCH (n:Memory) + {where_clause} + RETURN n.id AS id + """ + logger.info(f"[search_by_embedding] query: {query}, params: {params}") - def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]: + with self.driver.session(database=self.db_name) as session: + neo4j_results = session.run(query, params) + filtered_ids = {record["id"] for record in neo4j_results} + + # Filter vector results by Neo4j filtered IDs and return with scores + filtered_results = [ + {"id": r.id, "score": r.score} for r in vec_results if r.id in filtered_ids + ] + + return filtered_results + + def _normalize_date_string(self, date_str: str) -> str: + """ + Normalize date string to ISO 8601 format for Neo4j datetime() function. + + Args: + date_str: Date string in various formats (e.g., "2025-09-19", "2025-09-19T00:00:00Z") + + Returns: + ISO 8601 formatted date string (e.g., "2025-09-19T00:00:00Z") + """ + if not isinstance(date_str, str): + return date_str + + # If already in ISO 8601 format with time, return as is + if "T" in date_str or date_str.endswith("Z") or "+" in date_str or "-" in date_str[-6:]: + return date_str + + # Check if it's a simple date format (YYYY-MM-DD) + date_pattern = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", date_str) + if date_pattern: + # Convert to ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ + # For "gt" (greater than), use 00:00:00 of the next day + # For "lt" (less than), use 00:00:00 of the same day + # For "gte" (greater than or equal), use 00:00:00 of the same day + # For "lte" (less than or equal), use 23:59:59.999999999 of the same day + # But we'll use 00:00:00Z as default and let the caller handle the logic + return f"{date_str}T00:00:00Z" + + # If it's already a datetime string, try to parse and reformat + try: + # Try to parse various datetime formats + dt = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + return dt.isoformat().replace("+00:00", "Z") + except (ValueError, AttributeError): + # If parsing fails, return as is + return date_str + + def _build_filter_conditions_cypher( + self, + filter: dict | None, + param_counter_start: int = 0, + node_alias: str = "node", + ) -> tuple[list[str], dict[str, Any]]: + """ + Build filter conditions for Cypher queries with date normalization. + + This method extends the parent class method by normalizing date strings + to ISO 8601 format before building conditions. + + Args: + filter: Filter dictionary with "or" or "and" logic + param_counter_start: Starting value for parameter counter (to avoid conflicts) + node_alias: Node alias in Cypher query (default: "node" or "n") + + Returns: + Tuple of (condition_strings_list, parameters_dict) + """ + normalized_filter = self._normalize_filter_dates(filter) if filter else filter + + # Call parent method with normalized filter + return super()._build_filter_conditions_cypher( + filter=normalized_filter, + param_counter_start=param_counter_start, + node_alias=node_alias, + ) + + def _normalize_filter_dates(self, filter: dict) -> dict: + """ + Recursively normalize date strings in filter dictionary. + + Args: + filter: Filter dictionary that may contain date strings + + Returns: + Filter dictionary with normalized date strings + """ + if not isinstance(filter, dict): + return filter + + normalized = {} + + if "and" in filter: + normalized["and"] = [ + self._normalize_condition_dates(cond) if isinstance(cond, dict) else cond + for cond in filter["and"] + ] + elif "or" in filter: + normalized["or"] = [ + self._normalize_condition_dates(cond) if isinstance(cond, dict) else cond + for cond in filter["or"] + ] + else: + # Single condition + normalized = self._normalize_condition_dates(filter) + + return normalized + + def _normalize_condition_dates(self, condition: dict) -> dict: + """ + Normalize date strings in a single condition dictionary. + + Args: + condition: A condition dict like {"created_at": {"gt": "2025-09-19"}} + + Returns: + Condition dict with normalized date strings + """ + from datetime import timedelta + + normalized = {} + + for key, value in condition.items(): + # Check if this is a date field + is_date_field = key in ("created_at", "updated_at") or key.endswith("_at") + + if isinstance(value, dict): + # Handle comparison operators + normalized_value = {} + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte") and is_date_field: + # Normalize date string for date comparisons + if isinstance(op_value, str): + # Check if it's a simple date format (YYYY-MM-DD) + date_pattern = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", op_value) + if date_pattern: + try: + # Parse the date + dt = datetime.fromisoformat(op_value + "T00:00:00") + + if op == "gt": + # "gt": "2025-09-19" means > 2025-09-19 00:00:00 + # So we keep it as 2025-09-19T00:00:00Z + normalized_value[op] = dt.isoformat() + "Z" + elif op == "gte": + # "gte": "2025-09-19" means >= 2025-09-19 00:00:00 + normalized_value[op] = dt.isoformat() + "Z" + elif op == "lt": + # "lt": "2025-11-29" means < 2025-11-29 (exclude the entire day) + # So we convert to the start of the next day: 2025-11-30T00:00:00Z + # This ensures all times on 2025-11-29 are included + dt_next = dt + timedelta(days=1) + normalized_value[op] = dt_next.isoformat() + "Z" + elif op == "lte": + # "lte": "2025-11-29" means <= 2025-11-29 23:59:59.999999 + # So we convert to end of day: 2025-11-29T23:59:59.999999Z + dt_end = dt + timedelta(days=1) - timedelta(microseconds=1) + normalized_value[op] = dt_end.isoformat() + "Z" + except ValueError: + # If parsing fails, use the original normalization + normalized_value[op] = self._normalize_date_string(op_value) + else: + # Already in a more complex format, just normalize it + normalized_value[op] = self._normalize_date_string(op_value) + else: + normalized_value[op] = op_value + else: + normalized_value[op] = op_value + normalized[key] = normalized_value + else: + normalized[key] = value + + return normalized + + def get_all_memory_items( + self, + scope: str, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: """ Retrieve all memory items of a specific memory_type. Args: - scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. + scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', 'UserMemory', or 'OuterMemory'. + filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. + Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]} + knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by. + Returns: list[dict]: Full list of memory items under this scope. """ + logger.info( + f"[get_all_memory_items] scope: {scope}, filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + ) + print( + f"[get_all_memory_items] scope: {scope}, filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + ) + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name - if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory"}: + if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") - where_clause = "WHERE n.memory_type = $scope" + where_clauses = ["n.memory_type = $scope"] params = {"scope": scope} - if not self.config.use_multi_db and (self.config.user_name or user_name): - where_clause += " AND n.user_name = $user_name" - params["user_name"] = user_name + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + node_alias="n", + ) + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Build filter conditions using common method + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter=filter, + param_counter_start=0, + node_alias="n", + ) + where_clauses.extend(filter_conditions) + + where_clause = "WHERE " + " AND ".join(where_clauses) + + # Add user_name and knowledgebase_ids parameters using common method + params.update(user_name_params) + + # Add filter parameters + if filter_params: + params.update(filter_params) query = f""" MATCH (n:Memory) {where_clause} RETURN n """ + logger.info(f"[get_all_memory_items] query: {query}, params: {params}") + print(f"[get_all_memory_items] query: {query}, params: {params}") with self.driver.session(database=self.db_name) as session: results = session.run(query, params) return [self._parse_node(dict(record["n"])) for record in results] + def get_by_metadata( + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + ) -> list[str]: + """ + Retrieve node IDs that match given metadata filters. + Supports exact match. + + Args: + filters: List of filter dicts like: + [ + {"field": "key", "op": "in", "value": ["A", "B"]}, + {"field": "confidence", "op": ">=", "value": 80}, + {"field": "tags", "op": "contains", "value": "AI"}, + ... + ] + filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. + knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by user_name. + + Returns: + list[str]: Node IDs whose metadata match the filter conditions. (AND logic). + + Notes: + - Supports structured querying such as tag/category/importance/time filtering. + - Can be used for faceted recall or prefiltering before embedding rerank. + """ + logger.info( + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + ) + print( + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + ) + user_name = user_name if user_name else self.config.user_name + where_clauses = [] + params = {} + + for i, f in enumerate(filters): + field = f["field"] + op = f.get("op", "=") + value = f["value"] + param_key = f"val{i}" + + # Build WHERE clause + if op == "=": + where_clauses.append(f"n.{field} = ${param_key}") + params[param_key] = value + elif op == "in": + where_clauses.append(f"n.{field} IN ${param_key}") + params[param_key] = value + elif op == "contains": + where_clauses.append(f"ANY(x IN ${param_key} WHERE x IN n.{field})") + params[param_key] = value + elif op == "starts_with": + where_clauses.append(f"n.{field} STARTS WITH ${param_key}") + params[param_key] = value + elif op == "ends_with": + where_clauses.append(f"n.{field} ENDS WITH ${param_key}") + params[param_key] = value + elif op in [">", ">=", "<", "<="]: + where_clauses.append(f"n.{field} {op} ${param_key}") + params[param_key] = value + else: + raise ValueError(f"Unsupported operator: {op}") + + # Build user_name filter with knowledgebase_ids support (OR relationship) + user_name_conditions = [] + if not self.config.use_multi_db and (self.config.user_name or user_name): + user_name_conditions.append("n.user_name = $user_name") + + # Add knowledgebase_ids conditions (checking user_name field in the data) + if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: + for idx, kb_id in enumerate(knowledgebase_ids): + if isinstance(kb_id, str): + param_name = f"kb_id_{idx}" + user_name_conditions.append(f"n.user_name = ${param_name}") + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add filter conditions (supports "or" and "and" logic) + filter_params = {} + if filter: + # Helper function to build a single filter condition + def build_filter_condition( + condition_dict: dict, param_counter: list + ) -> tuple[str, dict]: + """Build a WHERE condition for a single filter item. + + Args: + condition_dict: A dict like {"id": "xxx"} or {"A": "xxx"} or {"created_at": {"gt": "2025-11-01"}} + param_counter: List to track parameter counter for unique param names + + Returns: + Tuple of (condition_string, parameters_dict) + """ + condition_parts = [] + filter_params_inner = {} + + for key, value in condition_dict.items(): + # Check if value is a dict with comparison operators (gt, lt, gte, lte) + if isinstance(value, dict): + # Handle comparison operators: gt (greater than), lt (less than), gte (greater than or equal), lte (less than or equal) + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte"): + # Map operator to Cypher operator + cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} + cypher_op = cypher_op_map[op] + + # All fields are stored as flat properties in Neo4j + param_name = f"filter_meta_{key}_{op}_{param_counter[0]}" + param_counter[0] += 1 + filter_params_inner[param_name] = op_value + + # Check if field is a date field (created_at, updated_at, etc.) + # Use datetime() function for date comparisons + if key in ("created_at", "updated_at") or key.endswith("_at"): + condition_parts.append( + f"n.{key} {cypher_op} datetime(${param_name})" + ) + else: + condition_parts.append(f"n.{key} {cypher_op} ${param_name}") + else: + # All fields are stored as flat properties in Neo4j (simple equality) + param_name = f"filter_meta_{key}_{param_counter[0]}" + param_counter[0] += 1 + filter_params_inner[param_name] = value + condition_parts.append(f"n.{key} = ${param_name}") + + return " AND ".join(condition_parts), filter_params_inner + + # Process filter structure + param_counter = [ + len(filters) + ] # Use list to allow modification in nested function, start from len(filters) to avoid conflicts + + if isinstance(filter, dict): + if "or" in filter: + # OR logic: at least one condition must match + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str, filter_params_inner = build_filter_condition( + condition, param_counter + ) + if condition_str: + or_conditions.append(f"({condition_str})") + filter_params.update(filter_params_inner) + if or_conditions: + where_clauses.append(f"({' OR '.join(or_conditions)})") + + elif "and" in filter: + # AND logic: all conditions must match + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str, filter_params_inner = build_filter_condition( + condition, param_counter + ) + if condition_str: + where_clauses.append(f"({condition_str})") + filter_params.update(filter_params_inner) + + where_str = " AND ".join(where_clauses) if where_clauses else "" + if where_str: + query = f"MATCH (n:Memory) WHERE {where_str} RETURN n.id AS id" + else: + query = "MATCH (n:Memory) RETURN n.id AS id" + + # Add user_name parameter + if not self.config.use_multi_db and (self.config.user_name or user_name): + params["user_name"] = user_name + + # Add knowledgebase_ids parameters + if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: + for idx, kb_id in enumerate(knowledgebase_ids): + if isinstance(kb_id, str): + param_name = f"kb_id_{idx}" + params[param_name] = kb_id + + # Merge filter parameters + if filter_params: + params.update(filter_params) + logger.info(f"[get_by_metadata] query: {query},params: {params}") + print(f"[get_by_metadata] query: {query},params: {params}") + + with self.driver.session(database=self.db_name) as session: + result = session.run(query, params) + return [record["id"] for record in result] + def clear(self, user_name: str | None = None) -> None: """ Clear the entire graph if the target database exists. diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index da1635296..a7e60704e 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1,5 +1,6 @@ import json import random +import textwrap from datetime import datetime from typing import Any, Literal @@ -1460,12 +1461,18 @@ def search_by_embedding( threshold: float | None = None, search_filter: dict | None = None, user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, **kwargs, ) -> list[dict]: """ Retrieve node IDs based on vector similarity using PostgreSQL vector operations. """ # Build WHERE clause dynamically like nebular.py + logger.info( + f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + ) + print(f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") where_clauses = [] if scope: where_clauses.append( @@ -1490,11 +1497,20 @@ def search_by_embedding( # else: # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype") """ - user_name = user_name if user_name else self.config.user_name - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, ) + # Add OR condition if we have any user_name conditions + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + # Add search_filter conditions like nebular.py if search_filter: for key, value in search_filter.items(): @@ -1507,6 +1523,10 @@ def search_by_embedding( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" # Keep original simple query structure but add dynamic WHERE clause @@ -1526,20 +1546,61 @@ def search_by_embedding( FROM t WHERE scope > 0.1; """ - params = [vector] + # Convert vector to string format for PostgreSQL vector type + # PostgreSQL vector type expects a string format like '[1,2,3]' + vector_str = convert_to_vector(vector) + # Use string format directly in query instead of parameterized query + # Replace %s with the vector string, but need to quote it properly + # PostgreSQL vector type needs the string to be quoted + query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)") + params = [] + + # Split query by lines and wrap long lines to prevent terminal truncation + query_lines = query.strip().split("\n") + for line in query_lines: + # Wrap lines longer than 200 characters to prevent terminal truncation + if len(line) > 200: + wrapped_lines = textwrap.wrap( + line, width=200, break_long_words=False, break_on_hyphens=False + ) + for wrapped_line in wrapped_lines: + print(wrapped_line) + else: + print(line) + + logger.info(f"[search_by_embedding] query: {query}, params: {params}") + print(f"[search_by_embedding] query: {query}, params: {params}") conn = self._get_connection() try: with conn.cursor() as cursor: - cursor.execute(query, params) + try: + # If params is empty, execute query directly without parameters + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + except Exception as e: + logger.error(f"[search_by_embedding] Error executing query: {e}") + logger.error(f"[search_by_embedding] Query length: {len(query)}") + logger.error( + f"[search_by_embedding] Params type: {type(params)}, length: {len(params)}" + ) + logger.error(f"[search_by_embedding] Query contains %s: {'%s' in query}") + raise results = cursor.fetchall() output = [] + print("=== Raw Results ===:", results) + print(f"=== Results count: {len(results)} ===") for row in results: """ polarId = row[0] # id properties = row[1] # properties # embedding = row[3] # embedding """ + if len(row) < 5: + logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") + continue oldid = row[3] # old_id score = row[4] # scope id_val = str(oldid) @@ -1553,7 +1614,11 @@ def search_by_embedding( @timed def get_by_metadata( - self, filters: list[dict[str, Any]], user_name: str | None = None + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list | None = None, ) -> list[str]: """ Retrieve node IDs that match given metadata filters. @@ -1572,6 +1637,9 @@ def get_by_metadata( Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). """ + logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + print(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + user_name = user_name if user_name else self._get_config_value("user_name") # Build WHERE conditions for cypher query @@ -1617,16 +1685,31 @@ def get_by_metadata( where_conditions.append(f"n.{field} STARTS WITH {escaped_value}") elif op == "ends_with": where_conditions.append(f"n.{field} ENDS WITH {escaped_value}") + elif op == "like": + where_conditions.append(f"n.{field} CONTAINS {escaped_value}") elif op in [">", ">=", "<", "<="]: where_conditions.append(f"n.{field} {op} {escaped_value}") else: raise ValueError(f"Unsupported operator: {op}") - # Add user_name filter - escaped_user_name = user_name.replace("'", "''") - where_conditions.append(f"n.user_name = '{escaped_user_name}'") + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_conditions.append(user_name_conditions[0]) + else: + where_conditions.append(f"({' OR '.join(user_name_conditions)})") + + # Build filter conditions using common method + filter_where_clause = self._build_filter_conditions_cypher(filter) - where_str = " AND ".join(where_conditions) + where_str = " AND ".join(where_conditions) + filter_where_clause # Use cypher query cypher_query = f""" @@ -1639,6 +1722,8 @@ def get_by_metadata( ids = [] conn = self._get_connection() + logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") + print(f"[get_by_metadata] cypher_query: {cypher_query}") try: with conn.cursor() as cursor: cursor.execute(cypher_query) @@ -2044,7 +2129,12 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: @timed def get_all_memory_items( - self, scope: str, include_embedding: bool = False, user_name: str | None = None + self, + scope: str, + include_embedding: bool = False, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list | None = None, ) -> list[dict]: """ Retrieve all memory items of a specific memory_type. @@ -2057,17 +2147,52 @@ def get_all_memory_items( Returns: list[dict]: Full list of memory items under this scope. """ + logger.info( + f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + ) + print(f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + user_name = user_name if user_name else self._get_config_value("user_name") if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + + # Build user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + user_name_where = user_name_conditions[0] + else: + user_name_where = f"({' OR '.join(user_name_conditions)})" + else: + user_name_where = "" + + # Build filter conditions using common method + filter_where_clause = self._build_filter_conditions_cypher(filter) + # Use cypher query to retrieve memory items if include_embedding: + # Build WHERE clause with user_name/knowledgebase_ids and filter + where_parts = [f"n.memory_type = '{scope}'"] + if user_name_where: + # user_name_where already contains parentheses if it's an OR condition + where_parts.append(user_name_where) + if filter_where_clause: + # filter_where_clause already contains " AND " prefix, so we just append it + where_clause = " AND ".join(where_parts) + filter_where_clause + else: + where_clause = " AND ".join(where_parts) + cypher_query = f""" WITH t as ( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (n:Memory) - WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' + WHERE {where_clause} RETURN id(n) as id1,n LIMIT 100 $$) AS (id1 agtype,n agtype) @@ -2110,10 +2235,21 @@ def get_all_memory_items( return nodes else: + # Build WHERE clause with user_name/knowledgebase_ids and filter + where_parts = [f"n.memory_type = '{scope}'"] + if user_name_where: + # user_name_where already contains parentheses if it's an OR condition + where_parts.append(user_name_where) + if filter_where_clause: + # filter_where_clause already contains " AND " prefix, so we just append it + where_clause = " AND ".join(where_parts) + filter_where_clause + else: + where_clause = " AND ".join(where_parts) + cypher_query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (n:Memory) - WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' + WHERE {where_clause} RETURN properties(n) as props LIMIT 100 $$) AS (nprops agtype) @@ -2121,6 +2257,8 @@ def get_all_memory_items( nodes = [] conn = self._get_connection() + logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") + print(f"[get_all_memory_items] cypher_query: {cypher_query}") try: with conn.cursor() as cursor: cursor.execute(cypher_query) @@ -2495,12 +2633,12 @@ def add_node( self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None ) -> None: """Add a memory node to the graph.""" - logger.info(f"In add node polardb: id-{id} memory-{memory}") + logger.info(f"[add_node] id: {id}, memory: {memory}, metadata: {metadata}") + print(f"[add_node] metadata: {metadata}, info: {metadata.get('info')}") # user_name comes from metadata; fallback to config if missing metadata["user_name"] = user_name if user_name else self.config.user_name - # Safely process metadata metadata = _prepare_node_metadata(metadata) # Merge node and set metadata @@ -2578,6 +2716,12 @@ def add_node( cursor.execute( insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) ) + logger.info( + f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + print( + f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) else: insert_query = f""" INSERT INTO {self.db_name}_graph."Memory"(id, properties) @@ -2587,7 +2731,13 @@ def add_node( ) """ cursor.execute(insert_query, (id, json.dumps(properties))) - logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + logger.info( + f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + print( + f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + finally: logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") self._return_connection(conn) @@ -3083,3 +3233,593 @@ def format_param_value(self, value: str | None) -> str: else: # Add double quotes return f'"{value}"' + + def _build_user_name_and_kb_ids_conditions_cypher( + self, + user_name: str | None, + knowledgebase_ids: list | None, + default_user_name: str | None = None, + ) -> list[str]: + """ + Build user_name and knowledgebase_ids conditions for Cypher queries. + + Args: + user_name: User name for filtering + knowledgebase_ids: List of knowledgebase IDs + default_user_name: Default user name from config if user_name is None + + Returns: + List of condition strings (will be joined with OR) + """ + user_name_conditions = [] + effective_user_name = user_name if user_name else default_user_name + + if effective_user_name: + escaped_user_name = effective_user_name.replace("'", "''") + user_name_conditions.append(f"n.user_name = '{escaped_user_name}'") + + # Add knowledgebase_ids conditions (checking user_name field in the data) + if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: + for kb_id in knowledgebase_ids: + if isinstance(kb_id, str): + escaped_kb_id = kb_id.replace("'", "''") + user_name_conditions.append(f"n.user_name = '{escaped_kb_id}'") + + return user_name_conditions + + def _build_user_name_and_kb_ids_conditions_sql( + self, + user_name: str | None, + knowledgebase_ids: list | None, + default_user_name: str | None = None, + ) -> list[str]: + """ + Build user_name and knowledgebase_ids conditions for SQL queries. + + Args: + user_name: User name for filtering + knowledgebase_ids: List of knowledgebase IDs + default_user_name: Default user name from config if user_name is None + + Returns: + List of condition strings (will be joined with OR) + """ + user_name_conditions = [] + effective_user_name = user_name if user_name else default_user_name + + if effective_user_name: + user_name_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype" + ) + + # Add knowledgebase_ids conditions (checking user_name field in the data) + if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: + for kb_id in knowledgebase_ids: + if isinstance(kb_id, str): + user_name_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kb_id}\"'::agtype" + ) + + return user_name_conditions + + def _build_filter_conditions_cypher( + self, + filter: dict | None, + ) -> str: + """ + Build filter conditions for Cypher queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + + Returns: + Filter WHERE clause string (empty string if no filter) + """ + filter_where_clause = "" + filter = self.parse_filter(filter) + if filter: + + def escape_cypher_string(value: str) -> str: + return value.replace("'", "\\'") + + def build_cypher_filter_condition(condition_dict: dict) -> str: + """Build a Cypher WHERE condition for a single filter item.""" + condition_parts = [] + for key, value in condition_dict.items(): + # Check if value is a dict with comparison operators (gt, lt, gte, lte, =, contains) + if isinstance(value, dict): + # Handle comparison operators: gt, lt, gte, lte, =, contains + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte"): + # Map operator to Cypher operator + cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} + cypher_op = cypher_op_map[op] + + # Check if key starts with "info." prefix (for nested fields like info.A, info.B) + if key.startswith("info."): + # Nested field access: n.info.field_name + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + condition_parts.append( + f"n.info.{info_field} {cypher_op} '{escaped_value}'" + ) + else: + condition_parts.append( + f"n.info.{info_field} {cypher_op} {op_value}" + ) + else: + # Direct property access (e.g., "created_at" is directly in n, not in n.info) + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + condition_parts.append( + f"n.{key} {cypher_op} '{escaped_value}'" + ) + else: + condition_parts.append(f"n.{key} {cypher_op} {op_value}") + elif op == "=": + # Handle equality operator + # For array fields, = means exact match of the entire array (e.g., tags = ['test:zdy'] or tags = ['mode:fast', 'test:zdy']) + # For scalar fields, = means equality + # Check if key starts with "info." prefix + if key.startswith("info."): + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + # For array fields, check if array exactly equals [value] + # For scalar fields, use = + if info_field in ("tags", "sources"): + condition_parts.append( + f"n.info.{info_field} = ['{escaped_value}']" + ) + else: + condition_parts.append( + f"n.info.{info_field} = '{escaped_value}'" + ) + elif isinstance(op_value, list): + # For array fields, format list as Cypher array + if info_field in ("tags", "sources"): + escaped_items = [ + f"'{escape_cypher_string(str(item))}'" + for item in op_value + ] + array_str = "[" + ", ".join(escaped_items) + "]" + condition_parts.append( + f"n.info.{info_field} = {array_str}" + ) + else: + condition_parts.append( + f"n.info.{info_field} = {op_value}" + ) + else: + if info_field in ("tags", "sources"): + condition_parts.append( + f"n.info.{info_field} = [{op_value}]" + ) + else: + condition_parts.append( + f"n.info.{info_field} = {op_value}" + ) + else: + # Direct property access + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + # For array fields, check if array exactly equals [value] + # For scalar fields, use = + if key in ("tags", "sources"): + condition_parts.append(f"n.{key} = ['{escaped_value}']") + else: + condition_parts.append(f"n.{key} = '{escaped_value}'") + elif isinstance(op_value, list): + # For array fields, format list as Cypher array + if key in ("tags", "sources"): + escaped_items = [ + f"'{escape_cypher_string(str(item))}'" + for item in op_value + ] + array_str = "[" + ", ".join(escaped_items) + "]" + condition_parts.append(f"n.{key} = {array_str}") + else: + condition_parts.append(f"n.{key} = {op_value}") + else: + if key in ("tags", "sources"): + condition_parts.append(f"n.{key} = [{op_value}]") + else: + condition_parts.append(f"n.{key} = {op_value}") + elif op == "contains": + # Handle contains operator (for array fields) + # Check if key starts with "info." prefix + if key.startswith("info."): + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + condition_parts.append( + f"'{escaped_value}' IN n.info.{info_field}" + ) + else: + condition_parts.append(f"{op_value} IN n.info.{info_field}") + else: + # Direct property access + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + condition_parts.append(f"'{escaped_value}' IN n.{key}") + else: + condition_parts.append(f"{op_value} IN n.{key}") + elif op == "like": + # Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%') + # Check if key starts with "info." prefix + if key.startswith("info."): + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + condition_parts.append( + f"n.info.{info_field} CONTAINS '{escaped_value}'" + ) + else: + condition_parts.append( + f"n.info.{info_field} CONTAINS {op_value}" + ) + else: + # Direct property access + if isinstance(op_value, str): + escaped_value = escape_cypher_string(op_value) + condition_parts.append( + f"n.{key} CONTAINS '{escaped_value}'" + ) + else: + condition_parts.append(f"n.{key} CONTAINS {op_value}") + # Check if key starts with "info." prefix (for simple equality) + elif key.startswith("info."): + info_field = key[5:] + if isinstance(value, str): + escaped_value = escape_cypher_string(value) + condition_parts.append(f"n.info.{info_field} = '{escaped_value}'") + else: + condition_parts.append(f"n.info.{info_field} = {value}") + else: + # Direct property access (simple equality) + if isinstance(value, str): + escaped_value = escape_cypher_string(value) + condition_parts.append(f"n.{key} = '{escaped_value}'") + else: + condition_parts.append(f"n.{key} = {value}") + return " AND ".join(condition_parts) + + if isinstance(filter, dict): + if "or" in filter: + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str = build_cypher_filter_condition(condition) + if condition_str: + or_conditions.append(f"({condition_str})") + if or_conditions: + filter_where_clause = " AND " + f"({' OR '.join(or_conditions)})" + + elif "and" in filter: + and_conditions = [] + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str = build_cypher_filter_condition(condition) + if condition_str: + and_conditions.append(f"({condition_str})") + if and_conditions: + filter_where_clause = " AND " + " AND ".join(and_conditions) + + return filter_where_clause + + def _build_filter_conditions_sql( + self, + filter: dict | None, + ) -> list[str]: + """ + Build filter conditions for SQL queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + + Returns: + List of filter WHERE clause strings (empty list if no filter) + """ + filter_conditions = [] + filter = self.parse_filter(filter) + if filter: + # Helper function to escape string value for SQL + def escape_sql_string(value: str) -> str: + """Escape single quotes in SQL string.""" + return value.replace("'", "''") + + # Helper function to build a single filter condition + def build_filter_condition(condition_dict: dict) -> str: + """Build a WHERE condition for a single filter item.""" + condition_parts = [] + for key, value in condition_dict.items(): + # Check if value is a dict with comparison operators (gt, lt, gte, lte, =, contains) + if isinstance(value, dict): + # Handle comparison operators: gt, lt, gte, lte, =, contains + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte"): + # Map operator to SQL operator + sql_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} + sql_op = sql_op_map[op] + + # Check if key starts with "info." prefix (for nested fields like info.A, info.B) + if key.startswith("info."): + # Nested field access: properties->'info'->'field_name' + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + escaped_value = escape_sql_string(op_value) + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} {op_value}::agtype" + ) + else: + # Direct property access (e.g., "created_at" is directly in properties, not in properties.info) + if isinstance(op_value, str): + escaped_value = escape_sql_string(op_value) + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} {op_value}::agtype" + ) + elif op == "=": + # Handle equality operator + # For array fields, = means exact match of the entire array (e.g., tags = ['test:zdy'] or tags = ['mode:fast', 'test:zdy']) + # For scalar fields, = means equality + # Check if key starts with "info." prefix + if key.startswith("info."): + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + escaped_value = escape_sql_string(op_value) + # For array fields, check if array exactly equals [value] + # For scalar fields, use = + if info_field in ("tags", "sources"): + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '[\"{escaped_value}\"]'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" + ) + elif isinstance(op_value, list): + # For array fields, format list as JSON array string + if info_field in ("tags", "sources"): + escaped_items = [ + escape_sql_string(str(item)) for item in op_value + ] + json_array = json.dumps(escaped_items) + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '{json_array}'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {op_value}::agtype" + ) + else: + if info_field in ("tags", "sources"): + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '[{op_value}]'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {op_value}::agtype" + ) + else: + # Direct property access + if isinstance(op_value, str): + escaped_value = escape_sql_string(op_value) + # For array fields, check if array exactly equals [value] + # For scalar fields, use = + if key in ("tags", "sources"): + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '[\"{escaped_value}\"]'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" + ) + elif isinstance(op_value, list): + # For array fields, format list as JSON array string + if key in ("tags", "sources"): + escaped_items = [ + escape_sql_string(str(item)) for item in op_value + ] + json_array = json.dumps(escaped_items) + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '{json_array}'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {op_value}::agtype" + ) + else: + if key in ("tags", "sources"): + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '[{op_value}]'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {op_value}::agtype" + ) + elif op == "contains": + # Handle contains operator (for array fields) - use @> operator + # Check if key starts with "info." prefix + if key.startswith("info."): + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + escaped_value = escape_sql_string(op_value) + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype) @> '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype) @> {op_value}::agtype" + ) + else: + # Direct property access + if isinstance(op_value, str): + escaped_value = escape_sql_string(op_value) + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> {op_value}::agtype" + ) + elif op == "like": + # Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%') + # Check if key starts with "info." prefix + if key.startswith("info."): + info_field = key[5:] # Remove "info." prefix + if isinstance(op_value, str): + # Escape SQL special characters for LIKE: % and _ need to be escaped + escaped_value = ( + escape_sql_string(op_value) + .replace("%", "\\%") + .replace("_", "\\_") + ) + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text LIKE '%{escaped_value}%'" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text LIKE '%{op_value}%'" + ) + else: + # Direct property access + if isinstance(op_value, str): + # Escape SQL special characters for LIKE: % and _ need to be escaped + escaped_value = ( + escape_sql_string(op_value) + .replace("%", "\\%") + .replace("_", "\\_") + ) + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text LIKE '%{escaped_value}%'" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text LIKE '%{op_value}%'" + ) + # Check if key starts with "info." prefix (for simple equality) + elif key.startswith("info."): + # Extract the field name after "info." + info_field = key[5:] # Remove "info." prefix (5 characters) + if isinstance(value, str): + escaped_value = escape_sql_string(value) + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{value}\"'::agtype" + ) + else: + # Direct property access (simple equality) + if isinstance(value, str): + escaped_value = escape_sql_string(value) + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + return " AND ".join(condition_parts) + + # Process filter structure + if isinstance(filter, dict): + if "or" in filter: + # OR logic: at least one condition must match + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + or_conditions.append(f"({condition_str})") + if or_conditions: + filter_conditions.append(f"({' OR '.join(or_conditions)})") + + elif "and" in filter: + # AND logic: all conditions must match + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + filter_conditions.append(f"({condition_str})") + + return filter_conditions + + def parse_filter( + self, + filter_dict: dict | None = None, + ): + if filter_dict is None: + return None + full_fields = { + "id", + "key", + "tags", + "type", + "usage", + "memory", + "status", + "sources", + "user_id", + "graph_id", + "user_name", + "background", + "confidence", + "created_at", + "session_id", + "updated_at", + "memory_type", + "node_type", + "info", + "app_id", + "agent_id", + } + + def process_condition(condition): + if not isinstance(condition, dict): + return condition + + new_condition = {} + + for key, value in condition.items(): + if key.lower() in ["or", "and"]: + if isinstance(value, list): + processed_items = [] + for item in value: + if isinstance(item, dict): + processed_item = {} + for item_key, item_value in item.items(): + if item_key not in full_fields and not item_key.startswith( + "info." + ): + new_item_key = f"info.{item_key}" + else: + new_item_key = item_key + processed_item[new_item_key] = item_value + processed_items.append(processed_item) + else: + processed_items.append(item) + new_condition[key] = processed_items + else: + new_condition[key] = value + else: + if key not in full_fields and not key.startswith("info."): + new_key = f"info.{key}" + else: + new_key = key + + new_condition[new_key] = value + + return new_condition + + return process_condition(filter_dict)