Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions docs/examples/image_text_models/vision_litellm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@
print(str(ch.content))

# test with PIL image
res = m.instruct(
res_instruct = m.instruct(
"Is there a person on the image? Is the subject in the image smiling?",
images=[test_pil],
)
print(str(res))
print(f"Test with PIL and instruct: \n{str(res_instruct)}\n-----")
# print(m.last_prompt())

# with PIL image and using m.chat
res = m.chat("How many eyes can you identify in the image? Explain.", images=[test_pil])
print(str(res.content))
res_chat = m.chat(
"How many eyes can you identify in the image? Explain.", images=[test_pil]
)
print(f"Test with PIL and chat: \n{str(res_chat.content)}\n-----")

# and now without images again...
res = m.instruct("How many eyes can you identify in the image?", images=[])
print(str(res))
res_empty = m.instruct("How many eyes can you identify in the image?", images=[])
print(f"Test without image: \n{str(res_empty)}\n-----")
11 changes: 3 additions & 8 deletions mellea/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ def __repr__(self):
return f"CBlock({self.value}, {self._meta.__repr__()})"


class ImageBlock:
class ImageBlock(CBlock):
"""A `ImageBlock` represents an image (as base64 PNG)."""

def __init__(self, value: str, meta: dict[str, Any] | None = None):
"""Initializes the ImageBlock with a base64 PNG string representation and some metadata."""
assert self.is_valid_base64_png(value), (
"Invalid base64 string representation of image."
)
self._value = value
self._meta = {} if meta is None else meta
super().__init__(value, meta)

@staticmethod
def is_valid_base64_png(s: str) -> bool:
Expand Down Expand Up @@ -117,13 +116,9 @@ def from_pil_image(
image_base64 = cls.pil_to_base64(image)
return cls(image_base64, meta)

def __str__(self):
"""Stringifies the block."""
return self._value

def __repr__(self):
"""Provides a python-parsable representation of the block (usually)."""
return f"ImageBlock({self._value}, {self._meta.__repr__()})"
return f"ImageBlock({self.value}, {self._meta.__repr__()})"


S = typing_extensions.TypeVar("S", default=Any, covariant=True)
Expand Down
15 changes: 5 additions & 10 deletions mellea/stdlib/components/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,22 @@ def __init__(
self.content = content # TODO this should be private.
self._content_cblock = CBlock(self.content)
self._images = images
# TODO this should replace _images.
self._images_cblocks: list[CBlock] | None = None
if self._images is not None:
self._images_cblocks = [CBlock(str(i)) for i in self._images]
self._docs = documents

@property
def images(self) -> None | list[str]:
"""Returns the images associated with this message as list of base 64 strings."""
if self._images_cblocks is not None:
return [str(i.value) for i in self._images_cblocks]
if self._images is not None:
return [str(i) for i in self._images]
return None

def parts(self) -> list[Component | CBlock]:
"""Returns all of the constituent parts of an Instruction."""
parts: list[Component | CBlock] = [self._content_cblock]
if self._docs is not None:
parts.extend(self._docs)
# TODO: we need to do this but images are not currently cblocks. This is captured in an issue on Jan 26 sprint. Leaving this code commented out for now.
# if self._images is not None:
# parts.extend(self._images)
if self._images is not None:
parts.extend(self._images)
return parts

def format_for_llm(self) -> TemplateRepresentation:
Expand All @@ -78,7 +73,7 @@ def format_for_llm(self) -> TemplateRepresentation:
args={
"role": self.role,
"content": self._content_cblock,
"images": self._images_cblocks,
"images": self._images,
"documents": self._docs,
},
template_order=["*", "Message"],
Expand Down
6 changes: 3 additions & 3 deletions test/backends/test_vision_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ def test_image_block_construction(pil_image: Image.Image):

image_block = ImageBlock(img_str)
assert isinstance(image_block, ImageBlock)
assert isinstance(image_block._value, str)
assert isinstance(image_block.value, str)


def test_image_block_construction_from_pil(pil_image: Image.Image):
image_block = ImageBlock.from_pil_image(pil_image)
assert isinstance(image_block, ImageBlock)
assert isinstance(image_block._value, str)
assert isinstance(image_block.value, str)
assert ImageBlock.is_valid_base64_png(str(image_block))


Expand Down Expand Up @@ -129,7 +129,7 @@ def test_image_block_in_chat(

# first image in image list should be the same as the image block
image0_str = last_action.images[0] # type: ignore
assert image0_str == ImageBlock.from_pil_image(pil_image)._value
assert image0_str == ImageBlock.from_pil_image(pil_image).value

# get prompt message
lp = turn.output._generate_log.prompt # type: ignore
Expand Down
10 changes: 5 additions & 5 deletions test/backends/test_vision_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def test_image_block_construction(pil_image: Image.Image):

image_block = ImageBlock(img_str)
assert isinstance(image_block, ImageBlock)
assert isinstance(image_block._value, str)
assert isinstance(image_block.value, str)


def test_image_block_construction_from_pil(pil_image: Image.Image):
image_block = ImageBlock.from_pil_image(pil_image)
assert isinstance(image_block, ImageBlock)
assert isinstance(image_block._value, str)
assert isinstance(image_block.value, str)
assert ImageBlock.is_valid_base64_png(str(image_block))


Expand Down Expand Up @@ -120,7 +120,7 @@ def test_image_block_in_instruction(
assert "url" in image_url

# check that the image is in the url content
assert image_block._value[:100] in image_url["url"]
assert image_block.value[:100] in image_url["url"]


def test_image_block_in_chat(
Expand All @@ -144,7 +144,7 @@ def test_image_block_in_chat(

# first image in image list should be the same as the image block
image0_str = last_action.images[0] # type: ignore
assert image0_str == ImageBlock.from_pil_image(pil_image)._value
assert image0_str == ImageBlock.from_pil_image(pil_image).value

# get prompt message
lp = turn.output._generate_log.prompt # type: ignore
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_image_block_in_chat(
assert "url" in image_url

# check that the image is in the url content
assert ImageBlock.from_pil_image(pil_image)._value[:100] in image_url["url"]
assert ImageBlock.from_pil_image(pil_image).value[:100] in image_url["url"]


if __name__ == "__main__":
Expand Down