diff --git a/hydra_vl4ai/tool/blip.py b/hydra_vl4ai/tool/blip.py index e3dd021..e20d1db 100644 --- a/hydra_vl4ai/tool/blip.py +++ b/hydra_vl4ai/tool/blip.py @@ -52,7 +52,7 @@ def __init__(self, gpu_number=0, blip_v2_model_type="blip2-flan-t5-xxl"): @torch.no_grad() def caption(self, image, prompt=None): - inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.dev) + inputs = self.processor(images=image, text=prompt, return_tensors="pt", do_rescale=False).to(self.dev) generated_ids = self.model.generate(**inputs, length_penalty=1., num_beams=5, max_length=30, min_length=1, do_sample=False, top_p=0.9, repetition_penalty=1.0, num_return_sequences=1, temperature=1) @@ -78,7 +78,7 @@ def pre_question(self, question): @torch.no_grad() def qa(self, image, question): - inputs = self.processor(images=image, text=question, return_tensors="pt", padding="longest").to(self.dev) + inputs = self.processor(images=image, text=question, return_tensors="pt", padding="longest", do_rescale=False).to(self.dev) generated_ids = self.model.generate(**inputs, length_penalty=-1, num_beams=5, max_length=10, min_length=1, do_sample=False, top_p=0.9, repetition_penalty=1.0, num_return_sequences=1, temperature=1)