Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitmind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# DEALINGS IN THE SOFTWARE.


__version__ = "2.2.8"
__version__ = "2.2.9"
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
Expand Down
44 changes: 22 additions & 22 deletions bitmind/synthetic_data_generation/prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def clear_gpu(self) -> None:
def generate(
self,
image: Image.Image,
task: Optional[str] = None,
max_new_tokens: int = 20,
verbose: bool = False
) -> str:
Expand All @@ -127,6 +128,8 @@ def generate(

Args:
image: The image for which the description is to be generated.
task: The generation task ('t2i', 't2v', 'i2i', 'i2v'). If video task,
motion descriptions will be added.
max_new_tokens: The maximum number of tokens to generate for each
prompt.
verbose: If True, additional logging information is printed.
Expand Down Expand Up @@ -185,7 +188,10 @@ def generate(
description += '.'

moderated_description = self.moderate(description)
return self.enhance(moderated_description)

if task in ['t2v', 'i2v']:
return self.enhance(moderated_description)
return moderated_description

def moderate(self, description: str, max_new_tokens: int = 80) -> str:
"""
Expand Down Expand Up @@ -233,34 +239,28 @@ def enhance(self, description: str, max_new_tokens: int = 80) -> str:
"""
Enhance a static image description to make it suitable for video generation
by adding dynamic elements and motion.

Args:
description: The static image description to enhance.
max_new_tokens: Maximum number of new tokens to generate in the enhanced text.

Returns:
An enhanced description suitable for video generation, or the original
description if enhancement fails.
An enhanced description suitable for video generation.
"""
messages = [
{
"role": "system",
"content": (
"[INST]You are an expert at converting static image descriptions "
"into dynamic video prompts. Enhance the given description by "
"adding natural motion and temporal elements while preserving the "
"core scene. Follow these rules:\n"
"1. Maintain the essential elements of the original description\n"
"2. Add smooth, continuous motions that work well in video\n"
"3. For portraits: Add natural facial movements or expressions\n"
"4. For non-portrait images with people: Add contextually appropriate "
"actions (e.g., for a beach scene, people might be walking along "
"the shoreline or playing in the waves; for a cafe scene, people "
"might be sipping drinks or engaging in conversation)\n"
"5. For landscapes: Add environmental motion like wind or water\n"
"6. For urban scenes: Add dynamic elements like people or traffic\n"
"7. Keep the description concise but descriptive\n"
"8. Focus on gradual, natural transitions\n"
"[INST]You are an expert at converting image descriptions into video prompts. "
"Analyze the existing motion in the scene and enhance it naturally:\n"
"1. If motion exists in the image (falling, throwing, running, etc.):\n"
" - Maintain and emphasize that existing motion\n"
" - Add smooth continuation of the movement\n"
"2. If the subject is static (sitting, standing, placed):\n"
" - Keep it stable\n"
" - Add minimal environmental motion if appropriate\n"
"3. Add ONE subtle camera motion that complements the scene\n"
"4. Keep the description concise and natural\n"
"Only respond with the enhanced description.[/INST]"
)
},
Expand All @@ -280,5 +280,5 @@ def enhance(self, description: str, max_new_tokens: int = 80) -> str:
return enhanced_text[0]['generated_text']

except Exception as e:
print(f"An error occurred during motion enhancement: {e}")
return description
bt.logging.error(f"An error occurred during motion enhancement: {e}")
return description
126 changes: 117 additions & 9 deletions bitmind/synthetic_data_generation/synthetic_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
T2V_MODEL_NAMES,
T2I_MODEL_NAMES,
I2I_MODEL_NAMES,
I2V_MODEL_NAMES,
TARGET_IMAGE_SIZE,
select_random_model,
get_task,
Expand Down Expand Up @@ -152,7 +153,12 @@ def batch_generate(self, batch_size: int = 5) -> None:
image_sample = self.image_cache.sample()
images.append(image_sample['image'])
bt.logging.info(f"Sampled image {i+1}/{batch_size} for captioning: {image_sample['path']}")
prompts.append(self.generate_prompt(image=image_sample['image'], clear_gpu=i==batch_size-1))
task = get_task(self.model_name) if self.model_name else None
prompts.append(self.generate_prompt(
image=image_sample['image'],
clear_gpu=i==batch_size-1,
task=task
))
bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}")

# If specific model is set, use only that model
Expand All @@ -163,9 +169,12 @@ def batch_generate(self, batch_size: int = 5) -> None:
i2i_model_names = random.sample(I2I_MODEL_NAMES, len(I2I_MODEL_NAMES))
t2i_model_names = random.sample(T2I_MODEL_NAMES, len(T2I_MODEL_NAMES))
t2v_model_names = random.sample(T2V_MODEL_NAMES, len(T2V_MODEL_NAMES))
i2v_model_names = random.sample(I2V_MODEL_NAMES, len(I2V_MODEL_NAMES))

model_names = [
m for triple in zip_longest(t2v_model_names, t2i_model_names, i2i_model_names)
for m in triple if m is not None
m for quad in zip_longest(t2v_model_names, t2i_model_names,
i2i_model_names, i2v_model_names)
for m in quad if m is not None
]

# Generate for each model/prompt combination
Expand Down Expand Up @@ -222,7 +231,7 @@ def generate(
ValueError: If real_image is None when using annotation prompt type.
NotImplementedError: If prompt type is not supported.
"""
prompt = self.generate_prompt(image, clear_gpu=True)
prompt = self.generate_prompt(image, clear_gpu=True, task=task)
bt.logging.info("Generating synthetic data...")
gen_data = self._run_generation(prompt, task, model_name, image)
self.clear_gpu()
Expand All @@ -231,7 +240,8 @@ def generate(
def generate_prompt(
self,
image: Optional[Image.Image] = None,
clear_gpu: bool = True
clear_gpu: bool = True,
task: Optional[str] = None
) -> str:
"""Generate a prompt based on the specified strategy."""
bt.logging.info("Generating prompt")
Expand All @@ -241,7 +251,7 @@ def generate_prompt(
"image can't be None if self.prompt_type is 'annotation'"
)
self.prompt_generator.load_models()
prompt = self.prompt_generator.generate(image)
prompt = self.prompt_generator.generate(image, task=task)
if clear_gpu:
self.prompt_generator.clear_gpu()
else:
Expand All @@ -261,9 +271,9 @@ def _run_generation(

Args:
prompt: The text prompt used to inspire the generation.
task: The generation task type ('t2i', 't2v', 'i2i', or None).
task: The generation task type ('t2i', 't2v', 'i2i', 'i2v', or None).
model_name: Optional model name to use for generation.
image: Optional input image for image-to-image generation.
image: Optional input image for image-to-image or image-to-video generation.
generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions.

Returns:
Expand All @@ -272,6 +282,10 @@ def _run_generation(
Raises:
RuntimeError: If generation fails.
"""
# Clear CUDA cache before loading model
torch.cuda.empty_cache()
gc.collect()

self.load_model(model_name)
model_config = MODELS[self.model_name]
task = get_task(model_name) if task is None else task
Expand All @@ -289,14 +303,38 @@ def _run_generation(

gen_args['mask_image'], mask_center = create_random_mask(image.size)
gen_args['image'] = image
# prep image-to-video generation args
elif task == 'i2v':
if image is None:
raise ValueError("image cannot be None for image-to-video generation")
# Get target size from gen_args if specified, otherwise use default
target_size = (
gen_args.get('height', 768),
gen_args.get('width', 768)
)
if image.size[0] > target_size[0] or image.size[1] > target_size[1]:
image = image.resize(target_size, Image.Resampling.LANCZOS)
gen_args['image'] = image

# Prepare generation arguments
for k, v in gen_args.items():
if isinstance(v, dict):
if "min" in v and "max" in v:
gen_args[k] = np.random.randint(v['min'], v['max'])
# For i2v, use minimum values to save memory
if task == 'i2v':
gen_args[k] = v['min']
else:
gen_args[k] = np.random.randint(v['min'], v['max'])
if "options" in v:
gen_args[k] = random.choice(v['options'])
# Ensure num_frames is always an integer
if k == 'num_frames' and isinstance(v, dict):
if "min" in v:
gen_args[k] = v['min']
elif "max" in v:
gen_args[k] = v['max']
else:
gen_args[k] = 24 # Default value

try:
if generate_at_target_size:
Expand All @@ -307,6 +345,10 @@ def _run_generation(
gen_args['width'] = gen_args['resolution'][1]
del gen_args['resolution']

# Ensure num_frames is an integer before generation
if 'num_frames' in gen_args:
gen_args['num_frames'] = int(gen_args['num_frames'])

truncated_prompt = truncate_prompt_if_too_long(prompt, self.model)
bt.logging.info(f"Generating media from prompt: {truncated_prompt}")
bt.logging.info(f"Generation args: {gen_args}")
Expand All @@ -321,8 +363,14 @@ def _run_generation(
pretrained_args = model_config.get('from_pretrained_args', {})
torch_dtype = pretrained_args.get('torch_dtype', torch.bfloat16)
with torch.autocast(self.device, torch_dtype, cache_enabled=False):
# Clear CUDA cache before generation
torch.cuda.empty_cache()
gc.collect()
gen_output = generate(truncated_prompt, **gen_args)
else:
# Clear CUDA cache before generation
torch.cuda.empty_cache()
gc.collect()
gen_output = generate(truncated_prompt, **gen_args)

gen_time = time.time() - start_time
Expand All @@ -334,6 +382,8 @@ def _run_generation(
f"default dimensions. Error: {e}"
)
try:
# Clear CUDA cache before retry
torch.cuda.empty_cache()
gen_output = self.model(prompt=truncated_prompt)
gen_time = time.time() - start_time
except Exception as fallback_error:
Expand Down Expand Up @@ -463,3 +513,61 @@ def clear_gpu(self) -> None:
gc.collect()
torch.cuda.empty_cache()

def generate_from_prompt(
self,
prompt: str,
task: Optional[str] = None,
image: Optional[Image.Image] = None,
generate_at_target_size: bool = False
) -> Dict[str, Any]:
"""Generate synthetic data based on a provided prompt.

Args:
prompt: The text prompt to use for generation
task: Optional task type ('t2i', 't2v', 'i2i', 'i2v')
image: Optional input image for i2i or i2v generation
generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions

Returns:
Dictionary containing generated data information
"""
bt.logging.info(f"Generating synthetic data from provided prompt: {prompt}")

# Default to t2i if task is not specified
if task is None:
task = 't2i'

# If model_name is not specified, select one based on the task
if self.model_name is None and self.use_random_model:
bt.logging.warning(f"No model configured. Using random model.")
if task == 't2i':
model_candidates = T2I_MODEL_NAMES
elif task == 't2v':
model_candidates = T2V_MODEL_NAMES
elif task == 'i2i':
model_candidates = I2I_MODEL_NAMES
elif task == 'i2v':
model_candidates = I2V_MODEL_NAMES
else:
raise ValueError(f"Unsupported task: {task}")

self.model_name = random.choice(model_candidates)

# Validate input image for tasks that require it
if task in ['i2i', 'i2v'] and image is None:
raise ValueError(f"Input image is required for {task} generation")

# Run the generation with the provided prompt
gen_data = self._run_generation(
prompt=prompt,
task=task,
model_name=self.model_name,
image=image,
generate_at_target_size=generate_at_target_size
)

# Clean up GPU memory
self.clear_gpu()

return gen_data

6 changes: 3 additions & 3 deletions bitmind/validator/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from bitmind.utils.uids import get_random_uids
from bitmind.validator.reward import get_rewards
from bitmind.validator.config import (
TARGET_IMAGE_SIZE,
TARGET_IMAGE_SIZE,
MIN_FRAMES,
MAX_FRAMES,
P_STITCH,
Expand Down Expand Up @@ -136,7 +136,7 @@ def sample_video_frames(self, video_cache):
sample['video'] = sample_A['video'] + sample_B['video']

return sample

def process_metadata(self, sample) -> bool:
"""Prepare challenge metadata and media for logging to Weights & Biases """
self.metadata = {
Expand Down Expand Up @@ -179,4 +179,4 @@ def create_wandb_video(video_frames, fps):
except Exception as e:
bt.logging.error(e)
bt.logging.error(f"{self.modality} is truncated or corrupt. Challenge skipped.")
return False
return False
Loading