Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion diffsynth/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,54 @@
},
]

image_metrics_series = [
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors")
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors")
"model_hash": "b5e2c0bfcbf4085ccdb2feb8f0ba408a",
"model_name": "image_metrics_clip_hf",
"model_class": "diffsynth.models.clip.ImageMetricsCLIPModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsCLIPStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors")
"model_hash": "f79e72cec8ae5a540cff0304bfb21b00",
"model_name": "image_metrics_hpsv2",
"model_class": "diffsynth.models.clip.ImageMetricsCLIPModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsOpenCLIPStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors")
"model_hash": "5655d9cde15b759cfeefe7432d7a912c",
"model_name": "image_metrics_hpsv3",
"model_class": "diffsynth.models.hpsv3.HPSv3Qwen2VLRewardModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsHPSv3StateDictConverter",
"extra_kwargs": {"vocab_size": 151658, "output_dim": 2, "reward_token": "special", "rm_head_type": "ranknet"},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors")
"model_hash": "b3cc8e10b76ca98cde653daa5cf63139",
"model_name": "image_metrics_image_reward",
"model_class": "diffsynth.models.image_reward.ImageRewardModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsImageRewardStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors")
"model_hash": "306981222ec94302794e07cf676c84cc",
"model_name": "image_metrics_aesthetic",
"model_class": "diffsynth.models.aesthetic.AestheticModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsAestheticStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors")
"model_hash": "d4e9549be726259b444d1f62db4ce413",
"model_name": "image_metrics_fid_inception",
"model_class": "diffsynth.models.fid.FIDInceptionModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsFIDStateDictConverter",
},
]

MODEL_CONFIGS = (
stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + image_metrics_series
)
22 changes: 22 additions & 0 deletions diffsynth/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from ..core import ModelConfig
from .aesthetic import AestheticMetric
from .base import Metric
from .clip import CLIPMetric
from .fid import FIDMetric
from .hpsv2 import HPSv2Metric
from .hpsv3 import HPSv3Metric
from .image_reward import ImageRewardMetric
from .pickscore import PickScoreMetric


__all__ = [
"Metric",
"ModelConfig",
"PickScoreMetric",
"ImageRewardMetric",
"HPSv2Metric",
"HPSv3Metric",
"CLIPMetric",
"AestheticMetric",
"FIDMetric",
]
42 changes: 42 additions & 0 deletions diffsynth/metrics/aesthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.aesthetic import AestheticModel
from .base import Metric
from transformers import CLIPImageProcessor

class AestheticMetric(Metric):
def __init__(self, model: AestheticModel):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors"),
processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/"),
torch_dtype: torch.dtype = None,
device: torch.device = get_device_type(),
processor_kwargs: dict = None,
vram_limit: float = None,
):

processor_kwargs = processor_kwargs or {}
model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
model = model_pool.fetch_model("image_metrics_aesthetic")
processor_config.download_if_necessary()
model.processor = CLIPImageProcessor.from_pretrained(processor_config.path, **processor_kwargs)
model.layers = model.layers.float()
model = model.eval()
return cls(model)

@torch.no_grad()
def score(self, images):
scores = self.model(images)
return self.tensor_to_list(scores)

def compute(self, images):
return self.score(images)

def forward(self, images):
return self.score(images)
28 changes: 28 additions & 0 deletions diffsynth/metrics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from ..core import ModelConfig
from ..models.model_loader import ModelPool

class Metric(torch.nn.Module):

@staticmethod
def tensor_to_list(value):
if torch.is_tensor(value):
value = value.detach().cpu().tolist()
return value if isinstance(value, list) else [value]

@staticmethod
def download_and_load_models(model_configs: list[ModelConfig], torch_dtype: torch.dtype = torch.float32, device="cuda", vram_limit: float = None):
model_pool = ModelPool()
for model_config in model_configs:
model_config.download_if_necessary()
vram_config = model_config.vram_config()
vram_config["computation_dtype"] = vram_config["computation_dtype"] or torch_dtype or torch.float32
vram_config["computation_device"] = vram_config["computation_device"] or device
model_pool.auto_load_model(
model_config.path,
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
)
return model_pool
55 changes: 55 additions & 0 deletions diffsynth/metrics/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from transformers import AutoProcessor
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.clip import CLIPModel
from .base import Metric

class CLIPMetric(Metric):
def __init__(self, model: CLIPModel):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors"),
processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/"),
torch_dtype: torch.dtype = None,
device: torch.device = get_device_type(),
max_length: int = 77,
processor_kwargs: dict = None,
vram_limit: float = None,
):

processor_kwargs = processor_kwargs or {}
model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
model = model_pool.fetch_model("image_metrics_clip_hf")
processor_config.download_if_necessary()
processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
model = CLIPModel(model=model, processor=processor, max_length=max_length).eval()
return cls(model)

@torch.no_grad()
def score(
self,
prompt: str | list[str],
images,
):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)

@torch.no_grad()
def similarity_matrix(
self,
prompt: str | list[str],
images,
):
scores = self.model.similarity_matrix(prompt, images)
return self.tensor_to_list(scores)

def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)

def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
37 changes: 37 additions & 0 deletions diffsynth/metrics/fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.fid import FIDModel
from .base import Metric


class FIDMetric(Metric):
def __init__(self, model: FIDModel):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors"),
device: torch.device = get_device_type(),
batch_size: int = 16,
num_workers: int = 0,
vram_limit: float = None,
):
model_pool = cls.download_and_load_models([model_config], torch_dtype=torch.float32, device=device, vram_limit=vram_limit)
model = model_pool.fetch_model("image_metrics_fid_inception")
model = FIDModel(model=model, device=device, batch_size=batch_size, num_workers=num_workers)
return cls(model)

@torch.no_grad()
def compute(self, reference_images, generated_images, batch_size: int = None, num_workers: int = None):
score = self.model.compute(reference_images, generated_images, batch_size=batch_size, num_workers=num_workers)
return score.detach().cpu().item() if torch.is_tensor(score) else float(score)

def statistics(self, images, batch_size: int = None, num_workers: int = None):
return self.model.statistics(images, batch_size=batch_size, num_workers=num_workers)

def forward(self, reference_images, generated_images):
return self.compute(reference_images, generated_images)
41 changes: 41 additions & 0 deletions diffsynth/metrics/hpsv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from transformers import AutoProcessor
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.hpsv2 import HPSv2Model
from .base import Metric

class HPSv2Metric(Metric):
def __init__(self, model: HPSv2Model):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors"),
processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/"),
torch_dtype: torch.dtype = None,
device: torch.device = get_device_type(),
processor_kwargs: dict = None,
vram_limit: float = None,
):

processor_kwargs = processor_kwargs or {}
model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
model = model_pool.fetch_model("image_metrics_hpsv2")
processor_config.download_if_necessary()
processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
model = HPSv2Model(model=model, processor=processor).eval()
return cls(model)

@torch.no_grad()
def score(self, prompt: str | list[str], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)

def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)

def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
63 changes: 63 additions & 0 deletions diffsynth/metrics/hpsv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from transformers import AutoProcessor
import torch
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.hpsv3 import HPSv3Model
from .base import Metric


class HPSv3Metric(Metric):
def __init__(self, model: HPSv3Model):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors"),
processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/"),
torch_dtype: torch.dtype = torch.bfloat16,
device: torch.device = get_device_type(),
score_index: int = 0,
use_special_tokens: bool = True,
max_pixels: int = 256 * 28 * 28,
min_pixels: int = 256 * 28 * 28,
processor_kwargs: dict = None,
vram_limit: float = None,
):

processor_kwargs = processor_kwargs or {}
model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
model = model_pool.fetch_model("image_metrics_hpsv3")
processor_config.download_if_necessary()
processor = AutoProcessor.from_pretrained(processor_config.path, padding_side="right", **processor_kwargs)
if use_special_tokens:
special_tokens = ["<|Reward|>"]
processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
model.special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
model.reward_token = "special"
model.config.tokenizer_padding_side = processor.tokenizer.padding_side
model.config.pad_token_id = processor.tokenizer.pad_token_id
if hasattr(model.config, "text_config"):
model.config.text_config.pad_token_id = processor.tokenizer.pad_token_id
model.rm_head.to(torch.float32)
model = HPSv3Model(
model=model,
processor=processor,
use_special_tokens=use_special_tokens,
max_pixels=max_pixels,
min_pixels=min_pixels,
score_index=score_index,
).eval()
return cls(model)

@torch.no_grad()
def score(self, prompt: str | list[str], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)

def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)

def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
48 changes: 48 additions & 0 deletions diffsynth/metrics/image_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from transformers import BertTokenizer
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.image_reward import ImageRewardModel
from .base import Metric

class ImageRewardMetric(Metric):
def __init__(self, model: ImageRewardModel):
super().__init__()
self.model = model

@classmethod
def from_pretrained(
cls,
model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors"),
tokenizer_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/"),
torch_dtype: torch.dtype = None,
device: torch.device = get_device_type(),
max_length: int = 35,
tokenizer_kwargs: dict = None,
vram_limit: float = None,
):

tokenizer_kwargs = tokenizer_kwargs or {}
model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
model = model_pool.fetch_model("image_metrics_image_reward")
tokenizer_config.download_if_necessary()
tokenizer = BertTokenizer.from_pretrained(tokenizer_config.path, **tokenizer_kwargs)
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
tokenizer.enc_token_id = tokenizer.convert_tokens_to_ids("[ENC]")
model.tokenizer = tokenizer
model.max_length = max_length
model.mlp = model.mlp.float()
model = model.eval()
return cls(model)

@torch.no_grad()
def score(self, prompt: str | list[str], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)

def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)

def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
Loading