diff --git a/docs/examples/image_text_models/vision_litellm_backend.py b/docs/examples/image_text_models/vision_litellm_backend.py index 03180fb2..03f1ea1b 100644 --- a/docs/examples/image_text_models/vision_litellm_backend.py +++ b/docs/examples/image_text_models/vision_litellm_backend.py @@ -24,17 +24,19 @@ print(str(ch.content)) # test with PIL image -res = m.instruct( +res_instruct = m.instruct( "Is there a person on the image? Is the subject in the image smiling?", images=[test_pil], ) -print(str(res)) +print(f"Test with PIL and instruct: \n{str(res_instruct)}\n-----") # print(m.last_prompt()) # with PIL image and using m.chat -res = m.chat("How many eyes can you identify in the image? Explain.", images=[test_pil]) -print(str(res.content)) +res_chat = m.chat( + "How many eyes can you identify in the image? Explain.", images=[test_pil] +) +print(f"Test with PIL and chat: \n{str(res_chat.content)}\n-----") # and now without images again... -res = m.instruct("How many eyes can you identify in the image?", images=[]) -print(str(res)) +res_empty = m.instruct("How many eyes can you identify in the image?", images=[]) +print(f"Test without image: \n{str(res_empty)}\n-----") diff --git a/mellea/core/base.py b/mellea/core/base.py index 41179bfe..3ab7c4ec 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -62,7 +62,7 @@ def __repr__(self): return f"CBlock({self.value}, {self._meta.__repr__()})" -class ImageBlock: +class ImageBlock(CBlock): """A `ImageBlock` represents an image (as base64 PNG).""" def __init__(self, value: str, meta: dict[str, Any] | None = None): @@ -70,8 +70,7 @@ def __init__(self, value: str, meta: dict[str, Any] | None = None): assert self.is_valid_base64_png(value), ( "Invalid base64 string representation of image." ) - self._value = value - self._meta = {} if meta is None else meta + super().__init__(value, meta) @staticmethod def is_valid_base64_png(s: str) -> bool: @@ -117,13 +116,9 @@ def from_pil_image( image_base64 = cls.pil_to_base64(image) return cls(image_base64, meta) - def __str__(self): - """Stringifies the block.""" - return self._value - def __repr__(self): """Provides a python-parsable representation of the block (usually).""" - return f"ImageBlock({self._value}, {self._meta.__repr__()})" + return f"ImageBlock({self.value}, {self._meta.__repr__()})" S = typing_extensions.TypeVar("S", default=Any, covariant=True) diff --git a/mellea/stdlib/components/chat.py b/mellea/stdlib/components/chat.py index 8763a70b..5235bdb2 100644 --- a/mellea/stdlib/components/chat.py +++ b/mellea/stdlib/components/chat.py @@ -44,17 +44,13 @@ def __init__( self.content = content # TODO this should be private. self._content_cblock = CBlock(self.content) self._images = images - # TODO this should replace _images. - self._images_cblocks: list[CBlock] | None = None - if self._images is not None: - self._images_cblocks = [CBlock(str(i)) for i in self._images] self._docs = documents @property def images(self) -> None | list[str]: """Returns the images associated with this message as list of base 64 strings.""" - if self._images_cblocks is not None: - return [str(i.value) for i in self._images_cblocks] + if self._images is not None: + return [str(i) for i in self._images] return None def parts(self) -> list[Component | CBlock]: @@ -62,9 +58,8 @@ def parts(self) -> list[Component | CBlock]: parts: list[Component | CBlock] = [self._content_cblock] if self._docs is not None: parts.extend(self._docs) - # TODO: we need to do this but images are not currently cblocks. This is captured in an issue on Jan 26 sprint. Leaving this code commented out for now. - # if self._images is not None: - # parts.extend(self._images) + if self._images is not None: + parts.extend(self._images) return parts def format_for_llm(self) -> TemplateRepresentation: @@ -78,7 +73,7 @@ def format_for_llm(self) -> TemplateRepresentation: args={ "role": self.role, "content": self._content_cblock, - "images": self._images_cblocks, + "images": self._images, "documents": self._docs, }, template_order=["*", "Message"], diff --git a/test/backends/test_vision_ollama.py b/test/backends/test_vision_ollama.py index 740043d9..bae43d31 100644 --- a/test/backends/test_vision_ollama.py +++ b/test/backends/test_vision_ollama.py @@ -46,13 +46,13 @@ def test_image_block_construction(pil_image: Image.Image): image_block = ImageBlock(img_str) assert isinstance(image_block, ImageBlock) - assert isinstance(image_block._value, str) + assert isinstance(image_block.value, str) def test_image_block_construction_from_pil(pil_image: Image.Image): image_block = ImageBlock.from_pil_image(pil_image) assert isinstance(image_block, ImageBlock) - assert isinstance(image_block._value, str) + assert isinstance(image_block.value, str) assert ImageBlock.is_valid_base64_png(str(image_block)) @@ -129,7 +129,7 @@ def test_image_block_in_chat( # first image in image list should be the same as the image block image0_str = last_action.images[0] # type: ignore - assert image0_str == ImageBlock.from_pil_image(pil_image)._value + assert image0_str == ImageBlock.from_pil_image(pil_image).value # get prompt message lp = turn.output._generate_log.prompt # type: ignore diff --git a/test/backends/test_vision_openai.py b/test/backends/test_vision_openai.py index 9c958efe..319a94b4 100644 --- a/test/backends/test_vision_openai.py +++ b/test/backends/test_vision_openai.py @@ -55,13 +55,13 @@ def test_image_block_construction(pil_image: Image.Image): image_block = ImageBlock(img_str) assert isinstance(image_block, ImageBlock) - assert isinstance(image_block._value, str) + assert isinstance(image_block.value, str) def test_image_block_construction_from_pil(pil_image: Image.Image): image_block = ImageBlock.from_pil_image(pil_image) assert isinstance(image_block, ImageBlock) - assert isinstance(image_block._value, str) + assert isinstance(image_block.value, str) assert ImageBlock.is_valid_base64_png(str(image_block)) @@ -120,7 +120,7 @@ def test_image_block_in_instruction( assert "url" in image_url # check that the image is in the url content - assert image_block._value[:100] in image_url["url"] + assert image_block.value[:100] in image_url["url"] def test_image_block_in_chat( @@ -144,7 +144,7 @@ def test_image_block_in_chat( # first image in image list should be the same as the image block image0_str = last_action.images[0] # type: ignore - assert image0_str == ImageBlock.from_pil_image(pil_image)._value + assert image0_str == ImageBlock.from_pil_image(pil_image).value # get prompt message lp = turn.output._generate_log.prompt # type: ignore @@ -173,7 +173,7 @@ def test_image_block_in_chat( assert "url" in image_url # check that the image is in the url content - assert ImageBlock.from_pil_image(pil_image)._value[:100] in image_url["url"] + assert ImageBlock.from_pil_image(pil_image).value[:100] in image_url["url"] if __name__ == "__main__":