Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 82 additions & 11 deletions fastdeploy/rl/dynamic_weight_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
# limitations under the License.
"""

import gc
import glob
import os
import re
import time
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict, List
Expand Down Expand Up @@ -111,20 +114,88 @@ def _normal_load_weight(self):
self._update_model_from_state(state_dict, "raw")

def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery."""
model_path = os.path.join(
self.fd_config.model_config.model,
f"model_state.tp0{self.meta_src_id}.pdparams",
)
"""Update using IPC snapshot strategy for elastic recovery.

try:
Loading priority:
1. Chunked part files (model_state.tp{rank}.{id}.part{N}.pdparams)
2. Single full file (model_state.tp{rank}.{id}.pdparams)
3. Legacy format (model_state.tp0{id}.pdparams)
4. Shared fallback dir (/shared_ipc_meta/...)
"""
model_dir = self.fd_config.model_config.model
base_name = f"model_state.tp{paddle.distributed.get_rank()}.{self.meta_src_id}"
legacy_base_name = f"model_state.tp0{self.meta_src_id}"

# --- Priority 1: load from chunked part files to avoid memory spike ---
part_pattern = os.path.join(model_dir, f"{base_name}.part*.pdparams")
all_part_files = glob.glob(part_pattern)

valid_part_files = []
invalid_part_files = []
part_regex = re.compile(r"\.part(\d+)\.")

for path in all_part_files:
match = part_regex.search(path)
if not match:
invalid_part_files.append(os.path.basename(path))
continue
try:
part_idx = int(match.group(1))
except (TypeError, ValueError):
invalid_part_files.append(os.path.basename(path))
continue
valid_part_files.append((part_idx, path))

if invalid_part_files:
logger.warning(
"Found snapshot part files with invalid naming pattern under %s: %s. "
"These files will be ignored when loading IPC snapshot parts.",
model_dir,
", ".join(invalid_part_files),
)

part_files = [p for _, p in sorted(valid_part_files, key=lambda item: item[0])]

if part_files:
logger.info(f"Found {len(part_files)} snapshot part files for {base_name}")
for load_idx, part_path in enumerate(part_files):
match = re.search(r"\.part(\d+)\.", part_path)
# Use part index parsed from filename to keep logs and src_type consistent with file naming
part_index = int(match.group(1)) if match else load_idx
logger.info(f"Loading snapshot part {part_index+1}/{len(part_files)} from {part_path}")
ipc_state_dict = paddle.load(part_path, safetensors=True)
self._update_model_from_state(ipc_state_dict, f"snapshot-part{part_index}")
del ipc_state_dict
gc.collect()
logger.info(f"IPC snapshot update completed from {len(part_files)} part files under {model_dir}")
return

# --- Priority 2: single full pdparams file ---
model_path = os.path.join(model_dir, f"{base_name}.pdparams")
if os.path.exists(model_path):
ipc_state_dict = paddle.load(model_path, safetensors=True)
except FileNotFoundError:
fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams"
ipc_state_dict = paddle.load(fallback_path)

self._update_model_from_state(ipc_state_dict, "snapshot")
logger.info(f"IPC snapshot update completed from {model_path}")
return

# --- Priority 3: legacy format (model_state.tp0{id}.pdparams) ---
legacy_path = os.path.join(model_dir, f"{legacy_base_name}.pdparams")
if os.path.exists(legacy_path):
ipc_state_dict = paddle.load(legacy_path, safetensors=True)
self._update_model_from_state(ipc_state_dict, "snapshot")
logger.info(f"IPC snapshot update completed from legacy format {legacy_path}")
return

# --- Priority 4: shared directory fallback ---
fallback_path = f"/shared_ipc_meta/{base_name}.pdparams"
if not os.path.exists(fallback_path):
raise FileNotFoundError(
f"No snapshot found for {base_name}: " f"checked {model_dir} (new/legacy) and {fallback_path}"
)
logger.info(f"No local snapshot in {model_dir}, fallback to {fallback_path}")
ipc_state_dict = paddle.load(fallback_path)
self._update_model_from_state(ipc_state_dict, "snapshot")
logger.info(f"IPC snapshot update parameters completed from {model_path}")
logger.info(f"IPC snapshot update completed from {fallback_path}")

def _update_ipc(self):
"""Update using standard IPC strategy (requires Training Worker)."""
Expand Down
237 changes: 237 additions & 0 deletions tests/rl/test_update_ipc_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import unittest
from unittest.mock import MagicMock, patch


def _make_manager(model_dir="/fake/model", rank=0, meta_src_id=0, load_strategy="ipc_snapshot"):
"""Build a DynamicWeightManager instance bypassing __init__, for testing _update_ipc_snapshot only."""
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager

obj = object.__new__(DynamicWeightManager)

# fd_config mock
fd_config = MagicMock()
fd_config.model_config.model = model_dir
fd_config.load_config.load_strategy = load_strategy
obj.fd_config = fd_config

obj.meta_src_id = meta_src_id
obj.rank = rank
# Two mock parameters to verify _update_model_from_state is called correctly
obj.state_dict = {
"layer.weight": MagicMock(name="layer.weight"),
"layer.bias": MagicMock(name="layer.bias"),
}
return obj


class TestUpdateIpcSnapshot(unittest.TestCase):
"""Unit tests for DynamicWeightManager._update_ipc_snapshot.

Covers all 4 loading priority branches inside the function:
Priority 1 - chunked part files
Priority 2 - single full pdparams file
Priority 3 - legacy format
Priority 4 - shared directory fallback
Plus the error path when no snapshot is found anywhere.
"""

def setUp(self):
# Pre-import the module so that fastdeploy.rl is set in fastdeploy.__dict__
# before @patch decorators resolve their targets via _importer.
# Without this, fastdeploy.__getattr__ prints a warning and returns None for "rl",
# causing _importer to later fail with AttributeError when the test runs first
# (e.g., alphabetical order in unittest).
import fastdeploy.rl.dynamic_weight_manager # noqa: F401

# ------------------------------------------------------------------
# Priority 1: chunked part files
# ------------------------------------------------------------------
@patch("fastdeploy.rl.dynamic_weight_manager.gc.collect")
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.load")
@patch("fastdeploy.rl.dynamic_weight_manager.os.path.exists")
@patch("fastdeploy.rl.dynamic_weight_manager.glob.glob")
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.distributed.get_rank", return_value=0)
def test_priority1_chunked_part_files(self, _mock_rank, mock_glob, mock_exists, mock_load, mock_gc):
"""When part files are found, load each part in order and do not fall through to other priorities."""
model_dir = "/fake/model"
part_files = [
f"{model_dir}/model_state.tp0.0.part1.pdparams",
f"{model_dir}/model_state.tp0.0.part2.pdparams",
]
mock_glob.return_value = part_files

fake_state_dict_1 = {"layer.weight": MagicMock()}
fake_state_dict_2 = {"layer.bias": MagicMock()}
mock_load.side_effect = [fake_state_dict_1, fake_state_dict_2]

mgr = _make_manager(model_dir=model_dir)
mgr._update_model_from_state = MagicMock()

mgr._update_ipc_snapshot()

# glob should be called to search for part files
mock_glob.assert_called_once()
# os.path.exists must NOT be called: priority 1 should return early
mock_exists.assert_not_called()
# paddle.load should be called once per part file
self.assertEqual(mock_load.call_count, 2)
mock_load.assert_any_call(part_files[0], safetensors=True)
mock_load.assert_any_call(part_files[1], safetensors=True)
# _update_model_from_state should be called once per part with correct args
self.assertEqual(mgr._update_model_from_state.call_count, 2)
mgr._update_model_from_state.assert_any_call(fake_state_dict_1, "snapshot-part1")
mgr._update_model_from_state.assert_any_call(fake_state_dict_2, "snapshot-part2")
# gc.collect should be called after each part to free memory
self.assertEqual(mock_gc.call_count, 2)

# ------------------------------------------------------------------
# Priority 2: single full pdparams file
# ------------------------------------------------------------------
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.load")
@patch("fastdeploy.rl.dynamic_weight_manager.os.path.exists")
@patch("fastdeploy.rl.dynamic_weight_manager.glob.glob", return_value=[])
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.distributed.get_rank", return_value=0)
def test_priority2_single_full_file(self, _mock_rank, _mock_glob, mock_exists, mock_load):
"""When no part files exist, load from model_state.tp{rank}.{id}.pdparams and return."""
model_dir = "/fake/model"
full_path = f"{model_dir}/model_state.tp0.0.pdparams"

# Only full_path exists
mock_exists.side_effect = lambda p: p == full_path

fake_state_dict = {"layer.weight": MagicMock()}
mock_load.return_value = fake_state_dict

mgr = _make_manager(model_dir=model_dir)
mgr._update_model_from_state = MagicMock()

mgr._update_ipc_snapshot()

mock_load.assert_called_once_with(full_path, safetensors=True)
mgr._update_model_from_state.assert_called_once_with(fake_state_dict, "snapshot")

# ------------------------------------------------------------------
# Priority 3: legacy format
# ------------------------------------------------------------------
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.load")
@patch("fastdeploy.rl.dynamic_weight_manager.os.path.exists")
@patch("fastdeploy.rl.dynamic_weight_manager.glob.glob", return_value=[])
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.distributed.get_rank", return_value=0)
def test_priority3_legacy_format(self, _mock_rank, _mock_glob, mock_exists, mock_load):
"""When the full path does not exist, fall back to legacy model_state.tp0{id}.pdparams."""
model_dir = "/fake/model"
legacy_path = f"{model_dir}/model_state.tp00.pdparams"

# full_path absent, legacy_path present
mock_exists.side_effect = lambda p: p == legacy_path

fake_state_dict = {"layer.weight": MagicMock()}
mock_load.return_value = fake_state_dict

mgr = _make_manager(model_dir=model_dir)
mgr._update_model_from_state = MagicMock()

mgr._update_ipc_snapshot()

mock_load.assert_called_once_with(legacy_path, safetensors=True)
mgr._update_model_from_state.assert_called_once_with(fake_state_dict, "snapshot")

# ------------------------------------------------------------------
# Priority 4: shared directory fallback
# ------------------------------------------------------------------
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.load")
@patch("fastdeploy.rl.dynamic_weight_manager.os.path.exists")
@patch("fastdeploy.rl.dynamic_weight_manager.glob.glob", return_value=[])
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.distributed.get_rank", return_value=0)
def test_priority4_shared_fallback(self, _mock_rank, _mock_glob, mock_exists, mock_load):
"""When all three local paths are absent, load from /shared_ipc_meta/."""
model_dir = "/fake/model"
fallback_path = "/shared_ipc_meta/model_state.tp0.0.pdparams"

# Only fallback_path exists
mock_exists.side_effect = lambda p: p == fallback_path

fake_state_dict = {"layer.weight": MagicMock()}
mock_load.return_value = fake_state_dict

mgr = _make_manager(model_dir=model_dir)
mgr._update_model_from_state = MagicMock()

mgr._update_ipc_snapshot()

mock_load.assert_called_once_with(fallback_path)
mgr._update_model_from_state.assert_called_once_with(fake_state_dict, "snapshot")

# ------------------------------------------------------------------
# Error path: no snapshot found anywhere
# ------------------------------------------------------------------
@patch("fastdeploy.rl.dynamic_weight_manager.os.path.exists", return_value=False)
@patch("fastdeploy.rl.dynamic_weight_manager.glob.glob", return_value=[])
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.distributed.get_rank", return_value=0)
def test_no_snapshot_raises_file_not_found(self, _mock_rank, _mock_glob, _mock_exists):
"""Should raise FileNotFoundError when none of the candidate paths exist."""
mgr = _make_manager()
mgr._update_model_from_state = MagicMock()

with self.assertRaises(FileNotFoundError) as ctx:
mgr._update_ipc_snapshot()
self.assertIn("No snapshot found", str(ctx.exception))

# ------------------------------------------------------------------
# Priority 1 sort correctness: part files loaded in numeric order
# ------------------------------------------------------------------
@patch("fastdeploy.rl.dynamic_weight_manager.gc.collect")
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.load")
@patch("fastdeploy.rl.dynamic_weight_manager.os.path.exists")
@patch("fastdeploy.rl.dynamic_weight_manager.glob.glob")
@patch("fastdeploy.rl.dynamic_weight_manager.paddle.distributed.get_rank", return_value=0)
def test_priority1_part_files_sorted_by_number(self, _mock_rank, mock_glob, mock_exists, mock_load, _mock_gc):
"""When glob returns part files out of order, they should be loaded in ascending numeric order."""
model_dir = "/fake/model"
# Intentionally return files in wrong order
part_files_unordered = [
f"{model_dir}/model_state.tp0.0.part3.pdparams",
f"{model_dir}/model_state.tp0.0.part1.pdparams",
f"{model_dir}/model_state.tp0.0.part2.pdparams",
]
mock_glob.return_value = part_files_unordered
mock_load.return_value = {}

mgr = _make_manager(model_dir=model_dir)
mgr._update_model_from_state = MagicMock()

# Capture the actual load order
loaded_paths = []
mock_load.side_effect = lambda p, **_kw: loaded_paths.append(p) or {}

mgr._update_ipc_snapshot()

# os.path.exists must NOT be called: priority 1 should return early
mock_exists.assert_not_called()

expected_order = [
f"{model_dir}/model_state.tp0.0.part1.pdparams",
f"{model_dir}/model_state.tp0.0.part2.pdparams",
f"{model_dir}/model_state.tp0.0.part3.pdparams",
]
self.assertEqual(loaded_paths, expected_order)


if __name__ == "__main__":
unittest.main()
Loading