diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 4f2a4452b89f..1f8200a1ba31 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -35,6 +35,8 @@ #include #include +#include "tvm/tir/var.h" + namespace tvm { namespace tir { @@ -158,6 +160,22 @@ class Layout : public ObjectRef { return undef; } + /*! + * \brief Packs the Given Array of IterVars into a Single IterVar. Each IterVar in the Array + * should represent either a single primal axis or one or more subordinate axis + * \param iters Array of iter vars to be packed + * \return A packed iter var + */ + static IterVar PackIterVar(ffi::Array iters); + + /*! + * \brief Unpacks a Packed IterVar into its constituents + * \param packed_iter A Packed IterVar containing a single primal axis or one or more subordinate + * axis + * \return Constituent IterVars + */ + static ffi::Array UnpackIterVar(IterVar packed_iter); + /*! * \brief Returns a sub-layout which is the portion of the object * that starts at dimension \p pos and spans \p len dimensions @@ -187,9 +205,12 @@ class Layout : public ObjectRef { inline size_t ndim_primal() const { if (!defined()) return 0; size_t ct = 0; - for (auto x : operator->()->axes) { - if (LayoutAxis::Get(x).IsPrimal()) { - ct++; + for (auto px : operator->()->axes) { + auto iter_vars = UnpackIterVar(px); + for (auto x : iter_vars) { + if (LayoutAxis::Get(x).IsPrimal()) { + ct++; + } } } return ct; @@ -204,10 +225,13 @@ class Layout : public ObjectRef { Layout new_src_layout; // 1) Find the axis which are missing in the current layout. Make them the prefix. std::string new_src_layout_str = ""; - for (auto dst_axis : dst_layout->axes) { - if (LayoutAxis::Get(dst_axis).IsPrimal()) { - if (!this->Contains(LayoutAxis::Get(dst_axis))) { - new_src_layout_str += dst_axis->var->name_hint; + for (auto packed_axis : dst_layout->axes) { + auto iter_vars = UnpackIterVar(packed_axis); + for (auto dst_axis : iter_vars) { + if (LayoutAxis::Get(dst_axis).IsPrimal()) { + if (!this->Contains(LayoutAxis::Get(dst_axis))) { + new_src_layout_str += dst_axis->var->name_hint; + } } } } @@ -221,18 +245,36 @@ class Layout : public ObjectRef { * \brief return the index of the input axis. * If it is not found in the layout or the layout is undefined, * return -1. - * \param axis the input axis. + * \param axis The input axis either a layout axis, or a packed axis * \return the index or -1 if not found. */ - inline int32_t IndexOf(const LayoutAxis& axis) const { + inline int32_t IndexOf(const std::string& axis) const { if (!this->defined()) return -1; const auto axes = operator->()->axes; for (size_t i = 0; i < axes.size(); ++i) { - if (axes[i]->var->name_hint == axis.name()) return static_cast(i); + if (axes[i]->var->name_hint == axis) return static_cast(i); } return -1; } + /*! + * \brief return the index of the input axis. + * If it is not found in the layout or the layout is undefined, + * return -1. + * \param axis the input layout axis. + * \return the index or -1 if not found. + */ + inline int32_t IndexOf(const LayoutAxis& axis) const { return IndexOf(axis.name()); } + + /*! + * \brief return the index of the input axis. + * If it is not found in the layout or the layout is undefined, + * return -1. + * \param iter the input iter var. + * \return the index or -1 if not found. + */ + inline int32_t IndexOf(const tir::IterVar& iter) const { return IndexOf(iter->var->name_hint); } + /*! * \brief Get the factor size of the subordinate axis. * \param axis the input primal-axis or subordinate-axis. @@ -249,9 +291,12 @@ class Layout : public ObjectRef { */ bool Contains(const LayoutAxis& axis) const { if (!defined()) return false; - for (const tir::IterVar var : operator->()->axes) { - if (var->var->name_hint == axis.name()) { - return true; + for (const tir::IterVar packed_var : operator->()->axes) { + auto iter_vars = UnpackIterVar(packed_var); + for (auto var : iter_vars) { + if (var->var->name_hint == axis.name()) { + return true; + } } } return false; @@ -265,6 +310,14 @@ class Layout : public ObjectRef { return LayoutAxis::Get(axis); } + IterVar PackedAxisAt(int32_t i) const { + ICHECK(defined()) << "Try to access axis from an undefined layout."; + int32_t index = i < 0 ? static_cast(ndim() + i) : i; + ICHECK(index >= 0 && static_cast(index) < ndim()) << "Invalid index " << i; + const tir::IterVar axis = operator->()->axes[index]; + return axis; + } + /*! \return the string description of the layout */ inline std::string name() const { if (!defined()) return "__undef__"; diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index f9c0e0cdc7ce..14f02373d991 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -41,7 +41,8 @@ def __len__(self): return _ffi_api.LayoutNdim(self) # type: ignore def __contains__(self, axis): - return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name + # Note: We do a weaker check for packed axis assuming layout is valid + return not any(bkt in axis for bkt in "[]") and axis in self.name def __getitem__(self, index): if index >= len(self): @@ -54,7 +55,7 @@ def index_of(self, axis): Parameters ---------- axis : str - The axis name, need to be [a-z,A-Z] + The axis name, needs to be [a-z,A-Z] or a packed axis Returns ------- diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 75f9bb50d15e..c3d91029db8e 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -24,9 +24,17 @@ #include #include #include +#include +#include +#include +#include +#include #include +#include #include +#include +#include #include namespace tvm { @@ -78,17 +86,28 @@ Layout::Layout(const ffi::Array& axes) { auto node = ffi::make_object(); node->axes = axes; std::ostringstream repr; - for (const IterVar& axis : axes) { - if (const auto* factor = axis->dom->extent.as()) { - ICHECK_GT(factor->value, 0); - repr << factor->value; + + for (const IterVar& packed_axis : axes) { + auto unpacked_axes = UnpackIterVar(packed_axis); + bool is_grouped = unpacked_axes.size() > 1; + + if (is_grouped) repr << "["; + for (const IterVar& axis : unpacked_axes) { + if (const auto* factor = axis->dom->extent.as()) { + ICHECK_GT(factor->value, 0); + repr << factor->value; + } else { + ICHECK(!is_grouped) << "Only Subordinate Axes with extent is allowed within a packed dim"; + } + ICHECK_EQ(axis->var.get()->name_hint.size(), 1) + << "Invalid layout axis " << axis->var.get()->name_hint; + char c = axis->var.get()->name_hint.operator std::string()[0]; + ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; + repr << axis->var.get()->name_hint; } - ICHECK_EQ(axis->var.get()->name_hint.size(), 1) - << "Invalid layout axis " << axis->var.get()->name_hint; - char c = axis->var.get()->name_hint.operator std::string()[0]; - ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; - repr << axis->var.get()->name_hint; + if (is_grouped) repr << "]"; } + node->name = repr.str(); data_ = std::move(node); } @@ -104,46 +123,91 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) // parse layout string int32_t factor = 0; + bool in_packing = false; + std::vector unpacked_axes; + for (char c : name) { if (c >= 'A' && c <= 'Z') { ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; - std::string shape_name("_shape"); - shape_name.insert(0, 1, c); - IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype), - tir::kDataPar); - node->axes.push_back(axis); + IterVar axis(Range(IntImm(dtype, 0), Var(std::string(1, c), dtype)), + Var(std::string(1, c), dtype), tir::kDataPar); + if (!in_packing) { + node->axes.push_back(axis); + } else { + unpacked_axes.push_back(axis); + } } else if (c >= 'a' && c <= 'z') { ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " for dimension " << c; - IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype), + std::stringstream name; + name << factor << c; + IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(name.str(), dtype), tir::kDataPar); - node->axes.push_back(axis); + if (!in_packing) { + node->axes.push_back(axis); + } else { + unpacked_axes.push_back(axis); + } factor = 0; } else if (c >= '0' && c <= '9') { ICHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number."; factor = factor * 10 + c - '0'; + } else if (c == '[') { + ICHECK(!in_packing) << "Invalid layout " << name << ": can't do nested packing"; + in_packing = true; + } else if (c == ']') { + ICHECK(in_packing) << "Invalid layout " << name << ": encountered ] without matching bracket"; + ICHECK(unpacked_axes.size() > 1) + << "Invalid layout " << name << ": found empty/single packed axis"; + std::stringstream ss; + int64_t extent = 1; + for (auto& axis : unpacked_axes) { + ICHECK(axis->dom->extent.as()) + << "Invalid Layout " << name << ": can't have variable sized node(" + << axis->var->name_hint << ") within a packed axis"; + auto axis_name = axis->var->name_hint.operator std::string(); + auto factor = axis->dom->extent.as().value(); + ss << axis_name; + extent = extent * factor->value; + } + std::string grouped_name = ss.str(); + IterVar grouped_axis(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(grouped_name, dtype), + tir::kDataPar); + node->axes.push_back(grouped_axis); + + in_packing = false; + unpacked_axes.clear(); } else { LOG(FATAL) << "Invalid layout " << name; } } + ICHECK(in_packing == false) << "Invalid Layout " << name + << ": haven't terminated the packing sequence"; // validate layout - std::vector exist_axis(256, false); - for (const IterVar& v : node->axes) { - auto axis_str = v->var.get()->name_hint.operator std::string(); - ICHECK_EQ(axis_str.size(), 1); - char axis = axis_str[0]; - ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); - exist_axis[axis] = true; + std::vector axis_cnt(256, 0); + for (const IterVar& pv : node->axes) { + for (const IterVar& v : UnpackIterVar(pv)) { + auto axis_str = v->var.get()->name_hint.operator std::string(); + ICHECK_EQ(axis_str.size(), 1); + char axis = axis_str[0]; + ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); + axis_cnt[axis] += 1; + } } - for (const IterVar& v : node->axes) { - char axis = v->var.get()->name_hint.operator std::string()[0]; - if (axis >= 'a' && axis <= 'z') { - ICHECK(exist_axis[axis - 'a' + 'A']) - << "Invalid layout " << name << ": missing axis " << std::toupper(axis); + for (const IterVar& pv : node->axes) { + for (const IterVar& v : UnpackIterVar(pv)) { + char axis = v->var.get()->name_hint.operator std::string()[0]; + if (axis >= 'a' && axis <= 'z') { + ICHECK(axis_cnt[axis - 'a' + 'A']) + << "Invalid layout " << name << ": missing axis " << std::toupper(axis); + ICHECK(axis_cnt[axis] == 1) << "Invalid layout " << name + << ": found more than one subordinate " << std::toupper(axis); + } } } + data_ = std::move(node); } @@ -159,27 +223,45 @@ Layout Layout::SubLayout(size_t pos, size_t len) const { return Layout(new_layout); } -Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const { - if (!defined()) return Layout::Undef(); - const std::string& name = operator->()->name; - const auto axes = operator->()->axes; - ICHECK(target_pos <= this->ndim()) - << "Invalid split position " << target_pos << " for layout " << name; - ICHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis; - ICHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name; - ICHECK(!this->Contains(axis.ToSubordinate())) - << "Axis " << axis << " has already been split in " << name; - ICHECK(factor > 0) << "Invalid split size " << factor; - ffi::Array new_layout; - for (size_t i = 0; i <= this->ndim(); ++i) { - if (i == target_pos) { - new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)), - Var(axis.ToSubordinate().name()), tir::kDataPar)); +ffi::Array Layout::UnpackIterVar(IterVar packed_iter) { + ffi::Array result; + int64_t factor = 0, final_factor = 1; + + std::string name(packed_iter->var->name_hint.c_str()); + DataType dtype = packed_iter->var.dtype(); + + for (auto ch : name) { + if (ch >= '0' && ch <= '9') { + factor = factor * 10 + (ch - '0'); + } else if (ch >= 'a' && ch <= 'z') { + ICHECK(factor != 0) << "Invalid Factor Size"; + result.push_back(IterVar(Range(IntImm(dtype, 0), IntImm(dtype, factor)), + Var(std::string(1, ch), dtype), tir::kDataPar)); + final_factor *= factor; + factor = 0; + } else if (ch >= 'A' && ch <= 'Z') { + ICHECK(factor == 0) << "Can't have non-zero factors for primal axis"; + result.push_back(IterVar(Range(IntImm(dtype, 0), Var(std::string(1, ch), dtype)), + Var(std::string(1, ch), dtype), tir::kDataPar)); } - if (i == this->ndim()) break; - new_layout.push_back(axes[i]); } - return Layout(new_layout); + + return result; +} + +IterVar Layout::PackIterVar(ffi::Array iter_vars) { + std::stringstream name; + size_t extent = 1; + + DataType dtype = iter_vars[0]->dom->extent.as().value()->dtype; + for (auto itvar : iter_vars) { + ICHECK(itvar->dom->extent.as()) << "Packed Axis can contain only Subordinate Axes"; + name << itvar->dom->extent.as().value() << itvar->var->name_hint; + extent = extent * itvar->dom->extent.as().value()->value; + } + + return IterVar(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(name.str(), dtype), + tir::kDataPar); } int32_t Layout::FactorOf(const LayoutAxis& axis) const { @@ -188,12 +270,13 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { int32_t factor = 1; bool has_sub = false; - for (const IterVar& itvar : operator->()->axes) { - if (sub == LayoutAxis::Get(itvar)) { - has_sub = true; - int32_t val = itvar->dom->extent.as()->value; - ICHECK(val); - factor *= val; + for (const IterVar& packed_itvar : operator->()->axes) { + for (auto itvar : UnpackIterVar(packed_itvar)) { + if (sub == LayoutAxis::Get(itvar)) { + has_sub = true; + int32_t val = itvar->dom->extent.as()->value; + factor *= val; + } } } factor = has_sub ? factor : -1; @@ -218,63 +301,120 @@ inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* return false; } - for (size_t i = 0; i < dst_layout.ndim(); ++i) { - const auto& store_axis = dst_layout[i]; - const IterVar& store_axis_impl = dst_layout->axes[i]; - PrimExpr index_store(0); - - for (size_t j = 0; j < src_layout.ndim(); ++j) { - const auto& orig_axis = src_layout[j]; - const IterVar& orig_axis_impl = src_layout->axes[j]; - if (store_axis.ToPrimal() == orig_axis.ToPrimal()) { - if (orig_axis.IsPrimal()) { - PrimExpr orig_var = orig_axis_impl->var; - const int32_t factor = src_layout.FactorOf(orig_axis); - if (factor > 0) { - orig_var = orig_var * factor; - } - index_store = index_store + orig_var; - } else { - PrimExpr factor(1); - for (size_t k = j + 1; k < src_layout.ndim(); ++k) { - if (LayoutAxis::Get(orig_axis_impl) == LayoutAxis::Get(src_layout->axes[k])) { - factor = factor * src_layout->axes[k]->dom->extent; + std::vector exists(128, false); + PrimExpr norm_indexes[128]; + for (auto& it : norm_indexes) it = PrimExpr(0); + + for (size_t i = 0; i < src_layout.ndim(); i++) { + auto factor = src_layout.PackedAxisAt(i)->dom->extent; + auto src_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(i)); + + if (src_unpacked_axes.size() == 1 && LayoutAxis::Get(src_unpacked_axes[0]).IsPrimal()) { + const auto& prim_axis = LayoutAxis::Get(src_unpacked_axes[0]); + int64_t offset = src_layout.FactorOf(prim_axis); + if (offset == -1) + norm_indexes[prim_axis.name()[0] - 'A'] = + norm_indexes[prim_axis.name()[0] - 'A'] + src_layout.PackedAxisAt(i); + else + norm_indexes[prim_axis.name()[0] - 'A'] = + norm_indexes[prim_axis.name()[0] - 'A'] + + src_layout.PackedAxisAt(i) * src_layout.FactorOf(prim_axis); + exists[prim_axis.name()[0]] = true; + } else { + int64_t value = 1; + std::vector index_divs(src_unpacked_axes.size()); + for (size_t j = 0; j < src_unpacked_axes.size(); j++) { + index_divs[j] = value; + const auto* extent = src_unpacked_axes[j]->dom->extent.as(); + ICHECK(extent) << "Expected Integer Extents for Offset Calculation"; + index_divs.push_back(value); + value = value * extent->value; + } + std::reverse(index_divs.begin(), index_divs.end()); + + for (size_t j = 0; j < src_unpacked_axes.size(); j++) { + const int extent = src_unpacked_axes[j]->dom->extent.as()->value; + const LayoutAxis& store_axis_impl = LayoutAxis::Get(src_unpacked_axes[j]); + const LayoutAxis& sub_axis = store_axis_impl.ToSubordinate(); /* Not Needed */ + const LayoutAxis& prim_axis = store_axis_impl.ToPrimal(); + + PrimExpr factor_ij = indexdiv(src_layout.PackedAxisAt(i), index_divs[j]); + if (j != 0) factor_ij = indexmod(factor_ij, extent); + + for (size_t k = i; k < src_layout.ndim(); k++) { + size_t l = 0; + if (k == i) l = j + 1; + + auto inter_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(k)); + for (; l < inter_unpacked_axes.size(); l++) { + const LayoutAxis& axis = LayoutAxis::Get(inter_unpacked_axes[l]); + if (axis == sub_axis) { + const auto* sub_extent = inter_unpacked_axes[l]->dom->extent.as(); + ICHECK(sub_extent) << "Expected Integer Extents for Offset Calculation"; + factor_ij = factor_ij * IntImm(sub_extent->dtype, sub_extent->value); } } - index_store = index_store + orig_axis_impl->var * factor; } + + norm_indexes[prim_axis.name()[0] - 'A'] = + norm_indexes[prim_axis.name()[0] - 'A'] + factor_ij; } } - if (tir::is_zero(index_store)) { - LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" << dst_layout.name() - << "' is not convertible."; - return false; - } + } + + arith::Analyzer ana; - PrimExpr shape_store = index_store; - if (store_axis.IsPrimal()) { - const int32_t factor = dst_layout.FactorOf(store_axis); - if (factor > 0) { - shape_store = shapediv(index_store, PrimExpr(factor)); - index_store = indexdiv(index_store, PrimExpr(factor)); + for (size_t i = 0; i < dst_layout.ndim(); i++) { + const auto dst_unpacked_axes = Layout::UnpackIterVar(dst_layout.PackedAxisAt(i)); + + if (dst_unpacked_axes.size() == 1 && LayoutAxis::Get(dst_unpacked_axes[0]).IsPrimal()) { + const auto& prim_axis = LayoutAxis::Get(dst_unpacked_axes[0]); + if (!exists[prim_axis.name()[0]]) return false; + int64_t offset = dst_layout.FactorOf(prim_axis); + if (offset != -1) { + index_rule->push_back( + indexdiv(norm_indexes[prim_axis.name()[0] - 'A'], dst_layout.FactorOf(prim_axis))); + shape_rule->push_back( + indexdiv(norm_indexes[prim_axis.name()[0] - 'A'] + (dst_layout.FactorOf(prim_axis) - 1), + dst_layout.FactorOf(prim_axis))); + } else { + index_rule->push_back(norm_indexes[prim_axis.name()[0] - 'A']); + shape_rule->push_back(norm_indexes[prim_axis.name()[0] - 'A']); } } else { - PrimExpr stride(1); - PrimExpr factor(1); - for (size_t j = i; j < dst_layout.ndim(); ++j) { - if (LayoutAxis::Get(store_axis_impl) == LayoutAxis::Get(dst_layout->axes[j])) { - stride = stride * dst_layout->axes[j]->dom->extent; - if (j > i) { - factor = factor * dst_layout->axes[j]->dom->extent; + PrimExpr factor(0); + for (size_t j = 0; j < dst_unpacked_axes.size(); j++) { + const auto& prim_axis = LayoutAxis::Get(dst_unpacked_axes[j]).ToPrimal(); + const auto& sub_axis = LayoutAxis::Get(dst_unpacked_axes[j]).ToSubordinate(); + const auto* extent = dst_unpacked_axes[j]->dom->extent.as(); + ICHECK(extent) << "Expected extent to be IntImmNode"; + + size_t divfactor = 1; + for (size_t k = i; k < dst_layout.ndim(); k++) { + size_t l = 0; + if (k == i) l = j + 1; + + const auto inter_unpacked_axes = Layout::UnpackIterVar(dst_layout.PackedAxisAt(k)); + for (; l < inter_unpacked_axes.size(); l++) { + const auto& axis = LayoutAxis::Get(inter_unpacked_axes[l]); + if (sub_axis == axis) { + const auto* sub_extent = inter_unpacked_axes[l]->dom->extent.as(); + ICHECK(sub_extent) << "Expected Integer Extents for Offset Calculation"; + divfactor = divfactor * sub_extent->value; + } } } + + factor = factor + indexmod(indexdiv(norm_indexes[prim_axis.name()[0] - 'A'], divfactor), + extent->value); + for (size_t k = j + 1; k < dst_unpacked_axes.size(); k++) { + factor = factor * dst_unpacked_axes[k]->dom->extent.as().value(); + } } - shape_store = indexdiv(indexmod(index_store, stride), factor); - index_store = indexdiv(indexmod(index_store, stride), factor); + ana.Simplify(factor); + index_rule->push_back(factor); + shape_rule->push_back(factor); } - - index_rule->push_back(index_store); - shape_rule->push_back(shape_store); } std::stringstream ss; @@ -289,7 +429,7 @@ inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* ss << r << ", "; } ss << "]" << std::endl; - VLOG(1) << std::endl << ss.str(); + VLOG(1) << ss.str() << std::endl; return true; } @@ -341,7 +481,8 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape for (size_t i = 0; i < src_shape.size(); ++i) { PrimExpr orig_shape = src_shape[i]; IterVar orig_axis = src_axis[i]; - if (!LayoutAxis::Get(orig_axis).IsPrimal()) { + auto layout = Layout::UnpackIterVar(orig_axis); + if (layout.size() != 1 || !LayoutAxis::Get(layout[0]).IsPrimal()) { if (orig_shape.defined()) { const auto* orig_shape_const = orig_shape.as(); const auto* orig_axis_extent = orig_axis->dom->extent.as(); @@ -366,7 +507,8 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape for (size_t i = 0; i < transform_rule.size(); ++i) { PrimExpr rule = transform_rule[i]; IterVar axis = target_axis[i]; - if (!LayoutAxis::Get(axis).IsPrimal()) { + auto layout = Layout::UnpackIterVar(axis); + if (layout.size() != 1 || !LayoutAxis::Get(layout[0]).IsPrimal()) { result.push_back(axis->dom->extent); } else { result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); @@ -435,9 +577,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def("tir.Layout", [](std::string name, DataType dtype) { return Layout(name, dtype); }) .def("tir.LayoutIndexOf", - [](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::Get(axis)); - }) + [](Layout layout, std::string axis) -> int { return layout.IndexOf(axis); }) .def("tir.LayoutFactorOf", [](Layout layout, std::string axis) -> int { return layout.FactorOf(LayoutAxis::Get(axis)); @@ -445,8 +585,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("tir.LayoutNdim", [](Layout layout) -> int { return layout.ndim(); }) .def("tir.LayoutGetItem", [](Layout layout, int idx) -> std::string { - const LayoutAxis& axis = layout[idx]; - return axis.name(); + const auto& axis = layout.PackedAxisAt(idx); + return axis->var->name_hint; }) .def("tir.BijectiveLayout", [](Layout src_layout, Layout dst_layout) -> BijectiveLayout { diff --git a/tests/python/tir-base/test_tir_data_layout.py b/tests/python/tir-base/test_tir_data_layout.py index a76cb50da3bd..df63e70143d5 100644 --- a/tests/python/tir-base/test_tir_data_layout.py +++ b/tests/python/tir-base/test_tir_data_layout.py @@ -18,8 +18,9 @@ import pytest import tvm -import tvm.error +import tvm.testing from tvm.topi.utils import get_const_tuple +from tvm.error import InternalError def test_layout(): @@ -35,7 +36,7 @@ def test_layout(): assert layout.index_of("C") == 1 assert layout.index_of("H") == 2 assert layout.index_of("W") == 3 - assert layout.index_of("c") == 4 + assert layout.index_of("16c") == 4 assert layout.index_of("O") == -1 assert "N" in layout @@ -49,8 +50,50 @@ def test_layout(): assert layout[1] == "C" assert layout[2] == "H" assert layout[3] == "W" - assert layout[4] == "c" - assert layout[-1] == "c" + assert layout[4] == "16c" + + layout = tvm.tir.layout("OIHW[4o4i]") + assert layout is not None + assert isinstance(layout, tvm.tir.Layout) + + assert layout.factor_of("o") == 4 + assert layout.factor_of("i") == 4 + assert layout.factor_of("H") == -1 + assert layout.factor_of("W") == -1 + assert layout.factor_of("N") == -1 + + assert layout.index_of("O") == 0 + assert layout.index_of("I") == 1 + assert layout.index_of("H") == 2 + assert layout.index_of("W") == 3 + assert layout.index_of("4o4i") == 4 + assert layout.index_of("i") == -1 + assert layout.index_of("o") == -1 + + assert "O" in layout + assert "I" in layout + assert "H" in layout + assert "W" in layout + assert "4o4i" in layout + assert "i" in layout + assert "o" in layout + + assert layout[0] == "O" + assert layout[1] == "I" + assert layout[2] == "H" + assert layout[3] == "W" + assert layout[4] == "4o4i" + + with pytest.raises(InternalError): + layout = tvm.tir.layout("[N4o]C") + with pytest.raises(InternalError): + layout = tvm.tir.layout("[O4o]") + with pytest.raises(InternalError): + layout = tvm.tir.layout("C4o") + with pytest.raises(InternalError): + layout = tvm.tir.layout("OI[4o4i][]") + with pytest.raises(InternalError): + layout = tvm.tir.layout("C4c[4c]") def test_layout_dtype(): @@ -84,6 +127,8 @@ def test_bilayout_convertible(): assert tvm.tir.bijective_layout("__undef__", "__undef__") is None assert tvm.tir.bijective_layout("", "NCHW") is None assert tvm.tir.bijective_layout("NCHW", "") is None + assert tvm.tir.bijective_layout("OIHW", "OIHW[4o4i]") is not None + assert tvm.tir.bijective_layout("OIHW[2o4i]", "OIHW") is not None assert tvm.tir.bijective_layout("", "") is None # convertible assert tvm.tir.bijective_layout("NCHW", "NCHW16c") is not None @@ -99,6 +144,14 @@ def test_bilayout_shape(): src_shape = bilayout.backward_shape(dst_shape) assert get_const_tuple(src_shape) == (1, 32, 7, 7) + bilayout = tvm.tir.bijective_layout("OIHW", "OIHW[4o4i]") + + dst_shape = bilayout.forward_shape((64, 28, 7, 7)) + assert get_const_tuple(dst_shape) == (16, 7, 7, 7, 16) + + src_shape = bilayout.backward_shape((2, 11, 4, 4, 16)) + assert get_const_tuple(src_shape) == (8, 44, 4, 4) + def test_bilayout_index(): bilayout = tvm.tir.bijective_layout("NCHW", "NCHW16c") @@ -109,10 +162,14 @@ def test_bilayout_index(): src_index = bilayout.backward_index([0, 1, 6, 6, 2]) assert get_const_tuple(src_index) == (0, 18, 6, 6) + bilayout = tvm.tir.bijective_layout("OIHW", "OIHW[4o4i]") + + dst_index = bilayout.forward_index((63, 29, 7, 7)) + assert get_const_tuple(dst_index) == (15, 7, 7, 7, 13) + + src_index = bilayout.backward_index((4, 7, 4, 4, 13)) + assert get_const_tuple(src_index) == (19, 29, 4, 4) + if __name__ == "__main__": - test_layout() - test_layout_dtype() - test_bilayout_convertible() - test_bilayout_shape() - test_bilayout_index() + tvm.testing.main()