Skip to content

Commit fa533e2

Browse files
author
钮圣虓
committed
fix
1 parent 9f43400 commit fa533e2

6 files changed

Lines changed: 50 additions & 19 deletions

File tree

lightllm/models/internvl/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ def init_imageitem_extral_params(
5656
):
5757
if sampling_params.image_max_patch_num > 0:
5858
img.extra_params["image_patch_max_num"] = sampling_params.image_max_patch_num
59-
return
6059
elif os.getenv("MAX_PATCH_NUM"):
6160
img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM"))
62-
return
6361
else:
6462
num_images = len(multi_params.images)
6563
if num_images == 1:

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,18 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5050
img_start_token_ids = []
5151
img_token_lens = []
5252
img_start_locs_in_cache = []
53-
unique_uids = []
54-
all_uids = []
5553
device = layer_weight.wte_weight_.weight.device
5654
dtype = layer_weight.wte_weight_.weight.dtype
5755
hidden_size = layer_weight.wte_weight_.weight.shape[1]
5856

59-
for _, p in enumerate(infer_state.multimodal_params):
57+
for batch_id, p in enumerate(infer_state.multimodal_params):
6058
for img in p["images"] + p["audios"]:
61-
all_uids.append(img["uuid"])
6259
# skip the same image
6360
if img["token_id"] in img_start_token_ids:
6461
continue
6562
img_start_token_ids.append(img["token_id"])
6663
img_token_lens.append(img["token_num"])
6764
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
68-
unique_uids.append(img["uuid"])
6965
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
7066

7167
from lightllm.server.router.model_infer.infer_batch import g_infer_context
@@ -78,12 +74,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
7874
)
7975

8076
if self.args.enable_remote_vit:
81-
for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache):
82-
embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir)
83-
self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache)
77+
unique_image_uids = []
78+
for _, p in enumerate(infer_state.multimodal_params):
79+
for img in p["images"]:
80+
if img["uuid"] in unique_image_uids:
81+
continue
82+
img_uid = img["uuid"]
83+
img_idx = img["start_index_in_embed_cache"]
84+
unique_image_uids.append(img_uid)
85+
embed_tensor = load_tensor_afs(get_shm_name_embed(img_uid), self.args.image_embed_dir)
86+
self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, img_idx)
8487

85-
if all_uids:
86-
self.cache_client.root.release(all_uids)
88+
if unique_image_uids:
89+
self.cache_client.root.release(unique_image_uids)
8790

8891
assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
8992
f"Dimension mismatch: text weight dimension is {hidden_size}, "

lightllm/server/api_lightllm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lightllm.server.core.objs.sampling_params import SamplingParams
66
from .multimodal_params import MultimodalParams
77
from .httpserver.manager import HttpServerManager
8+
from lightllm.utils.envs_utils import get_env_start_args
89
import ujson as json
910

1011

@@ -154,13 +155,15 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
154155

155156
async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response:
156157
request_dict = await request.json()
157-
# request_dict: {'parameters': {'max_new_tokens': 128},
158-
# 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}}
158+
args = get_env_start_args()
159+
assert not args.disable_vision
160+
assert args.enable_remote_vit
159161
sample_params_dict = request_dict["parameters"]
160162
sampling_params = SamplingParams()
161163
sampling_params.init(tokenizer=None, **sample_params_dict)
162164
sampling_params.verify()
163165
multimodal_params_dict = request_dict.get("multimodal_params", {})
166+
assert not multimodal_params_dict.get("audios")
164167
multimodal_params = MultimodalParams(**multimodal_params_dict)
165168

166169
await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request)

lightllm/server/api_start.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,8 @@ def visual_start(args):
537537

538538
if args.visual_nccl_ports is not None:
539539
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
540+
else:
541+
args.visual_nccl_ports = visual_nccl_ports
540542

541543
args.router_port = router_port
542544
args.visual_port = visual_port

lightllm/server/embed_cache/impl/memory_cache_with_redis.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def set_items_embed(self, ids: list[int]) -> None:
5454
rec = self._records.get(id)
5555
if rec is not None:
5656
rec.embed = True
57+
# Before the embed becomes ready, concurrent miss requests are only
58+
# tracked by the local record refcount. Materialize the remaining
59+
# pending readers into Redis so each later release has a matching
60+
# remote ref to consume.
61+
pending_remote_readers = max(rec.ref - 1, 0)
62+
for _ in range(pending_remote_readers):
63+
self.redis_cache.query_and_incre(str(id))
5764
if rec.ref > 0:
5865
self._update_record_ref_by_id(id, -1)
5966
# 保留一份 redis 引用,直到真正的消费者读取完成后再 release,

lightllm/server/httpserver/manager.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ async def _alloc_resource(self, items, uuids, token_nums, datas):
139139
raise Exception(str(records) + "and try to set --embed_cache_storage_size bigger")
140140

141141
uid_list = []
142+
unique_image_uids = []
142143
for item, rec in zip(items, records):
143144
item: Union[ImageItem, AudioItem] = item
144145
item.uuid = rec["id"]
@@ -147,11 +148,13 @@ async def _alloc_resource(self, items, uuids, token_nums, datas):
147148
item.start_index_in_embed_cache = rec["start_index_in_embed_cache"]
148149

149150
uid_list.append(rec["id"])
151+
if isinstance(item, ImageItem) and rec["id"] not in unique_image_uids:
152+
unique_image_uids.append(rec["id"])
150153

151-
# # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server
154+
# # If enable the vit-llm disaggregation, no need to cache the data in the memory of the server
152155
if self.args.enable_remote_vit:
153156
# 避免远端lru被逐出
154-
self.cache_client.root.get_items_embed(uid_list, False)
157+
self.cache_client.root.get_items_embed(unique_image_uids, False)
155158

156159
ready_flags = obtain(self.cache_client.root.get_items_data(uid_list))
157160
update_data_ids = []
@@ -251,6 +254,15 @@ async def loop_for_request(self):
251254
sampling_params,
252255
multimodal_params,
253256
) = await self.multinode_req_manager.recv_pyobj()
257+
258+
if prompt is None:
259+
260+
async def image_embedding_wrapper(sampling_params, multimodal_params):
261+
await self.get_image_embeding(sampling_params, multimodal_params, None)
262+
263+
asyncio.create_task(image_embedding_wrapper(sampling_params, multimodal_params))
264+
continue
265+
254266
results_generator = self.generate(prompt, sampling_params, multimodal_params, None)
255267

256268
async def generate_wrapper(results_generator):
@@ -450,7 +462,11 @@ async def get_image_embeding(
450462
visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time)
451463

452464
await self.transfer_to_next_module_or_node(
453-
None, sampling_params, original_multimodal_params, visual_req_status
465+
None,
466+
sampling_params,
467+
original_multimodal_params,
468+
visual_req_status,
469+
only_visual=True,
454470
)
455471
await self._release_multimodal_resources(multimodal_params)
456472

@@ -573,6 +589,7 @@ async def transfer_to_next_module_or_node(
573589
sampling_params: SamplingParams,
574590
original_multimodal_params: MultimodalParams,
575591
group_req_objs: Optional[GroupReqObjs] = None,
592+
only_visual: bool = False,
576593
):
577594
# 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点.
578595
if self.is_multinode_tp_master:
@@ -582,19 +599,20 @@ async def transfer_to_next_module_or_node(
582599
protocol=pickle.HIGHEST_PROTOCOL,
583600
)
584601

585-
await self.transfer_to_next_module(group_req_objs)
602+
await self.transfer_to_next_module(group_req_objs, only_visual=only_visual)
586603
return
587604

588605
async def transfer_to_next_module(
589606
self,
590607
group_req_objs: Optional[GroupReqObjs] = None,
608+
only_visual: bool = False,
591609
):
592610

593611
if self.pd_mode.is_P_or_NORMAL():
594612
group_req_index = group_req_objs.to_group_req_index()
595613
if not self.args.disable_vision:
596614
await self.vit_manager.send_to_vit(group_req_index, protocol=pickle.HIGHEST_PROTOCOL)
597-
if not self.args.enable_remote_vit:
615+
if only_visual or not self.args.enable_remote_vit:
598616
return
599617

600618
if not self.args.disable_audio:

0 commit comments

Comments
 (0)