Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/onnx/parse_dequantizelinear.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -21,12 +21,12 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/onnx/quantize_dequantize_linear.hpp>
#include <migraphx/op/builder/insert.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -36,10 +36,10 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
std::vector<op_desc> operators() const { return {{"DequantizeLinear"}}; }

instruction_ref parse(const op_desc& opd,
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
const std::vector<instruction_ref>& args) const
{
if(args.size() < 2 or args.size() > 3)
{
Expand All @@ -63,18 +63,18 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
}
}

int axis = 1;
value options = {};
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
{
options.insert({"axis", info.attributes.at("axis").i()});
}

int block_size = 0;
if(contains(info.attributes, "block_size"))
block_size = info.attributes.at("block_size").i();

args = transform_quantize_dequantize_linear_inputs(
info, opd.onnx_name, block_size, axis, args);
{
options.insert({"block_size", info.attributes.at("block_size").i()});
}

return info.add_instruction(make_op("dequantizelinear"), args);
return op::builder::add("dequantizelinear", *info.mod, args, options).at(0);
}
};

Expand Down
86 changes: 16 additions & 70 deletions src/onnx/parse_quantizelinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>

#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
#include <migraphx/onnx/quantize_dequantize_linear.hpp>
#include <migraphx/op/builder/insert.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -37,10 +36,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }

instruction_ref parse(const op_desc& opd,
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
const std::vector<instruction_ref>& args) const
{
if(args.size() < 2 or args.size() > 3)
{
Expand All @@ -65,18 +64,23 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
", provided y_zero_point shape: " + to_string_range(args[2]->get_shape().lens()));
}

int axis = 1;
value options = {};
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
{
options.insert({"axis", info.attributes.at("axis").i()});
}

int block_size = 0;
if(contains(info.attributes, "block_size"))
block_size = info.attributes.at("block_size").i();
{
options.insert({"block_size", info.attributes.at("block_size").i()});
}

std::optional<migraphx::shape::type_t> output_type;
if(contains(info.attributes, "output_dtype"))
{
output_type = get_type(info.attributes.at("output_dtype").i());
const auto& out_type = get_type(info.attributes.at("output_dtype").i());
output_type = out_type;
options.insert({"output_type", out_type});
}

if(output_type.has_value() and args.size() == 3 and
Expand All @@ -88,65 +92,7 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
+", y_zero_point type: " + to_string(args[2]->get_shape().type()));
}

args = transform_quantize_dequantize_linear_inputs(
info, opd.onnx_name, block_size, axis, args);

if(output_type == migraphx::shape::fp4x2_type)
{
// Parsing in pack_fp4 and unpack_fp4 for the FP4 case
auto q_ins = info.add_instruction(
make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), args);

// packing axis set to fastest dimension
auto quantized_shape = q_ins->get_shape();
const auto& qs_strides = quantized_shape.strides();
if(qs_strides.empty())
{
MIGRAPHX_THROW("QuantizeLinear: MX type quantized_shape has no strides");
}
int fast_axis =
std::min_element(qs_strides.cbegin(), qs_strides.cend()) - qs_strides.cbegin();
bool odd_fast_axis = (quantized_shape.lens().at(fast_axis) % 2 == 1);
if(odd_fast_axis)
{
// pad fastest dimension by 1 if it is odd
std::vector<int64_t> padding(2 * quantized_shape.ndim(), 0);
padding.at(fast_axis * 2 + 1) = 1;
q_ins = info.add_instruction(make_op("pad", {{"pads", padding}}), q_ins);
}
// output is fp4x2_type
auto pack_ins = info.add_instruction(make_op("pack_fp4"), q_ins);
// output is fp8e4m3fn_type
auto unpack_ins = info.add_instruction(make_op("unpack_fp4"), pack_ins);
if(odd_fast_axis)
{
// slice off padded values
unpack_ins = info.add_instruction(
make_op("slice",
{{"axes", {fast_axis}},
{"starts", {0}},
{"ends", {quantized_shape.lens().at(fast_axis)}}}),
unpack_ins);
}
return unpack_ins;
}

if(parser.opset_version < 19)
{
auto common_type = common_shape({args[0]->get_shape(), args[1]->get_shape()}).type();
std::transform(args.begin(), args.begin() + 2, args.begin(), [&](auto ins) {
if(ins->get_shape().type() != common_type)
ins = info.add_instruction(make_op("convert", {{"target_type", common_type}}),
ins);
return ins;
});
}

if(output_type.has_value())
return info.add_instruction(make_op("quantizelinear", {{"out_type", *output_type}}),
args);
else
return info.add_instruction(make_op("quantizelinear"), args);
return op::builder::add("quantizelinear", *info.mod, args, options).at(0);
}
};

Expand Down
105 changes: 4 additions & 101 deletions src/onnx/quantize_dequantize_linear.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -23,10 +23,7 @@
*/

#include <migraphx/onnx/quantize_dequantize_linear.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
#include <migraphx/op/builder/quantize_dequantize_linear.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -39,102 +36,8 @@ transform_quantize_dequantize_linear_inputs(const onnx_parser::node_info& info,
int axis,
std::vector<instruction_ref> args)
{
const auto x = args.at(0);
const auto x_lens = x->get_shape().lens();
const auto x_rank = x_lens.size();

instruction_ref y_scale = args.at(1);
const auto y_scale_lens = y_scale->get_shape().lens();
const auto y_scale_rank = y_scale_lens.size();

// Per-tensor (per-layer) granularity
if(y_scale->get_shape().elements() == 1)
{
std::transform(args.begin() + 1, args.end(), args.begin() + 1, [&](auto ins) {
return info.add_instruction(make_op("multibroadcast", {{"out_lens", x_lens}}), ins);
});
}
// Per-axis granularity
else if(y_scale_rank == 1)
{
axis = tune_axis(x_rank, axis, onnx_name);
if(x_lens[axis] != y_scale_lens[0])
{
MIGRAPHX_THROW(onnx_name +
": For per axis granularity the length of y_scale (actual: " +
to_string(y_scale_lens[0]) + ") must be equal to size of x on axis " +
to_string(axis) + "(actual: " + to_string(x_lens[axis]) + ")");
}

std::transform(args.begin() + 1, args.end(), args.begin() + 1, [&](auto ins) {
return info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"out_lens", x_lens}}), ins);
});
}
// Blocked granularity
else
{
axis = tune_axis(x_rank, axis, onnx_name);

if(x_rank != y_scale_rank)
{
MIGRAPHX_THROW(onnx_name + ": x(rank: " + to_string(x_rank) +
") and y_scale(rank: " + to_string(y_scale_rank) +
") must be of same rank for block granularity");
}

for(auto i = 0u; i < x_lens.size(); ++i)
{
if(x_lens[i] != y_scale_lens[i] and i != axis)
{
MIGRAPHX_THROW(onnx_name + ": x(shape: " + to_string_range(x_lens) +
") and y_scale(shape: " + to_string_range(y_scale_lens) +
") shapes may only differ along provided axis(" + to_string(axis) +
")");
}
}

// Given x shape (D0, ..., Di, ..., Dn), y_scale shape (S0, ... Si, ...Sn) and
// axis=i, the accepted range is [ceil(Di/Si), ceil(Di/(Si-1))-1]
float di = x_lens[axis];
float si = y_scale_lens[axis];
int block_size_min = std::ceil(di / si);
int block_size_max = std::ceil(di / (si - 1)) - 1;
// default block_size if not given is calculated (to support quark generated models):
if(block_size == 0)
block_size = block_size_min;
if(block_size < block_size_min or block_size > block_size_max)
MIGRAPHX_THROW(onnx_name + ": Block size(actual: " + to_string(block_size) +
") must be within range [" + to_string(block_size_min) + ", " +
to_string(block_size_max) + "]");

std::transform(args.begin() + 1, args.end(), args.begin() + 1, [&](auto ins) {
if(block_size == 1)
return ins;

ins = info.add_instruction(make_op("unsqueeze", {{"axes", {axis + 1}}}), ins);

auto bc_lens = ins->get_shape().lens();
bc_lens[axis + 1] = block_size;
ins = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), ins);

auto reshape_lens = x_lens;
reshape_lens[axis] = ins->get_shape().lens()[axis] * block_size;
ins = info.add_instruction(make_op("reshape", {{"dims", reshape_lens}}), ins);

// Detect runt block
if(x_lens[axis] < reshape_lens[axis])
{
ins = info.add_instruction(
make_op("slice", {{"axes", {axis}}, {"starts", {0}}, {"ends", {x_lens[axis]}}}),
ins);
}

return ins;
});
}

return args;
return op::builder::transform_quantize_dequantize_linear_inputs(
info, onnx_name, block_size, axis, std::move(args));
}

} // namespace onnx
Expand Down
56 changes: 56 additions & 0 deletions src/op/builder/dequantizelinear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/op/builder/op_builder.hpp>
#include <migraphx/op/builder/quantize_dequantize_linear.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
namespace builder {

struct dequantizelinear : op_builder<dequantizelinear>
{
int axis = 1;
int block_size = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.block_size, "block_size"));
}

std::vector<instruction_ref>
insert(module& m, instruction_ref /*ins*/, const std::vector<instruction_ref>& args) const
{
auto args_new =
transform_quantize_dequantize_linear_inputs(m, name(), block_size, axis, args);

return {m.add_instruction(make_op(name()), args_new)};
}
};

} // namespace builder
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Loading
Loading