diff --git a/clip_benchmark/models/__init__.py b/clip_benchmark/models/__init__.py index 96a6129..8d82ee1 100644 --- a/clip_benchmark/models/__init__.py +++ b/clip_benchmark/models/__init__.py @@ -2,11 +2,16 @@ import torch from .open_clip import load_open_clip from .japanese_clip import load_japanese_clip +from .synthclip import load_synthclip +from .scaling import load_model # loading function must return (model, transform, tokenizer) TYPE2FUNC = { "open_clip": load_open_clip, - "ja_clip": load_japanese_clip + "ja_clip": load_japanese_clip, + "synthclip": load_synthclip, + "scaling": load_model, + "auto": None, } MODEL_TYPES = list(TYPE2FUNC.keys()) @@ -19,5 +24,18 @@ def load_clip( device: Union[str, torch.device] = "cuda" ): assert model_type in MODEL_TYPES, f"model_type={model_type} is invalid!" - load_func = TYPE2FUNC[model_type] + if model_type != "auto": + load_func = TYPE2FUNC[model_type] + else: + # It's a hack, but it works! you have a better way? push a PR 😃. EOM - Victor + if "synthclip" in pretrained: + load_func = TYPE2FUNC["synthclip"] + elif "scaling" in pretrained: + load_func = TYPE2FUNC["scaling"] + elif pretrained in TYPE2FUNC: + load_func = TYPE2FUNC[pretrained] + else: + print(f"{model_type} and {pretrained=} unsupported defaulting to " + "open_clip. The Lord be with you 🙏") + load_func = TYPE2FUNC["open_clip"] return load_func(model_name=model_name, pretrained=pretrained, cache_dir=cache_dir, device=device) diff --git a/clip_benchmark/models/scaling.py b/clip_benchmark/models/scaling.py new file mode 100644 index 0000000..bca5b5e --- /dev/null +++ b/clip_benchmark/models/scaling.py @@ -0,0 +1,51 @@ +"""Load Scaling laws for synthetic data + +References +- Scaling Laws of Synthetic Images for Model Training ... for Now, Fan et al., CVPR 2024 +- https://github.com/google-research/syn-rep-learn/tree/main/Scaling#clip +""" +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.utils.data + + +from open_clip import create_model_and_transforms, get_tokenizer +from torch.nn import functional as F + + +def load_model(model: str = 'ViT-B-16', pretrained: str = None, + device: str = "cpu", cudnn_benchmark = True, **kwargs): + if pretrained is None: + raise FileNotFoundError(f'Failing early, missing: {pretrained}!') + + tokenizer = get_tokenizer(model) + model, preprocess_train, preprocess_val = create_model_and_transforms( + model, + '', + precision='amp', + device='cuda', + jit=False, + force_quick_gelu=True, + force_custom_text=False, + force_patch_dropout=None, + force_image_size=224, + pretrained_image=False, + image_mean=None, + image_std=None, + aug_cfg={}, + output_dict=True, + ) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + state_dict = torch.load(pretrained, map_location=device) + logit_scale = np.exp(state_dict['logit_scale'].item()) + msg = model.load_state_dict(state_dict, strict=True) + print(msg) + model = model.to(device) + model.eval() + cudnn.benchmark = cudnn_benchmark + return model, preprocess_val, tokenizer + + +if __name__ == '__main__': + load_model(ckpt='./logs/scaling_syn_real/371M.pt') diff --git a/clip_benchmark/models/synthclip.py b/clip_benchmark/models/synthclip.py new file mode 100644 index 0000000..eef8d7c --- /dev/null +++ b/clip_benchmark/models/synthclip.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from github.com/openai/CLIP +from collections import OrderedDict +from pathlib import Path + +import numpy as np +import timm +import torch +import torchvision.transforms as transforms +import open_clip +from torch import nn + +# import losses + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None + else None + ) + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] + ) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + vision_width: int, + vision_model: nn.Module, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + **kwargs, + ): + super().__init__() + + self.context_length = context_length + self.vision_width = vision_width + + self.visual = vision_model + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width) + ) + self.ln_final = LayerNorm(transformer_width) + + self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers) ** -0.5 + ) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + nn.init.normal_(self.image_projection, std=self.vision_width**-0.5) + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_image(self, image): + x = self.visual(image) + x = x @ self.image_projection + + return x + + def encode_text(self, text): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_embed = self.encode_image(image) + text_embed = self.encode_text(text) + + return { + "image_embed": image_embed, + "text_embed": text_embed, + "logit_scale": self.logit_scale.exp(), + } + + +def get_loss(gather_with_grad=False): + return losses.CLIPLoss(gather_with_grad=gather_with_grad) + + +def get_metric_names(): + return ["loss", "clip_loss", "clip_acc"] + + +def CLIP_VITB16(checkpoint_path: str = None, cache_dir: str = None, **kwargs): + vision_model = timm.create_model("vit_base_patch16_224", num_classes=0, + checkpoint_path=checkpoint_path, + cache_dir=cache_dir) + model = CLIP( + embed_dim=512, + vision_width=768, + vision_model=vision_model, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + **kwargs, + ) + + return model + + +def load_synthclip(model_name, pretrained, cache_dir, device): + if model_name == "ViT-B-16": + model = CLIP_VITB16() + tokenizer = open_clip.get_tokenizer(model_name) + + if pretrained: + pretrained = Path(pretrained) + pretrained = pretrained / "checkpoint_best.pt" if pretrained.is_dir() else pretrained + state_dict = torch.load(pretrained)["state_dict"] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k.replace("module.", "") + new_state_dict[name] = v + + load_status = model.load_state_dict(new_state_dict) + print(f'{__name__}:{load_synthclip.__name__}, {model_name=}, {pretrained=}, {device=}, {load_status=}') + + model.to(device) + val_transform = transform_pipeline() + return model, val_transform, tokenizer + + +def transform_pipeline(): + # Taken from + # https://github.com/hammoudhasan/SynthCLIP/blob/02ef69764d8dc921650bcac4a98bd0f477790787/Training/main.py#L240 + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + transform = transforms.Compose( + [ + transforms.Lambda(lambda img: img.convert('RGB')), + transforms.Resize((224, 224)), + transforms.ColorJitter(0.4, 0.4, 0.4), + transforms.ToTensor(), + normalize, + ] + ) + return transform