1+ import torch
2+ from torch import nn
3+ from transformers import SiglipVisionConfig , SiglipImageProcessor , SiglipVisionModel
4+
5+ valid_model_name_list = [
6+ 'google/siglip-base-patch16-224' ,
7+ 'google/siglip-base-patch16-256' ,
8+ 'google/siglip-base-patch16-384' ,
9+ 'google/siglip-base-patch16-512' ,
10+
11+ 'google/siglip2-base-patch16-224' ,
12+ 'google/siglip2-base-patch16-256' ,
13+ 'google/siglip2-base-patch16-384' ,
14+ 'google/siglip2-base-patch16-512'
15+ ]
16+
17+ class SigLipVisionTower (nn .Module ):
18+ def __init__ (self , vision_model_name , weights_dir ):
19+ super ().__init__ ()
20+
21+ self .is_loaded = False
22+
23+ self .vision_tower_name = vision_model_name
24+ self .weights_dir = weights_dir
25+ self .select_layer = - 2
26+ self .cfg_only = SiglipVisionConfig (self .vision_tower_name , cache_dir = self .weights_dir )
27+
28+ def load_model (self , device_map = None ):
29+ if self .is_loaded :
30+ return
31+
32+ self .image_processor = SiglipImageProcessor .from_pretrained (self .vision_tower_name , cache_dir = self .weights_dir )
33+ self .vision_tower = SiglipVisionModel .from_pretrained (self .vision_tower_name , cache_dir = self .weights_dir , device_map = device_map )
34+
35+ self .vision_tower .vision_model .head = nn .Identity ()
36+ self .vision_tower .requires_grad_ (False )
37+ self .eval ()
38+
39+ self .is_loaded = True
40+
41+ def forward (self , images ):
42+ image_forward_out = self .vision_tower (images .to (device = self .device , dtype = self .dtype ), output_hidden_states = True )
43+ image_feature = image_forward_out .hidden_states [self .select_layer ].to (images .dtype )
44+ return image_feature
45+
46+ @property
47+ def dtype (self ):
48+ for p in self .vision_tower .parameters ():
49+ return p .dtype
50+
51+ @property
52+ def device (self ):
53+ for p in self .vision_tower .parameters ():
54+ return p .device
55+
56+ @property
57+ def config (self ):
58+ if self .is_loaded :
59+ return self .vision_tower .config
60+ else :
61+ return self .cfg_only
62+
63+ @property
64+ def hidden_size (self ):
65+ return self .config .hidden_size
66+
67+ @property
68+ def num_patches (self ):
69+ return (self .config .image_size // self .config .patch_size ) ** 2
70+
71+ @property
72+ def num_patches_per_side (self ):
73+ return self .config .image_size // self .config .patch_size
74+
75+ @property
76+ def image_size (self ):
77+ return self .config .image_size
78+
79+
80+ if __name__ == "__main__" :
81+ import os
82+ import json
83+ import argparse
84+ from tqdm import tqdm
85+ from PIL import Image
86+
87+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
88+
89+ parser = argparse .ArgumentParser (description = '' )
90+ parser .add_argument ('--vision_model_name' , help = '' )
91+ parser .add_argument ('--coco_images_path' , help = '' )
92+ parser .add_argument ('--split' , help = '' )
93+ parser .add_argument ('--max_count' , type = int , help = '' )
94+ parser .add_argument ('--image_types' , type = str , nargs = '+' , default = ['png' , 'jpg' , 'jpeg' ], help = '' )
95+ args = parser .parse_args ()
96+
97+ if args .vision_model_name not in valid_model_name_list :
98+ raise Exception (f'vision_model_name should be in { valid_model_name_list } ' )
99+ if not os .path .isdir (args .coco_images_path ):
100+ raise Exception (f'coco_images_path should be a dir with images' )
101+ if args .split not in ['train' , 'val' ]:
102+ raise Exception (f'split should be in ["train", "val"]' )
103+
104+ #################################### pathes and names ####################################
105+ device = 'cuda:0'
106+ vision_model_name = args .vision_model_name
107+ vision_model_name_for_path = '-' .join (vision_model_name .split ('/' ))
108+ weights_dir = os .path .join (current_dir , '..' , 'feature_extractor_weights' )
109+ datasets_dir = os .path .join (current_dir , '..' , 'generated_datasets' )
110+ os .makedirs (weights_dir , mode = 0o777 , exist_ok = True )
111+ os .makedirs (datasets_dir , mode = 0o777 , exist_ok = True )
112+
113+ batch_size = 2
114+ mode = args .split
115+ max_images = args .max_count
116+ images_dir = args .coco_images_path
117+ features_dir = f'{ datasets_dir } /{ vision_model_name_for_path } /tensors_{ mode } '
118+ features_json = f'{ datasets_dir } /{ vision_model_name_for_path } /map_{ mode } .json'
119+ image_types = [args .image_types ] if type (args .image_types ) == str else args .image_types
120+ image_names = [
121+ n for n in os .listdir (images_dir )
122+ if n .split ('.' )[- 1 ].lower () in image_types
123+ ][:max_images ]
124+
125+
126+ os .makedirs (features_dir , mode = 0o777 , exist_ok = True )
127+ print ('----------> A directory for the dataset has been created. <----------' )
128+
129+
130+ #################################### dataset generation ####################################
131+ vision_tower = SigLipVisionTower (vision_model_name , weights_dir )
132+ vision_tower .load_model (device_map = device )
133+ print ('----------> The model has been downloaded. <----------' )
134+
135+ image_feature_map = {}
136+ with torch .inference_mode (), torch .no_grad ():
137+ for i in tqdm (range (0 , len (image_names ), batch_size )):
138+ batch_image_names = image_names [i :i + batch_size ]
139+ batch_processed_images = []
140+ batch_image_paths = []
141+ batch_feature_paths = []
142+
143+ for image_name in batch_image_names :
144+ feature_name = image_name .split ('.' )[0 ]
145+ feature_path = os .path .join (features_dir , f'{ feature_name } .pt' )
146+ image_path = os .path .join (images_dir , image_name )
147+
148+ try :
149+ example = Image .open (image_path ).convert ('RGB' )
150+ processed_image = vision_tower .image_processor (example , return_tensors = 'pt' )['pixel_values' ][0 ]
151+ batch_processed_images .append (processed_image )
152+ batch_image_paths .append (image_path )
153+ batch_feature_paths .append (feature_path )
154+ except Exception as e :
155+ print (f"Error processing image { image_path } : { e } " )
156+ continue
157+
158+ if not batch_processed_images :
159+ continue
160+
161+ images_batch = torch .stack (batch_processed_images ).to (device )
162+
163+ batch_features : torch .Tensor = vision_tower .forward (images_batch )
164+ batch_features = batch_features .to (torch .bfloat16 )
165+ assert batch_features .dtype == torch .bfloat16
166+ assert batch_features .shape [0 ] == len (batch_image_paths )
167+
168+ for idx in range (batch_features .shape [0 ]):
169+ image_path = batch_image_paths [idx ]
170+ feature_path = batch_feature_paths [idx ]
171+ image_feature_map [image_path ] = feature_path
172+
173+ features = batch_features [idx ]
174+
175+ features_reshaped = features .unflatten (0 , [vision_tower .num_patches_per_side , vision_tower .num_patches_per_side ])
176+ features_reshaped = features_reshaped .clone ()
177+ torch .save (features_reshaped , feature_path )
178+
179+ with open (features_json , 'w' ) as config :
180+ json .dump (image_feature_map , config )
0 commit comments