Skip to content

Commit 2be48da

Browse files
Wang-Daojiyuan.wangendxxxxCCC
authored
feat: skill memory (#1056)
* feat: skill memory * feat: split task chunks for skill memories * fix: refine the returned format from llm and parsing * feat: add new pack oss * feat: skill mem pipeline * feat: fill code * feat: modify code * feat: modify code * feat: async add skill memory * feat: update ollama version * feat: get memory return skill memory * feat: get api add skill mem * feat: get api add skill mem * feat: modify env config * feat: back set oss client * feat: delete tmp skill code * feat: process new package import error * feat: modify oss config * feat: modiy prompt and add two api * feat: modify prompt * feat: modify code * feat: add logger * feat: fix bug in memory id * fix:skill OSS + LOCAL存 zip * fix:skill OSS + LOCAL存 zip * fix:skill OSS + LOCAL存 zip * feat: new code * fix: fix name error in polardb and related code * fix: bug in polardb * feat: optimize skill * feat: local deploy --------- Co-authored-by: yuan.wang <yuan.wang@yuanwangdebijibendiannao.local> Co-authored-by: Wenqiang Wei <wwq38556399@163.com> Co-authored-by: CCC <15764764+triple-c-individual@user.noreply.gitee.com>
1 parent 15921be commit 2be48da

3 files changed

Lines changed: 167 additions & 73 deletions

File tree

src/memos/api/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,12 @@ def get_product_default_config() -> dict[str, Any]:
821821
"oss_config": APIConfig.get_oss_config(),
822822
"skills_dir_config": {
823823
"skills_oss_dir": os.getenv("SKILLS_OSS_DIR", "skill_memory/"),
824-
"skills_local_dir": os.getenv("SKILLS_LOCAL_DIR", "/tmp/skill_memory/"),
824+
"skills_local_tmp_dir": os.getenv(
825+
"SKILLS_LOCAL_TMP_DIR", "/tmp/skill_memory/"
826+
),
827+
"skills_local_dir": os.getenv(
828+
"SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/"
829+
),
825830
},
826831
},
827832
},

src/memos/api/server_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import logging
2+
import os
23

4+
from dotenv import load_dotenv
35
from fastapi import FastAPI, HTTPException
46
from fastapi.exceptions import RequestValidationError
7+
from starlette.staticfiles import StaticFiles
58

69
from memos.api.exceptions import APIExceptionHandler
710
from memos.api.middleware.request_context import RequestContextMiddleware
811
from memos.api.routers.server_router import router as server_router
912

1013

14+
load_dotenv()
15+
1116
# Configure logging
1217
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
1318
logger = logging.getLogger(__name__)
@@ -18,6 +23,8 @@
1823
version="1.0.1",
1924
)
2025

26+
app.mount("/download", StaticFiles(directory=os.getenv("FILE_LOCAL_PATH")), name="static_mapping")
27+
2128
app.add_middleware(RequestContextMiddleware, source="server_api")
2229
# Include routers
2330
app.include_router(server_router)

src/memos/mem_reader/read_skill_memory/process_skill_memory.py

Lines changed: 154 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import os
44
import shutil
55
import uuid
6+
import warnings
67
import zipfile
78

89
from concurrent.futures import as_completed
910
from datetime import datetime
1011
from pathlib import Path
1112
from typing import TYPE_CHECKING, Any
1213

14+
from dotenv import load_dotenv
15+
1316
from memos.context.context import ContextThreadPoolExecutor
1417
from memos.dependency import require_python_package
1518
from memos.embedders.base import BaseEmbedder
@@ -36,6 +39,8 @@
3639
from memos.types import MessageList
3740

3841

42+
load_dotenv()
43+
3944
if TYPE_CHECKING:
4045
from memos.types.general_types import UserContext
4146

@@ -653,42 +658,88 @@ def _rewrite_query(task_type: str, messages: MessageList, llm: BaseLLM, rewrite_
653658
import_name="alibabacloud_oss_v2",
654659
install_command="pip install alibabacloud-oss-v2",
655660
)
656-
def _upload_skills_to_oss(local_file_path: str, oss_file_path: str, client: Any) -> str:
657-
import alibabacloud_oss_v2 as oss
658-
659-
result = client.put_object_from_file(
660-
request=oss.PutObjectRequest(
661-
bucket=os.getenv("OSS_BUCKET_NAME"),
662-
key=oss_file_path,
663-
),
664-
filepath=local_file_path,
665-
)
661+
def _upload_skills(
662+
skills_repo_backend: str,
663+
skills_oss_dir: dict[str, Any] | None,
664+
local_tmp_file_path: str,
665+
local_save_file_path: str,
666+
client: Any,
667+
user_id: str,
668+
) -> str:
669+
if skills_repo_backend == "OSS":
670+
zip_filename = Path(local_tmp_file_path).name
671+
oss_path = (Path(skills_oss_dir) / user_id / zip_filename).as_posix()
672+
673+
import alibabacloud_oss_v2 as oss
674+
675+
result = client.put_object_from_file(
676+
request=oss.PutObjectRequest(
677+
bucket=os.getenv("OSS_BUCKET_NAME"),
678+
key=oss_path,
679+
),
680+
filepath=local_tmp_file_path,
681+
)
666682

667-
if result.status_code != 200:
668-
logger.warning("[PROCESS_SKILLS] Failed to upload skill to OSS")
669-
return ""
683+
if result.status_code != 200:
684+
logger.warning("[PROCESS_SKILLS] Failed to upload skill to OSS")
685+
return ""
686+
687+
# Construct and return the URL
688+
bucket_name = os.getenv("OSS_BUCKET_NAME")
689+
endpoint = os.getenv("OSS_ENDPOINT").replace("https://", "").replace("http://", "")
690+
url = f"https://{bucket_name}.{endpoint}/{oss_path}"
691+
return url
692+
else:
693+
import sys
694+
695+
args = sys.argv
696+
port = (
697+
int(args[args.index("--port") + 1])
698+
if "--port" in args and args.index("--port") + 1 < len(args)
699+
else "8000"
700+
)
670701

671-
# Construct and return the URL
672-
bucket_name = os.getenv("OSS_BUCKET_NAME")
673-
endpoint = os.getenv("OSS_ENDPOINT").replace("https://", "").replace("http://", "")
674-
url = f"https://{bucket_name}.{endpoint}/{oss_file_path}"
675-
return url
702+
zip_path = str(local_tmp_file_path)
703+
os.makedirs(local_save_file_path, exist_ok=True)
704+
file_name = os.path.basename(zip_path)
705+
target_full_path = os.path.join(local_save_file_path, file_name)
706+
shutil.copy2(zip_path, target_full_path)
707+
return f"http://localhost:{port}/download/{file_name}"
676708

677709

678710
@require_python_package(
679711
import_name="alibabacloud_oss_v2",
680712
install_command="pip install alibabacloud-oss-v2",
681713
)
682-
def _delete_skills_from_oss(oss_file_path: str, client: Any) -> Any:
683-
import alibabacloud_oss_v2 as oss
684-
685-
result = client.delete_object(
686-
oss.DeleteObjectRequest(
687-
bucket=os.getenv("OSS_BUCKET_NAME"),
688-
key=oss_file_path,
714+
def _delete_skills(
715+
skills_repo_backend: str,
716+
zip_filename: str,
717+
client: Any,
718+
skills_oss_dir: dict[str, Any] | None,
719+
local_save_file_path: str,
720+
user_id: str,
721+
) -> Any:
722+
if skills_repo_backend == "OSS":
723+
old_path = (Path(skills_oss_dir) / user_id / zip_filename).as_posix()
724+
import alibabacloud_oss_v2 as oss
725+
726+
return client.delete_object(
727+
oss.DeleteObjectRequest(
728+
bucket=os.getenv("OSS_BUCKET_NAME"),
729+
key=old_path,
730+
)
689731
)
690-
)
691-
return result
732+
else:
733+
target_full_path = os.path.join(local_save_file_path, zip_filename)
734+
target_path = Path(target_full_path)
735+
try:
736+
if target_path.is_file():
737+
target_path.unlink()
738+
logger.info(f"本地文件 {target_path} 已成功删除")
739+
else:
740+
print(f"本地文件 {target_path} 不存在,无需删除")
741+
except Exception as e:
742+
print(f"删除本地文件时出错:{e}")
692743

693744

694745
def _write_skills_to_file(
@@ -698,7 +749,7 @@ def _write_skills_to_file(
698749
skill_name = skill_memory.get("name", "unnamed_skill").replace(" ", "_").lower()
699750

700751
# Create tmp directory for user if it doesn't exist
701-
tmp_dir = Path(skills_dir_config["skills_local_dir"]) / user_id
752+
tmp_dir = Path(skills_dir_config["skills_local_tmp_dir"]) / user_id
702753
tmp_dir.mkdir(parents=True, exist_ok=True)
703754

704755
# Create skill directory directly in tmp_dir
@@ -889,6 +940,54 @@ def create_skill_memory_item(
889940
return TextualMemoryItem(id=item_id, memory=memory_content, metadata=metadata)
890941

891942

943+
def _skill_init(skills_repo_backend, oss_config, skills_dir_config):
944+
if skills_repo_backend == "OSS":
945+
# Validate required configurations
946+
if not oss_config:
947+
logger.warning(
948+
"[PROCESS_SKILLS] OSS configuration is required for skill memory processing"
949+
)
950+
return None, None, False
951+
952+
if not skills_dir_config:
953+
logger.warning(
954+
"[PROCESS_SKILLS] Skills directory configuration is required for skill memory processing"
955+
)
956+
return None, None, False
957+
958+
# Validate skills_dir has required keys
959+
required_keys = ["skills_local_tmp_dir", "skills_local_dir", "skills_oss_dir"]
960+
missing_keys = [key for key in required_keys if key not in skills_dir_config]
961+
if missing_keys:
962+
logger.warning(
963+
f"[PROCESS_SKILLS] Skills directory configuration missing required keys: {', '.join(missing_keys)}"
964+
)
965+
return None, None, False
966+
967+
oss_client = create_oss_client(oss_config)
968+
if not oss_client:
969+
logger.warning("[PROCESS_SKILLS] Failed to create OSS client")
970+
return None, None, False
971+
return oss_client, missing_keys, True
972+
else:
973+
return None, None, True
974+
975+
976+
def _get_skill_file_storage_location() -> str:
977+
# SKILLS_REPO_BACKEND: Skill 文件保存地址 OSS/LOCAL
978+
allowed_backends = {"OSS", "LOCAL"}
979+
raw_backend = os.getenv("SKILLS_REPO_BACKEND")
980+
if raw_backend in allowed_backends:
981+
return raw_backend
982+
else:
983+
warnings.warn(
984+
"环境变量【SKILLS_REPO_BACKEND】赋值错误,本次使用 LOCAL 存储 skill",
985+
UserWarning,
986+
stacklevel=1,
987+
)
988+
return "LOCAL"
989+
990+
892991
def process_skill_memory_fine(
893992
fast_memory_items: list[TextualMemoryItem],
894993
info: dict[str, Any],
@@ -902,36 +1001,16 @@ def process_skill_memory_fine(
9021001
complete_skill_memory: bool = True,
9031002
**kwargs,
9041003
) -> list[TextualMemoryItem]:
905-
# Validate required configurations
906-
if not oss_config:
907-
logger.warning("[PROCESS_SKILLS] OSS configuration is required for skill memory processing")
908-
return []
909-
910-
if not skills_dir_config:
911-
logger.warning(
912-
"[PROCESS_SKILLS] Skills directory configuration is required for skill memory processing"
913-
)
1004+
skills_repo_backend = _get_skill_file_storage_location()
1005+
oss_client, missing_keys, flag = _skill_init(skills_repo_backend, oss_config, skills_dir_config)
1006+
if not flag:
9141007
return []
9151008

9161009
chat_history = kwargs.get("chat_history")
9171010
if not chat_history or not isinstance(chat_history, list):
9181011
chat_history = []
9191012
logger.warning("[PROCESS_SKILLS] History is None in Skills")
9201013

921-
# Validate skills_dir has required keys
922-
required_keys = ["skills_local_dir", "skills_oss_dir"]
923-
missing_keys = [key for key in required_keys if key not in skills_dir_config]
924-
if missing_keys:
925-
logger.warning(
926-
f"[PROCESS_SKILLS] Skills directory configuration missing required keys: {', '.join(missing_keys)}"
927-
)
928-
return []
929-
930-
oss_client = create_oss_client(oss_config)
931-
if not oss_client:
932-
logger.warning("[PROCESS_SKILLS] Failed to create OSS client")
933-
return []
934-
9351014
messages = _reconstruct_messages_from_memory_items(fast_memory_items)
9361015

9371016
chat_history, messages = _preprocess_extract_messages(chat_history, messages)
@@ -1060,23 +1139,27 @@ def _full_extract():
10601139
old_memory = old_memories_map.get(old_memory_id)
10611140

10621141
if old_memory:
1063-
# Get old OSS path from the old memory's metadata
1064-
old_oss_path = getattr(old_memory.metadata, "url", None)
1142+
# Get old path from the old memory's metadata
1143+
old_path = getattr(old_memory.metadata, "url", None)
10651144

1066-
if old_oss_path:
1145+
if old_path:
10671146
try:
10681147
# delete old skill from OSS
1069-
zip_filename = Path(old_oss_path).name
1070-
old_oss_path = (
1071-
Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename
1072-
).as_posix()
1073-
_delete_skills_from_oss(old_oss_path, oss_client)
1148+
zip_filename = Path(old_path).name
1149+
_delete_skills(
1150+
skills_repo_backend=skills_repo_backend,
1151+
zip_filename=zip_filename,
1152+
client=oss_client,
1153+
skills_oss_dir=skills_dir_config["skills_oss_dir"],
1154+
local_save_file_path=skills_dir_config["skills_local_dir"],
1155+
user_id=user_id,
1156+
)
10741157
logger.info(
1075-
f"[PROCESS_SKILLS] Deleted old skill from OSS: {old_oss_path}"
1158+
f"[PROCESS_SKILLS] Deleted old skill from {skills_repo_backend}: {old_path}"
10761159
)
10771160
except Exception as e:
10781161
logger.warning(
1079-
f"[PROCESS_SKILLS] Failed to delete old skill from OSS: {e}"
1162+
f"[PROCESS_SKILLS] Failed to delete old skill from {skills_repo_backend}: {e}"
10801163
)
10811164

10821165
# delete old skill from graph db
@@ -1086,24 +1169,23 @@ def _full_extract():
10861169
f"[PROCESS_SKILLS] Deleted old skill from graph db: {old_memory_id}"
10871170
)
10881171

1089-
# Upload new skill to OSS
1172+
# Upload new skill
10901173
# Use the same filename as the local zip file
1091-
zip_filename = Path(zip_path).name
1092-
oss_path = (
1093-
Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename
1094-
).as_posix()
1095-
1096-
# _upload_skills_to_oss returns the URL
1097-
url = _upload_skills_to_oss(
1098-
local_file_path=str(zip_path), oss_file_path=oss_path, client=oss_client
1174+
url = _upload_skills(
1175+
skills_repo_backend=skills_repo_backend,
1176+
skills_oss_dir=skills_dir_config["skills_oss_dir"],
1177+
local_tmp_file_path=zip_path,
1178+
local_save_file_path=skills_dir_config["skills_local_dir"],
1179+
client=oss_client,
1180+
user_id=user_id,
10991181
)
11001182

11011183
# Set URL directly to skill_memory
11021184
skill_memory["url"] = url
11031185

1104-
logger.info(f"[PROCESS_SKILLS] Uploaded skill to OSS: {url}")
1186+
logger.info(f"[PROCESS_SKILLS] Uploaded skill to {skills_repo_backend}: {url}")
11051187
except Exception as e:
1106-
logger.warning(f"[PROCESS_SKILLS] Error uploading skill to OSS: {e}")
1188+
logger.warning(f"[PROCESS_SKILLS] Error uploading skill to {skills_repo_backend}: {e}")
11071189
skill_memory["url"] = "" # Set to empty string if upload fails
11081190
finally:
11091191
# Clean up local files after upload

0 commit comments

Comments
 (0)