Skip to content

Commit 1c2c8bf

Browse files
committed
[Fix] minor bugs
1 parent 8985ab8 commit 1c2c8bf

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

PyTorchSimFrontend/extension_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __getattr__(name):
6767
"multi_tile_conv",
6868
"subtile"
6969
}
70-
if opt_level == "all" or opt_level is "none":
70+
if opt_level == "all" or opt_level == "none":
7171
pass
7272
elif isinstance(opt_level, list):
7373
# Check if provided list contains only valid options

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,12 +1381,12 @@ def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index):
13811381
dim_list = []
13821382
for idx in range(len(tile_size)):
13831383
# Prepare initial values
1384-
offset = tile_desc.vlane_stride #* strides[idx]
1385-
outer_sz = tile_size[idx] // tile_desc.vlane_stride
1384+
offset = tile_desc.vmap.vlane_stride #* strides[idx]
1385+
outer_sz = tile_size[idx] // tile_desc.vmap.vlane_stride
13861386
with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse):
13871387
div_coeff = self.get_const_cse(strides[idx], "index")
13881388
mod_coeff = self.get_const_cse(tile_size[idx], "index")
1389-
vlane_stride_coeff = self.get_const_cse(tile_desc.vlane_stride, "index")
1389+
vlane_stride_coeff = self.get_const_cse(tile_desc.vmap.vlane_stride, "index")
13901390
vlane_outer_coeff = self.get_const_cse(outer_sz, "index")
13911391
nr_vector_lane = self.get_const_cse(self.vector_lane, "index")
13921392
vlane_coeff = self.get_const_cse(0, "i64")

Scheduler/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def setup_device():
179179
)
180180

181181
torch.utils.rename_privateuse1_backend("npu")
182-
torch._register_device_module("extension_device", module)
182+
torch._register_device_module("npu", module)
183183
from torch._inductor.codegen.common import (
184184
get_scheduling_for_device,
185185
get_wrapper_codegen_for_device,

0 commit comments

Comments
 (0)