Skip to content

Commit 0ed0cb6

Browse files
authored
Merge pull request #1 from mrsndmn/main
Initial commit
2 parents 01fc9d5 + aa1598c commit 0ed0cb6

21 files changed

+5481
-0
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
generated_datasets/
2+
feature_extractor_weights/
3+
__pycache__/
4+
coco_subsets/train2017/
5+
*.zip
6+
coco_subsets/
7+
metrics_calculation/precalculated_weights/
8+
metrics_calculation/rb_swap/.ipynb_checkpoints/

README.md

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
## Installation
2+
To set up the environment, navigate to the root directory containing `environment.yml` and run:
3+
4+
```bash
5+
conda env create --name interpretation_env --file environment.yml
6+
conda activate interpretation_env
7+
```
8+
9+
## Dataset Generation
10+
Given a feature extractor *E* and an image *i*, we can obtain its features as *f = E(i)*. The reconstruction model is trained on pairs *(i, f)*. To generate such dataset pairs:
11+
12+
**Generate Validation Split**
13+
14+
15+
```bash
16+
# Prepare Data
17+
mkdir coco_subsets
18+
19+
cd coco_subsets
20+
wget http://images.cocodataset.org/zips/val2017.zip
21+
wget http://images.cocodataset.org/zips/train2017.zip
22+
23+
unzip train2017.zip # Extract archieve
24+
unzip val2017.zip # Extract archieve
25+
cd - # Get back to project root
26+
```
27+
28+
**Generate Validation Split Features**
29+
30+
Generated dataset size ~ `300Mb`
31+
32+
```bash
33+
# Run Validation Dataset Generation
34+
VISION_MODEL="google/siglip2-base-patch16-512"
35+
python dataset_generation/generation.py \
36+
--vision_model_name "$VISION_MODEL" \
37+
--coco_images_path "./coco_subsets/val2017" \
38+
--split val \
39+
--max_count 1000
40+
```
41+
42+
**Generate Train Split Features**
43+
44+
Generated dataset size ~ `30GB`
45+
46+
You may limit count of processed images with `--max_count 1000` parameter.
47+
48+
```bash
49+
VISION_MODEL="google/siglip2-base-patch16-512"
50+
51+
python dataset_generation/generation.py \
52+
--vision_model_name "$VISION_MODEL" \
53+
--coco_images_path "./coco_subsets/train2017" \
54+
--split train
55+
```
56+
57+
This script will:
58+
1. Create `feature_extractor_weights` directory for storing pretrained weights
59+
2. Generate datasets in `generated_datasets` directory
60+
3. Use images from `coco_subsets/val2017` by default (configurable via script flags)
61+
62+
## Reconstruction Model Training
63+
Run reconstructor training
64+
65+
Script running takes could take from 6 to 24 hours depending from model supported image resolution.
66+
```bash
67+
python training/train.py --vision_model_name $VISION_MODEL
68+
```
69+
70+
This will:
71+
- Train a reconstructor for `google/siglip2-base-patch16-512` by default
72+
- Use the generated dataset from previous step
73+
- Create `training/samples` for training logs
74+
- Save weights in `training/checkpoint`
75+
76+
### Supported Feature Extractors:
77+
- `google/siglip-base-patch16-{224,256,384,512}`
78+
- `google/siglip2-base-patch16-{224,256,384,512}`
79+
80+
Modify the script arguments to use different extractors.
81+
82+
## CLIP Similarity Calculation
83+
To compute CLIP similarity metrics:
84+
85+
1. Generate dataset for your target feature extractor
86+
2. Train reconstructor or use precomputed [weights](https://drive.google.com/file/d/1i-B-5yBpSwcZL3_Z2Dz53jfxiY9T-fkb/view?usp=drive_link)
87+
3. Place weights in `metrics_calculation/precalculated_weights/` following the pattern:
88+
- `models--google--siglip-base-patch16-512.pt`
89+
- `models--google--siglip2-base-patch16-512.pt`
90+
4. Run:
91+
```bash
92+
bash metrics_calculation/siglip_vs_siglip2/calculate_similarity.sh
93+
```
94+
95+
For SigLIP vs SigLIP2 comparison:
96+
1. Compute metrics for all 8 models
97+
2. Run the analysis notebook:
98+
```bash
99+
metrics_calculation/siglip_vs_siglip2/understanding_graphs_for_article.ipynb
100+
```
101+
102+
Example output:
103+
<div align="center">
104+
<figure>
105+
<img src="resources/v1_vs_v2.png" width="600">
106+
<figcaption>SigLIP vs SigLIP2 Feature Space Comparison</figcaption>
107+
</figure>
108+
</div>
109+
110+
## Orthogonal Transformation Learning
111+
To study orthogonal transformations in feature space:
112+
113+
1. Generate dataset for `google/siglip2-base-patch16-512`
114+
2. Train reconstructor or use precomputed [weights](https://drive.google.com/file/d/1i-B-5yBpSwcZL3_Z2Dz53jfxiY9T-fkb/view?usp=drive_link)
115+
3. Place weights at:
116+
```bash
117+
metrics_calculation/precalculated_weights/models--google--siglip2-base-patch16-512.pt
118+
```
119+
4. Run the analysis notebook:
120+
```
121+
metrics_calculation/rb_swap/understanding_rgb-to-bgr_rotation.ipynb
122+
```
123+
124+
Example output:
125+
<div align="center">
126+
<figure>
127+
<img src="resources/rb_swap.png" width="600">
128+
<figcaption>RGB Channel Swap in Feature Space</figcaption>
129+
</figure>
130+
</div>

dataset_generation/check_generation_result.ipynb

Lines changed: 149 additions & 0 deletions
Large diffs are not rendered by default.

dataset_generation/generation.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import torch
2+
from torch import nn
3+
from transformers import SiglipVisionConfig, SiglipImageProcessor, SiglipVisionModel
4+
5+
valid_model_name_list = [
6+
'google/siglip-base-patch16-224',
7+
'google/siglip-base-patch16-256',
8+
'google/siglip-base-patch16-384',
9+
'google/siglip-base-patch16-512',
10+
11+
'google/siglip2-base-patch16-224',
12+
'google/siglip2-base-patch16-256',
13+
'google/siglip2-base-patch16-384',
14+
'google/siglip2-base-patch16-512'
15+
]
16+
17+
class SigLipVisionTower(nn.Module):
18+
def __init__(self, vision_model_name, weights_dir):
19+
super().__init__()
20+
21+
self.is_loaded = False
22+
23+
self.vision_tower_name = vision_model_name
24+
self.weights_dir = weights_dir
25+
self.select_layer = -2
26+
self.cfg_only = SiglipVisionConfig(self.vision_tower_name, cache_dir=self.weights_dir)
27+
28+
def load_model(self, device_map=None):
29+
if self.is_loaded:
30+
return
31+
32+
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name, cache_dir=self.weights_dir)
33+
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, cache_dir=self.weights_dir, device_map=device_map)
34+
35+
self.vision_tower.vision_model.head = nn.Identity()
36+
self.vision_tower.requires_grad_(False)
37+
self.eval()
38+
39+
self.is_loaded = True
40+
41+
def forward(self, images):
42+
image_forward_out = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
43+
image_feature = image_forward_out.hidden_states[self.select_layer].to(images.dtype)
44+
return image_feature
45+
46+
@property
47+
def dtype(self):
48+
for p in self.vision_tower.parameters():
49+
return p.dtype
50+
51+
@property
52+
def device(self):
53+
for p in self.vision_tower.parameters():
54+
return p.device
55+
56+
@property
57+
def config(self):
58+
if self.is_loaded:
59+
return self.vision_tower.config
60+
else:
61+
return self.cfg_only
62+
63+
@property
64+
def hidden_size(self):
65+
return self.config.hidden_size
66+
67+
@property
68+
def num_patches(self):
69+
return (self.config.image_size // self.config.patch_size) ** 2
70+
71+
@property
72+
def num_patches_per_side(self):
73+
return self.config.image_size // self.config.patch_size
74+
75+
@property
76+
def image_size(self):
77+
return self.config.image_size
78+
79+
80+
if __name__ == "__main__":
81+
import os
82+
import json
83+
import argparse
84+
from tqdm import tqdm
85+
from PIL import Image
86+
87+
current_dir = os.path.dirname(os.path.abspath(__file__))
88+
89+
parser = argparse.ArgumentParser(description='')
90+
parser.add_argument('--vision_model_name', help='')
91+
parser.add_argument('--coco_images_path', help='')
92+
parser.add_argument('--split', help='')
93+
parser.add_argument('--max_count', type=int, help='')
94+
parser.add_argument('--image_types', type=str, nargs='+', default=['png', 'jpg', 'jpeg'], help='')
95+
args = parser.parse_args()
96+
97+
if args.vision_model_name not in valid_model_name_list:
98+
raise Exception(f'vision_model_name should be in {valid_model_name_list}')
99+
if not os.path.isdir(args.coco_images_path):
100+
raise Exception(f'coco_images_path should be a dir with images')
101+
if args.split not in ['train', 'val']:
102+
raise Exception(f'split should be in ["train", "val"]')
103+
104+
#################################### pathes and names ####################################
105+
device = 'cuda:0'
106+
vision_model_name = args.vision_model_name
107+
vision_model_name_for_path = '-'.join(vision_model_name.split('/'))
108+
weights_dir = os.path.join(current_dir, '..', 'feature_extractor_weights')
109+
datasets_dir = os.path.join(current_dir, '..', 'generated_datasets')
110+
os.makedirs(weights_dir, mode=0o777, exist_ok=True)
111+
os.makedirs(datasets_dir, mode=0o777, exist_ok=True)
112+
113+
batch_size = 2
114+
mode = args.split
115+
max_images = args.max_count
116+
images_dir = args.coco_images_path
117+
features_dir = f'{datasets_dir}/{vision_model_name_for_path}/tensors_{mode}'
118+
features_json = f'{datasets_dir}/{vision_model_name_for_path}/map_{mode}.json'
119+
image_types = [args.image_types] if type(args.image_types) == str else args.image_types
120+
image_names = [
121+
n for n in os.listdir(images_dir)
122+
if n.split('.')[-1].lower() in image_types
123+
][:max_images]
124+
125+
126+
os.makedirs(features_dir, mode=0o777, exist_ok=True)
127+
print('----------> A directory for the dataset has been created. <----------')
128+
129+
130+
#################################### dataset generation ####################################
131+
vision_tower = SigLipVisionTower(vision_model_name, weights_dir)
132+
vision_tower.load_model(device_map=device)
133+
print('----------> The model has been downloaded. <----------')
134+
135+
image_feature_map = {}
136+
with torch.inference_mode(), torch.no_grad():
137+
for i in tqdm(range(0, len(image_names), batch_size)):
138+
batch_image_names = image_names[i:i+batch_size]
139+
batch_processed_images = []
140+
batch_image_paths = []
141+
batch_feature_paths = []
142+
143+
for image_name in batch_image_names:
144+
feature_name = image_name.split('.')[0]
145+
feature_path = os.path.join(features_dir, f'{feature_name}.pt')
146+
image_path = os.path.join(images_dir, image_name)
147+
148+
try:
149+
example = Image.open(image_path).convert('RGB')
150+
processed_image = vision_tower.image_processor(example, return_tensors='pt')['pixel_values'][0]
151+
batch_processed_images.append(processed_image)
152+
batch_image_paths.append(image_path)
153+
batch_feature_paths.append(feature_path)
154+
except Exception as e:
155+
print(f"Error processing image {image_path}: {e}")
156+
continue
157+
158+
if not batch_processed_images:
159+
continue
160+
161+
images_batch = torch.stack(batch_processed_images).to(device)
162+
163+
batch_features: torch.Tensor = vision_tower.forward(images_batch)
164+
batch_features = batch_features.to(torch.bfloat16)
165+
assert batch_features.dtype == torch.bfloat16
166+
assert batch_features.shape[0] == len(batch_image_paths)
167+
168+
for idx in range(batch_features.shape[0]):
169+
image_path = batch_image_paths[idx]
170+
feature_path = batch_feature_paths[idx]
171+
image_feature_map[image_path] = feature_path
172+
173+
features = batch_features[idx]
174+
175+
features_reshaped = features.unflatten(0, [vision_tower.num_patches_per_side, vision_tower.num_patches_per_side])
176+
features_reshaped = features_reshaped.clone()
177+
torch.save(features_reshaped, feature_path)
178+
179+
with open(features_json, 'w') as config:
180+
json.dump(image_feature_map, config)

dataset_generation/generation.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
4+
5+
models=(
6+
"google/siglip-base-patch16-224"
7+
"google/siglip-base-patch16-256"
8+
"google/siglip-base-patch16-384"
9+
"google/siglip-base-patch16-512"
10+
"google/siglip2-base-patch16-224"
11+
"google/siglip2-base-patch16-256"
12+
"google/siglip2-base-patch16-384"
13+
"google/siglip2-base-patch16-512"
14+
)
15+
16+
coco_images_path="$SCRIPT_DIR/../coco_subsets/val2017"
17+
split="val"
18+
max_count=10
19+
20+
for model in "${models[@]}"; do
21+
echo "========================================"
22+
echo "Processing model: $model"
23+
echo "========================================"
24+
25+
python "$SCRIPT_DIR/generation.py" \
26+
--vision_model_name "$model" \
27+
--coco_images_path "$coco_images_path" \
28+
--split "$split" \
29+
--max_count "$max_count"
30+
31+
echo "Finished processing model: $model"
32+
echo
33+
done
34+
35+
echo "All models processed successfully!"

0 commit comments

Comments
 (0)