Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/reduce_dims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,20 @@ static std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
static shape mask_shape(const shape& s, const std::vector<std::size_t>& lens)
{
assert(s.lens().size() == lens.size());

std::vector<std::size_t> mlens;
std::transform(s.lens().begin(), s.lens().end(), lens.begin(), std::back_inserter(mlens), [](auto x, auto y) -> std::size_t {
if(x != y)
return 1;
return x;
});
shape base{s.type(), mlens};
std::vector<std::size_t> rstrides(lens.size());
std::size_t stride = 1;
for(std::size_t i = lens.size() - 1; i < lens.size(); i--)
for(std::size_t i = 0; i < lens.size(); i++)
{
if(lens[i] == s.lens()[i])
{
rstrides[i] = stride;
stride *= lens[i];
}
else if(lens[i] != 1 and s.lens()[i] != 1)
{
return shape{};
rstrides[i] = base.strides()[i];
}
}
return shape{s.type(), lens, rstrides};
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ static std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
return (x + y - std::size_t{1}) / y;
}

static std::size_t compute_tile_factor(std::size_t r, std::size_t max_size = 64)
std::size_t compute_tile_factor(std::size_t r, std::size_t max_size)
{
std::size_t n = 1;
auto factors = make_array(2, 3, 5, 7, 11);
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ struct preload
bool is_preloading() const;
std::string str() const;
};

std::size_t compute_tile_factor(std::size_t r, std::size_t max_size = 64);
struct tile
{
enum mode
Expand Down
38 changes: 35 additions & 3 deletions src/targets/gpu/jit/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(${concat_args})(${post}, y, xs...);
concat::run<concat::${algo}, ${axis}>(${concat_args})(${post}, y, xs...);
});
}

Expand Down Expand Up @@ -81,6 +81,15 @@ struct concat_compiler : compiler<concat_compiler>
return result;
}

static std::size_t
max_size(const std::vector<shape>& inputs, std::size_t ninputs, std::size_t axis)
{
return std::max_element(inputs.begin(),
inputs.begin() + ninputs,
by(std::less<>{}, [&](const shape& s) { return s.lens()[axis]; }))
->lens()[axis];
}

operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
Expand All @@ -95,8 +104,8 @@ struct concat_compiler : compiler<concat_compiler>
vectorize vec{};
if(axis != concat_axis)
vec = vectorize::elements(ctx, axis, options.virtual_inputs);
auto nelements_per_op = options.virtual_inputs.back().elements() / op_names.size();
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
auto output = options.virtual_inputs.back();
auto nelements_per_op = output.elements() / op_names.size();
options.emplace_param("-Wno-float-equal");
std::vector<std::string> concat_params;
std::vector<std::string> concat_args;
Expand All @@ -114,6 +123,28 @@ struct concat_compiler : compiler<concat_compiler>
});
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
}
auto ninputs = concat_params.size();
auto max_elements_per_op =
max_size(options.virtual_inputs, ninputs, concat_axis) / vec.size;
auto avg_elements_per_op = output.lens()[concat_axis] / op_names.size();
std::string algo;
if(concat_axis == axis and max_elements_per_op < 64 and
max_elements_per_op == avg_elements_per_op)
{
std::size_t group = 1;
if(concat_axis > 0)
group = compute_tile_factor(output.lens()[concat_axis - 1], 16);
auto nslices = output.elements() / output.lens()[concat_axis];
auto block_size = compute_block_size(ctx, max_elements_per_op * group, 256);
algo = "block_tile<" + std::to_string(group) + ">";
options.set_launch_params(v, (nslices / group) * block_size, block_size);
}
else
{
algo = "simple";
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
}

auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
Expand All @@ -123,6 +154,7 @@ struct concat_compiler : compiler<concat_compiler>
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"algo", algo},
{"axis", std::to_string(concat_axis)}});
return compile_hip_code_object(ctx, src, options);
}
Expand Down
16 changes: 16 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,22 @@ struct integral_const_array : array<T, sizeof...(Xs)>
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {}

constexpr const base_array& base() const { return *this; }

constexpr base_array carry(base_array result) const
{
index_int overflow = 0;
for(diff_int i = result.size() - 1; i > 0; i--)
{
auto z = result[i] + overflow;
if(z >= this->d[i])
{
result[i] = z % this->d[i];
overflow = z / this->d[i];
}
}
result[0] += overflow;
return result;
}
};

template <class T, class... Ts>
Expand Down
Loading
Loading