diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index ad6784ffc..32665e712 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -1026,7 +1026,54 @@
},
]
+image_metrics_series = [
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors")
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors")
+ "model_hash": "b5e2c0bfcbf4085ccdb2feb8f0ba408a",
+ "model_name": "image_metrics_clip_hf",
+ "model_class": "diffsynth.models.clip.ImageMetricsCLIPModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsCLIPStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors")
+ "model_hash": "f79e72cec8ae5a540cff0304bfb21b00",
+ "model_name": "image_metrics_hpsv2",
+ "model_class": "diffsynth.models.clip.ImageMetricsCLIPModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsOpenCLIPStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors")
+ "model_hash": "5655d9cde15b759cfeefe7432d7a912c",
+ "model_name": "image_metrics_hpsv3",
+ "model_class": "diffsynth.models.hpsv3.HPSv3Qwen2VLRewardModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsHPSv3StateDictConverter",
+ "extra_kwargs": {"vocab_size": 151658, "output_dim": 2, "reward_token": "special", "rm_head_type": "ranknet"},
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors")
+ "model_hash": "b3cc8e10b76ca98cde653daa5cf63139",
+ "model_name": "image_metrics_image_reward",
+ "model_class": "diffsynth.models.image_reward.ImageRewardModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsImageRewardStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors")
+ "model_hash": "306981222ec94302794e07cf676c84cc",
+ "model_name": "image_metrics_aesthetic",
+ "model_class": "diffsynth.models.aesthetic.AestheticModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsAestheticStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors")
+ "model_hash": "d4e9549be726259b444d1f62db4ce413",
+ "model_name": "image_metrics_fid_inception",
+ "model_class": "diffsynth.models.fid.FIDInceptionModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsFIDStateDictConverter",
+ },
+]
+
MODEL_CONFIGS = (
stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
- + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
+ + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + image_metrics_series
)
diff --git a/diffsynth/metrics/__init__.py b/diffsynth/metrics/__init__.py
new file mode 100644
index 000000000..f555d3a7a
--- /dev/null
+++ b/diffsynth/metrics/__init__.py
@@ -0,0 +1,22 @@
+from ..core import ModelConfig
+from .aesthetic import AestheticMetric
+from .base import Metric
+from .clip import CLIPMetric
+from .fid import FIDMetric
+from .hpsv2 import HPSv2Metric
+from .hpsv3 import HPSv3Metric
+from .image_reward import ImageRewardMetric
+from .pickscore import PickScoreMetric
+
+
+__all__ = [
+ "Metric",
+ "ModelConfig",
+ "PickScoreMetric",
+ "ImageRewardMetric",
+ "HPSv2Metric",
+ "HPSv3Metric",
+ "CLIPMetric",
+ "AestheticMetric",
+ "FIDMetric",
+]
diff --git a/diffsynth/metrics/aesthetic.py b/diffsynth/metrics/aesthetic.py
new file mode 100644
index 000000000..23d9b1957
--- /dev/null
+++ b/diffsynth/metrics/aesthetic.py
@@ -0,0 +1,42 @@
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.aesthetic import AestheticModel
+from .base import Metric
+from transformers import CLIPImageProcessor
+
+class AestheticMetric(Metric):
+ def __init__(self, model: AestheticModel):
+ super().__init__()
+ self.model = model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/"),
+ torch_dtype: torch.dtype = None,
+ device: torch.device = get_device_type(),
+ processor_kwargs: dict = None,
+ vram_limit: float = None,
+ ):
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_aesthetic")
+ processor_config.download_if_necessary()
+ model.processor = CLIPImageProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model.layers = model.layers.float()
+ model = model.eval()
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, images):
+ scores = self.model(images)
+ return self.tensor_to_list(scores)
+
+ def compute(self, images):
+ return self.score(images)
+
+ def forward(self, images):
+ return self.score(images)
diff --git a/diffsynth/metrics/base.py b/diffsynth/metrics/base.py
new file mode 100644
index 000000000..784bf8eb1
--- /dev/null
+++ b/diffsynth/metrics/base.py
@@ -0,0 +1,28 @@
+import torch
+from ..core import ModelConfig
+from ..models.model_loader import ModelPool
+
+class Metric(torch.nn.Module):
+
+ @staticmethod
+ def tensor_to_list(value):
+ if torch.is_tensor(value):
+ value = value.detach().cpu().tolist()
+ return value if isinstance(value, list) else [value]
+
+ @staticmethod
+ def download_and_load_models(model_configs: list[ModelConfig], torch_dtype: torch.dtype = torch.float32, device="cuda", vram_limit: float = None):
+ model_pool = ModelPool()
+ for model_config in model_configs:
+ model_config.download_if_necessary()
+ vram_config = model_config.vram_config()
+ vram_config["computation_dtype"] = vram_config["computation_dtype"] or torch_dtype or torch.float32
+ vram_config["computation_device"] = vram_config["computation_device"] or device
+ model_pool.auto_load_model(
+ model_config.path,
+ vram_config=vram_config,
+ vram_limit=vram_limit,
+ clear_parameters=model_config.clear_parameters,
+ state_dict=model_config.state_dict,
+ )
+ return model_pool
diff --git a/diffsynth/metrics/clip.py b/diffsynth/metrics/clip.py
new file mode 100644
index 000000000..f15ad95f0
--- /dev/null
+++ b/diffsynth/metrics/clip.py
@@ -0,0 +1,55 @@
+import torch
+from transformers import AutoProcessor
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.clip import CLIPModel
+from .base import Metric
+
+class CLIPMetric(Metric):
+ def __init__(self, model: CLIPModel):
+ super().__init__()
+ self.model = model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/"),
+ torch_dtype: torch.dtype = None,
+ device: torch.device = get_device_type(),
+ max_length: int = 77,
+ processor_kwargs: dict = None,
+ vram_limit: float = None,
+ ):
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_clip_hf")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model = CLIPModel(model=model, processor=processor, max_length=max_length).eval()
+ return cls(model)
+
+ @torch.no_grad()
+ def score(
+ self,
+ prompt: str | list[str],
+ images,
+ ):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ @torch.no_grad()
+ def similarity_matrix(
+ self,
+ prompt: str | list[str],
+ images,
+ ):
+ scores = self.model.similarity_matrix(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def compute(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/fid.py b/diffsynth/metrics/fid.py
new file mode 100644
index 000000000..1cc4429c1
--- /dev/null
+++ b/diffsynth/metrics/fid.py
@@ -0,0 +1,37 @@
+import torch
+
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.fid import FIDModel
+from .base import Metric
+
+
+class FIDMetric(Metric):
+ def __init__(self, model: FIDModel):
+ super().__init__()
+ self.model = model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors"),
+ device: torch.device = get_device_type(),
+ batch_size: int = 16,
+ num_workers: int = 0,
+ vram_limit: float = None,
+ ):
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch.float32, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_fid_inception")
+ model = FIDModel(model=model, device=device, batch_size=batch_size, num_workers=num_workers)
+ return cls(model)
+
+ @torch.no_grad()
+ def compute(self, reference_images, generated_images, batch_size: int = None, num_workers: int = None):
+ score = self.model.compute(reference_images, generated_images, batch_size=batch_size, num_workers=num_workers)
+ return score.detach().cpu().item() if torch.is_tensor(score) else float(score)
+
+ def statistics(self, images, batch_size: int = None, num_workers: int = None):
+ return self.model.statistics(images, batch_size=batch_size, num_workers=num_workers)
+
+ def forward(self, reference_images, generated_images):
+ return self.compute(reference_images, generated_images)
diff --git a/diffsynth/metrics/hpsv2.py b/diffsynth/metrics/hpsv2.py
new file mode 100644
index 000000000..309bb8d3a
--- /dev/null
+++ b/diffsynth/metrics/hpsv2.py
@@ -0,0 +1,41 @@
+import torch
+from transformers import AutoProcessor
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.hpsv2 import HPSv2Model
+from .base import Metric
+
+class HPSv2Metric(Metric):
+ def __init__(self, model: HPSv2Model):
+ super().__init__()
+ self.model = model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/"),
+ torch_dtype: torch.dtype = None,
+ device: torch.device = get_device_type(),
+ processor_kwargs: dict = None,
+ vram_limit: float = None,
+ ):
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_hpsv2")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model = HPSv2Model(model=model, processor=processor).eval()
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, prompt: str | list[str], images):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def compute(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/hpsv3.py b/diffsynth/metrics/hpsv3.py
new file mode 100644
index 000000000..83b93f45b
--- /dev/null
+++ b/diffsynth/metrics/hpsv3.py
@@ -0,0 +1,63 @@
+from transformers import AutoProcessor
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.hpsv3 import HPSv3Model
+from .base import Metric
+
+
+class HPSv3Metric(Metric):
+ def __init__(self, model: HPSv3Model):
+ super().__init__()
+ self.model = model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/"),
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: torch.device = get_device_type(),
+ score_index: int = 0,
+ use_special_tokens: bool = True,
+ max_pixels: int = 256 * 28 * 28,
+ min_pixels: int = 256 * 28 * 28,
+ processor_kwargs: dict = None,
+ vram_limit: float = None,
+ ):
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_hpsv3")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, padding_side="right", **processor_kwargs)
+ if use_special_tokens:
+ special_tokens = ["<|Reward|>"]
+ processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
+ model.special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
+ model.reward_token = "special"
+ model.config.tokenizer_padding_side = processor.tokenizer.padding_side
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
+ if hasattr(model.config, "text_config"):
+ model.config.text_config.pad_token_id = processor.tokenizer.pad_token_id
+ model.rm_head.to(torch.float32)
+ model = HPSv3Model(
+ model=model,
+ processor=processor,
+ use_special_tokens=use_special_tokens,
+ max_pixels=max_pixels,
+ min_pixels=min_pixels,
+ score_index=score_index,
+ ).eval()
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, prompt: str | list[str], images):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def compute(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/image_reward.py b/diffsynth/metrics/image_reward.py
new file mode 100644
index 000000000..85a042e8b
--- /dev/null
+++ b/diffsynth/metrics/image_reward.py
@@ -0,0 +1,48 @@
+import torch
+from transformers import BertTokenizer
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.image_reward import ImageRewardModel
+from .base import Metric
+
+class ImageRewardMetric(Metric):
+ def __init__(self, model: ImageRewardModel):
+ super().__init__()
+ self.model = model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors"),
+ tokenizer_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/"),
+ torch_dtype: torch.dtype = None,
+ device: torch.device = get_device_type(),
+ max_length: int = 35,
+ tokenizer_kwargs: dict = None,
+ vram_limit: float = None,
+ ):
+
+ tokenizer_kwargs = tokenizer_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_image_reward")
+ tokenizer_config.download_if_necessary()
+ tokenizer = BertTokenizer.from_pretrained(tokenizer_config.path, **tokenizer_kwargs)
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+ tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
+ tokenizer.enc_token_id = tokenizer.convert_tokens_to_ids("[ENC]")
+ model.tokenizer = tokenizer
+ model.max_length = max_length
+ model.mlp = model.mlp.float()
+ model = model.eval()
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, prompt: str | list[str], images):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def compute(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/pickscore.py b/diffsynth/metrics/pickscore.py
new file mode 100644
index 000000000..608e54484
--- /dev/null
+++ b/diffsynth/metrics/pickscore.py
@@ -0,0 +1,59 @@
+from transformers import AutoProcessor
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.pickscore import PickScoreModel
+from .base import Metric
+
+class PickScoreMetric(Metric):
+ def __init__(self, model: PickScoreModel):
+ super().__init__()
+ self.model = model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/"),
+ torch_dtype: torch.dtype = None,
+ device: torch.device = get_device_type(),
+ max_length: int = 77,
+ processor_kwargs: dict = None,
+ vram_limit: float = None,
+ ):
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_clip_hf")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model = PickScoreModel(model=model, processor=processor, max_length=max_length).eval()
+ return cls(model)
+
+ @torch.no_grad()
+ def score(
+ self,
+ prompt: str | list[str],
+ images,
+ ):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ @torch.no_grad()
+ def probabilities(
+ self,
+ prompt: str | list[str],
+ images,
+ ):
+ scores = self.model(prompt, images)
+ probabilities = torch.softmax(scores, dim=-1)
+ return self.tensor_to_list(probabilities)
+
+ def calc_probs(self, prompt: str | list[str], images):
+ return self.probabilities(prompt, images)
+
+ def compute(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: str | list[str], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/models/aesthetic.py b/diffsynth/models/aesthetic.py
new file mode 100644
index 000000000..b1ed41462
--- /dev/null
+++ b/diffsynth/models/aesthetic.py
@@ -0,0 +1,90 @@
+from typing import Union
+import torch
+from PIL import Image
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+class AestheticMLP(torch.nn.Module):
+ def __init__(self, input_size: int):
+ super().__init__()
+ self.input_size = input_size
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(input_size, 1024),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ torch.nn.Linear(16, 1),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+def _as_image_list(images: ImageInput):
+ if isinstance(images, Image.Image):
+ images = [images]
+ return [image.convert("RGB") for image in images]
+
+
+class AestheticModel(torch.nn.Module):
+ def __init__(
+ self,
+ mlp: AestheticMLP = None,
+ vision_model: torch.nn.Module = None,
+ visual_projection: torch.nn.Module = None,
+ processor=None
+ ):
+ super().__init__()
+ if vision_model is None:
+ vision_model, visual_projection = self.default_vision_model()
+ if mlp is None:
+ mlp = AestheticMLP(768)
+
+ self.vision_model = vision_model
+ self.visual_projection = visual_projection
+ self.processor = processor
+ self.layers = mlp.layers
+
+ @staticmethod
+ def default_vision_model():
+ from transformers import CLIPVisionConfig, CLIPVisionModel
+
+ config = CLIPVisionConfig(
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_attention_heads=16,
+ num_hidden_layers=24,
+ image_size=224,
+ patch_size=14,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ projection_dim=768,
+ )
+ return CLIPVisionModel(config), torch.nn.Linear(config.hidden_size, config.projection_dim, bias=False)
+
+ @property
+ def device(self):
+ return next(self.parameters(), torch.tensor([])).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters(), torch.tensor(0.0)).dtype
+
+ @torch.no_grad()
+ def get_image_features(self, images):
+ images = _as_image_list(images)
+ inputs = self.processor(images=images, return_tensors="pt")
+ pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.dtype)
+
+ image_features = self.vision_model(pixel_values=pixel_values, return_dict=True).pooler_output
+ image_features = self.visual_projection(image_features)
+
+ return torch.nn.functional.normalize(image_features, dim=-1)
+
+ @torch.no_grad()
+ def forward(self, images):
+ image_features = self.get_image_features(images)
+ return self.layers(image_features).squeeze(-1)
\ No newline at end of file
diff --git a/diffsynth/models/clip.py b/diffsynth/models/clip.py
new file mode 100644
index 000000000..723107d8c
--- /dev/null
+++ b/diffsynth/models/clip.py
@@ -0,0 +1,153 @@
+from typing import Union
+import torch
+from PIL import Image
+from transformers import CLIPModel as HFCLIPModel
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _feature_tensor(output, feature_name: str):
+ if torch.is_tensor(output):
+ return output
+ for name in ("image_embeds", "text_embeds", "pooler_output"):
+ value = getattr(output, name, None)
+ if torch.is_tensor(value):
+ return value
+ if isinstance(output, (list, tuple)):
+ for value in output:
+ if torch.is_tensor(value):
+ return value
+ raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
+
+
+class ImageMetricsCLIPModel(HFCLIPModel):
+ def __init__(self, variant: str = "h14"):
+ super().__init__(self.config(variant))
+
+ @staticmethod
+ def config(variant: str):
+ from transformers import CLIPConfig
+ return CLIPConfig(
+ projection_dim=1024,
+ logit_scale_init_value=2.6592,
+ text_config={
+ "hidden_size": 1024,
+ "intermediate_size": 4096,
+ "num_attention_heads": 16,
+ "num_hidden_layers": 24,
+ "max_position_embeddings": 77,
+ "vocab_size": 49408,
+ "hidden_act": "quick_gelu",
+ "layer_norm_eps": 1e-5,
+ "projection_dim": 1024,
+ "bos_token_id": 0,
+ "eos_token_id": 2,
+ "pad_token_id": 1,
+ },
+ vision_config={
+ "hidden_size": 1280,
+ "intermediate_size": 5120,
+ "num_attention_heads": 16,
+ "num_hidden_layers": 32,
+ "image_size": 224,
+ "patch_size": 14,
+ "hidden_act": "quick_gelu",
+ "layer_norm_eps": 1e-5,
+ "projection_dim": 1024,
+ },
+ )
+ raise ValueError(f"Unsupported ImageMetrics CLIP variant: {variant}")
+
+
+class CLIPModel(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, processor, max_length: int = 77):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+ self.max_length = max_length
+
+ @property
+ def device(self):
+ return next(self.parameters(), torch.tensor([])).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters(), torch.tensor(0.0)).dtype
+
+ def _normalize_pairs(self, text, images):
+ if isinstance(text, str):
+ text = [text]
+ else:
+ text = list(text)
+
+ if isinstance(images, Image.Image):
+ images = [images]
+ images = [image.convert("RGB") for image in images]
+
+ if len(text) == 1 and len(images) > 1:
+ text = text * len(images)
+ if len(images) == 1 and len(text) > 1:
+ images = images * len(text)
+
+ if len(text) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(text)} and {len(images)}.")
+ return text, images
+
+ def _processor_call(self, **kwargs):
+ inputs = self.processor(
+ padding=True,
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt",
+ **kwargs,
+ ).to(self.device)
+
+ if self.dtype != torch.float32:
+ inputs = {
+ name: (
+ value.to(dtype=self.dtype)
+ if torch.is_tensor(value) and torch.is_floating_point(value)
+ else value
+ )
+ for name, value in inputs.items()
+ }
+ return inputs
+
+ @torch.no_grad()
+ def get_image_features(self, images: ImageInput):
+ if isinstance(images, Image.Image):
+ images = [images]
+ images = [image.convert("RGB") for image in images]
+
+ image_inputs = self._processor_call(images=images)
+ image_features = _feature_tensor(self.model.get_image_features(**image_inputs), "image_features")
+
+ return torch.nn.functional.normalize(image_features, dim=-1)
+
+ @torch.no_grad()
+ def get_text_features(self, text: Union[str, list[str]]):
+ text_inputs = self._processor_call(text=text)
+ text_features = _feature_tensor(self.model.get_text_features(**text_inputs), "text_features")
+
+ return torch.nn.functional.normalize(text_features, dim=-1)
+
+ @torch.no_grad()
+ def similarity_matrix(self, text: Union[str, list[str]], images: ImageInput):
+ image_features = self.get_image_features(images)
+ text_features = self.get_text_features(text)
+
+ scores = text_features @ image_features.T
+ if hasattr(self.model, "logit_scale"):
+ scores = self.model.logit_scale.exp() * scores
+ return scores
+
+ @torch.no_grad()
+ def forward(self, text: Union[str, list[str]], images: ImageInput):
+ text, images = self._normalize_pairs(text, images)
+
+ image_features = self.get_image_features(images)
+ text_features = self.get_text_features(text)
+
+ scores = (text_features * image_features).sum(dim=-1)
+ if hasattr(self.model, "logit_scale"):
+ scores = self.model.logit_scale.exp() * scores
+ return scores
\ No newline at end of file
diff --git a/diffsynth/models/fid.py b/diffsynth/models/fid.py
new file mode 100644
index 000000000..af3671975
--- /dev/null
+++ b/diffsynth/models/fid.py
@@ -0,0 +1,238 @@
+import os
+from typing import Iterable, Union
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision import transforms
+from torchvision.models import inception_v3
+from torchvision.models.inception import InceptionA, InceptionC, InceptionE
+
+ImageInput = Union[str, os.PathLike, Image.Image]
+
+IMAGE_EXTENSIONS = {".bmp", ".jpg", ".jpeg", ".pgm", ".png", ".ppm", ".tif", ".tiff", ".webp"}
+
+
+def _image_files(path: Union[str, os.PathLike]):
+ path = os.fspath(path)
+ if os.path.isfile(path):
+ if os.path.splitext(path)[1].lower() not in IMAGE_EXTENSIONS:
+ raise ValueError(f"Unsupported image extension for FID: {path}")
+ return [path]
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"FID path does not exist: {path}")
+ files = []
+ for root, dirs, names in os.walk(path):
+ dirs.sort()
+ for name in sorted(names):
+ if os.path.splitext(name)[1].lower() in IMAGE_EXTENSIONS:
+ files.append(os.path.join(root, name))
+ if not files:
+ raise ValueError(f"No images found under {path}.")
+ return files
+
+
+class _ImageDataset(torch.utils.data.Dataset):
+ def __init__(self, images: Iterable[ImageInput], transform):
+ self.images = list(images)
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.images)
+
+ def __getitem__(self, index):
+ image = self.images[index]
+ if isinstance(image, (str, os.PathLike)):
+ image = Image.open(image)
+ if not isinstance(image, Image.Image):
+ raise TypeError(f"FID expects PIL images or image paths, but received {type(image)}.")
+ return self.transform(image.convert("RGB"))
+
+
+class FIDInceptionModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = _fid_inception_v3()
+
+ def forward(self, images):
+ images = 2 * images - 1
+ return self.model(images)
+
+
+def _fid_inception_v3(weights_path: str = None):
+ model = inception_v3(weights=None, aux_logits=False, num_classes=1008, init_weights=False)
+ model.Mixed_5b = _FIDInceptionA(192, pool_features=32)
+ model.Mixed_5c = _FIDInceptionA(256, pool_features=64)
+ model.Mixed_5d = _FIDInceptionA(288, pool_features=64)
+ model.Mixed_6b = _FIDInceptionC(768, channels_7x7=128)
+ model.Mixed_6c = _FIDInceptionC(768, channels_7x7=160)
+ model.Mixed_6d = _FIDInceptionC(768, channels_7x7=160)
+ model.Mixed_6e = _FIDInceptionC(768, channels_7x7=192)
+ model.Mixed_7b = _FIDInceptionE1(1280)
+ model.Mixed_7c = _FIDInceptionE2(2048)
+ if weights_path is not None:
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
+ model.fc = nn.Identity()
+ return model
+
+
+class _FIDInceptionA(InceptionA):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 1)
+
+
+class _FIDInceptionC(InceptionC):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 1)
+
+
+class _FIDInceptionE1(InceptionE):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = torch.cat([self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)], 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = torch.cat([self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)], 1)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 1)
+
+
+class _FIDInceptionE2(InceptionE):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = torch.cat([self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)], 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = torch.cat([self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)], 1)
+
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 1)
+
+
+class FIDModel(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, device: Union[str, torch.device] = "cpu", batch_size: int = 50, num_workers: int = 0):
+ super().__init__()
+ self.model = model
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize((299, 299), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ ]
+ )
+ self.to(device)
+
+ @property
+ def device(self):
+ try:
+ return next(self.model.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ def _as_images(self, images):
+ if isinstance(images, (str, os.PathLike)):
+ return _image_files(images)
+ if isinstance(images, Image.Image):
+ return [images]
+ return list(images)
+
+ @torch.no_grad()
+ def get_activations(self, images, batch_size: int = None, num_workers: int = None):
+ images = self._as_images(images)
+ batch_size = self.batch_size if batch_size is None else batch_size
+ num_workers = self.num_workers if num_workers is None else num_workers
+ dataset = _ImageDataset(images, transform=self.transform)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=min(batch_size, len(dataset)), shuffle=False, num_workers=num_workers)
+ activations = []
+ self.model.eval()
+ for batch in dataloader:
+ batch = batch.to(self.device)
+ features = self.model(batch)
+ if isinstance(features, tuple):
+ features = features[0]
+ if features.ndim == 4:
+ features = F.adaptive_avg_pool2d(features, output_size=(1, 1)).flatten(1)
+ activations.append(features.detach().cpu().to(torch.float64))
+ return torch.cat(activations, dim=0)
+
+ def statistics(self, images, batch_size: int = None, num_workers: int = None):
+ activations = self.get_activations(images, batch_size=batch_size, num_workers=num_workers)
+ return self.activation_statistics(activations)
+
+ @staticmethod
+ def activation_statistics(activations):
+ activations = activations.to(torch.float64)
+ mean = activations.mean(dim=0)
+ centered = activations - mean
+ if activations.shape[0] <= 1:
+ covariance = torch.zeros((activations.shape[1], activations.shape[1]), dtype=torch.float64)
+ else:
+ covariance = centered.T @ centered / (activations.shape[0] - 1)
+ return mean, covariance
+
+ @staticmethod
+ def _sqrtm_psd(matrix, eps: float = 1e-10):
+ matrix = (matrix + matrix.T) * 0.5
+ eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
+ eigenvalues = eigenvalues.clamp_min(eps).sqrt()
+ return (eigenvectors * eigenvalues.unsqueeze(0)) @ eigenvectors.T
+
+ @classmethod
+ def frechet_distance(cls, mean1, covariance1, mean2, covariance2, eps: float = 1e-6):
+ mean1 = mean1.to(torch.float64)
+ covariance1 = covariance1.to(torch.float64)
+ mean2 = mean2.to(torch.float64)
+ covariance2 = covariance2.to(torch.float64)
+ diff = mean1 - mean2
+ offset = torch.eye(covariance1.shape[0], dtype=torch.float64) * eps
+ sqrt_cov1 = cls._sqrtm_psd(covariance1 + offset)
+ covmean = cls._sqrtm_psd(sqrt_cov1 @ (covariance2 + offset) @ sqrt_cov1)
+ distance = diff.dot(diff) + torch.trace(covariance1) + torch.trace(covariance2) - 2 * torch.trace(covmean)
+ return distance.clamp_min(0)
+
+ def compute(self, reference_images, generated_images, batch_size: int = None, num_workers: int = None):
+ mean1, covariance1 = self.statistics(reference_images, batch_size=batch_size, num_workers=num_workers)
+ mean2, covariance2 = self.statistics(generated_images, batch_size=batch_size, num_workers=num_workers)
+ return self.frechet_distance(mean1, covariance1, mean2, covariance2)
+
+ def forward(self, reference_images, generated_images):
+ return self.compute(reference_images, generated_images)
diff --git a/diffsynth/models/hpsv2.py b/diffsynth/models/hpsv2.py
new file mode 100644
index 000000000..c8cc4c924
--- /dev/null
+++ b/diffsynth/models/hpsv2.py
@@ -0,0 +1,92 @@
+from typing import Union
+import torch
+from PIL import Image
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _as_list(value):
+ if isinstance(value, (list, tuple)):
+ return list(value)
+ return [value]
+
+def _feature_tensor(output, feature_name: str):
+ if torch.is_tensor(output):
+ return output
+ for name in ("image_embeds", "text_embeds", "pooler_output"):
+ value = getattr(output, name, None)
+ if torch.is_tensor(value):
+ return value
+ if isinstance(output, (list, tuple)):
+ for value in output:
+ if torch.is_tensor(value):
+ return value
+ raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
+
+
+class HPSv2Model(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, processor):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+
+ @property
+ def device(self):
+ return next(self.parameters(), torch.tensor([])).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters(), torch.tensor(0.0)).dtype
+
+ def _normalize_inputs(self, prompts, images):
+ images = _as_list(images)
+ prompts = _as_list(prompts)
+
+ if len(prompts) == 1 and len(images) > 1:
+ prompts = prompts * len(images)
+ if len(images) == 1 and len(prompts) > 1:
+ images = images * len(prompts)
+
+ if len(prompts) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
+ return prompts, images
+
+ @torch.no_grad()
+ def forward(self, prompts: Union[str, list[str]], images: ImageInput):
+ prompts, images = self._normalize_inputs(prompts, images)
+ images = [image.convert("RGB") for image in images]
+
+ inputs = self.processor(
+ text=prompts,
+ images=images,
+ padding=True,
+ truncation=True,
+ return_tensors="pt"
+ ).to(self.device)
+
+ if self.dtype != torch.float32:
+ inputs = {
+ name: (
+ value.to(dtype=self.dtype)
+ if torch.is_tensor(value) and torch.is_floating_point(value)
+ else value
+ )
+ for name, value in inputs.items()
+ }
+
+ image_features = _feature_tensor(
+ self.model.get_image_features(pixel_values=inputs["pixel_values"]),
+ "image_features",
+ )
+ text_features = _feature_tensor(
+ self.model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask")),
+ "text_features",
+ )
+
+ image_features = torch.nn.functional.normalize(image_features, dim=-1)
+ text_features = torch.nn.functional.normalize(text_features, dim=-1)
+
+ scores = (image_features * text_features).sum(dim=-1)
+ if hasattr(self.model, "logit_scale"):
+ scores = self.model.logit_scale.exp() * scores
+
+ return scores
\ No newline at end of file
diff --git a/diffsynth/models/hpsv3.py b/diffsynth/models/hpsv3.py
new file mode 100644
index 000000000..b22966b6d
--- /dev/null
+++ b/diffsynth/models/hpsv3.py
@@ -0,0 +1,353 @@
+import math
+from typing import Optional, Union
+import torch
+from PIL import Image
+from transformers import Qwen2VLForConditionalGeneration
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+HPSV3_INSTRUCTION = """
+You are tasked with evaluating a generated image based on Visual Quality and Text Alignment and give a overall score to estimate the human preference. Please provide a rating from 0 to 10, with 0 being the worst and 10 being the best.
+
+**Visual Quality:**
+Evaluate the overall visual quality of the image. The following sub-dimensions should be considered:
+- **Reasonableness:** The image should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups.
+- **Clarity:** Evaluate the sharpness and visibility of the image. The image should be clear and easy to interpret, with no blurring or indistinct areas.
+- **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows).
+- **Aesthetic and Creativity:** Assess the artistic aspects of the image, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance.
+- **Safety:** The image should not contain harmful or inappropriate content, such as political, violent, or adult material. If such content is present, the image quality and satisfaction score should be the lowest possible.
+
+**Text Alignment:**
+Assess how well the image matches the textual prompt across the following sub-dimensions:
+- **Subject Relevance** Evaluate how accurately the subject(s) in the image (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, and behavior.
+- **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the image adheres to this style.
+- **Contextual Consistency**: Assess whether the background, setting, and surrounding elements in the image logically fit the scenario described in the prompt. The environment should support and enhance the subject without contradictions.
+- **Attribute Fidelity**: Check if specific attributes mentioned in the prompt (e.g., colors, clothing, accessories, expressions, actions) are faithfully represented in the image. Minor deviations may be acceptable, but critical attributes should be preserved.
+- **Semantic Coherence**: Evaluate whether the overall meaning and intent of the prompt are captured in the image. The generated content should not introduce elements that conflict with or distort the original description.
+Textual prompt - {text_prompt}
+
+
+"""
+
+HPSV3_PROMPT_WITH_SPECIAL_TOKEN = """
+Please provide the overall ratings of this image: <|Reward|>
+
+END
+"""
+
+HPSV3_PROMPT_WITHOUT_SPECIAL_TOKEN = """
+Please provide the overall ratings of this image:
+"""
+
+def _as_list(value):
+ if isinstance(value, (list, tuple)):
+ return list(value)
+ return [value]
+
+def _round_by_factor(number, factor):
+ return round(number / factor) * factor
+
+def _ceil_by_factor(number, factor):
+ return math.ceil(number / factor) * factor
+
+def _floor_by_factor(number, factor):
+ return math.floor(number / factor) * factor
+
+def _smart_resize(height, width, factor=28, min_pixels=256 * 28 * 28, max_pixels=256 * 28 * 28):
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(f"Image aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}.")
+
+ h_bar = max(factor, _round_by_factor(height, factor))
+ w_bar = max(factor, _round_by_factor(width, factor))
+
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = _floor_by_factor(height / beta, factor)
+ w_bar = _floor_by_factor(width / beta, factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = _ceil_by_factor(height * beta, factor)
+ w_bar = _ceil_by_factor(width * beta, factor)
+
+ return h_bar, w_bar
+
+
+class HPSv3RewardModelMixin:
+ def init_reward_head(
+ self,
+ output_dim=2,
+ reward_token="special",
+ special_token_ids=None,
+ rm_head_type="ranknet",
+ rm_head_kwargs=None,
+ ):
+ self.output_dim = output_dim
+ self.reward_token = "special" if special_token_ids is not None else reward_token
+ self.special_token_ids = special_token_ids
+
+ hidden_size = getattr(self.config, "hidden_size", None)
+ if hidden_size is None and hasattr(self.config, "text_config"):
+ hidden_size = self.config.text_config.hidden_size
+
+ if rm_head_type == "ranknet":
+ rm_head_kwargs = rm_head_kwargs or {}
+ hidden = rm_head_kwargs.get("hidden_size", 1024)
+ dropout = rm_head_kwargs.get("dropout", 0.05)
+ self.rm_head = torch.nn.Sequential(
+ torch.nn.Linear(hidden_size, hidden),
+ torch.nn.ReLU(),
+ torch.nn.Dropout(dropout),
+ torch.nn.Linear(hidden, 16),
+ torch.nn.ReLU(),
+ torch.nn.Linear(16, output_dim),
+ )
+ else:
+ self.rm_head = torch.nn.Linear(hidden_size, output_dim, bias=False)
+
+ self.rm_head.to(torch.float32)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs.pop("logits_to_keep", None)
+ mm_token_type_ids = kwargs.pop("mm_token_type_ids", None)
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ rope_deltas=rope_deltas,
+ mm_token_type_ids=mm_token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0]
+ logits = self.rm_head(hidden_states.to(next(self.rm_head.parameters()).dtype))
+
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
+ pad_token_id = getattr(self.config, "pad_token_id", None)
+ if pad_token_id is None and hasattr(self.config, "text_config"):
+ pad_token_id = getattr(self.config.text_config, "pad_token_id", None)
+
+ if pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+
+ if pad_token_id is None:
+ sequence_lengths = -1
+ elif input_ids is not None:
+ sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ if self.reward_token == "last":
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+ elif self.reward_token == "mean":
+ valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1)
+ pooled_logits = torch.stack([logits[i, : valid_lengths[i]].mean(dim=0) for i in range(batch_size)])
+ elif self.reward_token == "special":
+ if self.special_token_ids is None:
+ raise ValueError("HPSv3 reward_token='special' requires special_token_ids.")
+ special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
+ for special_token_id in self.special_token_ids:
+ special_token_mask = special_token_mask | (input_ids == special_token_id)
+ pooled_logits = logits[special_token_mask, ...].view(batch_size, -1)
+ else:
+ raise ValueError(f"Invalid HPSv3 reward token mode: {self.reward_token}")
+
+ return {"logits": pooled_logits}
+
+
+class HPSv3Qwen2VLRewardModel(HPSv3RewardModelMixin, Qwen2VLForConditionalGeneration):
+ def __init__(
+ self,
+ config=None,
+ vocab_size=None,
+ output_dim=2,
+ reward_token="special",
+ special_token_ids=None,
+ rm_head_type="ranknet",
+ rm_head_kwargs=None,
+ ):
+ if config is None:
+ config = self.default_config(vocab_size or 151658)
+ elif vocab_size is not None and hasattr(config, "text_config"):
+ config.text_config.vocab_size = vocab_size
+
+ super().__init__(config)
+ self.init_reward_head(
+ output_dim=output_dim,
+ reward_token=reward_token,
+ special_token_ids=special_token_ids,
+ rm_head_type=rm_head_type,
+ rm_head_kwargs=rm_head_kwargs,
+ )
+
+ @staticmethod
+ def default_config(vocab_size=151658):
+ from transformers import Qwen2VLConfig
+
+ return Qwen2VLConfig(
+ text_config={
+ "vocab_size": vocab_size,
+ "hidden_size": 3584,
+ "intermediate_size": 18944,
+ "num_hidden_layers": 28,
+ "num_attention_heads": 28,
+ "num_key_value_heads": 4,
+ "hidden_act": "silu",
+ "max_position_embeddings": 32768,
+ "initializer_range": 0.02,
+ "rms_norm_eps": 1e-6,
+ "use_cache": True,
+ "use_sliding_window": False,
+ "sliding_window": 32768,
+ "max_window_layers": 28,
+ "attention_dropout": 0.0,
+ "rope_parameters": {
+ "rope_type": "default",
+ "type": "mrope",
+ "mrope_section": [16, 24, 24],
+ "rope_theta": 1000000.0,
+ },
+ "bos_token_id": 151643,
+ "eos_token_id": 151645,
+ },
+ vision_config={
+ "depth": 32,
+ "embed_dim": 1280,
+ "hidden_size": 3584,
+ "mlp_ratio": 4,
+ "num_heads": 16,
+ "in_channels": 3,
+ "patch_size": 14,
+ "spatial_merge_size": 2,
+ "temporal_patch_size": 2,
+ },
+ image_token_id=151655,
+ video_token_id=151656,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ tie_word_embeddings=False,
+ )
+
+
+class HPSv3Model(torch.nn.Module):
+ def __init__(
+ self,
+ model,
+ processor,
+ use_special_tokens=True,
+ max_pixels=256 * 28 * 28,
+ min_pixels=256 * 28 * 28,
+ score_index=0,
+ ):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+ self.use_special_tokens = use_special_tokens
+ self.max_pixels = max_pixels
+ self.min_pixels = min_pixels
+ self.score_index = score_index
+
+ @property
+ def device(self):
+ return next(self.parameters(), torch.tensor([])).device
+
+ def _normalize_inputs(self, prompts, images):
+ images = _as_list(images)
+ prompts = _as_list(prompts)
+
+ if len(prompts) == 1 and len(images) > 1:
+ prompts = prompts * len(images)
+ if len(images) == 1 and len(prompts) > 1:
+ images = images * len(prompts)
+
+ if len(prompts) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
+ return prompts, images
+
+ def _prepare_images(self, images):
+ prepared = []
+ for image in images:
+ image = image.convert("RGB")
+ height, width = image.height, image.width
+
+ resized_height, resized_width = _smart_resize(
+ height,
+ width,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ prepared.append(image.resize((resized_width, resized_height), Image.BICUBIC))
+
+ return prepared
+
+ def _messages(self, prompts, images):
+ suffix = HPSV3_PROMPT_WITH_SPECIAL_TOKEN if self.use_special_tokens else HPSV3_PROMPT_WITHOUT_SPECIAL_TOKEN
+ messages = []
+
+ for prompt, image in zip(prompts, images):
+ messages.append(
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": image},
+ {"type": "text", "text": HPSV3_INSTRUCTION.format(text_prompt=prompt) + suffix},
+ ],
+ }
+ ]
+ )
+ return messages
+
+ def _prepare_batch(self, prompts, images):
+ images = self._prepare_images(images)
+ messages = self._messages(prompts, images)
+
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ batch = self.processor(text=text, images=images, padding=True, return_tensors="pt")
+
+ return batch.to(self.device)
+
+ @torch.no_grad()
+ def forward(self, prompts: Union[str, list[str]], images):
+ prompts, images = self._normalize_inputs(prompts, images)
+ batch = self._prepare_batch(prompts, images)
+
+ rewards = self.model(return_dict=True, **batch)["logits"]
+ if rewards.ndim == 2:
+ return rewards[:, self.score_index]
+
+ return rewards
\ No newline at end of file
diff --git a/diffsynth/models/image_reward.py b/diffsynth/models/image_reward.py
new file mode 100644
index 000000000..da3f39fc2
--- /dev/null
+++ b/diffsynth/models/image_reward.py
@@ -0,0 +1,206 @@
+from typing import Union
+import torch
+from PIL import Image
+from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
+from torchvision.transforms import InterpolationMode
+
+BICUBIC = InterpolationMode.BICUBIC
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _convert_image_to_rgb(image):
+ return image.convert("RGB")
+
+def _image_transform(image_size):
+ return Compose(
+ [
+ Resize(image_size, interpolation=BICUBIC),
+ CenterCrop(image_size),
+ _convert_image_to_rgb,
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ]
+ )
+
+def _as_list(value):
+ if isinstance(value, (list, tuple)):
+ return list(value)
+ return [value]
+
+
+class ImageRewardMLP(torch.nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(input_size, 1024),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ torch.nn.Linear(16, 1),
+ )
+
+ for name, param in self.layers.named_parameters():
+ if "weight" in name:
+ torch.nn.init.normal_(param, mean=0.0, std=1.0 / (input_size + 1))
+ if "bias" in name:
+ torch.nn.init.constant_(param, val=0)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class ImageRewardModel(torch.nn.Module):
+ def __init__(self, blip=None, tokenizer=None, image_size=224, max_length=35, mean=0.16717362830052426, std=1.0333394966054072):
+ super().__init__()
+ if blip is None:
+ blip = self.default_blip_model()
+
+ self.blip = blip
+ self.tokenizer = tokenizer
+ self.preprocess = _image_transform(image_size)
+ self.max_length = max_length
+ self.mlp = ImageRewardMLP(blip.config.text_config.hidden_size)
+
+ self.register_buffer("score_mean", torch.tensor(float(mean)), persistent=False)
+ self.register_buffer("score_std", torch.tensor(float(std)), persistent=False)
+
+ @staticmethod
+ def default_blip_model():
+ from transformers import BlipConfig, BlipForImageTextRetrieval
+
+ vision_hidden_size = 1024
+ text_config = ImageRewardModel._load_text_config(None)
+ config = BlipConfig(
+ vision_config={
+ "hidden_size": vision_hidden_size,
+ "intermediate_size": vision_hidden_size * 4,
+ "num_hidden_layers": 24,
+ "num_attention_heads": 16,
+ "image_size": 224,
+ "patch_size": 16,
+ "hidden_act": "gelu",
+ "layer_norm_eps": 1e-6,
+ },
+ text_config={
+ **text_config,
+ "vocab_size": 30524,
+ "encoder_hidden_size": vision_hidden_size,
+ "add_cross_attention": True,
+ "is_decoder": True,
+ },
+ projection_dim=256,
+ )
+ return BlipForImageTextRetrieval(config)
+
+ @staticmethod
+ def _load_text_config(med_config_path):
+ return {
+ "hidden_size": 768,
+ "intermediate_size": 3072,
+ "num_hidden_layers": 12,
+ "num_attention_heads": 12,
+ "max_position_embeddings": 512,
+ "vocab_size": 30524,
+ "hidden_act": "gelu",
+ "layer_norm_eps": 1e-12,
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_dropout_prob": 0.1,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ }
+
+ @staticmethod
+ def convert_key_value(key, value):
+ if key.startswith("blip.visual_encoder."):
+ suffix = key[len("blip.visual_encoder.") :]
+
+ if suffix == "cls_token":
+ return "blip.vision_model.embeddings.class_embedding", value
+ if suffix == "pos_embed":
+ return "blip.vision_model.embeddings.position_embedding", value
+ if suffix.startswith("patch_embed.proj."):
+ return "blip.vision_model.embeddings.patch_embedding." + suffix[len("patch_embed.proj.") :], value
+
+ if suffix.startswith("blocks."):
+ parts = suffix.split(".")
+ layer = parts[1]
+ rest = ".".join(parts[2:])
+ prefix = f"blip.vision_model.encoder.layers.{layer}."
+ mapping = {
+ "norm1.": "layer_norm1.",
+ "attn.qkv.": "self_attn.qkv.",
+ "attn.proj.": "self_attn.projection.",
+ "norm2.": "layer_norm2.",
+ "mlp.fc1.": "mlp.fc1.",
+ "mlp.fc2.": "mlp.fc2.",
+ }
+ for source, target in mapping.items():
+ if rest.startswith(source):
+ return prefix + target + rest[len(source) :], value
+
+ if suffix.startswith("norm."):
+ return "blip.vision_model.post_layernorm." + suffix[len("norm.") :], value
+
+ return None, value
+ return key, value
+
+ @property
+ def device(self):
+ return next(self.parameters(), torch.tensor([])).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters(), torch.tensor(0.0)).dtype
+
+ def _tokenize(self, prompts):
+ return self.tokenizer(
+ prompts,
+ padding="max_length",
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt",
+ ).to(self.device)
+
+ def _preprocess_images(self, images):
+ tensors = [self.preprocess(image.convert("RGB")) for image in images]
+ return torch.stack(tensors, dim=0).to(device=self.device, dtype=self.dtype)
+
+ def _normalize_inputs(self, prompts, images):
+ images = _as_list(images)
+ prompts = _as_list(prompts)
+
+ if len(prompts) == 1 and len(images) > 1:
+ prompts = prompts * len(images)
+ if len(images) == 1 and len(prompts) > 1:
+ images = images * len(prompts)
+
+ if len(prompts) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
+
+ return prompts, images
+
+ @torch.no_grad()
+ def forward(self, prompts: Union[str, list[str]], images):
+ prompts, images = self._normalize_inputs(prompts, images)
+ text_input = self._tokenize(prompts)
+ image_tensor = self._preprocess_images(images)
+
+ image_output = self.blip.vision_model(pixel_values=image_tensor, return_dict=True)
+ image_embeds = image_output.last_hidden_state
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=self.device)
+
+ text_output = self.blip.text_encoder(
+ input_ids=text_input.input_ids,
+ attention_mask=text_input.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ text_features = text_output.last_hidden_state[:, 0, :]
+ rewards = self.mlp(text_features).squeeze(-1)
+ rewards = (rewards - self.score_mean) / self.score_std
+
+ return rewards
\ No newline at end of file
diff --git a/diffsynth/models/pickscore.py b/diffsynth/models/pickscore.py
new file mode 100644
index 000000000..7012c3b93
--- /dev/null
+++ b/diffsynth/models/pickscore.py
@@ -0,0 +1,84 @@
+from typing import Union
+import torch
+from PIL import Image
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _feature_tensor(output, feature_name: str):
+ if torch.is_tensor(output):
+ return output
+ for name in ("image_embeds", "text_embeds", "pooler_output"):
+ value = getattr(output, name, None)
+ if torch.is_tensor(value):
+ return value
+ if isinstance(output, (list, tuple)):
+ for value in output:
+ if torch.is_tensor(value):
+ return value
+ raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
+
+
+class PickScoreModel(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, processor, max_length: int = 77):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+ self.max_length = max_length
+
+ @property
+ def device(self):
+ return next(self.parameters(), torch.tensor([])).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters(), torch.tensor(0.0)).dtype
+
+ def _processor_call(self, **kwargs):
+ inputs = self.processor(
+ padding=True,
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt",
+ **kwargs,
+ ).to(self.device)
+
+ if self.dtype != torch.float32:
+ inputs = {
+ name: (
+ value.to(dtype=self.dtype)
+ if torch.is_tensor(value) and torch.is_floating_point(value)
+ else value
+ )
+ for name, value in inputs.items()
+ }
+
+ return inputs
+
+ @torch.no_grad()
+ def get_image_features(self, images: ImageInput):
+ if isinstance(images, Image.Image):
+ images = [images]
+ images = [image.convert("RGB") for image in images]
+
+ image_inputs = self._processor_call(images=images)
+ image_features = _feature_tensor(self.model.get_image_features(**image_inputs), "image_features")
+
+ return torch.nn.functional.normalize(image_features, dim=-1)
+
+ @torch.no_grad()
+ def get_text_features(self, text: Union[str, list[str]]):
+ text_inputs = self._processor_call(text=text)
+ text_features = _feature_tensor(self.model.get_text_features(**text_inputs), "text_features")
+
+ return torch.nn.functional.normalize(text_features, dim=-1)
+
+ @torch.no_grad()
+ def forward(self, text: Union[str, list[str]], images: ImageInput):
+ image_features = self.get_image_features(images)
+ text_features = self.get_text_features(text)
+
+ scores = self.model.logit_scale.exp() * (text_features @ image_features.T)
+ if isinstance(text, str):
+ scores = scores[0]
+
+ return scores
\ No newline at end of file
diff --git a/diffsynth/utils/state_dict_converters/image_metrics.py b/diffsynth/utils/state_dict_converters/image_metrics.py
new file mode 100644
index 000000000..30c8b55a3
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/image_metrics.py
@@ -0,0 +1,121 @@
+import torch
+
+
+def ImageMetricsCLIPStateDictConverter(state_dict):
+ return {
+ key: state_dict[key]
+ for key in state_dict
+ if key not in ("text_model.embeddings.position_ids", "vision_model.embeddings.position_ids")
+ }
+
+
+def ImageMetricsOpenCLIPStateDictConverter(state_dict):
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ if key == "logit_scale":
+ converted["logit_scale"] = value
+ elif key == "token_embedding.weight":
+ converted["text_model.embeddings.token_embedding.weight"] = value
+ elif key == "positional_embedding":
+ converted["text_model.embeddings.position_embedding.weight"] = value
+ elif key.startswith("ln_final."):
+ converted["text_model.final_layer_norm." + key[len("ln_final.") :]] = value
+ elif key == "text_projection":
+ converted["text_projection.weight"] = value.T
+ elif key == "visual.class_embedding":
+ converted["vision_model.embeddings.class_embedding"] = value
+ elif key == "visual.conv1.weight":
+ converted["vision_model.embeddings.patch_embedding.weight"] = value
+ elif key == "visual.positional_embedding":
+ converted["vision_model.embeddings.position_embedding.weight"] = value
+ elif key.startswith("visual.ln_pre."):
+ converted["vision_model.pre_layrnorm." + key[len("visual.ln_pre.") :]] = value
+ elif key.startswith("visual.ln_post."):
+ converted["vision_model.post_layernorm." + key[len("visual.ln_post.") :]] = value
+ elif key == "visual.proj":
+ converted["visual_projection.weight"] = value.T
+ elif key.startswith("transformer.resblocks."):
+ converted.update(_convert_open_clip_resblock("text_model.encoder.layers", key[len("transformer.resblocks.") :], value))
+ elif key.startswith("visual.transformer.resblocks."):
+ converted.update(_convert_open_clip_resblock("vision_model.encoder.layers", key[len("visual.transformer.resblocks.") :], value))
+ return converted
+
+
+def ImageMetricsImageRewardStateDictConverter(state_dict):
+ from diffsynth.models.image_reward import ImageRewardModel
+
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ if key.startswith("module."):
+ key = key[len("module.") :]
+ new_key, new_value = ImageRewardModel.convert_key_value(key, value)
+ if new_key is not None and new_key != "blip.text_encoder.embeddings.position_ids":
+ converted[new_key] = new_value
+ hidden_size = converted["blip.text_encoder.embeddings.word_embeddings.weight"].shape[1]
+ converted["blip.itm_head.weight"] = torch.zeros((2, hidden_size), dtype=converted["blip.text_encoder.embeddings.word_embeddings.weight"].dtype)
+ converted["blip.itm_head.bias"] = torch.zeros((2,), dtype=converted["blip.text_encoder.embeddings.word_embeddings.weight"].dtype)
+ return converted
+
+
+def ImageMetricsAestheticStateDictConverter(state_dict):
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ for prefix in ("model.", "module.", "aesthetic_model.", "aesthetics_predictor.", "predictor."):
+ if key.startswith(prefix):
+ key = key[len(prefix) :]
+ if key == "vision_model.embeddings.position_ids":
+ continue
+ converted[key] = value
+ return converted
+
+
+def ImageMetricsFIDStateDictConverter(state_dict):
+ return {"model." + key: state_dict[key] for key in state_dict if not key.startswith("fc.")}
+
+
+def ImageMetricsHPSv3StateDictConverter(state_dict):
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ if key.startswith("visual."):
+ key = "model.visual." + key[len("visual.") :]
+ elif key.startswith("model.visual."):
+ pass
+ elif key.startswith("model.") and not key.startswith("model.language_model."):
+ key = "model.language_model." + key[len("model.") :]
+ converted[key] = value
+ return converted
+
+
+def _convert_open_clip_resblock(prefix, suffix, value):
+ converted = {}
+ parts = suffix.split(".")
+ layer = parts[0]
+ rest = ".".join(parts[1:])
+ layer_prefix = f"{prefix}.{layer}."
+ if rest == "attn.in_proj_weight":
+ q, k, v = value.chunk(3, dim=0)
+ converted[layer_prefix + "self_attn.q_proj.weight"] = q
+ converted[layer_prefix + "self_attn.k_proj.weight"] = k
+ converted[layer_prefix + "self_attn.v_proj.weight"] = v
+ elif rest == "attn.in_proj_bias":
+ q, k, v = value.chunk(3, dim=0)
+ converted[layer_prefix + "self_attn.q_proj.bias"] = q
+ converted[layer_prefix + "self_attn.k_proj.bias"] = k
+ converted[layer_prefix + "self_attn.v_proj.bias"] = v
+ else:
+ mapping = {
+ "attn.out_proj.": "self_attn.out_proj.",
+ "ln_1.": "layer_norm1.",
+ "ln_2.": "layer_norm2.",
+ "mlp.c_fc.": "mlp.fc1.",
+ "mlp.c_proj.": "mlp.fc2.",
+ }
+ for source, target in mapping.items():
+ if rest.startswith(source):
+ converted[layer_prefix + target + rest[len(source) :]] = value
+ break
+ return converted
diff --git a/docs/en/Model_Details/Image-Quality-Metrics.md b/docs/en/Model_Details/Image-Quality-Metrics.md
new file mode 100644
index 000000000..7d3f87139
--- /dev/null
+++ b/docs/en/Model_Details/Image-Quality-Metrics.md
@@ -0,0 +1,117 @@
+# Image Quality Evaluation Metrics
+
+DiffSynth-Studio provides a suite of image quality evaluation metrics and reward models in `diffsynth.metrics` to assess text alignment, aesthetic quality, human preference, and image distribution quality of generated images. Example code for these metrics can be found in [`examples/image_quality_metric/`](../../../examples/image_quality_metric/).
+
+## Installation
+
+Before using this project for model inference and training, please install DiffSynth-Studio first.
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
+
+## Quick Start
+
+Run the following code to quickly load PickScore and score an image against a prompt. The default models will be downloaded from ModelScope to `./models`.
+
+```python
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"PickScore score:: {score:.3f}")
+```
+
+## Metrics Overview
+
+| Metric | Input | Output | Example Code |
+| --- | --- | --- | --- |
+| PickScore | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/pickscore.py) |
+| ImageReward | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/image_reward.py) |
+| HPSv2 | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv2.py) |
+| HPSv3 | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv3.py) |
+| CLIP Score | prompt + PIL Image | Text-Image Similarity | [code](../../../examples/image_quality_metric/clipscore.py) |
+| Aesthetic | PIL Image | Aesthetic Score | [code](../../../examples/image_quality_metric/aesthetic.py) |
+| FID | reference image directory + generated image directory | Distribution Distance | [code](../../../examples/image_quality_metric/fid.py) |
+
+### Text-Image Alignment and Preference Evaluation
+
+Applicable metrics: **PickScore**, **ImageReward**, **HPSv2**, **HPSv3**, **CLIP Score**
+
+These models are used to evaluate whether an image follows the prompt and aligns with human visual preferences. They must receive both the `prompt` and the `image` simultaneously.
+
+**Basic Scoring**
+
+```python
+score = metric.compute(prompt, image)[0]
+```
+
+**Batch Scoring**
+
+If you need to evaluate multiple images, you can directly pass a list:
+
+```python
+scores = metric.compute("a cute cat", [image1, image2, image3])
+
+scores = metric.compute(["a cat", "a dog"], [image_cat, image_dog])
+```
+
+When prompt is a single string, the same prompt will be applied to every image. When prompt is a list of strings, the number of prompts must exactly match the number of images.
+
+### Pure Image Aesthetics Evaluation
+
+Applicable metric: **Aesthetic**
+
+This model solely evaluates aesthetic features such as the composition, color, and clarity of the image itself. It does not require a prompt.
+
+```python
+from diffsynth.metrics import AestheticMetric
+
+metric = AestheticMetric.from_pretrained(device="cuda")
+score = metric.compute(image)[0]
+```
+
+### Dataset Distribution Evaluation
+
+Applicable metric: **FID** (Fréchet Inception Distance)
+
+FID does not score individual images; instead, it compares the overall feature distribution distance between a real reference image set and a generated image set. A lower score indicates that the generated distribution is closer to the real distribution.
+
+```python
+from diffsynth.metrics import FIDMetric
+
+reference_dir = "path/to/real_reference_images"
+generated_dir = "path/to/model_generated_images"
+
+metric = FIDMetric.from_pretrained(device="cuda", batch_size=16)
+fid_score = metric.compute(reference_dir, generated_dir)
+print(f"FID: {fid_score:.3f}")
+```
+
+The baseline for FID is not fixed or unique. For general image generation, COCO Validation is commonly used; for specific domains (such as medical images or e-commerce products), a `reference_dir` composed of real data from that specific domain should be provided.
+
+## Important Notes
+
+* The scores from PickScore, ImageReward, HPSv2, HPSv3, CLIPScore, and Aesthetic are suitable for relative comparison within the same metric. It is not recommended to directly compare the numerical values across different metrics.
+* HPSv3 is based on Qwen2-VL and is a larger model, requiring significantly more VRAM than CLIP-based metrics.
+* FID is sensitive to the choice of reference, the reference sample size, and the generated sample size.
diff --git a/docs/en/Model_Details/Overview.md b/docs/en/Model_Details/Overview.md
index 286141e83..fecead8bd 100644
--- a/docs/en/Model_Details/Overview.md
+++ b/docs/en/Model_Details/Overview.md
@@ -289,3 +289,49 @@ graph LR;
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)
+
+
+## Image Quality Evaluation Metrics
+
+Documentation: [./Image-Quality-Metrics.md](../Model_Details/Image-Quality-Metrics.md)
+
+
+
+Quick Start
+
+```python
+import csv
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+
+print("PickScore score:", metric.compute(prompt, image)[0])
+```
+
+
+
+| Metric | Default Model | Example Code |
+|-|-|-|
+|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](../../../examples/image_quality_metric/pickscore.py)|
+|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](../../../examples/image_quality_metric/image_reward.py)|
+|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](../../../examples/image_quality_metric/hpsv2.py)|
+|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](../../../examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](../../../examples/image_quality_metric/clipscore.py)|
+|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](../../../examples/image_quality_metric/aesthetic.py)|
+|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](../../../examples/image_quality_metric/fid.py)|
\ No newline at end of file
diff --git a/docs/zh/Model_Details/Image-Quality-Metrics.md b/docs/zh/Model_Details/Image-Quality-Metrics.md
new file mode 100644
index 000000000..20cb846a7
--- /dev/null
+++ b/docs/zh/Model_Details/Image-Quality-Metrics.md
@@ -0,0 +1,116 @@
+# 图像质量评估指标
+
+DiffSynth-Studio 在 `diffsynth.metrics` 中提供了一组图像质量评估指标和奖励模型,用于评估生成图像的文本对齐、审美质量、人类偏好和图像分布质量。这些指标的示例代码位于 [`examples/image_quality_metric/`](../../../examples/image_quality_metric/)。
+
+## 安装
+
+在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
+
+## 快速开始
+
+运行以下代码可以快速加载 PickScore,并对一张图像和一段提示词进行评分。默认模型会从 ModelScope 下载到 `./models`。
+
+```python
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"PickScore score:: {score:.3f}")
+```
+
+## 指标总览
+
+|指标|输入|输出|示例代码|
+|-|-|-|-|
+|PickScore|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/pickscore.py)|
+|ImageReward|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/image_reward.py)|
+|HPSv2|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv2.py)|
+|HPSv3|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|prompt + PIL 图像|图文匹配度|[code](../../../examples/image_quality_metric/clipscore.py)|
+|Aesthetic|PIL 图像|美学分数|[code](../../../examples/image_quality_metric/aesthetic.py)|
+|FID|reference 图像目录 + generated 图像目录|分布距离|[code](../../../examples/image_quality_metric/fid.py)|
+
+### 文本-图像对齐与偏好评估
+
+适用指标: **PickScore**,**ImageReward**,**HPSv2**,**HPSv3**,**CLIP Score**
+
+这类模型用于评估图像是否遵循提示词以及是否符合人类视觉偏好。它们必须同时接收 `prompt` 和 `image`。
+
+**基础打分**
+```python
+score = metric.compute(prompt, image)[0]
+```
+
+**批量打分**
+如果需要评估多张图像,可以直接传入列表:
+
+```python
+scores = metric.compute("a cute cat", [image1, image2, image3])
+
+scores = metric.compute(["a cat", "a dog"], [image_cat, image_dog])
+```
+
+其中 prompt 为单个字符串时,会对每张图像使用同一个 prompt。prompt 为字符串列表时,prompt 数量需要和图像数量一致。
+
+### 纯图像美学评估
+
+适用指标: **Aesthetic**
+
+该模型仅评估图像本身的构图、色彩、清晰度等美学特征,不需要提示词介入。
+
+
+```python
+from diffsynth.metrics import AestheticMetric
+
+metric = AestheticMetric.from_pretrained(device="cuda")
+score = metric.compute(image)[0]
+```
+
+### 数据集分布评估
+适用指标: **FID** (Fréchet Inception Distance)
+
+FID 不对单张图片打分,而是比较真实参考图像集与生成图像集的整体特征分布距离。分数越低,说明生成分布越接近真实分布。
+
+```python
+from diffsynth.metrics import FIDMetric
+
+reference_dir = "path/to/real_reference_images"
+generated_dir = "path/to/model_generated_images"
+
+metric = FIDMetric.from_pretrained(device="cuda", batch_size=16)
+fid_score = metric.compute(reference_dir, generated_dir)
+print(f"FID: {fid_score:.3f}")
+```
+
+FID 的基准不是固定唯一的。对于通用图像生成,常使用 COCO Validation;如果是特定领域(如医学图像、电商商品),应提供该领域真实数据构成的 `reference_dir`。
+
+
+## 注意事项
+
+* PickScore、ImageReward、HPSv2、HPSv3、CLIPScore、Aesthetic 的分数适合做同一指标内部的相对比较,不建议直接把不同指标的数值大小相互比较。
+* HPSv3 基于 Qwen2-VL,模型较大,显存需求明显高于 CLIP 类指标。
+* FID 对 reference 选择、样本量和 generated 样本量较敏感。
diff --git a/docs/zh/Model_Details/Overview.md b/docs/zh/Model_Details/Overview.md
index cdfdce9f2..2c1d32849 100644
--- a/docs/zh/Model_Details/Overview.md
+++ b/docs/zh/Model_Details/Overview.md
@@ -286,3 +286,49 @@ graph LR;
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
+
+
+## 图像质量评估指标
+
+文档:[./Image-Quality-Metrics.md](../Model_Details/Image-Quality-Metrics.md)
+
+
+
+快速开始
+
+```python
+import csv
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+
+print("PickScore score:", metric.compute(prompt, image)[0])
+```
+
+
+
+|指标|默认模型|示例代码|
+|-|-|-|
+|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](../../../examples/image_quality_metric/pickscore.py)|
+|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](../../../examples/image_quality_metric/image_reward.py)|
+|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](../../../examples/image_quality_metric/hpsv2.py)|
+|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](../../../examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](../../../examples/image_quality_metric/clipscore.py)|
+|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](../../../examples/image_quality_metric/aesthetic.py)|
+|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](../../../examples/image_quality_metric/fid.py)|
\ No newline at end of file
diff --git a/examples/image_quality_metric/aesthetic.py b/examples/image_quality_metric/aesthetic.py
new file mode 100644
index 000000000..ae0b518fe
--- /dev/null
+++ b/examples/image_quality_metric/aesthetic.py
@@ -0,0 +1,20 @@
+from diffsynth.metrics import AestheticMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+device = "cuda"
+
+metric = AestheticMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(image)[0]
+print(f"Aesthetic score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/clipscore.py b/examples/image_quality_metric/clipscore.py
new file mode 100644
index 000000000..4a6621541
--- /dev/null
+++ b/examples/image_quality_metric/clipscore.py
@@ -0,0 +1,21 @@
+from diffsynth.metrics import CLIPMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = CLIPMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"CLIP score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/fid.py b/examples/image_quality_metric/fid.py
new file mode 100644
index 000000000..94a982b47
--- /dev/null
+++ b/examples/image_quality_metric/fid.py
@@ -0,0 +1,13 @@
+from diffsynth.metrics import FIDMetric, ModelConfig
+
+reference_dir = ""
+generated_dir = ""
+device = "cuda"
+
+metric = FIDMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors"),
+ device=device,
+)
+
+score = metric.compute(reference_dir, generated_dir)
+print(f"FID score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/hpsv2.py b/examples/image_quality_metric/hpsv2.py
new file mode 100644
index 000000000..c460357f6
--- /dev/null
+++ b/examples/image_quality_metric/hpsv2.py
@@ -0,0 +1,21 @@
+from diffsynth.metrics import HPSv2Metric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = HPSv2Metric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"HPSv2 score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/hpsv3.py b/examples/image_quality_metric/hpsv3.py
new file mode 100644
index 000000000..d6917f3d5
--- /dev/null
+++ b/examples/image_quality_metric/hpsv3.py
@@ -0,0 +1,21 @@
+from diffsynth.metrics import HPSv3Metric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = HPSv3Metric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"HPSv3 score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/image_reward.py b/examples/image_quality_metric/image_reward.py
new file mode 100644
index 000000000..e8e7df5b6
--- /dev/null
+++ b/examples/image_quality_metric/image_reward.py
@@ -0,0 +1,21 @@
+from diffsynth.metrics import ImageRewardMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = ImageRewardMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"ImageReward score: {score:.3f}")
diff --git a/examples/image_quality_metric/pickscore.py b/examples/image_quality_metric/pickscore.py
new file mode 100644
index 000000000..b409f3b3c
--- /dev/null
+++ b/examples/image_quality_metric/pickscore.py
@@ -0,0 +1,21 @@
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"PickScore score:: {score:.3f}")
\ No newline at end of file