Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def dispatch_task_loop(self):
self.waiting_dict[task.get_key()] = task
else:
task.start_trans_time = time.time()
self.success_queue.put((None, task))
self.success_queue.put((None, None, task))

# up status
task = trans_task_group.task_list[0]
Expand Down Expand Up @@ -335,7 +335,10 @@ def read_page_to_mems_loop(self):
while True:
trans_task: NIXLChunckedTransTask = self.ready_page_task_queue.get()
# 将数据写回 mem manger
copy_start_event = torch.cuda.Event(enable_timing=True)
copy_end_event = torch.cuda.Event(enable_timing=True)
with torch.cuda.stream(stream=self.copy_cuda_stream):
copy_start_event.record(self.copy_cuda_stream)
cur_mem = self.mem_managers[self.device_id]
cur_mem.read_page_kv_move_buffer_to_mem(
mem_indexes=trans_task.mem_indexes,
Expand All @@ -344,22 +347,21 @@ def read_page_to_mems_loop(self):
mem_managers=self.mem_managers,
dp_world_size=self.dp_world_size,
)
sync_event = torch.cuda.Event()
sync_event.record()
copy_end_event.record(self.copy_cuda_stream)

self.success_queue.put((sync_event, trans_task))
self.success_queue.put((copy_end_event, copy_start_event, trans_task))
return

@log_exception
def success_loop(self):
torch.cuda.set_device(self.device_id)
while True:
sync_event, trans_task = self.success_queue.get()
copy_end_event, copy_start_event, trans_task = self.success_queue.get()
trans_task: NIXLChunckedTransTask = trans_task
sync_event: Optional[torch.cuda.Event] = sync_event
# 兼容传输kv 数量为0的时候, sync_event 为 None的情况。
if sync_event is not None:
sync_event.synchronize()
read_page_gpu_time_ms = -1.0
if copy_end_event is not None:
copy_end_event.synchronize()
read_page_gpu_time_ms = copy_start_event.elapsed_time(copy_end_event)
Comment on lines +362 to +364
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure robust defensive programming, verify that both copy_end_event and copy_start_event are not None before calling elapsed_time to prevent potential AttributeError exceptions.

Suggested change
if copy_end_event is not None:
copy_end_event.synchronize()
read_page_gpu_time_ms = copy_start_event.elapsed_time(copy_end_event)
if copy_end_event is not None and copy_start_event is not None:
copy_end_event.synchronize()
read_page_gpu_time_ms = copy_start_event.elapsed_time(copy_end_event)


if trans_task.nixl_dst_page_index is not None:
self.page_index_queue.put(trans_task.nixl_dst_page_index)
Expand All @@ -369,7 +371,13 @@ def success_loop(self):

ret = trans_task.createRetObj()
self.task_out_queue.put(ret)
logger.info(f"trans task ret success:{ret} cost time: {trans_task.transfer_time()} s")
if read_page_gpu_time_ms >= 0:
logger.info(
f"trans task ret success:{ret} cost time: {trans_task.transfer_time()} s "
f"read_page_gpu_time: {read_page_gpu_time_ms:.3f} ms"
)
else:
logger.info(f"trans task ret success:{ret} cost time: {trans_task.transfer_time()} s")

@log_exception
def fail_loop(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
import copy
import time
from dataclasses import dataclass
from collections import defaultdict
from typing import Dict, List, Any, Optional, Tuple
Expand Down Expand Up @@ -63,6 +64,8 @@ def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata):
if remote_agent.agent_name in self.remote_agents:
return

start_time = time.time()

peer_name = self.nixl_agent.add_remote_agent(remote_agent.agent_metadata)
if isinstance(peer_name, bytes):
peer_name = peer_name.decode()
Expand All @@ -77,7 +80,9 @@ def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata):
)
remote_agent.page_xfer_handles = kv_page_xfer_handles

logger.info(f"Added remote agent {peer_name} with mem desc {page_mem_desc}")
logger.info(
f"Added remote agent {peer_name} with mem desc {page_mem_desc} cost time: {time.time() - start_time} s"
)

self.remote_agents[remote_agent.agent_name] = remote_agent
return
Expand Down
Loading