diff --git a/bitmind/synthetic_data_generation/synthetic_data_generator.py b/bitmind/synthetic_data_generation/synthetic_data_generator.py index e2122376..45eab8e4 100644 --- a/bitmind/synthetic_data_generation/synthetic_data_generator.py +++ b/bitmind/synthetic_data_generation/synthetic_data_generator.py @@ -511,63 +511,4 @@ def clear_gpu(self) -> None: del self.model self.model = None gc.collect() - torch.cuda.empty_cache() - - def generate_from_prompt( - self, - prompt: str, - task: Optional[str] = None, - image: Optional[Image.Image] = None, - generate_at_target_size: bool = False - ) -> Dict[str, Any]: - """Generate synthetic data based on a provided prompt. - - Args: - prompt: The text prompt to use for generation - task: Optional task type ('t2i', 't2v', 'i2i', 'i2v') - image: Optional input image for i2i or i2v generation - generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions - - Returns: - Dictionary containing generated data information - """ - bt.logging.info(f"Generating synthetic data from provided prompt: {prompt}") - - # Default to t2i if task is not specified - if task is None: - task = 't2i' - - # If model_name is not specified, select one based on the task - if self.model_name is None and self.use_random_model: - bt.logging.warning(f"No model configured. Using random model.") - if task == 't2i': - model_candidates = T2I_MODEL_NAMES - elif task == 't2v': - model_candidates = T2V_MODEL_NAMES - elif task == 'i2i': - model_candidates = I2I_MODEL_NAMES - elif task == 'i2v': - model_candidates = I2V_MODEL_NAMES - else: - raise ValueError(f"Unsupported task: {task}") - - self.model_name = random.choice(model_candidates) - - # Validate input image for tasks that require it - if task in ['i2i', 'i2v'] and image is None: - raise ValueError(f"Input image is required for {task} generation") - - # Run the generation with the provided prompt - gen_data = self._run_generation( - prompt=prompt, - task=task, - model_name=self.model_name, - image=image, - generate_at_target_size=generate_at_target_size - ) - - # Clean up GPU memory - self.clear_gpu() - - return gen_data - + torch.cuda.empty_cache() \ No newline at end of file