Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions bitmind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
# DEALINGS IN THE SOFTWARE.


__version__ = "2.2.10"
__version__ = "2.2.11"
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
+ (10 * int(version_split[1]))
+ (1 * int(version_split[2]))
)

7 changes: 7 additions & 0 deletions bitmind/synthetic_data_generation/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,10 @@ def create_random_mask(size: Tuple[int, int]) -> Image.Image:
)

return mask, (x, y)

def is_black_image(img: Image.Image, threshold: int = 10) -> bool:
"""
Returns True if the image is (almost) completely black.
"""
arr = np.array(img)
return np.mean(arr) < threshold
59 changes: 47 additions & 12 deletions bitmind/synthetic_data_generation/prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,8 @@ def __init__(

def are_models_loaded(self) -> bool:
return (self.vlm is not None) and (self.llm_pipeline is not None)

def load_models(self) -> None:
"""
Load the necessary models for image annotation and text moderation onto
the specified device.
"""
if self.are_models_loaded():
bt.logging.warning(f"Models already loaded")
return


def load_vlm(self) -> None:
bt.logging.info(f"Loading caption generation model {self.vlm_name}")
self.vlm_processor = Blip2Processor.from_pretrained(
self.vlm_name,
Expand All @@ -77,7 +69,8 @@ def load_models(self) -> None:
)
self.vlm.to(self.device)
bt.logging.info(f"Loaded image annotation model {self.vlm_name}")


def load_llm(self) -> None:
bt.logging.info(f"Loading caption moderation model {self.llm_name}")
llm = AutoModelForCausalLM.from_pretrained(
self.llm_name,
Expand All @@ -95,6 +88,17 @@ def load_models(self) -> None:
tokenizer=tokenizer
)
bt.logging.info(f"Loaded caption moderation model {self.llm_name}")

def load_models(self) -> None:
"""
Load the necessary models for image annotation and text moderation onto
the specified device.
"""
if self.are_models_loaded():
bt.logging.warning(f"Models already loaded")
return
self.load_vlm()
self.load_llm()

def clear_gpu(self) -> None:
"""
Expand Down Expand Up @@ -281,4 +285,35 @@ def enhance(self, description: str, max_new_tokens: int = 80) -> str:

except Exception as e:
bt.logging.error(f"An error occurred during motion enhancement: {e}")
return description
return description

def sanitize_prompt(self, prompt: str, max_new_tokens: int = 80) -> str:
"""
Use the LLM to make the prompt more SFW (less NSFW).
"""
messages = [
{
"role": "system",
"content": (
"[INST]You are an expert at making prompts safe for work (SFW). "
"Rephrase the following prompt to remove or neutralize any NSFW, sexual, or explicit content. "
"Keep the prompt as close as possible to the original intent, but ensure it is SFW. "
"Only respond with the sanitized prompt.[/INST]"
)
},
{
"role": "user",
"content": prompt
}
]
try:
sanitized = self.llm_pipeline(
messages,
max_new_tokens=max_new_tokens,
pad_token_id=self.llm_pipeline.tokenizer.eos_token_id,
return_full_text=False
)
return sanitized[0]['generated_text']
except Exception as e:
bt.logging.error(f"An error occurred during prompt sanitization: {e}")
return prompt
170 changes: 121 additions & 49 deletions bitmind/synthetic_data_generation/synthetic_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import bittensor as bt
import numpy as np
import random
import torch
from diffusers.utils import export_to_video
from PIL import Image
Expand All @@ -33,7 +32,7 @@
MediaType,
Modality
)
from bitmind.synthetic_data_generation.image_utils import create_random_mask
from bitmind.synthetic_data_generation.image_utils import create_random_mask, is_black_image
from bitmind.synthetic_data_generation.prompt_utils import truncate_prompt_if_too_long
from bitmind.synthetic_data_generation.prompt_generator import PromptGenerator
from bitmind.validator.cache import ImageCache
Expand Down Expand Up @@ -141,76 +140,126 @@ def batch_generate(self, batch_size: int = 5) -> None:
"""
Asynchronously generate synthetic data in batches.

This method handles the complete generation pipeline:
1. Samples source images for captioning
2. Determines which models to use (random or specific)
3. Groups models by type (video vs image)
4. Generates prompts by task type
5. Performs generation for each model using appropriate prompts
6. Saves outputs with metadata

Args:
batch_size: Number of prompts to generate in each batch.
batch_size: Number of prompts/generations to create per model
"""
prompts = []
# Step 1: Sample source images to reuse across all generations
images = []
bt.logging.info(f"Generating {batch_size} prompts")

# Generate all prompts first
bt.logging.info(f"Starting batch generation with size {batch_size}")
for i in range(batch_size):
image_sample = self.image_cache.sample()
images.append(image_sample['image'])
bt.logging.info(f"Sampled image {i+1}/{batch_size} for captioning: {image_sample['path']}")
task = get_task(self.model_name) if self.model_name else None
prompts.append(self.generate_prompt(
image=image_sample['image'],
clear_gpu=i==batch_size-1,
task=task
))
bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}")

# If specific model is set, use only that model

# Step 2: Determine which models to use
# Either use a single specified model or create a balanced random selection
if not self.use_random_model and self.model_name:
model_names = [self.model_name]
else:
# shuffle and interleave models to add stochasticity
# Shuffle each model type separately to maintain diversity
# Then interleave them to distribute different types evenly
i2i_model_names = random.sample(I2I_MODEL_NAMES, len(I2I_MODEL_NAMES))
t2i_model_names = random.sample(T2I_MODEL_NAMES, len(T2I_MODEL_NAMES))
t2v_model_names = random.sample(T2V_MODEL_NAMES, len(T2V_MODEL_NAMES))
i2v_model_names = random.sample(I2V_MODEL_NAMES, len(I2V_MODEL_NAMES))

# Interleave all model types to ensure variety in generation order
model_names = [
m for quad in zip_longest(t2v_model_names, t2i_model_names,
i2i_model_names, i2v_model_names)
for m in quad if m is not None
]

bt.logging.info(f"Using {'random models' if self.use_random_model else f'specific model: {self.model_name}'}")

# Generate for each model/prompt combination
# Step 3: Group models by type to minimize prompt generator reloading
# Video models (t2v, i2v) need enhanced prompts with motion
# Image models (t2i, i2i) use standard prompts
video_models = [] # t2v, i2v models
image_models = [] # t2i, i2i models
for model_name in model_names:
modality = get_modality(model_name)
task = get_task(model_name)
media_type = get_output_media_type(model_name)

for i, prompt in enumerate(prompts):
bt.logging.info(f"Started generation {i+1}/{batch_size} | Model: {model_name} | Prompt: {prompt}")

# Generate image/video from current model and prompt
output = self._run_generation(prompt, task=task, model_name=model_name, image=images[i])

model_output_dir = self.output_dir / modality / media_type / model_name.split('/')[1]
model_output_dir.mkdir(parents=True, exist_ok=True)
base_path = model_output_dir / str(output['time'])

bt.logging.info(f'Writing to cache {model_output_dir}')

metadata = {k: v for k, v in output.items() if k != 'gen_output' and 'image' not in k}
base_path.with_suffix('.json').write_text(json.dumps(metadata))

if modality == 'image':
out_path = base_path.with_suffix('.png')
output['gen_output'].images[0].save(out_path)
elif modality == 'video':
bt.logging.info("Writing to cache")
out_path = str(base_path.with_suffix('.mp4'))
export_to_video(
output['gen_output'].frames[0],
out_path,
fps=30
)
bt.logging.info(f"Wrote to {out_path}")
if task in ['t2v', 'i2v']:
video_models.append((model_name, task))
else:
image_models.append((model_name, task))

# Log model distribution for monitoring
bt.logging.info(f"Model distribution:")
bt.logging.info(f"- Video models ({len(video_models)}): {[m[0] for m in video_models]}")
bt.logging.info(f"- Image models ({len(image_models)}): {[m[0] for m in image_models]}")

# Step 4: Process each group (video/image) separately
# This allows us to generate appropriate prompts once per type
for is_video_task, model_group in [
(True, video_models),
(False, image_models)
]:
if not model_group:
bt.logging.info(f"Skipping {'video' if is_video_task else 'image'} generation - no models in group")
continue

# Generate all prompts for this task type at once
prompts = []
bt.logging.info(f"Generating {batch_size} {'video' if is_video_task else 'image'} prompts")

try:
for i in range(batch_size):
# Use the first model's task type as reference for prompt generation
reference_task = model_group[0][1] # Each entry is (model_name, task)
prompts.append(self.generate_prompt(
image=images[i],
clear_gpu=False, # Keep prompt generator loaded until all prompts are done
task=reference_task
))
bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}")
finally:
# Ensure prompt generator is cleared from GPU
self.prompt_generator.clear_gpu()
torch.cuda.empty_cache()
gc.collect()

# Step 5: Generate outputs for each model using the prepared prompts
for model_name, task in model_group:
modality = get_modality(model_name)
media_type = get_output_media_type(model_name)

for i, prompt in enumerate(prompts):
bt.logging.info(f"Started generation {i+1}/{batch_size} | Model: {model_name} | Prompt: {prompt}")

# Generate the actual output (image/video)
output = self._run_generation(prompt, task=task, model_name=model_name, image=images[i])

# Step 6: Save the generated output and its metadata
# Organize outputs by modality/type/model for easy retrieval
model_output_dir = self.output_dir / modality / media_type / model_name.split('/')[1]
model_output_dir.mkdir(parents=True, exist_ok=True)
base_path = model_output_dir / str(output['time'])

bt.logging.info(f'Writing to cache {model_output_dir}')

# Save metadata separately from the actual output
metadata = {k: v for k, v in output.items() if k != 'gen_output' and 'image' not in k}
base_path.with_suffix('.json').write_text(json.dumps(metadata))

# Save the output in appropriate format based on modality
if modality == Modality.IMAGE:
out_path = base_path.with_suffix('.png')
output['gen_output'].images[0].save(out_path)
elif modality == Modality.VIDEO:
bt.logging.info("Writing to cache")
out_path = str(base_path.with_suffix('.mp4'))
export_to_video(output['gen_output'].frames[0], out_path, fps=30)
bt.logging.info(f"Wrote to {out_path}")

def generate(
self,
image: Optional[Image.Image] = None,
Expand Down Expand Up @@ -257,7 +306,7 @@ def generate_prompt(
else:
raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}")
return prompt

def _run_generation(
self,
prompt: str,
Expand Down Expand Up @@ -398,6 +447,29 @@ def _run_generation(
bt.logging.error(f"Image generation error: {e}")
raise RuntimeError(f"Failed to generate image: {e}")

if self.model_name == "DeepFloyd/IF":
max_retries = 3
attempt = 0
while attempt < max_retries:
img = gen_output.images[0]
if is_black_image(img):
bt.logging.warning("DeepFloyd/IF returned a black image (likely NSFW). Attempting to sanitize prompt and retry.")
# Ensure prompt generator models are loaded
self.prompt_generator.load_llm()
# Sanitize the prompt
prompt = self.prompt_generator.sanitize_prompt(prompt)
truncated_prompt = truncate_prompt_if_too_long(prompt, self.model)
bt.logging.info(f"Sanitized prompt: {truncated_prompt}")
self.prompt_generator.clear_gpu()
try:
gen_output = generate(truncated_prompt, **gen_args)
except Exception as e:
bt.logging.error(f"Sanitized prompt generation failed: {e}")
break
attempt += 1
else:
break

print(f"Finished generation in {gen_time/60} minutes")
return {
'prompt': truncated_prompt,
Expand Down
26 changes: 22 additions & 4 deletions bitmind/validator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,27 @@ class Modality(StrEnum):
T2V_MODEL_NAMES: List[str] = list(T2V_MODELS.keys())

# Image-to-video model configurations
I2V_MODELS: Dict[str, Dict[str, Any]] = {}
I2V_MODELS: Dict[str, Dict[str, Any]] = {
"THUDM/CogVideoX1.5-5B-I2V": {
"pipeline_cls": CogVideoXImageToVideoPipeline,
"from_pretrained_args": {
"use_safetensors": True,
"torch_dtype": torch.bfloat16
},
"generate_args": {
"guidance_scale": 2,
"num_videos_per_prompt": 1,
"num_inference_steps": {"min": 50, "max": 125},
"num_frames": 49,
"height": 768,
"width": 768,
},
"save_args": {"fps": 8},
"enable_model_cpu_offload": True,
"vae_enable_slicing": True,
"vae_enable_tiling": True
}
}
I2V_MODEL_NAMES: List[str] = list(I2V_MODELS.keys())

# Combined model configurations
Expand Down Expand Up @@ -466,7 +486,7 @@ def select_random_model(task: Optional[str] = None) -> str:
NotImplementedError: If the specified modality is not supported.
"""
if task is None or task == 'random':
task = np.random.choice(['t2i', 'i2i', 't2v'])
task = np.random.choice(['t2i', 'i2i', 't2v', 'i2v'])

if task == 't2i':
return np.random.choice(T2I_MODEL_NAMES)
Expand All @@ -475,8 +495,6 @@ def select_random_model(task: Optional[str] = None) -> str:
elif task == 'i2i':
return np.random.choice(I2I_MODEL_NAMES)
elif task == 'i2v':
if not I2V_MODEL_NAMES:
raise NotImplementedError("I2V models are not currently configured")
return np.random.choice(I2V_MODEL_NAMES)
else:
raise NotImplementedError(f"Unsupported task: {task}")