From d6be45e96f9ac5838847c1d6cea5563e47ede0c9 Mon Sep 17 00:00:00 2001 From: nackabhi Date: Sat, 14 Mar 2026 17:36:01 +0530 Subject: [PATCH] fixed 13177 --- .../pipelines/flux2/image_processor.py | 42 ++++++++++++++++--- .../pipelines/flux2/pipeline_flux2_klein.py | 7 ++-- .../flux2/test_pipeline_flux2_klein.py | 30 +++++++++++++ 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index e0a1b80ce533..5a07aaf4edcf 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -14,7 +14,9 @@ import math +import numpy as np import PIL.Image +import torch from ...configuration_utils import register_to_config from ...image_processor import VaeImageProcessor @@ -56,27 +58,57 @@ def __init__( do_convert_rgb=do_convert_rgb, ) + @staticmethod + def to_pil(image) -> PIL.Image.Image: + """Convert torch.Tensor or np.ndarray to PIL.Image.Image. + + Accepts: + - PIL.Image.Image → returned as-is + - torch.Tensor → shape (C, H, W) or (B, C, H, W), values in [0, 1] + - np.ndarray → shape (H, W, C) or (B, H, W, C), values in [0, 1] + """ + if isinstance(image, PIL.Image.Image): + return image + + if isinstance(image, torch.Tensor): + image = image.detach().cpu().float() + if image.ndim == 4: + image = image[0] + image = image.permute(1, 2, 0).numpy() + elif isinstance(image, np.ndarray): + if image.ndim == 4: + image = image[0] + else: + raise ValueError( + f"Expected PIL.Image.Image, torch.Tensor, or np.ndarray, got {type(image)}" + ) + + if image.dtype != np.uint8: + image = (np.clip(image, 0, 1) * 255).astype(np.uint8) + + return PIL.Image.fromarray(image) + @staticmethod def check_image_input( - image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024 + image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024 ) -> PIL.Image.Image: """ Check if image meets minimum size and aspect ratio requirements. + Accepts PIL.Image.Image, torch.Tensor, or np.ndarray and converts to PIL. Args: - image: PIL Image to validate + image: Image to validate (PIL, tensor, or numpy array) max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width) min_side_length: Minimum pixels required for width and height max_area: Maximum allowed area in pixels² Returns: - The input image if valid + The image as PIL.Image.Image Raises: ValueError: If image is too small or aspect ratio is too extreme """ - if not isinstance(image, PIL.Image.Image): - raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + image = Flux2ImageProcessor.to_pil(image) width, height = image.size diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 936d2c3804ab..a92c2d2bc527 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -25,6 +25,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor +from ...image_processor import PipelineImageInput from ..pipeline_utils import DiffusionPipeline from .image_processor import Flux2ImageProcessor from .pipeline_output import Flux2PipelineOutput @@ -608,7 +609,7 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image: list[PIL.Image.Image] | PIL.Image.Image | None = None, + image: PipelineImageInput | None = None, prompt: str | list[str] = None, height: int | None = None, width: int | None = None, @@ -758,8 +759,8 @@ def __call__( condition_images = None if image is not None: - for img in image: - self.image_processor.check_image_input(img) + # Convert each image to PIL (handles tensor/numpy/PIL uniformly) + image = [self.image_processor.check_image_input(img) for img in image] condition_images = [] for img in image: diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index 8ed9bf3d1e91..acd31fb900e8 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -178,6 +178,36 @@ def test_image_input(self): # fmt: on assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4) + def test_image_input_tensor(self): + """Issue #13177: pipeline should accept torch.Tensor images.""" + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()).to(device) + inputs = self.get_dummy_inputs(device) + + inputs["image"] = torch.rand(3, 64, 64) + image = pipe(**inputs).images + assert image is not None and image.shape[-1] == 3 + + def test_image_input_numpy(self): + """Issue #13177: pipeline should accept np.ndarray images.""" + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()).to(device) + inputs = self.get_dummy_inputs(device) + + inputs["image"] = np.random.rand(64, 64, 3).astype(np.float32) + image = pipe(**inputs).images + assert image is not None and image.shape[-1] == 3 + + def test_image_input_tensor_list(self): + """Issue #13177: pipeline should accept list of tensors.""" + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()).to(device) + inputs = self.get_dummy_inputs(device) + + inputs["image"] = [torch.rand(3, 64, 64), torch.rand(3, 64, 64)] + image = pipe(**inputs).images + assert image is not None and image.shape[-1] == 3 + @unittest.skip("Needs to be revisited") def test_encode_prompt_works_in_isolation(self): pass