Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/infinicore/ops/awq_marlin_gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(AwqMarlinGemm, Tensor, const Tensor &, const Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, int64_t, bool, bool, bool, bool);

void awq_marlin_gemm_(Tensor c, const Tensor &a, const Tensor &b, Tensor &b_bias, Tensor &b_scales, Tensor &a_scales, Tensor &global_scales, Tensor &b_zeros, Tensor &g_idx, Tensor &perm, int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float);
} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "infiniop/ops/attention.h"
#include "infiniop/ops/avg_pool1d.h"
#include "infiniop/ops/avg_pool3d.h"
#include "infiniop/ops/awq_marlin_gemm.h"
#include "infiniop/ops/axpy.h"
#include "infiniop/ops/binary_cross_entropy_with_logits.h"
#include "infiniop/ops/blas_amax.h"
Expand Down
46 changes: 46 additions & 0 deletions include/infiniop/ops/awq_marlin_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef __INFINIOP_AWQ_MARLIN_GEMM_API_H__
#define __INFINIOP_AWQ_MARLIN_GEMM_API_H__

#include "../operator_descriptor.h"
#include <cstdint>

typedef struct InfiniopDescriptor *infiniopAwqMarlinGemmDescriptor_t;

__INFINI_C __export infiniStatus_t infiniopCreateAwqMarlinGemmDescriptor(infiniopHandle_t handle,
infiniopAwqMarlinGemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_bias_desc,
infiniopTensorDescriptor_t b_scales_desc,
infiniopTensorDescriptor_t a_scales_desc,
infiniopTensorDescriptor_t global_scales_desc,
infiniopTensorDescriptor_t b_zeros_desc,
infiniopTensorDescriptor_t g_idx_desc,
infiniopTensorDescriptor_t perm_desc);

__INFINI_C __export infiniStatus_t infiniopGetAwqMarlinGemmWorkspaceSize(infiniopAwqMarlinGemmDescriptor_t desc, size_t *size);

__INFINI_C __export infiniStatus_t infiniopAwqMarlinGemm(infiniopAwqMarlinGemmDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *b_bias,
void *b_scales,
void *a_scales,
void *global_scales,
void *b_zeros,
void *g_idx,
void *perm,
int64_t b_q_type_id,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float,
void *stream);

__INFINI_C __export infiniStatus_t infiniopDestroyAwqMarlinGemmDescriptor(infiniopAwqMarlinGemmDescriptor_t desc);

#endif
21 changes: 21 additions & 0 deletions src/infinicore/ops/awq_marlin_gemm/awq_marlin_gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "infinicore/ops/awq_marlin_gemm.hpp"

#include "../../utils.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AwqMarlinGemm);

AwqMarlinGemm::AwqMarlinGemm(Tensor c, const Tensor &a, const Tensor &b, Tensor &b_bias, Tensor &b_scales, Tensor &a_scales, Tensor &global_scales, Tensor &b_zeros, Tensor &g_idx, Tensor &perm, int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float);
}
void AwqMarlinGemm::execute(Tensor c, const Tensor &a, const Tensor &b, Tensor &b_bias, Tensor &b_scales, Tensor &a_scales, Tensor &global_scales, Tensor &b_zeros, Tensor &g_idx, Tensor &perm, int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AwqMarlinGemm, c, a, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float);
}

void awq_marlin_gemm_(Tensor c, const Tensor &a, const Tensor &b, Tensor &b_bias, Tensor &b_scales, Tensor &a_scales, Tensor &global_scales, Tensor &b_zeros, Tensor &g_idx, Tensor &perm, int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
AwqMarlinGemm::execute(c, a, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float);
}

} // namespace infinicore::op
82 changes: 82 additions & 0 deletions src/infinicore/ops/awq_marlin_gemm/awq_marlin_gemm_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/awq_marlin_gemm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>

namespace infinicore::op::awq_marlin_gemm_impl::infiniop {

INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AwqMarlinGemm, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, c, a, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm;
int64_t b_q_type_id;
bool is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float;
};

void *plan(Tensor c, const Tensor &a, const Tensor &b, Tensor &b_bias, Tensor &b_scales, Tensor &a_scales, Tensor &global_scales, Tensor &b_zeros, Tensor &g_idx, Tensor &perm, int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
size_t seed = hash_combine(c, a, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm);

INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, AwqMarlinGemm,
seed,
c->desc(), a->desc(),
b->desc(), b_bias->desc(), b_scales->desc(), a_scales->desc(), global_scales->desc(), b_zeros->desc(), g_idx->desc(), perm->desc());

INFINIOP_WORKSPACE_TENSOR(workspace, AwqMarlinGemm, descriptor);

return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(c),
graph::GraphTensor(a),
graph::GraphTensor(b),
graph::GraphTensor(b_bias),
graph::GraphTensor(b_scales),
graph::GraphTensor(a_scales),
graph::GraphTensor(global_scales),
graph::GraphTensor(b_zeros),
graph::GraphTensor(g_idx),
graph::GraphTensor(perm),
b_q_type_id,
is_k_full,
use_atomic_add,
use_fp32_reduce,
is_zp_float};
}

void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);

INFINICORE_CHECK_ERROR(infiniopAwqMarlinGemm(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->c->data(),
planned->a->data(),
planned->b->data(),
planned->b_bias->data(),
planned->b_scales->data(),
planned->a_scales->data(),
planned->global_scales->data(),
planned->b_zeros->data(),
planned->g_idx->data(),
planned->perm->data(),
planned->b_q_type_id,
planned->is_k_full,
planned->use_atomic_add,
planned->use_fp32_reduce,
planned->is_zp_float,
context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AwqMarlinGemm, &plan, &run, &cleanup);

} // namespace infinicore::op::awq_marlin_gemm_impl::infiniop
57 changes: 57 additions & 0 deletions src/infiniop/ops/awq_marlin_gemm/awq_marlin_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef AWQ_MARLIN_GEMM_H
#define AWQ_MARLIN_GEMM_H

#include "../../operator.h"
#include "info.h"

#define DESCRIPTOR(NAMESPACE) \
\
namespace op::awq_marlin_gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AwqMarlinGemmInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AwqMarlinGemmInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t b_bias_desc, \
infiniopTensorDescriptor_t b_scales_desc, \
infiniopTensorDescriptor_t a_scales_desc, \
infiniopTensorDescriptor_t global_scales_desc, \
infiniopTensorDescriptor_t b_zeros_desc, \
infiniopTensorDescriptor_t g_idx_desc, \
infiniopTensorDescriptor_t perm_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *c, \
const void *a, const void *b, \
void *b_bias, void *b_scales, void *a_scales, void *global_scales, \
void *b_zeros, void *g_idx, void *perm, \
int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float, \
void *stream) const; \
}; \
}

#endif // AWQ_MARLIN_GEMM_H
Loading
Loading