diff --git a/fastapi_startkit/src/fastapi_startkit/ai/__init__.py b/fastapi_startkit/src/fastapi_startkit/ai/__init__.py index ff4cde47..76c04f99 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/__init__.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/__init__.py @@ -2,12 +2,28 @@ Provides a LangGraph-powered declarative API for building AI agents backed by Anthropic, OpenAI, or Google provider SDKs. + +Also exposes a Laravel-style fluent API for image generation and text-to-speech:: + + from fastapi_startkit.ai import Image, Audio, Document + + image = await Image.of("A donut on a counter").generate() + + # With a photo attachment + doc = await Document.from_url("https://example.com/photo.jpg") + image = await Image.of("Make impressionist").attachments([doc]).generate() + + audio = await Audio.of("Hello world").female().generate() """ from .agent import Agent +from .audio import Audio, AudioResponse +from .audio_factory import AudioFactory from .config import AIConfig, AnthropicConfig, GoogleConfig, OpenAIConfig from .decorators import max_steps, max_tokens, memory, model, provider, timeout, top_p from .document import Document +from .image import Image, ImageResponse +from .image_factory import ImageFactory from .providers.ai_provider import AIProvider from .response import AgentResponse, AgentSnapshot @@ -18,8 +34,14 @@ "AIConfig", "AIProvider", "AnthropicConfig", + "Audio", + "AudioResponse", + "AudioFactory", "Document", "GoogleConfig", + "Image", + "ImageFactory", + "ImageResponse", "OpenAIConfig", "max_steps", "max_tokens", diff --git a/fastapi_startkit/src/fastapi_startkit/ai/audio.py b/fastapi_startkit/src/fastapi_startkit/ai/audio.py new file mode 100644 index 00000000..b709daff --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/audio.py @@ -0,0 +1,214 @@ +"""Audio generation API — text-to-speech via a pluggable provider.""" + +from __future__ import annotations + +import asyncio +import uuid +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from .audio_factory import AudioFactory + +try: + from fastapi_startkit.storage.storage import Storage +except Exception: # pragma: no cover + Storage = None # type: ignore[assignment,misc] + +try: + from fastapi_startkit import Config +except Exception: # pragma: no cover + Config = None # type: ignore[assignment,misc] + + +class AudioResponse: + """Returned by :meth:`Audio.generate`. + + Holds raw MP3 (or other format) bytes and provides async helpers to + persist the audio to any configured storage disk:: + + audio = await Audio.of("Hello world").generate() + + path = await audio.store() # auto-named, private disk + path = await audio.storeAs("greeting.mp3") # named, private disk + path = await audio.storePublicly() # auto-named, public disk + path = await audio.storePubliclyAs("greeting.mp3") + """ + + def __init__(self, data: bytes, fmt: str = "mp3"): + self._data = data + self._fmt = fmt + + @property + def data(self) -> bytes: + """Raw audio bytes.""" + return self._data + + def _auto_filename(self) -> str: + return f"{uuid.uuid4()}.{self._fmt}" + + # ── Storage helpers ──────────────────────────────────────────────────────── + + async def store(self) -> str: + """Save to the default private disk with an auto-generated filename.""" + return await self._save(self._auto_filename(), disk="local") + + async def storeAs(self, name: str) -> str: + """Save to the default private disk with a custom filename.""" + return await self._save(name, disk="local") + + async def storePublicly(self) -> str: + """Save to the public disk with an auto-generated filename.""" + return await self._save(self._auto_filename(), disk="public") + + async def storePubliclyAs(self, name: str) -> str: + """Save to the public disk with a custom filename.""" + return await self._save(name, disk="public") + + # ── Internal ─────────────────────────────────────────────────────────────── + + async def _save(self, name: str, disk: str = "local") -> str: + return await asyncio.to_thread(self._save_sync, name, disk) + + def _save_sync(self, name: str, disk: str) -> str: + """Try the Storage facade first; fall back to a temp file.""" + if Storage is not None: + try: + Storage.disk(disk).put(name, self._data) + return name + except Exception: + pass + import os + import tempfile + + path = os.path.join(tempfile.gettempdir(), name) + with open(path, "wb") as f: + f.write(self._data) + return path + + +class Audio: + """Fluent builder for text-to-speech generation. + + The active backend is selected from :attr:`~fastapi_startkit.ai.AIConfig.audio_provider` + (env: ``AI_AUDIO_PROVIDER``). Defaults to OpenAI TTS. + + Usage:: + + audio = await Audio.of("Hello world").generate() + audio = await Audio.of("Hello world").female().generate() + audio = await Audio.of("Hello world").male().generate() + audio = await Audio.of("Hello world").voice("nova").generate() + + Available OpenAI TTS voices: alloy, echo, fable, onyx, nova, shimmer. + """ + + # OpenAI TTS voice presets + _DEFAULT_VOICE = "alloy" + _DEFAULT_FEMALE_VOICE = "nova" + _DEFAULT_MALE_VOICE = "onyx" + + def __init__(self, text: str): + self._text = text + self._voice: str = self._DEFAULT_VOICE + self._model: str = "tts-1" + self._speed: float = 1.0 + self._response_format: str = "mp3" + + @classmethod + def of(cls, text: str) -> "Audio": + """Create an :class:`Audio` builder with the given input text.""" + return cls(text) + + # ── Modifier methods (chainable) ─────────────────────────────────────────── + + def female(self) -> "Audio": + """Use a female voice (``nova``).""" + self._voice = self._DEFAULT_FEMALE_VOICE + return self + + def male(self) -> "Audio": + """Use a male voice (``onyx``).""" + self._voice = self._DEFAULT_MALE_VOICE + return self + + def voice(self, name: str) -> "Audio": + """Set an explicit TTS voice name. + + OpenAI voices: ``alloy``, ``echo``, ``fable``, ``onyx``, ``nova``, + ``shimmer``. + """ + self._voice = name + return self + + def model(self, name: str) -> "Audio": + """Override the TTS model (default: ``tts-1``). + + Use ``tts-1-hd`` for higher quality at the cost of latency. + """ + self._model = name + return self + + def speed(self, value: float) -> "Audio": + """Set speech speed (0.25 – 4.0, default: 1.0).""" + self._speed = value + return self + + def format(self, fmt: str) -> "Audio": + """Set output format: ``mp3``, ``opus``, ``aac``, or ``flac``.""" + self._response_format = fmt + return self + + # ── Generation ───────────────────────────────────────────────────────────── + + async def generate(self) -> AudioResponse: + """Call the configured TTS provider and return an :class:`AudioResponse`.""" + provider = self._resolve_provider() + data = await provider.synthesize( + text=self._text, + voice=self._voice, + model=self._model, + speed=self._speed, + fmt=self._response_format, + ) + return AudioResponse(data=data, fmt=self._response_format) + + # ── Internal ─────────────────────────────────────────────────────────────── + + def _resolve_provider(self) -> "AudioFactory": + from .audio_factory import ( # noqa: PLC0415 + ElevenLabsAudioFactory, + GoogleAudioFactory, + OpenAIAudioFactory, + ) + + provider_name = "openai" + api_key: Optional[str] = None + base_url: Optional[str] = None + google_key: Optional[str] = None + elevenlabs_key: Optional[str] = None + + try: + ai_config = Config.get("ai") if Config is not None else None # type: ignore[union-attr] + if ai_config is None: + raise RuntimeError("Config not available") + provider_name = ai_config.audio_provider + openai_cfg = ai_config.providers.get("openai") + if openai_cfg: + api_key = openai_cfg.key or None + base_url = openai_cfg.url or None + google_cfg = ai_config.providers.get("google") + if google_cfg: + google_key = google_cfg.key or None + el_cfg = ai_config.providers.get("elevenlabs") + if el_cfg: + elevenlabs_key = el_cfg.key or None + except Exception: + pass + + if provider_name == "openai": + return OpenAIAudioFactory(api_key=api_key, base_url=base_url) + if provider_name == "google": + return GoogleAudioFactory(api_key=google_key) + if provider_name == "elevenlabs": + return ElevenLabsAudioFactory(api_key=elevenlabs_key) + raise ValueError(f"Unknown audio provider: {provider_name!r}. Use 'openai', 'google', or 'elevenlabs'.") diff --git a/fastapi_startkit/src/fastapi_startkit/ai/audio_factory.py b/fastapi_startkit/src/fastapi_startkit/ai/audio_factory.py new file mode 100644 index 00000000..14c817b6 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/audio_factory.py @@ -0,0 +1,260 @@ +"""Audio synthesis provider abstractions. + +Providers implement the :class:`AudioFactory` ABC so that the +:class:`~fastapi_startkit.ai.Audio` builder is not hard-wired to a single +vendor. Select the active provider via ``AI_AUDIO_PROVIDER`` in your +``.env`` (or ``AIConfig.audio_provider``). + +Supported providers +------------------- +* ``openai`` — OpenAI TTS (tts-1 / tts-1-hd) (default) +* ``google`` — Google Gemini TTS via the ``google-genai`` SDK +* ``elevenlabs`` — ElevenLabs TTS via the ``elevenlabs`` SDK +""" + +from __future__ import annotations + +import asyncio +import struct +from abc import ABC, abstractmethod + + +class AudioFactory(ABC): + """Abstract base for text-to-speech backends.""" + + @abstractmethod + async def synthesize( + self, + text: str, + voice: str, + model: str, + speed: float, + fmt: str, + ) -> bytes: + """Convert *text* to speech and return raw audio bytes.""" + + +class OpenAIAudioFactory(AudioFactory): + """OpenAI TTS provider using :class:`openai.AsyncOpenAI`. + + Supported voices: ``alloy``, ``echo``, ``fable``, ``onyx``, ``nova``, + ``shimmer``. Supported formats: ``mp3``, ``opus``, ``aac``, ``flac``. + """ + + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self._api_key = api_key + self._base_url = base_url + + async def synthesize( + self, + text: str, + voice: str, + model: str, + speed: float, + fmt: str, + ) -> bytes: + from openai import AsyncOpenAI # noqa: PLC0415 + + client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url) + response = await client.audio.speech.create( + model=model, + voice=voice, + input=text, + speed=speed, + response_format=fmt, + ) + return response.read() + + +class GoogleAudioFactory(AudioFactory): + """Google Gemini TTS provider via the ``google-genai`` SDK. + + Requires: ``pip install google-genai`` + + Configure via ``.env``:: + + AI_AUDIO_PROVIDER=google + GEMINI_API_KEY=your-key # or GOOGLE_API_KEY + + Default model: ``gemini-2.5-flash-preview-tts``. + + Google voice names (pass via :meth:`~fastapi_startkit.ai.Audio.voice`): + ``Kore``, ``Aoede``, ``Puck``, ``Charon``, ``Fenrir``, ``Leda``, + ``Orus``, ``Zephyr``. + + OpenAI-compatible voice aliases are also accepted and mapped + automatically: + + +----------+-------------+ + | Alias | Google voice| + +==========+=============+ + | nova | Aoede | + | alloy | Kore | + | echo | Charon | + | fable | Puck | + | onyx | Fenrir | + | shimmer | Leda | + +----------+-------------+ + + The provider returns **WAV** bytes regardless of the requested format, + because Gemini TTS yields raw PCM16 which is wrapped in a WAV container. + """ + + # Map OpenAI-style voice aliases → Google Gemini voice names + _VOICE_MAP: dict[str, str] = { + "nova": "Aoede", + "alloy": "Kore", + "echo": "Charon", + "fable": "Puck", + "onyx": "Fenrir", + "shimmer": "Leda", + } + + def __init__(self, api_key: str | None = None): + self._api_key = api_key + + async def synthesize( + self, + text: str, + voice: str, + model: str, + speed: float, + fmt: str, + ) -> bytes: + """Synthesise *text* via Gemini TTS and return WAV bytes. + + .. note:: + The ``speed`` parameter is accepted for API compatibility but is + not currently supported by the Gemini TTS API. + """ + from google import genai # noqa: PLC0415 + from google.genai import types # noqa: PLC0415 + + client = genai.Client(api_key=self._api_key) + google_voice = self._VOICE_MAP.get(voice, voice) + + response = await asyncio.to_thread( + client.models.generate_content, + model=model, + contents=text, + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=google_voice, + ) + ) + ), + ), + ) + + # Gemini TTS returns raw PCM16 samples — wrap in WAV container + pcm_data = response.candidates[0].content.parts[0].inline_data.data + return _pcm_to_wav(pcm_data) + + +class ElevenLabsAudioFactory(AudioFactory): + """ElevenLabs TTS provider via the ``elevenlabs`` SDK. + + Requires: ``pip install elevenlabs`` + + Configure via ``.env``:: + + AI_AUDIO_PROVIDER=elevenlabs + ELEVENLABS_API_KEY=your-key + + Voice names are mapped to ElevenLabs voice IDs. You can pass any + ElevenLabs voice ID directly via :meth:`~fastapi_startkit.ai.Audio.voice`, + or use one of the built-in aliases: + + +----------+----------------+---------+ + | Alias | EL Name | Gender | + +==========+================+=========+ + | nova | Rachel | female | + | alloy | Bella | female | + | shimmer | Elli | female | + | onyx | Adam | male | + | echo | Antoni | male | + | fable | Arnold | male | + +----------+----------------+---------+ + + Default model: ``eleven_multilingual_v2``. + Override with :meth:`~fastapi_startkit.ai.Audio.model`. + + .. note:: + The ``speed`` parameter is accepted for API compatibility but is + not currently supported by ElevenLabs. + """ + + # Map OpenAI-style aliases → ElevenLabs voice IDs + _VOICE_MAP: dict[str, str] = { + "nova": "21m00Tcm4TlvDq8ikWAM", # Rachel — female + "alloy": "EXAVITQu4vr4xnSDxMaL", # Bella — female + "shimmer": "MF3mGyEYCl7XYWbV9V6O", # Elli — female + "onyx": "pNInz6obpgDQGcFmaJgB", # Adam — male + "echo": "ErXwobaYiN019PkySvjV", # Antoni — male + "fable": "VR6AewLTigWG4xSOukaG", # Arnold — male + } + + def __init__(self, api_key: str | None = None): + self._api_key = api_key + + async def synthesize( + self, + text: str, + voice: str, + model: str, + speed: float, + fmt: str, + ) -> bytes: + from elevenlabs.client import ElevenLabs # noqa: PLC0415 + + voice_id = self._VOICE_MAP.get(voice, voice) + client = ElevenLabs(api_key=self._api_key) + + audio_chunks = await asyncio.to_thread( + client.text_to_speech.convert, + voice_id=voice_id, + text=text, + model_id=model, + output_format="mp3_44100_128", + ) + return b"".join(audio_chunks) + + +# ─── PCM → WAV helper ───────────────────────────────────────────────────────── + + +def _pcm_to_wav( + pcm_data: bytes, + sample_rate: int = 24000, + channels: int = 1, + bit_depth: int = 16, +) -> bytes: + """Wrap raw PCM16 samples in a minimal RIFF/WAV container. + + Gemini TTS returns signed 16-bit little-endian PCM at 24 kHz mono. + Most audio players and APIs accept WAV, so we wrap it here. + """ + data_size = len(pcm_data) + byte_rate = sample_rate * channels * bit_depth // 8 + block_align = channels * bit_depth // 8 + + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + 36 + data_size, + b"WAVE", + b"fmt ", + 16, + 1, # PCM + channels, + sample_rate, + byte_rate, + block_align, + bit_depth, + b"data", + data_size, + ) + return header + pcm_data diff --git a/fastapi_startkit/src/fastapi_startkit/ai/config.py b/fastapi_startkit/src/fastapi_startkit/ai/config.py index af1a1acf..16a74206 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/config.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/config.py @@ -33,6 +33,14 @@ class GoogleConfig: key: str = field(default_factory=lambda: env("GEMINI_API_KEY", "") or env("GOOGLE_API_KEY", "")) +@dataclass +class ElevenLabsConfig: + """Configuration for the ElevenLabs provider.""" + + driver: str = "elevenlabs" + key: str = field(default_factory=lambda: env("ELEVENLABS_API_KEY", "")) + + @dataclass class AIConfig: """Top-level AI configuration — selects the default provider and holds per-provider configs.""" @@ -44,5 +52,10 @@ class AIConfig: "openai": OpenAIConfig(), "anthropic": AnthropicConfig(), "google": GoogleConfig(), + "elevenlabs": ElevenLabsConfig(), } ) + + # Media-generation provider selection + image_provider: str = field(default_factory=lambda: env("AI_IMAGE_PROVIDER", "openai")) + audio_provider: str = field(default_factory=lambda: env("AI_AUDIO_PROVIDER", "openai")) diff --git a/fastapi_startkit/src/fastapi_startkit/ai/document.py b/fastapi_startkit/src/fastapi_startkit/ai/document.py index f6dffee9..9ac7fdbb 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/document.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/document.py @@ -1,27 +1,123 @@ -"""Document helper — attach files or text to agent prompts.""" +"""Document helper — attach files, images, or text to agent prompts.""" from __future__ import annotations +import asyncio +import base64 + +# Optional runtime dependency — imported at module level so tests can patch it. +try: + from fastapi_startkit.storage.storage import Storage +except Exception: # pragma: no cover + Storage = None # type: ignore[assignment,misc] + class Document: - """Attach documents to agent.prompt() calls.""" + """Attach text or binary content to :meth:`~fastapi_startkit.ai.Agent.prompt` calls. + + Supports both text (for LLM context documents) and binary (for image + attachments sent to :class:`~fastapi_startkit.ai.Image`). + + Text:: + + doc = Document.from_path("report.txt") + agent.prompt("Summarise this", attachments=[doc]) + + Binary image:: + + doc = await Document.from_url("https://example.com/photo.jpg") + image = await Image.of("Make this impressionist").attachments([doc]).generate() + """ - def __init__(self, content: str, name: str = "", media_type: str = "text/plain"): + def __init__(self, content: str | bytes, name: str = "", media_type: str = "text/plain"): self.content = content self.name = name self.media_type = media_type + # ── Sync constructors (text) ─────────────────────────────────────────────── + @classmethod def from_path(cls, path: str) -> "Document": - """Load a document from a local file path.""" - with open(path) as f: - content = f.read() + """Load a document from a local file path. + + Text files are returned with ``str`` content; binary files + (e.g. images) fall back to ``bytes`` automatically. + """ + try: + with open(path) as f: + content: str | bytes = f.read() + except UnicodeDecodeError: + with open(path, "rb") as f: + content = f.read() return cls(content=content, name=path) + # ── Async constructors (binary) ──────────────────────────────────────────── + + @classmethod + async def from_storage(cls, key: str) -> "Document": + """Load a binary file from application storage (``storage/``) asynchronously. + + Falls back to reading directly from the ``storage/`` directory relative + to the current working directory if the Storage facade is not configured. + """ + + def _read() -> bytes: + if Storage is not None: + try: + disk = Storage.disk("local") + # Resolve the full path and read as binary + resolved_path = disk.get_path(key) + with open(resolved_path, "rb") as f: + return f.read() + except Exception: + pass + import os # noqa: PLC0415 + + with open(os.path.join("storage", key), "rb") as f: + return f.read() + + data = await asyncio.to_thread(_read) + return cls(content=data, name=key) + @classmethod - def from_storage(cls, key: str) -> "Document": - """Load a document from application storage (storage/).""" - return cls.from_path(f"storage/{key}") + async def from_url(cls, url: str) -> "Document": + """Download bytes from a URL asynchronously using *httpx*. + + Example:: + + doc = await Document.from_url("https://example.com/photo.jpg") + """ + import httpx # noqa: PLC0415 + + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + name = url.rstrip("/").split("/")[-1] + return cls(content=response.content, name=name) + + # ── Binary accessor ──────────────────────────────────────────────────────── + + def to_bytes(self) -> bytes: + """Return the document content as raw bytes. + + If the content was loaded as text (e.g. via :meth:`from_path`), + it is UTF-8 encoded. Binary content is returned as-is. + """ + if isinstance(self.content, bytes): + return self.content + return self.content.encode("utf-8") + + def to_base64(self) -> str: + """Return the document content base64-encoded as a plain string. + + Useful when an API expects a base64-encoded image or audio payload:: + + doc = Document.from_path("/tmp/photo.jpg") + encoded = doc.to_base64() + """ + return base64.b64encode(self.to_bytes()).decode("utf-8") + + # ── LLM content blocks ───────────────────────────────────────────────────── def to_anthropic_block(self) -> dict: """Return an Anthropic-compatible content block for this document.""" diff --git a/fastapi_startkit/src/fastapi_startkit/ai/image.py b/fastapi_startkit/src/fastapi_startkit/ai/image.py new file mode 100644 index 00000000..c6794293 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/image.py @@ -0,0 +1,218 @@ +"""Image generation API — text-to-image and image editing via a pluggable provider.""" + +from __future__ import annotations + +import asyncio +import uuid +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from .document import Document + from .image_factory import ImageFactory + +try: + from fastapi_startkit.storage.storage import Storage +except Exception: # pragma: no cover + Storage = None # type: ignore[assignment,misc] + +try: + from fastapi_startkit import Config +except Exception: # pragma: no cover + Config = None # type: ignore[assignment,misc] + + +class ImageResponse: + """Returned by :meth:`Image.generate`. + + Holds raw PNG bytes and provides async helpers to persist the image to + any configured storage disk:: + + image = await Image.of("A donut on a counter").generate() + + path = await image.store() # auto-named, private disk + path = await image.storeAs("result.png") # named, private disk + path = await image.storePublicly() # auto-named, public disk + path = await image.storePubliclyAs("result.png") + """ + + def __init__(self, data: bytes, fmt: str = "png"): + self._data = data + self._fmt = fmt + + @property + def data(self) -> bytes: + """Raw image bytes.""" + return self._data + + def _auto_filename(self) -> str: + return f"{uuid.uuid4()}.{self._fmt}" + + # ── Storage helpers ──────────────────────────────────────────────────────── + + async def store(self) -> str: + """Save to the default private disk with an auto-generated filename.""" + return await self._save(self._auto_filename(), disk="local") + + async def storeAs(self, name: str) -> str: + """Save to the default private disk with a custom filename.""" + return await self._save(name, disk="local") + + async def storePublicly(self) -> str: + """Save to the public disk with an auto-generated filename.""" + return await self._save(self._auto_filename(), disk="public") + + async def storePubliclyAs(self, name: str) -> str: + """Save to the public disk with a custom filename.""" + return await self._save(name, disk="public") + + # ── Internal ─────────────────────────────────────────────────────────────── + + async def _save(self, name: str, disk: str = "local") -> str: + return await asyncio.to_thread(self._save_sync, name, disk) + + def _save_sync(self, name: str, disk: str) -> str: + """Try the Storage facade first; fall back to a temp file.""" + if Storage is not None: + try: + Storage.disk(disk).put(name, self._data) + return name + except Exception: + pass + import os + import tempfile + + path = os.path.join(tempfile.gettempdir(), name) + with open(path, "wb") as f: + f.write(self._data) + return path + + +class Image: + """Fluent builder for image generation and editing. + + The active backend is selected from :attr:`~fastapi_startkit.ai.AIConfig.image_provider` + (env: ``AI_IMAGE_PROVIDER``). Defaults to OpenAI DALL-E. + + Usage — text to image:: + + image = await Image.of("A donut on a counter").generate() + + Usage — edit with :class:`~fastapi_startkit.ai.Document` attachments:: + + from fastapi_startkit.ai import Document + + image = await ( + Image.of("Make this impressionist") + .attachments([await Document.from_url("https://example.com/photo.jpg")]) + .landscape() + .generate() + ) + """ + + # DALL-E 3 size presets + _LANDSCAPE_SIZE = "1792x1024" + _PORTRAIT_SIZE = "1024x1792" + _SQUARE_SIZE = "1024x1024" + + def __init__(self, prompt: str): + self._prompt = prompt + self._attachments: list[Document] = [] + self._size: str = self._SQUARE_SIZE + self._model: str = "dall-e-3" + self._quality: str = "standard" + + @classmethod + def of(cls, prompt: str) -> "Image": + """Create an :class:`Image` builder with the given prompt.""" + return cls(prompt) + + # ── Modifier methods (chainable) ─────────────────────────────────────────── + + def attachments(self, docs: list) -> "Image": + """Attach :class:`~fastapi_startkit.ai.Document` objects for an editing request.""" + self._attachments = list(docs) + return self + + def landscape(self) -> "Image": + """Use landscape size (1792×1024). DALL-E 3 only.""" + self._size = self._LANDSCAPE_SIZE + return self + + def portrait(self) -> "Image": + """Use portrait size (1024×1792). DALL-E 3 only.""" + self._size = self._PORTRAIT_SIZE + return self + + def square(self) -> "Image": + """Use square size (1024×1024).""" + self._size = self._SQUARE_SIZE + return self + + def model(self, name: str) -> "Image": + """Override the model (default: ``dall-e-3``).""" + self._model = name + return self + + def quality(self, q: str) -> "Image": + """Set quality — ``'standard'`` or ``'hd'`` (DALL-E 3 only).""" + self._quality = q + return self + + # ── Generation ───────────────────────────────────────────────────────────── + + async def generate(self) -> ImageResponse: + """Call the configured image provider and return an :class:`ImageResponse`.""" + provider = self._resolve_provider() + + if self._attachments: + image_bytes = await provider.edit( + prompt=self._prompt, + image_bytes=self._attachments[0].to_bytes(), + size=self._size, + ) + else: + image_bytes = await provider.generate( + prompt=self._prompt, + size=self._size, + model=self._model, + quality=self._quality, + ) + + return ImageResponse(data=image_bytes, fmt="png") + + # ── Internal ─────────────────────────────────────────────────────────────── + + def _resolve_provider(self) -> "ImageFactory": + from .image_factory import ( # noqa: PLC0415 + GoogleImageFactory, + OpenAIImageFactory, + StabilityImageFactory, + ) + + provider_name = "openai" + api_key: Optional[str] = None + base_url: Optional[str] = None + google_key: Optional[str] = None + + try: + ai_config = Config.get("ai") if Config is not None else None # type: ignore[union-attr] + if ai_config is None: + raise RuntimeError("Config not available") + provider_name = ai_config.image_provider + openai_cfg = ai_config.providers.get("openai") + if openai_cfg: + api_key = openai_cfg.key or None + base_url = openai_cfg.url or None + google_cfg = ai_config.providers.get("google") + if google_cfg: + google_key = google_cfg.key or None + except Exception: + pass + + if provider_name == "openai": + return OpenAIImageFactory(api_key=api_key, base_url=base_url) + if provider_name == "google": + return GoogleImageFactory(api_key=google_key) + if provider_name == "stability": + return StabilityImageFactory() + raise ValueError(f"Unknown image provider: {provider_name!r}. Use 'openai', 'google', or 'stability'.") diff --git a/fastapi_startkit/src/fastapi_startkit/ai/image_factory.py b/fastapi_startkit/src/fastapi_startkit/ai/image_factory.py new file mode 100644 index 00000000..b699dafa --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/image_factory.py @@ -0,0 +1,157 @@ +"""Image generation provider abstractions. + +Providers implement the :class:`ImageFactory` ABC so that the +:class:`~fastapi_startkit.ai.Image` builder is not hard-wired to a single +vendor. Select the active provider via ``AI_IMAGE_PROVIDER`` in your +``.env`` (or ``AIConfig.image_provider``). + +Supported providers +------------------- +* ``openai`` — OpenAI DALL-E 3 / DALL-E 2 (default) +* ``google`` — Google Imagen 3 via the ``google-genai`` SDK +* ``stability`` — Stability AI (stub, raises :exc:`NotImplementedError`) +""" + +from __future__ import annotations + +import asyncio +import base64 +from abc import ABC, abstractmethod + + +class ImageFactory(ABC): + """Abstract base for image generation backends.""" + + @abstractmethod + async def generate(self, prompt: str, size: str, model: str, quality: str) -> bytes: + """Generate a new image from a text prompt and return raw PNG bytes.""" + + @abstractmethod + async def edit(self, prompt: str, image_bytes: bytes, size: str) -> bytes: + """Edit an existing image (described by *image_bytes*) and return raw PNG bytes.""" + + +class OpenAIImageFactory(ImageFactory): + """OpenAI DALL-E provider using :class:`openai.AsyncOpenAI`. + + Uses DALL-E 3 for generation and DALL-E 2 for editing (the only model + that supports inpainting as of mid-2025). + """ + + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self._api_key = api_key + self._base_url = base_url + + async def generate(self, prompt: str, size: str, model: str, quality: str) -> bytes: + from openai import AsyncOpenAI # noqa: PLC0415 + + client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url) + params: dict = { + "model": model, + "prompt": prompt, + "size": size, + "n": 1, + "response_format": "b64_json", + } + if model == "dall-e-3": + params["quality"] = quality + + response = await client.images.generate(**params) + return base64.b64decode(response.data[0].b64_json) + + async def edit(self, prompt: str, image_bytes: bytes, size: str) -> bytes: + import io # noqa: PLC0415 + + from openai import AsyncOpenAI # noqa: PLC0415 + + client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url) + image_file = io.BytesIO(image_bytes) + image_file.name = "image.png" + + response = await client.images.edit( + model="dall-e-2", + image=image_file, + prompt=prompt, + size="1024x1024", + n=1, + response_format="b64_json", + ) + return base64.b64decode(response.data[0].b64_json) + + +class GoogleImageFactory(ImageFactory): + """Google Imagen 3 provider via the ``google-genai`` SDK. + + Requires: ``pip install google-genai`` + + Configure via ``.env``:: + + AI_IMAGE_PROVIDER=google + GEMINI_API_KEY=your-key # or GOOGLE_API_KEY + + Default model: ``imagen-3.0-generate-002``. + Override with :meth:`~fastapi_startkit.ai.Image.model`:: + + image = await Image.of("A sunset").model("imagen-3.0-fast-generate-001").generate() + + Size mapping (DALL-E pixel sizes → Imagen aspect ratios): + + +--------------+-----------+ + | Size string | Ratio | + +==============+===========+ + | 1024×1024 | 1:1 | + | 1792×1024 | 16:9 | + | 1024×1792 | 9:16 | + +--------------+-----------+ + """ + + # Map DALL-E-style pixel sizes to Imagen aspect ratios + _ASPECT_MAP: dict[str, str] = { + "1024x1024": "1:1", + "1792x1024": "16:9", + "1024x1792": "9:16", + "1280x720": "16:9", + "720x1280": "9:16", + } + + def __init__(self, api_key: str | None = None): + self._api_key = api_key + + async def generate(self, prompt: str, size: str, model: str, quality: str) -> bytes: + """Generate an image via Imagen 3 and return raw PNG bytes.""" + from google import genai # noqa: PLC0415 + from google.genai import types # noqa: PLC0415 + + client = genai.Client(api_key=self._api_key) + aspect_ratio = self._ASPECT_MAP.get(size, "1:1") + + response = await asyncio.to_thread( + client.models.generate_images, + model=model, + prompt=prompt, + config=types.GenerateImagesConfig( + number_of_images=1, + aspect_ratio=aspect_ratio, + ), + ) + return response.generated_images[0].image.image_bytes + + async def edit(self, prompt: str, image_bytes: bytes, size: str) -> bytes: + """Image editing is not yet supported by :class:`GoogleImageFactory`. + + Use :class:`OpenAIImageFactory` for editing workflows. + """ + raise NotImplementedError( + "GoogleImageFactory does not support image editing yet. " + "Use OpenAIImageFactory (AI_IMAGE_PROVIDER=openai) for editing." + ) + + +class StabilityImageFactory(ImageFactory): + """Stability AI provider stub — raises :exc:`NotImplementedError` until implemented.""" + + async def generate(self, prompt: str, size: str, model: str, quality: str) -> bytes: + raise NotImplementedError("StabilityImageFactory is not yet implemented") + + async def edit(self, prompt: str, image_bytes: bytes, size: str) -> bytes: + raise NotImplementedError("StabilityImageFactory is not yet implemented") diff --git a/fastapi_startkit/tests/ai/test_audio.py b/fastapi_startkit/tests/ai/test_audio.py new file mode 100644 index 00000000..0087de9f --- /dev/null +++ b/fastapi_startkit/tests/ai/test_audio.py @@ -0,0 +1,344 @@ +"""Tests for the Audio generation API (Audio, AudioResponse).""" + +from __future__ import annotations + +import os +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi_startkit.ai import Audio, AudioResponse + + +# ─── Shared helpers ─────────────────────────────────────────────────────────── + + +def _fake_audio_bytes() -> bytes: + return b"ID3\x03\x00" # minimal MP3 magic + + +def _mock_provider(result: bytes | None = None) -> MagicMock: + """Return a mock AudioFactory.""" + p = MagicMock() + p.synthesize = AsyncMock(return_value=result if result is not None else _fake_audio_bytes()) + return p + + +# ─── Audio builder — chainable API ──────────────────────────────────────────── + + +class TestAudioBuilder(IsolatedAsyncioTestCase): + def test_of_returns_audio_instance(self): + audio = Audio.of("Hello world") + assert isinstance(audio, Audio) + assert audio._text == "Hello world" + + def test_default_voice_is_alloy(self): + assert Audio.of("Hello")._voice == "alloy" + + def test_female_sets_nova_voice(self): + assert Audio.of("Hello").female()._voice == "nova" + + def test_male_sets_onyx_voice(self): + assert Audio.of("Hello").male()._voice == "onyx" + + def test_voice_sets_explicit_voice(self): + assert Audio.of("Hello").voice("shimmer")._voice == "shimmer" + + def test_voice_overrides_previous_setting(self): + assert Audio.of("Hello").female().voice("echo")._voice == "echo" + + def test_model_override(self): + assert Audio.of("Hello").model("tts-1-hd")._model == "tts-1-hd" + + def test_speed_override(self): + assert Audio.of("Hello").speed(1.5)._speed == 1.5 + + def test_format_override(self): + assert Audio.of("Hello").format("opus")._response_format == "opus" + + def test_chainable_methods_return_self(self): + audio = Audio.of("Hello") + assert audio.female() is audio + assert audio.male() is audio + assert audio.voice("alloy") is audio + assert audio.model("tts-1") is audio + assert audio.speed(1.0) is audio + assert audio.format("mp3") is audio + + +# ─── Audio.generate() ───────────────────────────────────────────────────────── + + +class TestAudioGeneration(IsolatedAsyncioTestCase): + async def test_generate_calls_provider_and_returns_response(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + result = await Audio.of("Hello world").generate() + + assert isinstance(result, AudioResponse) + assert result.data == _fake_audio_bytes() + provider.synthesize.assert_called_once() + + async def test_generate_passes_text_to_provider(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hello world").generate() + + assert provider.synthesize.call_args[1]["text"] == "Hello world" + + async def test_generate_female_passes_nova_voice(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").female().generate() + + assert provider.synthesize.call_args[1]["voice"] == "nova" + + async def test_generate_male_passes_onyx_voice(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").male().generate() + + assert provider.synthesize.call_args[1]["voice"] == "onyx" + + async def test_generate_explicit_voice(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").voice("shimmer").generate() + + assert provider.synthesize.call_args[1]["voice"] == "shimmer" + + async def test_generate_passes_speed(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").speed(1.25).generate() + + assert provider.synthesize.call_args[1]["speed"] == 1.25 + + async def test_generate_passes_format(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").format("opus").generate() + + assert provider.synthesize.call_args[1]["fmt"] == "opus" + + async def test_generate_hd_model(self): + provider = _mock_provider() + + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").model("tts-1-hd").generate() + + assert provider.synthesize.call_args[1]["model"] == "tts-1-hd" + + +# ─── AudioResponse storage methods ──────────────────────────────────────────── + + +class TestAudioResult(IsolatedAsyncioTestCase): + async def test_store_writes_to_temp_when_no_storage(self): + resp = AudioResponse(data=_fake_audio_bytes()) + + path = await resp.store() + + assert os.path.exists(path) + with open(path, "rb") as f: + assert f.read() == _fake_audio_bytes() + os.remove(path) + + async def test_store_as_uses_given_name(self): + resp = AudioResponse(data=_fake_audio_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/greeting.mp3" + await resp.storeAs("greeting.mp3") + + mock_save.assert_called_once_with("greeting.mp3", "local") + + async def test_store_publicly_as_uses_public_disk(self): + resp = AudioResponse(data=_fake_audio_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/greeting.mp3" + await resp.storePubliclyAs("greeting.mp3") + + mock_save.assert_called_once_with("greeting.mp3", "public") + + async def test_store_publicly_uses_public_disk(self): + resp = AudioResponse(data=_fake_audio_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.mp3" + await resp.storePublicly() + + _, disk = mock_save.call_args[0] + assert disk == "public" + + async def test_store_auto_filename_has_mp3_ext(self): + resp = AudioResponse(data=_fake_audio_bytes(), fmt="mp3") + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.mp3" + await resp.store() + + name, _ = mock_save.call_args[0] + assert name.endswith(".mp3") + + async def test_store_uses_storage_facade_when_available(self): + resp = AudioResponse(data=_fake_audio_bytes()) + mock_disk = MagicMock() + + with patch("fastapi_startkit.ai.audio.Storage") as mock_storage_cls: + mock_storage_cls.disk.return_value = mock_disk + await resp.storeAs("hello.mp3") + + mock_storage_cls.disk.assert_called_once_with("local") + mock_disk.put.assert_called_once_with("hello.mp3", _fake_audio_bytes()) + + +# ─── GoogleAudioFactory ─────────────────────────────────────────────────────── + + +class TestGoogleAudioFactory(IsolatedAsyncioTestCase): + async def test_synthesize_returns_wav_bytes(self): + from fastapi_startkit.ai.audio_factory import GoogleAudioFactory + + fake_pcm = b"\x00\x01" * 100 # 200 bytes of fake PCM16 + + mock_part = MagicMock() + mock_part.inline_data.data = fake_pcm + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + provider = GoogleAudioFactory(api_key="test-key") + + with patch.dict( + "sys.modules", + {"google": MagicMock(), "google.genai": MagicMock(), "google.genai.types": MagicMock()}, + ): + with patch( + "fastapi_startkit.ai.audio_factory.asyncio.to_thread", + new=AsyncMock(return_value=mock_response), + ): + result = await provider.synthesize("Hello world", "nova", "gemini-2.5-flash-preview-tts", 1.0, "mp3") + + assert result[:4] == b"RIFF" + assert result[8:12] == b"WAVE" + + def test_voice_map_covers_openai_aliases(self): + from fastapi_startkit.ai.audio_factory import GoogleAudioFactory + + provider = GoogleAudioFactory() + assert provider._VOICE_MAP["nova"] == "Aoede" + assert provider._VOICE_MAP["onyx"] == "Fenrir" + assert provider._VOICE_MAP["alloy"] == "Kore" + + def test_unknown_voice_passed_through(self): + from fastapi_startkit.ai.audio_factory import GoogleAudioFactory + + provider = GoogleAudioFactory() + assert provider._VOICE_MAP.get("Zephyr", "Zephyr") == "Zephyr" + + async def test_audio_builder_resolves_google_factory(self): + mock_ai_config = MagicMock() + mock_ai_config.audio_provider = "google" + mock_ai_config.providers = { + "google": MagicMock(key="gkey"), + "openai": MagicMock(key=""), + "elevenlabs": MagicMock(key=""), + } + + with patch("fastapi_startkit.ai.audio.Config") as mock_config: + mock_config.get.return_value = mock_ai_config + from fastapi_startkit.ai.audio_factory import GoogleAudioFactory + + provider = Audio.of("test")._resolve_provider() + + assert isinstance(provider, GoogleAudioFactory) + + +# ─── _pcm_to_wav helper ─────────────────────────────────────────────────────── + + +class TestPcmToWav(IsolatedAsyncioTestCase): + def test_wav_header_structure(self): + from fastapi_startkit.ai.audio_factory import _pcm_to_wav + + pcm = b"\x00\x00" * 24000 # 1 second of silence at 24 kHz mono 16-bit + wav = _pcm_to_wav(pcm) + + assert wav[:4] == b"RIFF" + assert wav[8:12] == b"WAVE" + assert wav[12:16] == b"fmt " + assert wav[36:40] == b"data" + + def test_wav_size_matches_pcm(self): + from fastapi_startkit.ai.audio_factory import _pcm_to_wav + + pcm = b"\x01\x02" * 100 + wav = _pcm_to_wav(pcm) + assert len(wav) == 44 + len(pcm) + + +# ─── ElevenLabsAudioFactory ─────────────────────────────────────────────────── + + +class TestElevenLabsAudioFactory(IsolatedAsyncioTestCase): + async def test_synthesize_joins_audio_chunks(self): + from fastapi_startkit.ai.audio_factory import ElevenLabsAudioFactory + + fake_chunks = [b"chunk1", b"chunk2", b"chunk3"] + mock_elevenlabs_module = MagicMock() + provider = ElevenLabsAudioFactory(api_key="test-key") + + with patch.dict( + "sys.modules", + {"elevenlabs": mock_elevenlabs_module, "elevenlabs.client": mock_elevenlabs_module}, + ): + with patch( + "fastapi_startkit.ai.audio_factory.asyncio.to_thread", + new=AsyncMock(return_value=iter(fake_chunks)), + ): + result = await provider.synthesize("Hello", "nova", "eleven_multilingual_v2", 1.0, "mp3") + + assert result == b"chunk1chunk2chunk3" + + def test_voice_map_covers_openai_aliases(self): + from fastapi_startkit.ai.audio_factory import ElevenLabsAudioFactory + + provider = ElevenLabsAudioFactory() + assert "nova" in provider._VOICE_MAP + assert "onyx" in provider._VOICE_MAP + assert "echo" in provider._VOICE_MAP + + def test_direct_voice_id_passed_through(self): + from fastapi_startkit.ai.audio_factory import ElevenLabsAudioFactory + + provider = ElevenLabsAudioFactory() + direct_id = "some-custom-voice-id" + assert provider._VOICE_MAP.get(direct_id, direct_id) == direct_id + + async def test_audio_builder_resolves_elevenlabs_factory(self): + mock_ai_config = MagicMock() + mock_ai_config.audio_provider = "elevenlabs" + mock_ai_config.providers = { + "google": MagicMock(key=""), + "openai": MagicMock(key=""), + "elevenlabs": MagicMock(key="elkey"), + } + + with patch("fastapi_startkit.ai.audio.Config") as mock_config: + mock_config.get.return_value = mock_ai_config + from fastapi_startkit.ai.audio_factory import ElevenLabsAudioFactory + + provider = Audio.of("test")._resolve_provider() + + assert isinstance(provider, ElevenLabsAudioFactory) + assert provider._api_key == "elkey" diff --git a/fastapi_startkit/tests/ai/test_image.py b/fastapi_startkit/tests/ai/test_image.py new file mode 100644 index 00000000..4c98c3df --- /dev/null +++ b/fastapi_startkit/tests/ai/test_image.py @@ -0,0 +1,295 @@ +"""Tests for the Image generation API (Image, ImageResponse, Document attachments).""" + +from __future__ import annotations + +import os +import tempfile +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi_startkit.ai import Document, Image, ImageResponse + + +# ─── Shared helpers ─────────────────────────────────────────────────────────── + + +def _fake_image_bytes() -> bytes: + return b"\x89PNG\r\n\x1a\n" # minimal PNG magic + + +def _mock_provider(generate_result: bytes | None = None, edit_result: bytes | None = None) -> MagicMock: + """Return a mock ImageFactory.""" + p = MagicMock() + p.generate = AsyncMock(return_value=generate_result if generate_result is not None else _fake_image_bytes()) + p.edit = AsyncMock(return_value=edit_result if edit_result is not None else _fake_image_bytes()) + return p + + +# ─── Document used as image attachment ──────────────────────────────────────── + + +class TestDocumentImageAttachment(IsolatedAsyncioTestCase): + def test_document_from_path_reads_binary_via_to_bytes(self): + """from_path auto-detects binary files and stores bytes content.""" + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff") + tmp_file = f.name + try: + doc = Document.from_path(tmp_file) + assert doc.to_bytes() == b"\xff\xd8\xff" + finally: + os.remove(tmp_file) + + def test_document_content_bytes_stored_directly(self): + doc = Document(content=b"\x89PNG", name="photo.png") + assert doc.to_bytes() == b"\x89PNG" + + def test_document_content_str_encoded_to_bytes(self): + doc = Document(content="hello", name="text.txt") + assert doc.to_bytes() == b"hello" + + async def test_document_from_url_downloads_bytes(self): + fake_data = b"fake-image-bytes" + + with patch("httpx.AsyncClient") as MockClient: + mock_response = MagicMock() + mock_response.content = fake_data + mock_response.raise_for_status = MagicMock() + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.get = AsyncMock(return_value=mock_response) + + doc = await Document.from_url("https://example.com/photo.jpg") + + assert doc.to_bytes() == fake_data + assert doc.name == "photo.jpg" + + async def test_document_from_storage_reads_bytes(self): + with tempfile.TemporaryDirectory() as tmp: + storage_dir = os.path.join(tmp, "storage") + os.makedirs(storage_dir) + with open(os.path.join(storage_dir, "photo.jpg"), "wb") as f: + f.write(b"\x89PNG") + + orig_dir = os.getcwd() + os.chdir(tmp) + try: + with patch("fastapi_startkit.ai.document.Storage", None): + doc = await Document.from_storage("photo.jpg") + finally: + os.chdir(orig_dir) + + assert doc.to_bytes() == b"\x89PNG" + + def test_document_to_base64_encodes_bytes(self): + import base64 + + doc = Document(content=b"\x89PNG", name="photo.png") + assert doc.to_base64() == base64.b64encode(b"\x89PNG").decode("utf-8") + + def test_document_to_base64_encodes_text(self): + import base64 + + doc = Document(content="hello", name="text.txt") + assert doc.to_base64() == base64.b64encode(b"hello").decode("utf-8") + + +# ─── Image builder — chainable API ──────────────────────────────────────────── + + +class TestImageBuilder(IsolatedAsyncioTestCase): + def test_of_returns_image_instance(self): + img = Image.of("A donut on a counter") + assert isinstance(img, Image) + assert img._prompt == "A donut on a counter" + + def test_landscape_sets_size(self): + assert Image.of("test").landscape()._size == "1792x1024" + + def test_portrait_sets_size(self): + assert Image.of("test").portrait()._size == "1024x1792" + + def test_square_sets_size(self): + assert Image.of("test").landscape().square()._size == "1024x1024" + + def test_model_override(self): + assert Image.of("test").model("dall-e-2")._model == "dall-e-2" + + def test_quality_override(self): + assert Image.of("test").quality("hd")._quality == "hd" + + def test_attachments_sets_list(self): + doc = Document(content=b"img", name="x.png") + assert Image.of("test").attachments([doc])._attachments == [doc] + + +# ─── Image.generate() ───────────────────────────────────────────────────────── + + +class TestImageGeneration(IsolatedAsyncioTestCase): + async def test_generate_calls_provider_and_returns_response(self): + provider = _mock_provider() + + with patch.object(Image, "_resolve_provider", return_value=provider): + result = await Image.of("A donut on a counter").generate() + + assert isinstance(result, ImageResponse) + assert result.data == _fake_image_bytes() + provider.generate.assert_called_once() + + async def test_generate_passes_landscape_size_to_provider(self): + provider = _mock_provider() + + with patch.object(Image, "_resolve_provider", return_value=provider): + await Image.of("test").landscape().generate() + + assert provider.generate.call_args[1]["size"] == "1792x1024" + + async def test_generate_passes_quality_to_provider(self): + provider = _mock_provider() + + with patch.object(Image, "_resolve_provider", return_value=provider): + await Image.of("test").quality("hd").generate() + + assert provider.generate.call_args[1]["quality"] == "hd" + + async def test_generate_uses_edit_when_attachments_present(self): + provider = _mock_provider() + doc = Document(content=b"img-bytes", name="photo.png") + + with patch.object(Image, "_resolve_provider", return_value=provider): + result = await Image.of("Make impressionist").attachments([doc]).generate() + + assert isinstance(result, ImageResponse) + provider.edit.assert_called_once() + provider.generate.assert_not_called() + + async def test_generate_passes_attachment_bytes_to_edit(self): + provider = _mock_provider() + doc = Document(content=b"raw-image-bytes", name="photo.png") + + with patch.object(Image, "_resolve_provider", return_value=provider): + await Image.of("Make impressionist").attachments([doc]).generate() + + assert provider.edit.call_args[1]["image_bytes"] == b"raw-image-bytes" + + +# ─── ImageResponse storage methods ──────────────────────────────────────────── + + +class TestImageResult(IsolatedAsyncioTestCase): + async def test_store_writes_to_temp_when_no_storage(self): + """Falls back to tempfile when Storage facade is unavailable.""" + resp = ImageResponse(data=_fake_image_bytes()) + + path = await resp.store() + + assert os.path.exists(path) + with open(path, "rb") as f: + assert f.read() == _fake_image_bytes() + os.remove(path) + + async def test_store_as_uses_given_name(self): + resp = ImageResponse(data=_fake_image_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/result.png" + path = await resp.storeAs("result.png") + + mock_save.assert_called_once_with("result.png", "local") + assert path.endswith("result.png") + + async def test_store_publicly_as_uses_public_disk(self): + resp = ImageResponse(data=_fake_image_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/result.png" + await resp.storePubliclyAs("result.png") + + mock_save.assert_called_once_with("result.png", "public") + + async def test_store_publicly_uses_public_disk(self): + resp = ImageResponse(data=_fake_image_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.png" + await resp.storePublicly() + + _, disk = mock_save.call_args[0] + assert disk == "public" + + async def test_store_auto_filename_has_png_ext(self): + resp = ImageResponse(data=_fake_image_bytes(), fmt="png") + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.png" + await resp.store() + + name, _ = mock_save.call_args[0] + assert name.endswith(".png") + + async def test_store_uses_storage_facade_when_available(self): + resp = ImageResponse(data=_fake_image_bytes()) + mock_disk = MagicMock() + + with patch("fastapi_startkit.ai.image.Storage") as mock_storage_cls: + mock_storage_cls.disk.return_value = mock_disk + await resp.storeAs("photo.png") + + mock_storage_cls.disk.assert_called_once_with("local") + mock_disk.put.assert_called_once_with("photo.png", _fake_image_bytes()) + + +# ─── GoogleImageFactory ─────────────────────────────────────────────────────── + + +class TestGoogleImageFactory(IsolatedAsyncioTestCase): + async def test_generate_calls_genai_and_returns_bytes(self): + from fastapi_startkit.ai.image_factory import GoogleImageFactory + + fake_bytes = b"\x89PNG\r\n" + mock_image = MagicMock() + mock_image.image.image_bytes = fake_bytes + mock_response = MagicMock() + mock_response.generated_images = [mock_image] + + with patch.dict( + "sys.modules", + {"google": MagicMock(), "google.genai": MagicMock(), "google.genai.types": MagicMock()}, + ): + with patch( + "fastapi_startkit.ai.image_factory.asyncio.to_thread", + new=AsyncMock(return_value=mock_response), + ): + provider = GoogleImageFactory(api_key="test-key") + result = await provider.generate("A sunset", "1024x1024", "imagen-3.0-generate-002", "standard") + + assert result == fake_bytes + + def test_aspect_ratio_mapping(self): + from fastapi_startkit.ai.image_factory import GoogleImageFactory + + provider = GoogleImageFactory() + assert provider._ASPECT_MAP["1024x1024"] == "1:1" + assert provider._ASPECT_MAP["1792x1024"] == "16:9" + assert provider._ASPECT_MAP["1024x1792"] == "9:16" + + async def test_edit_raises_not_implemented(self): + from fastapi_startkit.ai.image_factory import GoogleImageFactory + + provider = GoogleImageFactory(api_key="test-key") + with self.assertRaisesRegex(NotImplementedError, "does not support image editing"): + await provider.edit("Make it blue", b"\x89PNG", "1024x1024") + + async def test_image_builder_resolves_google_factory(self): + mock_ai_config = MagicMock() + mock_ai_config.image_provider = "google" + mock_ai_config.providers = {"google": MagicMock(key="gkey"), "openai": MagicMock(key="")} + + with patch("fastapi_startkit.ai.image.Config") as mock_config: + mock_config.get.return_value = mock_ai_config + from fastapi_startkit.ai.image_factory import GoogleImageFactory + + provider = Image.of("test")._resolve_provider() + + assert isinstance(provider, GoogleImageFactory)