Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sim/simx/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ SRCS += $(SRC_DIR)/execute.cpp $(SRC_DIR)/func_unit.cpp
SRCS += $(SRC_DIR)/cache_sim.cpp $(SRC_DIR)/mem_sim.cpp $(SRC_DIR)/local_mem.cpp $(SRC_DIR)/mem_coalescer.cpp
SRCS += $(SRC_DIR)/dcrs.cpp $(SRC_DIR)/types.cpp

# sparse unit; add -DEXT_SPARSE_ENABLE flag later
SRCS += $(SRC_DIR)/sparse_unit.cpp

# Add V extension sources
ifneq ($(findstring -DEXT_V_ENABLE, $(CONFIGS)),)
Expand All @@ -42,6 +40,11 @@ endif
ifneq ($(findstring -DEXT_TCU_ENABLE, $(CONFIGS)),)
SRCS += $(SRC_DIR)/tensor_unit.cpp
endif
# Add VEGETA extension sources
ifneq ($(findstring -DEXT_VEGETA_ENABLE, $(CONFIGS)),)
SRCS += $(SRC_DIR)/vegeta_lsu.cpp
SRCS += $(SRC_DIR)/sparse_unit.cpp
endif

# Debugging
ifdef DEBUG
Expand Down
1 change: 1 addition & 0 deletions sim/simx/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Core::Core(const SimContext& ctx,
#endif
#ifdef EXT_VEGETA_ENABLE
, sparse_unit_(SparseUnit::Create("spu", arch, this))
, vegeta_lsu_(VegetaLsu::Create("vegeta_lsu", this, 1))
#endif
, emulator_(arch, dcrs, this)
, ibuffers_(arch.num_warps(), IBUF_SIZE)
Expand Down
6 changes: 6 additions & 0 deletions sim/simx/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#endif
#ifdef EXT_VEGETA_ENABLE
#include "sparse_unit.h"
#include "vegeta_lsu.h"
#endif

#include "dispatcher.h"
Expand Down Expand Up @@ -184,6 +185,10 @@ class Core : public SimObject<Core> {
SparseUnit::Ptr& sparse_unit() {
return sparse_unit_;
}

VegetaLsu::Ptr& vegeta_lsu() {
return vegeta_lsu_;
}
#endif

auto& trace_pool() {
Expand Down Expand Up @@ -217,6 +222,7 @@ class Core : public SimObject<Core> {

#ifdef EXT_VEGETA_ENABLE
SparseUnit::Ptr sparse_unit_;
VegetaLsu::Ptr vegeta_lsu_;
#endif

Emulator emulator_;
Expand Down
61 changes: 57 additions & 4 deletions sim/simx/execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_T: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1625,7 +1624,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_U: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1640,7 +1638,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_V: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1654,7 +1651,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_R: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1665,6 +1661,63 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
sparse_unit_->tile_gemm_r(dst_reg, src1_reg, src2_reg, src1_reg);
rd_write = false;
} break;
case VegetaTcuType::WMMA: {
auto tpuArgs = std::get<IntrVegetaTcuArgs>(instrArgs);
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;

// Get metadata from integer registers a0-a7 (x10-x17) for sparse fragA
// These contain metadata values loaded by mma_sync into a0-a7
DTH(3, "WMMA: current regfile values:" << std::hex << std::endl);
for (uint32_t i = 0; i < 32; ++i) {
DTN(3, " x" << std::setfill('0') << std::setw(2) << i << ": 0x" << warp.ireg_file.at(i).at(0) << std::dec << std::endl);
}
uint32_t metadata[8] = {0};
for (uint32_t reg = 0; reg < 8; ++reg) {
// a0-a7 correspond to x10-x17 in RISC-V
uint32_t a_reg = 10 + reg; // a0=10, a1=11, ..., a7=17

// Get value from integer register a_reg for thread 0 (all threads should have same metadata)
if (warp.tmask.test(0) && a_reg < warp.ireg_file.size()) {
metadata[reg] = warp.ireg_file.at(a_reg).at(0);
}
}
DTH(3, "WMMA: metadata values:" << std::hex << std::endl);
for (uint32_t i = 0; i < 8; ++i) {
DTN(3, " a" << std::setfill('0') << std::setw(1) << i << ": 0x" << metadata[i] << std::dec << std::endl);
}

// Extract sparsity degree from register t0 (x5)
uint32_t sparsity_degree = 2; // default to 2:4
const uint32_t t0_reg = 5; // t0 is x5
if (warp.tmask.test(0) && t0_reg < warp.ireg_file.size()) {
sparsity_degree = static_cast<uint32_t>(warp.ireg_file.at(t0_reg).at(0));
// Validate sparsity degree (should be 1 or 2)
if (sparsity_degree != 1 && sparsity_degree != 2) {
std::cerr << "Warning: Invalid sparsity degree " << sparsity_degree << " in register t0 (x5), using default 2" << std::endl;
sparsity_degree = 2;
}
}

// Create updated args with the extracted sparsity degree
// Note: The sparsity degree is also used by the hardware instruction via register t0 in mma_sync
IntrVegetaTcuArgs updated_args{tpuArgs.fmt_s, tpuArgs.fmt_d, tpuArgs.step_m, tpuArgs.step_n, static_cast<uint32_t>(sparsity_degree)};

// Resize rd_data and rs3_data to accommodate WMMA output (tcM * tcN elements)
// For sparse WMMA, we need at least tcM * tcN elements
namespace vt = vortex::sparse;
using cfg = vt::wmma_config_t<NUM_THREADS>;
uint32_t wmma_size = cfg::tcM * cfg::tcN;
if (rd_data.size() < wmma_size) {
rd_data.resize(wmma_size);
}
if (rs3_data.size() < wmma_size) {
rs3_data.resize(wmma_size);
}

sparse_unit_->wmma(wid, updated_args.fmt_s, updated_args.fmt_d, updated_args.step_m, updated_args.step_n, rs1_data, rs2_data, rs3_data, rd_data, trace_data.get(), metadata, updated_args.sparsity_degree);
rd_write = true;
} break;
default:
std::abort();
}
Expand Down
145 changes: 82 additions & 63 deletions sim/simx/sparse_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "sparse_cfg.h"
#include <rvfloats.h>
#include "core.h"
#include "vegeta_lsu.h"
#include <cstring>

using namespace vortex;
Expand Down Expand Up @@ -409,90 +410,97 @@ class SparseUnit::Impl {
uint32_t tile_reg_idx = vd;
assert(tile_reg_idx < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[tile_reg_idx];
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32
base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads

// Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
// Use VegetaLsu for bulk tile load (1KB)
constexpr uint32_t T_TILE_SIZE = TILE_DIM * TILE_DIM * sizeof(float);
float tile_buffer[TILE_DIM * TILE_DIM];
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::T_TILE,
tile_reg_idx, wid, tid, tile_buffer);

// Copy from linear buffer to 2D tile register
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
uint32_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});

// Interpret as float and store in tile register
float value;
std::memcpy(&value, &mem_data, ELEMENT_SIZE);
tile_reg[row][col] = value;
tile_reg[row][col] = tile_buffer[row * TILE_DIM + col];
}
}

DP(2, "TILE_LOAD_T: wid=" << wid << ", tid=" << tid
// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < TILE_DIM * TILE_DIM; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_LOAD_T (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", tile_reg_idx=" << tile_reg_idx << ", base_addr=0x" << std::hex << base_addr << std::dec);
break;
}
case VegetaLsuType::TILE_LOAD_U: {
// tile_load_u: DestReg contains ureg index, map to tile registers
// ureg 0 -> tile reg 0, 1
// ureg 0 -> tile reg 0, 1 (2KB total = 2 T-tiles)
std::vector<uint32_t> target_tregs = map_ureg_to_treg(vd);
base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype);

uint64_t current_addr = base_addr;
for (uint32_t treg_idx : target_tregs) {
// Use VegetaLsu for bulk U-tile load (2KB)
constexpr uint32_t T_TILE_ELEMENTS = TILE_DIM * TILE_DIM;
float tile_buffer[T_TILE_ELEMENTS * 2]; // 2 T-tiles for U-reg
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::U_TILE,
vd, wid, tid, tile_buffer);

// Copy from linear buffer to 2D tile registers
for (uint32_t t = 0; t < target_tregs.size(); ++t) {
uint32_t treg_idx = target_tregs[t];
assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[treg_idx];

// Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
uint32_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});

float value;
std::memcpy(&value, &mem_data, ELEMENT_SIZE);
tile_reg[row][col] = value;
tile_reg[row][col] = tile_buffer[t * T_TILE_ELEMENTS + row * TILE_DIM + col];
}
}
current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB)
}

DP(2, "TILE_LOAD_U: wid=" << wid << ", tid=" << tid
// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < T_TILE_ELEMENTS * 2; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_LOAD_U (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", ureg_idx=" << vd << ", target_tregs=["
<< target_tregs[0] << ", " << target_tregs[1] << "], base_addr=0x" << std::hex << base_addr << std::dec);
break;
}
case VegetaLsuType::TILE_LOAD_V: {
// tile_load_v: DestReg contains vreg index, map to tile registers
// vreg 0 -> tile reg 0, 1, 2, 3
// vreg 0 -> tile reg 0, 1, 2, 3 (4KB total = 4 T-tiles)
std::vector<uint32_t> target_tregs = map_vreg_to_treg(vd);
base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype);

uint64_t current_addr = base_addr;
for (uint32_t treg_idx : target_tregs) {
// Use VegetaLsu for bulk V-tile load (4KB)
constexpr uint32_t T_TILE_ELEMENTS = TILE_DIM * TILE_DIM;
float tile_buffer[T_TILE_ELEMENTS * 4]; // 4 T-tiles for V-reg
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::V_TILE,
vd, wid, tid, tile_buffer);

// Copy from linear buffer to 2D tile registers
for (uint32_t t = 0; t < target_tregs.size(); ++t) {
uint32_t treg_idx = target_tregs[t];
assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[treg_idx];

// Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
uint32_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});

float value;
std::memcpy(&value, &mem_data, ELEMENT_SIZE);
tile_reg[row][col] = value;
tile_reg[row][col] = tile_buffer[t * T_TILE_ELEMENTS + row * TILE_DIM + col];
}
}
current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB)
}

DP(2, "TILE_LOAD_V: wid=" << wid << ", tid=" << tid
// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < T_TILE_ELEMENTS * 4; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_LOAD_V (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", vreg_idx=" << vd << ", target_tregs=["
<< target_tregs[0] << ", " << target_tregs[1] << ", "
<< target_tregs[2] << ", " << target_tregs[3] << "], base_addr=0x" << std::hex << base_addr << std::dec);
Expand All @@ -504,22 +512,28 @@ class SparseUnit::Impl {
assert(meta_reg_idx < metadata_reg_file_.size() && "Metadata register index out of bounds");
auto &metadata_reg = metadata_reg_file_[meta_reg_idx];

// Load metadata from memory: 16 rows x 16 columns = 256 uint4 elements = 128 bytes
// Use VegetaLsu for bulk M-tile load (128 bytes)
constexpr uint32_t M_TILE_SIZE = 128;
uint8_t meta_buffer[M_TILE_SIZE];
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::M_TILE,
meta_reg_idx, wid, tid, meta_buffer);

// Parse nibbles from linear buffer into metadata register
// Each byte stores two uint4 values: upper nibble for col N, lower nibble for col N+1
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; col += 2) {
uint64_t mem_addr = base_addr + (row * (TILE_DIM / 2) + col / 2);
uint8_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, 1);
trace_data->mem_addrs.at(tid).push_back({mem_addr, 1});

// Upper nibble for col N, lower nibble for col N+1
metadata_reg[row][col] = (mem_data >> 4) & 0x0F;
metadata_reg[row][col + 1] = mem_data & 0x0F;
uint8_t byte = meta_buffer[row * (TILE_DIM / 2) + col / 2];
metadata_reg[row][col] = (byte >> 4) & 0x0F;
metadata_reg[row][col + 1] = byte & 0x0F;
}
}

// Record trace for all bytes
for (uint32_t i = 0; i < M_TILE_SIZE; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i, 1});
}

DP(2, "TILE_LOAD_M: wid=" << wid << ", tid=" << tid
DP(2, "TILE_LOAD_M (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", metadata_reg_idx=" << meta_reg_idx << ", base_addr=0x" << std::hex << base_addr << std::dec);
break;
}
Expand Down Expand Up @@ -548,21 +562,26 @@ class SparseUnit::Impl {
assert(vs3 < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[vs3];
constexpr uint32_t TILE_DIM = 16;
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32

// Store tile to memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
// Copy 2D tile register to linear buffer for VegetaLsu
float tile_buffer[TILE_DIM * TILE_DIM];
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
float value = tile_reg[row][col];
uint32_t mem_data = 0;
std::memcpy(&mem_data, &value, ELEMENT_SIZE);
core_->dcache_write(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});
tile_buffer[row * TILE_DIM + col] = tile_reg[row][col];
}
}

// Use VegetaLsu for bulk tile store (1KB)
core_->vegeta_lsu()->store_tile(base_addr, VegetaLsu::TileType::T_TILE,
vs3, wid, tid, tile_buffer);

// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < TILE_DIM * TILE_DIM; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_STORE: wid=" << wid << ", tid=" << tid << ", vs3=" << vs3
DP(2, "TILE_STORE (via VegetaLsu): wid=" << wid << ", tid=" << tid << ", vs3=" << vs3
<< ", base_addr=0x" << std::hex << base_addr << std::dec);
#else
std::abort(); // EXT_VEGETA_ENABLE required for store operations
Expand Down
Loading