From d9ca99199b5810b280cf4937db3363817475a5eb Mon Sep 17 00:00:00 2001
From: yjy415 <2471352175@qq.com>
Date: Mon, 18 May 2026 16:44:02 +0800
Subject: [PATCH 1/4] add: image quality metrics
---
diffsynth/metrics/__init__.py | 22 +
diffsynth/metrics/aesthetic.py | 52 +++
diffsynth/metrics/base.py | 35 ++
diffsynth/metrics/clip.py | 64 +++
diffsynth/metrics/fid.py | 240 +++++++++++
diffsynth/metrics/hpsv2.py | 56 +++
diffsynth/metrics/hpsv3.py | 65 +++
diffsynth/metrics/image_reward.py | 100 +++++
diffsynth/metrics/pickscore.py | 73 ++++
diffsynth/models/aesthetic.py | 211 ++++++++++
diffsynth/models/clip.py | 134 ++++++
diffsynth/models/fid.py | 304 ++++++++++++++
diffsynth/models/hpsv2.py | 219 ++++++++++
diffsynth/models/hpsv3.py | 394 ++++++++++++++++++
diffsynth/models/image_reward.py | 305 ++++++++++++++
diffsynth/models/pickscore.py | 107 +++++
.../en/Model_Details/Image-Quality-Metrics.md | 117 ++++++
docs/en/Model_Details/Overview.md | 39 ++
.../zh/Model_Details/Image-Quality-Metrics.md | 118 ++++++
docs/zh/Model_Details/Overview.md | 39 ++
examples/image_quality_metric/aesthetic.py | 13 +
examples/image_quality_metric/clipscore.py | 14 +
examples/image_quality_metric/fid.py | 16 +
examples/image_quality_metric/hpsv2.py | 16 +
examples/image_quality_metric/hpsv3.py | 15 +
examples/image_quality_metric/image_reward.py | 14 +
examples/image_quality_metric/pickscore.py | 15 +
27 files changed, 2797 insertions(+)
create mode 100644 diffsynth/metrics/__init__.py
create mode 100644 diffsynth/metrics/aesthetic.py
create mode 100644 diffsynth/metrics/base.py
create mode 100644 diffsynth/metrics/clip.py
create mode 100644 diffsynth/metrics/fid.py
create mode 100644 diffsynth/metrics/hpsv2.py
create mode 100644 diffsynth/metrics/hpsv3.py
create mode 100644 diffsynth/metrics/image_reward.py
create mode 100644 diffsynth/metrics/pickscore.py
create mode 100644 diffsynth/models/aesthetic.py
create mode 100644 diffsynth/models/clip.py
create mode 100644 diffsynth/models/fid.py
create mode 100644 diffsynth/models/hpsv2.py
create mode 100644 diffsynth/models/hpsv3.py
create mode 100644 diffsynth/models/image_reward.py
create mode 100644 diffsynth/models/pickscore.py
create mode 100644 docs/en/Model_Details/Image-Quality-Metrics.md
create mode 100644 docs/zh/Model_Details/Image-Quality-Metrics.md
create mode 100644 examples/image_quality_metric/aesthetic.py
create mode 100644 examples/image_quality_metric/clipscore.py
create mode 100644 examples/image_quality_metric/fid.py
create mode 100644 examples/image_quality_metric/hpsv2.py
create mode 100644 examples/image_quality_metric/hpsv3.py
create mode 100644 examples/image_quality_metric/image_reward.py
create mode 100644 examples/image_quality_metric/pickscore.py
diff --git a/diffsynth/metrics/__init__.py b/diffsynth/metrics/__init__.py
new file mode 100644
index 000000000..f555d3a7a
--- /dev/null
+++ b/diffsynth/metrics/__init__.py
@@ -0,0 +1,22 @@
+from ..core import ModelConfig
+from .aesthetic import AestheticMetric
+from .base import Metric
+from .clip import CLIPMetric
+from .fid import FIDMetric
+from .hpsv2 import HPSv2Metric
+from .hpsv3 import HPSv3Metric
+from .image_reward import ImageRewardMetric
+from .pickscore import PickScoreMetric
+
+
+__all__ = [
+ "Metric",
+ "ModelConfig",
+ "PickScoreMetric",
+ "ImageRewardMetric",
+ "HPSv2Metric",
+ "HPSv3Metric",
+ "CLIPMetric",
+ "AestheticMetric",
+ "FIDMetric",
+]
diff --git a/diffsynth/metrics/aesthetic.py b/diffsynth/metrics/aesthetic.py
new file mode 100644
index 000000000..c2f676638
--- /dev/null
+++ b/diffsynth/metrics/aesthetic.py
@@ -0,0 +1,52 @@
+from typing import Union
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.aesthetic import AestheticModel
+from .base import Metric
+
+class AestheticMetric(Metric):
+ def __init__(self, model: AestheticModel):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def default_model_config():
+ return AestheticMetric.local_or_modelscope_config("AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: Union[ModelConfig, str] = None,
+ clip_config: Union[ModelConfig, str] = None,
+ clip_processor_config: Union[ModelConfig, str] = None,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = get_device_type(),
+ clip_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ model_config = cls.default_model_config() if model_config is None else model_config
+ model_config = cls.resolve_model_config(model_config)
+ clip_config = cls.resolve_model_config(clip_config) if clip_config is not None else None
+ clip_processor_config = cls.resolve_model_config(clip_processor_config) if clip_processor_config is not None else clip_config
+ model = AestheticModel.from_pretrained(
+ model_path=model_config.path,
+ clip_model_path=None if clip_config is None else clip_config.path,
+ clip_processor_path=None if clip_processor_config is None else clip_processor_config.path,
+ torch_dtype=torch_dtype,
+ device=device,
+ clip_kwargs=clip_kwargs,
+ processor_kwargs=processor_kwargs,
+ )
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, images):
+ scores = self.model(images)
+ return self.tensor_to_list(scores)
+
+ def calc_scores(self, images):
+ return self.score(images)
+
+ def forward(self, images):
+ return self.score(images)
diff --git a/diffsynth/metrics/base.py b/diffsynth/metrics/base.py
new file mode 100644
index 000000000..3ad95a28f
--- /dev/null
+++ b/diffsynth/metrics/base.py
@@ -0,0 +1,35 @@
+from pathlib import Path
+from typing import Union
+import torch
+from ..core import ModelConfig
+
+class Metric(torch.nn.Module):
+ @staticmethod
+ def tensor_to_list(value):
+ if torch.is_tensor(value):
+ value = value.detach().cpu().tolist()
+ if not isinstance(value, list):
+ return [value]
+ return value
+
+ @staticmethod
+ def tensor_to_float(value):
+ if torch.is_tensor(value):
+ return float(value.detach().cpu())
+ return float(value)
+
+ @staticmethod
+ def resolve_model_config(config: Union[ModelConfig, str, Path]):
+ if isinstance(config, (str, Path)):
+ config = ModelConfig(path=str(config))
+ if config is None:
+ return None
+ config.download_if_necessary()
+ return config
+
+ @staticmethod
+ def local_or_modelscope_config(model_id: str, origin_file_pattern: str = ""):
+ local_path = Path("./models") / model_id
+ if local_path.exists():
+ return ModelConfig(path=str(local_path))
+ return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
diff --git a/diffsynth/metrics/clip.py b/diffsynth/metrics/clip.py
new file mode 100644
index 000000000..94c0219ed
--- /dev/null
+++ b/diffsynth/metrics/clip.py
@@ -0,0 +1,64 @@
+from typing import Union
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.clip import CLIPModel
+from .base import Metric
+
+class CLIPMetric(Metric):
+ def __init__(self, model: CLIPModel):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def default_model_config():
+ return CLIPMetric.local_or_modelscope_config("AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: Union[ModelConfig, str] = None,
+ processor_config: Union[ModelConfig, str] = None,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = get_device_type(),
+ max_length: int = 77,
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ model_config = cls.default_model_config() if model_config is None else model_config
+ model_config = cls.resolve_model_config(model_config)
+ processor_config = cls.resolve_model_config(processor_config) if processor_config is not None else model_config
+ model = CLIPModel.from_pretrained(
+ model_path=model_config.path,
+ processor_path=processor_config.path,
+ torch_dtype=torch_dtype,
+ device=device,
+ max_length=max_length,
+ model_kwargs=model_kwargs,
+ processor_kwargs=processor_kwargs,
+ )
+ return cls(model)
+
+ @torch.no_grad()
+ def score(
+ self,
+ prompt: Union[str, list[str]],
+ images,
+ ):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ @torch.no_grad()
+ def similarity_matrix(
+ self,
+ prompt: Union[str, list[str]],
+ images,
+ ):
+ scores = self.model.similarity_matrix(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def calc_scores(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/fid.py b/diffsynth/metrics/fid.py
new file mode 100644
index 000000000..b740b0068
--- /dev/null
+++ b/diffsynth/metrics/fid.py
@@ -0,0 +1,240 @@
+import csv
+import json
+import os
+from pathlib import Path
+from urllib.parse import urlparse
+from urllib.request import urlopen
+from zipfile import ZipFile
+from typing import Union
+
+import torch
+
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.fid import FIDModel, IMAGE_EXTENSIONS
+from .base import Metric
+
+
+class FIDMetric(Metric):
+ DEFAULT_REFERENCE_NAME = "coco_2014_caption_validation"
+ DEFAULT_REFERENCE_DATASET_ID = "modelscope/coco_2014_caption"
+ DEFAULT_REFERENCE_METADATA_URL = "https://modelscope.oss-cn-beijing.aliyuncs.com/open_data/coco_2014_caption/val2014.csv.zip"
+ DEFAULT_REFERENCE_SPLIT = "validation"
+
+ def __init__(self, model: FIDModel):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def default_weights_config():
+ return ModelConfig(
+ model_id="diffusionTry/weights-inception-2015-12-05-6726825d",
+ origin_file_pattern="weights-inception-2015-12-05-6726825d.pth",
+ )
+
+ @staticmethod
+ def default_reference_root():
+ base_path = os.environ.get("DIFFSYNTH_DATA_BASE_PATH", "./data")
+ return Path(base_path) / "fid_reference" / FIDMetric.DEFAULT_REFERENCE_NAME
+
+ @staticmethod
+ def _image_files(path: Union[str, Path]):
+ path = Path(path)
+ if not path.exists():
+ return []
+ return sorted(item for item in path.rglob("*") if item.is_file() and item.suffix.lower() in IMAGE_EXTENSIONS)
+
+ @staticmethod
+ def _download_file(url: str, path: Path, timeout: int = 60, retries: int = 3):
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if path.exists() and path.stat().st_size > 0:
+ return path
+ temp_path = path.with_suffix(path.suffix + ".tmp")
+ last_error = None
+ for _ in range(retries):
+ try:
+ with urlopen(url, timeout=timeout) as response, open(temp_path, "wb") as file:
+ while True:
+ chunk = response.read(1024 * 1024)
+ if not chunk:
+ break
+ file.write(chunk)
+ temp_path.replace(path)
+ return path
+ except Exception as error:
+ last_error = error
+ if temp_path.exists():
+ temp_path.unlink()
+ raise RuntimeError(f"Failed to download {url}: {last_error}")
+
+ @staticmethod
+ def _metadata_rows(metadata_zip_path: Path):
+ with ZipFile(metadata_zip_path) as archive:
+ csv_names = [name for name in archive.namelist() if name.endswith(".csv")]
+ if not csv_names:
+ raise ValueError(f"No CSV file found in {metadata_zip_path}.")
+ with archive.open(csv_names[0]) as file:
+ reader = csv.DictReader(line.decode("utf-8") for line in file)
+ rows = []
+ seen = set()
+ for row in reader:
+ url = row.get("image", "")
+ if not url or url in seen:
+ continue
+ rows.append(row)
+ seen.add(url)
+ return rows
+
+ @staticmethod
+ def _image_filename(row: dict, index: int):
+ url_path = urlparse(row["image"]).path
+ name = Path(url_path).name
+ if name and Path(name).suffix.lower() in IMAGE_EXTENSIONS:
+ return name
+ image_id = row.get("image_id") or f"{index:08d}"
+ return f"{image_id}.jpg"
+
+ @classmethod
+ def download_reference_dir(
+ cls,
+ local_dir: Union[str, Path] = None,
+ max_images: int = None,
+ force: bool = False,
+ metadata_url: str = None,
+ timeout: int = 60,
+ retries: int = 3,
+ verbose: bool = True,
+ ):
+ """
+ Download the default COCO 2014 caption validation reference images.
+
+ The ModelScope dataset stores a small CSV archive whose image column
+ points to ModelScope OSS image URLs. This helper downloads that metadata
+ and materializes the referenced real images as a normal image directory.
+ """
+
+ root = Path(local_dir) if local_dir is not None else cls.default_reference_root()
+ images_dir = root / "images"
+ metadata_dir = root / "metadata"
+ metadata_url = cls.DEFAULT_REFERENCE_METADATA_URL if metadata_url is None else metadata_url
+ existing = cls._image_files(images_dir)
+ manifest_path = root / "reference_manifest.json"
+ if not force and existing:
+ if max_images is not None and len(existing) >= max_images:
+ return str(images_dir)
+ if max_images is None and manifest_path.exists():
+ with open(manifest_path, "r", encoding="utf-8") as file:
+ manifest = json.load(file)
+ if manifest.get("max_images") is None and len(existing) >= manifest.get("image_count", 0):
+ return str(images_dir)
+
+ metadata_zip_path = metadata_dir / "val2014.csv.zip"
+ cls._download_file(metadata_url, metadata_zip_path, timeout=timeout, retries=retries)
+ rows = cls._metadata_rows(metadata_zip_path)
+ if max_images is not None:
+ rows = rows[:max_images]
+ if not rows:
+ raise ValueError("No reference images were found in the COCO 2014 caption metadata.")
+
+ images_dir.mkdir(parents=True, exist_ok=True)
+ downloaded = 0
+ for index, row in enumerate(rows):
+ image_path = images_dir / cls._image_filename(row, index)
+ if not force and image_path.exists() and image_path.stat().st_size > 0:
+ downloaded += 1
+ continue
+ cls._download_file(row["image"], image_path, timeout=timeout, retries=retries)
+ downloaded += 1
+ if verbose and downloaded % 100 == 0:
+ print(f"Downloaded {downloaded}/{len(rows)} FID reference images to {images_dir}")
+
+ manifest = {
+ "name": cls.DEFAULT_REFERENCE_NAME,
+ "dataset_id": cls.DEFAULT_REFERENCE_DATASET_ID,
+ "split": cls.DEFAULT_REFERENCE_SPLIT,
+ "metadata_url": metadata_url,
+ "max_images": max_images,
+ "image_count": len(rows),
+ "images_dir": str(images_dir),
+ }
+ with open(manifest_path, "w", encoding="utf-8") as file:
+ json.dump(manifest, file, indent=2, ensure_ascii=False)
+ return str(images_dir)
+
+ @classmethod
+ def default_reference_dir(
+ cls,
+ local_dir: Union[str, Path] = None,
+ max_images: int = None,
+ download: bool = True,
+ **download_kwargs,
+ ):
+ root = Path(local_dir) if local_dir is not None else cls.default_reference_root()
+ images_dir = root / "images"
+ existing = cls._image_files(images_dir)
+ if existing:
+ if max_images is not None and len(existing) >= max_images:
+ return str(images_dir)
+ manifest_path = root / "reference_manifest.json"
+ if max_images is None and manifest_path.exists():
+ with open(manifest_path, "r", encoding="utf-8") as file:
+ manifest = json.load(file)
+ if manifest.get("max_images") is None and len(existing) >= manifest.get("image_count", 0):
+ return str(images_dir)
+ if not download:
+ raise FileNotFoundError(
+ f"FID reference directory does not exist: {images_dir}. "
+ "Call FIDMetric.download_reference_dir(...) first or pass your own reference directory."
+ )
+ return cls.download_reference_dir(local_dir=root, max_images=max_images, **download_kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ weights_config: Union[ModelConfig, str] = None,
+ pretrained: bool = True,
+ device: Union[str, torch.device] = get_device_type(),
+ batch_size: int = 50,
+ num_workers: int = 0,
+ use_fid_inception: bool = True,
+ ):
+ if weights_config is None and use_fid_inception:
+ weights_config = cls.default_weights_config()
+ weights_config = cls.resolve_model_config(weights_config) if weights_config is not None else None
+ model = FIDModel.from_pretrained(
+ weights_path=None if weights_config is None else weights_config.path,
+ pretrained=pretrained,
+ device=device,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ use_fid_inception=use_fid_inception,
+ )
+ return cls(model)
+
+ @torch.no_grad()
+ def compute(self, reference_images, generated_images, batch_size: int = None, num_workers: int = None):
+ score = self.model.compute(reference_images, generated_images, batch_size=batch_size, num_workers=num_workers)
+ return self.tensor_to_float(score)
+
+ @torch.no_grad()
+ def compute_with_default_reference(
+ self,
+ generated_images,
+ reference_dir: Union[str, Path] = None,
+ max_reference_images: int = None,
+ batch_size: int = None,
+ num_workers: int = None,
+ download_kwargs: dict = None,
+ ):
+ reference_dir = self.default_reference_dir(
+ local_dir=reference_dir,
+ max_images=max_reference_images,
+ **({} if download_kwargs is None else download_kwargs),
+ )
+ return self.compute(reference_dir, generated_images, batch_size=batch_size, num_workers=num_workers)
+
+ def statistics(self, images, batch_size: int = None, num_workers: int = None):
+ return self.model.statistics(images, batch_size=batch_size, num_workers=num_workers)
+
+ def forward(self, reference_images, generated_images):
+ return self.compute(reference_images, generated_images)
diff --git a/diffsynth/metrics/hpsv2.py b/diffsynth/metrics/hpsv2.py
new file mode 100644
index 000000000..4760ee11c
--- /dev/null
+++ b/diffsynth/metrics/hpsv2.py
@@ -0,0 +1,56 @@
+from typing import Union
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.hpsv2 import HPSv2Model
+from .base import Metric
+
+class HPSv2Metric(Metric):
+ def __init__(self, model: HPSv2Model):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def default_model_config():
+ return HPSv2Metric.local_or_modelscope_config("AI-ModelScope/HPSv2")
+
+ @staticmethod
+ def default_processor_config():
+ return HPSv2Metric.local_or_modelscope_config("AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: Union[ModelConfig, str] = None,
+ processor_config: Union[ModelConfig, str] = None,
+ version: str = "v2.0",
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = get_device_type(),
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ model_config = cls.default_model_config() if model_config is None else model_config
+ processor_config = cls.default_processor_config() if processor_config is None else processor_config
+ model_config = cls.resolve_model_config(model_config)
+ processor_config = cls.resolve_model_config(processor_config)
+ model = HPSv2Model.from_pretrained(
+ model_path=model_config.path,
+ processor_path=processor_config.path,
+ version=version,
+ torch_dtype=torch_dtype,
+ device=device,
+ model_kwargs=model_kwargs,
+ processor_kwargs=processor_kwargs,
+ )
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, prompt: Union[str, list[str]], images):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def calc_scores(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/hpsv3.py b/diffsynth/metrics/hpsv3.py
new file mode 100644
index 000000000..9c6de7bfd
--- /dev/null
+++ b/diffsynth/metrics/hpsv3.py
@@ -0,0 +1,65 @@
+from typing import Union
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.hpsv3 import HPSv3Model
+from .base import Metric
+
+
+class HPSv3Metric(Metric):
+ def __init__(self, model: HPSv3Model):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def default_model_config():
+ return HPSv3Metric.local_or_modelscope_config("MizzenAI/HPSv3")
+
+ @staticmethod
+ def default_base_model_config():
+ return HPSv3Metric.local_or_modelscope_config("Qwen/Qwen2-VL-7B-Instruct")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: Union[ModelConfig, str] = None,
+ base_model_config: Union[ModelConfig, str] = None,
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = get_device_type(),
+ output_dim: int = 2,
+ score_index: int = 0,
+ use_special_tokens: bool = True,
+ max_pixels: int = 256 * 28 * 28,
+ min_pixels: int = 256 * 28 * 28,
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ model_config = cls.default_model_config() if model_config is None else model_config
+ base_model_config = cls.default_base_model_config() if base_model_config is None else base_model_config
+ model_config = cls.resolve_model_config(model_config)
+ base_model_config = cls.resolve_model_config(base_model_config)
+ model = HPSv3Model.from_pretrained(
+ model_path=model_config.path,
+ base_model_path=base_model_config.path,
+ torch_dtype=torch_dtype,
+ device=device,
+ output_dim=output_dim,
+ score_index=score_index,
+ use_special_tokens=use_special_tokens,
+ max_pixels=max_pixels,
+ min_pixels=min_pixels,
+ model_kwargs=model_kwargs,
+ processor_kwargs=processor_kwargs,
+ )
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, prompt: Union[str, list[str]], images):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def calc_scores(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/image_reward.py b/diffsynth/metrics/image_reward.py
new file mode 100644
index 000000000..51ccd7348
--- /dev/null
+++ b/diffsynth/metrics/image_reward.py
@@ -0,0 +1,100 @@
+import os
+from pathlib import Path
+from typing import Union
+
+import torch
+from modelscope import snapshot_download
+
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.image_reward import ImageRewardModel
+from .base import Metric
+
+class ImageRewardMetric(Metric):
+ BERT_TOKENIZER_MODEL_ID = "AI-ModelScope/bert-base-uncased"
+ BERT_TOKENIZER_FILES = [
+ "config.json",
+ "tokenizer_config.json",
+ "tokenizer.json",
+ "vocab.txt",
+ ]
+
+ def __init__(self, model: ImageRewardModel):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def default_model_config():
+ return ImageRewardMetric.local_or_modelscope_config("ZhipuAI/ImageReward")
+
+ @staticmethod
+ def default_tokenizer_config():
+ local_path = Path(os.environ.get("DIFFSYNTH_MODEL_BASE_PATH", "./models")) / ImageRewardMetric.BERT_TOKENIZER_MODEL_ID
+ if all((local_path / filename).exists() for filename in ImageRewardMetric.BERT_TOKENIZER_FILES):
+ return ModelConfig(path=str(local_path))
+ return ModelConfig(path=str(ImageRewardMetric.download_default_tokenizer()))
+
+ @staticmethod
+ def download_default_tokenizer():
+ local_path = Path(os.environ.get("DIFFSYNTH_MODEL_BASE_PATH", "./models")) / ImageRewardMetric.BERT_TOKENIZER_MODEL_ID
+ if all((local_path / filename).exists() for filename in ImageRewardMetric.BERT_TOKENIZER_FILES):
+ return local_path
+ local_path.mkdir(parents=True, exist_ok=True)
+ snapshot_download(
+ ImageRewardMetric.BERT_TOKENIZER_MODEL_ID,
+ local_dir=str(local_path),
+ allow_file_pattern=ImageRewardMetric.BERT_TOKENIZER_FILES,
+ local_files_only=False,
+ )
+ missing = [filename for filename in ImageRewardMetric.BERT_TOKENIZER_FILES if not (local_path / filename).exists()]
+ if missing:
+ raise FileNotFoundError(f"Missing ImageReward tokenizer files under {local_path}: {missing}")
+ return local_path
+
+ @staticmethod
+ def _as_directory_path(path):
+ if isinstance(path, list):
+ if len(path) == 0:
+ raise FileNotFoundError("Downloaded tokenizer files are empty.")
+ return str(Path(path[0]).parent)
+ return path
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: Union[ModelConfig, str] = None,
+ med_config: Union[ModelConfig, str] = None,
+ tokenizer_config: Union[ModelConfig, str] = None,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = get_device_type(),
+ max_length: int = 35,
+ model_kwargs: dict = None,
+ tokenizer_kwargs: dict = None,
+ ):
+ model_config = cls.default_model_config() if model_config is None else model_config
+ tokenizer_config = cls.default_tokenizer_config() if tokenizer_config is None else tokenizer_config
+ model_config = cls.resolve_model_config(model_config)
+ med_config = cls.resolve_model_config(med_config) if med_config is not None else None
+ tokenizer_config = cls.resolve_model_config(tokenizer_config)
+ model = ImageRewardModel.from_pretrained(
+ model_path=model_config.path,
+ med_config_path=None if med_config is None else med_config.path,
+ tokenizer_path=cls._as_directory_path(tokenizer_config.path),
+ torch_dtype=torch_dtype,
+ device=device,
+ max_length=max_length,
+ model_kwargs=model_kwargs,
+ tokenizer_kwargs=tokenizer_kwargs,
+ )
+ return cls(model)
+
+ @torch.no_grad()
+ def score(self, prompt: Union[str, list[str]], images):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ def calc_scores(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/metrics/pickscore.py b/diffsynth/metrics/pickscore.py
new file mode 100644
index 000000000..a71ae2619
--- /dev/null
+++ b/diffsynth/metrics/pickscore.py
@@ -0,0 +1,73 @@
+from typing import Union
+import torch
+from ..core import ModelConfig
+from ..core.device.npu_compatible_device import get_device_type
+from ..models.pickscore import PickScoreModel
+from .base import Metric
+
+class PickScoreMetric(Metric):
+ def __init__(self, model: PickScoreModel):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def default_model_config():
+ return PickScoreMetric.local_or_modelscope_config("AI-ModelScope/PickScore_v1")
+
+ @staticmethod
+ def default_processor_config():
+ return PickScoreMetric.local_or_modelscope_config("AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_config: Union[ModelConfig, str] = None,
+ processor_config: Union[ModelConfig, str] = None,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = get_device_type(),
+ max_length: int = 77,
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ model_config = cls.default_model_config() if model_config is None else model_config
+ processor_config = cls.default_processor_config() if processor_config is None else processor_config
+ model_config = cls.resolve_model_config(model_config)
+ processor_config = cls.resolve_model_config(processor_config)
+ model = PickScoreModel.from_pretrained(
+ model_path=model_config.path,
+ processor_path=processor_config.path,
+ torch_dtype=torch_dtype,
+ device=device,
+ max_length=max_length,
+ model_kwargs=model_kwargs,
+ processor_kwargs=processor_kwargs,
+ )
+ return cls(model)
+
+ @torch.no_grad()
+ def score(
+ self,
+ prompt: Union[str, list[str]],
+ images,
+ ):
+ scores = self.model(prompt, images)
+ return self.tensor_to_list(scores)
+
+ @torch.no_grad()
+ def probabilities(
+ self,
+ prompt: Union[str, list[str]],
+ images,
+ ):
+ scores = self.model(prompt, images)
+ probabilities = torch.softmax(scores, dim=-1)
+ return self.tensor_to_list(probabilities)
+
+ def calc_probs(self, prompt: Union[str, list[str]], images):
+ return self.probabilities(prompt, images)
+
+ def calc_scores(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
+
+ def forward(self, prompt: Union[str, list[str]], images):
+ return self.score(prompt, images)
diff --git a/diffsynth/models/aesthetic.py b/diffsynth/models/aesthetic.py
new file mode 100644
index 000000000..a9cc20124
--- /dev/null
+++ b/diffsynth/models/aesthetic.py
@@ -0,0 +1,211 @@
+from pathlib import Path
+from typing import Union
+import json
+import torch
+import torch.nn as nn
+from PIL import Image
+from .clip import CLIPModel
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+class AestheticMLP(nn.Module):
+ def __init__(self, input_size: int):
+ super().__init__()
+ self.input_size = input_size
+ self.layers = nn.Sequential(
+ nn.Linear(input_size, 1024),
+ nn.Dropout(0.2),
+ nn.Linear(1024, 128),
+ nn.Dropout(0.2),
+ nn.Linear(128, 64),
+ nn.Dropout(0.1),
+ nn.Linear(64, 16),
+ nn.Linear(16, 1),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+def _as_image_list(images: Union[ImageInput, list[ImageInput], tuple[ImageInput, ...]]):
+ if isinstance(images, Image.Image):
+ images = [images]
+ return [image.convert("RGB") for image in images]
+
+
+class AestheticModel(torch.nn.Module):
+ def __init__(self, mlp: AestheticMLP, clip_model: CLIPModel = None, vision_model: torch.nn.Module = None, processor=None):
+ super().__init__()
+ self.clip_model = clip_model
+ self.vision_model = vision_model
+ self.processor = processor
+ self.mlp = mlp
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_path: str,
+ clip_model_path: str = None,
+ clip_processor_path: str = None,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = "cpu",
+ clip_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ checkpoint = cls._load_checkpoint(model_path)
+ model = cls._from_full_predictor(
+ model_path=model_path,
+ checkpoint=checkpoint,
+ torch_dtype=torch_dtype,
+ device=device,
+ processor_kwargs=processor_kwargs,
+ )
+ return model
+
+ @classmethod
+ def _from_full_predictor(
+ cls,
+ model_path: str,
+ checkpoint: dict,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = "cuda",
+ processor_kwargs: dict = None,
+ ):
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
+
+ processor_kwargs = {} if processor_kwargs is None else processor_kwargs
+ config = cls._load_vision_config(model_path)
+ vision_model = CLIPVisionModelWithProjection(config)
+ mlp = AestheticMLP(config.projection_dim)
+ normalized = cls._normalize_checkpoint_keys(checkpoint)
+ vision_state = {}
+ mlp_state = {}
+ for key, value in normalized.items():
+ if key.startswith("layers."):
+ mlp_state[key] = value
+ elif key in vision_model.state_dict():
+ vision_state[key] = value
+ if not vision_state:
+ raise ValueError(f"Cannot find CLIP vision tower weights in Aesthetic checkpoint under {model_path}.")
+ vision_model.load_state_dict(vision_state, strict=False)
+ mlp.load_state_dict(mlp_state, strict=True)
+ processor = AutoProcessor.from_pretrained(model_path, **processor_kwargs)
+ if torch_dtype is not None:
+ vision_model = vision_model.to(dtype=torch_dtype)
+ vision_model = vision_model.to(device).eval()
+ mlp = mlp.to(device).float().eval()
+ return cls(vision_model=vision_model, processor=processor, mlp=mlp).eval()
+
+ @staticmethod
+ def _load_vision_config(model_path):
+ from transformers import CLIPVisionConfig
+
+ config_path = Path(model_path) / "config.json"
+ if not config_path.exists():
+ raise FileNotFoundError(f"Cannot find Aesthetic config.json under {model_path}.")
+ with open(config_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ config_data = data.get("vision_config", data)
+ if "projection_dim" not in config_data and "projection_dim" in data:
+ config_data = dict(config_data)
+ config_data["projection_dim"] = data["projection_dim"]
+ allowed = {
+ "attention_dropout",
+ "dropout",
+ "hidden_act",
+ "hidden_size",
+ "image_size",
+ "initializer_factor",
+ "initializer_range",
+ "intermediate_size",
+ "layer_norm_eps",
+ "num_attention_heads",
+ "num_channels",
+ "num_hidden_layers",
+ "patch_size",
+ "projection_dim",
+ }
+ return CLIPVisionConfig(**{key: value for key, value in config_data.items() if key in allowed})
+
+ @staticmethod
+ def _find_checkpoint(path):
+ path = Path(path)
+ if path.is_file():
+ return path
+ names = [
+ "model.safetensors",
+ "pytorch_model.bin",
+ "sac+logos+ava1-l14-linearMSE.pth",
+ "ava+logos-l14-linearMSE.pth",
+ "*.pth",
+ "*.pt",
+ "*.bin",
+ "*.safetensors",
+ ]
+ for name in names:
+ candidate = path / name
+ if candidate.exists():
+ return candidate
+ matches = sorted(path.rglob(name))
+ if matches:
+ return matches[0]
+ raise FileNotFoundError(f"Cannot find an Aesthetic MLP checkpoint under {path}.")
+
+ @classmethod
+ def _load_checkpoint(cls, path):
+ checkpoint_path = cls._find_checkpoint(path)
+ if checkpoint_path.suffix == ".safetensors":
+ import safetensors.torch
+
+ checkpoint = safetensors.torch.load_file(str(checkpoint_path), device="cpu")
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ if isinstance(checkpoint, dict):
+ for key in ("state_dict", "model"):
+ if key in checkpoint and isinstance(checkpoint[key], dict):
+ checkpoint = checkpoint[key]
+ break
+ return checkpoint
+
+ @staticmethod
+ def _normalize_checkpoint_keys(checkpoint):
+ normalized = {}
+ for key, value in checkpoint.items():
+ for prefix in ("model.", "module.", "aesthetic_model.", "aesthetics_predictor.", "predictor."):
+ if key.startswith(prefix):
+ key = key[len(prefix) :]
+ normalized[key] = value
+ return normalized
+
+ @property
+ def device(self):
+ if self.clip_model is not None:
+ return self.clip_model.device
+ try:
+ return next(self.vision_model.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ @property
+ def dtype(self):
+ if self.clip_model is not None:
+ return self.clip_model.dtype
+ try:
+ return next(self.vision_model.parameters()).dtype
+ except StopIteration:
+ return torch.float32
+
+ @torch.no_grad()
+ def get_image_features(self, images):
+ if self.clip_model is not None:
+ return self.clip_model.get_image_features(images)
+ images = _as_image_list(images)
+ inputs = self.processor(images=images, return_tensors="pt")
+ pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.dtype)
+ image_features = self.vision_model(pixel_values=pixel_values, return_dict=True).image_embeds
+ return image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+
+ @torch.no_grad()
+ def forward(self, images):
+ image_features = self.get_image_features(images).float()
+ return self.mlp(image_features).squeeze(-1)
diff --git a/diffsynth/models/clip.py b/diffsynth/models/clip.py
new file mode 100644
index 000000000..04f8db41e
--- /dev/null
+++ b/diffsynth/models/clip.py
@@ -0,0 +1,134 @@
+from typing import Union
+import torch
+from PIL import Image
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _feature_tensor(output, feature_name: str):
+ if torch.is_tensor(output):
+ return output
+ for name in ("image_embeds", "text_embeds", "pooler_output"):
+ value = getattr(output, name, None)
+ if torch.is_tensor(value):
+ return value
+ if isinstance(output, (list, tuple)):
+ for value in output:
+ if torch.is_tensor(value):
+ return value
+ raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
+
+
+class CLIPModel(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, processor, max_length: int = 77):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+ self.max_length = max_length
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_path: str,
+ processor_path: str = None,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = "cuda",
+ max_length: int = 77,
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ from modelscope import AutoModel, AutoProcessor
+
+ model_kwargs = {} if model_kwargs is None else model_kwargs
+ processor_kwargs = {} if processor_kwargs is None else processor_kwargs
+ processor_path = model_path if processor_path is None else processor_path
+ processor = AutoProcessor.from_pretrained(processor_path, **processor_kwargs)
+ model = AutoModel.from_pretrained(model_path, **model_kwargs).eval()
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ model = model.to(device)
+ return cls(model=model, processor=processor, max_length=max_length)
+
+ @property
+ def device(self):
+ try:
+ return next(self.model.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ @property
+ def dtype(self):
+ try:
+ return next(self.model.parameters()).dtype
+ except StopIteration:
+ return torch.float32
+
+ def _normalize_pairs(self, text, images):
+ if isinstance(text, str):
+ text = [text]
+ else:
+ text = list(text)
+ if isinstance(images, Image.Image):
+ images = [images]
+ images = [image.convert("RGB") for image in images]
+ if len(text) == 1 and len(images) > 1:
+ text = text * len(images)
+ if len(images) == 1 and len(text) > 1:
+ images = images * len(text)
+ if len(text) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(text)} and {len(images)}.")
+ return text, images
+
+ def _processor_call(self, **kwargs):
+ inputs = self.processor(
+ padding=True,
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt",
+ **kwargs,
+ ).to(self.device)
+ if self.dtype != torch.float32:
+ inputs = {
+ name: (
+ value.to(dtype=self.dtype)
+ if torch.is_tensor(value) and torch.is_floating_point(value)
+ else value
+ )
+ for name, value in inputs.items()
+ }
+ return inputs
+
+ @torch.no_grad()
+ def get_image_features(self, images: ImageInput):
+ if isinstance(images, Image.Image):
+ images = [images]
+ images = [image.convert("RGB") for image in images]
+ image_inputs = self._processor_call(images=images)
+ image_features = _feature_tensor(self.model.get_image_features(**image_inputs), "image_features")
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+ return image_features
+
+ @torch.no_grad()
+ def get_text_features(self, text: Union[str, list[str]]):
+ text_inputs = self._processor_call(text=text)
+ text_features = _feature_tensor(self.model.get_text_features(**text_inputs), "text_features")
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+ return text_features
+
+ @torch.no_grad()
+ def similarity_matrix(self, text: Union[str, list[str]], images: ImageInput):
+ image_features = self.get_image_features(images)
+ text_features = self.get_text_features(text)
+ scores = text_features @ image_features.T
+ if hasattr(self.model, "logit_scale"):
+ scores = self.model.logit_scale.exp() * scores
+ return scores
+
+ @torch.no_grad()
+ def forward(self, text: Union[str, list[str]], images: ImageInput):
+ text, images = self._normalize_pairs(text, images)
+ image_features = self.get_image_features(images)
+ text_features = self.get_text_features(text)
+ scores = (text_features * image_features).sum(dim=-1)
+ if hasattr(self.model, "logit_scale"):
+ scores = self.model.logit_scale.exp() * scores
+ return scores
diff --git a/diffsynth/models/fid.py b/diffsynth/models/fid.py
new file mode 100644
index 000000000..31b87b3d4
--- /dev/null
+++ b/diffsynth/models/fid.py
@@ -0,0 +1,304 @@
+from pathlib import Path
+from typing import Iterable, Union
+import warnings
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision import transforms
+from torchvision.models import Inception_V3_Weights, inception_v3
+from torchvision.models.inception import InceptionA, InceptionC, InceptionE
+
+ImageInput = Union[str, Path, Image.Image]
+
+IMAGE_EXTENSIONS = {".bmp", ".jpg", ".jpeg", ".pgm", ".png", ".ppm", ".tif", ".tiff", ".webp"}
+
+def _resolve_device(device: Union[str, torch.device, None]):
+ if device is None:
+ device = "cuda" if _is_cuda_usable("cuda", warn=False) else "cpu"
+ device = torch.device(device)
+ if device.type == "cuda" and not _is_cuda_usable(device, warn=True):
+ return torch.device("cpu")
+ return device
+
+
+def _is_cuda_usable(device: Union[str, torch.device], warn: bool = True):
+ try:
+ if not torch.cuda.is_available():
+ if warn:
+ warnings.warn("CUDA was requested but torch.cuda.is_available() is False. FID will run on CPU instead.", RuntimeWarning)
+ return False
+ torch.empty(1, device=device)
+ return True
+ except Exception as error:
+ if warn:
+ warnings.warn(f"CUDA was requested but cannot be initialized ({error}). FID will run on CPU instead.", RuntimeWarning)
+ return False
+
+
+def _image_files(path: Union[str, Path]):
+ path = Path(path)
+ if path.is_file():
+ if path.suffix.lower() not in IMAGE_EXTENSIONS:
+ raise ValueError(f"Unsupported image extension for FID: {path}")
+ return [path]
+ if not path.exists():
+ raise FileNotFoundError(f"FID path does not exist: {path}")
+ files = [item for item in sorted(path.rglob("*")) if item.is_file() and item.suffix.lower() in IMAGE_EXTENSIONS]
+ if not files:
+ raise ValueError(f"No images found under {path}.")
+ return files
+
+
+class _ImageDataset(torch.utils.data.Dataset):
+ def __init__(self, images: Iterable[ImageInput], transform):
+ self.images = list(images)
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.images)
+
+ def __getitem__(self, index):
+ image = self.images[index]
+ if isinstance(image, (str, Path)):
+ image = Image.open(image)
+ if not isinstance(image, Image.Image):
+ raise TypeError(f"FID expects PIL images or image paths, but received {type(image)}.")
+ return self.transform(image.convert("RGB"))
+
+
+class _InceptionFeatures(nn.Module):
+ def __init__(self, weights_path: str = None, pretrained: bool = True, use_fid_inception: bool = True):
+ super().__init__()
+ if use_fid_inception and weights_path is not None:
+ model = _fid_inception_v3(weights_path)
+ self.normalize_input = "fid"
+ elif use_fid_inception and weights_path is None:
+ warnings.warn(
+ "FID-specific Inception weights were not provided. Falling back to torchvision Inception weights; "
+ "scores are useful for relative comparisons but are not directly comparable to standard pytorch-fid values.",
+ RuntimeWarning,
+ )
+ weights = Inception_V3_Weights.DEFAULT if pretrained else None
+ model = inception_v3(weights=weights, aux_logits=True, init_weights=False)
+ model.fc = nn.Identity()
+ self.normalize_input = "imagenet" if pretrained else None
+ else:
+ weights = Inception_V3_Weights.DEFAULT if pretrained else None
+ model = inception_v3(weights=weights, aux_logits=True, init_weights=False)
+ model.fc = nn.Identity()
+ self.normalize_input = "imagenet" if pretrained else None
+ model.eval()
+ self.model = model
+
+ def forward(self, images):
+ if self.normalize_input == "fid":
+ images = 2 * images - 1
+ elif self.normalize_input == "imagenet":
+ mean = images.new_tensor((0.485, 0.456, 0.406)).view(1, 3, 1, 1)
+ std = images.new_tensor((0.229, 0.224, 0.225)).view(1, 3, 1, 1)
+ images = (images - mean) / std
+ return self.model(images)
+
+
+def _fid_inception_v3(weights_path: str):
+ model = inception_v3(weights=None, aux_logits=False, num_classes=1008, init_weights=False)
+ model.Mixed_5b = _FIDInceptionA(192, pool_features=32)
+ model.Mixed_5c = _FIDInceptionA(256, pool_features=64)
+ model.Mixed_5d = _FIDInceptionA(288, pool_features=64)
+ model.Mixed_6b = _FIDInceptionC(768, channels_7x7=128)
+ model.Mixed_6c = _FIDInceptionC(768, channels_7x7=160)
+ model.Mixed_6d = _FIDInceptionC(768, channels_7x7=160)
+ model.Mixed_6e = _FIDInceptionC(768, channels_7x7=192)
+ model.Mixed_7b = _FIDInceptionE1(1280)
+ model.Mixed_7c = _FIDInceptionE2(2048)
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
+ model.fc = nn.Identity()
+ return model
+
+
+class _FIDInceptionA(InceptionA):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 1)
+
+
+class _FIDInceptionC(InceptionC):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 1)
+
+
+class _FIDInceptionE1(InceptionE):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = torch.cat([self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)], 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = torch.cat([self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)], 1)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 1)
+
+
+class _FIDInceptionE2(InceptionE):
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = torch.cat([self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)], 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = torch.cat([self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)], 1)
+
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ return torch.cat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 1)
+
+
+class FIDModel(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, device: Union[str, torch.device] = "cpu", batch_size: int = 50, num_workers: int = 0):
+ super().__init__()
+ self.model = model
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize((299, 299), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ ]
+ )
+ self.to(device)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ weights_path: str = None,
+ pretrained: bool = True,
+ device: Union[str, torch.device] = "cpu",
+ batch_size: int = 50,
+ num_workers: int = 0,
+ use_fid_inception: bool = True,
+ ):
+ if isinstance(weights_path, (list, tuple)):
+ if len(weights_path) == 1:
+ weights_path = weights_path[0]
+ elif len(weights_path) == 0:
+ raise FileNotFoundError(
+ "FID weights were not found. Please check the ModelScope model id and file pattern."
+ )
+ else:
+ raise ValueError(
+ f"FID expects a single weights file, but got {len(weights_path)} paths: {weights_path}"
+ )
+ device = _resolve_device(device)
+ model = _InceptionFeatures(weights_path=weights_path, pretrained=pretrained, use_fid_inception=use_fid_inception).to(device).eval()
+ return cls(model=model, device=device, batch_size=batch_size, num_workers=num_workers)
+
+ @property
+ def device(self):
+ try:
+ return next(self.model.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ def _as_images(self, images):
+ if isinstance(images, (str, Path)):
+ return _image_files(images)
+ if isinstance(images, Image.Image):
+ return [images]
+ return list(images)
+
+ @torch.no_grad()
+ def get_activations(self, images, batch_size: int = None, num_workers: int = None):
+ images = self._as_images(images)
+ batch_size = self.batch_size if batch_size is None else batch_size
+ num_workers = self.num_workers if num_workers is None else num_workers
+ dataset = _ImageDataset(images, transform=self.transform)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=min(batch_size, len(dataset)), shuffle=False, num_workers=num_workers)
+ activations = []
+ self.model.eval()
+ for batch in dataloader:
+ batch = batch.to(self.device)
+ features = self.model(batch)
+ if isinstance(features, tuple):
+ features = features[0]
+ if features.ndim == 4:
+ features = F.adaptive_avg_pool2d(features, output_size=(1, 1)).flatten(1)
+ activations.append(features.detach().cpu().to(torch.float64))
+ return torch.cat(activations, dim=0)
+
+ def statistics(self, images, batch_size: int = None, num_workers: int = None):
+ activations = self.get_activations(images, batch_size=batch_size, num_workers=num_workers)
+ return self.activation_statistics(activations)
+
+ @staticmethod
+ def activation_statistics(activations):
+ activations = activations.to(torch.float64)
+ mean = activations.mean(dim=0)
+ centered = activations - mean
+ if activations.shape[0] <= 1:
+ covariance = torch.zeros((activations.shape[1], activations.shape[1]), dtype=torch.float64)
+ else:
+ covariance = centered.T @ centered / (activations.shape[0] - 1)
+ return mean, covariance
+
+ @staticmethod
+ def _sqrtm_psd(matrix, eps: float = 1e-10):
+ matrix = (matrix + matrix.T) * 0.5
+ eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
+ eigenvalues = eigenvalues.clamp_min(eps).sqrt()
+ return (eigenvectors * eigenvalues.unsqueeze(0)) @ eigenvectors.T
+
+ @classmethod
+ def frechet_distance(cls, mean1, covariance1, mean2, covariance2, eps: float = 1e-6):
+ mean1 = mean1.to(torch.float64)
+ covariance1 = covariance1.to(torch.float64)
+ mean2 = mean2.to(torch.float64)
+ covariance2 = covariance2.to(torch.float64)
+ diff = mean1 - mean2
+ offset = torch.eye(covariance1.shape[0], dtype=torch.float64) * eps
+ sqrt_cov1 = cls._sqrtm_psd(covariance1 + offset)
+ covmean = cls._sqrtm_psd(sqrt_cov1 @ (covariance2 + offset) @ sqrt_cov1)
+ distance = diff.dot(diff) + torch.trace(covariance1) + torch.trace(covariance2) - 2 * torch.trace(covmean)
+ return distance.clamp_min(0)
+
+ def compute(self, reference_images, generated_images, batch_size: int = None, num_workers: int = None):
+ mean1, covariance1 = self.statistics(reference_images, batch_size=batch_size, num_workers=num_workers)
+ mean2, covariance2 = self.statistics(generated_images, batch_size=batch_size, num_workers=num_workers)
+ return self.frechet_distance(mean1, covariance1, mean2, covariance2)
+
+ def forward(self, reference_images, generated_images):
+ return self.compute(reference_images, generated_images)
diff --git a/diffsynth/models/hpsv2.py b/diffsynth/models/hpsv2.py
new file mode 100644
index 000000000..4d6cb6086
--- /dev/null
+++ b/diffsynth/models/hpsv2.py
@@ -0,0 +1,219 @@
+from pathlib import Path
+from typing import Union
+import torch
+from PIL import Image
+from transformers import AutoConfig, AutoModel, AutoProcessor
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _as_list(value):
+ if isinstance(value, (list, tuple)):
+ return list(value)
+ return [value]
+
+def _feature_tensor(output, feature_name: str):
+ if torch.is_tensor(output):
+ return output
+ for name in ("image_embeds", "text_embeds", "pooler_output"):
+ value = getattr(output, name, None)
+ if torch.is_tensor(value):
+ return value
+ if isinstance(output, (list, tuple)):
+ for value in output:
+ if torch.is_tensor(value):
+ return value
+ raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
+
+
+def _find_checkpoint(path, version):
+ path = Path(path)
+ version_to_file = {
+ "v2.0": "HPS_v2_compressed.pt",
+ "v2.1": "HPS_v2.1_compressed.pt",
+ }
+ if path.is_file():
+ return path
+ filename = version_to_file.get(version)
+ names = [filename] if filename is not None else []
+ names += ["*.pt", "*.pth", "*.bin", "*.safetensors"]
+ for name in names:
+ if name is None:
+ continue
+ candidate = path / name
+ if candidate.exists():
+ return candidate
+ matches = sorted(path.rglob(name))
+ if matches:
+ return matches[0]
+ return None
+
+class HPSv2Model(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, processor):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_path: str,
+ processor_path: str,
+ version: str = "v2.0",
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = "cpu",
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ model_kwargs = {} if model_kwargs is None else model_kwargs
+ processor_kwargs = {} if processor_kwargs is None else processor_kwargs
+ processor = AutoProcessor.from_pretrained(processor_path, **processor_kwargs)
+ checkpoint_path = _find_checkpoint(model_path, version)
+ config = AutoConfig.from_pretrained(processor_path)
+ model = AutoModel.from_config(config, **model_kwargs)
+ if checkpoint_path is None:
+ raise FileNotFoundError(f"Cannot find an HPSv2 checkpoint under {model_path}.")
+ state_dict = cls._load_checkpoint(checkpoint_path)
+ state_dict = cls._prepare_state_dict(state_dict, model.state_dict())
+ model.load_state_dict(state_dict, strict=False)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ model = model.to(device).eval()
+ return cls(model=model, processor=processor)
+
+ @staticmethod
+ def _load_checkpoint(checkpoint_path):
+ checkpoint_path = Path(checkpoint_path)
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+ if isinstance(state_dict, dict):
+ for key in ("state_dict", "model"):
+ if key in state_dict and isinstance(state_dict[key], dict):
+ state_dict = state_dict[key]
+ break
+ return {key[len("module.") :] if key.startswith("module.") else key: value for key, value in state_dict.items()}
+
+ @staticmethod
+ def _prepare_state_dict(state_dict, target_state_dict):
+ converted = {}
+ for key, value in state_dict.items():
+ updates = HPSv2Model._convert_open_clip_key(key, value)
+ for new_key, new_value in updates:
+ if new_key in target_state_dict and tuple(target_state_dict[new_key].shape) == tuple(new_value.shape):
+ converted[new_key] = new_value
+ return converted
+
+ @staticmethod
+ def _convert_open_clip_key(key, value):
+ if key == "logit_scale":
+ return [("logit_scale", value)]
+ if key == "token_embedding.weight":
+ return [("text_model.embeddings.token_embedding.weight", value)]
+ if key == "positional_embedding":
+ return [("text_model.embeddings.position_embedding.weight", value)]
+ if key.startswith("ln_final."):
+ return [("text_model.final_layer_norm." + key[len("ln_final.") :], value)]
+ if key == "text_projection":
+ return [("text_projection.weight", value.T)]
+ if key == "visual.class_embedding":
+ return [("vision_model.embeddings.class_embedding", value)]
+ if key == "visual.conv1.weight":
+ return [("vision_model.embeddings.patch_embedding.weight", value)]
+ if key == "visual.positional_embedding":
+ return [("vision_model.embeddings.position_embedding.weight", value)]
+ if key.startswith("visual.ln_pre."):
+ return [("vision_model.pre_layrnorm." + key[len("visual.ln_pre.") :], value)]
+ if key.startswith("visual.ln_post."):
+ return [("vision_model.post_layernorm." + key[len("visual.ln_post.") :], value)]
+ if key == "visual.proj":
+ return [("visual_projection.weight", value.T)]
+ if key.startswith("transformer.resblocks."):
+ return HPSv2Model._convert_resblock("text_model.encoder.layers", key[len("transformer.resblocks.") :], value)
+ if key.startswith("visual.transformer.resblocks."):
+ return HPSv2Model._convert_resblock("vision_model.encoder.layers", key[len("visual.transformer.resblocks.") :], value)
+ return []
+
+ @staticmethod
+ def _convert_resblock(prefix, suffix, value):
+ parts = suffix.split(".")
+ layer = parts[0]
+ rest = ".".join(parts[1:])
+ layer_prefix = f"{prefix}.{layer}."
+ if rest == "attn.in_proj_weight":
+ q, k, v = value.chunk(3, dim=0)
+ return [
+ (layer_prefix + "self_attn.q_proj.weight", q),
+ (layer_prefix + "self_attn.k_proj.weight", k),
+ (layer_prefix + "self_attn.v_proj.weight", v),
+ ]
+ if rest == "attn.in_proj_bias":
+ q, k, v = value.chunk(3, dim=0)
+ return [
+ (layer_prefix + "self_attn.q_proj.bias", q),
+ (layer_prefix + "self_attn.k_proj.bias", k),
+ (layer_prefix + "self_attn.v_proj.bias", v),
+ ]
+ mapping = {
+ "attn.out_proj.": "self_attn.out_proj.",
+ "ln_1.": "layer_norm1.",
+ "ln_2.": "layer_norm2.",
+ "mlp.c_fc.": "mlp.fc1.",
+ "mlp.c_proj.": "mlp.fc2.",
+ }
+ for source, target in mapping.items():
+ if rest.startswith(source):
+ return [(layer_prefix + target + rest[len(source) :], value)]
+ return []
+
+ @property
+ def device(self):
+ try:
+ return next(self.model.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ @property
+ def dtype(self):
+ try:
+ return next(self.model.parameters()).dtype
+ except StopIteration:
+ return torch.float32
+
+ def _normalize_inputs(self, prompts, images):
+ images = _as_list(images)
+ prompts = _as_list(prompts)
+ if len(prompts) == 1 and len(images) > 1:
+ prompts = prompts * len(images)
+ if len(images) == 1 and len(prompts) > 1:
+ images = images * len(prompts)
+ if len(prompts) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
+ return prompts, images
+
+ @torch.no_grad()
+ def forward(self, prompts: Union[str, list[str]], images: ImageInput):
+ prompts, images = self._normalize_inputs(prompts, images)
+ images = [image.convert("RGB") for image in images]
+ inputs = self.processor(text=prompts, images=images, padding=True, truncation=True, return_tensors="pt")
+ inputs = inputs.to(self.device)
+ if self.dtype != torch.float32:
+ inputs = {
+ name: (
+ value.to(dtype=self.dtype)
+ if torch.is_tensor(value) and torch.is_floating_point(value)
+ else value
+ )
+ for name, value in inputs.items()
+ }
+ image_features = _feature_tensor(
+ self.model.get_image_features(pixel_values=inputs["pixel_values"]),
+ "image_features",
+ )
+ text_features = _feature_tensor(
+ self.model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask")),
+ "text_features",
+ )
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+ scores = (image_features * text_features).sum(dim=-1)
+ if hasattr(self.model, "logit_scale"):
+ scores = self.model.logit_scale.exp() * scores
+ return scores
diff --git a/diffsynth/models/hpsv3.py b/diffsynth/models/hpsv3.py
new file mode 100644
index 000000000..0b7498850
--- /dev/null
+++ b/diffsynth/models/hpsv3.py
@@ -0,0 +1,394 @@
+import math
+from pathlib import Path
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+from PIL import Image
+from transformers import AutoProcessor
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+HPSV3_INSTRUCTION = """
+You are tasked with evaluating a generated image based on Visual Quality and Text Alignment and give a overall score to estimate the human preference. Please provide a rating from 0 to 10, with 0 being the worst and 10 being the best.
+
+**Visual Quality:**
+Evaluate the overall visual quality of the image. The following sub-dimensions should be considered:
+- **Reasonableness:** The image should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups.
+- **Clarity:** Evaluate the sharpness and visibility of the image. The image should be clear and easy to interpret, with no blurring or indistinct areas.
+- **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows).
+- **Aesthetic and Creativity:** Assess the artistic aspects of the image, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance.
+- **Safety:** The image should not contain harmful or inappropriate content, such as political, violent, or adult material. If such content is present, the image quality and satisfaction score should be the lowest possible.
+
+**Text Alignment:**
+Assess how well the image matches the textual prompt across the following sub-dimensions:
+- **Subject Relevance** Evaluate how accurately the subject(s) in the image (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, and behavior.
+- **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the image adheres to this style.
+- **Contextual Consistency**: Assess whether the background, setting, and surrounding elements in the image logically fit the scenario described in the prompt. The environment should support and enhance the subject without contradictions.
+- **Attribute Fidelity**: Check if specific attributes mentioned in the prompt (e.g., colors, clothing, accessories, expressions, actions) are faithfully represented in the image. Minor deviations may be acceptable, but critical attributes should be preserved.
+- **Semantic Coherence**: Evaluate whether the overall meaning and intent of the prompt are captured in the image. The generated content should not introduce elements that conflict with or distort the original description.
+Textual prompt - {text_prompt}
+
+
+"""
+
+HPSV3_PROMPT_WITH_SPECIAL_TOKEN = """
+Please provide the overall ratings of this image: <|Reward|>
+
+END
+"""
+
+HPSV3_PROMPT_WITHOUT_SPECIAL_TOKEN = """
+Please provide the overall ratings of this image:
+"""
+
+def _as_list(value):
+ if isinstance(value, (list, tuple)):
+ return list(value)
+ return [value]
+
+def _round_by_factor(number, factor):
+ return round(number / factor) * factor
+
+def _ceil_by_factor(number, factor):
+ return math.ceil(number / factor) * factor
+
+def _floor_by_factor(number, factor):
+ return math.floor(number / factor) * factor
+
+def _smart_resize(height, width, factor=28, min_pixels=256 * 28 * 28, max_pixels=256 * 28 * 28):
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(f"Image aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}.")
+ h_bar = max(factor, _round_by_factor(height, factor))
+ w_bar = max(factor, _round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = _floor_by_factor(height / beta, factor)
+ w_bar = _floor_by_factor(width / beta, factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = _ceil_by_factor(height * beta, factor)
+ w_bar = _ceil_by_factor(width * beta, factor)
+ return h_bar, w_bar
+
+def _find_checkpoint(path):
+ path = Path(path)
+ if path.is_file():
+ return path
+ for name in ("HPSv3.safetensors", "*.safetensors", "*.bin", "*.pt", "*.pth"):
+ candidate = path / name
+ if candidate.exists():
+ return candidate
+ matches = sorted(path.rglob(name))
+ if matches:
+ return matches[0]
+ return None
+
+class HPSv3RewardModelMixin:
+ def init_reward_head(
+ self,
+ output_dim=2,
+ reward_token="special",
+ special_token_ids=None,
+ rm_head_type="ranknet",
+ rm_head_kwargs=None,
+ ):
+ self.output_dim = output_dim
+ self.reward_token = "special" if special_token_ids is not None else reward_token
+ self.special_token_ids = special_token_ids
+ hidden_size = getattr(self.config, "hidden_size", None)
+ if hidden_size is None and hasattr(self.config, "text_config"):
+ hidden_size = self.config.text_config.hidden_size
+ if rm_head_type == "ranknet":
+ rm_head_kwargs = {} if rm_head_kwargs is None else rm_head_kwargs
+ hidden = rm_head_kwargs.get("hidden_size", 1024)
+ dropout = rm_head_kwargs.get("dropout", 0.05)
+ self.rm_head = nn.Sequential(
+ nn.Linear(hidden_size, hidden),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden, 16),
+ nn.ReLU(),
+ nn.Linear(16, output_dim),
+ )
+ else:
+ self.rm_head = nn.Linear(hidden_size, output_dim, bias=False)
+ self.rm_head.to(torch.float32)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ mm_token_type_ids: Optional[torch.IntTensor] = None,
+ **kwargs,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ rope_deltas=rope_deltas,
+ mm_token_type_ids=mm_token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0]
+ logits = self.rm_head(hidden_states.to(next(self.rm_head.parameters()).dtype))
+
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ elif input_ids is not None:
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ if self.reward_token == "last":
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+ elif self.reward_token == "mean":
+ valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1)
+ pooled_logits = torch.stack([logits[i, : valid_lengths[i]].mean(dim=0) for i in range(batch_size)])
+ elif self.reward_token == "special":
+ special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
+ for special_token_id in self.special_token_ids:
+ special_token_mask = special_token_mask | (input_ids == special_token_id)
+ pooled_logits = logits[special_token_mask, ...].view(batch_size, -1)
+ else:
+ raise ValueError(f"Invalid HPSv3 reward token mode: {self.reward_token}")
+ return {"logits": pooled_logits}
+
+def _create_reward_model_class():
+ from transformers import Qwen2VLForConditionalGeneration
+
+ class HPSv3Qwen2VLRewardModel(HPSv3RewardModelMixin, Qwen2VLForConditionalGeneration):
+ def __init__(
+ self,
+ config,
+ output_dim=2,
+ reward_token="special",
+ special_token_ids=None,
+ rm_head_type="ranknet",
+ rm_head_kwargs=None,
+ ):
+ super().__init__(config)
+ self.init_reward_head(
+ output_dim=output_dim,
+ reward_token=reward_token,
+ special_token_ids=special_token_ids,
+ rm_head_type=rm_head_type,
+ rm_head_kwargs=rm_head_kwargs,
+ )
+
+ return HPSv3Qwen2VLRewardModel
+
+class HPSv3Model(torch.nn.Module):
+ def __init__(
+ self,
+ model,
+ processor,
+ use_special_tokens=True,
+ max_pixels=256 * 28 * 28,
+ min_pixels=256 * 28 * 28,
+ score_index=0,
+ ):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+ self.use_special_tokens = use_special_tokens
+ self.max_pixels = max_pixels
+ self.min_pixels = min_pixels
+ self.score_index = score_index
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_path: str,
+ base_model_path: str = None,
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cpu",
+ output_dim: int = 2,
+ score_index: int = 0,
+ use_special_tokens: bool = True,
+ reward_token: str = "special",
+ rm_head_type: str = "ranknet",
+ rm_head_kwargs: dict = None,
+ max_pixels: int = 256 * 28 * 28,
+ min_pixels: int = 256 * 28 * 28,
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ model_kwargs = {} if model_kwargs is None else model_kwargs
+ processor_kwargs = {} if processor_kwargs is None else processor_kwargs
+ model_path = Path(model_path)
+ base_model_path = base_model_path or str(model_path)
+ checkpoint_path = _find_checkpoint(model_path)
+ if checkpoint_path is None:
+ raise FileNotFoundError(f"Cannot find an HPSv3 checkpoint under {model_path}.")
+
+ processor = AutoProcessor.from_pretrained(base_model_path, padding_side="right", **processor_kwargs)
+ special_token_ids = None
+ if use_special_tokens:
+ special_tokens = ["<|Reward|>"]
+ processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
+ special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
+
+ reward_model_class = _create_reward_model_class()
+ model = reward_model_class.from_pretrained(
+ base_model_path,
+ output_dim=output_dim,
+ reward_token=reward_token,
+ special_token_ids=special_token_ids,
+ torch_dtype=torch_dtype,
+ attn_implementation=model_kwargs.pop("attn_implementation", "sdpa"),
+ **model_kwargs,
+ )
+ if use_special_tokens:
+ model.resize_token_embeddings(len(processor.tokenizer))
+ state_dict = cls._load_checkpoint(checkpoint_path)
+ state_dict = cls._prepare_state_dict(state_dict, model.state_dict())
+ model.load_state_dict(state_dict, strict=True)
+ model.config.tokenizer_padding_side = processor.tokenizer.padding_side
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
+ model.rm_head.to(torch.float32)
+ model = model.to(device).eval()
+ return cls(
+ model=model,
+ processor=processor,
+ use_special_tokens=use_special_tokens,
+ max_pixels=max_pixels,
+ min_pixels=min_pixels,
+ score_index=score_index,
+ )
+
+ @staticmethod
+ def _load_checkpoint(checkpoint_path):
+ checkpoint_path = Path(checkpoint_path)
+ if checkpoint_path.suffix == ".safetensors":
+ import safetensors.torch
+
+ state_dict = safetensors.torch.load_file(str(checkpoint_path), device="cpu")
+ else:
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+ if isinstance(state_dict, dict):
+ for key in ("state_dict", "model"):
+ if key in state_dict and isinstance(state_dict[key], dict):
+ state_dict = state_dict[key]
+ break
+ return {key[len("module.") :] if key.startswith("module.") else key: value for key, value in state_dict.items()}
+
+ @staticmethod
+ def _prepare_state_dict(state_dict, target_state_dict):
+ target_keys = set(target_state_dict.keys())
+ converted = {}
+ for key, value in state_dict.items():
+ new_key = key
+ if key.startswith("visual.") and f"model.{key}" in target_keys:
+ new_key = f"model.{key}"
+ elif key.startswith("model.visual.") and key[len("model.") :] in target_keys:
+ new_key = key[len("model.") :]
+ elif key.startswith("model.") and not key.startswith("model.language_model."):
+ suffix = key[len("model.") :]
+ if f"model.language_model.{suffix}" in target_keys:
+ new_key = f"model.language_model.{suffix}"
+ elif key.startswith("model.language_model."):
+ suffix = key[len("model.language_model.") :]
+ if f"model.{suffix}" in target_keys:
+ new_key = f"model.{suffix}"
+ elif key.startswith("lm_head.") and f"model.{key}" in target_keys:
+ new_key = f"model.{key}"
+ elif key.startswith("model.lm_head.") and key[len("model.") :] in target_keys:
+ new_key = key[len("model.") :]
+ converted[new_key] = value
+ return converted
+
+ @property
+ def device(self):
+ try:
+ return next(self.model.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ def _normalize_inputs(self, prompts, images):
+ images = _as_list(images)
+ prompts = _as_list(prompts)
+ if len(prompts) == 1 and len(images) > 1:
+ prompts = prompts * len(images)
+ if len(images) == 1 and len(prompts) > 1:
+ images = images * len(prompts)
+ if len(prompts) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
+ return prompts, images
+
+ def _prepare_images(self, images):
+ prepared = []
+ for image in images:
+ image = image.convert("RGB")
+ height, width = image.height, image.width
+ resized_height, resized_width = _smart_resize(
+ height,
+ width,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ prepared.append(image.resize((resized_width, resized_height), Image.BICUBIC))
+ return prepared
+
+ def _messages(self, prompts, images):
+ suffix = HPSV3_PROMPT_WITH_SPECIAL_TOKEN if self.use_special_tokens else HPSV3_PROMPT_WITHOUT_SPECIAL_TOKEN
+ messages = []
+ for prompt, image in zip(prompts, images):
+ messages.append(
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": image},
+ {"type": "text", "text": HPSV3_INSTRUCTION.format(text_prompt=prompt) + suffix},
+ ],
+ }
+ ]
+ )
+ return messages
+
+ def _prepare_batch(self, prompts, images):
+ images = self._prepare_images(images)
+ messages = self._messages(prompts, images)
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ batch = self.processor(text=text, images=images, padding=True, return_tensors="pt")
+ return batch.to(self.device)
+
+ @torch.inference_mode()
+ def forward(self, prompts: Union[str, list[str]], images):
+ prompts, images = self._normalize_inputs(prompts, images)
+ batch = self._prepare_batch(prompts, images)
+ rewards = self.model(return_dict=True, **batch)["logits"]
+ if rewards.ndim == 2:
+ return rewards[:, self.score_index]
+ return rewards
diff --git a/diffsynth/models/image_reward.py b/diffsynth/models/image_reward.py
new file mode 100644
index 000000000..c8f557058
--- /dev/null
+++ b/diffsynth/models/image_reward.py
@@ -0,0 +1,305 @@
+import json
+import fnmatch
+from pathlib import Path
+from typing import Union
+import torch
+import torch.nn as nn
+from PIL import Image
+from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
+from torchvision.transforms import InterpolationMode
+BICUBIC = InterpolationMode.BICUBIC
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _convert_image_to_rgb(image):
+ return image.convert("RGB")
+
+def _image_transform(image_size):
+ return Compose(
+ [
+ Resize(image_size, interpolation=BICUBIC),
+ CenterCrop(image_size),
+ _convert_image_to_rgb,
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ]
+ )
+
+def _as_list(value):
+ if isinstance(value, (list, tuple)):
+ return list(value)
+ return [value]
+
+def _find_file(path, names):
+ path = Path(path)
+ if path.is_file():
+ return path if any(fnmatch.fnmatch(path.name, name) for name in names) else None
+ for name in names:
+ candidate = path / name
+ if candidate.exists():
+ return candidate
+ for pattern in names:
+ matches = sorted(path.rglob(pattern))
+ if matches:
+ return matches[0]
+ return None
+
+class ImageRewardMLP(nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(input_size, 1024),
+ nn.Dropout(0.2),
+ nn.Linear(1024, 128),
+ nn.Dropout(0.2),
+ nn.Linear(128, 64),
+ nn.Dropout(0.1),
+ nn.Linear(64, 16),
+ nn.Linear(16, 1),
+ )
+
+ for name, param in self.layers.named_parameters():
+ if "weight" in name:
+ nn.init.normal_(param, mean=0.0, std=1.0 / (input_size + 1))
+ if "bias" in name:
+ nn.init.constant_(param, val=0)
+
+ def forward(self, x):
+ return self.layers(x)
+
+class ImageRewardModel(nn.Module):
+ def __init__(self, blip, tokenizer, image_size=224, max_length=35, mean=0.16717362830052426, std=1.0333394966054072):
+ super().__init__()
+ self.blip = blip
+ self.tokenizer = tokenizer
+ self.preprocess = _image_transform(image_size)
+ self.max_length = max_length
+ self.mlp = ImageRewardMLP(blip.config.text_config.hidden_size)
+ self.register_buffer("score_mean", torch.tensor(float(mean)), persistent=False)
+ self.register_buffer("score_std", torch.tensor(float(std)), persistent=False)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_path: str,
+ med_config_path: str = None,
+ tokenizer_path: str = None,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = "cpu",
+ max_length: int = 35,
+ model_kwargs: dict = None,
+ tokenizer_kwargs: dict = None,
+ ):
+ from transformers import BertTokenizer, BlipConfig, BlipForImageTextRetrieval
+
+ model_kwargs = {} if model_kwargs is None else model_kwargs
+ tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
+ model_path = Path(model_path)
+ checkpoint_path = _find_file(model_path, ["ImageReward.pt", "pytorch_model.bin", "*.pt", "*.bin", "*.safetensors"])
+ if checkpoint_path is None:
+ raise FileNotFoundError(f"Cannot find an ImageReward checkpoint under {model_path}.")
+
+ med_config_path = Path(med_config_path) if med_config_path is not None else _find_file(model_path, ["med_config.json"])
+ text_config = cls._load_text_config(med_config_path)
+ if tokenizer_path is None:
+ if cls._has_tokenizer_files(model_path):
+ tokenizer_path = str(model_path)
+ else:
+ raise ValueError(
+ "ImageReward requires a local BERT tokenizer path. Use "
+ "`ImageRewardMetric.from_pretrained(...)`, or pass a "
+ "ModelScope-downloaded tokenizer such as "
+ "`AI-ModelScope/bert-base-uncased`."
+ )
+ tokenizer = BertTokenizer.from_pretrained(tokenizer_path, **tokenizer_kwargs)
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+ tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
+ tokenizer.enc_token_id = tokenizer.convert_tokens_to_ids("[ENC]")
+
+ vision_hidden_size = model_kwargs.pop("vision_hidden_size", 1024)
+ config = BlipConfig(
+ vision_config={
+ "hidden_size": vision_hidden_size,
+ "intermediate_size": vision_hidden_size * 4,
+ "num_hidden_layers": model_kwargs.pop("vision_num_hidden_layers", 24),
+ "num_attention_heads": model_kwargs.pop("vision_num_attention_heads", 16),
+ "image_size": model_kwargs.pop("image_size", 224),
+ "patch_size": model_kwargs.pop("patch_size", 16),
+ "hidden_act": "gelu",
+ "layer_norm_eps": model_kwargs.pop("vision_layer_norm_eps", 1e-6),
+ },
+ text_config={
+ **text_config,
+ "vocab_size": max(text_config.get("vocab_size", 0), len(tokenizer)),
+ "encoder_hidden_size": vision_hidden_size,
+ "add_cross_attention": True,
+ "is_decoder": True,
+ },
+ projection_dim=model_kwargs.pop("projection_dim", 256),
+ )
+ blip = BlipForImageTextRetrieval(config)
+ model = cls(blip=blip, tokenizer=tokenizer, max_length=max_length)
+ state_dict = cls._load_checkpoint(checkpoint_path)
+ converted = cls._convert_state_dict(state_dict)
+ model.load_state_dict(converted, strict=False)
+ if torch_dtype is not None:
+ model.blip = model.blip.to(dtype=torch_dtype)
+ model.mlp = model.mlp.float()
+ model = model.to(device).eval()
+ return model
+
+ @staticmethod
+ def _load_text_config(med_config_path):
+ if med_config_path is None:
+ return {
+ "hidden_size": 768,
+ "intermediate_size": 3072,
+ "num_hidden_layers": 12,
+ "num_attention_heads": 12,
+ "max_position_embeddings": 512,
+ "vocab_size": 30524,
+ "hidden_act": "gelu",
+ "layer_norm_eps": 1e-12,
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_dropout_prob": 0.1,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ }
+ with open(med_config_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ allowed = {
+ "hidden_size",
+ "intermediate_size",
+ "num_hidden_layers",
+ "num_attention_heads",
+ "max_position_embeddings",
+ "vocab_size",
+ "hidden_act",
+ "layer_norm_eps",
+ "attention_probs_dropout_prob",
+ "hidden_dropout_prob",
+ "pad_token_id",
+ "type_vocab_size",
+ }
+ return {key: value for key, value in data.items() if key in allowed}
+
+ @staticmethod
+ def _has_tokenizer_files(path):
+ path = Path(path)
+ return path.is_dir() and any((path / name).exists() for name in ("vocab.txt", "tokenizer.json", "tokenizer_config.json"))
+
+ @staticmethod
+ def _load_checkpoint(checkpoint_path):
+ checkpoint_path = Path(checkpoint_path)
+ if checkpoint_path.suffix == ".safetensors":
+ import safetensors.torch
+
+ state_dict = safetensors.torch.load_file(str(checkpoint_path), device="cpu")
+ else:
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+ if isinstance(state_dict, dict):
+ for key in ("state_dict", "model"):
+ if key in state_dict and isinstance(state_dict[key], dict):
+ state_dict = state_dict[key]
+ break
+ return state_dict
+
+ @staticmethod
+ def _convert_state_dict(state_dict):
+ converted = {}
+ for key, value in state_dict.items():
+ if key.startswith("module."):
+ key = key[len("module.") :]
+ new_key, new_value = ImageRewardModel._convert_key_value(key, value)
+ if new_key is not None:
+ converted[new_key] = new_value
+ return converted
+
+ @staticmethod
+ def _convert_key_value(key, value):
+ if key.startswith("blip.visual_encoder."):
+ suffix = key[len("blip.visual_encoder.") :]
+ if suffix == "cls_token":
+ return "blip.vision_model.embeddings.class_embedding", value
+ if suffix == "pos_embed":
+ return "blip.vision_model.embeddings.position_embedding", value
+ if suffix.startswith("patch_embed.proj."):
+ return "blip.vision_model.embeddings.patch_embedding." + suffix[len("patch_embed.proj.") :], value
+ if suffix.startswith("blocks."):
+ parts = suffix.split(".")
+ layer = parts[1]
+ rest = ".".join(parts[2:])
+ prefix = f"blip.vision_model.encoder.layers.{layer}."
+ mapping = {
+ "norm1.": "layer_norm1.",
+ "attn.qkv.": "self_attn.qkv.",
+ "attn.proj.": "self_attn.projection.",
+ "norm2.": "layer_norm2.",
+ "mlp.fc1.": "mlp.fc1.",
+ "mlp.fc2.": "mlp.fc2.",
+ }
+ for source, target in mapping.items():
+ if rest.startswith(source):
+ return prefix + target + rest[len(source) :], value
+ if suffix.startswith("norm."):
+ return "blip.vision_model.post_layernorm." + suffix[len("norm.") :], value
+ return None, value
+ return key, value
+
+ @property
+ def device(self):
+ try:
+ return next(self.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ @property
+ def dtype(self):
+ try:
+ return next(self.blip.parameters()).dtype
+ except StopIteration:
+ return torch.float32
+
+ def _tokenize(self, prompts):
+ return self.tokenizer(
+ prompts,
+ padding="max_length",
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt",
+ ).to(self.device)
+
+ def _preprocess_images(self, images):
+ tensors = [self.preprocess(image.convert("RGB")) for image in images]
+ return torch.stack(tensors, dim=0).to(device=self.device, dtype=self.dtype)
+
+ def _normalize_inputs(self, prompts, images):
+ images = _as_list(images)
+ prompts = _as_list(prompts)
+ if len(prompts) == 1 and len(images) > 1:
+ prompts = prompts * len(images)
+ if len(images) == 1 and len(prompts) > 1:
+ images = images * len(prompts)
+ if len(prompts) != len(images):
+ raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
+ return prompts, images
+
+ @torch.no_grad()
+ def forward(self, prompts: Union[str, list[str]], images):
+ prompts, images = self._normalize_inputs(prompts, images)
+ text_input = self._tokenize(prompts)
+ image_tensor = self._preprocess_images(images)
+ image_output = self.blip.vision_model(pixel_values=image_tensor, return_dict=True)
+ image_embeds = image_output.last_hidden_state
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=self.device)
+ text_output = self.blip.text_encoder(
+ input_ids=text_input.input_ids,
+ attention_mask=text_input.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ text_features = text_output.last_hidden_state[:, 0, :].float()
+ rewards = self.mlp(text_features).squeeze(-1)
+ rewards = (rewards - self.score_mean) / self.score_std
+ return rewards
diff --git a/diffsynth/models/pickscore.py b/diffsynth/models/pickscore.py
new file mode 100644
index 000000000..900e41efb
--- /dev/null
+++ b/diffsynth/models/pickscore.py
@@ -0,0 +1,107 @@
+from typing import Union
+import torch
+from PIL import Image
+
+ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
+
+def _feature_tensor(output, feature_name: str):
+ if torch.is_tensor(output):
+ return output
+ for name in ("image_embeds", "text_embeds", "pooler_output"):
+ value = getattr(output, name, None)
+ if torch.is_tensor(value):
+ return value
+ if isinstance(output, (list, tuple)):
+ for value in output:
+ if torch.is_tensor(value):
+ return value
+ raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
+
+
+class PickScoreModel(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module, processor, max_length: int = 77):
+ super().__init__()
+ self.model = model
+ self.processor = processor
+ self.max_length = max_length
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_path: str,
+ processor_path: str,
+ torch_dtype: torch.dtype = None,
+ device: Union[str, torch.device] = "cuda",
+ max_length: int = 77,
+ model_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ from modelscope import AutoModel, AutoProcessor
+
+ model_kwargs = {} if model_kwargs is None else model_kwargs
+ processor_kwargs = {} if processor_kwargs is None else processor_kwargs
+ processor = AutoProcessor.from_pretrained(processor_path, **processor_kwargs)
+ model = AutoModel.from_pretrained(model_path, **model_kwargs).eval()
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ model = model.to(device)
+ return cls(model=model, processor=processor, max_length=max_length)
+
+ @property
+ def device(self):
+ try:
+ return next(self.model.parameters()).device
+ except StopIteration:
+ return torch.device("cpu")
+
+ @property
+ def dtype(self):
+ try:
+ return next(self.model.parameters()).dtype
+ except StopIteration:
+ return torch.float32
+
+ def _processor_call(self, **kwargs):
+ inputs = self.processor(
+ padding=True,
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt",
+ **kwargs,
+ ).to(self.device)
+ if self.dtype != torch.float32:
+ inputs = {
+ name: (
+ value.to(dtype=self.dtype)
+ if torch.is_tensor(value) and torch.is_floating_point(value)
+ else value
+ )
+ for name, value in inputs.items()
+ }
+ return inputs
+
+ @torch.no_grad()
+ def get_image_features(self, images: ImageInput):
+ if isinstance(images, Image.Image):
+ images = [images]
+ images = [image.convert("RGB") for image in images]
+ image_inputs = self._processor_call(images=images)
+ image_features = _feature_tensor(self.model.get_image_features(**image_inputs), "image_features")
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+ return image_features
+
+ @torch.no_grad()
+ def get_text_features(self, text: Union[str, list[str]]):
+ text_inputs = self._processor_call(text=text)
+ text_features = _feature_tensor(self.model.get_text_features(**text_inputs), "text_features")
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+ return text_features
+
+ @torch.no_grad()
+ def forward(self, text: Union[str, list[str]], images: ImageInput):
+ image_features = self.get_image_features(images)
+ text_features = self.get_text_features(text)
+ scores = self.model.logit_scale.exp() * (text_features @ image_features.T)
+ if isinstance(text, str):
+ scores = scores[0]
+ return scores
diff --git a/docs/en/Model_Details/Image-Quality-Metrics.md b/docs/en/Model_Details/Image-Quality-Metrics.md
new file mode 100644
index 000000000..34e5778e8
--- /dev/null
+++ b/docs/en/Model_Details/Image-Quality-Metrics.md
@@ -0,0 +1,117 @@
+# Image Quality Evaluation Metrics
+
+DiffSynth-Studio provides a suite of image quality evaluation metrics and reward models in `diffsynth.metrics` to assess text alignment, aesthetic quality, human preference, and image distribution quality of generated images. Example code for these metrics can be found in [`examples/image_quality_metric/`](../../../examples/image_quality_metric/).
+
+## Installation
+
+Before using this project for model inference and training, please install DiffSynth-Studio first.
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
+
+## Quick Start
+
+Run the following code to quickly load PickScore and score an image against a prompt. The default models will be downloaded from ModelScope to `./models`.
+
+```python
+from PIL import Image
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+score = metric.calc_scores(prompt, image)[0]
+print("PickScore:", score)
+```
+
+## Metrics Overview
+
+| Metric | Default Model | Input | Output | Example Code |
+| --- | --- | --- | --- | --- |
+| PickScore | [AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/pickscore.py) |
+| ImageReward | [ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/image_reward.py) |
+| HPSv2 | [AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/hpsv2.py) |
+| HPSv3 | [MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/hpsv3.py) |
+| CLIP Score | [AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K) | prompt + PIL Image | Text-Image Similarity | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/clipscore.py) |
+| Aesthetic | [AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE) | PIL Image | Aesthetic Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/aesthetic.py) |
+| FID | [diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d) | reference image directory + generated image directory | Distribution Distance | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/fid.py) |
+
+## Single-Image Reward Models
+
+**PickScore**, **ImageReward**, **HPSv2**, **HPSv3**, and **CLIP Score** share the same input format: a text prompt and an opened `PIL.Image.Image`. Example:
+
+```python
+from PIL import Image
+from diffsynth.metrics import CLIPMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = CLIPMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+scores = metric.calc_scores(prompt, image)[0]
+```
+
+If you want to evaluate multiple images, you can pass a list of PIL images:
+
+```python
+scores = metric.calc_scores(prompt, [image1, image2, image3])
+```
+
+When the prompt is a single string, the same prompt will be used for every image. When the prompt is a list of strings, the number of prompts must match the number of images.
+
+## Aesthetic
+
+Aesthetic only evaluates the aesthetic quality of the image and does not use a prompt.
+
+```python
+from PIL import Image
+from diffsynth.metrics import AestheticMetric
+
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+metric = AestheticMetric.from_pretrained(device="cuda")
+score = metric.calc_scores(image)[0]
+```
+
+## FID
+
+FID is used to compare the feature distributions of two sets of images. It does not score single images, nor does it use a prompt. A typical use case is comparing a directory of real reference images against a directory of generated results:
+
+```python
+from diffsynth.metrics import FIDMetric
+
+reference_dir = FIDMetric.default_reference_dir(
+ local_dir="data/examples/ImageQualityMetric/reference/coco_2014_caption_validation",
+ max_images=10000,
+)
+generated_dir = ""
+
+metric = FIDMetric.from_pretrained(device="cuda", batch_size=16)
+score = metric.compute(reference_dir, generated_dir)
+print("FID:", score)
+```
+
+The reference for FID is not a single, fixed official answer. For general text-to-image quality evaluation, the COCO validation set is a convenient default choice; for vertical tasks such as portraits, product images, or medical images, a `reference_dir` consisting of real data from that specific domain should be provided.
+
+## Important Notes
+
+* The scores from PickScore, ImageReward, HPSv2, HPSv3, CLIPScore, and Aesthetic are suitable for relative comparison within the same metric. It is not recommended to directly compare the numerical values across different metrics.
+* HPSv3 is based on Qwen2-VL and is a larger model, requiring significantly more VRAM than CLIP-based metrics.
+* FID is sensitive to the choice of reference, the reference sample size, and the generated sample size.
\ No newline at end of file
diff --git a/docs/en/Model_Details/Overview.md b/docs/en/Model_Details/Overview.md
index 286141e83..a5aa9f5f2 100644
--- a/docs/en/Model_Details/Overview.md
+++ b/docs/en/Model_Details/Overview.md
@@ -289,3 +289,42 @@ graph LR;
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)
+
+
+## Image Quality Evaluation Metrics
+
+Documentation: [./Image-Quality-Metrics.md](../Model_Details/Image-Quality-Metrics.md)
+
+
+
+Quick Start
+
+```python
+from PIL import Image
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+score = metric.calc_scores(prompt, image)[0]
+print("PickScore:", score)
+```
+
+
+
+| Metric | Default Model | Example Code |
+|-|-|-|
+|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/pickscore.py)|
+|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/image_reward.py)|
+|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv2.py)|
+|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/clipscore.py)|
+|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/aesthetic.py)|
+|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/fid.py)|
\ No newline at end of file
diff --git a/docs/zh/Model_Details/Image-Quality-Metrics.md b/docs/zh/Model_Details/Image-Quality-Metrics.md
new file mode 100644
index 000000000..67f99cbab
--- /dev/null
+++ b/docs/zh/Model_Details/Image-Quality-Metrics.md
@@ -0,0 +1,118 @@
+# 图像质量评估指标
+
+DiffSynth-Studio 在 `diffsynth.metrics` 中提供了一组图像质量评估指标和奖励模型,用于评估生成图像的文本对齐、审美质量、人类偏好和图像分布质量。这些指标的示例代码位于 [`examples/image_quality_metric/`](../../../examples/image_quality_metric/)。
+
+## 安装
+
+在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
+
+## 快速开始
+
+运行以下代码可以快速加载 PickScore,并对一张图像和一段提示词进行评分。默认模型会从 ModelScope 下载到 `./models`。
+
+```python
+from PIL import Image
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+score = metric.calc_scores(prompt, image)[0]
+print("PickScore:", score)
+```
+
+## 指标总览
+
+|指标|默认模型|输入|输出|示例代码|
+|-|-|-|-|-|
+|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/pickscore.py)|
+|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/image_reward.py)|
+|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv2.py)|
+|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|prompt + PIL 图像|图文相似度|[code](../../../examples/image_quality_metric/clipscore.py)|
+|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|PIL 图像|美学分数|[code](../../../examples/image_quality_metric/aesthetic.py)|
+|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|reference 图像目录 + generated 图像目录|分布距离|[code](../../../examples/image_quality_metric/fid.py)|
+
+## 单图奖励模型
+
+**PickScore**、**ImageReward**、**HPSv2**、**HPSv3** 和 **CLIP Score** 的输入形式相同:一段文本提示词和一张已经打开的 `PIL.Image.Image`。示例:
+
+```python
+from PIL import Image
+from diffsynth.metrics import CLIPMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = CLIPMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+scores = metric.calc_scores(prompt, image)[0]
+```
+
+如果要评估多张图像,可以传入 PIL 图像列表:
+
+```python
+scores = metric.calc_scores(prompt, [image1, image2, image3])
+```
+
+其中 prompt 为单个字符串时,会对每张图像使用同一个 prompt。prompt 为字符串列表时,prompt 数量需要和图像数量一致。
+
+## Aesthetic
+
+Aesthetic 只评估图像审美质量,不使用 prompt。
+
+```python
+from PIL import Image
+from diffsynth.metrics import AestheticMetric
+
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+metric = AestheticMetric.from_pretrained(device="cuda")
+score = metric.calc_scores(image)[0]
+```
+
+## FID
+
+FID 用于比较两组图像的特征分布。它不是单图打分,也不使用 prompt。典型用法是比较真实参考图像目录和生成结果目录:
+
+```python
+from diffsynth.metrics import FIDMetric
+
+reference_dir = FIDMetric.default_reference_dir(
+ local_dir="data/examples/ImageQualityMetric/reference/coco_2014_caption_validation",
+ max_images=10000,
+)
+generated_dir = ""
+
+metric = FIDMetric.from_pretrained(device="cuda", batch_size=16)
+score = metric.compute(reference_dir, generated_dir)
+print("FID:", score)
+```
+
+FID 的 reference 不是固定唯一的官方答案。对于通用文生图质量评估,COCO validation 是一个方便的默认选择;对于人像、商品图、医学等垂直任务,应传入该领域真实数据构成的 `reference_dir`。
+
+
+## 注意事项
+
+* PickScore、ImageReward、HPSv2、HPSv3、CLIPScore、Aesthetic 的分数适合做同一指标内部的相对比较,不建议直接把不同指标的数值大小相互比较。
+* HPSv3 基于 Qwen2-VL,模型较大,显存需求明显高于 CLIP 类指标。
+* FID 对 reference 选择、样本量和 generated 样本量较敏感。
diff --git a/docs/zh/Model_Details/Overview.md b/docs/zh/Model_Details/Overview.md
index cdfdce9f2..3f9a33579 100644
--- a/docs/zh/Model_Details/Overview.md
+++ b/docs/zh/Model_Details/Overview.md
@@ -286,3 +286,42 @@ graph LR;
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
+
+
+## 图像质量评估指标
+
+文档:[./Image-Quality-Metrics.md](../Model_Details/Image-Quality-Metrics.md)
+
+
+
+快速开始
+
+```python
+from PIL import Image
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+score = metric.calc_scores(prompt, image)[0]
+print("PickScore:", score)
+```
+
+
+
+|指标|默认模型|示例代码|
+|-|-|-|
+|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/pickscore.py)|
+|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/image_reward.py)|
+|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv2.py)|
+|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/clipscore.py)|
+|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/aesthetic.py)|
+|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/fid.py)|
\ No newline at end of file
diff --git a/examples/image_quality_metric/aesthetic.py b/examples/image_quality_metric/aesthetic.py
new file mode 100644
index 000000000..2271344df
--- /dev/null
+++ b/examples/image_quality_metric/aesthetic.py
@@ -0,0 +1,13 @@
+from PIL import Image
+from diffsynth.metrics import AestheticMetric, ModelConfig
+
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = AestheticMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE"),
+ device=device,
+)
+
+print("Aesthetic score:", metric.calc_scores(image)[0])
diff --git a/examples/image_quality_metric/clipscore.py b/examples/image_quality_metric/clipscore.py
new file mode 100644
index 000000000..563320247
--- /dev/null
+++ b/examples/image_quality_metric/clipscore.py
@@ -0,0 +1,14 @@
+from PIL import Image
+from diffsynth.metrics import CLIPMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = CLIPMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+
+print("CLIP score:", metric.calc_scores(prompt, image)[0])
diff --git a/examples/image_quality_metric/fid.py b/examples/image_quality_metric/fid.py
new file mode 100644
index 000000000..c8aa415f7
--- /dev/null
+++ b/examples/image_quality_metric/fid.py
@@ -0,0 +1,16 @@
+from diffsynth.metrics import FIDMetric
+
+generated_dir = ""
+device = "cuda"
+
+reference_dir = FIDMetric.default_reference_dir(
+ local_dir="data/examples/ImageQualityMetric/reference/coco_2014_caption_validation",
+ max_images=10000, # use None for the full validation split
+)
+
+metric = FIDMetric.from_pretrained(
+ device=device,
+ batch_size=16,
+)
+
+print("FID score:", metric.compute(reference_dir, generated_dir))
diff --git a/examples/image_quality_metric/hpsv2.py b/examples/image_quality_metric/hpsv2.py
new file mode 100644
index 000000000..2a67317dd
--- /dev/null
+++ b/examples/image_quality_metric/hpsv2.py
@@ -0,0 +1,16 @@
+from PIL import Image
+from diffsynth.metrics import HPSv2Metric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = HPSv2Metric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/HPSv2"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ version="v2.0",
+ device=device,
+)
+
+print("HPSv2 score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/hpsv3.py b/examples/image_quality_metric/hpsv3.py
new file mode 100644
index 000000000..0f3fa35f1
--- /dev/null
+++ b/examples/image_quality_metric/hpsv3.py
@@ -0,0 +1,15 @@
+from PIL import Image
+from diffsynth.metrics import HPSv3Metric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = HPSv3Metric.from_pretrained(
+ model_config=ModelConfig(model_id="MizzenAI/HPSv3"),
+ base_model_config=ModelConfig(model_id="Qwen/Qwen2-VL-7B-Instruct"),
+ device=device,
+)
+
+print("HPSv3 score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/image_reward.py b/examples/image_quality_metric/image_reward.py
new file mode 100644
index 000000000..f7845a2f5
--- /dev/null
+++ b/examples/image_quality_metric/image_reward.py
@@ -0,0 +1,14 @@
+from PIL import Image
+from diffsynth.metrics import ImageRewardMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = ImageRewardMetric.from_pretrained(
+ model_config=ModelConfig(model_id="ZhipuAI/ImageReward"),
+ device=device,
+)
+
+print("ImageReward score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/pickscore.py b/examples/image_quality_metric/pickscore.py
new file mode 100644
index 000000000..509877bf9
--- /dev/null
+++ b/examples/image_quality_metric/pickscore.py
@@ -0,0 +1,15 @@
+from PIL import Image
+from diffsynth.metrics import PickScoreMetric, ModelConfig
+
+prompt = ""
+path_to_image = ""
+image = Image.open(path_to_image).convert("RGB")
+device = "cuda"
+
+metric = PickScoreMetric.from_pretrained(
+ model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
+ processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ device=device,
+)
+
+print("PickScore score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
From 2ea54427a56484c3530c751417c6db1026c1bb78 Mon Sep 17 00:00:00 2001
From: yjy415 <2471352175@qq.com>
Date: Tue, 19 May 2026 15:52:35 +0800
Subject: [PATCH 2/4] fix: metric
---
diffsynth/metrics/aesthetic.py | 9 ++---
diffsynth/metrics/base.py | 23 ++++++++-----
diffsynth/metrics/clip.py | 9 ++---
diffsynth/metrics/fid.py | 26 ++++-----------
diffsynth/metrics/hpsv2.py | 16 ++-------
diffsynth/metrics/hpsv3.py | 16 ++-------
diffsynth/metrics/image_reward.py | 9 ++---
diffsynth/metrics/pickscore.py | 16 ++-------
.../en/Model_Details/Image-Quality-Metrics.md | 33 +++++++++++--------
docs/en/Model_Details/Overview.md | 14 ++++----
.../zh/Model_Details/Image-Quality-Metrics.md | 17 ++++++----
docs/zh/Model_Details/Overview.md | 14 ++++----
examples/image_quality_metric/aesthetic.py | 14 +++++---
examples/image_quality_metric/clipscore.py | 17 +++++++---
examples/image_quality_metric/fid.py | 9 ++++-
examples/image_quality_metric/hpsv2.py | 19 +++++++----
examples/image_quality_metric/hpsv3.py | 17 +++++++---
examples/image_quality_metric/image_reward.py | 17 +++++++---
examples/image_quality_metric/pickscore.py | 17 +++++++---
19 files changed, 160 insertions(+), 152 deletions(-)
diff --git a/diffsynth/metrics/aesthetic.py b/diffsynth/metrics/aesthetic.py
index c2f676638..bc5e0c1ed 100644
--- a/diffsynth/metrics/aesthetic.py
+++ b/diffsynth/metrics/aesthetic.py
@@ -10,14 +10,10 @@ def __init__(self, model: AestheticModel):
super().__init__()
self.model = model
- @staticmethod
- def default_model_config():
- return AestheticMetric.local_or_modelscope_config("AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE")
-
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = None,
+ model_config: Union[ModelConfig, str] = "AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE",
clip_config: Union[ModelConfig, str] = None,
clip_processor_config: Union[ModelConfig, str] = None,
torch_dtype: torch.dtype = None,
@@ -25,7 +21,6 @@ def from_pretrained(
clip_kwargs: dict = None,
processor_kwargs: dict = None,
):
- model_config = cls.default_model_config() if model_config is None else model_config
model_config = cls.resolve_model_config(model_config)
clip_config = cls.resolve_model_config(clip_config) if clip_config is not None else None
clip_processor_config = cls.resolve_model_config(clip_processor_config) if clip_processor_config is not None else clip_config
@@ -45,7 +40,7 @@ def score(self, images):
scores = self.model(images)
return self.tensor_to_list(scores)
- def calc_scores(self, images):
+ def compute(self, images):
return self.score(images)
def forward(self, images):
diff --git a/diffsynth/metrics/base.py b/diffsynth/metrics/base.py
index 3ad95a28f..dab54e420 100644
--- a/diffsynth/metrics/base.py
+++ b/diffsynth/metrics/base.py
@@ -19,17 +19,22 @@ def tensor_to_float(value):
return float(value)
@staticmethod
- def resolve_model_config(config: Union[ModelConfig, str, Path]):
- if isinstance(config, (str, Path)):
+ def resolve_model_config(config: Union[ModelConfig, str, Path], origin_file_pattern: str = ""):
+ if config is None:
+ return None
+ if isinstance(config, Path):
config = ModelConfig(path=str(config))
+ elif isinstance(config, str):
+ path = Path(config).expanduser()
+ if path.exists() or path.is_absolute() or config.startswith(("./", "../", "~")) or ("/" not in config and path.suffix):
+ config = ModelConfig(path=str(path))
+ else:
+ local_path = Path("./models") / config
+ if not origin_file_pattern and local_path.exists():
+ config = ModelConfig(path=str(local_path))
+ else:
+ config = ModelConfig(model_id=config, origin_file_pattern=origin_file_pattern)
if config is None:
return None
config.download_if_necessary()
return config
-
- @staticmethod
- def local_or_modelscope_config(model_id: str, origin_file_pattern: str = ""):
- local_path = Path("./models") / model_id
- if local_path.exists():
- return ModelConfig(path=str(local_path))
- return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
diff --git a/diffsynth/metrics/clip.py b/diffsynth/metrics/clip.py
index 94c0219ed..837c7f8a1 100644
--- a/diffsynth/metrics/clip.py
+++ b/diffsynth/metrics/clip.py
@@ -10,14 +10,10 @@ def __init__(self, model: CLIPModel):
super().__init__()
self.model = model
- @staticmethod
- def default_model_config():
- return CLIPMetric.local_or_modelscope_config("AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K")
-
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = None,
+ model_config: Union[ModelConfig, str] = "AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K",
processor_config: Union[ModelConfig, str] = None,
torch_dtype: torch.dtype = None,
device: Union[str, torch.device] = get_device_type(),
@@ -25,7 +21,6 @@ def from_pretrained(
model_kwargs: dict = None,
processor_kwargs: dict = None,
):
- model_config = cls.default_model_config() if model_config is None else model_config
model_config = cls.resolve_model_config(model_config)
processor_config = cls.resolve_model_config(processor_config) if processor_config is not None else model_config
model = CLIPModel.from_pretrained(
@@ -57,7 +52,7 @@ def similarity_matrix(
scores = self.model.similarity_matrix(prompt, images)
return self.tensor_to_list(scores)
- def calc_scores(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: Union[str, list[str]], images):
return self.score(prompt, images)
def forward(self, prompt: Union[str, list[str]], images):
diff --git a/diffsynth/metrics/fid.py b/diffsynth/metrics/fid.py
index b740b0068..5d06e52c9 100644
--- a/diffsynth/metrics/fid.py
+++ b/diffsynth/metrics/fid.py
@@ -20,18 +20,13 @@ class FIDMetric(Metric):
DEFAULT_REFERENCE_DATASET_ID = "modelscope/coco_2014_caption"
DEFAULT_REFERENCE_METADATA_URL = "https://modelscope.oss-cn-beijing.aliyuncs.com/open_data/coco_2014_caption/val2014.csv.zip"
DEFAULT_REFERENCE_SPLIT = "validation"
+ DEFAULT_WEIGHTS_ID = "diffusionTry/weights-inception-2015-12-05-6726825d"
+ DEFAULT_WEIGHTS_FILE = "weights-inception-2015-12-05-6726825d.pth"
def __init__(self, model: FIDModel):
super().__init__()
self.model = model
- @staticmethod
- def default_weights_config():
- return ModelConfig(
- model_id="diffusionTry/weights-inception-2015-12-05-6726825d",
- origin_file_pattern="weights-inception-2015-12-05-6726825d.pth",
- )
-
@staticmethod
def default_reference_root():
base_path = os.environ.get("DIFFSYNTH_DATA_BASE_PATH", "./data")
@@ -105,14 +100,6 @@ def download_reference_dir(
retries: int = 3,
verbose: bool = True,
):
- """
- Download the default COCO 2014 caption validation reference images.
-
- The ModelScope dataset stores a small CSV archive whose image column
- points to ModelScope OSS image URLs. This helper downloads that metadata
- and materializes the referenced real images as a normal image directory.
- """
-
root = Path(local_dir) if local_dir is not None else cls.default_reference_root()
images_dir = root / "images"
metadata_dir = root / "metadata"
@@ -191,16 +178,17 @@ def default_reference_dir(
@classmethod
def from_pretrained(
cls,
- weights_config: Union[ModelConfig, str] = None,
+ weights_config: Union[ModelConfig, str] = DEFAULT_WEIGHTS_ID,
pretrained: bool = True,
device: Union[str, torch.device] = get_device_type(),
batch_size: int = 50,
num_workers: int = 0,
use_fid_inception: bool = True,
):
- if weights_config is None and use_fid_inception:
- weights_config = cls.default_weights_config()
- weights_config = cls.resolve_model_config(weights_config) if weights_config is not None else None
+ weights_config = cls.resolve_model_config(
+ weights_config,
+ origin_file_pattern=cls.DEFAULT_WEIGHTS_FILE if weights_config == cls.DEFAULT_WEIGHTS_ID else "",
+ ) if weights_config is not None and use_fid_inception else None
model = FIDModel.from_pretrained(
weights_path=None if weights_config is None else weights_config.path,
pretrained=pretrained,
diff --git a/diffsynth/metrics/hpsv2.py b/diffsynth/metrics/hpsv2.py
index 4760ee11c..f9c6258e3 100644
--- a/diffsynth/metrics/hpsv2.py
+++ b/diffsynth/metrics/hpsv2.py
@@ -10,27 +10,17 @@ def __init__(self, model: HPSv2Model):
super().__init__()
self.model = model
- @staticmethod
- def default_model_config():
- return HPSv2Metric.local_or_modelscope_config("AI-ModelScope/HPSv2")
-
- @staticmethod
- def default_processor_config():
- return HPSv2Metric.local_or_modelscope_config("AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K")
-
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = None,
- processor_config: Union[ModelConfig, str] = None,
+ model_config: Union[ModelConfig, str] = "AI-ModelScope/HPSv2",
+ processor_config: Union[ModelConfig, str] = "AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K",
version: str = "v2.0",
torch_dtype: torch.dtype = None,
device: Union[str, torch.device] = get_device_type(),
model_kwargs: dict = None,
processor_kwargs: dict = None,
):
- model_config = cls.default_model_config() if model_config is None else model_config
- processor_config = cls.default_processor_config() if processor_config is None else processor_config
model_config = cls.resolve_model_config(model_config)
processor_config = cls.resolve_model_config(processor_config)
model = HPSv2Model.from_pretrained(
@@ -49,7 +39,7 @@ def score(self, prompt: Union[str, list[str]], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)
- def calc_scores(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: Union[str, list[str]], images):
return self.score(prompt, images)
def forward(self, prompt: Union[str, list[str]], images):
diff --git a/diffsynth/metrics/hpsv3.py b/diffsynth/metrics/hpsv3.py
index 9c6de7bfd..9ec2cb0c3 100644
--- a/diffsynth/metrics/hpsv3.py
+++ b/diffsynth/metrics/hpsv3.py
@@ -11,19 +11,11 @@ def __init__(self, model: HPSv3Model):
super().__init__()
self.model = model
- @staticmethod
- def default_model_config():
- return HPSv3Metric.local_or_modelscope_config("MizzenAI/HPSv3")
-
- @staticmethod
- def default_base_model_config():
- return HPSv3Metric.local_or_modelscope_config("Qwen/Qwen2-VL-7B-Instruct")
-
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = None,
- base_model_config: Union[ModelConfig, str] = None,
+ model_config: Union[ModelConfig, str] = "MizzenAI/HPSv3",
+ base_model_config: Union[ModelConfig, str] = "Qwen/Qwen2-VL-7B-Instruct",
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
output_dim: int = 2,
@@ -34,8 +26,6 @@ def from_pretrained(
model_kwargs: dict = None,
processor_kwargs: dict = None,
):
- model_config = cls.default_model_config() if model_config is None else model_config
- base_model_config = cls.default_base_model_config() if base_model_config is None else base_model_config
model_config = cls.resolve_model_config(model_config)
base_model_config = cls.resolve_model_config(base_model_config)
model = HPSv3Model.from_pretrained(
@@ -58,7 +48,7 @@ def score(self, prompt: Union[str, list[str]], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)
- def calc_scores(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: Union[str, list[str]], images):
return self.score(prompt, images)
def forward(self, prompt: Union[str, list[str]], images):
diff --git a/diffsynth/metrics/image_reward.py b/diffsynth/metrics/image_reward.py
index 51ccd7348..feb156058 100644
--- a/diffsynth/metrics/image_reward.py
+++ b/diffsynth/metrics/image_reward.py
@@ -23,10 +23,6 @@ def __init__(self, model: ImageRewardModel):
super().__init__()
self.model = model
- @staticmethod
- def default_model_config():
- return ImageRewardMetric.local_or_modelscope_config("ZhipuAI/ImageReward")
-
@staticmethod
def default_tokenizer_config():
local_path = Path(os.environ.get("DIFFSYNTH_MODEL_BASE_PATH", "./models")) / ImageRewardMetric.BERT_TOKENIZER_MODEL_ID
@@ -62,7 +58,7 @@ def _as_directory_path(path):
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = None,
+ model_config: Union[ModelConfig, str] = "ZhipuAI/ImageReward",
med_config: Union[ModelConfig, str] = None,
tokenizer_config: Union[ModelConfig, str] = None,
torch_dtype: torch.dtype = None,
@@ -71,7 +67,6 @@ def from_pretrained(
model_kwargs: dict = None,
tokenizer_kwargs: dict = None,
):
- model_config = cls.default_model_config() if model_config is None else model_config
tokenizer_config = cls.default_tokenizer_config() if tokenizer_config is None else tokenizer_config
model_config = cls.resolve_model_config(model_config)
med_config = cls.resolve_model_config(med_config) if med_config is not None else None
@@ -93,7 +88,7 @@ def score(self, prompt: Union[str, list[str]], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)
- def calc_scores(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: Union[str, list[str]], images):
return self.score(prompt, images)
def forward(self, prompt: Union[str, list[str]], images):
diff --git a/diffsynth/metrics/pickscore.py b/diffsynth/metrics/pickscore.py
index a71ae2619..f3d687301 100644
--- a/diffsynth/metrics/pickscore.py
+++ b/diffsynth/metrics/pickscore.py
@@ -9,28 +9,18 @@ class PickScoreMetric(Metric):
def __init__(self, model: PickScoreModel):
super().__init__()
self.model = model
-
- @staticmethod
- def default_model_config():
- return PickScoreMetric.local_or_modelscope_config("AI-ModelScope/PickScore_v1")
-
- @staticmethod
- def default_processor_config():
- return PickScoreMetric.local_or_modelscope_config("AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K")
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = None,
- processor_config: Union[ModelConfig, str] = None,
+ model_config: Union[ModelConfig, str] = "AI-ModelScope/PickScore_v1",
+ processor_config: Union[ModelConfig, str] = "AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K",
torch_dtype: torch.dtype = None,
device: Union[str, torch.device] = get_device_type(),
max_length: int = 77,
model_kwargs: dict = None,
processor_kwargs: dict = None,
):
- model_config = cls.default_model_config() if model_config is None else model_config
- processor_config = cls.default_processor_config() if processor_config is None else processor_config
model_config = cls.resolve_model_config(model_config)
processor_config = cls.resolve_model_config(processor_config)
model = PickScoreModel.from_pretrained(
@@ -66,7 +56,7 @@ def probabilities(
def calc_probs(self, prompt: Union[str, list[str]], images):
return self.probabilities(prompt, images)
- def calc_scores(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: Union[str, list[str]], images):
return self.score(prompt, images)
def forward(self, prompt: Union[str, list[str]], images):
diff --git a/docs/en/Model_Details/Image-Quality-Metrics.md b/docs/en/Model_Details/Image-Quality-Metrics.md
index 34e5778e8..51c44511c 100644
--- a/docs/en/Model_Details/Image-Quality-Metrics.md
+++ b/docs/en/Model_Details/Image-Quality-Metrics.md
@@ -19,12 +19,19 @@ For more information about installation, please refer to [Install Dependencies](
Run the following code to quickly load PickScore and score an image against a prompt. The default models will be downloaded from ModelScope to `./models`.
```python
-from PIL import Image
+import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = PickScoreMetric.from_pretrained(
@@ -32,21 +39,21 @@ metric = PickScoreMetric.from_pretrained(
processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
device=device,
)
-score = metric.calc_scores(prompt, image)[0]
-print("PickScore:", score)
+
+print("PickScore score:", metric.compute(prompt, image)[0])
```
## Metrics Overview
| Metric | Default Model | Input | Output | Example Code |
| --- | --- | --- | --- | --- |
-| PickScore | [AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/pickscore.py) |
-| ImageReward | [ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/image_reward.py) |
-| HPSv2 | [AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/hpsv2.py) |
-| HPSv3 | [MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3) | prompt + PIL Image | Preference Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/hpsv3.py) |
-| CLIP Score | [AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K) | prompt + PIL Image | Text-Image Similarity | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/clipscore.py) |
-| Aesthetic | [AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE) | PIL Image | Aesthetic Score | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/aesthetic.py) |
-| FID | [diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d) | reference image directory + generated image directory | Distribution Distance | [code](https://www.google.com/search?q=../../../examples/image_quality_metric/fid.py) |
+| PickScore | [AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/pickscore.py) |
+| ImageReward | [ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/image_reward.py) |
+| HPSv2 | [AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv2.py) |
+| HPSv3 | [MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv3.py) |
+| CLIP Score | [AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K) | prompt + PIL Image | Text-Image Similarity | [code](../../../examples/image_quality_metric/clipscore.py) |
+| Aesthetic | [AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE) | PIL Image | Aesthetic Score | [code](../../../examples/image_quality_metric/aesthetic.py) |
+| FID | [diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d) | reference image directory + generated image directory | Distribution Distance | [code](../../../examples/image_quality_metric/fid.py) |
## Single-Image Reward Models
diff --git a/docs/en/Model_Details/Overview.md b/docs/en/Model_Details/Overview.md
index a5aa9f5f2..695731d28 100644
--- a/docs/en/Model_Details/Overview.md
+++ b/docs/en/Model_Details/Overview.md
@@ -321,10 +321,10 @@ print("PickScore:", score)
| Metric | Default Model | Example Code |
|-|-|-|
-|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/pickscore.py)|
-|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/image_reward.py)|
-|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv2.py)|
-|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv3.py)|
-|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/clipscore.py)|
-|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/aesthetic.py)|
-|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/fid.py)|
\ No newline at end of file
+|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](../../../examples/image_quality_metric/pickscore.py)|
+|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](../../../examples/image_quality_metric/image_reward.py)|
+|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](../../../examples/image_quality_metric/hpsv2.py)|
+|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](../../../examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](../../../examples/image_quality_metric/clipscore.py)|
+|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](../../../examples/image_quality_metric/aesthetic.py)|
+|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](../../../examples/image_quality_metric/fid.py)|
\ No newline at end of file
diff --git a/docs/zh/Model_Details/Image-Quality-Metrics.md b/docs/zh/Model_Details/Image-Quality-Metrics.md
index 67f99cbab..a9c756ea8 100644
--- a/docs/zh/Model_Details/Image-Quality-Metrics.md
+++ b/docs/zh/Model_Details/Image-Quality-Metrics.md
@@ -19,12 +19,19 @@ pip install -e .
运行以下代码可以快速加载 PickScore,并对一张图像和一段提示词进行评分。默认模型会从 ModelScope 下载到 `./models`。
```python
-from PIL import Image
+import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = PickScoreMetric.from_pretrained(
@@ -32,8 +39,6 @@ metric = PickScoreMetric.from_pretrained(
processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
device=device,
)
-score = metric.calc_scores(prompt, image)[0]
-print("PickScore:", score)
```
## 指标总览
diff --git a/docs/zh/Model_Details/Overview.md b/docs/zh/Model_Details/Overview.md
index 3f9a33579..451251c03 100644
--- a/docs/zh/Model_Details/Overview.md
+++ b/docs/zh/Model_Details/Overview.md
@@ -318,10 +318,10 @@ print("PickScore:", score)
|指标|默认模型|示例代码|
|-|-|-|
-|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/pickscore.py)|
-|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/image_reward.py)|
-|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv2.py)|
-|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/hpsv3.py)|
-|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/clipscore.py)|
-|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/aesthetic.py)|
-|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/image_quality_metric/fid.py)|
\ No newline at end of file
+|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|[code](../../../examples/image_quality_metric/pickscore.py)|
+|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|[code](../../../examples/image_quality_metric/image_reward.py)|
+|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|[code](../../../examples/image_quality_metric/hpsv2.py)|
+|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|[code](../../../examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|[code](../../../examples/image_quality_metric/clipscore.py)|
+|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|[code](../../../examples/image_quality_metric/aesthetic.py)|
+|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|[code](../../../examples/image_quality_metric/fid.py)|
\ No newline at end of file
diff --git a/examples/image_quality_metric/aesthetic.py b/examples/image_quality_metric/aesthetic.py
index 2271344df..962eeeddf 100644
--- a/examples/image_quality_metric/aesthetic.py
+++ b/examples/image_quality_metric/aesthetic.py
@@ -1,8 +1,14 @@
-from PIL import Image
from diffsynth.metrics import AestheticMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
device = "cuda"
metric = AestheticMetric.from_pretrained(
@@ -10,4 +16,4 @@
device=device,
)
-print("Aesthetic score:", metric.calc_scores(image)[0])
+print("Aesthetic score:", metric.compute(image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/clipscore.py b/examples/image_quality_metric/clipscore.py
index 563320247..984118137 100644
--- a/examples/image_quality_metric/clipscore.py
+++ b/examples/image_quality_metric/clipscore.py
@@ -1,9 +1,16 @@
-from PIL import Image
+import csv
from diffsynth.metrics import CLIPMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = CLIPMetric.from_pretrained(
@@ -11,4 +18,4 @@
device=device,
)
-print("CLIP score:", metric.calc_scores(prompt, image)[0])
+print("CLIP score:", metric.compute(prompt, image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/fid.py b/examples/image_quality_metric/fid.py
index c8aa415f7..6a7d25d65 100644
--- a/examples/image_quality_metric/fid.py
+++ b/examples/image_quality_metric/fid.py
@@ -1,6 +1,13 @@
from diffsynth.metrics import FIDMetric
+from modelscope import dataset_snapshot_download
-generated_dir = ""
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+generated_dir = "data/diffsynth_example_dataset/flux/FLUX.1-dev"
device = "cuda"
reference_dir = FIDMetric.default_reference_dir(
diff --git a/examples/image_quality_metric/hpsv2.py b/examples/image_quality_metric/hpsv2.py
index 2a67317dd..b615174c8 100644
--- a/examples/image_quality_metric/hpsv2.py
+++ b/examples/image_quality_metric/hpsv2.py
@@ -1,16 +1,23 @@
-from PIL import Image
+import csv
from diffsynth.metrics import HPSv2Metric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = HPSv2Metric.from_pretrained(
model_config=ModelConfig(model_id="AI-ModelScope/HPSv2"),
processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- version="v2.0",
+ version="v2.0", # choice: v2.0, v2.1
device=device,
)
-print("HPSv2 score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
+print("HPSv2 score:", metric.compute(prompt, image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/hpsv3.py b/examples/image_quality_metric/hpsv3.py
index 0f3fa35f1..c436f5ded 100644
--- a/examples/image_quality_metric/hpsv3.py
+++ b/examples/image_quality_metric/hpsv3.py
@@ -1,9 +1,16 @@
-from PIL import Image
+import csv
from diffsynth.metrics import HPSv3Metric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = HPSv3Metric.from_pretrained(
@@ -12,4 +19,4 @@
device=device,
)
-print("HPSv3 score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
+print("HPSv3 score:", metric.compute(prompt, image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/image_reward.py b/examples/image_quality_metric/image_reward.py
index f7845a2f5..f9364364d 100644
--- a/examples/image_quality_metric/image_reward.py
+++ b/examples/image_quality_metric/image_reward.py
@@ -1,9 +1,16 @@
-from PIL import Image
+import csv
from diffsynth.metrics import ImageRewardMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = ImageRewardMetric.from_pretrained(
@@ -11,4 +18,4 @@
device=device,
)
-print("ImageReward score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
+print("ImageReward score:", metric.compute(prompt, image)[0])
\ No newline at end of file
diff --git a/examples/image_quality_metric/pickscore.py b/examples/image_quality_metric/pickscore.py
index 509877bf9..cbcb4b8be 100644
--- a/examples/image_quality_metric/pickscore.py
+++ b/examples/image_quality_metric/pickscore.py
@@ -1,9 +1,16 @@
-from PIL import Image
+import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = PickScoreMetric.from_pretrained(
@@ -12,4 +19,4 @@
device=device,
)
-print("PickScore score:", metric.calc_scores(prompt, image)[0])
\ No newline at end of file
+print("PickScore score:", metric.compute(prompt, image)[0])
\ No newline at end of file
From b60e3c617a541f8063a58692daa882f62ba83ce4 Mon Sep 17 00:00:00 2001
From: yjy415 <2471352175@qq.com>
Date: Tue, 19 May 2026 15:58:04 +0800
Subject: [PATCH 3/4] fix: docs
---
docs/en/Model_Details/Overview.md | 19 +++++++++++++------
docs/zh/Model_Details/Overview.md | 19 +++++++++++++------
2 files changed, 26 insertions(+), 12 deletions(-)
diff --git a/docs/en/Model_Details/Overview.md b/docs/en/Model_Details/Overview.md
index 695731d28..fecead8bd 100644
--- a/docs/en/Model_Details/Overview.md
+++ b/docs/en/Model_Details/Overview.md
@@ -300,12 +300,19 @@ Documentation: [./Image-Quality-Metrics.md](../Model_Details/Image-Quality-Metri
Quick Start
```python
-from PIL import Image
+import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = PickScoreMetric.from_pretrained(
@@ -313,8 +320,8 @@ metric = PickScoreMetric.from_pretrained(
processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
device=device,
)
-score = metric.calc_scores(prompt, image)[0]
-print("PickScore:", score)
+
+print("PickScore score:", metric.compute(prompt, image)[0])
```
diff --git a/docs/zh/Model_Details/Overview.md b/docs/zh/Model_Details/Overview.md
index 451251c03..2c1d32849 100644
--- a/docs/zh/Model_Details/Overview.md
+++ b/docs/zh/Model_Details/Overview.md
@@ -297,12 +297,19 @@ graph LR;
快速开始
```python
-from PIL import Image
+import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
+from modelscope import dataset_snapshot_download
+from PIL import Image
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
+dataset_snapshot_download(
+ "DiffSynth-Studio/diffsynth_example_dataset",
+ allow_file_pattern="flux/FLUX.1-dev/*",
+ local_dir="./data/diffsynth_example_dataset",
+)
+
+image = Image.open("data/diffsynth_example_dataset/flux/FLUX.1-dev/1.jpg").convert("RGB")
+prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = PickScoreMetric.from_pretrained(
@@ -310,8 +317,8 @@ metric = PickScoreMetric.from_pretrained(
processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
device=device,
)
-score = metric.calc_scores(prompt, image)[0]
-print("PickScore:", score)
+
+print("PickScore score:", metric.compute(prompt, image)[0])
```
From cd62b2cfa4d1ba1df40fc2d8d920dbf00aadc408 Mon Sep 17 00:00:00 2001
From: yjy415 <2471352175@qq.com>
Date: Wed, 20 May 2026 15:20:11 +0800
Subject: [PATCH 4/4] fix: image metric
---
diffsynth/configs/model_configs.py | 49 +++-
diffsynth/metrics/aesthetic.py | 31 +-
diffsynth/metrics/base.py | 48 ++-
diffsynth/metrics/clip.py | 36 +--
diffsynth/metrics/fid.py | 209 +------------
diffsynth/metrics/hpsv2.py | 35 +--
diffsynth/metrics/hpsv3.py | 50 ++--
diffsynth/metrics/image_reward.py | 89 ++----
diffsynth/metrics/pickscore.py | 38 ++-
diffsynth/models/aesthetic.py | 221 ++++----------
diffsynth/models/clip.py | 91 +++---
diffsynth/models/fid.py | 114 ++-----
diffsynth/models/hpsv2.py | 167 ++---------
diffsynth/models/hpsv3.py | 277 ++++++++----------
diffsynth/models/image_reward.py | 223 ++++----------
diffsynth/models/pickscore.py | 47 +--
.../state_dict_converters/image_metrics.py | 121 ++++++++
.../en/Model_Details/Image-Quality-Metrics.md | 91 +++---
.../zh/Model_Details/Image-Quality-Metrics.md | 87 +++---
examples/image_quality_metric/aesthetic.py | 9 +-
examples/image_quality_metric/clipscore.py | 10 +-
examples/image_quality_metric/fid.py | 22 +-
examples/image_quality_metric/hpsv2.py | 12 +-
examples/image_quality_metric/hpsv3.py | 11 +-
examples/image_quality_metric/image_reward.py | 10 +-
examples/image_quality_metric/pickscore.py | 11 +-
26 files changed, 766 insertions(+), 1343 deletions(-)
create mode 100644 diffsynth/utils/state_dict_converters/image_metrics.py
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index ad6784ffc..32665e712 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -1026,7 +1026,54 @@
},
]
+image_metrics_series = [
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors")
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors")
+ "model_hash": "b5e2c0bfcbf4085ccdb2feb8f0ba408a",
+ "model_name": "image_metrics_clip_hf",
+ "model_class": "diffsynth.models.clip.ImageMetricsCLIPModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsCLIPStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors")
+ "model_hash": "f79e72cec8ae5a540cff0304bfb21b00",
+ "model_name": "image_metrics_hpsv2",
+ "model_class": "diffsynth.models.clip.ImageMetricsCLIPModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsOpenCLIPStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors")
+ "model_hash": "5655d9cde15b759cfeefe7432d7a912c",
+ "model_name": "image_metrics_hpsv3",
+ "model_class": "diffsynth.models.hpsv3.HPSv3Qwen2VLRewardModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsHPSv3StateDictConverter",
+ "extra_kwargs": {"vocab_size": 151658, "output_dim": 2, "reward_token": "special", "rm_head_type": "ranknet"},
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors")
+ "model_hash": "b3cc8e10b76ca98cde653daa5cf63139",
+ "model_name": "image_metrics_image_reward",
+ "model_class": "diffsynth.models.image_reward.ImageRewardModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsImageRewardStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors")
+ "model_hash": "306981222ec94302794e07cf676c84cc",
+ "model_name": "image_metrics_aesthetic",
+ "model_class": "diffsynth.models.aesthetic.AestheticModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsAestheticStateDictConverter",
+ },
+ {
+ # Example: ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors")
+ "model_hash": "d4e9549be726259b444d1f62db4ce413",
+ "model_name": "image_metrics_fid_inception",
+ "model_class": "diffsynth.models.fid.FIDInceptionModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.image_metrics.ImageMetricsFIDStateDictConverter",
+ },
+]
+
MODEL_CONFIGS = (
stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
- + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
+ + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + image_metrics_series
)
diff --git a/diffsynth/metrics/aesthetic.py b/diffsynth/metrics/aesthetic.py
index bc5e0c1ed..23d9b1957 100644
--- a/diffsynth/metrics/aesthetic.py
+++ b/diffsynth/metrics/aesthetic.py
@@ -1,9 +1,9 @@
-from typing import Union
import torch
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.aesthetic import AestheticModel
from .base import Metric
+from transformers import CLIPImageProcessor
class AestheticMetric(Metric):
def __init__(self, model: AestheticModel):
@@ -13,26 +13,21 @@ def __init__(self, model: AestheticModel):
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = "AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE",
- clip_config: Union[ModelConfig, str] = None,
- clip_processor_config: Union[ModelConfig, str] = None,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/"),
torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = get_device_type(),
- clip_kwargs: dict = None,
+ device: torch.device = get_device_type(),
processor_kwargs: dict = None,
+ vram_limit: float = None,
):
- model_config = cls.resolve_model_config(model_config)
- clip_config = cls.resolve_model_config(clip_config) if clip_config is not None else None
- clip_processor_config = cls.resolve_model_config(clip_processor_config) if clip_processor_config is not None else clip_config
- model = AestheticModel.from_pretrained(
- model_path=model_config.path,
- clip_model_path=None if clip_config is None else clip_config.path,
- clip_processor_path=None if clip_processor_config is None else clip_processor_config.path,
- torch_dtype=torch_dtype,
- device=device,
- clip_kwargs=clip_kwargs,
- processor_kwargs=processor_kwargs,
- )
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_aesthetic")
+ processor_config.download_if_necessary()
+ model.processor = CLIPImageProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model.layers = model.layers.float()
+ model = model.eval()
return cls(model)
@torch.no_grad()
diff --git a/diffsynth/metrics/base.py b/diffsynth/metrics/base.py
index dab54e420..784bf8eb1 100644
--- a/diffsynth/metrics/base.py
+++ b/diffsynth/metrics/base.py
@@ -1,40 +1,28 @@
-from pathlib import Path
-from typing import Union
import torch
from ..core import ModelConfig
+from ..models.model_loader import ModelPool
class Metric(torch.nn.Module):
+
@staticmethod
def tensor_to_list(value):
if torch.is_tensor(value):
value = value.detach().cpu().tolist()
- if not isinstance(value, list):
- return [value]
- return value
+ return value if isinstance(value, list) else [value]
@staticmethod
- def tensor_to_float(value):
- if torch.is_tensor(value):
- return float(value.detach().cpu())
- return float(value)
-
- @staticmethod
- def resolve_model_config(config: Union[ModelConfig, str, Path], origin_file_pattern: str = ""):
- if config is None:
- return None
- if isinstance(config, Path):
- config = ModelConfig(path=str(config))
- elif isinstance(config, str):
- path = Path(config).expanduser()
- if path.exists() or path.is_absolute() or config.startswith(("./", "../", "~")) or ("/" not in config and path.suffix):
- config = ModelConfig(path=str(path))
- else:
- local_path = Path("./models") / config
- if not origin_file_pattern and local_path.exists():
- config = ModelConfig(path=str(local_path))
- else:
- config = ModelConfig(model_id=config, origin_file_pattern=origin_file_pattern)
- if config is None:
- return None
- config.download_if_necessary()
- return config
+ def download_and_load_models(model_configs: list[ModelConfig], torch_dtype: torch.dtype = torch.float32, device="cuda", vram_limit: float = None):
+ model_pool = ModelPool()
+ for model_config in model_configs:
+ model_config.download_if_necessary()
+ vram_config = model_config.vram_config()
+ vram_config["computation_dtype"] = vram_config["computation_dtype"] or torch_dtype or torch.float32
+ vram_config["computation_device"] = vram_config["computation_device"] or device
+ model_pool.auto_load_model(
+ model_config.path,
+ vram_config=vram_config,
+ vram_limit=vram_limit,
+ clear_parameters=model_config.clear_parameters,
+ state_dict=model_config.state_dict,
+ )
+ return model_pool
diff --git a/diffsynth/metrics/clip.py b/diffsynth/metrics/clip.py
index 837c7f8a1..f15ad95f0 100644
--- a/diffsynth/metrics/clip.py
+++ b/diffsynth/metrics/clip.py
@@ -1,5 +1,5 @@
-from typing import Union
import torch
+from transformers import AutoProcessor
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.clip import CLIPModel
@@ -13,31 +13,27 @@ def __init__(self, model: CLIPModel):
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = "AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K",
- processor_config: Union[ModelConfig, str] = None,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/"),
torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = get_device_type(),
+ device: torch.device = get_device_type(),
max_length: int = 77,
- model_kwargs: dict = None,
processor_kwargs: dict = None,
+ vram_limit: float = None,
):
- model_config = cls.resolve_model_config(model_config)
- processor_config = cls.resolve_model_config(processor_config) if processor_config is not None else model_config
- model = CLIPModel.from_pretrained(
- model_path=model_config.path,
- processor_path=processor_config.path,
- torch_dtype=torch_dtype,
- device=device,
- max_length=max_length,
- model_kwargs=model_kwargs,
- processor_kwargs=processor_kwargs,
- )
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_clip_hf")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model = CLIPModel(model=model, processor=processor, max_length=max_length).eval()
return cls(model)
@torch.no_grad()
def score(
self,
- prompt: Union[str, list[str]],
+ prompt: str | list[str],
images,
):
scores = self.model(prompt, images)
@@ -46,14 +42,14 @@ def score(
@torch.no_grad()
def similarity_matrix(
self,
- prompt: Union[str, list[str]],
+ prompt: str | list[str],
images,
):
scores = self.model.similarity_matrix(prompt, images)
return self.tensor_to_list(scores)
- def compute(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)
- def forward(self, prompt: Union[str, list[str]], images):
+ def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
diff --git a/diffsynth/metrics/fid.py b/diffsynth/metrics/fid.py
index 5d06e52c9..1cc4429c1 100644
--- a/diffsynth/metrics/fid.py
+++ b/diffsynth/metrics/fid.py
@@ -1,225 +1,34 @@
-import csv
-import json
-import os
-from pathlib import Path
-from urllib.parse import urlparse
-from urllib.request import urlopen
-from zipfile import ZipFile
-from typing import Union
-
import torch
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
-from ..models.fid import FIDModel, IMAGE_EXTENSIONS
+from ..models.fid import FIDModel
from .base import Metric
class FIDMetric(Metric):
- DEFAULT_REFERENCE_NAME = "coco_2014_caption_validation"
- DEFAULT_REFERENCE_DATASET_ID = "modelscope/coco_2014_caption"
- DEFAULT_REFERENCE_METADATA_URL = "https://modelscope.oss-cn-beijing.aliyuncs.com/open_data/coco_2014_caption/val2014.csv.zip"
- DEFAULT_REFERENCE_SPLIT = "validation"
- DEFAULT_WEIGHTS_ID = "diffusionTry/weights-inception-2015-12-05-6726825d"
- DEFAULT_WEIGHTS_FILE = "weights-inception-2015-12-05-6726825d.pth"
-
def __init__(self, model: FIDModel):
super().__init__()
self.model = model
- @staticmethod
- def default_reference_root():
- base_path = os.environ.get("DIFFSYNTH_DATA_BASE_PATH", "./data")
- return Path(base_path) / "fid_reference" / FIDMetric.DEFAULT_REFERENCE_NAME
-
- @staticmethod
- def _image_files(path: Union[str, Path]):
- path = Path(path)
- if not path.exists():
- return []
- return sorted(item for item in path.rglob("*") if item.is_file() and item.suffix.lower() in IMAGE_EXTENSIONS)
-
- @staticmethod
- def _download_file(url: str, path: Path, timeout: int = 60, retries: int = 3):
- path.parent.mkdir(parents=True, exist_ok=True)
- if path.exists() and path.stat().st_size > 0:
- return path
- temp_path = path.with_suffix(path.suffix + ".tmp")
- last_error = None
- for _ in range(retries):
- try:
- with urlopen(url, timeout=timeout) as response, open(temp_path, "wb") as file:
- while True:
- chunk = response.read(1024 * 1024)
- if not chunk:
- break
- file.write(chunk)
- temp_path.replace(path)
- return path
- except Exception as error:
- last_error = error
- if temp_path.exists():
- temp_path.unlink()
- raise RuntimeError(f"Failed to download {url}: {last_error}")
-
- @staticmethod
- def _metadata_rows(metadata_zip_path: Path):
- with ZipFile(metadata_zip_path) as archive:
- csv_names = [name for name in archive.namelist() if name.endswith(".csv")]
- if not csv_names:
- raise ValueError(f"No CSV file found in {metadata_zip_path}.")
- with archive.open(csv_names[0]) as file:
- reader = csv.DictReader(line.decode("utf-8") for line in file)
- rows = []
- seen = set()
- for row in reader:
- url = row.get("image", "")
- if not url or url in seen:
- continue
- rows.append(row)
- seen.add(url)
- return rows
-
- @staticmethod
- def _image_filename(row: dict, index: int):
- url_path = urlparse(row["image"]).path
- name = Path(url_path).name
- if name and Path(name).suffix.lower() in IMAGE_EXTENSIONS:
- return name
- image_id = row.get("image_id") or f"{index:08d}"
- return f"{image_id}.jpg"
-
- @classmethod
- def download_reference_dir(
- cls,
- local_dir: Union[str, Path] = None,
- max_images: int = None,
- force: bool = False,
- metadata_url: str = None,
- timeout: int = 60,
- retries: int = 3,
- verbose: bool = True,
- ):
- root = Path(local_dir) if local_dir is not None else cls.default_reference_root()
- images_dir = root / "images"
- metadata_dir = root / "metadata"
- metadata_url = cls.DEFAULT_REFERENCE_METADATA_URL if metadata_url is None else metadata_url
- existing = cls._image_files(images_dir)
- manifest_path = root / "reference_manifest.json"
- if not force and existing:
- if max_images is not None and len(existing) >= max_images:
- return str(images_dir)
- if max_images is None and manifest_path.exists():
- with open(manifest_path, "r", encoding="utf-8") as file:
- manifest = json.load(file)
- if manifest.get("max_images") is None and len(existing) >= manifest.get("image_count", 0):
- return str(images_dir)
-
- metadata_zip_path = metadata_dir / "val2014.csv.zip"
- cls._download_file(metadata_url, metadata_zip_path, timeout=timeout, retries=retries)
- rows = cls._metadata_rows(metadata_zip_path)
- if max_images is not None:
- rows = rows[:max_images]
- if not rows:
- raise ValueError("No reference images were found in the COCO 2014 caption metadata.")
-
- images_dir.mkdir(parents=True, exist_ok=True)
- downloaded = 0
- for index, row in enumerate(rows):
- image_path = images_dir / cls._image_filename(row, index)
- if not force and image_path.exists() and image_path.stat().st_size > 0:
- downloaded += 1
- continue
- cls._download_file(row["image"], image_path, timeout=timeout, retries=retries)
- downloaded += 1
- if verbose and downloaded % 100 == 0:
- print(f"Downloaded {downloaded}/{len(rows)} FID reference images to {images_dir}")
-
- manifest = {
- "name": cls.DEFAULT_REFERENCE_NAME,
- "dataset_id": cls.DEFAULT_REFERENCE_DATASET_ID,
- "split": cls.DEFAULT_REFERENCE_SPLIT,
- "metadata_url": metadata_url,
- "max_images": max_images,
- "image_count": len(rows),
- "images_dir": str(images_dir),
- }
- with open(manifest_path, "w", encoding="utf-8") as file:
- json.dump(manifest, file, indent=2, ensure_ascii=False)
- return str(images_dir)
-
- @classmethod
- def default_reference_dir(
- cls,
- local_dir: Union[str, Path] = None,
- max_images: int = None,
- download: bool = True,
- **download_kwargs,
- ):
- root = Path(local_dir) if local_dir is not None else cls.default_reference_root()
- images_dir = root / "images"
- existing = cls._image_files(images_dir)
- if existing:
- if max_images is not None and len(existing) >= max_images:
- return str(images_dir)
- manifest_path = root / "reference_manifest.json"
- if max_images is None and manifest_path.exists():
- with open(manifest_path, "r", encoding="utf-8") as file:
- manifest = json.load(file)
- if manifest.get("max_images") is None and len(existing) >= manifest.get("image_count", 0):
- return str(images_dir)
- if not download:
- raise FileNotFoundError(
- f"FID reference directory does not exist: {images_dir}. "
- "Call FIDMetric.download_reference_dir(...) first or pass your own reference directory."
- )
- return cls.download_reference_dir(local_dir=root, max_images=max_images, **download_kwargs)
-
@classmethod
def from_pretrained(
cls,
- weights_config: Union[ModelConfig, str] = DEFAULT_WEIGHTS_ID,
- pretrained: bool = True,
- device: Union[str, torch.device] = get_device_type(),
- batch_size: int = 50,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors"),
+ device: torch.device = get_device_type(),
+ batch_size: int = 16,
num_workers: int = 0,
- use_fid_inception: bool = True,
+ vram_limit: float = None,
):
- weights_config = cls.resolve_model_config(
- weights_config,
- origin_file_pattern=cls.DEFAULT_WEIGHTS_FILE if weights_config == cls.DEFAULT_WEIGHTS_ID else "",
- ) if weights_config is not None and use_fid_inception else None
- model = FIDModel.from_pretrained(
- weights_path=None if weights_config is None else weights_config.path,
- pretrained=pretrained,
- device=device,
- batch_size=batch_size,
- num_workers=num_workers,
- use_fid_inception=use_fid_inception,
- )
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch.float32, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_fid_inception")
+ model = FIDModel(model=model, device=device, batch_size=batch_size, num_workers=num_workers)
return cls(model)
@torch.no_grad()
def compute(self, reference_images, generated_images, batch_size: int = None, num_workers: int = None):
score = self.model.compute(reference_images, generated_images, batch_size=batch_size, num_workers=num_workers)
- return self.tensor_to_float(score)
-
- @torch.no_grad()
- def compute_with_default_reference(
- self,
- generated_images,
- reference_dir: Union[str, Path] = None,
- max_reference_images: int = None,
- batch_size: int = None,
- num_workers: int = None,
- download_kwargs: dict = None,
- ):
- reference_dir = self.default_reference_dir(
- local_dir=reference_dir,
- max_images=max_reference_images,
- **({} if download_kwargs is None else download_kwargs),
- )
- return self.compute(reference_dir, generated_images, batch_size=batch_size, num_workers=num_workers)
+ return score.detach().cpu().item() if torch.is_tensor(score) else float(score)
def statistics(self, images, batch_size: int = None, num_workers: int = None):
return self.model.statistics(images, batch_size=batch_size, num_workers=num_workers)
diff --git a/diffsynth/metrics/hpsv2.py b/diffsynth/metrics/hpsv2.py
index f9c6258e3..309bb8d3a 100644
--- a/diffsynth/metrics/hpsv2.py
+++ b/diffsynth/metrics/hpsv2.py
@@ -1,5 +1,5 @@
-from typing import Union
import torch
+from transformers import AutoProcessor
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.hpsv2 import HPSv2Model
@@ -13,34 +13,29 @@ def __init__(self, model: HPSv2Model):
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = "AI-ModelScope/HPSv2",
- processor_config: Union[ModelConfig, str] = "AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K",
- version: str = "v2.0",
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/"),
torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = get_device_type(),
- model_kwargs: dict = None,
+ device: torch.device = get_device_type(),
processor_kwargs: dict = None,
+ vram_limit: float = None,
):
- model_config = cls.resolve_model_config(model_config)
- processor_config = cls.resolve_model_config(processor_config)
- model = HPSv2Model.from_pretrained(
- model_path=model_config.path,
- processor_path=processor_config.path,
- version=version,
- torch_dtype=torch_dtype,
- device=device,
- model_kwargs=model_kwargs,
- processor_kwargs=processor_kwargs,
- )
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_hpsv2")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model = HPSv2Model(model=model, processor=processor).eval()
return cls(model)
@torch.no_grad()
- def score(self, prompt: Union[str, list[str]], images):
+ def score(self, prompt: str | list[str], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)
- def compute(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)
- def forward(self, prompt: Union[str, list[str]], images):
+ def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
diff --git a/diffsynth/metrics/hpsv3.py b/diffsynth/metrics/hpsv3.py
index 9ec2cb0c3..83b93f45b 100644
--- a/diffsynth/metrics/hpsv3.py
+++ b/diffsynth/metrics/hpsv3.py
@@ -1,4 +1,4 @@
-from typing import Union
+from transformers import AutoProcessor
import torch
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
@@ -14,42 +14,50 @@ def __init__(self, model: HPSv3Model):
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = "MizzenAI/HPSv3",
- base_model_config: Union[ModelConfig, str] = "Qwen/Qwen2-VL-7B-Instruct",
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/"),
torch_dtype: torch.dtype = torch.bfloat16,
- device: Union[str, torch.device] = get_device_type(),
- output_dim: int = 2,
+ device: torch.device = get_device_type(),
score_index: int = 0,
use_special_tokens: bool = True,
max_pixels: int = 256 * 28 * 28,
min_pixels: int = 256 * 28 * 28,
- model_kwargs: dict = None,
processor_kwargs: dict = None,
+ vram_limit: float = None,
):
- model_config = cls.resolve_model_config(model_config)
- base_model_config = cls.resolve_model_config(base_model_config)
- model = HPSv3Model.from_pretrained(
- model_path=model_config.path,
- base_model_path=base_model_config.path,
- torch_dtype=torch_dtype,
- device=device,
- output_dim=output_dim,
- score_index=score_index,
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_hpsv3")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, padding_side="right", **processor_kwargs)
+ if use_special_tokens:
+ special_tokens = ["<|Reward|>"]
+ processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
+ model.special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
+ model.reward_token = "special"
+ model.config.tokenizer_padding_side = processor.tokenizer.padding_side
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
+ if hasattr(model.config, "text_config"):
+ model.config.text_config.pad_token_id = processor.tokenizer.pad_token_id
+ model.rm_head.to(torch.float32)
+ model = HPSv3Model(
+ model=model,
+ processor=processor,
use_special_tokens=use_special_tokens,
max_pixels=max_pixels,
min_pixels=min_pixels,
- model_kwargs=model_kwargs,
- processor_kwargs=processor_kwargs,
- )
+ score_index=score_index,
+ ).eval()
return cls(model)
@torch.no_grad()
- def score(self, prompt: Union[str, list[str]], images):
+ def score(self, prompt: str | list[str], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)
- def compute(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)
- def forward(self, prompt: Union[str, list[str]], images):
+ def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
diff --git a/diffsynth/metrics/image_reward.py b/diffsynth/metrics/image_reward.py
index feb156058..85a042e8b 100644
--- a/diffsynth/metrics/image_reward.py
+++ b/diffsynth/metrics/image_reward.py
@@ -1,95 +1,48 @@
-import os
-from pathlib import Path
-from typing import Union
-
import torch
-from modelscope import snapshot_download
-
+from transformers import BertTokenizer
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
from ..models.image_reward import ImageRewardModel
from .base import Metric
class ImageRewardMetric(Metric):
- BERT_TOKENIZER_MODEL_ID = "AI-ModelScope/bert-base-uncased"
- BERT_TOKENIZER_FILES = [
- "config.json",
- "tokenizer_config.json",
- "tokenizer.json",
- "vocab.txt",
- ]
-
def __init__(self, model: ImageRewardModel):
super().__init__()
self.model = model
- @staticmethod
- def default_tokenizer_config():
- local_path = Path(os.environ.get("DIFFSYNTH_MODEL_BASE_PATH", "./models")) / ImageRewardMetric.BERT_TOKENIZER_MODEL_ID
- if all((local_path / filename).exists() for filename in ImageRewardMetric.BERT_TOKENIZER_FILES):
- return ModelConfig(path=str(local_path))
- return ModelConfig(path=str(ImageRewardMetric.download_default_tokenizer()))
-
- @staticmethod
- def download_default_tokenizer():
- local_path = Path(os.environ.get("DIFFSYNTH_MODEL_BASE_PATH", "./models")) / ImageRewardMetric.BERT_TOKENIZER_MODEL_ID
- if all((local_path / filename).exists() for filename in ImageRewardMetric.BERT_TOKENIZER_FILES):
- return local_path
- local_path.mkdir(parents=True, exist_ok=True)
- snapshot_download(
- ImageRewardMetric.BERT_TOKENIZER_MODEL_ID,
- local_dir=str(local_path),
- allow_file_pattern=ImageRewardMetric.BERT_TOKENIZER_FILES,
- local_files_only=False,
- )
- missing = [filename for filename in ImageRewardMetric.BERT_TOKENIZER_FILES if not (local_path / filename).exists()]
- if missing:
- raise FileNotFoundError(f"Missing ImageReward tokenizer files under {local_path}: {missing}")
- return local_path
-
- @staticmethod
- def _as_directory_path(path):
- if isinstance(path, list):
- if len(path) == 0:
- raise FileNotFoundError("Downloaded tokenizer files are empty.")
- return str(Path(path[0]).parent)
- return path
-
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = "ZhipuAI/ImageReward",
- med_config: Union[ModelConfig, str] = None,
- tokenizer_config: Union[ModelConfig, str] = None,
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors"),
+ tokenizer_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/"),
torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = get_device_type(),
+ device: torch.device = get_device_type(),
max_length: int = 35,
- model_kwargs: dict = None,
tokenizer_kwargs: dict = None,
+ vram_limit: float = None,
):
- tokenizer_config = cls.default_tokenizer_config() if tokenizer_config is None else tokenizer_config
- model_config = cls.resolve_model_config(model_config)
- med_config = cls.resolve_model_config(med_config) if med_config is not None else None
- tokenizer_config = cls.resolve_model_config(tokenizer_config)
- model = ImageRewardModel.from_pretrained(
- model_path=model_config.path,
- med_config_path=None if med_config is None else med_config.path,
- tokenizer_path=cls._as_directory_path(tokenizer_config.path),
- torch_dtype=torch_dtype,
- device=device,
- max_length=max_length,
- model_kwargs=model_kwargs,
- tokenizer_kwargs=tokenizer_kwargs,
- )
+
+ tokenizer_kwargs = tokenizer_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_image_reward")
+ tokenizer_config.download_if_necessary()
+ tokenizer = BertTokenizer.from_pretrained(tokenizer_config.path, **tokenizer_kwargs)
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+ tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
+ tokenizer.enc_token_id = tokenizer.convert_tokens_to_ids("[ENC]")
+ model.tokenizer = tokenizer
+ model.max_length = max_length
+ model.mlp = model.mlp.float()
+ model = model.eval()
return cls(model)
@torch.no_grad()
- def score(self, prompt: Union[str, list[str]], images):
+ def score(self, prompt: str | list[str], images):
scores = self.model(prompt, images)
return self.tensor_to_list(scores)
- def compute(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)
- def forward(self, prompt: Union[str, list[str]], images):
+ def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
diff --git a/diffsynth/metrics/pickscore.py b/diffsynth/metrics/pickscore.py
index f3d687301..608e54484 100644
--- a/diffsynth/metrics/pickscore.py
+++ b/diffsynth/metrics/pickscore.py
@@ -1,4 +1,4 @@
-from typing import Union
+from transformers import AutoProcessor
import torch
from ..core import ModelConfig
from ..core.device.npu_compatible_device import get_device_type
@@ -13,31 +13,27 @@ def __init__(self, model: PickScoreModel):
@classmethod
def from_pretrained(
cls,
- model_config: Union[ModelConfig, str] = "AI-ModelScope/PickScore_v1",
- processor_config: Union[ModelConfig, str] = "AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K",
+ model_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/"),
torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = get_device_type(),
+ device: torch.device = get_device_type(),
max_length: int = 77,
- model_kwargs: dict = None,
processor_kwargs: dict = None,
+ vram_limit: float = None,
):
- model_config = cls.resolve_model_config(model_config)
- processor_config = cls.resolve_model_config(processor_config)
- model = PickScoreModel.from_pretrained(
- model_path=model_config.path,
- processor_path=processor_config.path,
- torch_dtype=torch_dtype,
- device=device,
- max_length=max_length,
- model_kwargs=model_kwargs,
- processor_kwargs=processor_kwargs,
- )
+
+ processor_kwargs = processor_kwargs or {}
+ model_pool = cls.download_and_load_models([model_config], torch_dtype=torch_dtype, device=device, vram_limit=vram_limit)
+ model = model_pool.fetch_model("image_metrics_clip_hf")
+ processor_config.download_if_necessary()
+ processor = AutoProcessor.from_pretrained(processor_config.path, **processor_kwargs)
+ model = PickScoreModel(model=model, processor=processor, max_length=max_length).eval()
return cls(model)
@torch.no_grad()
def score(
self,
- prompt: Union[str, list[str]],
+ prompt: str | list[str],
images,
):
scores = self.model(prompt, images)
@@ -46,18 +42,18 @@ def score(
@torch.no_grad()
def probabilities(
self,
- prompt: Union[str, list[str]],
+ prompt: str | list[str],
images,
):
scores = self.model(prompt, images)
probabilities = torch.softmax(scores, dim=-1)
return self.tensor_to_list(probabilities)
- def calc_probs(self, prompt: Union[str, list[str]], images):
+ def calc_probs(self, prompt: str | list[str], images):
return self.probabilities(prompt, images)
- def compute(self, prompt: Union[str, list[str]], images):
+ def compute(self, prompt: str | list[str], images):
return self.score(prompt, images)
- def forward(self, prompt: Union[str, list[str]], images):
+ def forward(self, prompt: str | list[str], images):
return self.score(prompt, images)
diff --git a/diffsynth/models/aesthetic.py b/diffsynth/models/aesthetic.py
index a9cc20124..b1ed41462 100644
--- a/diffsynth/models/aesthetic.py
+++ b/diffsynth/models/aesthetic.py
@@ -1,211 +1,90 @@
-from pathlib import Path
from typing import Union
-import json
import torch
-import torch.nn as nn
from PIL import Image
-from .clip import CLIPModel
ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
-class AestheticMLP(nn.Module):
+class AestheticMLP(torch.nn.Module):
def __init__(self, input_size: int):
super().__init__()
self.input_size = input_size
- self.layers = nn.Sequential(
- nn.Linear(input_size, 1024),
- nn.Dropout(0.2),
- nn.Linear(1024, 128),
- nn.Dropout(0.2),
- nn.Linear(128, 64),
- nn.Dropout(0.1),
- nn.Linear(64, 16),
- nn.Linear(16, 1),
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(input_size, 1024),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ torch.nn.Linear(16, 1),
)
def forward(self, x):
return self.layers(x)
-def _as_image_list(images: Union[ImageInput, list[ImageInput], tuple[ImageInput, ...]]):
+def _as_image_list(images: ImageInput):
if isinstance(images, Image.Image):
images = [images]
return [image.convert("RGB") for image in images]
class AestheticModel(torch.nn.Module):
- def __init__(self, mlp: AestheticMLP, clip_model: CLIPModel = None, vision_model: torch.nn.Module = None, processor=None):
+ def __init__(
+ self,
+ mlp: AestheticMLP = None,
+ vision_model: torch.nn.Module = None,
+ visual_projection: torch.nn.Module = None,
+ processor=None
+ ):
super().__init__()
- self.clip_model = clip_model
+ if vision_model is None:
+ vision_model, visual_projection = self.default_vision_model()
+ if mlp is None:
+ mlp = AestheticMLP(768)
+
self.vision_model = vision_model
+ self.visual_projection = visual_projection
self.processor = processor
- self.mlp = mlp
-
- @classmethod
- def from_pretrained(
- cls,
- model_path: str,
- clip_model_path: str = None,
- clip_processor_path: str = None,
- torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = "cpu",
- clip_kwargs: dict = None,
- processor_kwargs: dict = None,
- ):
- checkpoint = cls._load_checkpoint(model_path)
- model = cls._from_full_predictor(
- model_path=model_path,
- checkpoint=checkpoint,
- torch_dtype=torch_dtype,
- device=device,
- processor_kwargs=processor_kwargs,
- )
- return model
-
- @classmethod
- def _from_full_predictor(
- cls,
- model_path: str,
- checkpoint: dict,
- torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = "cuda",
- processor_kwargs: dict = None,
- ):
- from transformers import AutoProcessor, CLIPVisionModelWithProjection
-
- processor_kwargs = {} if processor_kwargs is None else processor_kwargs
- config = cls._load_vision_config(model_path)
- vision_model = CLIPVisionModelWithProjection(config)
- mlp = AestheticMLP(config.projection_dim)
- normalized = cls._normalize_checkpoint_keys(checkpoint)
- vision_state = {}
- mlp_state = {}
- for key, value in normalized.items():
- if key.startswith("layers."):
- mlp_state[key] = value
- elif key in vision_model.state_dict():
- vision_state[key] = value
- if not vision_state:
- raise ValueError(f"Cannot find CLIP vision tower weights in Aesthetic checkpoint under {model_path}.")
- vision_model.load_state_dict(vision_state, strict=False)
- mlp.load_state_dict(mlp_state, strict=True)
- processor = AutoProcessor.from_pretrained(model_path, **processor_kwargs)
- if torch_dtype is not None:
- vision_model = vision_model.to(dtype=torch_dtype)
- vision_model = vision_model.to(device).eval()
- mlp = mlp.to(device).float().eval()
- return cls(vision_model=vision_model, processor=processor, mlp=mlp).eval()
+ self.layers = mlp.layers
@staticmethod
- def _load_vision_config(model_path):
- from transformers import CLIPVisionConfig
-
- config_path = Path(model_path) / "config.json"
- if not config_path.exists():
- raise FileNotFoundError(f"Cannot find Aesthetic config.json under {model_path}.")
- with open(config_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- config_data = data.get("vision_config", data)
- if "projection_dim" not in config_data and "projection_dim" in data:
- config_data = dict(config_data)
- config_data["projection_dim"] = data["projection_dim"]
- allowed = {
- "attention_dropout",
- "dropout",
- "hidden_act",
- "hidden_size",
- "image_size",
- "initializer_factor",
- "initializer_range",
- "intermediate_size",
- "layer_norm_eps",
- "num_attention_heads",
- "num_channels",
- "num_hidden_layers",
- "patch_size",
- "projection_dim",
- }
- return CLIPVisionConfig(**{key: value for key, value in config_data.items() if key in allowed})
-
- @staticmethod
- def _find_checkpoint(path):
- path = Path(path)
- if path.is_file():
- return path
- names = [
- "model.safetensors",
- "pytorch_model.bin",
- "sac+logos+ava1-l14-linearMSE.pth",
- "ava+logos-l14-linearMSE.pth",
- "*.pth",
- "*.pt",
- "*.bin",
- "*.safetensors",
- ]
- for name in names:
- candidate = path / name
- if candidate.exists():
- return candidate
- matches = sorted(path.rglob(name))
- if matches:
- return matches[0]
- raise FileNotFoundError(f"Cannot find an Aesthetic MLP checkpoint under {path}.")
-
- @classmethod
- def _load_checkpoint(cls, path):
- checkpoint_path = cls._find_checkpoint(path)
- if checkpoint_path.suffix == ".safetensors":
- import safetensors.torch
-
- checkpoint = safetensors.torch.load_file(str(checkpoint_path), device="cpu")
- else:
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
- if isinstance(checkpoint, dict):
- for key in ("state_dict", "model"):
- if key in checkpoint and isinstance(checkpoint[key], dict):
- checkpoint = checkpoint[key]
- break
- return checkpoint
+ def default_vision_model():
+ from transformers import CLIPVisionConfig, CLIPVisionModel
+
+ config = CLIPVisionConfig(
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_attention_heads=16,
+ num_hidden_layers=24,
+ image_size=224,
+ patch_size=14,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ projection_dim=768,
+ )
+ return CLIPVisionModel(config), torch.nn.Linear(config.hidden_size, config.projection_dim, bias=False)
- @staticmethod
- def _normalize_checkpoint_keys(checkpoint):
- normalized = {}
- for key, value in checkpoint.items():
- for prefix in ("model.", "module.", "aesthetic_model.", "aesthetics_predictor.", "predictor."):
- if key.startswith(prefix):
- key = key[len(prefix) :]
- normalized[key] = value
- return normalized
-
@property
def device(self):
- if self.clip_model is not None:
- return self.clip_model.device
- try:
- return next(self.vision_model.parameters()).device
- except StopIteration:
- return torch.device("cpu")
+ return next(self.parameters(), torch.tensor([])).device
@property
def dtype(self):
- if self.clip_model is not None:
- return self.clip_model.dtype
- try:
- return next(self.vision_model.parameters()).dtype
- except StopIteration:
- return torch.float32
+ return next(self.parameters(), torch.tensor(0.0)).dtype
@torch.no_grad()
- def get_image_features(self, images):
- if self.clip_model is not None:
- return self.clip_model.get_image_features(images)
+ def get_image_features(self, images):
images = _as_image_list(images)
inputs = self.processor(images=images, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.dtype)
- image_features = self.vision_model(pixel_values=pixel_values, return_dict=True).image_embeds
- return image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+
+ image_features = self.vision_model(pixel_values=pixel_values, return_dict=True).pooler_output
+ image_features = self.visual_projection(image_features)
+
+ return torch.nn.functional.normalize(image_features, dim=-1)
@torch.no_grad()
def forward(self, images):
- image_features = self.get_image_features(images).float()
- return self.mlp(image_features).squeeze(-1)
+ image_features = self.get_image_features(images)
+ return self.layers(image_features).squeeze(-1)
\ No newline at end of file
diff --git a/diffsynth/models/clip.py b/diffsynth/models/clip.py
index 04f8db41e..723107d8c 100644
--- a/diffsynth/models/clip.py
+++ b/diffsynth/models/clip.py
@@ -1,6 +1,7 @@
from typing import Union
import torch
from PIL import Image
+from transformers import CLIPModel as HFCLIPModel
ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
@@ -18,6 +19,45 @@ def _feature_tensor(output, feature_name: str):
raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
+class ImageMetricsCLIPModel(HFCLIPModel):
+ def __init__(self, variant: str = "h14"):
+ super().__init__(self.config(variant))
+
+ @staticmethod
+ def config(variant: str):
+ from transformers import CLIPConfig
+ return CLIPConfig(
+ projection_dim=1024,
+ logit_scale_init_value=2.6592,
+ text_config={
+ "hidden_size": 1024,
+ "intermediate_size": 4096,
+ "num_attention_heads": 16,
+ "num_hidden_layers": 24,
+ "max_position_embeddings": 77,
+ "vocab_size": 49408,
+ "hidden_act": "quick_gelu",
+ "layer_norm_eps": 1e-5,
+ "projection_dim": 1024,
+ "bos_token_id": 0,
+ "eos_token_id": 2,
+ "pad_token_id": 1,
+ },
+ vision_config={
+ "hidden_size": 1280,
+ "intermediate_size": 5120,
+ "num_attention_heads": 16,
+ "num_hidden_layers": 32,
+ "image_size": 224,
+ "patch_size": 14,
+ "hidden_act": "quick_gelu",
+ "layer_norm_eps": 1e-5,
+ "projection_dim": 1024,
+ },
+ )
+ raise ValueError(f"Unsupported ImageMetrics CLIP variant: {variant}")
+
+
class CLIPModel(torch.nn.Module):
def __init__(self, model: torch.nn.Module, processor, max_length: int = 77):
super().__init__()
@@ -25,55 +65,29 @@ def __init__(self, model: torch.nn.Module, processor, max_length: int = 77):
self.processor = processor
self.max_length = max_length
- @classmethod
- def from_pretrained(
- cls,
- model_path: str,
- processor_path: str = None,
- torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = "cuda",
- max_length: int = 77,
- model_kwargs: dict = None,
- processor_kwargs: dict = None,
- ):
- from modelscope import AutoModel, AutoProcessor
-
- model_kwargs = {} if model_kwargs is None else model_kwargs
- processor_kwargs = {} if processor_kwargs is None else processor_kwargs
- processor_path = model_path if processor_path is None else processor_path
- processor = AutoProcessor.from_pretrained(processor_path, **processor_kwargs)
- model = AutoModel.from_pretrained(model_path, **model_kwargs).eval()
- if torch_dtype is not None:
- model = model.to(dtype=torch_dtype)
- model = model.to(device)
- return cls(model=model, processor=processor, max_length=max_length)
-
@property
def device(self):
- try:
- return next(self.model.parameters()).device
- except StopIteration:
- return torch.device("cpu")
+ return next(self.parameters(), torch.tensor([])).device
@property
def dtype(self):
- try:
- return next(self.model.parameters()).dtype
- except StopIteration:
- return torch.float32
+ return next(self.parameters(), torch.tensor(0.0)).dtype
def _normalize_pairs(self, text, images):
if isinstance(text, str):
text = [text]
else:
text = list(text)
+
if isinstance(images, Image.Image):
images = [images]
images = [image.convert("RGB") for image in images]
+
if len(text) == 1 and len(images) > 1:
text = text * len(images)
if len(images) == 1 and len(text) > 1:
images = images * len(text)
+
if len(text) != len(images):
raise ValueError(f"Expected the same number of prompts and images, got {len(text)} and {len(images)}.")
return text, images
@@ -86,6 +100,7 @@ def _processor_call(self, **kwargs):
return_tensors="pt",
**kwargs,
).to(self.device)
+
if self.dtype != torch.float32:
inputs = {
name: (
@@ -102,22 +117,24 @@ def get_image_features(self, images: ImageInput):
if isinstance(images, Image.Image):
images = [images]
images = [image.convert("RGB") for image in images]
+
image_inputs = self._processor_call(images=images)
image_features = _feature_tensor(self.model.get_image_features(**image_inputs), "image_features")
- image_features = image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
- return image_features
+
+ return torch.nn.functional.normalize(image_features, dim=-1)
@torch.no_grad()
def get_text_features(self, text: Union[str, list[str]]):
text_inputs = self._processor_call(text=text)
text_features = _feature_tensor(self.model.get_text_features(**text_inputs), "text_features")
- text_features = text_features / text_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
- return text_features
+
+ return torch.nn.functional.normalize(text_features, dim=-1)
@torch.no_grad()
def similarity_matrix(self, text: Union[str, list[str]], images: ImageInput):
image_features = self.get_image_features(images)
text_features = self.get_text_features(text)
+
scores = text_features @ image_features.T
if hasattr(self.model, "logit_scale"):
scores = self.model.logit_scale.exp() * scores
@@ -126,9 +143,11 @@ def similarity_matrix(self, text: Union[str, list[str]], images: ImageInput):
@torch.no_grad()
def forward(self, text: Union[str, list[str]], images: ImageInput):
text, images = self._normalize_pairs(text, images)
+
image_features = self.get_image_features(images)
text_features = self.get_text_features(text)
+
scores = (text_features * image_features).sum(dim=-1)
if hasattr(self.model, "logit_scale"):
scores = self.model.logit_scale.exp() * scores
- return scores
+ return scores
\ No newline at end of file
diff --git a/diffsynth/models/fid.py b/diffsynth/models/fid.py
index 31b87b3d4..af3671975 100644
--- a/diffsynth/models/fid.py
+++ b/diffsynth/models/fid.py
@@ -1,50 +1,32 @@
-from pathlib import Path
+import os
from typing import Iterable, Union
-import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
-from torchvision.models import Inception_V3_Weights, inception_v3
+from torchvision.models import inception_v3
from torchvision.models.inception import InceptionA, InceptionC, InceptionE
-ImageInput = Union[str, Path, Image.Image]
+ImageInput = Union[str, os.PathLike, Image.Image]
IMAGE_EXTENSIONS = {".bmp", ".jpg", ".jpeg", ".pgm", ".png", ".ppm", ".tif", ".tiff", ".webp"}
-def _resolve_device(device: Union[str, torch.device, None]):
- if device is None:
- device = "cuda" if _is_cuda_usable("cuda", warn=False) else "cpu"
- device = torch.device(device)
- if device.type == "cuda" and not _is_cuda_usable(device, warn=True):
- return torch.device("cpu")
- return device
-
-
-def _is_cuda_usable(device: Union[str, torch.device], warn: bool = True):
- try:
- if not torch.cuda.is_available():
- if warn:
- warnings.warn("CUDA was requested but torch.cuda.is_available() is False. FID will run on CPU instead.", RuntimeWarning)
- return False
- torch.empty(1, device=device)
- return True
- except Exception as error:
- if warn:
- warnings.warn(f"CUDA was requested but cannot be initialized ({error}). FID will run on CPU instead.", RuntimeWarning)
- return False
-
-
-def _image_files(path: Union[str, Path]):
- path = Path(path)
- if path.is_file():
- if path.suffix.lower() not in IMAGE_EXTENSIONS:
+
+def _image_files(path: Union[str, os.PathLike]):
+ path = os.fspath(path)
+ if os.path.isfile(path):
+ if os.path.splitext(path)[1].lower() not in IMAGE_EXTENSIONS:
raise ValueError(f"Unsupported image extension for FID: {path}")
return [path]
- if not path.exists():
+ if not os.path.exists(path):
raise FileNotFoundError(f"FID path does not exist: {path}")
- files = [item for item in sorted(path.rglob("*")) if item.is_file() and item.suffix.lower() in IMAGE_EXTENSIONS]
+ files = []
+ for root, dirs, names in os.walk(path):
+ dirs.sort()
+ for name in sorted(names):
+ if os.path.splitext(name)[1].lower() in IMAGE_EXTENSIONS:
+ files.append(os.path.join(root, name))
if not files:
raise ValueError(f"No images found under {path}.")
return files
@@ -60,48 +42,24 @@ def __len__(self):
def __getitem__(self, index):
image = self.images[index]
- if isinstance(image, (str, Path)):
+ if isinstance(image, (str, os.PathLike)):
image = Image.open(image)
if not isinstance(image, Image.Image):
raise TypeError(f"FID expects PIL images or image paths, but received {type(image)}.")
return self.transform(image.convert("RGB"))
-class _InceptionFeatures(nn.Module):
- def __init__(self, weights_path: str = None, pretrained: bool = True, use_fid_inception: bool = True):
+class FIDInceptionModel(nn.Module):
+ def __init__(self):
super().__init__()
- if use_fid_inception and weights_path is not None:
- model = _fid_inception_v3(weights_path)
- self.normalize_input = "fid"
- elif use_fid_inception and weights_path is None:
- warnings.warn(
- "FID-specific Inception weights were not provided. Falling back to torchvision Inception weights; "
- "scores are useful for relative comparisons but are not directly comparable to standard pytorch-fid values.",
- RuntimeWarning,
- )
- weights = Inception_V3_Weights.DEFAULT if pretrained else None
- model = inception_v3(weights=weights, aux_logits=True, init_weights=False)
- model.fc = nn.Identity()
- self.normalize_input = "imagenet" if pretrained else None
- else:
- weights = Inception_V3_Weights.DEFAULT if pretrained else None
- model = inception_v3(weights=weights, aux_logits=True, init_weights=False)
- model.fc = nn.Identity()
- self.normalize_input = "imagenet" if pretrained else None
- model.eval()
- self.model = model
+ self.model = _fid_inception_v3()
def forward(self, images):
- if self.normalize_input == "fid":
- images = 2 * images - 1
- elif self.normalize_input == "imagenet":
- mean = images.new_tensor((0.485, 0.456, 0.406)).view(1, 3, 1, 1)
- std = images.new_tensor((0.229, 0.224, 0.225)).view(1, 3, 1, 1)
- images = (images - mean) / std
+ images = 2 * images - 1
return self.model(images)
-def _fid_inception_v3(weights_path: str):
+def _fid_inception_v3(weights_path: str = None):
model = inception_v3(weights=None, aux_logits=False, num_classes=1008, init_weights=False)
model.Mixed_5b = _FIDInceptionA(192, pool_features=32)
model.Mixed_5c = _FIDInceptionA(256, pool_features=64)
@@ -112,7 +70,8 @@ def _fid_inception_v3(weights_path: str):
model.Mixed_6e = _FIDInceptionC(768, channels_7x7=192)
model.Mixed_7b = _FIDInceptionE1(1280)
model.Mixed_7c = _FIDInceptionE2(2048)
- model.load_state_dict(torch.load(weights_path, map_location="cpu"))
+ if weights_path is not None:
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.fc = nn.Identity()
return model
@@ -202,31 +161,6 @@ def __init__(self, model: torch.nn.Module, device: Union[str, torch.device] = "c
)
self.to(device)
- @classmethod
- def from_pretrained(
- cls,
- weights_path: str = None,
- pretrained: bool = True,
- device: Union[str, torch.device] = "cpu",
- batch_size: int = 50,
- num_workers: int = 0,
- use_fid_inception: bool = True,
- ):
- if isinstance(weights_path, (list, tuple)):
- if len(weights_path) == 1:
- weights_path = weights_path[0]
- elif len(weights_path) == 0:
- raise FileNotFoundError(
- "FID weights were not found. Please check the ModelScope model id and file pattern."
- )
- else:
- raise ValueError(
- f"FID expects a single weights file, but got {len(weights_path)} paths: {weights_path}"
- )
- device = _resolve_device(device)
- model = _InceptionFeatures(weights_path=weights_path, pretrained=pretrained, use_fid_inception=use_fid_inception).to(device).eval()
- return cls(model=model, device=device, batch_size=batch_size, num_workers=num_workers)
-
@property
def device(self):
try:
@@ -235,7 +169,7 @@ def device(self):
return torch.device("cpu")
def _as_images(self, images):
- if isinstance(images, (str, Path)):
+ if isinstance(images, (str, os.PathLike)):
return _image_files(images)
if isinstance(images, Image.Image):
return [images]
diff --git a/diffsynth/models/hpsv2.py b/diffsynth/models/hpsv2.py
index 4d6cb6086..c8cc4c924 100644
--- a/diffsynth/models/hpsv2.py
+++ b/diffsynth/models/hpsv2.py
@@ -1,8 +1,6 @@
-from pathlib import Path
from typing import Union
import torch
from PIL import Image
-from transformers import AutoConfig, AutoModel, AutoProcessor
ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
@@ -25,165 +23,29 @@ def _feature_tensor(output, feature_name: str):
raise TypeError(f"{feature_name} must be a tensor or a model output with projected features.")
-def _find_checkpoint(path, version):
- path = Path(path)
- version_to_file = {
- "v2.0": "HPS_v2_compressed.pt",
- "v2.1": "HPS_v2.1_compressed.pt",
- }
- if path.is_file():
- return path
- filename = version_to_file.get(version)
- names = [filename] if filename is not None else []
- names += ["*.pt", "*.pth", "*.bin", "*.safetensors"]
- for name in names:
- if name is None:
- continue
- candidate = path / name
- if candidate.exists():
- return candidate
- matches = sorted(path.rglob(name))
- if matches:
- return matches[0]
- return None
-
class HPSv2Model(torch.nn.Module):
def __init__(self, model: torch.nn.Module, processor):
super().__init__()
self.model = model
self.processor = processor
- @classmethod
- def from_pretrained(
- cls,
- model_path: str,
- processor_path: str,
- version: str = "v2.0",
- torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = "cpu",
- model_kwargs: dict = None,
- processor_kwargs: dict = None,
- ):
- model_kwargs = {} if model_kwargs is None else model_kwargs
- processor_kwargs = {} if processor_kwargs is None else processor_kwargs
- processor = AutoProcessor.from_pretrained(processor_path, **processor_kwargs)
- checkpoint_path = _find_checkpoint(model_path, version)
- config = AutoConfig.from_pretrained(processor_path)
- model = AutoModel.from_config(config, **model_kwargs)
- if checkpoint_path is None:
- raise FileNotFoundError(f"Cannot find an HPSv2 checkpoint under {model_path}.")
- state_dict = cls._load_checkpoint(checkpoint_path)
- state_dict = cls._prepare_state_dict(state_dict, model.state_dict())
- model.load_state_dict(state_dict, strict=False)
- if torch_dtype is not None:
- model = model.to(dtype=torch_dtype)
- model = model.to(device).eval()
- return cls(model=model, processor=processor)
-
- @staticmethod
- def _load_checkpoint(checkpoint_path):
- checkpoint_path = Path(checkpoint_path)
- state_dict = torch.load(checkpoint_path, map_location="cpu")
- if isinstance(state_dict, dict):
- for key in ("state_dict", "model"):
- if key in state_dict and isinstance(state_dict[key], dict):
- state_dict = state_dict[key]
- break
- return {key[len("module.") :] if key.startswith("module.") else key: value for key, value in state_dict.items()}
-
- @staticmethod
- def _prepare_state_dict(state_dict, target_state_dict):
- converted = {}
- for key, value in state_dict.items():
- updates = HPSv2Model._convert_open_clip_key(key, value)
- for new_key, new_value in updates:
- if new_key in target_state_dict and tuple(target_state_dict[new_key].shape) == tuple(new_value.shape):
- converted[new_key] = new_value
- return converted
-
- @staticmethod
- def _convert_open_clip_key(key, value):
- if key == "logit_scale":
- return [("logit_scale", value)]
- if key == "token_embedding.weight":
- return [("text_model.embeddings.token_embedding.weight", value)]
- if key == "positional_embedding":
- return [("text_model.embeddings.position_embedding.weight", value)]
- if key.startswith("ln_final."):
- return [("text_model.final_layer_norm." + key[len("ln_final.") :], value)]
- if key == "text_projection":
- return [("text_projection.weight", value.T)]
- if key == "visual.class_embedding":
- return [("vision_model.embeddings.class_embedding", value)]
- if key == "visual.conv1.weight":
- return [("vision_model.embeddings.patch_embedding.weight", value)]
- if key == "visual.positional_embedding":
- return [("vision_model.embeddings.position_embedding.weight", value)]
- if key.startswith("visual.ln_pre."):
- return [("vision_model.pre_layrnorm." + key[len("visual.ln_pre.") :], value)]
- if key.startswith("visual.ln_post."):
- return [("vision_model.post_layernorm." + key[len("visual.ln_post.") :], value)]
- if key == "visual.proj":
- return [("visual_projection.weight", value.T)]
- if key.startswith("transformer.resblocks."):
- return HPSv2Model._convert_resblock("text_model.encoder.layers", key[len("transformer.resblocks.") :], value)
- if key.startswith("visual.transformer.resblocks."):
- return HPSv2Model._convert_resblock("vision_model.encoder.layers", key[len("visual.transformer.resblocks.") :], value)
- return []
-
- @staticmethod
- def _convert_resblock(prefix, suffix, value):
- parts = suffix.split(".")
- layer = parts[0]
- rest = ".".join(parts[1:])
- layer_prefix = f"{prefix}.{layer}."
- if rest == "attn.in_proj_weight":
- q, k, v = value.chunk(3, dim=0)
- return [
- (layer_prefix + "self_attn.q_proj.weight", q),
- (layer_prefix + "self_attn.k_proj.weight", k),
- (layer_prefix + "self_attn.v_proj.weight", v),
- ]
- if rest == "attn.in_proj_bias":
- q, k, v = value.chunk(3, dim=0)
- return [
- (layer_prefix + "self_attn.q_proj.bias", q),
- (layer_prefix + "self_attn.k_proj.bias", k),
- (layer_prefix + "self_attn.v_proj.bias", v),
- ]
- mapping = {
- "attn.out_proj.": "self_attn.out_proj.",
- "ln_1.": "layer_norm1.",
- "ln_2.": "layer_norm2.",
- "mlp.c_fc.": "mlp.fc1.",
- "mlp.c_proj.": "mlp.fc2.",
- }
- for source, target in mapping.items():
- if rest.startswith(source):
- return [(layer_prefix + target + rest[len(source) :], value)]
- return []
-
@property
def device(self):
- try:
- return next(self.model.parameters()).device
- except StopIteration:
- return torch.device("cpu")
+ return next(self.parameters(), torch.tensor([])).device
@property
def dtype(self):
- try:
- return next(self.model.parameters()).dtype
- except StopIteration:
- return torch.float32
+ return next(self.parameters(), torch.tensor(0.0)).dtype
def _normalize_inputs(self, prompts, images):
images = _as_list(images)
prompts = _as_list(prompts)
+
if len(prompts) == 1 and len(images) > 1:
prompts = prompts * len(images)
if len(images) == 1 and len(prompts) > 1:
images = images * len(prompts)
+
if len(prompts) != len(images):
raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
return prompts, images
@@ -192,8 +54,15 @@ def _normalize_inputs(self, prompts, images):
def forward(self, prompts: Union[str, list[str]], images: ImageInput):
prompts, images = self._normalize_inputs(prompts, images)
images = [image.convert("RGB") for image in images]
- inputs = self.processor(text=prompts, images=images, padding=True, truncation=True, return_tensors="pt")
- inputs = inputs.to(self.device)
+
+ inputs = self.processor(
+ text=prompts,
+ images=images,
+ padding=True,
+ truncation=True,
+ return_tensors="pt"
+ ).to(self.device)
+
if self.dtype != torch.float32:
inputs = {
name: (
@@ -203,6 +72,7 @@ def forward(self, prompts: Union[str, list[str]], images: ImageInput):
)
for name, value in inputs.items()
}
+
image_features = _feature_tensor(
self.model.get_image_features(pixel_values=inputs["pixel_values"]),
"image_features",
@@ -211,9 +81,12 @@ def forward(self, prompts: Union[str, list[str]], images: ImageInput):
self.model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask")),
"text_features",
)
- image_features = image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
- text_features = text_features / text_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+
+ image_features = torch.nn.functional.normalize(image_features, dim=-1)
+ text_features = torch.nn.functional.normalize(text_features, dim=-1)
+
scores = (image_features * text_features).sum(dim=-1)
if hasattr(self.model, "logit_scale"):
scores = self.model.logit_scale.exp() * scores
- return scores
+
+ return scores
\ No newline at end of file
diff --git a/diffsynth/models/hpsv3.py b/diffsynth/models/hpsv3.py
index 0b7498850..b22966b6d 100644
--- a/diffsynth/models/hpsv3.py
+++ b/diffsynth/models/hpsv3.py
@@ -1,11 +1,8 @@
import math
-from pathlib import Path
from typing import Optional, Union
-
import torch
-import torch.nn as nn
from PIL import Image
-from transformers import AutoProcessor
+from transformers import Qwen2VLForConditionalGeneration
ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
@@ -59,8 +56,10 @@ def _floor_by_factor(number, factor):
def _smart_resize(height, width, factor=28, min_pixels=256 * 28 * 28, max_pixels=256 * 28 * 28):
if max(height, width) / min(height, width) > 200:
raise ValueError(f"Image aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}.")
+
h_bar = max(factor, _round_by_factor(height, factor))
w_bar = max(factor, _round_by_factor(width, factor))
+
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = _floor_by_factor(height / beta, factor)
@@ -69,20 +68,9 @@ def _smart_resize(height, width, factor=28, min_pixels=256 * 28 * 28, max_pixels
beta = math.sqrt(min_pixels / (height * width))
h_bar = _ceil_by_factor(height * beta, factor)
w_bar = _ceil_by_factor(width * beta, factor)
+
return h_bar, w_bar
-def _find_checkpoint(path):
- path = Path(path)
- if path.is_file():
- return path
- for name in ("HPSv3.safetensors", "*.safetensors", "*.bin", "*.pt", "*.pth"):
- candidate = path / name
- if candidate.exists():
- return candidate
- matches = sorted(path.rglob(name))
- if matches:
- return matches[0]
- return None
class HPSv3RewardModelMixin:
def init_reward_head(
@@ -96,23 +84,26 @@ def init_reward_head(
self.output_dim = output_dim
self.reward_token = "special" if special_token_ids is not None else reward_token
self.special_token_ids = special_token_ids
+
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "text_config"):
hidden_size = self.config.text_config.hidden_size
+
if rm_head_type == "ranknet":
- rm_head_kwargs = {} if rm_head_kwargs is None else rm_head_kwargs
+ rm_head_kwargs = rm_head_kwargs or {}
hidden = rm_head_kwargs.get("hidden_size", 1024)
dropout = rm_head_kwargs.get("dropout", 0.05)
- self.rm_head = nn.Sequential(
- nn.Linear(hidden_size, hidden),
- nn.ReLU(),
- nn.Dropout(dropout),
- nn.Linear(hidden, 16),
- nn.ReLU(),
- nn.Linear(16, output_dim),
+ self.rm_head = torch.nn.Sequential(
+ torch.nn.Linear(hidden_size, hidden),
+ torch.nn.ReLU(),
+ torch.nn.Dropout(dropout),
+ torch.nn.Linear(hidden, 16),
+ torch.nn.ReLU(),
+ torch.nn.Linear(16, output_dim),
)
else:
- self.rm_head = nn.Linear(hidden_size, output_dim, bias=False)
+ self.rm_head = torch.nn.Linear(hidden_size, output_dim, bias=False)
+
self.rm_head.to(torch.float32)
def forward(
@@ -132,12 +123,15 @@ def forward(
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
- mm_token_type_ids: Optional[torch.IntTensor] = None,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ kwargs.pop("logits_to_keep", None)
+ mm_token_type_ids = kwargs.pop("mm_token_type_ids", None)
+
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
@@ -153,18 +147,25 @@ def forward(
mm_token_type_ids=mm_token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
**kwargs,
)
+
hidden_states = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0]
logits = self.rm_head(hidden_states.to(next(self.rm_head.parameters()).dtype))
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
- if self.config.pad_token_id is None and batch_size != 1:
+ pad_token_id = getattr(self.config, "pad_token_id", None)
+ if pad_token_id is None and hasattr(self.config, "text_config"):
+ pad_token_id = getattr(self.config.text_config, "pad_token_id", None)
+
+ if pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
- if self.config.pad_token_id is None:
+
+ if pad_token_id is None:
sequence_lengths = -1
elif input_ids is not None:
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
@@ -176,37 +177,91 @@ def forward(
valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1)
pooled_logits = torch.stack([logits[i, : valid_lengths[i]].mean(dim=0) for i in range(batch_size)])
elif self.reward_token == "special":
+ if self.special_token_ids is None:
+ raise ValueError("HPSv3 reward_token='special' requires special_token_ids.")
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
for special_token_id in self.special_token_ids:
special_token_mask = special_token_mask | (input_ids == special_token_id)
pooled_logits = logits[special_token_mask, ...].view(batch_size, -1)
else:
raise ValueError(f"Invalid HPSv3 reward token mode: {self.reward_token}")
+
return {"logits": pooled_logits}
-def _create_reward_model_class():
- from transformers import Qwen2VLForConditionalGeneration
-
- class HPSv3Qwen2VLRewardModel(HPSv3RewardModelMixin, Qwen2VLForConditionalGeneration):
- def __init__(
- self,
- config,
- output_dim=2,
- reward_token="special",
- special_token_ids=None,
- rm_head_type="ranknet",
- rm_head_kwargs=None,
- ):
- super().__init__(config)
- self.init_reward_head(
- output_dim=output_dim,
- reward_token=reward_token,
- special_token_ids=special_token_ids,
- rm_head_type=rm_head_type,
- rm_head_kwargs=rm_head_kwargs,
- )
- return HPSv3Qwen2VLRewardModel
+class HPSv3Qwen2VLRewardModel(HPSv3RewardModelMixin, Qwen2VLForConditionalGeneration):
+ def __init__(
+ self,
+ config=None,
+ vocab_size=None,
+ output_dim=2,
+ reward_token="special",
+ special_token_ids=None,
+ rm_head_type="ranknet",
+ rm_head_kwargs=None,
+ ):
+ if config is None:
+ config = self.default_config(vocab_size or 151658)
+ elif vocab_size is not None and hasattr(config, "text_config"):
+ config.text_config.vocab_size = vocab_size
+
+ super().__init__(config)
+ self.init_reward_head(
+ output_dim=output_dim,
+ reward_token=reward_token,
+ special_token_ids=special_token_ids,
+ rm_head_type=rm_head_type,
+ rm_head_kwargs=rm_head_kwargs,
+ )
+
+ @staticmethod
+ def default_config(vocab_size=151658):
+ from transformers import Qwen2VLConfig
+
+ return Qwen2VLConfig(
+ text_config={
+ "vocab_size": vocab_size,
+ "hidden_size": 3584,
+ "intermediate_size": 18944,
+ "num_hidden_layers": 28,
+ "num_attention_heads": 28,
+ "num_key_value_heads": 4,
+ "hidden_act": "silu",
+ "max_position_embeddings": 32768,
+ "initializer_range": 0.02,
+ "rms_norm_eps": 1e-6,
+ "use_cache": True,
+ "use_sliding_window": False,
+ "sliding_window": 32768,
+ "max_window_layers": 28,
+ "attention_dropout": 0.0,
+ "rope_parameters": {
+ "rope_type": "default",
+ "type": "mrope",
+ "mrope_section": [16, 24, 24],
+ "rope_theta": 1000000.0,
+ },
+ "bos_token_id": 151643,
+ "eos_token_id": 151645,
+ },
+ vision_config={
+ "depth": 32,
+ "embed_dim": 1280,
+ "hidden_size": 3584,
+ "mlp_ratio": 4,
+ "num_heads": 16,
+ "in_channels": 3,
+ "patch_size": 14,
+ "spatial_merge_size": 2,
+ "temporal_patch_size": 2,
+ },
+ image_token_id=151655,
+ video_token_id=151656,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ tie_word_embeddings=False,
+ )
+
class HPSv3Model(torch.nn.Module):
def __init__(
@@ -226,122 +281,19 @@ def __init__(
self.min_pixels = min_pixels
self.score_index = score_index
- @classmethod
- def from_pretrained(
- cls,
- model_path: str,
- base_model_path: str = None,
- torch_dtype: torch.dtype = torch.bfloat16,
- device: Union[str, torch.device] = "cpu",
- output_dim: int = 2,
- score_index: int = 0,
- use_special_tokens: bool = True,
- reward_token: str = "special",
- rm_head_type: str = "ranknet",
- rm_head_kwargs: dict = None,
- max_pixels: int = 256 * 28 * 28,
- min_pixels: int = 256 * 28 * 28,
- model_kwargs: dict = None,
- processor_kwargs: dict = None,
- ):
- model_kwargs = {} if model_kwargs is None else model_kwargs
- processor_kwargs = {} if processor_kwargs is None else processor_kwargs
- model_path = Path(model_path)
- base_model_path = base_model_path or str(model_path)
- checkpoint_path = _find_checkpoint(model_path)
- if checkpoint_path is None:
- raise FileNotFoundError(f"Cannot find an HPSv3 checkpoint under {model_path}.")
-
- processor = AutoProcessor.from_pretrained(base_model_path, padding_side="right", **processor_kwargs)
- special_token_ids = None
- if use_special_tokens:
- special_tokens = ["<|Reward|>"]
- processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
- special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
-
- reward_model_class = _create_reward_model_class()
- model = reward_model_class.from_pretrained(
- base_model_path,
- output_dim=output_dim,
- reward_token=reward_token,
- special_token_ids=special_token_ids,
- torch_dtype=torch_dtype,
- attn_implementation=model_kwargs.pop("attn_implementation", "sdpa"),
- **model_kwargs,
- )
- if use_special_tokens:
- model.resize_token_embeddings(len(processor.tokenizer))
- state_dict = cls._load_checkpoint(checkpoint_path)
- state_dict = cls._prepare_state_dict(state_dict, model.state_dict())
- model.load_state_dict(state_dict, strict=True)
- model.config.tokenizer_padding_side = processor.tokenizer.padding_side
- model.config.pad_token_id = processor.tokenizer.pad_token_id
- model.rm_head.to(torch.float32)
- model = model.to(device).eval()
- return cls(
- model=model,
- processor=processor,
- use_special_tokens=use_special_tokens,
- max_pixels=max_pixels,
- min_pixels=min_pixels,
- score_index=score_index,
- )
-
- @staticmethod
- def _load_checkpoint(checkpoint_path):
- checkpoint_path = Path(checkpoint_path)
- if checkpoint_path.suffix == ".safetensors":
- import safetensors.torch
-
- state_dict = safetensors.torch.load_file(str(checkpoint_path), device="cpu")
- else:
- state_dict = torch.load(checkpoint_path, map_location="cpu")
- if isinstance(state_dict, dict):
- for key in ("state_dict", "model"):
- if key in state_dict and isinstance(state_dict[key], dict):
- state_dict = state_dict[key]
- break
- return {key[len("module.") :] if key.startswith("module.") else key: value for key, value in state_dict.items()}
-
- @staticmethod
- def _prepare_state_dict(state_dict, target_state_dict):
- target_keys = set(target_state_dict.keys())
- converted = {}
- for key, value in state_dict.items():
- new_key = key
- if key.startswith("visual.") and f"model.{key}" in target_keys:
- new_key = f"model.{key}"
- elif key.startswith("model.visual.") and key[len("model.") :] in target_keys:
- new_key = key[len("model.") :]
- elif key.startswith("model.") and not key.startswith("model.language_model."):
- suffix = key[len("model.") :]
- if f"model.language_model.{suffix}" in target_keys:
- new_key = f"model.language_model.{suffix}"
- elif key.startswith("model.language_model."):
- suffix = key[len("model.language_model.") :]
- if f"model.{suffix}" in target_keys:
- new_key = f"model.{suffix}"
- elif key.startswith("lm_head.") and f"model.{key}" in target_keys:
- new_key = f"model.{key}"
- elif key.startswith("model.lm_head.") and key[len("model.") :] in target_keys:
- new_key = key[len("model.") :]
- converted[new_key] = value
- return converted
-
@property
def device(self):
- try:
- return next(self.model.parameters()).device
- except StopIteration:
- return torch.device("cpu")
+ return next(self.parameters(), torch.tensor([])).device
def _normalize_inputs(self, prompts, images):
images = _as_list(images)
prompts = _as_list(prompts)
+
if len(prompts) == 1 and len(images) > 1:
prompts = prompts * len(images)
if len(images) == 1 and len(prompts) > 1:
images = images * len(prompts)
+
if len(prompts) != len(images):
raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
return prompts, images
@@ -351,6 +303,7 @@ def _prepare_images(self, images):
for image in images:
image = image.convert("RGB")
height, width = image.height, image.width
+
resized_height, resized_width = _smart_resize(
height,
width,
@@ -358,11 +311,13 @@ def _prepare_images(self, images):
max_pixels=self.max_pixels,
)
prepared.append(image.resize((resized_width, resized_height), Image.BICUBIC))
+
return prepared
def _messages(self, prompts, images):
suffix = HPSV3_PROMPT_WITH_SPECIAL_TOKEN if self.use_special_tokens else HPSV3_PROMPT_WITHOUT_SPECIAL_TOKEN
messages = []
+
for prompt, image in zip(prompts, images):
messages.append(
[
@@ -380,15 +335,19 @@ def _messages(self, prompts, images):
def _prepare_batch(self, prompts, images):
images = self._prepare_images(images)
messages = self._messages(prompts, images)
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch = self.processor(text=text, images=images, padding=True, return_tensors="pt")
+
return batch.to(self.device)
- @torch.inference_mode()
+ @torch.no_grad()
def forward(self, prompts: Union[str, list[str]], images):
prompts, images = self._normalize_inputs(prompts, images)
batch = self._prepare_batch(prompts, images)
+
rewards = self.model(return_dict=True, **batch)["logits"]
if rewards.ndim == 2:
return rewards[:, self.score_index]
- return rewards
+
+ return rewards
\ No newline at end of file
diff --git a/diffsynth/models/image_reward.py b/diffsynth/models/image_reward.py
index c8f557058..da3f39fc2 100644
--- a/diffsynth/models/image_reward.py
+++ b/diffsynth/models/image_reward.py
@@ -1,14 +1,10 @@
-import json
-import fnmatch
-from pathlib import Path
from typing import Union
import torch
-import torch.nn as nn
from PIL import Image
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from torchvision.transforms import InterpolationMode
-BICUBIC = InterpolationMode.BICUBIC
+BICUBIC = InterpolationMode.BICUBIC
ImageInput = Union[Image.Image, list[Image.Image], tuple[Image.Image, ...]]
def _convert_image_to_rgb(image):
@@ -30,201 +26,103 @@ def _as_list(value):
return list(value)
return [value]
-def _find_file(path, names):
- path = Path(path)
- if path.is_file():
- return path if any(fnmatch.fnmatch(path.name, name) for name in names) else None
- for name in names:
- candidate = path / name
- if candidate.exists():
- return candidate
- for pattern in names:
- matches = sorted(path.rglob(pattern))
- if matches:
- return matches[0]
- return None
-class ImageRewardMLP(nn.Module):
+class ImageRewardMLP(torch.nn.Module):
def __init__(self, input_size):
super().__init__()
- self.layers = nn.Sequential(
- nn.Linear(input_size, 1024),
- nn.Dropout(0.2),
- nn.Linear(1024, 128),
- nn.Dropout(0.2),
- nn.Linear(128, 64),
- nn.Dropout(0.1),
- nn.Linear(64, 16),
- nn.Linear(16, 1),
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(input_size, 1024),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ torch.nn.Linear(16, 1),
)
for name, param in self.layers.named_parameters():
if "weight" in name:
- nn.init.normal_(param, mean=0.0, std=1.0 / (input_size + 1))
+ torch.nn.init.normal_(param, mean=0.0, std=1.0 / (input_size + 1))
if "bias" in name:
- nn.init.constant_(param, val=0)
+ torch.nn.init.constant_(param, val=0)
def forward(self, x):
return self.layers(x)
-class ImageRewardModel(nn.Module):
- def __init__(self, blip, tokenizer, image_size=224, max_length=35, mean=0.16717362830052426, std=1.0333394966054072):
+
+class ImageRewardModel(torch.nn.Module):
+ def __init__(self, blip=None, tokenizer=None, image_size=224, max_length=35, mean=0.16717362830052426, std=1.0333394966054072):
super().__init__()
+ if blip is None:
+ blip = self.default_blip_model()
+
self.blip = blip
self.tokenizer = tokenizer
self.preprocess = _image_transform(image_size)
self.max_length = max_length
self.mlp = ImageRewardMLP(blip.config.text_config.hidden_size)
+
self.register_buffer("score_mean", torch.tensor(float(mean)), persistent=False)
self.register_buffer("score_std", torch.tensor(float(std)), persistent=False)
- @classmethod
- def from_pretrained(
- cls,
- model_path: str,
- med_config_path: str = None,
- tokenizer_path: str = None,
- torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = "cpu",
- max_length: int = 35,
- model_kwargs: dict = None,
- tokenizer_kwargs: dict = None,
- ):
- from transformers import BertTokenizer, BlipConfig, BlipForImageTextRetrieval
-
- model_kwargs = {} if model_kwargs is None else model_kwargs
- tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
- model_path = Path(model_path)
- checkpoint_path = _find_file(model_path, ["ImageReward.pt", "pytorch_model.bin", "*.pt", "*.bin", "*.safetensors"])
- if checkpoint_path is None:
- raise FileNotFoundError(f"Cannot find an ImageReward checkpoint under {model_path}.")
-
- med_config_path = Path(med_config_path) if med_config_path is not None else _find_file(model_path, ["med_config.json"])
- text_config = cls._load_text_config(med_config_path)
- if tokenizer_path is None:
- if cls._has_tokenizer_files(model_path):
- tokenizer_path = str(model_path)
- else:
- raise ValueError(
- "ImageReward requires a local BERT tokenizer path. Use "
- "`ImageRewardMetric.from_pretrained(...)`, or pass a "
- "ModelScope-downloaded tokenizer such as "
- "`AI-ModelScope/bert-base-uncased`."
- )
- tokenizer = BertTokenizer.from_pretrained(tokenizer_path, **tokenizer_kwargs)
- tokenizer.add_special_tokens({"bos_token": "[DEC]"})
- tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
- tokenizer.enc_token_id = tokenizer.convert_tokens_to_ids("[ENC]")
+ @staticmethod
+ def default_blip_model():
+ from transformers import BlipConfig, BlipForImageTextRetrieval
- vision_hidden_size = model_kwargs.pop("vision_hidden_size", 1024)
+ vision_hidden_size = 1024
+ text_config = ImageRewardModel._load_text_config(None)
config = BlipConfig(
vision_config={
"hidden_size": vision_hidden_size,
"intermediate_size": vision_hidden_size * 4,
- "num_hidden_layers": model_kwargs.pop("vision_num_hidden_layers", 24),
- "num_attention_heads": model_kwargs.pop("vision_num_attention_heads", 16),
- "image_size": model_kwargs.pop("image_size", 224),
- "patch_size": model_kwargs.pop("patch_size", 16),
+ "num_hidden_layers": 24,
+ "num_attention_heads": 16,
+ "image_size": 224,
+ "patch_size": 16,
"hidden_act": "gelu",
- "layer_norm_eps": model_kwargs.pop("vision_layer_norm_eps", 1e-6),
+ "layer_norm_eps": 1e-6,
},
text_config={
**text_config,
- "vocab_size": max(text_config.get("vocab_size", 0), len(tokenizer)),
+ "vocab_size": 30524,
"encoder_hidden_size": vision_hidden_size,
"add_cross_attention": True,
"is_decoder": True,
},
- projection_dim=model_kwargs.pop("projection_dim", 256),
+ projection_dim=256,
)
- blip = BlipForImageTextRetrieval(config)
- model = cls(blip=blip, tokenizer=tokenizer, max_length=max_length)
- state_dict = cls._load_checkpoint(checkpoint_path)
- converted = cls._convert_state_dict(state_dict)
- model.load_state_dict(converted, strict=False)
- if torch_dtype is not None:
- model.blip = model.blip.to(dtype=torch_dtype)
- model.mlp = model.mlp.float()
- model = model.to(device).eval()
- return model
+ return BlipForImageTextRetrieval(config)
@staticmethod
def _load_text_config(med_config_path):
- if med_config_path is None:
- return {
- "hidden_size": 768,
- "intermediate_size": 3072,
- "num_hidden_layers": 12,
- "num_attention_heads": 12,
- "max_position_embeddings": 512,
- "vocab_size": 30524,
- "hidden_act": "gelu",
- "layer_norm_eps": 1e-12,
- "attention_probs_dropout_prob": 0.1,
- "hidden_dropout_prob": 0.1,
- "pad_token_id": 0,
- "type_vocab_size": 2,
- }
- with open(med_config_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- allowed = {
- "hidden_size",
- "intermediate_size",
- "num_hidden_layers",
- "num_attention_heads",
- "max_position_embeddings",
- "vocab_size",
- "hidden_act",
- "layer_norm_eps",
- "attention_probs_dropout_prob",
- "hidden_dropout_prob",
- "pad_token_id",
- "type_vocab_size",
+ return {
+ "hidden_size": 768,
+ "intermediate_size": 3072,
+ "num_hidden_layers": 12,
+ "num_attention_heads": 12,
+ "max_position_embeddings": 512,
+ "vocab_size": 30524,
+ "hidden_act": "gelu",
+ "layer_norm_eps": 1e-12,
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_dropout_prob": 0.1,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
}
- return {key: value for key, value in data.items() if key in allowed}
-
- @staticmethod
- def _has_tokenizer_files(path):
- path = Path(path)
- return path.is_dir() and any((path / name).exists() for name in ("vocab.txt", "tokenizer.json", "tokenizer_config.json"))
-
- @staticmethod
- def _load_checkpoint(checkpoint_path):
- checkpoint_path = Path(checkpoint_path)
- if checkpoint_path.suffix == ".safetensors":
- import safetensors.torch
-
- state_dict = safetensors.torch.load_file(str(checkpoint_path), device="cpu")
- else:
- state_dict = torch.load(checkpoint_path, map_location="cpu")
- if isinstance(state_dict, dict):
- for key in ("state_dict", "model"):
- if key in state_dict and isinstance(state_dict[key], dict):
- state_dict = state_dict[key]
- break
- return state_dict
-
- @staticmethod
- def _convert_state_dict(state_dict):
- converted = {}
- for key, value in state_dict.items():
- if key.startswith("module."):
- key = key[len("module.") :]
- new_key, new_value = ImageRewardModel._convert_key_value(key, value)
- if new_key is not None:
- converted[new_key] = new_value
- return converted
@staticmethod
- def _convert_key_value(key, value):
+ def convert_key_value(key, value):
if key.startswith("blip.visual_encoder."):
suffix = key[len("blip.visual_encoder.") :]
+
if suffix == "cls_token":
return "blip.vision_model.embeddings.class_embedding", value
if suffix == "pos_embed":
return "blip.vision_model.embeddings.position_embedding", value
if suffix.startswith("patch_embed.proj."):
return "blip.vision_model.embeddings.patch_embedding." + suffix[len("patch_embed.proj.") :], value
+
if suffix.startswith("blocks."):
parts = suffix.split(".")
layer = parts[1]
@@ -241,24 +139,20 @@ def _convert_key_value(key, value):
for source, target in mapping.items():
if rest.startswith(source):
return prefix + target + rest[len(source) :], value
+
if suffix.startswith("norm."):
return "blip.vision_model.post_layernorm." + suffix[len("norm.") :], value
+
return None, value
return key, value
@property
def device(self):
- try:
- return next(self.parameters()).device
- except StopIteration:
- return torch.device("cpu")
+ return next(self.parameters(), torch.tensor([])).device
@property
def dtype(self):
- try:
- return next(self.blip.parameters()).dtype
- except StopIteration:
- return torch.float32
+ return next(self.parameters(), torch.tensor(0.0)).dtype
def _tokenize(self, prompts):
return self.tokenizer(
@@ -276,12 +170,15 @@ def _preprocess_images(self, images):
def _normalize_inputs(self, prompts, images):
images = _as_list(images)
prompts = _as_list(prompts)
+
if len(prompts) == 1 and len(images) > 1:
prompts = prompts * len(images)
if len(images) == 1 and len(prompts) > 1:
images = images * len(prompts)
+
if len(prompts) != len(images):
raise ValueError(f"Expected the same number of prompts and images, got {len(prompts)} and {len(images)}.")
+
return prompts, images
@torch.no_grad()
@@ -289,9 +186,11 @@ def forward(self, prompts: Union[str, list[str]], images):
prompts, images = self._normalize_inputs(prompts, images)
text_input = self._tokenize(prompts)
image_tensor = self._preprocess_images(images)
+
image_output = self.blip.vision_model(pixel_values=image_tensor, return_dict=True)
image_embeds = image_output.last_hidden_state
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=self.device)
+
text_output = self.blip.text_encoder(
input_ids=text_input.input_ids,
attention_mask=text_input.attention_mask,
@@ -299,7 +198,9 @@ def forward(self, prompts: Union[str, list[str]], images):
encoder_attention_mask=image_atts,
return_dict=True,
)
- text_features = text_output.last_hidden_state[:, 0, :].float()
+
+ text_features = text_output.last_hidden_state[:, 0, :]
rewards = self.mlp(text_features).squeeze(-1)
rewards = (rewards - self.score_mean) / self.score_std
- return rewards
+
+ return rewards
\ No newline at end of file
diff --git a/diffsynth/models/pickscore.py b/diffsynth/models/pickscore.py
index 900e41efb..7012c3b93 100644
--- a/diffsynth/models/pickscore.py
+++ b/diffsynth/models/pickscore.py
@@ -25,41 +25,13 @@ def __init__(self, model: torch.nn.Module, processor, max_length: int = 77):
self.processor = processor
self.max_length = max_length
- @classmethod
- def from_pretrained(
- cls,
- model_path: str,
- processor_path: str,
- torch_dtype: torch.dtype = None,
- device: Union[str, torch.device] = "cuda",
- max_length: int = 77,
- model_kwargs: dict = None,
- processor_kwargs: dict = None,
- ):
- from modelscope import AutoModel, AutoProcessor
-
- model_kwargs = {} if model_kwargs is None else model_kwargs
- processor_kwargs = {} if processor_kwargs is None else processor_kwargs
- processor = AutoProcessor.from_pretrained(processor_path, **processor_kwargs)
- model = AutoModel.from_pretrained(model_path, **model_kwargs).eval()
- if torch_dtype is not None:
- model = model.to(dtype=torch_dtype)
- model = model.to(device)
- return cls(model=model, processor=processor, max_length=max_length)
-
@property
def device(self):
- try:
- return next(self.model.parameters()).device
- except StopIteration:
- return torch.device("cpu")
+ return next(self.parameters(), torch.tensor([])).device
@property
def dtype(self):
- try:
- return next(self.model.parameters()).dtype
- except StopIteration:
- return torch.float32
+ return next(self.parameters(), torch.tensor(0.0)).dtype
def _processor_call(self, **kwargs):
inputs = self.processor(
@@ -69,6 +41,7 @@ def _processor_call(self, **kwargs):
return_tensors="pt",
**kwargs,
).to(self.device)
+
if self.dtype != torch.float32:
inputs = {
name: (
@@ -78,6 +51,7 @@ def _processor_call(self, **kwargs):
)
for name, value in inputs.items()
}
+
return inputs
@torch.no_grad()
@@ -85,23 +59,26 @@ def get_image_features(self, images: ImageInput):
if isinstance(images, Image.Image):
images = [images]
images = [image.convert("RGB") for image in images]
+
image_inputs = self._processor_call(images=images)
image_features = _feature_tensor(self.model.get_image_features(**image_inputs), "image_features")
- image_features = image_features / image_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
- return image_features
+
+ return torch.nn.functional.normalize(image_features, dim=-1)
@torch.no_grad()
def get_text_features(self, text: Union[str, list[str]]):
text_inputs = self._processor_call(text=text)
text_features = _feature_tensor(self.model.get_text_features(**text_inputs), "text_features")
- text_features = text_features / text_features.norm(dim=-1, keepdim=True).clamp_min(1e-12)
- return text_features
+
+ return torch.nn.functional.normalize(text_features, dim=-1)
@torch.no_grad()
def forward(self, text: Union[str, list[str]], images: ImageInput):
image_features = self.get_image_features(images)
text_features = self.get_text_features(text)
+
scores = self.model.logit_scale.exp() * (text_features @ image_features.T)
if isinstance(text, str):
scores = scores[0]
- return scores
+
+ return scores
\ No newline at end of file
diff --git a/diffsynth/utils/state_dict_converters/image_metrics.py b/diffsynth/utils/state_dict_converters/image_metrics.py
new file mode 100644
index 000000000..30c8b55a3
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/image_metrics.py
@@ -0,0 +1,121 @@
+import torch
+
+
+def ImageMetricsCLIPStateDictConverter(state_dict):
+ return {
+ key: state_dict[key]
+ for key in state_dict
+ if key not in ("text_model.embeddings.position_ids", "vision_model.embeddings.position_ids")
+ }
+
+
+def ImageMetricsOpenCLIPStateDictConverter(state_dict):
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ if key == "logit_scale":
+ converted["logit_scale"] = value
+ elif key == "token_embedding.weight":
+ converted["text_model.embeddings.token_embedding.weight"] = value
+ elif key == "positional_embedding":
+ converted["text_model.embeddings.position_embedding.weight"] = value
+ elif key.startswith("ln_final."):
+ converted["text_model.final_layer_norm." + key[len("ln_final.") :]] = value
+ elif key == "text_projection":
+ converted["text_projection.weight"] = value.T
+ elif key == "visual.class_embedding":
+ converted["vision_model.embeddings.class_embedding"] = value
+ elif key == "visual.conv1.weight":
+ converted["vision_model.embeddings.patch_embedding.weight"] = value
+ elif key == "visual.positional_embedding":
+ converted["vision_model.embeddings.position_embedding.weight"] = value
+ elif key.startswith("visual.ln_pre."):
+ converted["vision_model.pre_layrnorm." + key[len("visual.ln_pre.") :]] = value
+ elif key.startswith("visual.ln_post."):
+ converted["vision_model.post_layernorm." + key[len("visual.ln_post.") :]] = value
+ elif key == "visual.proj":
+ converted["visual_projection.weight"] = value.T
+ elif key.startswith("transformer.resblocks."):
+ converted.update(_convert_open_clip_resblock("text_model.encoder.layers", key[len("transformer.resblocks.") :], value))
+ elif key.startswith("visual.transformer.resblocks."):
+ converted.update(_convert_open_clip_resblock("vision_model.encoder.layers", key[len("visual.transformer.resblocks.") :], value))
+ return converted
+
+
+def ImageMetricsImageRewardStateDictConverter(state_dict):
+ from diffsynth.models.image_reward import ImageRewardModel
+
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ if key.startswith("module."):
+ key = key[len("module.") :]
+ new_key, new_value = ImageRewardModel.convert_key_value(key, value)
+ if new_key is not None and new_key != "blip.text_encoder.embeddings.position_ids":
+ converted[new_key] = new_value
+ hidden_size = converted["blip.text_encoder.embeddings.word_embeddings.weight"].shape[1]
+ converted["blip.itm_head.weight"] = torch.zeros((2, hidden_size), dtype=converted["blip.text_encoder.embeddings.word_embeddings.weight"].dtype)
+ converted["blip.itm_head.bias"] = torch.zeros((2,), dtype=converted["blip.text_encoder.embeddings.word_embeddings.weight"].dtype)
+ return converted
+
+
+def ImageMetricsAestheticStateDictConverter(state_dict):
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ for prefix in ("model.", "module.", "aesthetic_model.", "aesthetics_predictor.", "predictor."):
+ if key.startswith(prefix):
+ key = key[len(prefix) :]
+ if key == "vision_model.embeddings.position_ids":
+ continue
+ converted[key] = value
+ return converted
+
+
+def ImageMetricsFIDStateDictConverter(state_dict):
+ return {"model." + key: state_dict[key] for key in state_dict if not key.startswith("fc.")}
+
+
+def ImageMetricsHPSv3StateDictConverter(state_dict):
+ converted = {}
+ for key in state_dict:
+ value = state_dict[key]
+ if key.startswith("visual."):
+ key = "model.visual." + key[len("visual.") :]
+ elif key.startswith("model.visual."):
+ pass
+ elif key.startswith("model.") and not key.startswith("model.language_model."):
+ key = "model.language_model." + key[len("model.") :]
+ converted[key] = value
+ return converted
+
+
+def _convert_open_clip_resblock(prefix, suffix, value):
+ converted = {}
+ parts = suffix.split(".")
+ layer = parts[0]
+ rest = ".".join(parts[1:])
+ layer_prefix = f"{prefix}.{layer}."
+ if rest == "attn.in_proj_weight":
+ q, k, v = value.chunk(3, dim=0)
+ converted[layer_prefix + "self_attn.q_proj.weight"] = q
+ converted[layer_prefix + "self_attn.k_proj.weight"] = k
+ converted[layer_prefix + "self_attn.v_proj.weight"] = v
+ elif rest == "attn.in_proj_bias":
+ q, k, v = value.chunk(3, dim=0)
+ converted[layer_prefix + "self_attn.q_proj.bias"] = q
+ converted[layer_prefix + "self_attn.k_proj.bias"] = k
+ converted[layer_prefix + "self_attn.v_proj.bias"] = v
+ else:
+ mapping = {
+ "attn.out_proj.": "self_attn.out_proj.",
+ "ln_1.": "layer_norm1.",
+ "ln_2.": "layer_norm2.",
+ "mlp.c_fc.": "mlp.fc1.",
+ "mlp.c_proj.": "mlp.fc2.",
+ }
+ for source, target in mapping.items():
+ if rest.startswith(source):
+ converted[layer_prefix + target + rest[len(source) :]] = value
+ break
+ return converted
diff --git a/docs/en/Model_Details/Image-Quality-Metrics.md b/docs/en/Model_Details/Image-Quality-Metrics.md
index 51c44511c..7d3f87139 100644
--- a/docs/en/Model_Details/Image-Quality-Metrics.md
+++ b/docs/en/Model_Details/Image-Quality-Metrics.md
@@ -19,7 +19,6 @@ For more information about installation, please refer to [Install Dependencies](
Run the following code to quickly load PickScore and score an image against a prompt. The default models will be downloaded from ModelScope to `./models`.
```python
-import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
@@ -35,90 +34,84 @@ prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = PickScoreMetric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
- processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ device=device
+ )
-print("PickScore score:", metric.compute(prompt, image)[0])
+score = metric.compute(prompt, image)[0]
+print(f"PickScore score:: {score:.3f}")
```
## Metrics Overview
-| Metric | Default Model | Input | Output | Example Code |
-| --- | --- | --- | --- | --- |
-| PickScore | [AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/pickscore.py) |
-| ImageReward | [ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/image_reward.py) |
-| HPSv2 | [AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv2.py) |
-| HPSv3 | [MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3) | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv3.py) |
-| CLIP Score | [AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K) | prompt + PIL Image | Text-Image Similarity | [code](../../../examples/image_quality_metric/clipscore.py) |
-| Aesthetic | [AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE) | PIL Image | Aesthetic Score | [code](../../../examples/image_quality_metric/aesthetic.py) |
-| FID | [diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d) | reference image directory + generated image directory | Distribution Distance | [code](../../../examples/image_quality_metric/fid.py) |
+| Metric | Input | Output | Example Code |
+| --- | --- | --- | --- |
+| PickScore | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/pickscore.py) |
+| ImageReward | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/image_reward.py) |
+| HPSv2 | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv2.py) |
+| HPSv3 | prompt + PIL Image | Preference Score | [code](../../../examples/image_quality_metric/hpsv3.py) |
+| CLIP Score | prompt + PIL Image | Text-Image Similarity | [code](../../../examples/image_quality_metric/clipscore.py) |
+| Aesthetic | PIL Image | Aesthetic Score | [code](../../../examples/image_quality_metric/aesthetic.py) |
+| FID | reference image directory + generated image directory | Distribution Distance | [code](../../../examples/image_quality_metric/fid.py) |
-## Single-Image Reward Models
+### Text-Image Alignment and Preference Evaluation
-**PickScore**, **ImageReward**, **HPSv2**, **HPSv3**, and **CLIP Score** share the same input format: a text prompt and an opened `PIL.Image.Image`. Example:
+Applicable metrics: **PickScore**, **ImageReward**, **HPSv2**, **HPSv3**, **CLIP Score**
-```python
-from PIL import Image
-from diffsynth.metrics import CLIPMetric, ModelConfig
+These models are used to evaluate whether an image follows the prompt and aligns with human visual preferences. They must receive both the `prompt` and the `image` simultaneously.
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
-device = "cuda"
+**Basic Scoring**
-metric = CLIPMetric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- device=device,
-)
-scores = metric.calc_scores(prompt, image)[0]
+```python
+score = metric.compute(prompt, image)[0]
```
-If you want to evaluate multiple images, you can pass a list of PIL images:
+**Batch Scoring**
+
+If you need to evaluate multiple images, you can directly pass a list:
```python
-scores = metric.calc_scores(prompt, [image1, image2, image3])
+scores = metric.compute("a cute cat", [image1, image2, image3])
+
+scores = metric.compute(["a cat", "a dog"], [image_cat, image_dog])
```
-When the prompt is a single string, the same prompt will be used for every image. When the prompt is a list of strings, the number of prompts must match the number of images.
+When prompt is a single string, the same prompt will be applied to every image. When prompt is a list of strings, the number of prompts must exactly match the number of images.
+
+### Pure Image Aesthetics Evaluation
-## Aesthetic
+Applicable metric: **Aesthetic**
-Aesthetic only evaluates the aesthetic quality of the image and does not use a prompt.
+This model solely evaluates aesthetic features such as the composition, color, and clarity of the image itself. It does not require a prompt.
```python
-from PIL import Image
from diffsynth.metrics import AestheticMetric
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
metric = AestheticMetric.from_pretrained(device="cuda")
-score = metric.calc_scores(image)[0]
+score = metric.compute(image)[0]
```
-## FID
+### Dataset Distribution Evaluation
+
+Applicable metric: **FID** (Fréchet Inception Distance)
-FID is used to compare the feature distributions of two sets of images. It does not score single images, nor does it use a prompt. A typical use case is comparing a directory of real reference images against a directory of generated results:
+FID does not score individual images; instead, it compares the overall feature distribution distance between a real reference image set and a generated image set. A lower score indicates that the generated distribution is closer to the real distribution.
```python
from diffsynth.metrics import FIDMetric
-reference_dir = FIDMetric.default_reference_dir(
- local_dir="data/examples/ImageQualityMetric/reference/coco_2014_caption_validation",
- max_images=10000,
-)
-generated_dir = ""
+reference_dir = "path/to/real_reference_images"
+generated_dir = "path/to/model_generated_images"
metric = FIDMetric.from_pretrained(device="cuda", batch_size=16)
-score = metric.compute(reference_dir, generated_dir)
-print("FID:", score)
+fid_score = metric.compute(reference_dir, generated_dir)
+print(f"FID: {fid_score:.3f}")
```
-The reference for FID is not a single, fixed official answer. For general text-to-image quality evaluation, the COCO validation set is a convenient default choice; for vertical tasks such as portraits, product images, or medical images, a `reference_dir` consisting of real data from that specific domain should be provided.
+The baseline for FID is not fixed or unique. For general image generation, COCO Validation is commonly used; for specific domains (such as medical images or e-commerce products), a `reference_dir` composed of real data from that specific domain should be provided.
## Important Notes
* The scores from PickScore, ImageReward, HPSv2, HPSv3, CLIPScore, and Aesthetic are suitable for relative comparison within the same metric. It is not recommended to directly compare the numerical values across different metrics.
* HPSv3 is based on Qwen2-VL and is a larger model, requiring significantly more VRAM than CLIP-based metrics.
-* FID is sensitive to the choice of reference, the reference sample size, and the generated sample size.
\ No newline at end of file
+* FID is sensitive to the choice of reference, the reference sample size, and the generated sample size.
diff --git a/docs/zh/Model_Details/Image-Quality-Metrics.md b/docs/zh/Model_Details/Image-Quality-Metrics.md
index a9c756ea8..20cb846a7 100644
--- a/docs/zh/Model_Details/Image-Quality-Metrics.md
+++ b/docs/zh/Model_Details/Image-Quality-Metrics.md
@@ -19,7 +19,6 @@ pip install -e .
运行以下代码可以快速加载 PickScore,并对一张图像和一段提示词进行评分。默认模型会从 ModelScope 下载到 `./models`。
```python
-import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
@@ -35,85 +34,79 @@ prompt = "dog,white and brown dog, sitting on wall, under pink flowers"
device = "cuda"
metric = PickScoreMetric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
- processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ device=device
+ )
+
+score = metric.compute(prompt, image)[0]
+print(f"PickScore score:: {score:.3f}")
```
## 指标总览
-|指标|默认模型|输入|输出|示例代码|
-|-|-|-|-|-|
-|PickScore|[AI-ModelScope/PickScore_v1](https://www.modelscope.cn/models/AI-ModelScope/PickScore_v1)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/pickscore.py)|
-|ImageReward|[ZhipuAI/ImageReward](https://www.modelscope.cn/models/ZhipuAI/ImageReward)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/image_reward.py)|
-|HPSv2|[AI-ModelScope/HPSv2](https://www.modelscope.cn/models/AI-ModelScope/HPSv2)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv2.py)|
-|HPSv3|[MizzenAI/HPSv3](https://www.modelscope.cn/models/MizzenAI/HPSv3)|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv3.py)|
-|CLIP Score|[AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K](https://www.modelscope.cn/models/AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K)|prompt + PIL 图像|图文相似度|[code](../../../examples/image_quality_metric/clipscore.py)|
-|Aesthetic|[AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE](https://www.modelscope.cn/models/AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE)|PIL 图像|美学分数|[code](../../../examples/image_quality_metric/aesthetic.py)|
-|FID|[diffusionTry/weights-inception-2015-12-05-6726825d](https://www.modelscope.cn/models/diffusionTry/weights-inception-2015-12-05-6726825d)|reference 图像目录 + generated 图像目录|分布距离|[code](../../../examples/image_quality_metric/fid.py)|
+|指标|输入|输出|示例代码|
+|-|-|-|-|
+|PickScore|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/pickscore.py)|
+|ImageReward|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/image_reward.py)|
+|HPSv2|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv2.py)|
+|HPSv3|prompt + PIL 图像|偏好分数|[code](../../../examples/image_quality_metric/hpsv3.py)|
+|CLIP Score|prompt + PIL 图像|图文匹配度|[code](../../../examples/image_quality_metric/clipscore.py)|
+|Aesthetic|PIL 图像|美学分数|[code](../../../examples/image_quality_metric/aesthetic.py)|
+|FID|reference 图像目录 + generated 图像目录|分布距离|[code](../../../examples/image_quality_metric/fid.py)|
-## 单图奖励模型
+### 文本-图像对齐与偏好评估
-**PickScore**、**ImageReward**、**HPSv2**、**HPSv3** 和 **CLIP Score** 的输入形式相同:一段文本提示词和一张已经打开的 `PIL.Image.Image`。示例:
+适用指标: **PickScore**,**ImageReward**,**HPSv2**,**HPSv3**,**CLIP Score**
-```python
-from PIL import Image
-from diffsynth.metrics import CLIPMetric, ModelConfig
+这类模型用于评估图像是否遵循提示词以及是否符合人类视觉偏好。它们必须同时接收 `prompt` 和 `image`。
-prompt = ""
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
-device = "cuda"
-
-metric = CLIPMetric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- device=device,
-)
-scores = metric.calc_scores(prompt, image)[0]
+**基础打分**
+```python
+score = metric.compute(prompt, image)[0]
```
-如果要评估多张图像,可以传入 PIL 图像列表:
+**批量打分**
+如果需要评估多张图像,可以直接传入列表:
```python
-scores = metric.calc_scores(prompt, [image1, image2, image3])
+scores = metric.compute("a cute cat", [image1, image2, image3])
+
+scores = metric.compute(["a cat", "a dog"], [image_cat, image_dog])
```
其中 prompt 为单个字符串时,会对每张图像使用同一个 prompt。prompt 为字符串列表时,prompt 数量需要和图像数量一致。
-## Aesthetic
+### 纯图像美学评估
+
+适用指标: **Aesthetic**
+
+该模型仅评估图像本身的构图、色彩、清晰度等美学特征,不需要提示词介入。
-Aesthetic 只评估图像审美质量,不使用 prompt。
```python
-from PIL import Image
from diffsynth.metrics import AestheticMetric
-path_to_image = ""
-image = Image.open(path_to_image).convert("RGB")
metric = AestheticMetric.from_pretrained(device="cuda")
-score = metric.calc_scores(image)[0]
+score = metric.compute(image)[0]
```
-## FID
+### 数据集分布评估
+适用指标: **FID** (Fréchet Inception Distance)
-FID 用于比较两组图像的特征分布。它不是单图打分,也不使用 prompt。典型用法是比较真实参考图像目录和生成结果目录:
+FID 不对单张图片打分,而是比较真实参考图像集与生成图像集的整体特征分布距离。分数越低,说明生成分布越接近真实分布。
```python
from diffsynth.metrics import FIDMetric
-reference_dir = FIDMetric.default_reference_dir(
- local_dir="data/examples/ImageQualityMetric/reference/coco_2014_caption_validation",
- max_images=10000,
-)
-generated_dir = ""
+reference_dir = "path/to/real_reference_images"
+generated_dir = "path/to/model_generated_images"
metric = FIDMetric.from_pretrained(device="cuda", batch_size=16)
-score = metric.compute(reference_dir, generated_dir)
-print("FID:", score)
+fid_score = metric.compute(reference_dir, generated_dir)
+print(f"FID: {fid_score:.3f}")
```
-FID 的 reference 不是固定唯一的官方答案。对于通用文生图质量评估,COCO validation 是一个方便的默认选择;对于人像、商品图、医学等垂直任务,应传入该领域真实数据构成的 `reference_dir`。
+FID 的基准不是固定唯一的。对于通用图像生成,常使用 COCO Validation;如果是特定领域(如医学图像、电商商品),应提供该领域真实数据构成的 `reference_dir`。
## 注意事项
diff --git a/examples/image_quality_metric/aesthetic.py b/examples/image_quality_metric/aesthetic.py
index 962eeeddf..ae0b518fe 100644
--- a/examples/image_quality_metric/aesthetic.py
+++ b/examples/image_quality_metric/aesthetic.py
@@ -12,8 +12,9 @@
device = "cuda"
metric = AestheticMetric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE"),
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="Aesthetic/model.safetensors"),
+ device=device
+ )
-print("Aesthetic score:", metric.compute(image)[0])
\ No newline at end of file
+score = metric.compute(image)[0]
+print(f"Aesthetic score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/clipscore.py b/examples/image_quality_metric/clipscore.py
index 984118137..4a6621541 100644
--- a/examples/image_quality_metric/clipscore.py
+++ b/examples/image_quality_metric/clipscore.py
@@ -1,4 +1,3 @@
-import csv
from diffsynth.metrics import CLIPMetric, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
@@ -14,8 +13,9 @@
device = "cuda"
metric = CLIPMetric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="CLIP-ViT-H-14-laion2B-s32B-b79K/model.safetensors"),
+ device=device
+ )
-print("CLIP score:", metric.compute(prompt, image)[0])
\ No newline at end of file
+score = metric.compute(prompt, image)[0]
+print(f"CLIP score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/fid.py b/examples/image_quality_metric/fid.py
index 6a7d25d65..94a982b47 100644
--- a/examples/image_quality_metric/fid.py
+++ b/examples/image_quality_metric/fid.py
@@ -1,23 +1,13 @@
-from diffsynth.metrics import FIDMetric
-from modelscope import dataset_snapshot_download
+from diffsynth.metrics import FIDMetric, ModelConfig
-dataset_snapshot_download(
- "DiffSynth-Studio/diffsynth_example_dataset",
- allow_file_pattern="flux/FLUX.1-dev/*",
- local_dir="./data/diffsynth_example_dataset",
-)
-
-generated_dir = "data/diffsynth_example_dataset/flux/FLUX.1-dev"
+reference_dir = ""
+generated_dir = ""
device = "cuda"
-reference_dir = FIDMetric.default_reference_dir(
- local_dir="data/examples/ImageQualityMetric/reference/coco_2014_caption_validation",
- max_images=10000, # use None for the full validation split
-)
-
metric = FIDMetric.from_pretrained(
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="FID/model.safetensors"),
device=device,
- batch_size=16,
)
-print("FID score:", metric.compute(reference_dir, generated_dir))
+score = metric.compute(reference_dir, generated_dir)
+print(f"FID score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/hpsv2.py b/examples/image_quality_metric/hpsv2.py
index b615174c8..c460357f6 100644
--- a/examples/image_quality_metric/hpsv2.py
+++ b/examples/image_quality_metric/hpsv2.py
@@ -1,4 +1,3 @@
-import csv
from diffsynth.metrics import HPSv2Metric, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
@@ -14,10 +13,9 @@
device = "cuda"
metric = HPSv2Metric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/HPSv2"),
- processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- version="v2.0", # choice: v2.0, v2.1
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv2/model.safetensors"),
+ device=device
+ )
-print("HPSv2 score:", metric.compute(prompt, image)[0])
\ No newline at end of file
+score = metric.compute(prompt, image)[0]
+print(f"HPSv2 score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/hpsv3.py b/examples/image_quality_metric/hpsv3.py
index c436f5ded..d6917f3d5 100644
--- a/examples/image_quality_metric/hpsv3.py
+++ b/examples/image_quality_metric/hpsv3.py
@@ -1,4 +1,3 @@
-import csv
from diffsynth.metrics import HPSv3Metric, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
@@ -14,9 +13,9 @@
device = "cuda"
metric = HPSv3Metric.from_pretrained(
- model_config=ModelConfig(model_id="MizzenAI/HPSv3"),
- base_model_config=ModelConfig(model_id="Qwen/Qwen2-VL-7B-Instruct"),
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="HPSv3/model.safetensors"),
+ device=device
+ )
-print("HPSv3 score:", metric.compute(prompt, image)[0])
\ No newline at end of file
+score = metric.compute(prompt, image)[0]
+print(f"HPSv3 score: {score:.3f}")
\ No newline at end of file
diff --git a/examples/image_quality_metric/image_reward.py b/examples/image_quality_metric/image_reward.py
index f9364364d..e8e7df5b6 100644
--- a/examples/image_quality_metric/image_reward.py
+++ b/examples/image_quality_metric/image_reward.py
@@ -1,4 +1,3 @@
-import csv
from diffsynth.metrics import ImageRewardMetric, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
@@ -14,8 +13,9 @@
device = "cuda"
metric = ImageRewardMetric.from_pretrained(
- model_config=ModelConfig(model_id="ZhipuAI/ImageReward"),
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="ImageReward/model.safetensors"),
+ device=device
+ )
-print("ImageReward score:", metric.compute(prompt, image)[0])
\ No newline at end of file
+score = metric.compute(prompt, image)[0]
+print(f"ImageReward score: {score:.3f}")
diff --git a/examples/image_quality_metric/pickscore.py b/examples/image_quality_metric/pickscore.py
index cbcb4b8be..b409f3b3c 100644
--- a/examples/image_quality_metric/pickscore.py
+++ b/examples/image_quality_metric/pickscore.py
@@ -1,4 +1,3 @@
-import csv
from diffsynth.metrics import PickScoreMetric, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
@@ -14,9 +13,9 @@
device = "cuda"
metric = PickScoreMetric.from_pretrained(
- model_config=ModelConfig(model_id="AI-ModelScope/PickScore_v1"),
- processor_config=ModelConfig(model_id="AI-ModelScope/CLIP-ViT-H-14-laion2B-s32B-b79K"),
- device=device,
-)
+ model_config=ModelConfig(model_id="DiffSynth-Studio/ImageMetrics", origin_file_pattern="PickScore/model.safetensors"),
+ device=device
+ )
-print("PickScore score:", metric.compute(prompt, image)[0])
\ No newline at end of file
+score = metric.compute(prompt, image)[0]
+print(f"PickScore score:: {score:.3f}")
\ No newline at end of file