-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I'm unable to simulate inference for the MobileNet-SSDLite model using the function:
def run_mobilenet(batch, config):
from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request
scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config)
device = scheduler.execution_engine.module.custom_device()
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT).eval()
input = torch.randn(batch, 3, 320, 320).to(device=device)
opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last))
SchedulerDNNModel.register_model("mobilenet-ssd", opt_fn)
request = Request("mobilenet-ssd", [input], [], request_queue_idx=0)
scheduler.add_request(request, request_time=0)
# Run scheduler
while not scheduler.is_finished():
with torch.no_grad():
scheduler.schedule()
print("Mobilenet SSD Simulation Done")Which gives the error: TypeError: cannot unpack non-iterable CSEVariable object with the trace:
File "/workspace/PyTorchSim/experiments/mobilenet-ssd.py", line 28, in run_mobilenet
scheduler.schedule()
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 461, in schedule
result.append(self.per_schedule(i))
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 446, in per_schedule
self.execution_engine.submit(request_list, request_queue_idx)
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 208, in submit
self.prepare_model(batched_req_model)
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 233, in prepare_model
ret = req_model.model(*input_tensor_list)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
.....
File "/workspace/PyTorchSim/PyTorchSimFrontend/mlir/mlir_common.py", line 864, in inner
code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/common.py", line 1017, in inner
getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type]
File "/workspace/PyTorchSim/PyTorchSimFrontend/mlir/mlir_common.py", line 864, in inner
code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: cannot unpack non-iterable CSEVariable object
I have attempted to suppress this error and fall back to eager mode with the line:
torch._dynamo.config.suppress_errors = TrueHowever, this leads to another error:
File "/workspace/PyTorchSim/experiments/mobilenet-ssd.py", line 29, in run_mobilenet
scheduler.schedule()
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 461, in schedule
result.append(self.per_schedule(i))
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 446, in per_schedule
self.execution_engine.submit(request_list, request_queue_idx)
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 208, in submit
self.prepare_model(batched_req_model)
File "/workspace/PyTorchSim/Scheduler/scheduler.py", line 233, in prepare_model
ret = req_model.model(*input_tensor_list)
....
File "/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py", line 142, in forward
image, target_index = self.resize(image, target_index)
File "/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py", line 179, in resize
def resize(
AttributeError: 'generator' object has no attribute 'shape'
Suggesting a similar issue to #204, where a generator object is being passed unexpectedly.
To Reproduce
Here's a script to run the function: https://gist.github.com/oluwatimilehin/8d786e3fc87143525ff7305cd40552ab. Note that I have added the following lines to extension_device.cpp to get to this point:
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
.....
m.impl("clamp_min.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("clamp_max.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("ceil.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("upsample_bilinear2d.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working