diff --git a/bitmind/__init__.py b/bitmind/__init__.py index 1c4a95f8..7cd2bfa6 100644 --- a/bitmind/__init__.py +++ b/bitmind/__init__.py @@ -18,11 +18,10 @@ # DEALINGS IN THE SOFTWARE. -__version__ = "2.2.10" +__version__ = "2.2.11" version_split = __version__.split(".") __spec_version__ = ( (1000 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) ) - diff --git a/bitmind/synthetic_data_generation/image_utils.py b/bitmind/synthetic_data_generation/image_utils.py index 4593edb4..f3ad50be 100644 --- a/bitmind/synthetic_data_generation/image_utils.py +++ b/bitmind/synthetic_data_generation/image_utils.py @@ -107,3 +107,10 @@ def create_random_mask(size: Tuple[int, int]) -> Image.Image: ) return mask, (x, y) + +def is_black_image(img: Image.Image, threshold: int = 10) -> bool: + """ + Returns True if the image is (almost) completely black. + """ + arr = np.array(img) + return np.mean(arr) < threshold \ No newline at end of file diff --git a/bitmind/synthetic_data_generation/prompt_generator.py b/bitmind/synthetic_data_generation/prompt_generator.py index cd7823e6..e418ef08 100644 --- a/bitmind/synthetic_data_generation/prompt_generator.py +++ b/bitmind/synthetic_data_generation/prompt_generator.py @@ -55,16 +55,8 @@ def __init__( def are_models_loaded(self) -> bool: return (self.vlm is not None) and (self.llm_pipeline is not None) - - def load_models(self) -> None: - """ - Load the necessary models for image annotation and text moderation onto - the specified device. - """ - if self.are_models_loaded(): - bt.logging.warning(f"Models already loaded") - return - + + def load_vlm(self) -> None: bt.logging.info(f"Loading caption generation model {self.vlm_name}") self.vlm_processor = Blip2Processor.from_pretrained( self.vlm_name, @@ -77,7 +69,8 @@ def load_models(self) -> None: ) self.vlm.to(self.device) bt.logging.info(f"Loaded image annotation model {self.vlm_name}") - + + def load_llm(self) -> None: bt.logging.info(f"Loading caption moderation model {self.llm_name}") llm = AutoModelForCausalLM.from_pretrained( self.llm_name, @@ -95,6 +88,17 @@ def load_models(self) -> None: tokenizer=tokenizer ) bt.logging.info(f"Loaded caption moderation model {self.llm_name}") + + def load_models(self) -> None: + """ + Load the necessary models for image annotation and text moderation onto + the specified device. + """ + if self.are_models_loaded(): + bt.logging.warning(f"Models already loaded") + return + self.load_vlm() + self.load_llm() def clear_gpu(self) -> None: """ @@ -281,4 +285,35 @@ def enhance(self, description: str, max_new_tokens: int = 80) -> str: except Exception as e: bt.logging.error(f"An error occurred during motion enhancement: {e}") - return description \ No newline at end of file + return description + + def sanitize_prompt(self, prompt: str, max_new_tokens: int = 80) -> str: + """ + Use the LLM to make the prompt more SFW (less NSFW). + """ + messages = [ + { + "role": "system", + "content": ( + "[INST]You are an expert at making prompts safe for work (SFW). " + "Rephrase the following prompt to remove or neutralize any NSFW, sexual, or explicit content. " + "Keep the prompt as close as possible to the original intent, but ensure it is SFW. " + "Only respond with the sanitized prompt.[/INST]" + ) + }, + { + "role": "user", + "content": prompt + } + ] + try: + sanitized = self.llm_pipeline( + messages, + max_new_tokens=max_new_tokens, + pad_token_id=self.llm_pipeline.tokenizer.eos_token_id, + return_full_text=False + ) + return sanitized[0]['generated_text'] + except Exception as e: + bt.logging.error(f"An error occurred during prompt sanitization: {e}") + return prompt \ No newline at end of file diff --git a/bitmind/synthetic_data_generation/synthetic_data_generator.py b/bitmind/synthetic_data_generation/synthetic_data_generator.py index 45eab8e4..36257ad7 100644 --- a/bitmind/synthetic_data_generation/synthetic_data_generator.py +++ b/bitmind/synthetic_data_generation/synthetic_data_generator.py @@ -10,7 +10,6 @@ import bittensor as bt import numpy as np -import random import torch from diffusers.utils import export_to_video from PIL import Image @@ -33,7 +32,7 @@ MediaType, Modality ) -from bitmind.synthetic_data_generation.image_utils import create_random_mask +from bitmind.synthetic_data_generation.image_utils import create_random_mask, is_black_image from bitmind.synthetic_data_generation.prompt_utils import truncate_prompt_if_too_long from bitmind.synthetic_data_generation.prompt_generator import PromptGenerator from bitmind.validator.cache import ImageCache @@ -141,76 +140,126 @@ def batch_generate(self, batch_size: int = 5) -> None: """ Asynchronously generate synthetic data in batches. + This method handles the complete generation pipeline: + 1. Samples source images for captioning + 2. Determines which models to use (random or specific) + 3. Groups models by type (video vs image) + 4. Generates prompts by task type + 5. Performs generation for each model using appropriate prompts + 6. Saves outputs with metadata + Args: - batch_size: Number of prompts to generate in each batch. + batch_size: Number of prompts/generations to create per model """ - prompts = [] + # Step 1: Sample source images to reuse across all generations images = [] - bt.logging.info(f"Generating {batch_size} prompts") - - # Generate all prompts first + bt.logging.info(f"Starting batch generation with size {batch_size}") for i in range(batch_size): image_sample = self.image_cache.sample() images.append(image_sample['image']) bt.logging.info(f"Sampled image {i+1}/{batch_size} for captioning: {image_sample['path']}") - task = get_task(self.model_name) if self.model_name else None - prompts.append(self.generate_prompt( - image=image_sample['image'], - clear_gpu=i==batch_size-1, - task=task - )) - bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}") - - # If specific model is set, use only that model + + # Step 2: Determine which models to use + # Either use a single specified model or create a balanced random selection if not self.use_random_model and self.model_name: model_names = [self.model_name] else: - # shuffle and interleave models to add stochasticity + # Shuffle each model type separately to maintain diversity + # Then interleave them to distribute different types evenly i2i_model_names = random.sample(I2I_MODEL_NAMES, len(I2I_MODEL_NAMES)) t2i_model_names = random.sample(T2I_MODEL_NAMES, len(T2I_MODEL_NAMES)) t2v_model_names = random.sample(T2V_MODEL_NAMES, len(T2V_MODEL_NAMES)) i2v_model_names = random.sample(I2V_MODEL_NAMES, len(I2V_MODEL_NAMES)) + # Interleave all model types to ensure variety in generation order model_names = [ m for quad in zip_longest(t2v_model_names, t2i_model_names, i2i_model_names, i2v_model_names) for m in quad if m is not None ] + + bt.logging.info(f"Using {'random models' if self.use_random_model else f'specific model: {self.model_name}'}") - # Generate for each model/prompt combination + # Step 3: Group models by type to minimize prompt generator reloading + # Video models (t2v, i2v) need enhanced prompts with motion + # Image models (t2i, i2i) use standard prompts + video_models = [] # t2v, i2v models + image_models = [] # t2i, i2i models for model_name in model_names: - modality = get_modality(model_name) task = get_task(model_name) - media_type = get_output_media_type(model_name) - - for i, prompt in enumerate(prompts): - bt.logging.info(f"Started generation {i+1}/{batch_size} | Model: {model_name} | Prompt: {prompt}") - - # Generate image/video from current model and prompt - output = self._run_generation(prompt, task=task, model_name=model_name, image=images[i]) - - model_output_dir = self.output_dir / modality / media_type / model_name.split('/')[1] - model_output_dir.mkdir(parents=True, exist_ok=True) - base_path = model_output_dir / str(output['time']) - - bt.logging.info(f'Writing to cache {model_output_dir}') - - metadata = {k: v for k, v in output.items() if k != 'gen_output' and 'image' not in k} - base_path.with_suffix('.json').write_text(json.dumps(metadata)) - - if modality == 'image': - out_path = base_path.with_suffix('.png') - output['gen_output'].images[0].save(out_path) - elif modality == 'video': - bt.logging.info("Writing to cache") - out_path = str(base_path.with_suffix('.mp4')) - export_to_video( - output['gen_output'].frames[0], - out_path, - fps=30 - ) - bt.logging.info(f"Wrote to {out_path}") + if task in ['t2v', 'i2v']: + video_models.append((model_name, task)) + else: + image_models.append((model_name, task)) + + # Log model distribution for monitoring + bt.logging.info(f"Model distribution:") + bt.logging.info(f"- Video models ({len(video_models)}): {[m[0] for m in video_models]}") + bt.logging.info(f"- Image models ({len(image_models)}): {[m[0] for m in image_models]}") + + # Step 4: Process each group (video/image) separately + # This allows us to generate appropriate prompts once per type + for is_video_task, model_group in [ + (True, video_models), + (False, image_models) + ]: + if not model_group: + bt.logging.info(f"Skipping {'video' if is_video_task else 'image'} generation - no models in group") + continue + + # Generate all prompts for this task type at once + prompts = [] + bt.logging.info(f"Generating {batch_size} {'video' if is_video_task else 'image'} prompts") + + try: + for i in range(batch_size): + # Use the first model's task type as reference for prompt generation + reference_task = model_group[0][1] # Each entry is (model_name, task) + prompts.append(self.generate_prompt( + image=images[i], + clear_gpu=False, # Keep prompt generator loaded until all prompts are done + task=reference_task + )) + bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}") + finally: + # Ensure prompt generator is cleared from GPU + self.prompt_generator.clear_gpu() + torch.cuda.empty_cache() + gc.collect() + # Step 5: Generate outputs for each model using the prepared prompts + for model_name, task in model_group: + modality = get_modality(model_name) + media_type = get_output_media_type(model_name) + + for i, prompt in enumerate(prompts): + bt.logging.info(f"Started generation {i+1}/{batch_size} | Model: {model_name} | Prompt: {prompt}") + + # Generate the actual output (image/video) + output = self._run_generation(prompt, task=task, model_name=model_name, image=images[i]) + + # Step 6: Save the generated output and its metadata + # Organize outputs by modality/type/model for easy retrieval + model_output_dir = self.output_dir / modality / media_type / model_name.split('/')[1] + model_output_dir.mkdir(parents=True, exist_ok=True) + base_path = model_output_dir / str(output['time']) + + bt.logging.info(f'Writing to cache {model_output_dir}') + + # Save metadata separately from the actual output + metadata = {k: v for k, v in output.items() if k != 'gen_output' and 'image' not in k} + base_path.with_suffix('.json').write_text(json.dumps(metadata)) + + # Save the output in appropriate format based on modality + if modality == Modality.IMAGE: + out_path = base_path.with_suffix('.png') + output['gen_output'].images[0].save(out_path) + elif modality == Modality.VIDEO: + bt.logging.info("Writing to cache") + out_path = str(base_path.with_suffix('.mp4')) + export_to_video(output['gen_output'].frames[0], out_path, fps=30) + bt.logging.info(f"Wrote to {out_path}") + def generate( self, image: Optional[Image.Image] = None, @@ -257,7 +306,7 @@ def generate_prompt( else: raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}") return prompt - + def _run_generation( self, prompt: str, @@ -398,6 +447,29 @@ def _run_generation( bt.logging.error(f"Image generation error: {e}") raise RuntimeError(f"Failed to generate image: {e}") + if self.model_name == "DeepFloyd/IF": + max_retries = 3 + attempt = 0 + while attempt < max_retries: + img = gen_output.images[0] + if is_black_image(img): + bt.logging.warning("DeepFloyd/IF returned a black image (likely NSFW). Attempting to sanitize prompt and retry.") + # Ensure prompt generator models are loaded + self.prompt_generator.load_llm() + # Sanitize the prompt + prompt = self.prompt_generator.sanitize_prompt(prompt) + truncated_prompt = truncate_prompt_if_too_long(prompt, self.model) + bt.logging.info(f"Sanitized prompt: {truncated_prompt}") + self.prompt_generator.clear_gpu() + try: + gen_output = generate(truncated_prompt, **gen_args) + except Exception as e: + bt.logging.error(f"Sanitized prompt generation failed: {e}") + break + attempt += 1 + else: + break + print(f"Finished generation in {gen_time/60} minutes") return { 'prompt': truncated_prompt, diff --git a/bitmind/validator/config.py b/bitmind/validator/config.py index fc56eefb..08d04842 100644 --- a/bitmind/validator/config.py +++ b/bitmind/validator/config.py @@ -420,7 +420,27 @@ class Modality(StrEnum): T2V_MODEL_NAMES: List[str] = list(T2V_MODELS.keys()) # Image-to-video model configurations -I2V_MODELS: Dict[str, Dict[str, Any]] = {} +I2V_MODELS: Dict[str, Dict[str, Any]] = { + "THUDM/CogVideoX1.5-5B-I2V": { + "pipeline_cls": CogVideoXImageToVideoPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.bfloat16 + }, + "generate_args": { + "guidance_scale": 2, + "num_videos_per_prompt": 1, + "num_inference_steps": {"min": 50, "max": 125}, + "num_frames": 49, + "height": 768, + "width": 768, + }, + "save_args": {"fps": 8}, + "enable_model_cpu_offload": True, + "vae_enable_slicing": True, + "vae_enable_tiling": True + } +} I2V_MODEL_NAMES: List[str] = list(I2V_MODELS.keys()) # Combined model configurations @@ -466,7 +486,7 @@ def select_random_model(task: Optional[str] = None) -> str: NotImplementedError: If the specified modality is not supported. """ if task is None or task == 'random': - task = np.random.choice(['t2i', 'i2i', 't2v']) + task = np.random.choice(['t2i', 'i2i', 't2v', 'i2v']) if task == 't2i': return np.random.choice(T2I_MODEL_NAMES) @@ -475,8 +495,6 @@ def select_random_model(task: Optional[str] = None) -> str: elif task == 'i2i': return np.random.choice(I2I_MODEL_NAMES) elif task == 'i2v': - if not I2V_MODEL_NAMES: - raise NotImplementedError("I2V models are not currently configured") return np.random.choice(I2V_MODEL_NAMES) else: raise NotImplementedError(f"Unsupported task: {task}") \ No newline at end of file