Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 66 additions & 13 deletions include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include <utility>
#include <vector>

#include "tvm/tir/var.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -158,6 +160,22 @@ class Layout : public ObjectRef {
return undef;
}

/*!
* \brief Packs the Given Array of IterVars into a Single IterVar. Each IterVar in the Array
* should represent either a single primal axis or one or more subordinate axis
* \param iters Array of iter vars to be packed
* \return A packed iter var
*/
static IterVar PackIterVar(ffi::Array<IterVar> iters);

/*!
* \brief Unpacks a Packed IterVar into its constituents
* \param packed_iter A Packed IterVar containing a single primal axis or one or more subordinate
* axis
* \return Constituent IterVars
*/
static ffi::Array<IterVar> UnpackIterVar(IterVar packed_iter);

/*!
* \brief Returns a sub-layout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
Expand Down Expand Up @@ -187,9 +205,12 @@ class Layout : public ObjectRef {
inline size_t ndim_primal() const {
if (!defined()) return 0;
size_t ct = 0;
for (auto x : operator->()->axes) {
if (LayoutAxis::Get(x).IsPrimal()) {
ct++;
for (auto px : operator->()->axes) {
auto iter_vars = UnpackIterVar(px);
for (auto x : iter_vars) {
if (LayoutAxis::Get(x).IsPrimal()) {
ct++;
}
}
}
return ct;
Expand All @@ -204,10 +225,13 @@ class Layout : public ObjectRef {
Layout new_src_layout;
// 1) Find the axis which are missing in the current layout. Make them the prefix.
std::string new_src_layout_str = "";
for (auto dst_axis : dst_layout->axes) {
if (LayoutAxis::Get(dst_axis).IsPrimal()) {
if (!this->Contains(LayoutAxis::Get(dst_axis))) {
new_src_layout_str += dst_axis->var->name_hint;
for (auto packed_axis : dst_layout->axes) {
auto iter_vars = UnpackIterVar(packed_axis);
for (auto dst_axis : iter_vars) {
if (LayoutAxis::Get(dst_axis).IsPrimal()) {
if (!this->Contains(LayoutAxis::Get(dst_axis))) {
new_src_layout_str += dst_axis->var->name_hint;
}
}
}
}
Expand All @@ -221,18 +245,36 @@ class Layout : public ObjectRef {
* \brief return the index of the input axis.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param axis the input axis.
* \param axis The input axis either a layout axis, or a packed axis
* \return the index or -1 if not found.
*/
inline int32_t IndexOf(const LayoutAxis& axis) const {
inline int32_t IndexOf(const std::string& axis) const {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
if (axes[i]->var->name_hint == axis) return static_cast<int32_t>(i);
}
return -1;
}

/*!
* \brief return the index of the input axis.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param axis the input layout axis.
* \return the index or -1 if not found.
*/
inline int32_t IndexOf(const LayoutAxis& axis) const { return IndexOf(axis.name()); }

/*!
* \brief return the index of the input axis.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param iter the input iter var.
* \return the index or -1 if not found.
*/
inline int32_t IndexOf(const tir::IterVar& iter) const { return IndexOf(iter->var->name_hint); }

/*!
* \brief Get the factor size of the subordinate axis.
* \param axis the input primal-axis or subordinate-axis.
Expand All @@ -249,9 +291,12 @@ class Layout : public ObjectRef {
*/
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const tir::IterVar var : operator->()->axes) {
if (var->var->name_hint == axis.name()) {
return true;
for (const tir::IterVar packed_var : operator->()->axes) {
auto iter_vars = UnpackIterVar(packed_var);
for (auto var : iter_vars) {
if (var->var->name_hint == axis.name()) {
return true;
}
}
}
return false;
Expand All @@ -265,6 +310,14 @@ class Layout : public ObjectRef {
return LayoutAxis::Get(axis);
}

IterVar PackedAxisAt(int32_t i) const {
ICHECK(defined()) << "Try to access axis from an undefined layout.";
int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
const tir::IterVar axis = operator->()->axes[index];
return axis;
}

/*! \return the string description of the layout */
inline std::string name() const {
if (!defined()) return "__undef__";
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/tir/data_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __len__(self):
return _ffi_api.LayoutNdim(self) # type: ignore

def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
# Note: We do a weaker check for packed axis assuming layout is valid
return not any(bkt in axis for bkt in "[]") and axis in self.name

def __getitem__(self, index):
if index >= len(self):
Expand All @@ -54,7 +55,7 @@ def index_of(self, axis):
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
The axis name, needs to be [a-z,A-Z] or a packed axis

Returns
-------
Expand Down
Loading
Loading