From 1d7d44a94d6a8f553dcd0726c6f20d8a88e56e83 Mon Sep 17 00:00:00 2001 From: 0xsatoshi99 <0xsatoshi99@gmail.com> Date: Fri, 5 Dec 2025 23:59:43 +0100 Subject: [PATCH 1/5] feat: replicate_service as alternative service --- .../generator/services/replicate_service.py | 239 ++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 neurons/generator/services/replicate_service.py diff --git a/neurons/generator/services/replicate_service.py b/neurons/generator/services/replicate_service.py new file mode 100644 index 00000000..29ba93bb --- /dev/null +++ b/neurons/generator/services/replicate_service.py @@ -0,0 +1,239 @@ +import os +import time +import requests +import bittensor as bt +from typing import Dict, Any, Optional + +from .base_service import BaseGenerationService +from ..task_manager import GenerationTask + + +class Models: + """Symbolic constants for Replicate models.""" + + FLUX_SCHNELL = "flux-schnell" + FLUX_DEV = "flux-dev" + FLUX_PRO = "flux-pro" + SDXL = "sdxl" + + +MODEL_INFO = { + Models.FLUX_SCHNELL: { + "version": "black-forest-labs/flux-schnell", + "name": "FLUX.1 Schnell", + "family": "flux", + "supports_negative_prompt": False, + }, + Models.FLUX_DEV: { + "version": "black-forest-labs/flux-dev", + "name": "FLUX.1 Dev", + "family": "flux", + "supports_negative_prompt": False, + }, + Models.FLUX_PRO: { + "version": "black-forest-labs/flux-pro", + "name": "FLUX.1 Pro", + "family": "flux", + "supports_negative_prompt": False, + }, + Models.SDXL: { + "version": "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", + "name": "Stable Diffusion XL", + "family": "sdxl", + "supports_negative_prompt": True, + }, +} + + +class ReplicateService(BaseGenerationService): + """ + Replicate API service for FLUX and SDXL image generation. + + Features: + - FLUX.1 models (Schnell, Dev, Pro) + - SDXL support + - Async prediction with polling + - Multiple aspect ratio support + """ + + API_BASE = "https://api.replicate.com/v1" + POLL_INTERVAL = 1.0 + MAX_POLL_TIME = 120 + + def __init__(self, config: Any = None): + super().__init__(config) + + self.api_key = os.getenv("REPLICATE_API_TOKEN") + self.timeout = 30 + self.default_model = Models.FLUX_SCHNELL + + if not self.api_key: + bt.logging.warning("REPLICATE_API_TOKEN not found.") + else: + bt.logging.info("ReplicateService initialized with API token") + + # --------------------------------------------------------------------- + # Base methods + # --------------------------------------------------------------------- + def is_available(self) -> bool: + return self.api_key is not None and self.api_key.strip() != "" + + def supports_modality(self, modality: str) -> bool: + return modality == "image" + + def get_supported_tasks(self) -> Dict[str, list]: + return { + "image": ["image_generation"], + "video": [] + } + + def get_api_key_requirements(self) -> Dict[str, str]: + return {"REPLICATE_API_TOKEN": "API token for Replicate image generation"} + + # --------------------------------------------------------------------- + # Processing logic + # --------------------------------------------------------------------- + def process(self, task: GenerationTask) -> Dict[str, Any]: + if task.modality != "image": + raise ValueError(f"ReplicateService does not support modality: {task.modality}") + + return self._generate_image(task) + + # --------------------------------------------------------------------- + # API helpers + # --------------------------------------------------------------------- + def _create_prediction(self, model_version: str, input_data: Dict[str, Any]) -> Dict[str, Any]: + url = f"{self.API_BASE}/predictions" + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + if "/" in model_version and ":" not in model_version: + payload = {"model": model_version, "input": input_data} + else: + version_id = model_version.split(":")[-1] if ":" in model_version else model_version + payload = {"version": version_id, "input": input_data} + + response = requests.post(url, headers=headers, json=payload, timeout=self.timeout) + + if response.status_code == 201: + return response.json() + else: + raise RuntimeError(f"Replicate API error {response.status_code}: {response.text}") + + def _poll_prediction(self, prediction_url: str) -> Dict[str, Any]: + headers = {"Authorization": f"Bearer {self.api_key}"} + start_time = time.time() + + while time.time() - start_time < self.MAX_POLL_TIME: + response = requests.get(prediction_url, headers=headers, timeout=self.timeout) + + if response.status_code != 200: + raise RuntimeError(f"Replicate poll error {response.status_code}: {response.text}") + + result = response.json() + status = result.get("status") + + if status == "succeeded": + return result + elif status == "failed": + error = result.get("error", "Unknown error") + raise RuntimeError(f"Replicate prediction failed: {error}") + elif status == "canceled": + raise RuntimeError("Replicate prediction was canceled") + + time.sleep(self.POLL_INTERVAL) + + raise RuntimeError(f"Replicate prediction timed out after {self.MAX_POLL_TIME}s") + + def _download_image(self, url: str) -> bytes: + response = requests.get(url, timeout=60) + response.raise_for_status() + return response.content + + # --------------------------------------------------------------------- + # Image generation core + # --------------------------------------------------------------------- + def _generate_image(self, task: GenerationTask) -> Dict[str, Any]: + try: + params = task.parameters or {} + model = params.get("model", self.default_model) + prompt = task.prompt + + bt.logging.info(f"Replicate generating image with model={model}") + + if model not in MODEL_INFO: + raise ValueError(f"Unknown Replicate model: {model}. " + f"Available models: {list(MODEL_INFO.keys())}") + + model_info = MODEL_INFO[model] + model_version = model_info["version"] + + input_data = {"prompt": prompt} + + if model_info.get("supports_negative_prompt") and "negative_prompt" in params: + input_data["negative_prompt"] = params["negative_prompt"] + + if "width" in params: + input_data["width"] = params["width"] + if "height" in params: + input_data["height"] = params["height"] + if "num_inference_steps" in params: + input_data["num_inference_steps"] = params["num_inference_steps"] + if "guidance_scale" in params: + input_data["guidance_scale"] = params["guidance_scale"] + if "seed" in params: + input_data["seed"] = params["seed"] + if "aspect_ratio" in params: + input_data["aspect_ratio"] = params["aspect_ratio"] + + start_time = time.time() + + prediction = self._create_prediction(model_version, input_data) + prediction_url = prediction.get("urls", {}).get("get") + + if not prediction_url: + raise RuntimeError("No prediction URL returned from Replicate") + + bt.logging.info("Replicate prediction created, polling for result...") + + result = self._poll_prediction(prediction_url) + gen_time = time.time() - start_time + + output = result.get("output") + if not output: + raise RuntimeError("No output returned from Replicate") + + image_url = output[0] if isinstance(output, list) else output + + bt.logging.info(f"Replicate generated image in {gen_time:.2f}s, downloading...") + + img_bytes = self._download_image(image_url) + + bt.logging.success(f"Downloaded {len(img_bytes)} bytes from Replicate") + + return { + "data": img_bytes, + "metadata": { + "model": model, + "provider": "replicate", + "generation_time": round(gen_time, 2), + "prediction_id": result.get("id"), + } + } + + except Exception as e: + bt.logging.error(f"Replicate image generation failed: {e}") + raise + + def get_service_info(self) -> Dict[str, Any]: + return { + "name": "Replicate", + "type": "api", + "provider": "api.replicate.com", + "available": self.is_available(), + "supported_tasks": self.get_supported_tasks(), + "default_model": self.default_model + } From 41eb7ecf9e5e29434a46672be75f63b5da11dcca Mon Sep 17 00:00:00 2001 From: 0xsatoshi99 <0xsatoshi99@gmail.com> Date: Sat, 6 Dec 2025 00:00:00 +0100 Subject: [PATCH 2/5] feat: add replicate service to registry --- neurons/generator/services/service_registry.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/neurons/generator/services/service_registry.py b/neurons/generator/services/service_registry.py index d41a256b..9a2dd5a8 100644 --- a/neurons/generator/services/service_registry.py +++ b/neurons/generator/services/service_registry.py @@ -6,12 +6,14 @@ from .openai_service import OpenAIService from .openrouter_service import OpenRouterService from .local_service import LocalService +from .replicate_service import ReplicateService SERVICE_MAP = { "openai": OpenAIService, "openrouter": OpenRouterService, "local": LocalService, + "replicate": ReplicateService, } @@ -20,13 +22,14 @@ class ServiceRegistry: Registry for managing generation services. Set per-modality service via env vars: - IMAGE_SERVICE=openai|openrouter|local|none - VIDEO_SERVICE=openai|openrouter|local|none + IMAGE_SERVICE=openai|openrouter|local|replicate|none + VIDEO_SERVICE=openai|openrouter|local|replicate|none Services: - openai: DALL-E 3 (requires OPENAI_API_KEY) - openrouter: Google Gemini via OpenRouter (requires OPEN_ROUTER_API_KEY) - local: Local Stable Diffusion models + - replicate: FLUX, SDXL via Replicate (requires REPLICATE_API_TOKEN) - none: Disable this modality (no service loaded) If not set, falls back to loading all available services. From b55610082b0ec619e98fb7be23cba57b17ddeb2e Mon Sep 17 00:00:00 2001 From: 0xsatoshi99 <0xsatoshi99@gmail.com> Date: Sat, 6 Dec 2025 00:00:14 +0100 Subject: [PATCH 3/5] test: replicate service --- tests/generator/replicate_service.py | 165 +++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 tests/generator/replicate_service.py diff --git a/tests/generator/replicate_service.py b/tests/generator/replicate_service.py new file mode 100644 index 00000000..e3b43dbe --- /dev/null +++ b/tests/generator/replicate_service.py @@ -0,0 +1,165 @@ +import os +import traceback +from PIL import Image +import io +import time + +from neurons.generator.services.replicate_service import ReplicateService, Models +from neurons.generator.task_manager import TaskManager + +os.environ.setdefault( + "REPLICATE_API_TOKEN", + "r8_YOUR-API-TOKEN" +) + + +def save_image(img_bytes, filename): + os.makedirs("outputs", exist_ok=True) + out_path = f"outputs/{filename}" + with open(out_path, "wb") as f: + f.write(img_bytes) + return out_path + + +def validate_image(img_bytes): + try: + Image.open(io.BytesIO(img_bytes)).verify() + return True + except Exception: + return False + + +def run_model_test(service, manager, model): + print(f"\n=== Running generation test for model: {model} ===") + + task_id = manager.create_task( + modality="image", + prompt="A neon cyberpunk city with flying cars", + parameters={ + "model": model, + "aspect_ratio": "16:9", + "seed": 777, + }, + webhook_url=None, + signed_by="test-suite" + ) + + task = manager.get_task(task_id) + + try: + start_time = time.time() + result = service.process(task) + elapsed = time.time() - start_time + + img_bytes = result["data"] + meta = result["metadata"] + + print(f"✔ Generated image in {elapsed:.2f}s ({len(img_bytes)/1024:.1f} KB)") + print(f"✔ Metadata keys: {list(meta.keys())}") + + out = save_image(img_bytes, f"replicate_{model.replace('.', '_')}.png") + print(f"✔ Saved to {out}") + + assert validate_image(img_bytes), "Image failed Pillow validation" + + assert meta["model"] == model + assert meta["provider"] == "replicate" + assert meta["generation_time"] > 0 + assert "prediction_id" in meta + + print(f"=== Model {model} PASSED ===") + + except Exception: + print(f"=== Model {model} FAILED ===") + print(traceback.format_exc()) + + +def test_invalid_api_key(): + print("\n=== Testing invalid API key ===") + os.environ["REPLICATE_API_TOKEN"] = "invalid-token" + + service = ReplicateService() + manager = TaskManager() + + task_id = manager.create_task( + modality="image", + prompt="test prompt", + parameters={"model": Models.FLUX_SCHNELL}, + webhook_url=None, + signed_by="test" + ) + task = manager.get_task(task_id) + + try: + service.process(task) + raise AssertionError("❌ Should have failed with invalid API token!") + except Exception as e: + print(f"✔ Correctly failed: {e}") + + +def test_invalid_model(service, manager): + print("\n=== Testing invalid model ===") + + task_id = manager.create_task( + modality="image", + prompt="test prompt", + parameters={"model": "invalid-model"}, + webhook_url=None, + signed_by="test" + ) + task = manager.get_task(task_id) + + try: + service.process(task) + raise AssertionError("❌ Should have failed with invalid model!") + except ValueError as e: + print(f"✔ Correctly raised ValueError: {e}") + + +def test_sdxl_with_negative_prompt(service, manager): + print("\n=== Testing SDXL with negative prompt ===") + + task_id = manager.create_task( + modality="image", + prompt="A beautiful sunset over mountains", + parameters={ + "model": Models.SDXL, + "negative_prompt": "low quality, blurry, distorted", + "width": 1024, + "height": 1024, + }, + webhook_url=None, + signed_by="test" + ) + task = manager.get_task(task_id) + + result = service.process(task) + assert result["data"] is not None + assert result["metadata"]["model"] == Models.SDXL + print("✔ SDXL with negative prompt succeeded") + + +def run_full_test_suite(): + print("\n========== Replicate Full Test Suite ==========\n") + + os.environ["REPLICATE_API_TOKEN"] = os.getenv("REPLICATE_API_TOKEN", "") + + service = ReplicateService() + manager = TaskManager() + + if not service.is_available(): + print("❌ API token missing — cannot run tests") + return + + for model in [Models.FLUX_SCHNELL, Models.FLUX_DEV]: + run_model_test(service, manager, model) + + test_sdxl_with_negative_prompt(service, manager) + test_invalid_model(service, manager) + test_invalid_api_key() + + print("\n========== All Tests Completed ==========\n") + + +if __name__ == "__main__": + run_full_test_suite() From 1076c5dba40862e46e2f1e162ece7aec15f0b498 Mon Sep 17 00:00:00 2001 From: 0xsatoshi99 <0xsatoshi99@gmail.com> Date: Sat, 6 Dec 2025 00:44:23 +0100 Subject: [PATCH 4/5] fix: address cursor bot review feedback --- tests/generator/replicate_service.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/generator/replicate_service.py b/tests/generator/replicate_service.py index e3b43dbe..705d4e2d 100644 --- a/tests/generator/replicate_service.py +++ b/tests/generator/replicate_service.py @@ -7,10 +7,7 @@ from neurons.generator.services.replicate_service import ReplicateService, Models from neurons.generator.task_manager import TaskManager -os.environ.setdefault( - "REPLICATE_API_TOKEN", - "r8_YOUR-API-TOKEN" -) +# Note: Set REPLICATE_API_TOKEN in your environment before running tests def save_image(img_bytes, filename): @@ -93,6 +90,8 @@ def test_invalid_api_key(): try: service.process(task) raise AssertionError("❌ Should have failed with invalid API token!") + except AssertionError: + raise except Exception as e: print(f"✔ Correctly failed: {e}") From 1a9e6f90f60d2410edbdb0b6756da16484689501 Mon Sep 17 00:00:00 2001 From: 0xsatoshi99 <0xsatoshi99@gmail.com> Date: Mon, 8 Dec 2025 01:47:40 +0100 Subject: [PATCH 5/5] fix: address cursor bot review - restore env var and re-raise exceptions --- tests/generator/replicate_service.py | 48 ++++++++++++++++------------ 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/tests/generator/replicate_service.py b/tests/generator/replicate_service.py index 705d4e2d..0a0e754e 100644 --- a/tests/generator/replicate_service.py +++ b/tests/generator/replicate_service.py @@ -69,31 +69,39 @@ def run_model_test(service, manager, model): except Exception: print(f"=== Model {model} FAILED ===") print(traceback.format_exc()) + raise def test_invalid_api_key(): print("\n=== Testing invalid API key ===") - os.environ["REPLICATE_API_TOKEN"] = "invalid-token" - - service = ReplicateService() - manager = TaskManager() - - task_id = manager.create_task( - modality="image", - prompt="test prompt", - parameters={"model": Models.FLUX_SCHNELL}, - webhook_url=None, - signed_by="test" - ) - task = manager.get_task(task_id) - + original_token = os.environ.get("REPLICATE_API_TOKEN") try: - service.process(task) - raise AssertionError("❌ Should have failed with invalid API token!") - except AssertionError: - raise - except Exception as e: - print(f"✔ Correctly failed: {e}") + os.environ["REPLICATE_API_TOKEN"] = "invalid-token" + + service = ReplicateService() + manager = TaskManager() + + task_id = manager.create_task( + modality="image", + prompt="test prompt", + parameters={"model": Models.FLUX_SCHNELL}, + webhook_url=None, + signed_by="test" + ) + task = manager.get_task(task_id) + + try: + service.process(task) + raise AssertionError("❌ Should have failed with invalid API token!") + except AssertionError: + raise + except Exception as e: + print(f"✔ Correctly failed: {e}") + finally: + if original_token is not None: + os.environ["REPLICATE_API_TOKEN"] = original_token + elif "REPLICATE_API_TOKEN" in os.environ: + del os.environ["REPLICATE_API_TOKEN"] def test_invalid_model(service, manager):