@@ -139,6 +139,7 @@ async def _alloc_resource(self, items, uuids, token_nums, datas):
139139 raise Exception (str (records ) + "and try to set --embed_cache_storage_size bigger" )
140140
141141 uid_list = []
142+ unique_image_uids = []
142143 for item , rec in zip (items , records ):
143144 item : Union [ImageItem , AudioItem ] = item
144145 item .uuid = rec ["id" ]
@@ -147,11 +148,13 @@ async def _alloc_resource(self, items, uuids, token_nums, datas):
147148 item .start_index_in_embed_cache = rec ["start_index_in_embed_cache" ]
148149
149150 uid_list .append (rec ["id" ])
151+ if isinstance (item , ImageItem ) and rec ["id" ] not in unique_image_uids :
152+ unique_image_uids .append (rec ["id" ])
150153
151- # # If enable the vit/audio -llm disaggregation, no need to cache the data in the memory of the server
154+ # # If enable the vit-llm disaggregation, no need to cache the data in the memory of the server
152155 if self .args .enable_remote_vit :
153156 # 避免远端lru被逐出
154- self .cache_client .root .get_items_embed (uid_list , False )
157+ self .cache_client .root .get_items_embed (unique_image_uids , False )
155158
156159 ready_flags = obtain (self .cache_client .root .get_items_data (uid_list ))
157160 update_data_ids = []
@@ -251,6 +254,15 @@ async def loop_for_request(self):
251254 sampling_params ,
252255 multimodal_params ,
253256 ) = await self .multinode_req_manager .recv_pyobj ()
257+
258+ if prompt is None :
259+
260+ async def image_embedding_wrapper (sampling_params , multimodal_params ):
261+ await self .get_image_embeding (sampling_params , multimodal_params , None )
262+
263+ asyncio .create_task (image_embedding_wrapper (sampling_params , multimodal_params ))
264+ continue
265+
254266 results_generator = self .generate (prompt , sampling_params , multimodal_params , None )
255267
256268 async def generate_wrapper (results_generator ):
@@ -450,7 +462,11 @@ async def get_image_embeding(
450462 visual_req_status = GroupReqObjs (group_request_id , multimodal_params , None , start_time )
451463
452464 await self .transfer_to_next_module_or_node (
453- None , sampling_params , original_multimodal_params , visual_req_status
465+ None ,
466+ sampling_params ,
467+ original_multimodal_params ,
468+ visual_req_status ,
469+ only_visual = True ,
454470 )
455471 await self ._release_multimodal_resources (multimodal_params )
456472
@@ -573,6 +589,7 @@ async def transfer_to_next_module_or_node(
573589 sampling_params : SamplingParams ,
574590 original_multimodal_params : MultimodalParams ,
575591 group_req_objs : Optional [GroupReqObjs ] = None ,
592+ only_visual : bool = False ,
576593 ):
577594 # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点.
578595 if self .is_multinode_tp_master :
@@ -582,19 +599,20 @@ async def transfer_to_next_module_or_node(
582599 protocol = pickle .HIGHEST_PROTOCOL ,
583600 )
584601
585- await self .transfer_to_next_module (group_req_objs )
602+ await self .transfer_to_next_module (group_req_objs , only_visual = only_visual )
586603 return
587604
588605 async def transfer_to_next_module (
589606 self ,
590607 group_req_objs : Optional [GroupReqObjs ] = None ,
608+ only_visual : bool = False ,
591609 ):
592610
593611 if self .pd_mode .is_P_or_NORMAL ():
594612 group_req_index = group_req_objs .to_group_req_index ()
595613 if not self .args .disable_vision :
596614 await self .vit_manager .send_to_vit (group_req_index , protocol = pickle .HIGHEST_PROTOCOL )
597- if not self .args .enable_remote_vit :
615+ if only_visual or not self .args .enable_remote_vit :
598616 return
599617
600618 if not self .args .disable_audio :
0 commit comments