From a54bc37c6b4fb2595def0fd4195dd9dfe7b4928d Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 00:58:23 +0100 Subject: [PATCH 01/31] So it begins --- CMakeLists.txt | 116 +++---- src/eval_types.hpp | 119 ++++--- src/evaltune_main.cpp | 214 ++++++------ src/evaluation.cpp | 2 +- src/tuning/arena.hpp | 66 ++++ src/tuning/globals.hpp | 86 ++--- src/tuning/graph.cpp | 474 +++++++++++++++++++++++++- src/tuning/graph.hpp | 189 ++++------- src/tuning/info.hpp | 2 +- src/tuning/loss.hpp | 18 +- src/tuning/operations.hpp | 77 +++++ src/tuning/optim.hpp | 51 ++- src/tuning/value.cpp | 273 ++++++++++++++- src/tuning/value.hpp | 688 +++++--------------------------------- 14 files changed, 1314 insertions(+), 1061 deletions(-) create mode 100644 src/tuning/arena.hpp create mode 100644 src/tuning/operations.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index cc03e7af..25b1662a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,62 +81,66 @@ endfunction() # Sorted list of source files set(srcs - src/bench.cpp - src/bench.hpp - src/board.cpp - src/board.hpp - src/common.hpp - src/dbg_tools.cpp - src/dbg_tools.hpp - src/eval_constants.hpp - src/eval_types.hpp - src/evaluation.cpp - src/evaluation.hpp - src/geometry.cpp - src/geometry.hpp - src/history.cpp - src/history.hpp - src/move.cpp - src/move.hpp - src/movegen.cpp - src/movegen.hpp - src/movepick.cpp - src/movepick.hpp - src/perft.cpp - src/perft.hpp - src/position.cpp - src/position.hpp - src/psqt_state.hpp - src/repetition_info.cpp - src/repetition_info.hpp - src/search.cpp - src/search.hpp - src/speedtest.cpp - src/speedtest.hpp - src/square.hpp - src/tm.cpp - src/tm.hpp - src/tt.cpp - src/tt.hpp - src/tuned.cpp - src/tuned.hpp - src/tuning/globals.hpp - src/tuning/graph.cpp - src/tuning/graph.hpp - src/tuning/loss.hpp - src/tuning/optim.hpp - src/tuning/value.cpp - src/tuning/value.hpp - src/uci.cpp - src/uci.hpp - src/util/bit.hpp - src/util/parse.hpp - src/util/pretty.hpp - src/util/static_vector.hpp - src/util/types.hpp - src/util/vec/sse2.hpp - src/zobrist.cpp - src/zobrist.hpp + src/bench.cpp + src/bench.hpp + src/board.cpp + src/board.hpp + src/common.hpp + src/dbg_tools.cpp + src/dbg_tools.hpp + src/eval_constants.hpp + src/eval_types.hpp + src/evaluation.cpp + src/evaluation.hpp + src/geometry.cpp + src/geometry.hpp + src/history.cpp + src/history.hpp + src/move.cpp + src/move.hpp + src/movegen.cpp + src/movegen.hpp + src/movepick.cpp + src/movepick.hpp + src/perft.cpp + src/perft.hpp + src/position.cpp + src/position.hpp + src/psqt_state.hpp + src/repetition_info.cpp + src/repetition_info.hpp + src/search.cpp + src/search.hpp + src/speedtest.cpp + src/speedtest.hpp + src/square.hpp + src/tm.cpp + src/tm.hpp + src/tt.cpp + src/tt.hpp + src/tuned.cpp + src/tuned.hpp + src/tuning/arena.hpp + src/tuning/globals.hpp + src/tuning/graph.cpp + src/tuning/graph.hpp + src/tuning/info.hpp + + src/tuning/loss.hpp + src/tuning/operations.hpp + src/tuning/optim.hpp + src/tuning/value.cpp + src/tuning/value.hpp + src/uci.cpp + src/uci.hpp + src/util/bit.hpp + src/util/parse.hpp + src/util/pretty.hpp + src/util/static_vector.hpp + src/util/types.hpp + src/util/vec/sse2.hpp + src/zobrist.cpp + src/zobrist.hpp ) if(CLOCKWORK_ENABLE_EVALTUNE) diff --git a/src/eval_types.hpp b/src/eval_types.hpp index 8262a951..5fc1d202 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -14,14 +14,19 @@ #endif namespace Clockwork { -#ifndef EVAL_TUNING +#ifndef EVAL_TUNING +// ============================================================================ +// NORMAL BUILD (NO TUNING) +// ============================================================================ using Score = i16; + class PScore { private: i32 m_score; + explicit constexpr PScore(i32 score) : - m_score{score} { + m_score(score) { } public: @@ -30,99 +35,105 @@ class PScore { } constexpr PScore(Score midgame, Score endgame) : - m_score{static_cast(static_cast(endgame) << 16) + midgame} { + m_score(static_cast((u32(endgame) << 16) + u16(midgame))) { assert(std::numeric_limits::min() <= midgame - && std::numeric_limits::max() >= midgame); + && midgame <= std::numeric_limits::max()); assert(std::numeric_limits::min() <= endgame - && std::numeric_limits::max() >= endgame); + && endgame <= std::numeric_limits::max()); } - [[nodiscard]] inline auto mg() const { - const auto mg = static_cast(m_score); - - i16 v{}; + inline Score mg() const { + u16 mg = u16(m_score); + i16 v; std::memcpy(&v, &mg, sizeof(mg)); - - return static_cast(v); + return v; } - [[nodiscard]] inline auto eg() const { - const auto eg = static_cast(static_cast(m_score + 0x8000) >> 16); - - i16 v{}; + inline Score eg() const { + u16 eg = u16(u32(m_score + 0x8000) >> 16); + i16 v; std::memcpy(&v, &eg, sizeof(eg)); - - return static_cast(v); + return v; } - [[nodiscard]] constexpr auto operator+(const PScore& other) const { - return PScore{m_score + other.m_score}; + // Operators identical to original version + constexpr PScore operator+(const PScore& o) const { + return PScore(m_score + o.m_score); } - - constexpr auto operator+=(const PScore& other) -> auto& { - m_score += other.m_score; - return *this; + constexpr PScore operator-(const PScore& o) const { + return PScore(m_score - o.m_score); } - - [[nodiscard]] constexpr auto operator-(const PScore& other) const { - return PScore{m_score - other.m_score}; + constexpr PScore operator*(i32 v) const { + return PScore(m_score * v); } - - constexpr auto operator-=(const PScore& other) -> auto& { - m_score -= other.m_score; + constexpr PScore& operator+=(const PScore& o) { + m_score += o.m_score; return *this; } - - [[nodiscard]] constexpr auto operator*(i32 v) const { - return PScore{m_score * v}; + constexpr PScore& operator-=(const PScore& o) { + m_score -= o.m_score; + return *this; } - - constexpr auto operator*=(i32 v) -> auto& { + constexpr PScore& operator*=(i32 v) { m_score *= v; return *this; } - - [[nodiscard]] constexpr auto operator-() const { - return PScore{-m_score}; + constexpr PScore operator-() const { + return PScore(-m_score); } - [[nodiscard]] constexpr bool operator==(const PScore& other) const = default; + constexpr bool operator==(const PScore&) const = default; - [[nodiscard]] constexpr const PScore* operator->() const { + constexpr const PScore* operator->() const { return this; } - // Phasing between two scores + // Phase function (non-tuning: returns int) template - Value phase(i32 alpha) const { + inline Value phase(i32 alpha) const { assert(0 <= alpha && alpha <= max); - return static_cast((mg() * alpha + eg() * (max - alpha)) / max); + return Value((mg() * alpha + eg() * (max - alpha)) / max); } - friend std::ostream& operator<<(std::ostream& stream, const PScore& score) { - stream << "(" << score.mg() << "\t" << score.eg() << ")"; - return stream; + friend std::ostream& operator<<(std::ostream& os, const PScore& s) { + os << "(" << s.mg() << "\t" << s.eg() << ")"; + return os; } }; using PParam = PScore; #else +// ============================================================================ +// TUNING BUILD (NEW AUTOGRAD API) +// ============================================================================ -using Score = Autograd::ValuePtr; -using PScore = Autograd::PairPtr; -using PParam = Autograd::PairPlaceholder; +using Score = Autograd::ValueHandle; +using PScore = Autograd::PairHandle; // (mg, eg) handle +using PParam = Autograd::PairHandle; // tunable pair #endif + +// ============================================================================ +// Macro Definitions +// ============================================================================ + #ifdef EVAL_TUNING - #define S(a, b) Autograd::PairPlaceholder::create_tunable((a), (b)) // Defines a tunable pscore + // Tunable scalar pair (mg, eg) + #define S(a, b) Autograd::PairPlaceholder::create_tunable((a), (b)) + + // Constant (fixed) scalar pair (mg, eg) #define CS(a, b) Autograd::PairPlaceholder::create((a), (b)) - #define PSCORE_ZERO Autograd::Pair::create(0, 0) + + // Zero pair + #define PSCORE_ZERO Autograd::PairPlaceholder::create(0, 0) + #else - #define S(a, b) PScore((a), (b)) // Defines a constant pscore when not tuning - #define CS(a, b) S((a), (b)) - #define PSCORE_ZERO CS(0, 0) + // Non-tuning build: use fixed, non-autograd PScore + #define S(a, b) PScore((a), (b)) + #define CS(a, b) PScore((a), (b)) + #define PSCORE_ZERO PScore(0, 0) #endif -} // namespace Clockwork +} // namespace Clockwork \ No newline at end of file diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index 0568a660..b382de92 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -2,12 +2,15 @@ #include "eval_types.hpp" #include "evaluation.hpp" #include "position.hpp" + #include "tuning/graph.hpp" #include "tuning/loss.hpp" #include "tuning/optim.hpp" #include "tuning/value.hpp" + #include "util/pretty.hpp" #include "util/types.hpp" + #include #include #include @@ -22,194 +25,190 @@ #include using namespace Clockwork; +using namespace Clockwork::Autograd; int main() { - // Load fens from multiple files. + // ------------------------------ + // Load FENs + // ------------------------------ std::vector positions; std::vector results; - // List of files to load const std::vector fenFiles = { - "data/dfrc-1m.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3/v3.txt", + "data/dfrcv1.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3.txt", }; - // Number of threads to use, default to half available const u32 thread_count = std::max(1, std::thread::hardware_concurrency() / 2); - std::cout << "Running on " << thread_count << " threads" << std::endl; + std::cout << "Running on " << thread_count << " threads\n"; for (const auto& filename : fenFiles) { std::ifstream fenFile(filename); if (!fenFile) { - std::cerr << "Error opening " << filename << std::endl; + std::cerr << "Error opening " << filename << "\n"; return 1; } std::string line; while (std::getline(fenFile, line)) { size_t pos = line.find(';'); - if (pos != std::string::npos) { - std::string fen = line.substr(0, pos); - auto parsed = Position::parse(fen); - if (parsed) { - positions.push_back(*parsed); - } else { - std::cerr << "Failed to parse FEN in file " << filename << ": " << fen - << std::endl; - continue; - } + if (pos == std::string::npos) { + std::cerr << "Bad line in " << filename << ": " << line << "\n"; + continue; + } - std::string result = line.substr(pos + 1); - result.erase(std::remove_if(result.begin(), result.end(), ::isspace), result.end()); - - if (result == "w") { - results.push_back(1.0); - } else if (result == "d") { - results.push_back(0.5); - } else if (result == "b") { - results.push_back(0.0); - } else { - std::cerr << "Invalid result in file " << filename << " line: " << line - << " (result is '" << result << "')" << std::endl; - } + std::string fen = line.substr(0, pos); + auto parsed = Position::parse(fen); + + if (!parsed) { + std::cerr << "Failed to parse FEN in " << filename << ": " << fen << "\n"; + continue; + } + + positions.push_back(*parsed); + + std::string result = line.substr(pos + 1); + result.erase(std::remove_if(result.begin(), result.end(), ::isspace), result.end()); + + if (result == "w") { + results.push_back(1.0); + } else if (result == "d") { + results.push_back(0.5); + } else if (result == "b") { + results.push_back(0.0); } else { - std::cerr << "Invalid line format in " << filename << ": " << line << std::endl; + std::cerr << "Invalid result in " << filename << ": " << line << "\n"; } } - - fenFile.close(); } - // Print the number of positions loaded - std::cout << "Loaded " << positions.size() << " FENs from " << fenFiles.size() << " files." - << std::endl; - - if (positions.size() == 0) { - std::cerr << "No positions loaded!" << std::endl; + std::cout << "Loaded " << positions.size() << " FENs.\n"; + if (positions.empty()) { return 1; } - using namespace Clockwork::Autograd; + // ------------------------------ + // Setup Autograd system + // ------------------------------ + + const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts(); - const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts(); - Parameters current_parameter_values = Graph::get().get_all_parameter_values(); + Parameters current_parameter_values = Graph::get().get_all_parameter_values(); AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0); const i32 epochs = 1000; const f64 K = 1.0 / 400; - const size_t batch_size = 16 * 16384; // Set batch size here - - std::mt19937 rng(std::random_device{}()); // Random number generator for shuffling + const size_t batch_size = 16 * 16384; - const size_t total_batches = (positions.size() + batch_size - 1) / batch_size; + std::mt19937 rng(std::random_device{}()); std::vector indices(positions.size()); + const size_t total_batches = (positions.size() + batch_size - 1) / batch_size; + // Shared gradient accumulator Parameters batch_gradients = Parameters::zeros(parameter_count); - std::mutex mutex; + std::mutex mutex; + std::barrier epoch_barrier{thread_count + 1}; std::barrier batch_barrier{thread_count + 1, [&]() noexcept { - std::lock_guard guard{mutex}; + // Single-thread optimizer update optim.step(current_parameter_values, batch_gradients); batch_gradients = Parameters::zeros(parameter_count); }}; - for (u32 thread_idx = 0; thread_idx < thread_count; thread_idx++) { - std::thread([&, thread_idx] { - Graph::get().cleanup(); - - std::vector subbatch_outputs; - std::vector subbatch_targets; + // ------------------------------ + // Worker threads + // ------------------------------ + for (u32 t = 0; t < thread_count; ++t) { + std::thread([&, t]() { + // Each thread uses its own Graph arena - for (i32 epoch = 0; epoch < epochs; epoch++) { + for (int epoch = 0; epoch < epochs; ++epoch) { epoch_barrier.arrive_and_wait(); for (size_t batch_start = 0; batch_start < positions.size(); batch_start += batch_size) { + size_t batch_end = std::min(batch_start + batch_size, positions.size()); + size_t this_batch_size = batch_end - batch_start; - size_t batch_end = std::min(batch_start + batch_size, positions.size()); - size_t current_batch_size = batch_end - batch_start; - size_t subbatch_size = (current_batch_size + thread_count - 1) / thread_count; + size_t sub_size = (this_batch_size + thread_count - 1) / thread_count; - size_t subbatch_start = batch_start + subbatch_size * thread_idx; - size_t subbatch_end = std::min(subbatch_start + subbatch_size, batch_end); - size_t current_subbatch_size = subbatch_end - subbatch_start; - - subbatch_outputs.clear(); - subbatch_targets.clear(); - subbatch_outputs.reserve(current_subbatch_size); - subbatch_targets.reserve(current_subbatch_size); + size_t sub_start = batch_start + sub_size * t; + size_t sub_end = std::min(sub_start + sub_size, batch_end); Graph::get().copy_parameter_values(current_parameter_values); - uint32_t i = 0; - for (size_t j = subbatch_start; j < subbatch_end; ++j) { - size_t idx = indices[j]; - f64 y = results[idx]; - Position pos = positions[idx]; - auto result = (evaluate_white_pov(pos) * K)->sigmoid(); - subbatch_outputs.push_back(result); - subbatch_targets.push_back(y); - if (++i == 1024) { - i = 0; - auto subbatch_loss = - mse(subbatch_outputs, subbatch_targets) - * Autograd::Value::create(1.0 / static_cast(current_batch_size)); - Graph::get().backward(); - Graph::get().clear_backwardables(); - subbatch_outputs.clear(); - subbatch_targets.clear(); - } + std::vector outputs; + std::vector targets; + outputs.reserve(sub_end - sub_start); + targets.reserve(sub_end - sub_start); + + // ------------------------------ + // Forward pass + // ------------------------------ + for (size_t j = sub_start; j < sub_end; ++j) { + size_t idx = indices[j]; + + auto y = results[idx]; + ValueHandle v = (evaluate_white_pov(positions[idx]) * K).sigmoid(); + outputs.push_back(v); + targets.push_back(y); } - auto subbatch_loss = - mse(subbatch_outputs, subbatch_targets) - * Autograd::Value::create(1.0 / static_cast(current_batch_size)); + // ------------------------------ + // Loss and backward + // ------------------------------ + ValueHandle loss = mse(outputs, targets) + * ValueHandle::create(1.0 / double(this_batch_size)); + Graph::get().backward(); - Parameters subbatch_gradients = Graph::get().get_all_parameter_gradients(); + Parameters grads = Graph::get().get_all_parameter_gradients(); + // Accumulate { - std::lock_guard guard{mutex}; - batch_gradients.accumulate(subbatch_gradients); + std::lock_guard guard(mutex); + batch_gradients.accumulate(grads); } - batch_barrier.arrive_and_wait(); Graph::get().cleanup(); + Graph::get().zero_grad(); + + batch_barrier.arrive_and_wait(); } } }).detach(); } - for (i32 epoch = 0; epoch < epochs; epoch++) { - // Print epoch header - std::cout << "Epoch " << (epoch + 1) << "/" << epochs << std::endl; + // ------------------------------ + // Main thread: epoch coordinator + // ------------------------------ + for (int epoch = 0; epoch < epochs; ++epoch) { - const auto epoch_start_time = time::Clock::now(); + std::cout << "Epoch " << epoch + 1 << "/" << epochs << "\n"; + + const auto start = time::Clock::now(); std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), rng); epoch_barrier.arrive_and_wait(); - for (size_t batch_idx = 0, batch_start = 0; batch_start < positions.size(); - batch_start += batch_size, ++batch_idx) { - + for (size_t bi = 0, bstart = 0; bstart < positions.size(); bstart += batch_size, ++bi) { batch_barrier.arrive_and_wait(); - - // Print batch progress bar - print_progress(batch_idx + 1, total_batches); + print_progress(bi + 1, total_batches); } - const auto epoch_end_time = time::Clock::now(); - - std::cout << std::endl; // Finish progress bar line + std::cout << "\n"; - // Print current values + // Dump current parameter values Graph::get().copy_parameter_values(current_parameter_values); + + Graph::get().cleanup(); + Graph::get().zero_grad(); std::cout << "inline const PParam PAWN_MAT = " << PAWN_MAT << ";" << std::endl; std::cout << "inline const PParam KNIGHT_MAT = " << KNIGHT_MAT << ";" << std::endl; @@ -336,10 +335,11 @@ int main() { printPsqtArray("ROOK_PSQT", ROOK_PSQT); printPsqtArray("QUEEN_PSQT", QUEEN_PSQT); printPsqtArray("KING_PSQT", KING_PSQT); + std::cout << std::endl; - std::cout << "// Epoch duration: " - << time::cast(epoch_end_time - epoch_start_time).count() - << "s" << std::endl; + const auto end = time::Clock::now(); + std::cout << "// Epoch duration: " << time::cast(end - start).count() + << "s\n"; if (epoch > 5) { optim.set_lr(optim.get_lr() * 0.91); diff --git a/src/evaluation.cpp b/src/evaluation.cpp index 9ac677a1..9a8a0e7e 100644 --- a/src/evaluation.cpp +++ b/src/evaluation.cpp @@ -335,7 +335,7 @@ Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) { eval += evaluate_space(pos) - evaluate_space(pos); eval += evaluate_outposts(pos) - evaluate_outposts(pos); eval += (us == Color::White) ? TEMPO_VAL : -TEMPO_VAL; - return static_cast(eval->phase<24>(static_cast(phase))); + return static_cast(eval.phase<24>(static_cast(phase))); }; Score evaluate_stm_pov(const Position& pos, const PsqtState& psqt_state) { diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp new file mode 100644 index 00000000..f9780670 --- /dev/null +++ b/src/tuning/arena.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include "util/types.hpp" +#include +#include + +namespace Clockwork::Autograd { + +// A simple contiguous storage for types T. +// Returns indices (handles) instead of pointers. +template +class Arena { +private: + std::vector m_data; + +public: + // Allocates a new slot and returns its index + u32 alloc(const T& initial_value) { + u32 idx = static_cast(m_data.size()); + m_data.push_back(initial_value); + return idx; + } + + // Emplace version + template + u32 emplace(Args&&... args) { + u32 idx = static_cast(m_data.size()); + m_data.emplace_back(std::forward(args)...); + return idx; + } + + // Accessors + inline T& operator[](u32 index) { + assert(index < m_data.size()); + return m_data[index]; + } + + inline const T& operator[](u32 index) const { + assert(index < m_data.size()); + return m_data[index]; + } + + inline usize size() const { + return m_data.size(); + } + + // Resets the arena size, effectively clearing it. + // Note: Does not free memory (capacity remains) to reduce allocations next cycle. + void clear() { + m_data.clear(); + } + + // Keeps the first `n` elements, effectively clearing the rest. + // Useful for keeping parameters which are at the start of the arena. + void reset_to(usize n) { + if (n < m_data.size()) { + m_data.resize(n); + } + } + + std::vector& raw() { + return m_data; + } +}; + +} // namespace Clockwork::Autograd diff --git a/src/tuning/globals.hpp b/src/tuning/globals.hpp index e573930f..cec841bb 100644 --- a/src/tuning/globals.hpp +++ b/src/tuning/globals.hpp @@ -1,6 +1,6 @@ #pragma once -#include "tuning/graph.hpp" +#include "tuning/graph.hpp" // Required for Graph::get() #include "tuning/info.hpp" #include "tuning/value.hpp" #include "util/types.hpp" @@ -8,7 +8,7 @@ #include #include #include -#include +#include namespace Clockwork::Autograd { @@ -60,7 +60,6 @@ class Globals { } bool is_parameter_constant(usize i) const; - bool is_pair_parameter_constant(usize i) const; private: @@ -89,18 +88,17 @@ class ValuePlaceholder { return ValuePlaceholder(a, true); } - operator ValuePtr() const { + // Conversion to Handle: Delegates to the Graph + operator ValueHandle() const { return Graph::get().get_parameter(m_index); } usize index() const { return m_index; } - f64 default_value() const { return m_default_value; } - bool constant() const { return m_constant; } @@ -111,35 +109,6 @@ class ValuePlaceholder { bool m_constant; }; -inline bool Globals::is_parameter_constant(usize i) const { - return m_parameters[i]->constant(); -} - -inline std::ostream& operator<<(std::ostream& os, ValuePlaceholder a) { - os << static_cast(a); - return os; -} - -inline ValuePtr operator-(ValuePlaceholder a) { - return -static_cast(a); -} - -inline ValuePtr operator+(ValuePlaceholder a, ValuePlaceholder b) { - return static_cast(a) + static_cast(b); -} - -inline ValuePtr operator-(ValuePlaceholder a, ValuePlaceholder b) { - return static_cast(a) - static_cast(b); -} - -inline ValuePtr operator*(ValuePlaceholder a, i32 b) { - return static_cast(a) * b; -} - -inline ValuePtr operator/(ValuePlaceholder a, i32 b) { - return static_cast(a) / b; -} - class PairPlaceholder { public: explicit PairPlaceholder(f128 default_value, bool constant) : @@ -156,18 +125,17 @@ class PairPlaceholder { return PairPlaceholder(f128::make(a, b), true); } - operator PairPtr() const { + // Conversion to Handle: Delegates to the Graph + operator PairHandle() const { return Graph::get().get_pair_parameter(m_index); } usize index() const { return m_index; } - f128 default_value() const { return m_default_value; } - bool constant() const { return m_constant; } @@ -178,33 +146,41 @@ class PairPlaceholder { bool m_constant; }; +inline bool Globals::is_parameter_constant(usize i) const { + return m_parameters[i]->constant(); +} + inline bool Globals::is_pair_parameter_constant(usize i) const { return m_pair_parameters[i]->constant(); } -inline std::ostream& operator<<(std::ostream& os, PairPlaceholder a) { - os << static_cast(a); - return os; -} +// --- Helper Operators for Placeholders --- +// These allow Placeholders to be used directly in arithmetic expressions by implicit conversion to Handles. -inline PairPtr operator-(PairPlaceholder a) { - return -static_cast(a); +inline ValueHandle operator-(ValuePlaceholder a) { + return -static_cast(a); } - -inline PairPtr operator+(PairPlaceholder a, PairPlaceholder b) { - return static_cast(a) + static_cast(b); +inline ValueHandle operator+(ValuePlaceholder a, ValuePlaceholder b) { + return static_cast(a) + static_cast(b); } - -inline PairPtr operator-(PairPlaceholder a, PairPlaceholder b) { - return static_cast(a) - static_cast(b); +inline ValueHandle operator-(ValuePlaceholder a, ValuePlaceholder b) { + return static_cast(a) - static_cast(b); } - -inline PairPtr operator*(PairPlaceholder a, i32 b) { - return static_cast(a) * b; +inline ValueHandle operator*(ValuePlaceholder a, i32 b) { + return static_cast(a) * static_cast(b); +} +inline ValueHandle operator/(ValuePlaceholder a, i32 b) { + return static_cast(a) / static_cast(b); } -inline PairPtr operator/(PairPlaceholder a, i32 b) { - return static_cast(a) / b; +inline PairHandle operator-(PairPlaceholder a) { + return -static_cast(a); +} +inline PairHandle operator+(PairPlaceholder a, PairPlaceholder b) { + return static_cast(a) + static_cast(b); +} +inline PairHandle operator-(PairPlaceholder a, PairPlaceholder b) { + return static_cast(a) - static_cast(b); } } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 706ee5c9..3d87a1a6 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -1,16 +1,480 @@ #include "tuning/graph.hpp" #include "tuning/globals.hpp" +#include namespace Clockwork::Autograd { Graph::Graph() { - for (ValuePlaceholder* placeholder : Globals::get().get_parameters()) { - register_param(std::make_shared(placeholder->default_value())); + // Initialize arenas with global parameters + auto params = Globals::get().get_parameters(); + auto pair_params = Globals::get().get_pair_parameters(); + + m_global_param_count = params.size(); + m_global_pair_count = pair_params.size(); + + for (auto* p : params) { + m_values.alloc({p->default_value(), 0.0}); + } + for (auto* p : pair_params) { + m_pairs.alloc({p->default_value(), f128::zero()}); + } +} + +Graph& Graph::get() { + thread_local Graph instance; + return instance; +} + +ValueHandle Graph::create_value(f64 data, bool is_parameter) { + return ValueHandle(m_values.alloc({data, 0.0})); +} + +PairHandle Graph::create_pair(f128 data, bool is_parameter) { + return PairHandle(m_pairs.alloc({data, f128::zero()})); +} + +// ------------------ Recording ------------------ + +ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { + u32 out = m_values.alloc({0.0, 0.0}); + f64 l = m_values[lhs.index].value; + f64 r = m_values[rhs.index].value; + f64 res = 0.0; + + switch (op) { + case OpType::Add: + res = l + r; + break; + case OpType::Sub: + res = l - r; + break; + case OpType::Mul: + res = l * r; + break; + case OpType::Div: + res = l / r; + break; + case OpType::Pow: + res = std::pow(l, r); + break; + default: + break; + } + m_values[out].value = res; + m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); + return ValueHandle(out); +} + +ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { + u32 out = m_values.alloc({0.0, 0.0}); + f64 l = m_values[input.index].value; + f64 res = 0.0; + + switch (op) { + case OpType::Exp: + res = std::exp(l); + break; + case OpType::Log: + res = std::log(l); + break; + case OpType::Sigmoid: + res = 1.0 / (1.0 + std::exp(-l)); + break; + case OpType::Neg: + res = -l; + break; + case OpType::PowConst: + res = std::pow(l, scalar); + break; + case OpType::AddScalar: + res = l + scalar; + break; + case OpType::SubScalarVal: + res = scalar - l; + break; + case OpType::ValSubScalar: + res = l - scalar; + break; + case OpType::MulScalar: + res = l * scalar; + break; + case OpType::DivScalarVal: + res = scalar / l; + break; + case OpType::ValDivScalar: + res = l / scalar; + break; + default: + break; + } + m_values[out].value = res; + m_tape.push_back({op, out, input.index, 0, scalar}); + return ValueHandle(out); +} + +PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { + u32 out = m_pairs.alloc({f128::zero(), f128::zero()}); + f128 l = m_pairs[lhs.index].values; + f128 r = m_pairs[rhs.index].values; + f128 res = f128::zero(); + + switch (op) { + case OpType::PairAdd: + res = f128::add(l, r); + break; + case OpType::PairSub: + res = f128::sub(l, r); + break; + default: + break; + } + m_pairs[out].values = res; + m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); + return PairHandle(out); +} + +PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { + u32 out = m_pairs.alloc({f128::zero(), f128::zero()}); + f128 l = m_pairs[input.index].values; + f128 res = f128::zero(); + + switch (op) { + case OpType::PairNeg: + res = f128::neg(l); + break; + case OpType::PairMulScalar: + res = f128::mul_scalar(l, scalar); + break; + case OpType::PairDivScalar: + res = f128::div_scalar(l, scalar); + break; + case OpType::ScalarDivPair: + res = f128::scalar_div(scalar, l); + break; + default: + break; + } + m_pairs[out].values = res; + m_tape.push_back({op, out, input.index, 0, scalar}); + return PairHandle(out); +} + +PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) { + u32 out = m_pairs.alloc({f128::zero(), f128::zero()}); + f128 p = m_pairs[pair.index].values; + f64 v = m_values[val.index].value; + f128 res = f128::zero(); + + switch (op) { + case OpType::PairMulValue: + case OpType::ValueMulPair: + res = f128::mul_scalar(p, v); + break; + case OpType::PairDivValue: + res = f128::div_scalar(p, v); + break; + case OpType::ValueDivPair: + res = f128::scalar_div(v, p); + break; + default: + break; + } + m_pairs[out].values = res; + // Note: rhs is stored in rhs_idx, but it refers to m_values arena! + // The OpType tells us which arena to look at. + m_tape.push_back({op, out, pair.index, val.index, 0.0}); + return PairHandle(out); +} + +ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { + u32 out = m_values.alloc({0.0, 0.0}); + f128 p = m_pairs[input.index].values; + // Linear interpolation: alpha * first + (1-alpha) * second + f64 val = alpha * p.first() + (1.0 - alpha) * p.second(); + m_values[out].value = val; + m_tape.push_back({OpType::Phase, out, input.index, 0, alpha}); + return ValueHandle(out); +} + +// ------------------ Backward ------------------ + +void Graph::backward() { + if (m_tape.empty()) { + return; + } + + // Seed gradient of last node + const auto& last_node = m_tape.back(); + // Assuming the last node produces a Value (the loss) + m_values[last_node.output_idx].gradient = 1.0; + + // Reverse iterate + for (auto it = m_tape.rbegin(); it != m_tape.rend(); ++it) { + const Node& node = *it; + + // Fetch gradients and values based on op type + // Note: References to vector elements are risky if alloc happened, but backward pass doesn't alloc. + // Using pointers or indices is safer. + + switch (node.type) { + // --- Value Binary --- + case OpType::Add: { + f64 grad = m_values[node.output_idx].gradient; + m_values[node.lhs_idx].gradient += grad; + m_values[node.rhs_idx].gradient += grad; + break; + } + case OpType::Sub: { + f64 grad = m_values[node.output_idx].gradient; + m_values[node.lhs_idx].gradient += grad; + m_values[node.rhs_idx].gradient -= grad; + break; + } + case OpType::Mul: { + f64 grad = m_values[node.output_idx].gradient; + f64 l = m_values[node.lhs_idx].value; + f64 r = m_values[node.rhs_idx].value; + m_values[node.lhs_idx].gradient += r * grad; + m_values[node.rhs_idx].gradient += l * grad; + break; + } + case OpType::Div: { + f64 grad = m_values[node.output_idx].gradient; + f64 l = m_values[node.lhs_idx].value; + f64 r = m_values[node.rhs_idx].value; + m_values[node.lhs_idx].gradient += (1.0 / r) * grad; + m_values[node.rhs_idx].gradient += (-l / (r * r)) * grad; + break; + } + case OpType::Pow: { + f64 grad = m_values[node.output_idx].gradient; + f64 base = m_values[node.lhs_idx].value; + f64 exp = m_values[node.rhs_idx].value; + m_values[node.lhs_idx].gradient += exp * std::pow(base, exp - 1) * grad; + m_values[node.rhs_idx].gradient += std::pow(base, exp) * std::log(base) * grad; + break; + } + + // --- Value Unary --- + case OpType::Exp: { + // d/dx e^x = e^x = y + f64 grad = m_values[node.output_idx].gradient; + f64 val = m_values[node.output_idx].value; + m_values[node.lhs_idx].gradient += val * grad; + break; + } + case OpType::Log: { + f64 grad = m_values[node.output_idx].gradient; + f64 l = m_values[node.lhs_idx].value; + m_values[node.lhs_idx].gradient += (1.0 / l) * grad; + break; + } + case OpType::Sigmoid: { + f64 grad = m_values[node.output_idx].gradient; + f64 s = m_values[node.output_idx].value; + m_values[node.lhs_idx].gradient += (s * (1.0 - s)) * grad; + break; + } + case OpType::Neg: { + m_values[node.lhs_idx].gradient -= m_values[node.output_idx].gradient; + break; + } + case OpType::PowConst: { + f64 grad = m_values[node.output_idx].gradient; + f64 l = m_values[node.lhs_idx].value; + f64 exp = node.scalar_data; + m_values[node.lhs_idx].gradient += exp * std::pow(l, exp - 1.0) * grad; + break; + } + case OpType::AddScalar: + case OpType::SubScalarVal: { + m_values[node.lhs_idx].gradient += m_values[node.output_idx].gradient; + break; + } + case OpType::ValSubScalar: { + m_values[node.lhs_idx].gradient += m_values[node.output_idx].gradient; + break; + } + case OpType::MulScalar: { + m_values[node.lhs_idx].gradient += + node.scalar_data * m_values[node.output_idx].gradient; + break; + } + case OpType::ValDivScalar: { + m_values[node.lhs_idx].gradient += + (1.0 / node.scalar_data) * m_values[node.output_idx].gradient; + break; + } + case OpType::DivScalarVal: { + f64 grad = m_values[node.output_idx].gradient; + f64 l = m_values[node.lhs_idx].value; + m_values[node.lhs_idx].gradient += (-node.scalar_data / (l * l)) * grad; + break; + } + + // --- Pair Binary --- + case OpType::PairAdd: { + f128 grad = m_pairs[node.output_idx].gradients; + m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad); + m_pairs[node.rhs_idx].gradients = f128::add(m_pairs[node.rhs_idx].gradients, grad); + break; + } + case OpType::PairSub: { + f128 grad = m_pairs[node.output_idx].gradients; + m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad); + m_pairs[node.rhs_idx].gradients = f128::sub(m_pairs[node.rhs_idx].gradients, grad); + break; + } + + // --- Pair Scalar --- + case OpType::PairNeg: { + f128 grad = m_pairs[node.output_idx].gradients; + m_pairs[node.lhs_idx].gradients = f128::sub(m_pairs[node.lhs_idx].gradients, grad); + break; + } + case OpType::PairMulScalar: { + f128 grad = m_pairs[node.output_idx].gradients; + f128 scaled_grad = f128::mul_scalar(grad, node.scalar_data); + m_pairs[node.lhs_idx].gradients = + f128::add(m_pairs[node.lhs_idx].gradients, scaled_grad); + break; + } + case OpType::PairDivScalar: { + f128 grad = m_pairs[node.output_idx].gradients; + f128 scaled_grad = f128::div_scalar(grad, node.scalar_data); + m_pairs[node.lhs_idx].gradients = + f128::add(m_pairs[node.lhs_idx].gradients, scaled_grad); + break; + } + case OpType::ScalarDivPair: { + f128 grad = m_pairs[node.output_idx].gradients; + f128 l = m_pairs[node.lhs_idx].values; + f128 l_sq = f128::mul(l, l); + f128 neg_s_over_sq = f128::neg(f128::scalar_div(node.scalar_data, l_sq)); + f128 update = f128::mul(neg_s_over_sq, grad); + m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, update); + break; + } + + // --- Pair Value --- + case OpType::PairMulValue: + case OpType::ValueMulPair: { + // out = p * v + f128 grad_out = m_pairs[node.output_idx].gradients; + f128 p = m_pairs[node.lhs_idx].values; + f64 v = m_values[node.rhs_idx].value; + + // d/dp = v + f128 grad_p = f128::mul_scalar(grad_out, v); + m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); + + // d/dv = p.first * grad.first + p.second * grad.second + f128 contrib = f128::mul(p, grad_out); + m_values[node.rhs_idx].gradient += contrib.first() + contrib.second(); + break; + } + case OpType::PairDivValue: { + // out = p / v + f128 grad_out = m_pairs[node.output_idx].gradients; + f128 p = m_pairs[node.lhs_idx].values; + f64 v = m_values[node.rhs_idx].value; + + f128 grad_p = f128::div_scalar(grad_out, v); + m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); + + // d/dv = -p/v^2 * grad + f128 num = f128::mul(p, grad_out); + f64 sum_contr = num.first() + num.second(); + m_values[node.rhs_idx].gradient += -sum_contr / (v * v); + break; + } + case OpType::ValueDivPair: { + // out = v / p + f128 grad_out = m_pairs[node.output_idx].gradients; + f128 p = m_pairs[node.lhs_idx].values; + f64 v = m_values[node.rhs_idx].value; + + // d/dp = -v/p^2 + f128 p_sq = f128::mul(p, p); + f128 neg_v_sq = f128::neg(f128::scalar_div(v, p_sq)); + f128 grad_p = f128::mul(neg_v_sq, grad_out); + m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); + + // d/dv = 1/p + f128 v_contr = f128::div(grad_out, p); + m_values[node.rhs_idx].gradient += v_contr.first() + v_contr.second(); + break; + } + + // --- Phase --- + case OpType::Phase: { + f64 grad = m_values[node.output_idx].gradient; + f64 alpha = node.scalar_data; + // d/d_first = alpha, d/d_second = 1-alpha + f128 grad_upd = f128::make(alpha * grad, (1.0 - alpha) * grad); + m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_upd); + break; + } + default: + break; + } + } +} + +void Graph::cleanup() { + // Keep parameters, clear the rest + m_values.reset_to(m_global_param_count); + m_pairs.reset_to(m_global_pair_count); + m_tape.clear(); +} + +void Graph::zero_grad() { + for (usize i = 0; i < m_global_param_count; ++i) { + m_values[i].gradient = 0.0; + } + for (usize i = 0; i < m_global_pair_count; ++i) { + m_pairs[i].gradients = f128::zero(); + } +} + +void Graph::copy_parameter_values(const Parameters& source) { + if (source.parameters.size() != m_global_param_count + || source.pair_parameters.size() != m_global_pair_count) { + std::cerr << "Graph parameter count desync!" << std::endl; + std::terminate(); + } + for (usize i = 0; i < m_global_param_count; ++i) { + m_values[i].value = source.parameters[i]; + } + for (usize i = 0; i < m_global_pair_count; ++i) { + m_pairs[i].values = source.pair_parameters[i]; + } +} + +Parameters Graph::get_all_parameter_values() const { + Parameters p; + p.parameters.reserve(m_global_param_count); + p.pair_parameters.reserve(m_global_pair_count); + for (usize i = 0; i < m_global_param_count; ++i) { + p.parameters.push_back(m_values[i].value); + } + for (usize i = 0; i < m_global_pair_count; ++i) { + p.pair_parameters.push_back(m_pairs[i].values); + } + return p; +} + +Parameters Graph::get_all_parameter_gradients() const { + Parameters p; + p.parameters.reserve(m_global_param_count); + p.pair_parameters.reserve(m_global_pair_count); + for (usize i = 0; i < m_global_param_count; ++i) { + p.parameters.push_back(m_values[i].gradient); } - for (PairPlaceholder* placeholder : Globals::get().get_pair_parameters()) { - register_param( - std::make_shared(placeholder->default_value(), placeholder->constant())); + for (usize i = 0; i < m_global_pair_count; ++i) { + p.pair_parameters.push_back(m_pairs[i].gradients); } + return p; } } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 5f029677..f834e616 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -1,163 +1,94 @@ #pragma once -#include "tuning/info.hpp" -#include "tuning/value.hpp" -#include +#include "arena.hpp" +#include "info.hpp" +#include "operations.hpp" +#include "util/vec/sse2.hpp" +#include "value.hpp" #include #include #include namespace Clockwork::Autograd { -class Backwardable; -using BackwardablePtr = std::shared_ptr; - -template -class SmartBackwardable; -template -using SmartBackwardablePtr = std::shared_ptr>; - -class Value; -using ValuePtr = std::shared_ptr; +struct ValueData { + f64 value; + f64 gradient; +}; -class Pair; -using PairPtr = std::shared_ptr; +struct PairData { + f128 values; + f128 gradients; +}; class Graph { private: - // Tunable parameters - std::vector m_parameters; - std::vector m_pair_parameters; - - // All backwardable nodes in insertion order (intermediates + outputs + parameters) - std::vector m_backwardables; + // Storage + Arena m_values; + Arena m_pairs; - Graph(); + // Tape (Linear record of operations) + std::vector m_tape; - void register_param(const ValuePtr& param) { - m_parameters.push_back(param); - } + // Counts of global parameters (they sit at the start of the arenas) + usize m_global_param_count = 0; + usize m_global_pair_count = 0; - void register_param(const PairPtr& param) { - m_pair_parameters.push_back(param); - } + Graph(); public: - static Graph& get() { - thread_local Graph instance; - return instance; - } + static Graph& get(); - // ------------------ Registration ------------------ - void register_value(const BackwardablePtr& node) { - m_backwardables.push_back(node); - } + // ------------------ Creation ------------------ + ValueHandle create_value(f64 data, bool is_parameter = false); + PairHandle create_pair(f128 data, bool is_parameter = false); - void register_value(const ValuePtr& node) { - m_backwardables.push_back(std::static_pointer_cast(node)); - } + // ------------------ Operation Recording ------------------ + // Value-Value Binary + ValueHandle record_op(OpType op, ValueHandle lhs, ValueHandle rhs); + // Value Unary / Scalar + ValueHandle record_op(OpType op, ValueHandle input, f64 scalar = 0.0); + // Pair-Pair Binary + PairHandle record_pair_op(OpType op, PairHandle lhs, PairHandle rhs); + // Pair-Scalar + PairHandle record_pair_scalar(OpType op, PairHandle input, f64 scalar); + // Pair-Value + PairHandle record_pair_value(OpType op, PairHandle pair, ValueHandle val); - void register_value(const PairPtr& node) { - m_backwardables.push_back(std::static_pointer_cast(node)); - } + // Special Phase + ValueHandle record_phase(PairHandle input, f64 alpha); - // ------------------- Copy Values ------------------- - void copy_parameter_values(const Parameters& source) { - if (source.parameters.size() != m_parameters.size() - || source.pair_parameters.size() != m_pair_parameters.size()) { - std::cerr << "Graph parameters count have desynced" << std::endl; - std::terminate(); - } - - for (usize i = 0; i < m_parameters.size(); i++) { - m_parameters[i]->set_value(source.parameters[i]); - } - for (usize i = 0; i < m_pair_parameters.size(); i++) { - m_pair_parameters[i]->set_values(source.pair_parameters[i]); - } - } + // ------------------ Backend Logic ------------------ + void backward(); - Parameters get_all_parameter_values() const { - Parameters result; - result.parameters.resize(m_parameters.size()); - result.pair_parameters.resize(m_pair_parameters.size()); - for (usize i = 0; i < m_parameters.size(); i++) { - result.parameters[i] = m_parameters[i]->get_value(); - } - for (usize i = 0; i < m_pair_parameters.size(); i++) { - result.pair_parameters[i] = m_pair_parameters[i]->get_values(); - } - return result; - } + // ------------------ Management ------------------ + void cleanup(); + void zero_grad(); + void copy_parameter_values(const Parameters& source); + Parameters get_all_parameter_values() const; + Parameters get_all_parameter_gradients() const; - Parameters get_all_parameter_gradients() const { - Parameters result; - result.parameters.resize(m_parameters.size()); - result.pair_parameters.resize(m_pair_parameters.size()); - for (usize i = 0; i < m_parameters.size(); i++) { - result.parameters[i] = m_parameters[i]->get_gradient(); - } - for (usize i = 0; i < m_pair_parameters.size(); i++) { - result.pair_parameters[i] = m_pair_parameters[i]->get_graidents(); - } - return result; + // Accessors for Handles + ValueData& get_value_data(ValueHandle h) { + return m_values[h.index]; } - - // ------------------ Backward Pass ------------------ - void backward() { - if (m_backwardables.empty()) { - return; - } - - // Initialize gradient on last node (loss node) - auto last = std::static_pointer_cast(m_backwardables.back()); - last->m_gradient = static_cast(1); - - // Reverse pass - for (auto it = m_backwardables.rbegin(); it != m_backwardables.rend(); ++it) { - (*it)->backward(); - } + const ValueData& get_value_data(ValueHandle h) const { + return m_values[h.index]; } - void clear_backwardables() { - m_backwardables.clear(); + PairData& get_pair_data(PairHandle h) { + return m_pairs[h.index]; } - - // ------------------ Cleanup ------------------ - void cleanup() { - for (auto& param : m_parameters) { - param->zero_grad(); - } - for (auto& param : m_pair_parameters) { - param->zero_grad(); - } - - m_backwardables.clear(); + const PairData& get_pair_data(PairHandle h) const { + return m_pairs[h.index]; } - // ------------------ Reset ------------------ - void init_zeros() { - for (auto& param : m_parameters) { - param->set_value(0.0); - } - for (auto& param : m_pair_parameters) { - param->set_values(0.0, 0.0); - } - cleanup(); + ValueHandle get_parameter(usize global_index) const { + return ValueHandle(static_cast(global_index)); } - // ------------------ Accessors ------------------ - const std::vector& get_parameters() const { - return m_parameters; - } - const std::vector& get_pair_parameters() const { - return m_pair_parameters; - } - ValuePtr get_parameter(usize index) const { - return m_parameters[index]; - } - PairPtr get_pair_parameter(usize index) const { - return m_pair_parameters[index]; + PairHandle get_pair_parameter(usize global_index) const { + return PairHandle(static_cast(global_index)); } }; diff --git a/src/tuning/info.hpp b/src/tuning/info.hpp index ce6d01e4..717ab935 100644 --- a/src/tuning/info.hpp +++ b/src/tuning/info.hpp @@ -47,4 +47,4 @@ struct Parameters { } }; -} // namespace Clockwork::Autograd +} // namespace Clockwork::Autograd \ No newline at end of file diff --git a/src/tuning/loss.hpp b/src/tuning/loss.hpp index 79d620ce..41eb30bb 100644 --- a/src/tuning/loss.hpp +++ b/src/tuning/loss.hpp @@ -1,7 +1,7 @@ #pragma once #include "value.hpp" -#include +#include #include namespace Clockwork::Autograd { @@ -13,33 +13,31 @@ enum class Reduction { }; template -auto mse(const std::vector& predictions, const std::vector& targets) { +auto mse(const std::vector& predictions, const std::vector& targets) { if (predictions.size() != targets.size()) { throw std::invalid_argument("MSE: predictions and targets must have the same size."); } if constexpr (R == Reduction::None) { - // Return vector of squared errors (no reduction) - std::vector losses; + std::vector losses; losses.reserve(predictions.size()); for (size_t i = 0; i < predictions.size(); ++i) { - ValuePtr diff = predictions[i] - targets[i]; + ValueHandle diff = predictions[i] - targets[i]; losses.push_back(diff * diff); } return losses; } else { - // Compute sum of squared errors in one accumulated node - std::vector losses; + std::vector losses; losses.reserve(predictions.size()); for (size_t i = 0; i < predictions.size(); ++i) { - ValuePtr diff = predictions[i] - targets[i]; + ValueHandle diff = predictions[i] - targets[i]; losses.push_back(diff * diff); } - ValuePtr total_loss = Value::sum(losses); + ValueHandle total_loss = ValueHandle::sum(losses); if constexpr (R == Reduction::Mean) { f64 n = static_cast(predictions.size()); - return total_loss * Value::create(static_cast(1) / n); + return total_loss * (1.0 / n); } else { return total_loss; } diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp new file mode 100644 index 00000000..24900dfb --- /dev/null +++ b/src/tuning/operations.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include "util/types.hpp" + +namespace Clockwork::Autograd { + +enum class OpType : u8 { + // --- Leaf / Special --- + None, + Parameter, // Created from a global parameter + Input, // Created manually (e.g. from data) + + // --- Value Binary Ops --- + Add, + Sub, + Mul, + Div, + Pow, + + // --- Value Unary Ops --- + Exp, + Log, + Sigmoid, + Neg, + PowConst, // x ^ scalar + AddScalar, // x + scalar + SubScalarVal, // scalar - x + ValSubScalar, // x - scalar + MulScalar, // x * scalar + DivScalarVal, // scalar / x + ValDivScalar, // x / scalar + + // --- Pair Ops --- + PairCreate, // (val, val) -> pair + PairAdd, + PairSub, + PairNeg, + + // --- Pair/Scalar Ops --- + PairMulScalar, + PairDivScalar, + ScalarDivPair, + + // --- Pair/Value Ops --- + PairMulValue, + ValueMulPair, // Commutative wrapper usually, but distinct op code helps + PairDivValue, + ValueDivPair, + + // --- Phasing --- + Phase, // Pair -> Value via alpha + + // --- Reduction --- + Sum // Sum of a vector of values +}; + +// A single node in the compute tape. +// Designed to be compact and fit in cache lines. +struct Node { + OpType type; + + u32 output_idx; // Index in the respective arena (Value or Pair) + u32 lhs_idx; // Index of first operand + u32 rhs_idx; // Index of second operand (if applicable) + + // Auxiliary data for scalar ops, constants, or specific parameters + f64 scalar_data; + + // Helper to handle the "Reduction" case (Sum) where we have >2 inputs. + // In a pure tape, we might handle Sum by chaining Adds, or storing an index + // to a side-table of indices. For simplicity/speed in this specific codebase, + // we can implement Sum as a sequence of Adds or use a special range. + // For now, we will implement Sum as a chain of binary adds in the graph + // builder to keep the Node struct fixed-size and simple. +}; + +} // namespace Clockwork::Autograd diff --git a/src/tuning/optim.hpp b/src/tuning/optim.hpp index 037e0598..5ae13062 100644 --- a/src/tuning/optim.hpp +++ b/src/tuning/optim.hpp @@ -6,7 +6,6 @@ #include "util/vec/sse2.hpp" #include -#include #include namespace Clockwork::Autograd { @@ -14,9 +13,8 @@ namespace Clockwork::Autograd { class SGD { private: ParameterCountInfo m_counts; - - f64 m_lr; - f64 m_momentum; + f64 m_lr; + f64 m_momentum; std::vector m_value_velocity; std::vector m_pair_velocity; @@ -31,9 +29,11 @@ class SGD { } void step(Parameters& values, const Parameters& gradients) { + const auto& globals = Globals::get(); + // ---- Value parameters ---- for (size_t i = 0; i < m_counts.parameter_count; ++i) { - if (Globals::get().is_parameter_constant(i)) { + if (globals.is_parameter_constant(i)) { continue; } @@ -42,13 +42,12 @@ class SGD { auto& v = m_value_velocity[i]; v = m_momentum * v - m_lr * p_grad; - p_value += v; } // ---- Pair parameters ---- for (size_t i = 0; i < m_counts.pair_parameter_count; ++i) { - if (Globals::get().is_pair_parameter_constant(i)) { + if (globals.is_pair_parameter_constant(i)) { continue; } @@ -56,13 +55,11 @@ class SGD { auto& p_grad = gradients.pair_parameters[i]; auto& v = m_pair_velocity[i]; - const f128 lr_grad = f128::mul_scalar(p_grad, m_lr); - + const f128 lr_grad = f128::mul_scalar(p_grad, m_lr); const f128 mom_v = f128::mul_scalar(v, m_momentum); const f128 neg_lr_grad = f128::neg(lr_grad); v = f128::add(mom_v, neg_lr_grad); - - p_value = f128::add(p_value, v); + p_value = f128::add(p_value, v); } } @@ -78,13 +75,12 @@ class SGD { class AdamW { private: ParameterCountInfo m_counts; - - f64 m_lr; - f64 m_beta1; - f64 m_beta2; - f64 m_eps; - f64 m_weight_decay; - long long m_t; + f64 m_lr; + f64 m_beta1; + f64 m_beta2; + f64 m_eps; + f64 m_weight_decay; + long long m_t; std::vector m_m; std::vector m_v; @@ -107,13 +103,13 @@ class AdamW { m_t(0) { m_m.resize(m_counts.parameter_count, 0.0); m_v.resize(m_counts.parameter_count, 0.0); - m_pair_m.resize(m_counts.pair_parameter_count, f128::zero()); m_pair_v.resize(m_counts.pair_parameter_count, f128::zero()); } void step(Parameters& values, const Parameters& gradients) { m_t += 1; + const auto& globals = Globals::get(); const f64 b1t = std::pow(m_beta1, static_cast(m_t)); const f64 b2t = std::pow(m_beta2, static_cast(m_t)); @@ -122,7 +118,7 @@ class AdamW { // ---------------- Value parameters ---------------- for (size_t i = 0; i < m_counts.parameter_count; ++i) { - if (Globals::get().is_parameter_constant(i)) { + if (globals.is_parameter_constant(i)) { continue; } @@ -130,24 +126,19 @@ class AdamW { auto& g = gradients.parameters[i]; m_m[i] = m_beta1 * m_m[i] + (1.0 - m_beta1) * g; - m_v[i] = m_beta2 * m_v[i] + (1.0 - m_beta2) * g * g; - const f64 m_hat = m_m[i] * inv1mb1t; - const f64 v_hat = m_v[i] * inv1mb2t; - - const f64 adam_update = m_lr * m_hat / (std::sqrt(v_hat) + m_eps); - + const f64 m_hat = m_m[i] * inv1mb1t; + const f64 v_hat = m_v[i] * inv1mb2t; + const f64 adam_update = m_lr * m_hat / (std::sqrt(v_hat) + m_eps); const f64 weight_decay_update = m_lr * m_weight_decay * p; - const f64 total_update = -(adam_update + weight_decay_update); - - p += total_update; + p += -(adam_update + weight_decay_update); } // ---------------- Pair parameters ---------------- for (size_t i = 0; i < m_counts.pair_parameter_count; ++i) { - if (Globals::get().is_pair_parameter_constant(i)) { + if (globals.is_pair_parameter_constant(i)) { continue; } diff --git a/src/tuning/value.cpp b/src/tuning/value.cpp index 422556d8..8fd8fd4a 100644 --- a/src/tuning/value.cpp +++ b/src/tuning/value.cpp @@ -1,27 +1,268 @@ -#include "value.hpp" -#include "../util/types.hpp" -#include "graph.hpp" -#include "util/vec/sse2.hpp" +#include "tuning/value.hpp" +#include "operations.hpp" +#include "tuning/graph.hpp" +#include "util/types.hpp" #include namespace Clockwork::Autograd { -ValuePtr Value::create(f64 data) { - ValuePtr res = std::make_shared(data); - Graph::get().register_value(res); - return res; +// ------------------ ValueHandle ------------------ + +ValueHandle ValueHandle::create(f64 data) { + return Graph::get().create_value(data); +} + +ValueHandle ValueHandle::sum(const std::vector& inputs) { + if (inputs.empty()) { + return ValueHandle::create(0.0); + } + // Simple linear accumulation on the tape + ValueHandle total = inputs[0]; + for (size_t i = 1; i < inputs.size(); ++i) { + total = total + inputs[i]; + } + return total; +} + +ValueHandle ValueHandle::exp() const { + return Graph::get().record_op(OpType::Exp, *this); +} +ValueHandle ValueHandle::log() const { + return Graph::get().record_op(OpType::Log, *this); +} +ValueHandle ValueHandle::sigmoid() const { + return Graph::get().record_op(OpType::Sigmoid, *this); +} +ValueHandle ValueHandle::pow(ValueHandle exponent) const { + return Graph::get().record_op(OpType::Pow, *this, exponent); +} +ValueHandle ValueHandle::pow(f64 exponent) const { + return Graph::get().record_op(OpType::PowConst, *this, exponent); +} + +void ValueHandle::add_gradient(f64 rhs) const { + if (is_valid()) { + Graph::get().get_value_data(*this).gradient += rhs; + } +} + +f64 ValueHandle::get_value() const { + return is_valid() ? Graph::get().get_value_data(*this).value : 0.0; +} + +f64 ValueHandle::get_gradient() const { + return is_valid() ? Graph::get().get_value_data(*this).gradient : 0.0; +} + +void ValueHandle::zero_grad() const { + if (is_valid()) { + Graph::get().get_value_data(*this).gradient = 0.0; + } +} + +void ValueHandle::set_value(f64 v) const { + if (is_valid()) { + Graph::get().get_value_data(*this).value = v; + } +} + +// ------------------ PairHandle ------------------ + +PairHandle PairHandle::create(f64 first, f64 second) { + return Graph::get().create_pair(f128::make(first, second)); +} + +PairHandle PairHandle::create(const f128& values) { + return Graph::get().create_pair(values); +} + +f128 PairHandle::get_values() const { + return Graph::get().get_pair_data(*this).values; +} +f128 PairHandle::get_gradients() const { + return Graph::get().get_pair_data(*this).gradients; +} +f64 PairHandle::first() const { + return get_values().first(); +} +f64 PairHandle::second() const { + return get_values().second(); +} + +void PairHandle::set_values(const f128& v) const { + Graph::get().get_pair_data(*this).values = v; +} +void PairHandle::set_values(f64 f, f64 s) const { + set_values(f128::make(f, s)); +} + +void PairHandle::zero_grad() const { + Graph::get().get_pair_data(*this).gradients = f128::zero(); +} + +// Implementation of the helper called by the template in the header +ValueHandle PairHandle::phase_impl(f64 scaled_alpha) const { + return Graph::get().record_phase(*this, scaled_alpha); +} + +// ------------------ Operator Implementations ------------------ + +// --- Value Unary/Binary --- +ValueHandle operator-(ValueHandle a) { + return Graph::get().record_op(OpType::Neg, a); +} +ValueHandle operator+(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Add, a, b); +} +ValueHandle operator-(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Sub, a, b); +} +ValueHandle operator*(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Mul, a, b); +} +ValueHandle operator/(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Div, a, b); } -PairPtr Pair::create(f64 first, f64 second) { - PairPtr res = std::make_shared(first, second, true); - Graph::get().register_value(res); - return res; +// --- Value Scalar --- +ValueHandle operator+(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::AddScalar, a, b); +} +ValueHandle operator-(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::ValSubScalar, a, b); +} +ValueHandle operator*(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::MulScalar, a, b); +} +ValueHandle operator/(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::ValDivScalar, a, b); +} + +ValueHandle operator+(f64 a, ValueHandle b) { + return b + a; +} +ValueHandle operator-(f64 a, ValueHandle b) { + return Graph::get().record_op(OpType::SubScalarVal, b, a); +} +ValueHandle operator*(f64 a, ValueHandle b) { + return b * a; +} +ValueHandle operator/(f64 a, ValueHandle b) { + return Graph::get().record_op(OpType::DivScalarVal, b, a); +} + +// --- Comparison --- +bool operator<(ValueHandle a, ValueHandle b) { + return a.get_value() < b.get_value(); +} +bool operator>(ValueHandle a, ValueHandle b) { + return a.get_value() > b.get_value(); } -PairPtr Pair::create(const f128& values) { - PairPtr res = std::make_shared(values, true); - Graph::get().register_value(res); - return res; +// --- Pair Binary --- +PairHandle operator+(PairHandle a, PairHandle b) { + return Graph::get().record_pair_op(OpType::PairAdd, a, b); +} +PairHandle operator-(PairHandle a, PairHandle b) { + return Graph::get().record_pair_op(OpType::PairSub, a, b); +} +PairHandle operator-(PairHandle a) { + return Graph::get().record_pair_scalar(OpType::PairNeg, a, 0.0); +} + +// --- Pair Scalar --- +PairHandle operator*(PairHandle a, f64 scalar) { + return Graph::get().record_pair_scalar(OpType::PairMulScalar, a, scalar); +} +PairHandle operator*(f64 scalar, PairHandle a) { + return a * scalar; +} +PairHandle operator/(PairHandle a, f64 scalar) { + return Graph::get().record_pair_scalar(OpType::PairDivScalar, a, scalar); +} +PairHandle operator/(f64 scalar, PairHandle a) { + return Graph::get().record_pair_scalar(OpType::ScalarDivPair, a, scalar); +} + +// --- Pair Value --- +PairHandle operator*(PairHandle a, ValueHandle v) { + return Graph::get().record_pair_value(OpType::PairMulValue, a, v); +} +PairHandle operator*(ValueHandle v, PairHandle a) { + return Graph::get().record_pair_value(OpType::ValueMulPair, a, v); +} +PairHandle operator/(PairHandle a, ValueHandle v) { + return Graph::get().record_pair_value(OpType::PairDivValue, a, v); +} +PairHandle operator/(ValueHandle v, PairHandle a) { + return Graph::get().record_pair_value(OpType::ValueDivPair, a, v); +} + +std::ostream& operator<<(std::ostream& os, const PairHandle& p) { + os << "S(" << static_cast(p.first() + 0.5) << ", " << static_cast(p.second() + 0.5) + << ")"; + return os; +} + +// --- Inplace Operators (Syntactic Sugar) --- + +ValueHandle& operator+=(ValueHandle& a, ValueHandle b) { + a = a + b; + return a; +} +ValueHandle& operator-=(ValueHandle& a, ValueHandle b) { + a = a - b; + return a; +} +ValueHandle& operator*=(ValueHandle& a, ValueHandle b) { + a = a * b; + return a; +} +ValueHandle& operator/=(ValueHandle& a, ValueHandle b) { + a = a / b; + return a; +} + +ValueHandle& operator+=(ValueHandle& a, f64 b) { + a = a + b; + return a; +} +ValueHandle& operator-=(ValueHandle& a, f64 b) { + a = a - b; + return a; +} +ValueHandle& operator*=(ValueHandle& a, f64 b) { + a = a * b; + return a; +} +ValueHandle& operator/=(ValueHandle& a, f64 b) { + a = a / b; + return a; +} + +PairHandle& operator+=(PairHandle& a, PairHandle b) { + a = a + b; + return a; +} +PairHandle& operator-=(PairHandle& a, PairHandle b) { + a = a - b; + return a; +} +PairHandle& operator*=(PairHandle& a, f64 scalar) { + a = a * scalar; + return a; +} +PairHandle& operator*=(PairHandle& a, ValueHandle v) { + a = a * v; + return a; +} +PairHandle& operator/=(PairHandle& a, f64 scalar) { + a = a / scalar; + return a; +} +PairHandle& operator/=(PairHandle& a, ValueHandle v) { + a = a / v; + return a; } } // namespace Clockwork::Autograd diff --git a/src/tuning/value.hpp b/src/tuning/value.hpp index 73c23015..64abab52 100644 --- a/src/tuning/value.hpp +++ b/src/tuning/value.hpp @@ -1,629 +1,123 @@ #pragma once -#include "../util/types.hpp" +#include "util/types.hpp" #include "util/vec/sse2.hpp" -#include - #include -#include -#include +#include #include namespace Clockwork::Autograd { -// Forward declarations - -class Backwardable; -template -class SmartBackwardable; -class Value; -class Pair; class Graph; - -using ValuePtr = std::shared_ptr; - -using PairPtr = std::shared_ptr; - -using BackwardablePtr = std::shared_ptr; - - -class Backwardable { -public: - friend class Graph; - std::function m_backward_func; - virtual ~Backwardable() = default; - virtual void backward() = 0; -}; - -template -class SmartBackwardable : public Backwardable, public std::enable_shared_from_this { -public: - virtual ~SmartBackwardable() = default; -}; - - -class Value : public SmartBackwardable { -private: - f64 m_value = 0; - f64 m_gradient = 0; - std::function m_backward_func; - -public: - friend class Graph; - friend class Pair; - - explicit Value(f64 data) : - m_value(data) {}; - - inline f64 get_value() const { - return m_value; - } - inline void change_value(f64 amount) { - m_value += amount; - } - inline void set_value(f64 value) { - m_value = value; - } - inline f64 get_gradient() const { - return m_gradient; - } - - inline void zero_grad() { - m_gradient = 0.0; - } - static ValuePtr create(f64 data); - - ValuePtr exp() { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::exp(this->m_value)); - result->m_backward_func = [this_value](ValuePtr out) { - f64 grad = out->m_value; // Avoid recomputing exp val - this_value->m_gradient += grad * out->m_gradient; - }; - return result; - } - - ValuePtr log() { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::log(this->m_value)); - result->m_backward_func = [this_value](ValuePtr out) { - f64 grad = (1 / this_value->m_value); - this_value->m_gradient += grad * out->m_gradient; - }; - return result; - } - - ValuePtr sigmoid() { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(1 / (1 + std::exp(-this_value->m_value))); - result->m_backward_func = [this_value](ValuePtr out) { - f64 grad = out->m_value - * (1 - out->m_value); // Same trick as before, avoid recomputing sigmoid(x) - this_value->m_gradient += grad * out->m_gradient; - }; - return result; - } - - ValuePtr pow(ValuePtr exponent) { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::pow(this_value->m_value, exponent->m_value)); - result->m_backward_func = [this_value, exponent](ValuePtr out) { - this_value->m_gradient += exponent->m_value - * std::pow(this_value->m_value, exponent->m_value - 1) - * out->m_gradient; - exponent->m_gradient += out->m_value * std::log(this_value->m_value) * out->m_gradient; - }; - return result; - } - - ValuePtr pow(f64 exponent) { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::pow(this_value->m_value, exponent)); - result->m_backward_func = [this_value, exponent](ValuePtr out) { - this_value->m_gradient += - exponent * std::pow(this_value->m_value, exponent - 1) * out->m_gradient; - }; - return result; - } - - void add_gradient(f64 rhs) { - m_gradient += rhs; - } - - friend ValuePtr operator-(ValuePtr a) { - ValuePtr result = Value::create(-a->m_value); - result->m_backward_func = [a](ValuePtr out) { - f64 grad = -out->m_gradient; - a->m_gradient += grad; - }; - return result; - } - - friend ValuePtr operator+(ValuePtr a, ValuePtr b) { - ValuePtr result = Value::create(a->m_value + b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += out->m_gradient; - b->m_gradient += out->m_gradient; - }; - return result; +struct ValueHandle { + u32 index; + ValueHandle() : + index(0xFFFFFFFF) { } - - friend ValuePtr operator-(ValuePtr a, - ValuePtr b) { // We are NOT cheaping out with a + (-b) - ValuePtr result = Value::create(a->m_value - b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += out->m_gradient; - b->m_gradient -= out->m_gradient; - }; - return result; - } - - friend ValuePtr operator*(ValuePtr a, ValuePtr b) { - ValuePtr result = Value::create(a->m_value * b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += b->m_value * out->m_gradient; - b->m_gradient += a->m_value * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator/(ValuePtr a, - ValuePtr b) { // We are NOT cheaping out with a * (std::pow(b,-1)) - ValuePtr result = Value::create(a->m_value / b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += 1.0 / b->m_value * out->m_gradient; - b->m_gradient += -a->m_value / (b->m_value * b->m_value) * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator+(ValuePtr a, f64 b) { - ValuePtr result = Value::create(a->m_value + b); - result->m_backward_func = [a](ValuePtr out) { - a->m_gradient += out->m_gradient; - }; - return result; - } - - friend ValuePtr operator-(ValuePtr a, f64 b) { - ValuePtr result = Value::create(a->m_value - b); - result->m_backward_func = [a](ValuePtr out) { - a->m_gradient += out->m_gradient; - }; - return result; - } - - friend ValuePtr operator*(ValuePtr a, f64 b) { - ValuePtr result = Value::create(a->m_value * b); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += b * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator/(ValuePtr a, - f64 b) { // We are NOT cheaping out with a * (std::pow(b,-1)) - ValuePtr result = Value::create(a->m_value / b); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += 1.0 / b * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator+(f64 a, ValuePtr b) { - return b + a; - } - - friend ValuePtr operator-(f64 a, ValuePtr b) { - ValuePtr result = Value::create(a - b->m_value); - result->m_backward_func = [b](ValuePtr out) { - b->m_gradient -= out->m_gradient; - }; - return result; - } - - friend ValuePtr operator*(f64 a, ValuePtr b) { - return b * a; + explicit ValueHandle(u32 idx) : + index(idx) { } - - friend ValuePtr operator/(f64 a, ValuePtr b) { - ValuePtr result = Value::create(a / b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - b->m_gradient += -a / (b->m_value * b->m_value) * out->m_gradient; - }; - return result; - } - - static ValuePtr sum(const std::vector& inputs) { - if (inputs.empty()) { - return Value::create(0.0); - } - - f64 sum = 0.0; - f64 c = 0.0; - for (auto& v : inputs) { - f64 y = v->m_value - c; - f64 t = sum + y; - c = (t - sum) - y; - sum = t; - } - - ValuePtr result = Value::create(sum); - - result->m_backward_func = [inputs](ValuePtr out) { - f64 grad = out->m_gradient; - for (auto& v : inputs) { - v->m_gradient += grad; - } - }; - - return result; - } - - - friend bool operator==(ValuePtr a, ValuePtr b) { - return a->m_value == b->m_value; + bool is_valid() const { + return index != 0xFFFFFFFF; } - friend bool operator!=(ValuePtr a, ValuePtr b) { - return a->m_value != b->m_value; - } - - friend bool operator<(ValuePtr a, ValuePtr b) { - return a->m_value < b->m_value; - } - - friend bool operator<=(ValuePtr a, ValuePtr b) { - return a->m_value <= b->m_value; - } - - friend bool operator>(ValuePtr a, ValuePtr b) { - return a->m_value > b->m_value; - } - - friend bool operator>=(ValuePtr a, ValuePtr b) { - return a->m_value >= b->m_value; - } + static ValueHandle create(f64 data); + static ValueHandle sum(const std::vector& inputs); - friend std::ostream& operator<<(std::ostream& os, const ValuePtr& value) { - os << "Value(data=" << value->get_value() << ", grad=" << value->get_gradient() << ")"; - return os; - } + ValueHandle exp() const; + ValueHandle log() const; + ValueHandle sigmoid() const; + ValueHandle pow(ValueHandle exponent) const; + ValueHandle pow(f64 exponent) const; - void backward() override { - auto this_value = this->shared_from_this(); - if (this_value->m_backward_func) { - this_value->m_backward_func(this_value); - } - } + void add_gradient(f64 rhs) const; + f64 get_value() const; + f64 get_gradient() const; + void zero_grad() const; + void set_value(f64 v) const; }; -class Pair : public SmartBackwardable { -private: - std::function m_backward_func; - bool m_constant; - -public: - friend class Graph; - friend class Value; - - f128 m_values; // stores (first, second) - f128 m_gradients; // stores (grad_first, grad_second) - - explicit Pair(f64 first, f64 second, bool constant = false) : - m_constant(constant), - m_values(f128::make(first, second)), - m_gradients(f128::zero()) { - } - - explicit Pair(const f128& values, bool constant = false) : - m_constant(constant), - m_values(values), - m_gradients(f128::zero()) { - } - - static PairPtr create(f64 first, f64 second); - - static PairPtr create(const f128& values); - - inline f64 first() const { - return m_values.first(); - } - - inline f64 second() const { - return m_values.second(); - } - - inline f128 get_values() const { - return m_values; - } - - inline f64 grad_first() const { - return m_gradients.first(); - } - - inline f64 grad_second() const { - return m_gradients.second(); - } - - inline f128 get_graidents() const { - return m_gradients; - } - - inline void zero_grad() { - m_gradients = f128::zero(); - } - - inline void set_values(const f128& values) { - m_values = values; - } - inline void set_values(f64 first, f64 second) { - m_values = f128::make(first, second); - } - - // ------------------- Backward ------------------- - void backward() override { - auto self = this->shared_from_this(); - if (m_backward_func) { - m_backward_func(self); - } - } - - // ------------------- Arithmetic ------------------- - friend PairPtr operator+(const PairPtr& a, const PairPtr& b) { - f128 result_values = f128::add(a->m_values, b->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, b](PairPtr out) { - a->m_gradients = f128::add(a->m_gradients, out->m_gradients); - b->m_gradients = f128::add(b->m_gradients, out->m_gradients); - }; - return result; - } - - friend PairPtr operator-(const PairPtr& a, const PairPtr& b) { - f128 result_values = f128::sub(a->m_values, b->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, b](PairPtr out) { - a->m_gradients = f128::add(a->m_gradients, out->m_gradients); - b->m_gradients = f128::sub(b->m_gradients, out->m_gradients); - }; - return result; - } - - // ------------------- Scalar multiplication ------------------- - friend PairPtr operator*(const PairPtr& a, f64 scalar) { - f128 result_values = f128::mul_scalar(a->m_values, scalar); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, scalar](PairPtr out) { - f128 scaled_grad = f128::mul_scalar(out->m_gradients, scalar); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); - }; - return result; - } - - friend PairPtr operator*(f64 scalar, const PairPtr& a) { - return a * scalar; - } - - // ------------------- Pair * Value ------------------- - friend PairPtr operator*(const PairPtr& a, const ValuePtr& v) { - f64 v_val = v->get_value(); - f128 result_values = f128::mul_scalar(a->m_values, v_val); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, v](PairPtr out) { - f64 v_val = v->get_value(); - f128 scaled_grad = f128::mul_scalar(out->m_gradients, v_val); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); - - f128 contribution = f128::mul(a->m_values, out->m_gradients); - v->add_gradient(contribution.first() + contribution.second()); - }; - return result; +struct PairHandle { + u32 index; + PairHandle() : + index(0xFFFFFFFF) { } - - friend PairPtr operator*(const ValuePtr& v, const PairPtr& a) { - return a * v; + explicit PairHandle(u32 idx) : + index(idx) { } - - // ------------------- Pair / scalar ------------------- - friend PairPtr operator/(const PairPtr& a, f64 scalar) { - f128 result_values = f128::div_scalar(a->m_values, scalar); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, scalar](PairPtr out) { - f128 scaled_grad = f128::div_scalar(out->m_gradients, scalar); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); - }; - return result; + bool is_valid() const { + return index != 0xFFFFFFFF; } - // ------------------- Scalar / Pair ------------------- - friend PairPtr operator/(f64 scalar, const PairPtr& a) { - f128 result_values = f128::scalar_div(scalar, a->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, scalar](PairPtr out) { - f128 a_squared = f128::mul(a->m_values, a->m_values); - f128 neg_scalar_over_a_sq = f128::neg(f128::scalar_div(scalar, a_squared)); - f128 grad_update = f128::mul(neg_scalar_over_a_sq, out->m_gradients); - a->m_gradients = f128::add(a->m_gradients, grad_update); - }; - return result; + static PairHandle create(f64 first, f64 second); + static PairHandle create(const f128& values); + static PairHandle create_tunable(f64 a, f64 b) { + return create(a, b); } - // ------------------- Pair / Value ------------------- - friend PairPtr operator/(const PairPtr& a, const ValuePtr& v) { - f64 v_val = v->get_value(); - f128 result_values = f128::div_scalar(a->m_values, v_val); - PairPtr result = Pair::create(result_values); + f128 get_values() const; + f128 get_gradients() const; + f64 first() const; + f64 second() const; + void set_values(const f128& v) const; + void set_values(f64 f, f64 s) const; + void zero_grad() const; - result->m_backward_func = [a, v](PairPtr out) { - f64 v_val = v->get_value(); - f128 scaled_grad = f128::div_scalar(out->m_gradients, v_val); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); + // Internal helper to avoid including Graph in header + ValueHandle phase_impl(f64 scaled_alpha) const; - f128 numerator = f128::mul(a->m_values, out->m_gradients); - f64 sum_contribution = numerator.first() + numerator.second(); - v->add_gradient(-sum_contribution / (v_val * v_val)); - }; - return result; - } - - // ------------------- Value / Pair ------------------- - friend PairPtr operator/(const ValuePtr& v, const PairPtr& a) { - f64 v_val = v->get_value(); - f128 result_values = f128::scalar_div(v_val, a->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, v](PairPtr out) { - f64 v_val = v->get_value(); - f128 a_squared = f128::mul(a->m_values, a->m_values); - f128 neg_v_over_a_sq = f128::neg(f128::scalar_div(v_val, a_squared)); - f128 grad_update = f128::mul(neg_v_over_a_sq, out->m_gradients); - a->m_gradients = f128::add(a->m_gradients, grad_update); - - f128 v_grad_contrib = f128::div(out->m_gradients, a->m_values); - v->add_gradient(v_grad_contrib.first() + v_grad_contrib.second()); - }; - return result; - } - - // ------------------- Unary negation ------------------- - friend PairPtr operator-(PairPtr a) { - f128 result_values = f128::neg(a->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a](PairPtr out) { - a->m_gradients = f128::sub(a->m_gradients, out->m_gradients); - }; - return result; - } - - // ------------------- Phasing ------------------- template - ValuePtr phase(f64 alpha) { - alpha /= max; - auto self = this->shared_from_this(); - - // Linear interpolation: alpha * first + (1-alpha) * second - f64 result_val = alpha * first() + (1.0 - alpha) * second(); - ValuePtr result = Value::create(result_val); - - result->m_backward_func = [self, alpha](ValuePtr out) { - // Gradient of output w.r.t first and second - f64 grad_first = alpha * out->m_gradient; - f64 grad_second = (1.0 - alpha) * out->m_gradient; - f128 grad_update = f128::make(grad_first, grad_second); - self->m_gradients = f128::add(self->m_gradients, grad_update); - }; - - return result; - } - - // ------------------- Print ------------------- - friend std::ostream& operator<<(std::ostream& os, const PairPtr& p) { -#ifdef VERBOSE_AUTOGRAD - os << "Pair(first=" << p->first() << ", second=" << p->second() - << ", grad_first=" << p->grad_first() << ", grad_second=" << p->grad_second() << ")"; -#else - os << (p->m_constant ? "CS" : "S"); - os << "(" << static_cast(p->first() + 0.5) << ", " - << static_cast(p->second() + 0.5) << ")"; -#endif - return os; + ValueHandle phase(f64 alpha) const { + return phase_impl(alpha / max); } }; -// Inplace ops (we can't do them as member functions because we use shared pointers) - -// ValuePtr += ValuePtr -inline ValuePtr& operator+=(ValuePtr& a, const ValuePtr& b) { - a = a + b; - return a; -} - -// ValuePtr -= ValuePtr -inline ValuePtr& operator-=(ValuePtr& a, const ValuePtr& b) { - a = a - b; - return a; -} - -// ValuePtr *= ValuePtr -inline ValuePtr& operator*=(ValuePtr& a, const ValuePtr& b) { - a = a * b; - return a; -} - -// ValuePtr /= ValuePtr -inline ValuePtr& operator/=(ValuePtr& a, const ValuePtr& b) { - a = a / b; - return a; -} - - -// ValuePtr += scalar -inline ValuePtr& operator+=(ValuePtr& a, f64 b) { - a = a + b; - return a; -} - -// ValuePtr -= scalar -inline ValuePtr& operator-=(ValuePtr& a, f64 b) { - a = a - b; - return a; -} - -// ValuePtr *= scalar -inline ValuePtr& operator*=(ValuePtr& a, f64 b) { - a = a * b; - return a; -} - -// ValuePtr /= scalar -inline ValuePtr& operator/=(ValuePtr& a, f64 b) { - a = a / b; - return a; -} - -// PairPtr += PairPtr -inline PairPtr& operator+=(PairPtr& a, const PairPtr& b) { - a = a + b; - return a; -} - -// PairPtr -= PairPtr -inline PairPtr& operator-=(PairPtr& a, const PairPtr& b) { - a = a - b; - return a; -} - -// PairPtr *= scalar -inline PairPtr& operator*=(PairPtr& a, f64 scalar) { - a = a * scalar; - return a; -} - -// PairPtr *= ValuePtr -inline PairPtr& operator*=(PairPtr& a, const ValuePtr& v) { - a = a * v; - return a; -} - -// PairPtr /= scalar -inline PairPtr& operator/=(PairPtr& a, f64 scalar) { - a = a / scalar; - return a; -} - -// PairPtr /= ValuePtr -inline PairPtr& operator/=(PairPtr& a, const ValuePtr& v) { - a = a / v; - return a; -} +// --- Operator Declarations --- +ValueHandle operator-(ValueHandle a); +ValueHandle operator+(ValueHandle a, ValueHandle b); +ValueHandle operator-(ValueHandle a, ValueHandle b); +ValueHandle operator*(ValueHandle a, ValueHandle b); +ValueHandle operator/(ValueHandle a, ValueHandle b); +ValueHandle operator+(ValueHandle a, f64 b); +ValueHandle operator-(ValueHandle a, f64 b); +ValueHandle operator*(ValueHandle a, f64 b); +ValueHandle operator/(ValueHandle a, f64 b); +ValueHandle operator+(f64 a, ValueHandle b); +ValueHandle operator-(f64 a, ValueHandle b); +ValueHandle operator*(f64 a, ValueHandle b); +ValueHandle operator/(f64 a, ValueHandle b); +bool operator<(ValueHandle a, ValueHandle b); +bool operator>(ValueHandle a, ValueHandle b); + +PairHandle operator+(PairHandle a, PairHandle b); +PairHandle operator-(PairHandle a, PairHandle b); +PairHandle operator-(PairHandle a); +PairHandle operator*(PairHandle a, f64 scalar); +PairHandle operator*(f64 scalar, PairHandle a); +PairHandle operator/(PairHandle a, f64 scalar); +PairHandle operator/(f64 scalar, PairHandle a); +PairHandle operator*(PairHandle a, ValueHandle v); +PairHandle operator*(ValueHandle v, PairHandle a); +PairHandle operator/(PairHandle a, ValueHandle v); +PairHandle operator/(ValueHandle v, PairHandle a); +std::ostream& operator<<(std::ostream& os, const PairHandle& p); + +// Inplace +ValueHandle& operator+=(ValueHandle& a, ValueHandle b); +ValueHandle& operator-=(ValueHandle& a, ValueHandle b); +ValueHandle& operator*=(ValueHandle& a, ValueHandle b); +ValueHandle& operator/=(ValueHandle& a, ValueHandle b); +ValueHandle& operator+=(ValueHandle& a, f64 b); +ValueHandle& operator-=(ValueHandle& a, f64 b); +ValueHandle& operator*=(ValueHandle& a, f64 b); +ValueHandle& operator/=(ValueHandle& a, f64 b); + +PairHandle& operator+=(PairHandle& a, PairHandle b); +PairHandle& operator-=(PairHandle& a, PairHandle b); +PairHandle& operator*=(PairHandle& a, f64 scalar); +PairHandle& operator*=(PairHandle& a, ValueHandle v); +PairHandle& operator/=(PairHandle& a, f64 scalar); +PairHandle& operator/=(PairHandle& a, ValueHandle v); } // namespace Clockwork::Autograd From fea63ea7b3d834c8c57ce0f335389536f5c931a1 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:21:15 +0100 Subject: [PATCH 02/31] Bench: 12044152 --- src/eval_constants.hpp | 176 +++++++++++++++++++------------------- src/eval_types.hpp | 16 ++-- src/evaltune_main.cpp | 9 +- src/tuning/arena.hpp | 15 ++-- src/tuning/globals.hpp | 28 ------ src/tuning/info.hpp | 2 +- src/tuning/operations.hpp | 16 +--- 7 files changed, 115 insertions(+), 147 deletions(-) diff --git a/src/eval_constants.hpp b/src/eval_constants.hpp index 03a660e7..482180df 100644 --- a/src/eval_constants.hpp +++ b/src/eval_constants.hpp @@ -5,154 +5,154 @@ namespace Clockwork { // clang-format off -inline const PParam PAWN_MAT = S(293, 314); -inline const PParam KNIGHT_MAT = S(1136, 911); -inline const PParam BISHOP_MAT = S(1229, 949); -inline const PParam ROOK_MAT = S(1712, 1681); -inline const PParam QUEEN_MAT = S(3665, 2887); -inline const PParam TEMPO_VAL = S(59, 15); +inline const PParam PAWN_MAT = S(179, 165); +inline const PParam KNIGHT_MAT = S(728, 532); +inline const PParam BISHOP_MAT = S(786, 535); +inline const PParam ROOK_MAT = S(749, 704); +inline const PParam QUEEN_MAT = S(1631, 1200); +inline const PParam TEMPO_VAL = S(58, 15); -inline const PParam BISHOP_PAIR_VAL = S(80, 177); -inline const PParam ROOK_OPEN_VAL = S(104, -28); -inline const PParam ROOK_SEMIOPEN_VAL = S(39, 13); +inline const PParam BISHOP_PAIR_VAL = S(77, 174); +inline const PParam ROOK_OPEN_VAL = S(101, -24); +inline const PParam ROOK_SEMIOPEN_VAL = S(38, 14); -inline const PParam DOUBLED_PAWN_VAL = S(-37, -78); +inline const PParam DOUBLED_PAWN_VAL = S(-36, -78); -inline const PParam POTENTIAL_CHECKER_VAL = S(-74, -2); +inline const PParam POTENTIAL_CHECKER_VAL = S(-75, 0); inline const PParam OUTPOST_KNIGHT_VAL = S(7, 51); -inline const PParam OUTPOST_BISHOP_VAL = S(43, 44); +inline const PParam OUTPOST_BISHOP_VAL = S(42, 43); inline const PParam PAWN_PUSH_THREAT_KNIGHT = S(54, 18); -inline const PParam PAWN_PUSH_THREAT_BISHOP = S(56, -14); -inline const PParam PAWN_PUSH_THREAT_ROOK = S(34, 33); -inline const PParam PAWN_PUSH_THREAT_QUEEN = S(55, -52); +inline const PParam PAWN_PUSH_THREAT_BISHOP = S(55, -14); +inline const PParam PAWN_PUSH_THREAT_ROOK = S(33, 33); +inline const PParam PAWN_PUSH_THREAT_QUEEN = S(52, -45); inline const std::array PAWN_PHALANX = { - S(20, 19), S(63, 31), S(74, 70), S(191, 140), S(561, 241), S(931, 1149), + S(20, 19), S(62, 31), S(72, 69), S(185, 139), S(520, 249), S(641, 696), }; inline const std::array DEFENDED_PAWN = { - S(64, 43), S(61, 31), S(67, 57), S(147, 120), S(689, -90), + S(63, 42), S(59, 31), S(65, 57), S(143, 117), S(607, -50), }; inline const std::array PASSED_PAWN = { - S(-70, -103), S(-60, -85), S(-34, -9), S(19, 70), S(42, 211), S(306, 312), + S(-68, -93), S(-57, -75), S(-31, -2), S(22, 75), S(50, 207), S(281, 301), }; inline const std::array DEFENDED_PASSED_PUSH = { - S(50, -44), S(36, -5), S(21, 27), S(24, 75), S(93, 151), S(144, 295), + S(49, -42), S(36, -4), S(20, 29), S(21, 78), S(83, 157), S(163, 282), }; inline const std::array BLOCKED_PASSED_PAWN = { - S(15, -45), S(3, 2), S(0, -27), S(7, -47), S(3, -96), S(-192, -145), + S(15, -45), S(4, 2), S(0, -26), S(6, -45), S(0, -90), S(-190, -140), }; inline const std::array FRIENDLY_KING_PASSED_PAWN_DISTANCE = { - CS(0, 0), S(12, 103), S(-21, 90), S(-13, 39), S(0, 10), S(9, 15), S(39, 12), S(18, 0), + S(0, 0), S(12, 98), S(-20, 86), S(-13, 36), S(0, 7), S(9, 12), S(39, 9), S(18, -3), }; inline const std::array ENEMY_KING_PASSED_PAWN_DISTANCE = { - CS(0, 0), S(-183, -50), S(33, -3), S(-9, 44), S(13, 75), S(18, 100), S(38, 99), S(-10, 119), + S(0, 0), S(-183, -53), S(27, -6), S(-12, 40), S(10, 69), S(15, 94), S(35, 93), S(-12, 113), }; inline const std::array KNIGHT_MOBILITY = { - S(-226, -223), S(-118, -60), S(-61, -8), S(-18, 25), S(28, 39), S(54, 77), S(91, 73), S(125, 76), S(171, 18), + S(16, -9), S(119, 152), S(174, 205), S(216, 239), S(262, 254), S(287, 292), S(323, 288), S(356, 291), S(401, 233), }; inline const std::array BISHOP_MOBILITY = { - S(-238, -283), S(-163, -98), S(-88, -36), S(-53, 13), S(-22, 43), S(-4, 64), S(14, 78), S(32, 83), S(52, 87), S(65, 83), S(91, 69), S(155, 23), S(187, 2), S(247, -30), + S(26, -65), S(98, 116), S(171, 177), S(205, 227), S(235, 258), S(252, 278), S(271, 292), S(289, 298), S(308, 301), S(322, 297), S(346, 284), S(410, 238), S(439, 218), S(503, 181), }; inline const std::array ROOK_MOBILITY = { - S(-305, -228), S(-151, -81), S(-99, -17), S(-67, -7), S(-40, 16), S(-25, 38), S(-7, 50), S(11, 56), S(28, 68), S(46, 77), S(64, 79), S(77, 81), S(98, 84), S(109, 70), S(255, -56), + S(157, 217), S(287, 410), S(338, 474), S(370, 484), S(396, 508), S(410, 530), S(428, 542), S(446, 548), S(462, 560), S(480, 569), S(498, 571), S(510, 573), S(531, 575), S(543, 561), S(685, 436), }; inline const std::array QUEEN_MOBILITY = { - S(-970, -879), S(-254, -656), S(-158, -523), S(-103, -314), S(-95, -106), S(-56, -1), S(-52, 108), S(-27, 122), S(-22, 177), S(-10, 203), S(-1, 227), S(3, 242), S(22, 234), S(33, 246), S(40, 240), S(53, 236), S(60, 228), S(59, 235), S(85, 190), S(107, 154), S(122, 135), S(165, 69), S(178, 61), S(337, -112), S(372, -155), S(615, -311), S(470, -245), S(793, -417), + S(0, 2), S(721, 361), S(817, 490), S(876, 653), S(888, 837), S(928, 929), S(933, 1032), S(958, 1040), S(965, 1090), S(977, 1113), S(987, 1134), S(993, 1145), S(1013, 1133), S(1025, 1140), S(1034, 1129), S(1050, 1119), S(1060, 1103), S(1065, 1100), S(1097, 1045), S(1127, 996), S(1155, 960), S(1223, 864), S(1222, 865), S(1326, 729), S(1282, 736), S(1243, 733), S(1006, 797), S(905, 780), }; inline const std::array KING_MOBILITY = { - S(447, 2), S(102, -119), S(-2, -28), S(-16, 9), S(-44, 12), S(-78, 18), S(-57, 19), S(-66, 13), S(-67, -35), + S(335, 145), S(140, -113), S(39, -27), S(25, 9), S(-1, 11), S(-34, 17), S(-14, 18), S(-23, 12), S(-24, -35), }; inline const std::array KNIGHT_KING_RING = { - CS(0, 0), S(88, -31), S(158, -78), + S(0, 0), S(85, -28), S(154, -73), }; inline const std::array BISHOP_KING_RING = { - CS(0, 0), S(37, -6), S(137, -43), + S(0, 0), S(35, -4), S(134, -40), }; inline const std::array ROOK_KING_RING = { - CS(0, 0), S(68, -48), S(53, -64), S(104, -63), S(161, -129), + S(0, 0), S(66, -45), S(48, -58), S(96, -55), S(143, -115), }; inline const std::array QUEEN_KING_RING = { - CS(0, 0), S(-37, 27), S(-55, 38), S(0, -8), S(162, -97), S(357, -238), + S(0, 0), S(-52, 64), S(-84, 99), S(-48, 78), S(81, 24), S(211, -61), }; -inline const PParam PAWN_THREAT_KNIGHT = S(240, 57); -inline const PParam PAWN_THREAT_BISHOP = S(205, 99); -inline const PParam PAWN_THREAT_ROOK = S(199, 56); -inline const PParam PAWN_THREAT_QUEEN = S(179, -62); +inline const PParam PAWN_THREAT_KNIGHT = S(234, 57); +inline const PParam PAWN_THREAT_BISHOP = S(200, 99); +inline const PParam PAWN_THREAT_ROOK = S(194, 57); +inline const PParam PAWN_THREAT_QUEEN = S(175, -63); -inline const PParam KNIGHT_THREAT_BISHOP = S(105, 72); -inline const PParam KNIGHT_THREAT_ROOK = S(244, 5); -inline const PParam KNIGHT_THREAT_QUEEN = S(156, -68); +inline const PParam KNIGHT_THREAT_BISHOP = S(102, 71); +inline const PParam KNIGHT_THREAT_ROOK = S(238, 6); +inline const PParam KNIGHT_THREAT_QUEEN = S(156, -77); -inline const PParam BISHOP_THREAT_KNIGHT = S(110, 34); -inline const PParam BISHOP_THREAT_ROOK = S(244, 55); -inline const PParam BISHOP_THREAT_QUEEN = S(192, 48); +inline const PParam BISHOP_THREAT_KNIGHT = S(108, 34); +inline const PParam BISHOP_THREAT_ROOK = S(237, 55); +inline const PParam BISHOP_THREAT_QUEEN = S(192, 35); inline const std::array BISHOP_PAWNS = { - S(1, -6), S(-1, 0), S(0, -11), S(-5, -21), S(-11, -26), S(-17, -33), S(-18, -40), S(-24, -37), S(-34, -43), + S(1, -6), S(-1, 0), S(0, -10), S(-5, -21), S(-11, -26), S(-16, -32), S(-17, -39), S(-23, -37), S(-33, -42), }; inline const std::array PAWN_PSQT = { - S(111, 162), S(101, 209), S(169, 174), S(231, 56), S(177, 50), S(162, 115), S(59, 139), S(118, 115), // - S(79, 44), S(192, 72), S(168, 15), S(169, -42), S(123, -59), S(67, -10), S(28, 36), S(-21, 40), // - S(-1, 12), S(18, 15), S(34, -28), S(21, -42), S(4, -46), S(-38, -40), S(-78, 7), S(-104, 28), // - S(-27, -36), S(-9, -8), S(-14, -41), S(-32, -38), S(-58, -47), S(-78, -39), S(-129, 10), S(-147, -1), // - S(-29, -65), S(34, -63), S(-16, -19), S(-47, -17), S(-67, -26), S(-107, -27), S(-125, -14), S(-149, -20), // - S(-19, -59), S(111, -54), S(66, -18), S(8, 0), S(-29, -12), S(-65, -17), S(-90, 6), S(-128, -6), // + S(234, 316), S(229, 361), S(287, 331), S(350, 215), S(300, 208), S(283, 273), S(182, 295), S(239, 272), // + S(179, 193), S(288, 221), S(266, 165), S(265, 110), S(220, 93), S(167, 140), S(129, 186), S(80, 188), // + S(101, 156), S(120, 160), S(137, 116), S(125, 103), S(108, 98), S(67, 105), S(27, 152), S(1, 173), // + S(76, 109), S(93, 136), S(88, 103), S(70, 107), S(46, 98), S(26, 106), S(-23, 155), S(-41, 143), // + S(74, 79), S(135, 82), S(85, 125), S(55, 127), S(35, 118), S(-2, 117), S(-20, 131), S(-43, 125), // + S(84, 87), S(210, 92), S(165, 127), S(109, 145), S(72, 132), S(38, 128), S(13, 152), S(-22, 139), // }; inline const std::array KNIGHT_PSQT = { - S(-400, -159), S(-350, 59), S(-461, 234), S(-126, 67), S(-256, 92), S(-340, 98), S(-573, 85), S(-545, -17), // - S(0, -2), S(65, 9), S(167, -57), S(112, 8), S(115, 14), S(52, -9), S(-7, 10), S(-26, -35), // - S(57, -29), S(99, 16), S(191, 8), S(142, 30), S(141, 20), S(61, 28), S(49, 3), S(-45, 11), // - S(111, 5), S(101, 27), S(132, 33), S(110, 60), S(119, 46), S(86, 40), S(60, -1), S(35, 4), // - S(100, -14), S(126, -17), S(120, 8), S(92, 25), S(84, 33), S(80, 27), S(53, 0), S(39, -55), // - S(12, -23), S(43, -35), S(36, -14), S(48, 29), S(55, 26), S(-2, 3), S(3, -35), S(-36, -41), // - S(14, -8), S(35, -36), S(19, -30), S(19, -11), S(6, -17), S(-21, -38), S(-10, -52), S(-67, -121), // - S(-34, -60), S(3, -16), S(21, -39), S(30, -33), S(22, -25), S(-24, -55), S(-37, -31), S(-86, -84), // + S(-255, -4), S(-195, 195), S(-271, 345), S(4, 213), S(-115, 234), S(-191, 238), S(-400, 215), S(-389, 125), // + S(123, 148), S(188, 159), S(287, 93), S(237, 154), S(238, 162), S(178, 139), S(118, 159), S(99, 116), // + S(180, 121), S(225, 160), S(311, 156), S(264, 179), S(263, 169), S(184, 178), S(173, 152), S(80, 160), // + S(234, 154), S(225, 174), S(254, 181), S(234, 207), S(241, 195), S(209, 189), S(184, 147), S(160, 154), // + S(223, 135), S(247, 132), S(242, 156), S(214, 174), S(206, 183), S(203, 177), S(177, 150), S(163, 95), // + S(137, 126), S(167, 113), S(161, 133), S(172, 177), S(179, 174), S(123, 152), S(129, 113), S(89, 108), // + S(139, 141), S(160, 112), S(144, 118), S(144, 137), S(131, 131), S(104, 111), S(115, 98), S(59, 32), // + S(91, 94), S(129, 134), S(146, 109), S(154, 116), S(147, 124), S(101, 94), S(88, 120), S(40, 71), // }; inline const std::array BISHOP_PSQT = { - S(-167, 80), S(-187, 61), S(-429, 88), S(-309, 101), S(-260, 104), S(-428, 130), S(-169, 107), S(-122, 82), // - S(5, -32), S(-10, 46), S(8, 27), S(-11, 31), S(-35, 46), S(0, 38), S(-20, 28), S(-57, 29), // - S(35, 24), S(82, 14), S(161, 23), S(89, 22), S(64, 25), S(40, 37), S(97, 9), S(-4, 26), // - S(52, -20), S(65, 11), S(101, 15), S(100, 40), S(106, 40), S(44, 38), S(32, 14), S(-15, 20), // - S(53, -49), S(61, -10), S(68, 6), S(68, 30), S(60, 47), S(20, 38), S(4, 0), S(0, -43), // - S(67, -39), S(113, -20), S(115, -8), S(57, 36), S(37, 40), S(38, 36), S(69, -6), S(31, -36), // - S(53, -75), S(103, -51), S(73, -38), S(45, -8), S(36, -24), S(37, -35), S(19, -20), S(38, -85), // - S(48, -56), S(35, -9), S(39, -3), S(50, -37), S(59, -48), S(55, -7), S(47, -34), S(30, -37), // + S(-24, 259), S(-45, 241), S(-275, 265), S(-159, 279), S(-112, 282), S(-272, 306), S(-28, 286), S(22, 259), // + S(142, 151), S(126, 228), S(146, 208), S(126, 212), S(102, 227), S(137, 219), S(117, 210), S(81, 210), // + S(169, 207), S(215, 197), S(292, 206), S(222, 205), S(198, 208), S(175, 220), S(231, 192), S(131, 209), // + S(186, 162), S(198, 195), S(234, 198), S(232, 223), S(238, 223), S(178, 221), S(166, 197), S(120, 203), // + S(187, 135), S(194, 173), S(202, 190), S(201, 214), S(193, 231), S(154, 221), S(139, 184), S(134, 141), // + S(201, 144), S(246, 163), S(247, 174), S(190, 218), S(171, 223), S(171, 219), S(202, 177), S(165, 148), // + S(187, 109), S(236, 132), S(206, 145), S(179, 174), S(170, 159), S(171, 148), S(153, 163), S(172, 100), // + S(182, 129), S(169, 173), S(174, 179), S(184, 146), S(193, 135), S(189, 176), S(181, 150), S(164, 148), // }; inline const std::array ROOK_PSQT = { - S(105, 11), S(171, 10), S(100, 40), S(100, 34), S(107, 23), S(56, 36), S(63, 39), S(71, 44), // - S(14, 69), S(100, 44), S(175, 21), S(100, 64), S(116, 52), S(63, 61), S(4, 80), S(-4, 86), // - S(2, 46), S(151, 4), S(181, -1), S(181, -5), S(136, 4), S(61, 46), S(78, 33), S(-40, 83), // - S(-28, 40), S(48, 33), S(79, 24), S(102, -11), S(71, 11), S(9, 60), S(-8, 59), S(-79, 67), // - S(-92, -8), S(-10, -3), S(-25, 13), S(-42, 14), S(-47, 11), S(-65, 51), S(-95, 48), S(-114, 37), // - S(-115, -28), S(-40, -55), S(-47, -25), S(-66, -24), S(-49, -42), S(-99, 10), S(-101, -6), S(-124, -9), // - S(-177, -18), S(-77, -78), S(-53, -63), S(-49, -61), S(-56, -56), S(-75, -40), S(-95, -62), S(-127, -46), // - S(-143, -16), S(-112, -11), S(-58, -47), S(-33, -64), S(-47, -49), S(-60, -37), S(-75, -45), S(-93, -29), // + S(556, 471), S(619, 470), S(550, 499), S(552, 493), S(561, 481), S(510, 494), S(518, 497), S(526, 502), // + S(473, 524), S(557, 500), S(629, 479), S(559, 519), S(574, 507), S(523, 516), S(464, 534), S(455, 541), // + S(461, 502), S(606, 463), S(636, 456), S(636, 452), S(592, 461), S(520, 502), S(535, 490), S(420, 539), // + S(431, 497), S(505, 490), S(537, 481), S(560, 445), S(528, 468), S(468, 516), S(450, 516), S(381, 523), // + S(368, 448), S(447, 453), S(432, 470), S(416, 470), S(411, 467), S(393, 507), S(365, 504), S(346, 493), // + S(345, 427), S(418, 401), S(412, 430), S(392, 432), S(409, 413), S(360, 467), S(359, 449), S(337, 447), // + S(285, 438), S(383, 378), S(406, 392), S(410, 395), S(403, 400), S(385, 416), S(365, 394), S(335, 410), // + S(318, 441), S(349, 445), S(401, 410), S(425, 394), S(412, 408), S(399, 420), S(384, 412), S(368, 429), // }; inline const std::array QUEEN_PSQT = { - S(37, 41), S(76, 7), S(82, 13), S(-38, 142), S(33, 61), S(-23, 88), S(42, 2), S(-22, 28), // - S(24, 80), S(-48, 174), S(-53, 230), S(-140, 262), S(-112, 207), S(-123, 206), S(-74, 111), S(-38, 51), // - S(-17, 111), S(72, 108), S(13, 184), S(-4, 195), S(-45, 177), S(-77, 176), S(-7, 71), S(-41, 43), // - S(51, 24), S(55, 92), S(18, 128), S(6, 194), S(-15, 178), S(-24, 112), S(13, 19), S(0, -13), // - S(15, 52), S(50, 14), S(24, 83), S(-15, 140), S(-27, 127), S(-21, 85), S(-10, 6), S(-7, -42), // - S(25, -106), S(50, -62), S(49, 7), S(-2, 33), S(9, -9), S(12, -4), S(23, -75), S(0, -68), // - S(10, -212), S(44, -312), S(31, -175), S(48, -102), S(21, -83), S(38, -154), S(13, -92), S(-4, -90), // - S(-45, -132), S(15, -382), S(12, -369), S(38, -275), S(41, -193), S(45, -238), S(31, -198), S(-21, -122), // + S(831, 784), S(881, 742), S(890, 752), S(787, 868), S(830, 821), S(775, 844), S(820, 779), S(748, 811), // + S(774, 890), S(708, 977), S(719, 1008), S(647, 1024), S(658, 996), S(633, 1014), S(670, 938), S(708, 872), // + S(736, 918), S(820, 921), S(770, 986), S(750, 1001), S(706, 991), S(666, 1007), S(733, 907), S(701, 876), // + S(792, 849), S(793, 922), S(759, 956), S(743, 1031), S(718, 1023), S(710, 956), S(743, 874), S(735, 831), // + S(751, 888), S(782, 856), S(755, 929), S(709, 1002), S(696, 993), S(704, 946), S(718, 863), S(721, 814), // + S(757, 742), S(780, 786), S(777, 862), S(722, 897), S(733, 858), S(735, 865), S(749, 789), S(729, 787), // + S(739, 647), S(769, 554), S(757, 688), S(771, 767), S(745, 786), S(762, 716), S(737, 777), S(724, 767), // + S(685, 730), S(743, 489), S(738, 497), S(763, 597), S(765, 682), S(768, 638), S(757, 672), S(707, 742), // }; inline const std::array KING_PSQT = { - S(-137, -378), S(99, -22), S(-60, 30), S(-168, 54), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // - S(193, -114), S(8, 142), S(26, 128), S(131, 58), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // - S(-33, 56), S(72, 134), S(111, 103), S(95, 58), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // - S(-252, 80), S(36, 95), S(27, 92), S(-25, 77), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // - S(-226, 37), S(-42, 68), S(-25, 73), S(-108, 109), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // - S(-131, 7), S(60, 10), S(-29, 61), S(-73, 87), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // - S(84, -83), S(137, -43), S(49, 3), S(-32, 46), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // - S(-39, -112), S(90, -102), S(-9, -66), S(-26, -65), S(20, -11), S(20, -11), S(20, -11), S(20, -11), // + S(-233, -319), S(30, 3), S(-94, 44), S(-169, 61), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(124, -85), S(-15, 152), S(3, 138), S(118, 67), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-62, 69), S(50, 145), S(88, 114), S(74, 69), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-275, 92), S(12, 106), S(4, 105), S(-47, 89), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-246, 48), S(-66, 80), S(-47, 85), S(-128, 120), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-151, 17), S(35, 23), S(-53, 74), S(-94, 99), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(59, -70), S(110, -28), S(24, 17), S(-54, 59), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-60, -101), S(66, -87), S(-31, -53), S(-48, -52), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // }; // Epoch duration: 61.8411s // clang-format on diff --git a/src/eval_types.hpp b/src/eval_types.hpp index 5fc1d202..ff867406 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -109,8 +109,8 @@ using PParam = PScore; // ============================================================================ using Score = Autograd::ValueHandle; -using PScore = Autograd::PairHandle; // (mg, eg) handle -using PParam = Autograd::PairHandle; // tunable pair +using PScore = Autograd::PairHandle; +using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter #endif @@ -123,16 +123,20 @@ using PParam = Autograd::PairHandle; // tunable pair // Tunable scalar pair (mg, eg) #define S(a, b) Autograd::PairPlaceholder::create_tunable((a), (b)) - // Constant (fixed) scalar pair (mg, eg) + // Constant scalar pair (mg, eg) #define CS(a, b) Autograd::PairPlaceholder::create((a), (b)) - // Zero pair - #define PSCORE_ZERO Autograd::PairPlaceholder::create(0, 0) + // Zero pair FOR PARAMETERS (e.g., in an array) + #define PPARAM_ZERO Autograd::PairPlaceholder::create(0, 0) + + // Zero pair FOR INTERMEDIATES (e.g., scores) + #define PSCORE_ZERO Autograd::PairHandle::create(0, 0) #else - // Non-tuning build: use fixed, non-autograd PScore + // ... (non-tuning definitions) ... #define S(a, b) PScore((a), (b)) #define CS(a, b) PScore((a), (b)) + #define PPARAM_ZERO PScore(0, 0) #define PSCORE_ZERO PScore(0, 0) #endif diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index b382de92..dee2a536 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -39,7 +39,7 @@ int main() { "data/dfrcv1.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3.txt", }; - const u32 thread_count = std::max(1, std::thread::hardware_concurrency() / 2); + const u32 thread_count = std::max(1, std::thread::hardware_concurrency()); std::cout << "Running on " << thread_count << " threads\n"; @@ -94,8 +94,13 @@ int main() { const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts(); + // This line loads the defaults from your S() macros Parameters current_parameter_values = Graph::get().get_all_parameter_values(); + // Uncomment for zero tune: Overwrite them all with zeros. + current_parameter_values = Parameters::zeros(parameter_count); + + // The optimizer will now start with all-zero parameters AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0); const i32 epochs = 1000; @@ -206,7 +211,7 @@ int main() { // Dump current parameter values Graph::get().copy_parameter_values(current_parameter_values); - + Graph::get().cleanup(); Graph::get().zero_grad(); diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index f9780670..83eea8df 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -6,8 +6,8 @@ namespace Clockwork::Autograd { -// A simple contiguous storage for types T. -// Returns indices (handles) instead of pointers. +/// ARENA IMPLEMENTATION \\\ +// Simple vector-based arena for storing values and pairs. Surely can be done better. Kek. template class Arena { private: @@ -21,7 +21,8 @@ class Arena { return idx; } - // Emplace version + // Emplace version we might want later for ops that return many values? + // Might be seeing things. template u32 emplace(Args&&... args) { u32 idx = static_cast(m_data.size()); @@ -44,23 +45,17 @@ class Arena { return m_data.size(); } - // Resets the arena size, effectively clearing it. - // Note: Does not free memory (capacity remains) to reduce allocations next cycle. + // Common std::vector W void clear() { m_data.clear(); } - // Keeps the first `n` elements, effectively clearing the rest. - // Useful for keeping parameters which are at the start of the arena. void reset_to(usize n) { if (n < m_data.size()) { m_data.resize(n); } } - std::vector& raw() { - return m_data; - } }; } // namespace Clockwork::Autograd diff --git a/src/tuning/globals.hpp b/src/tuning/globals.hpp index cec841bb..fa23f9e3 100644 --- a/src/tuning/globals.hpp +++ b/src/tuning/globals.hpp @@ -154,33 +154,5 @@ inline bool Globals::is_pair_parameter_constant(usize i) const { return m_pair_parameters[i]->constant(); } -// --- Helper Operators for Placeholders --- -// These allow Placeholders to be used directly in arithmetic expressions by implicit conversion to Handles. - -inline ValueHandle operator-(ValuePlaceholder a) { - return -static_cast(a); -} -inline ValueHandle operator+(ValuePlaceholder a, ValuePlaceholder b) { - return static_cast(a) + static_cast(b); -} -inline ValueHandle operator-(ValuePlaceholder a, ValuePlaceholder b) { - return static_cast(a) - static_cast(b); -} -inline ValueHandle operator*(ValuePlaceholder a, i32 b) { - return static_cast(a) * static_cast(b); -} -inline ValueHandle operator/(ValuePlaceholder a, i32 b) { - return static_cast(a) / static_cast(b); -} - -inline PairHandle operator-(PairPlaceholder a) { - return -static_cast(a); -} -inline PairHandle operator+(PairPlaceholder a, PairPlaceholder b) { - return static_cast(a) + static_cast(b); -} -inline PairHandle operator-(PairPlaceholder a, PairPlaceholder b) { - return static_cast(a) - static_cast(b); -} } // namespace Clockwork::Autograd diff --git a/src/tuning/info.hpp b/src/tuning/info.hpp index 717ab935..ce6d01e4 100644 --- a/src/tuning/info.hpp +++ b/src/tuning/info.hpp @@ -47,4 +47,4 @@ struct Parameters { } }; -} // namespace Clockwork::Autograd \ No newline at end of file +} // namespace Clockwork::Autograd diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index 24900dfb..f96e7a95 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -5,19 +5,19 @@ namespace Clockwork::Autograd { enum class OpType : u8 { - // --- Leaf / Special --- + // Leaf nodes None, Parameter, // Created from a global parameter Input, // Created manually (e.g. from data) - // --- Value Binary Ops --- + // Binary Ops Add, Sub, Mul, Div, Pow, - // --- Value Unary Ops --- + // Unary Ops Exp, Log, Sigmoid, @@ -54,8 +54,7 @@ enum class OpType : u8 { Sum // Sum of a vector of values }; -// A single node in the compute tape. -// Designed to be compact and fit in cache lines. +// A single node in the compute tape. Probably can be rewritten more compactly. struct Node { OpType type; @@ -65,13 +64,6 @@ struct Node { // Auxiliary data for scalar ops, constants, or specific parameters f64 scalar_data; - - // Helper to handle the "Reduction" case (Sum) where we have >2 inputs. - // In a pure tape, we might handle Sum by chaining Adds, or storing an index - // to a side-table of indices. For simplicity/speed in this specific codebase, - // we can implement Sum as a sequence of Adds or use a special range. - // For now, we will implement Sum as a chain of binary adds in the graph - // builder to keep the Node struct fixed-size and simple. }; } // namespace Clockwork::Autograd From ded6eb1cbd16e05a5c85702211785fe165a02fdc Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:39:03 +0100 Subject: [PATCH 03/31] Cleanup --- src/tuning/graph.cpp | 34 +++++++++++----------------------- src/tuning/graph.hpp | 7 ++++--- src/tuning/operations.hpp | 2 +- src/tuning/value.cpp | 25 +++++++++++-------------- src/tuning/value.hpp | 3 ++- 5 files changed, 29 insertions(+), 42 deletions(-) diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 3d87a1a6..3e8d5421 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -203,21 +203,18 @@ void Graph::backward() { return; } - // Seed gradient of last node + // Our backward model assumes the last operation produces a scalar loss value. const auto& last_node = m_tape.back(); - // Assuming the last node produces a Value (the loss) + + // Initialize gradient of final output to 1.0 to start backprop m_values[last_node.output_idx].gradient = 1.0; - // Reverse iterate + // Rev iterate for (auto it = m_tape.rbegin(); it != m_tape.rend(); ++it) { const Node& node = *it; - // Fetch gradients and values based on op type - // Note: References to vector elements are risky if alloc happened, but backward pass doesn't alloc. - // Using pointers or indices is safer. - switch (node.type) { - // --- Value Binary --- + // Value Binary case OpType::Add: { f64 grad = m_values[node.output_idx].gradient; m_values[node.lhs_idx].gradient += grad; @@ -255,9 +252,8 @@ void Graph::backward() { break; } - // --- Value Unary --- + // Value Unary case OpType::Exp: { - // d/dx e^x = e^x = y f64 grad = m_values[node.output_idx].gradient; f64 val = m_values[node.output_idx].value; m_values[node.lhs_idx].gradient += val * grad; @@ -312,7 +308,7 @@ void Graph::backward() { break; } - // --- Pair Binary --- + // Pair Binary case OpType::PairAdd: { f128 grad = m_pairs[node.output_idx].gradients; m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad); @@ -326,7 +322,7 @@ void Graph::backward() { break; } - // --- Pair Scalar --- + // Pair Scalar case OpType::PairNeg: { f128 grad = m_pairs[node.output_idx].gradients; m_pairs[node.lhs_idx].gradients = f128::sub(m_pairs[node.lhs_idx].gradients, grad); @@ -356,25 +352,21 @@ void Graph::backward() { break; } - // --- Pair Value --- + // Pair-Value case OpType::PairMulValue: case OpType::ValueMulPair: { - // out = p * v f128 grad_out = m_pairs[node.output_idx].gradients; f128 p = m_pairs[node.lhs_idx].values; f64 v = m_values[node.rhs_idx].value; - // d/dp = v f128 grad_p = f128::mul_scalar(grad_out, v); m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); - // d/dv = p.first * grad.first + p.second * grad.second f128 contrib = f128::mul(p, grad_out); m_values[node.rhs_idx].gradient += contrib.first() + contrib.second(); break; } case OpType::PairDivValue: { - // out = p / v f128 grad_out = m_pairs[node.output_idx].gradients; f128 p = m_pairs[node.lhs_idx].values; f64 v = m_values[node.rhs_idx].value; @@ -382,35 +374,31 @@ void Graph::backward() { f128 grad_p = f128::div_scalar(grad_out, v); m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); - // d/dv = -p/v^2 * grad f128 num = f128::mul(p, grad_out); f64 sum_contr = num.first() + num.second(); m_values[node.rhs_idx].gradient += -sum_contr / (v * v); break; } case OpType::ValueDivPair: { - // out = v / p f128 grad_out = m_pairs[node.output_idx].gradients; f128 p = m_pairs[node.lhs_idx].values; f64 v = m_values[node.rhs_idx].value; - // d/dp = -v/p^2 f128 p_sq = f128::mul(p, p); f128 neg_v_sq = f128::neg(f128::scalar_div(v, p_sq)); f128 grad_p = f128::mul(neg_v_sq, grad_out); m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); - // d/dv = 1/p f128 v_contr = f128::div(grad_out, p); m_values[node.rhs_idx].gradient += v_contr.first() + v_contr.second(); break; } - // --- Phase --- + // Special Case phase case OpType::Phase: { f64 grad = m_values[node.output_idx].gradient; f64 alpha = node.scalar_data; - // d/d_first = alpha, d/d_second = 1-alpha + f128 grad_upd = f128::make(alpha * grad, (1.0 - alpha) * grad); m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_upd); break; diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index f834e616..3c0e6afe 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -39,11 +39,12 @@ class Graph { public: static Graph& get(); - // ------------------ Creation ------------------ + // Creation ValueHandle create_value(f64 data, bool is_parameter = false); PairHandle create_pair(f128 data, bool is_parameter = false); - // ------------------ Operation Recording ------------------ + // Operation recording stuff + // Value-Value Binary ValueHandle record_op(OpType op, ValueHandle lhs, ValueHandle rhs); // Value Unary / Scalar @@ -55,7 +56,7 @@ class Graph { // Pair-Value PairHandle record_pair_value(OpType op, PairHandle pair, ValueHandle val); - // Special Phase + // Handling phasing separately due to its unique nature, probably can be done better ValueHandle record_phase(PairHandle input, f64 alpha); // ------------------ Backend Logic ------------------ diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index f96e7a95..7543b8e8 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -62,7 +62,7 @@ struct Node { u32 lhs_idx; // Index of first operand u32 rhs_idx; // Index of second operand (if applicable) - // Auxiliary data for scalar ops, constants, or specific parameters + // Auxiliary data for scalar ops, constants, or specific parameters. f64 scalar_data; }; diff --git a/src/tuning/value.cpp b/src/tuning/value.cpp index 8fd8fd4a..05741eda 100644 --- a/src/tuning/value.cpp +++ b/src/tuning/value.cpp @@ -16,7 +16,7 @@ ValueHandle ValueHandle::sum(const std::vector& inputs) { if (inputs.empty()) { return ValueHandle::create(0.0); } - // Simple linear accumulation on the tape + // Simple linear accumulation on the tape. I dropped the old optimized version for the arena rewrite, but its the top priority for future optimization. ValueHandle total = inputs[0]; for (size_t i = 1; i < inputs.size(); ++i) { total = total + inputs[i]; @@ -66,7 +66,7 @@ void ValueHandle::set_value(f64 v) const { } } -// ------------------ PairHandle ------------------ +// PairHandle implementations PairHandle PairHandle::create(f64 first, f64 second) { return Graph::get().create_pair(f128::make(first, second)); @@ -100,14 +100,13 @@ void PairHandle::zero_grad() const { Graph::get().get_pair_data(*this).gradients = f128::zero(); } -// Implementation of the helper called by the template in the header + +// Special phasing case ValueHandle PairHandle::phase_impl(f64 scaled_alpha) const { return Graph::get().record_phase(*this, scaled_alpha); } -// ------------------ Operator Implementations ------------------ - -// --- Value Unary/Binary --- +// ValueHandle Operators ValueHandle operator-(ValueHandle a) { return Graph::get().record_op(OpType::Neg, a); } @@ -124,7 +123,6 @@ ValueHandle operator/(ValueHandle a, ValueHandle b) { return Graph::get().record_op(OpType::Div, a, b); } -// --- Value Scalar --- ValueHandle operator+(ValueHandle a, f64 b) { return Graph::get().record_op(OpType::AddScalar, a, b); } @@ -151,7 +149,6 @@ ValueHandle operator/(f64 a, ValueHandle b) { return Graph::get().record_op(OpType::DivScalarVal, b, a); } -// --- Comparison --- bool operator<(ValueHandle a, ValueHandle b) { return a.get_value() < b.get_value(); } @@ -159,7 +156,7 @@ bool operator>(ValueHandle a, ValueHandle b) { return a.get_value() > b.get_value(); } -// --- Pair Binary --- +// PairHandle Operators PairHandle operator+(PairHandle a, PairHandle b) { return Graph::get().record_pair_op(OpType::PairAdd, a, b); } @@ -170,7 +167,6 @@ PairHandle operator-(PairHandle a) { return Graph::get().record_pair_scalar(OpType::PairNeg, a, 0.0); } -// --- Pair Scalar --- PairHandle operator*(PairHandle a, f64 scalar) { return Graph::get().record_pair_scalar(OpType::PairMulScalar, a, scalar); } @@ -184,7 +180,6 @@ PairHandle operator/(f64 scalar, PairHandle a) { return Graph::get().record_pair_scalar(OpType::ScalarDivPair, a, scalar); } -// --- Pair Value --- PairHandle operator*(PairHandle a, ValueHandle v) { return Graph::get().record_pair_value(OpType::PairMulValue, a, v); } @@ -198,13 +193,13 @@ PairHandle operator/(ValueHandle v, PairHandle a) { return Graph::get().record_pair_value(OpType::ValueDivPair, a, v); } +// Printing overloads for debugging std::ostream& operator<<(std::ostream& os, const PairHandle& p) { - os << "S(" << static_cast(p.first() + 0.5) << ", " << static_cast(p.second() + 0.5) - << ")"; + os << "S(" << std::round(p.first()) << ", " << std::round(p.second()) << ")"; return os; } -// --- Inplace Operators (Syntactic Sugar) --- +// Value Inplaces ValueHandle& operator+=(ValueHandle& a, ValueHandle b) { a = a + b; @@ -240,6 +235,8 @@ ValueHandle& operator/=(ValueHandle& a, f64 b) { return a; } +// Pair Inplaces + PairHandle& operator+=(PairHandle& a, PairHandle b) { a = a + b; return a; diff --git a/src/tuning/value.hpp b/src/tuning/value.hpp index 64abab52..3ab3fc5d 100644 --- a/src/tuning/value.hpp +++ b/src/tuning/value.hpp @@ -103,7 +103,7 @@ PairHandle operator/(PairHandle a, ValueHandle v); PairHandle operator/(ValueHandle v, PairHandle a); std::ostream& operator<<(std::ostream& os, const PairHandle& p); -// Inplace +// Value Inplaces ValueHandle& operator+=(ValueHandle& a, ValueHandle b); ValueHandle& operator-=(ValueHandle& a, ValueHandle b); ValueHandle& operator*=(ValueHandle& a, ValueHandle b); @@ -113,6 +113,7 @@ ValueHandle& operator-=(ValueHandle& a, f64 b); ValueHandle& operator*=(ValueHandle& a, f64 b); ValueHandle& operator/=(ValueHandle& a, f64 b); +// Pair Inplaces PairHandle& operator+=(PairHandle& a, PairHandle b); PairHandle& operator-=(PairHandle& a, PairHandle b); PairHandle& operator*=(PairHandle& a, f64 scalar); From 16a3283a560cf09f08029b84ea9d6f0401808027 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:41:12 +0100 Subject: [PATCH 04/31] Cleanup2 --- src/tuning/graph.cpp | 8 ++------ src/tuning/graph.hpp | 2 -- src/tuning/operations.hpp | 2 +- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 3e8d5421..be08b12c 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -180,8 +180,6 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) break; } m_pairs[out].values = res; - // Note: rhs is stored in rhs_idx, but it refers to m_values arena! - // The OpType tells us which arena to look at. m_tape.push_back({op, out, pair.index, val.index, 0.0}); return PairHandle(out); } @@ -189,15 +187,13 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { u32 out = m_values.alloc({0.0, 0.0}); f128 p = m_pairs[input.index].values; - // Linear interpolation: alpha * first + (1-alpha) * second + f64 val = alpha * p.first() + (1.0 - alpha) * p.second(); m_values[out].value = val; m_tape.push_back({OpType::Phase, out, input.index, 0, alpha}); return ValueHandle(out); } -// ------------------ Backward ------------------ - void Graph::backward() { if (m_tape.empty()) { return; @@ -398,7 +394,7 @@ void Graph::backward() { case OpType::Phase: { f64 grad = m_values[node.output_idx].gradient; f64 alpha = node.scalar_data; - + f128 grad_upd = f128::make(alpha * grad, (1.0 - alpha) * grad); m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_upd); break; diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 3c0e6afe..2a6a299f 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -59,10 +59,8 @@ class Graph { // Handling phasing separately due to its unique nature, probably can be done better ValueHandle record_phase(PairHandle input, f64 alpha); - // ------------------ Backend Logic ------------------ void backward(); - // ------------------ Management ------------------ void cleanup(); void zero_grad(); void copy_parameter_values(const Parameters& source); diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index 7543b8e8..81d38203 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -56,7 +56,7 @@ enum class OpType : u8 { // A single node in the compute tape. Probably can be rewritten more compactly. struct Node { - OpType type; + OpType type; // This tells us which arenas to look at and how to interpret lhs/rhs u32 output_idx; // Index in the respective arena (Value or Pair) u32 lhs_idx; // Index of first operand From 8ce66c3c498c105cac8d92f4131d193659804a40 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:44:14 +0100 Subject: [PATCH 05/31] Cleanup3 --- src/tuning/operations.hpp | 14 +++++++------- src/tuning/value.cpp | 2 -- src/tuning/value.hpp | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index 81d38203..0ccac494 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -30,27 +30,27 @@ enum class OpType : u8 { DivScalarVal, // scalar / x ValDivScalar, // x / scalar - // --- Pair Ops --- - PairCreate, // (val, val) -> pair + // Pair Ops + PairCreate, PairAdd, PairSub, PairNeg, - // --- Pair/Scalar Ops --- + // Pair-Scalar Ops PairMulScalar, PairDivScalar, ScalarDivPair, - // --- Pair/Value Ops --- + // Pair-Value Ops PairMulValue, - ValueMulPair, // Commutative wrapper usually, but distinct op code helps + ValueMulPair, PairDivValue, ValueDivPair, - // --- Phasing --- + // Phasing Phase, // Pair -> Value via alpha - // --- Reduction --- + // Reduction (TODO: optimize this) Sum // Sum of a vector of values }; diff --git a/src/tuning/value.cpp b/src/tuning/value.cpp index 05741eda..6bb05b1e 100644 --- a/src/tuning/value.cpp +++ b/src/tuning/value.cpp @@ -200,7 +200,6 @@ std::ostream& operator<<(std::ostream& os, const PairHandle& p) { } // Value Inplaces - ValueHandle& operator+=(ValueHandle& a, ValueHandle b) { a = a + b; return a; @@ -236,7 +235,6 @@ ValueHandle& operator/=(ValueHandle& a, f64 b) { } // Pair Inplaces - PairHandle& operator+=(PairHandle& a, PairHandle b) { a = a + b; return a; diff --git a/src/tuning/value.hpp b/src/tuning/value.hpp index 3ab3fc5d..452287e1 100644 --- a/src/tuning/value.hpp +++ b/src/tuning/value.hpp @@ -73,7 +73,7 @@ struct PairHandle { } }; -// --- Operator Declarations --- +// Operation decls ValueHandle operator-(ValueHandle a); ValueHandle operator+(ValueHandle a, ValueHandle b); ValueHandle operator-(ValueHandle a, ValueHandle b); From e6a25f82255fbb45e49651909acf713ce600f831 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:44:54 +0100 Subject: [PATCH 06/31] Format Bench: 12044152 --- src/eval_types.hpp | 4 ++-- src/tuning/arena.hpp | 3 +-- src/tuning/graph.cpp | 2 +- src/tuning/operations.hpp | 4 ++-- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/eval_types.hpp b/src/eval_types.hpp index ff867406..e2da0b85 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -110,7 +110,7 @@ using PParam = PScore; using Score = Autograd::ValueHandle; using PScore = Autograd::PairHandle; -using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter +using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter #endif @@ -140,4 +140,4 @@ using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter #define PSCORE_ZERO PScore(0, 0) #endif -} // namespace Clockwork \ No newline at end of file +} // namespace Clockwork diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index 83eea8df..f2f5266c 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -21,7 +21,7 @@ class Arena { return idx; } - // Emplace version we might want later for ops that return many values? + // Emplace version we might want later for ops that return many values? // Might be seeing things. template u32 emplace(Args&&... args) { @@ -55,7 +55,6 @@ class Arena { m_data.resize(n); } } - }; } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index be08b12c..b86c9ed4 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -187,7 +187,7 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { u32 out = m_values.alloc({0.0, 0.0}); f128 p = m_pairs[input.index].values; - + f64 val = alpha * p.first() + (1.0 - alpha) * p.second(); m_values[out].value = val; m_tape.push_back({OpType::Phase, out, input.index, 0, alpha}); diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index 0ccac494..09404d32 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -31,7 +31,7 @@ enum class OpType : u8 { ValDivScalar, // x / scalar // Pair Ops - PairCreate, + PairCreate, PairAdd, PairSub, PairNeg, @@ -56,7 +56,7 @@ enum class OpType : u8 { // A single node in the compute tape. Probably can be rewritten more compactly. struct Node { - OpType type; // This tells us which arenas to look at and how to interpret lhs/rhs + OpType type; // This tells us which arenas to look at and how to interpret lhs/rhs u32 output_idx; // Index in the respective arena (Value or Pair) u32 lhs_idx; // Index of first operand From 024971884ed82e407b25b14a175ce9caeabf9b6c Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:49:32 +0100 Subject: [PATCH 07/31] Cleanup + slight optim Bench: 12044152 --- src/evaltune_main.cpp | 29 +++++++++-------------------- src/tuning/graph.cpp | 4 ++-- src/tuning/graph.hpp | 4 ++-- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index dee2a536..1538319d 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -39,7 +39,7 @@ int main() { "data/dfrcv1.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3.txt", }; - const u32 thread_count = std::max(1, std::thread::hardware_concurrency()); + const u32 thread_count = std::max(1, std::thread::hardware_concurrency() / 2); std::cout << "Running on " << thread_count << " threads\n"; @@ -88,10 +88,7 @@ int main() { return 1; } - // ------------------------------ - // Setup Autograd system - // ------------------------------ - + // Setup tuning const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts(); // This line loads the defaults from your S() macros @@ -109,7 +106,9 @@ int main() { std::mt19937 rng(std::random_device{}()); std::vector indices(positions.size()); - const size_t total_batches = (positions.size() + batch_size - 1) / batch_size; + // Initialize indices 1..N + std::iota(indices.begin(), indices.end(), 0); + const size_t total_batches = (positions.size() + batch_size - 1) / batch_size; // Shared gradient accumulator Parameters batch_gradients = Parameters::zeros(parameter_count); @@ -123,13 +122,10 @@ int main() { batch_gradients = Parameters::zeros(parameter_count); }}; - // ------------------------------ - // Worker threads - // ------------------------------ + // Spawn worker threads for (u32 t = 0; t < thread_count; ++t) { std::thread([&, t]() { // Each thread uses its own Graph arena - for (int epoch = 0; epoch < epochs; ++epoch) { epoch_barrier.arrive_and_wait(); @@ -151,9 +147,7 @@ int main() { outputs.reserve(sub_end - sub_start); targets.reserve(sub_end - sub_start); - // ------------------------------ - // Forward pass - // ------------------------------ + // Forward for (size_t j = sub_start; j < sub_end; ++j) { size_t idx = indices[j]; @@ -163,9 +157,7 @@ int main() { targets.push_back(y); } - // ------------------------------ - // Loss and backward - // ------------------------------ + // Backward ValueHandle loss = mse(outputs, targets) * ValueHandle::create(1.0 / double(this_batch_size)); @@ -188,16 +180,13 @@ int main() { }).detach(); } - // ------------------------------ - // Main thread: epoch coordinator - // ------------------------------ + // Epoch loop for (int epoch = 0; epoch < epochs; ++epoch) { std::cout << "Epoch " << epoch + 1 << "/" << epochs << "\n"; const auto start = time::Clock::now(); - std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), rng); epoch_barrier.arrive_and_wait(); diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index b86c9ed4..58c17ef5 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -25,11 +25,11 @@ Graph& Graph::get() { return instance; } -ValueHandle Graph::create_value(f64 data, bool is_parameter) { +ValueHandle Graph::create_value(f64 data) { return ValueHandle(m_values.alloc({data, 0.0})); } -PairHandle Graph::create_pair(f128 data, bool is_parameter) { +PairHandle Graph::create_pair(f128 data) { return PairHandle(m_pairs.alloc({data, f128::zero()})); } diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 2a6a299f..0b43a29a 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -40,8 +40,8 @@ class Graph { static Graph& get(); // Creation - ValueHandle create_value(f64 data, bool is_parameter = false); - PairHandle create_pair(f128 data, bool is_parameter = false); + ValueHandle create_value(f64 data); + PairHandle create_pair(f128 data); // Operation recording stuff From 3b014eb67cfdbdb1a5b9373020e074ba9f24172a Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:51:05 +0100 Subject: [PATCH 08/31] Useless comment begone Bench: 12044152 --- src/evaltune_main.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index 1538319d..37f48068 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -29,9 +29,6 @@ using namespace Clockwork::Autograd; int main() { - // ------------------------------ - // Load FENs - // ------------------------------ std::vector positions; std::vector results; From b0b96210080924a21b55204b5c0ca97171ed6a36 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 14:59:05 +0100 Subject: [PATCH 09/31] [[nodiscards]] are back on the menu Bench: 12044152 --- src/eval_types.hpp | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/eval_types.hpp b/src/eval_types.hpp index e2da0b85..50890fa9 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -16,9 +16,6 @@ namespace Clockwork { #ifndef EVAL_TUNING -// ============================================================================ -// NORMAL BUILD (NO TUNING) -// ============================================================================ using Score = i16; class PScore { @@ -42,14 +39,14 @@ class PScore { && endgame <= std::numeric_limits::max()); } - inline Score mg() const { + [[nodiscard]] inline Score mg() const { u16 mg = u16(m_score); i16 v; std::memcpy(&v, &mg, sizeof(mg)); return v; } - inline Score eg() const { + [[nodiscard]] inline Score eg() const { u16 eg = u16(u32(m_score + 0x8000) >> 16); i16 v; std::memcpy(&v, &eg, sizeof(eg)); @@ -90,7 +87,7 @@ class PScore { // Phase function (non-tuning: returns int) template - inline Value phase(i32 alpha) const { + [[nodiscard]] inline Value phase(i32 alpha) const { assert(0 <= alpha && alpha <= max); return Value((mg() * alpha + eg() * (max - alpha)) / max); } @@ -104,9 +101,6 @@ class PScore { using PParam = PScore; #else -// ============================================================================ -// TUNING BUILD (NEW AUTOGRAD API) -// ============================================================================ using Score = Autograd::ValueHandle; using PScore = Autograd::PairHandle; @@ -115,10 +109,6 @@ using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter #endif -// ============================================================================ -// Macro Definitions -// ============================================================================ - #ifdef EVAL_TUNING // Tunable scalar pair (mg, eg) #define S(a, b) Autograd::PairPlaceholder::create_tunable((a), (b)) From f68e3e07122369b16211806b944698a74ea43d98 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 15:05:38 +0100 Subject: [PATCH 10/31] f128 is evil Bench: 12044152 --- src/tuning/globals.hpp | 10 ++-- src/tuning/graph.cpp | 126 ++++++++++++++++++++--------------------- src/tuning/graph.hpp | 6 +- src/tuning/info.hpp | 8 +-- src/tuning/optim.hpp | 42 +++++++------- src/tuning/value.cpp | 14 ++--- src/tuning/value.hpp | 8 +-- src/util/vec/sse2.hpp | 56 +++++++++--------- 8 files changed, 135 insertions(+), 135 deletions(-) diff --git a/src/tuning/globals.hpp b/src/tuning/globals.hpp index fa23f9e3..6f969f0a 100644 --- a/src/tuning/globals.hpp +++ b/src/tuning/globals.hpp @@ -111,18 +111,18 @@ class ValuePlaceholder { class PairPlaceholder { public: - explicit PairPlaceholder(f128 default_value, bool constant) : + explicit PairPlaceholder(f64x2 default_value, bool constant) : m_index(Globals::get().register_param(this)), m_default_value(default_value), m_constant(constant) { } static PairPlaceholder create_tunable(f64 a, f64 b) { - return PairPlaceholder(f128::make(a, b), false); + return PairPlaceholder(f64x2::make(a, b), false); } static PairPlaceholder create(f64 a, f64 b) { - return PairPlaceholder(f128::make(a, b), true); + return PairPlaceholder(f64x2::make(a, b), true); } // Conversion to Handle: Delegates to the Graph @@ -133,7 +133,7 @@ class PairPlaceholder { usize index() const { return m_index; } - f128 default_value() const { + f64x2 default_value() const { return m_default_value; } bool constant() const { @@ -142,7 +142,7 @@ class PairPlaceholder { private: usize m_index; - f128 m_default_value; + f64x2 m_default_value; bool m_constant; }; diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 58c17ef5..48526c85 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -16,7 +16,7 @@ Graph::Graph() { m_values.alloc({p->default_value(), 0.0}); } for (auto* p : pair_params) { - m_pairs.alloc({p->default_value(), f128::zero()}); + m_pairs.alloc({p->default_value(), f64x2::zero()}); } } @@ -29,8 +29,8 @@ ValueHandle Graph::create_value(f64 data) { return ValueHandle(m_values.alloc({data, 0.0})); } -PairHandle Graph::create_pair(f128 data) { - return PairHandle(m_pairs.alloc({data, f128::zero()})); +PairHandle Graph::create_pair(f64x2 data) { + return PairHandle(m_pairs.alloc({data, f64x2::zero()})); } // ------------------ Recording ------------------ @@ -113,17 +113,17 @@ ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { } PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { - u32 out = m_pairs.alloc({f128::zero(), f128::zero()}); - f128 l = m_pairs[lhs.index].values; - f128 r = m_pairs[rhs.index].values; - f128 res = f128::zero(); + u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); + f64x2 l = m_pairs[lhs.index].values; + f64x2 r = m_pairs[rhs.index].values; + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairAdd: - res = f128::add(l, r); + res = f64x2::add(l, r); break; case OpType::PairSub: - res = f128::sub(l, r); + res = f64x2::sub(l, r); break; default: break; @@ -134,22 +134,22 @@ PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { } PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { - u32 out = m_pairs.alloc({f128::zero(), f128::zero()}); - f128 l = m_pairs[input.index].values; - f128 res = f128::zero(); + u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); + f64x2 l = m_pairs[input.index].values; + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairNeg: - res = f128::neg(l); + res = f64x2::neg(l); break; case OpType::PairMulScalar: - res = f128::mul_scalar(l, scalar); + res = f64x2::mul_scalar(l, scalar); break; case OpType::PairDivScalar: - res = f128::div_scalar(l, scalar); + res = f64x2::div_scalar(l, scalar); break; case OpType::ScalarDivPair: - res = f128::scalar_div(scalar, l); + res = f64x2::scalar_div(scalar, l); break; default: break; @@ -160,21 +160,21 @@ PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { } PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) { - u32 out = m_pairs.alloc({f128::zero(), f128::zero()}); - f128 p = m_pairs[pair.index].values; + u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); + f64x2 p = m_pairs[pair.index].values; f64 v = m_values[val.index].value; - f128 res = f128::zero(); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairMulValue: case OpType::ValueMulPair: - res = f128::mul_scalar(p, v); + res = f64x2::mul_scalar(p, v); break; case OpType::PairDivValue: - res = f128::div_scalar(p, v); + res = f64x2::div_scalar(p, v); break; case OpType::ValueDivPair: - res = f128::scalar_div(v, p); + res = f64x2::scalar_div(v, p); break; default: break; @@ -186,7 +186,7 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { u32 out = m_values.alloc({0.0, 0.0}); - f128 p = m_pairs[input.index].values; + f64x2 p = m_pairs[input.index].values; f64 val = alpha * p.first() + (1.0 - alpha) * p.second(); m_values[out].value = val; @@ -306,86 +306,86 @@ void Graph::backward() { // Pair Binary case OpType::PairAdd: { - f128 grad = m_pairs[node.output_idx].gradients; - m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad); - m_pairs[node.rhs_idx].gradients = f128::add(m_pairs[node.rhs_idx].gradients, grad); + f64x2 grad = m_pairs[node.output_idx].gradients; + m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad); + m_pairs[node.rhs_idx].gradients = f64x2::add(m_pairs[node.rhs_idx].gradients, grad); break; } case OpType::PairSub: { - f128 grad = m_pairs[node.output_idx].gradients; - m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad); - m_pairs[node.rhs_idx].gradients = f128::sub(m_pairs[node.rhs_idx].gradients, grad); + f64x2 grad = m_pairs[node.output_idx].gradients; + m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad); + m_pairs[node.rhs_idx].gradients = f64x2::sub(m_pairs[node.rhs_idx].gradients, grad); break; } // Pair Scalar case OpType::PairNeg: { - f128 grad = m_pairs[node.output_idx].gradients; - m_pairs[node.lhs_idx].gradients = f128::sub(m_pairs[node.lhs_idx].gradients, grad); + f64x2 grad = m_pairs[node.output_idx].gradients; + m_pairs[node.lhs_idx].gradients = f64x2::sub(m_pairs[node.lhs_idx].gradients, grad); break; } case OpType::PairMulScalar: { - f128 grad = m_pairs[node.output_idx].gradients; - f128 scaled_grad = f128::mul_scalar(grad, node.scalar_data); + f64x2 grad = m_pairs[node.output_idx].gradients; + f64x2 scaled_grad = f64x2::mul_scalar(grad, node.scalar_data); m_pairs[node.lhs_idx].gradients = - f128::add(m_pairs[node.lhs_idx].gradients, scaled_grad); + f64x2::add(m_pairs[node.lhs_idx].gradients, scaled_grad); break; } case OpType::PairDivScalar: { - f128 grad = m_pairs[node.output_idx].gradients; - f128 scaled_grad = f128::div_scalar(grad, node.scalar_data); + f64x2 grad = m_pairs[node.output_idx].gradients; + f64x2 scaled_grad = f64x2::div_scalar(grad, node.scalar_data); m_pairs[node.lhs_idx].gradients = - f128::add(m_pairs[node.lhs_idx].gradients, scaled_grad); + f64x2::add(m_pairs[node.lhs_idx].gradients, scaled_grad); break; } case OpType::ScalarDivPair: { - f128 grad = m_pairs[node.output_idx].gradients; - f128 l = m_pairs[node.lhs_idx].values; - f128 l_sq = f128::mul(l, l); - f128 neg_s_over_sq = f128::neg(f128::scalar_div(node.scalar_data, l_sq)); - f128 update = f128::mul(neg_s_over_sq, grad); - m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, update); + f64x2 grad = m_pairs[node.output_idx].gradients; + f64x2 l = m_pairs[node.lhs_idx].values; + f64x2 l_sq = f64x2::mul(l, l); + f64x2 neg_s_over_sq = f64x2::neg(f64x2::scalar_div(node.scalar_data, l_sq)); + f64x2 update = f64x2::mul(neg_s_over_sq, grad); + m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, update); break; } // Pair-Value case OpType::PairMulValue: case OpType::ValueMulPair: { - f128 grad_out = m_pairs[node.output_idx].gradients; - f128 p = m_pairs[node.lhs_idx].values; + f64x2 grad_out = m_pairs[node.output_idx].gradients; + f64x2 p = m_pairs[node.lhs_idx].values; f64 v = m_values[node.rhs_idx].value; - f128 grad_p = f128::mul_scalar(grad_out, v); - m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); + f64x2 grad_p = f64x2::mul_scalar(grad_out, v); + m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); - f128 contrib = f128::mul(p, grad_out); + f64x2 contrib = f64x2::mul(p, grad_out); m_values[node.rhs_idx].gradient += contrib.first() + contrib.second(); break; } case OpType::PairDivValue: { - f128 grad_out = m_pairs[node.output_idx].gradients; - f128 p = m_pairs[node.lhs_idx].values; + f64x2 grad_out = m_pairs[node.output_idx].gradients; + f64x2 p = m_pairs[node.lhs_idx].values; f64 v = m_values[node.rhs_idx].value; - f128 grad_p = f128::div_scalar(grad_out, v); - m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); + f64x2 grad_p = f64x2::div_scalar(grad_out, v); + m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); - f128 num = f128::mul(p, grad_out); + f64x2 num = f64x2::mul(p, grad_out); f64 sum_contr = num.first() + num.second(); m_values[node.rhs_idx].gradient += -sum_contr / (v * v); break; } case OpType::ValueDivPair: { - f128 grad_out = m_pairs[node.output_idx].gradients; - f128 p = m_pairs[node.lhs_idx].values; + f64x2 grad_out = m_pairs[node.output_idx].gradients; + f64x2 p = m_pairs[node.lhs_idx].values; f64 v = m_values[node.rhs_idx].value; - f128 p_sq = f128::mul(p, p); - f128 neg_v_sq = f128::neg(f128::scalar_div(v, p_sq)); - f128 grad_p = f128::mul(neg_v_sq, grad_out); - m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_p); + f64x2 p_sq = f64x2::mul(p, p); + f64x2 neg_v_sq = f64x2::neg(f64x2::scalar_div(v, p_sq)); + f64x2 grad_p = f64x2::mul(neg_v_sq, grad_out); + m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); - f128 v_contr = f128::div(grad_out, p); + f64x2 v_contr = f64x2::div(grad_out, p); m_values[node.rhs_idx].gradient += v_contr.first() + v_contr.second(); break; } @@ -395,8 +395,8 @@ void Graph::backward() { f64 grad = m_values[node.output_idx].gradient; f64 alpha = node.scalar_data; - f128 grad_upd = f128::make(alpha * grad, (1.0 - alpha) * grad); - m_pairs[node.lhs_idx].gradients = f128::add(m_pairs[node.lhs_idx].gradients, grad_upd); + f64x2 grad_upd = f64x2::make(alpha * grad, (1.0 - alpha) * grad); + m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_upd); break; } default: @@ -417,7 +417,7 @@ void Graph::zero_grad() { m_values[i].gradient = 0.0; } for (usize i = 0; i < m_global_pair_count; ++i) { - m_pairs[i].gradients = f128::zero(); + m_pairs[i].gradients = f64x2::zero(); } } diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 0b43a29a..28eef7a2 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -17,8 +17,8 @@ struct ValueData { }; struct PairData { - f128 values; - f128 gradients; + f64x2 values; + f64x2 gradients; }; class Graph { @@ -41,7 +41,7 @@ class Graph { // Creation ValueHandle create_value(f64 data); - PairHandle create_pair(f128 data); + PairHandle create_pair(f64x2 data); // Operation recording stuff diff --git a/src/tuning/info.hpp b/src/tuning/info.hpp index ce6d01e4..d0c53747 100644 --- a/src/tuning/info.hpp +++ b/src/tuning/info.hpp @@ -14,12 +14,12 @@ struct ParameterCountInfo { struct Parameters { std::vector parameters; - std::vector pair_parameters; + std::vector pair_parameters; static Parameters zeros(ParameterCountInfo counts) { Parameters result; result.parameters.resize(counts.parameter_count, 0.0); - result.pair_parameters.resize(counts.pair_parameter_count, f128::zero()); + result.pair_parameters.resize(counts.pair_parameter_count, f64x2::zero()); return result; } @@ -30,7 +30,7 @@ struct Parameters { parameters[i] += b.parameters[i]; } for (usize i = 0; i < pair_parameters.size(); i++) { - pair_parameters[i] = f128::add(pair_parameters[i], b.pair_parameters[i]); + pair_parameters[i] = f64x2::add(pair_parameters[i], b.pair_parameters[i]); } } @@ -42,7 +42,7 @@ struct Parameters { } for (usize i = 0; i < pair_parameters.size(); i++) { pair_parameters[i] = - f128::madd(pair_parameters[i], f128::broadcast(weight), b.pair_parameters[i]); + f64x2::madd(pair_parameters[i], f64x2::broadcast(weight), b.pair_parameters[i]); } } }; diff --git a/src/tuning/optim.hpp b/src/tuning/optim.hpp index 5ae13062..f0a1a958 100644 --- a/src/tuning/optim.hpp +++ b/src/tuning/optim.hpp @@ -17,7 +17,7 @@ class SGD { f64 m_momentum; std::vector m_value_velocity; - std::vector m_pair_velocity; + std::vector m_pair_velocity; public: explicit SGD(ParameterCountInfo counts, f64 lr, f64 momentum = 0.9) : @@ -25,7 +25,7 @@ class SGD { m_lr(lr), m_momentum(momentum) { m_value_velocity.resize(m_counts.parameter_count, 0.0); - m_pair_velocity.resize(m_counts.pair_parameter_count, f128::zero()); + m_pair_velocity.resize(m_counts.pair_parameter_count, f64x2::zero()); } void step(Parameters& values, const Parameters& gradients) { @@ -55,11 +55,11 @@ class SGD { auto& p_grad = gradients.pair_parameters[i]; auto& v = m_pair_velocity[i]; - const f128 lr_grad = f128::mul_scalar(p_grad, m_lr); - const f128 mom_v = f128::mul_scalar(v, m_momentum); - const f128 neg_lr_grad = f128::neg(lr_grad); - v = f128::add(mom_v, neg_lr_grad); - p_value = f128::add(p_value, v); + const f64x2 lr_grad = f64x2::mul_scalar(p_grad, m_lr); + const f64x2 mom_v = f64x2::mul_scalar(v, m_momentum); + const f64x2 neg_lr_grad = f64x2::neg(lr_grad); + v = f64x2::add(mom_v, neg_lr_grad); + p_value = f64x2::add(p_value, v); } } @@ -84,8 +84,8 @@ class AdamW { std::vector m_m; std::vector m_v; - std::vector m_pair_m; - std::vector m_pair_v; + std::vector m_pair_m; + std::vector m_pair_v; public: explicit AdamW(ParameterCountInfo counts, @@ -103,8 +103,8 @@ class AdamW { m_t(0) { m_m.resize(m_counts.parameter_count, 0.0); m_v.resize(m_counts.parameter_count, 0.0); - m_pair_m.resize(m_counts.pair_parameter_count, f128::zero()); - m_pair_v.resize(m_counts.pair_parameter_count, f128::zero()); + m_pair_m.resize(m_counts.pair_parameter_count, f64x2::zero()); + m_pair_v.resize(m_counts.pair_parameter_count, f64x2::zero()); } void step(Parameters& values, const Parameters& gradients) { @@ -147,18 +147,18 @@ class AdamW { auto& m = m_pair_m[i]; auto& v = m_pair_v[i]; - const f128 g2 = f128::mul(g, g); + const f64x2 g2 = f64x2::mul(g, g); - const f128 m_scaled = f128::mul_scalar(m, m_beta1); - const f128 g_scaled = f128::mul_scalar(g, (1.0 - m_beta1)); - m = f128::add(m_scaled, g_scaled); + const f64x2 m_scaled = f64x2::mul_scalar(m, m_beta1); + const f64x2 g_scaled = f64x2::mul_scalar(g, (1.0 - m_beta1)); + m = f64x2::add(m_scaled, g_scaled); - const f128 v_scaled = f128::mul_scalar(v, m_beta2); - const f128 g2_scaled = f128::mul_scalar(g2, (1.0 - m_beta2)); - v = f128::add(v_scaled, g2_scaled); + const f64x2 v_scaled = f64x2::mul_scalar(v, m_beta2); + const f64x2 g2_scaled = f64x2::mul_scalar(g2, (1.0 - m_beta2)); + v = f64x2::add(v_scaled, g2_scaled); - const f128 m_hat = f128::mul_scalar(m, inv1mb1t); - const f128 v_hat = f128::mul_scalar(v, inv1mb2t); + const f64x2 m_hat = f64x2::mul_scalar(m, inv1mb1t); + const f64x2 v_hat = f64x2::mul_scalar(v, inv1mb2t); const f64 adam_upd_f = m_lr * m_hat.first() / (std::sqrt(v_hat.first()) + m_eps); const f64 adam_upd_s = m_lr * m_hat.second() / (std::sqrt(v_hat.second()) + m_eps); @@ -169,7 +169,7 @@ class AdamW { const f64 total_upd_f = -(adam_upd_f + decay_upd_f); const f64 total_upd_s = -(adam_upd_s + decay_upd_s); - p = f128::add(p, f128::make(total_upd_f, total_upd_s)); + p = f64x2::add(p, f64x2::make(total_upd_f, total_upd_s)); } } diff --git a/src/tuning/value.cpp b/src/tuning/value.cpp index 6bb05b1e..78a914ab 100644 --- a/src/tuning/value.cpp +++ b/src/tuning/value.cpp @@ -69,17 +69,17 @@ void ValueHandle::set_value(f64 v) const { // PairHandle implementations PairHandle PairHandle::create(f64 first, f64 second) { - return Graph::get().create_pair(f128::make(first, second)); + return Graph::get().create_pair(f64x2::make(first, second)); } -PairHandle PairHandle::create(const f128& values) { +PairHandle PairHandle::create(const f64x2& values) { return Graph::get().create_pair(values); } -f128 PairHandle::get_values() const { +f64x2 PairHandle::get_values() const { return Graph::get().get_pair_data(*this).values; } -f128 PairHandle::get_gradients() const { +f64x2 PairHandle::get_gradients() const { return Graph::get().get_pair_data(*this).gradients; } f64 PairHandle::first() const { @@ -89,15 +89,15 @@ f64 PairHandle::second() const { return get_values().second(); } -void PairHandle::set_values(const f128& v) const { +void PairHandle::set_values(const f64x2& v) const { Graph::get().get_pair_data(*this).values = v; } void PairHandle::set_values(f64 f, f64 s) const { - set_values(f128::make(f, s)); + set_values(f64x2::make(f, s)); } void PairHandle::zero_grad() const { - Graph::get().get_pair_data(*this).gradients = f128::zero(); + Graph::get().get_pair_data(*this).gradients = f64x2::zero(); } diff --git a/src/tuning/value.hpp b/src/tuning/value.hpp index 452287e1..fe5b7b9f 100644 --- a/src/tuning/value.hpp +++ b/src/tuning/value.hpp @@ -51,16 +51,16 @@ struct PairHandle { } static PairHandle create(f64 first, f64 second); - static PairHandle create(const f128& values); + static PairHandle create(const f64x2& values); static PairHandle create_tunable(f64 a, f64 b) { return create(a, b); } - f128 get_values() const; - f128 get_gradients() const; + f64x2 get_values() const; + f64x2 get_gradients() const; f64 first() const; f64 second() const; - void set_values(const f128& v) const; + void set_values(const f64x2& v) const; void set_values(f64 f, f64 s) const; void zero_grad() const; diff --git a/src/util/vec/sse2.hpp b/src/util/vec/sse2.hpp index 0b212b68..3203398b 100644 --- a/src/util/vec/sse2.hpp +++ b/src/util/vec/sse2.hpp @@ -11,7 +11,7 @@ #define F128_USE_SSE2 0 #endif -struct f128 { +struct f64x2 { #if F128_USE_SSE2 __m128d v = _mm_setzero_pd(); #else @@ -20,9 +20,9 @@ struct f128 { #endif // ---- Constructors ---- - static inline f128 make(double a, double b) { + static inline f64x2 make(double a, double b) { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_set_pd(b, a); return r; #else @@ -30,9 +30,9 @@ struct f128 { #endif } - static inline f128 broadcast(double x) { + static inline f64x2 broadcast(double x) { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_set1_pd(x); return r; #else @@ -40,9 +40,9 @@ struct f128 { #endif } - static inline f128 zero() { + static inline f64x2 zero() { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_setzero_pd(); return r; #else @@ -72,9 +72,9 @@ struct f128 { } // ---- Arithmetic ---- - static inline f128 add(const f128& a, const f128& b) { + static inline f64x2 add(const f64x2& a, const f64x2& b) { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_add_pd(a.v, b.v); return r; #else @@ -82,9 +82,9 @@ struct f128 { #endif } - static inline f128 sub(const f128& a, const f128& b) { + static inline f64x2 sub(const f64x2& a, const f64x2& b) { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_sub_pd(a.v, b.v); return r; #else @@ -92,9 +92,9 @@ struct f128 { #endif } - static inline f128 mul(const f128& a, const f128& b) { + static inline f64x2 mul(const f64x2& a, const f64x2& b) { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_mul_pd(a.v, b.v); return r; #else @@ -102,9 +102,9 @@ struct f128 { #endif } - static inline f128 div(const f128& a, const f128& b) { + static inline f64x2 div(const f64x2& a, const f64x2& b) { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_div_pd(a.v, b.v); return r; #else @@ -112,10 +112,10 @@ struct f128 { #endif } - static inline f128 neg(const f128& a) { + static inline f64x2 neg(const f64x2& a) { #if F128_USE_SSE2 __m128d zero = _mm_setzero_pd(); - f128 r; + f64x2 r; r.v = _mm_sub_pd(zero, a.v); return r; #else @@ -124,26 +124,26 @@ struct f128 { } // ---- Scalar ops ---- - static inline f128 add_scalar(const f128& a, double s) { + static inline f64x2 add_scalar(const f64x2& a, double s) { return add(a, broadcast(s)); } - static inline f128 sub_scalar(const f128& a, double s) { + static inline f64x2 sub_scalar(const f64x2& a, double s) { return sub(a, broadcast(s)); } - static inline f128 mul_scalar(const f128& a, double s) { + static inline f64x2 mul_scalar(const f64x2& a, double s) { return mul(a, broadcast(s)); } - static inline f128 div_scalar(const f128& a, double s) { + static inline f64x2 div_scalar(const f64x2& a, double s) { return div(a, broadcast(s)); } - static inline f128 scalar_div(double s, const f128& a) { + static inline f64x2 scalar_div(double s, const f64x2& a) { #if F128_USE_SSE2 __m128d num = _mm_set1_pd(s); - f128 r; + f64x2 r; r.v = _mm_div_pd(num, a.v); return r; #else @@ -152,9 +152,9 @@ struct f128 { } // ---- Math functions ---- - static inline f128 sqrt(const f128& a) { + static inline f64x2 sqrt(const f64x2& a) { #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_sqrt_pd(a.v); return r; #else @@ -163,10 +163,10 @@ struct f128 { } // ---- FMA-style (useful for gradient updates) ---- - static inline f128 madd(const f128& a, const f128& b, const f128& c) { + static inline f64x2 madd(const f64x2& a, const f64x2& b, const f64x2& c) { // a + b*c #if F128_USE_SSE2 - f128 r; + f64x2 r; r.v = _mm_add_pd(a.v, _mm_mul_pd(b.v, c.v)); return r; #else @@ -175,7 +175,7 @@ struct f128 { } // ---- Printing ---- - friend std::ostream& operator<<(std::ostream& os, const f128& f) { + friend std::ostream& operator<<(std::ostream& os, const f64x2& f) { os << "(" << f.first() << ", " << f.second() << ")"; return os; } From b06cd299e29f735b0e01811becc3ef2ec25b6881 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 15:06:42 +0100 Subject: [PATCH 11/31] Format Bench: 12044152 --- src/tuning/globals.hpp | 2 +- src/tuning/graph.cpp | 46 +++++++++++++++++++++--------------------- src/tuning/info.hpp | 2 +- src/tuning/optim.hpp | 14 ++++++------- src/tuning/value.hpp | 10 ++++----- src/util/vec/sse2.hpp | 4 ++-- 6 files changed, 39 insertions(+), 39 deletions(-) diff --git a/src/tuning/globals.hpp b/src/tuning/globals.hpp index 6f969f0a..f9236e4b 100644 --- a/src/tuning/globals.hpp +++ b/src/tuning/globals.hpp @@ -142,7 +142,7 @@ class PairPlaceholder { private: usize m_index; - f64x2 m_default_value; + f64x2 m_default_value; bool m_constant; }; diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 48526c85..212ee38c 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -113,7 +113,7 @@ ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { } PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { - u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); + u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); f64x2 l = m_pairs[lhs.index].values; f64x2 r = m_pairs[rhs.index].values; f64x2 res = f64x2::zero(); @@ -134,7 +134,7 @@ PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { } PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { - u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); + u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); f64x2 l = m_pairs[input.index].values; f64x2 res = f64x2::zero(); @@ -160,9 +160,9 @@ PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { } PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) { - u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); + u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); f64x2 p = m_pairs[pair.index].values; - f64 v = m_values[val.index].value; + f64 v = m_values[val.index].value; f64x2 res = f64x2::zero(); switch (op) { @@ -185,7 +185,7 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) } ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { - u32 out = m_values.alloc({0.0, 0.0}); + u32 out = m_values.alloc({0.0, 0.0}); f64x2 p = m_pairs[input.index].values; f64 val = alpha * p.first() + (1.0 - alpha) * p.second(); @@ -306,13 +306,13 @@ void Graph::backward() { // Pair Binary case OpType::PairAdd: { - f64x2 grad = m_pairs[node.output_idx].gradients; + f64x2 grad = m_pairs[node.output_idx].gradients; m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad); m_pairs[node.rhs_idx].gradients = f64x2::add(m_pairs[node.rhs_idx].gradients, grad); break; } case OpType::PairSub: { - f64x2 grad = m_pairs[node.output_idx].gradients; + f64x2 grad = m_pairs[node.output_idx].gradients; m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad); m_pairs[node.rhs_idx].gradients = f64x2::sub(m_pairs[node.rhs_idx].gradients, grad); break; @@ -320,7 +320,7 @@ void Graph::backward() { // Pair Scalar case OpType::PairNeg: { - f64x2 grad = m_pairs[node.output_idx].gradients; + f64x2 grad = m_pairs[node.output_idx].gradients; m_pairs[node.lhs_idx].gradients = f64x2::sub(m_pairs[node.lhs_idx].gradients, grad); break; } @@ -339,11 +339,11 @@ void Graph::backward() { break; } case OpType::ScalarDivPair: { - f64x2 grad = m_pairs[node.output_idx].gradients; - f64x2 l = m_pairs[node.lhs_idx].values; - f64x2 l_sq = f64x2::mul(l, l); - f64x2 neg_s_over_sq = f64x2::neg(f64x2::scalar_div(node.scalar_data, l_sq)); - f64x2 update = f64x2::mul(neg_s_over_sq, grad); + f64x2 grad = m_pairs[node.output_idx].gradients; + f64x2 l = m_pairs[node.lhs_idx].values; + f64x2 l_sq = f64x2::mul(l, l); + f64x2 neg_s_over_sq = f64x2::neg(f64x2::scalar_div(node.scalar_data, l_sq)); + f64x2 update = f64x2::mul(neg_s_over_sq, grad); m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, update); break; } @@ -353,9 +353,9 @@ void Graph::backward() { case OpType::ValueMulPair: { f64x2 grad_out = m_pairs[node.output_idx].gradients; f64x2 p = m_pairs[node.lhs_idx].values; - f64 v = m_values[node.rhs_idx].value; + f64 v = m_values[node.rhs_idx].value; - f64x2 grad_p = f64x2::mul_scalar(grad_out, v); + f64x2 grad_p = f64x2::mul_scalar(grad_out, v); m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); f64x2 contrib = f64x2::mul(p, grad_out); @@ -365,24 +365,24 @@ void Graph::backward() { case OpType::PairDivValue: { f64x2 grad_out = m_pairs[node.output_idx].gradients; f64x2 p = m_pairs[node.lhs_idx].values; - f64 v = m_values[node.rhs_idx].value; + f64 v = m_values[node.rhs_idx].value; - f64x2 grad_p = f64x2::div_scalar(grad_out, v); + f64x2 grad_p = f64x2::div_scalar(grad_out, v); m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); f64x2 num = f64x2::mul(p, grad_out); - f64 sum_contr = num.first() + num.second(); + f64 sum_contr = num.first() + num.second(); m_values[node.rhs_idx].gradient += -sum_contr / (v * v); break; } case OpType::ValueDivPair: { f64x2 grad_out = m_pairs[node.output_idx].gradients; f64x2 p = m_pairs[node.lhs_idx].values; - f64 v = m_values[node.rhs_idx].value; + f64 v = m_values[node.rhs_idx].value; - f64x2 p_sq = f64x2::mul(p, p); - f64x2 neg_v_sq = f64x2::neg(f64x2::scalar_div(v, p_sq)); - f64x2 grad_p = f64x2::mul(neg_v_sq, grad_out); + f64x2 p_sq = f64x2::mul(p, p); + f64x2 neg_v_sq = f64x2::neg(f64x2::scalar_div(v, p_sq)); + f64x2 grad_p = f64x2::mul(neg_v_sq, grad_out); m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); f64x2 v_contr = f64x2::div(grad_out, p); @@ -395,7 +395,7 @@ void Graph::backward() { f64 grad = m_values[node.output_idx].gradient; f64 alpha = node.scalar_data; - f64x2 grad_upd = f64x2::make(alpha * grad, (1.0 - alpha) * grad); + f64x2 grad_upd = f64x2::make(alpha * grad, (1.0 - alpha) * grad); m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_upd); break; } diff --git a/src/tuning/info.hpp b/src/tuning/info.hpp index d0c53747..d3874d1c 100644 --- a/src/tuning/info.hpp +++ b/src/tuning/info.hpp @@ -13,7 +13,7 @@ struct ParameterCountInfo { }; struct Parameters { - std::vector parameters; + std::vector parameters; std::vector pair_parameters; static Parameters zeros(ParameterCountInfo counts) { diff --git a/src/tuning/optim.hpp b/src/tuning/optim.hpp index f0a1a958..ccad9614 100644 --- a/src/tuning/optim.hpp +++ b/src/tuning/optim.hpp @@ -16,7 +16,7 @@ class SGD { f64 m_lr; f64 m_momentum; - std::vector m_value_velocity; + std::vector m_value_velocity; std::vector m_pair_velocity; public: @@ -58,8 +58,8 @@ class SGD { const f64x2 lr_grad = f64x2::mul_scalar(p_grad, m_lr); const f64x2 mom_v = f64x2::mul_scalar(v, m_momentum); const f64x2 neg_lr_grad = f64x2::neg(lr_grad); - v = f64x2::add(mom_v, neg_lr_grad); - p_value = f64x2::add(p_value, v); + v = f64x2::add(mom_v, neg_lr_grad); + p_value = f64x2::add(p_value, v); } } @@ -82,8 +82,8 @@ class AdamW { f64 m_weight_decay; long long m_t; - std::vector m_m; - std::vector m_v; + std::vector m_m; + std::vector m_v; std::vector m_pair_m; std::vector m_pair_v; @@ -151,11 +151,11 @@ class AdamW { const f64x2 m_scaled = f64x2::mul_scalar(m, m_beta1); const f64x2 g_scaled = f64x2::mul_scalar(g, (1.0 - m_beta1)); - m = f64x2::add(m_scaled, g_scaled); + m = f64x2::add(m_scaled, g_scaled); const f64x2 v_scaled = f64x2::mul_scalar(v, m_beta2); const f64x2 g2_scaled = f64x2::mul_scalar(g2, (1.0 - m_beta2)); - v = f64x2::add(v_scaled, g2_scaled); + v = f64x2::add(v_scaled, g2_scaled); const f64x2 m_hat = f64x2::mul_scalar(m, inv1mb1t); const f64x2 v_hat = f64x2::mul_scalar(v, inv1mb2t); diff --git a/src/tuning/value.hpp b/src/tuning/value.hpp index fe5b7b9f..c8d60241 100644 --- a/src/tuning/value.hpp +++ b/src/tuning/value.hpp @@ -58,11 +58,11 @@ struct PairHandle { f64x2 get_values() const; f64x2 get_gradients() const; - f64 first() const; - f64 second() const; - void set_values(const f64x2& v) const; - void set_values(f64 f, f64 s) const; - void zero_grad() const; + f64 first() const; + f64 second() const; + void set_values(const f64x2& v) const; + void set_values(f64 f, f64 s) const; + void zero_grad() const; // Internal helper to avoid including Graph in header ValueHandle phase_impl(f64 scaled_alpha) const; diff --git a/src/util/vec/sse2.hpp b/src/util/vec/sse2.hpp index 3203398b..e8f302dc 100644 --- a/src/util/vec/sse2.hpp +++ b/src/util/vec/sse2.hpp @@ -115,7 +115,7 @@ struct f64x2 { static inline f64x2 neg(const f64x2& a) { #if F128_USE_SSE2 __m128d zero = _mm_setzero_pd(); - f64x2 r; + f64x2 r; r.v = _mm_sub_pd(zero, a.v); return r; #else @@ -143,7 +143,7 @@ struct f64x2 { static inline f64x2 scalar_div(double s, const f64x2& a) { #if F128_USE_SSE2 __m128d num = _mm_set1_pd(s); - f64x2 r; + f64x2 r; r.v = _mm_div_pd(num, a.v); return r; #else From 648d8110f2427d500fceed79256d869561a3ba9e Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 15:10:35 +0100 Subject: [PATCH 12/31] run meaner.py Bench: 12044152 --- src/eval_constants.hpp | 144 ++++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 68 deletions(-) diff --git a/src/eval_constants.hpp b/src/eval_constants.hpp index 482180df..a7394047 100644 --- a/src/eval_constants.hpp +++ b/src/eval_constants.hpp @@ -5,11 +5,11 @@ namespace Clockwork { // clang-format off -inline const PParam PAWN_MAT = S(179, 165); -inline const PParam KNIGHT_MAT = S(728, 532); -inline const PParam BISHOP_MAT = S(786, 535); -inline const PParam ROOK_MAT = S(749, 704); -inline const PParam QUEEN_MAT = S(1631, 1200); +inline const PParam PAWN_MAT = S(298, 321); +inline const PParam KNIGHT_MAT = S(1089, 895); +inline const PParam BISHOP_MAT = S(1206, 950); +inline const PParam ROOK_MAT = S(1640, 1672); +inline const PParam QUEEN_MAT = S(3372, 2936); inline const PParam TEMPO_VAL = S(58, 15); inline const PParam BISHOP_PAIR_VAL = S(77, 174); @@ -28,55 +28,62 @@ inline const PParam PAWN_PUSH_THREAT_ROOK = S(33, 33); inline const PParam PAWN_PUSH_THREAT_QUEEN = S(52, -45); inline const std::array PAWN_PHALANX = { - S(20, 19), S(62, 31), S(72, 69), S(185, 139), S(520, 249), S(641, 696), + S(20, 19), S(62, 31), S(72, 69), S(185, 139), S(520, 249), S(641, 696) }; inline const std::array DEFENDED_PAWN = { - S(63, 42), S(59, 31), S(65, 57), S(143, 117), S(607, -50), + S(63, 42), S(59, 31), S(65, 57), S(143, 117), S(607, -50) }; inline const std::array PASSED_PAWN = { - S(-68, -93), S(-57, -75), S(-31, -2), S(22, 75), S(50, 207), S(281, 301), + S(-68, -93), S(-57, -75), S(-31, -2), S(22, 75), S(50, 207), S(281, 301) }; inline const std::array DEFENDED_PASSED_PUSH = { - S(49, -42), S(36, -4), S(20, 29), S(21, 78), S(83, 157), S(163, 282), + S(49, -42), S(36, -4), S(20, 29), S(21, 78), S(83, 157), S(163, 282) }; inline const std::array BLOCKED_PASSED_PAWN = { - S(15, -45), S(4, 2), S(0, -26), S(6, -45), S(0, -90), S(-190, -140), + S(15, -45), S(4, 2), S(0, -26), S(6, -45), S(0, -90), S(-190, -140) }; inline const std::array FRIENDLY_KING_PASSED_PAWN_DISTANCE = { - S(0, 0), S(12, 98), S(-20, 86), S(-13, 36), S(0, 7), S(9, 12), S(39, 9), S(18, -3), + S(0, 0), S(12, 98), S(-20, 86), S(-13, 36), S(0, 7), S(9, 12), S(39, 9), S(18, -3) }; inline const std::array ENEMY_KING_PASSED_PAWN_DISTANCE = { - S(0, 0), S(-183, -53), S(27, -6), S(-12, 40), S(10, 69), S(15, 94), S(35, 93), S(-12, 113), + S(0, 0), S(-183, -53), S(27, -6), S(-12, 40), S(10, 69), S(15, 94), S(35, 93), S(-12, 113) }; inline const std::array KNIGHT_MOBILITY = { - S(16, -9), S(119, 152), S(174, 205), S(216, 239), S(262, 254), S(287, 292), S(323, 288), S(356, 291), S(401, 233), + S(-223, -225), S(-120, -64), S(-65, -11), S(-23, 23), S(23, 38), S(48, 76), S(84, 72), S(117, 75), + S(162, 17) }; inline const std::array BISHOP_MOBILITY = { - S(26, -65), S(98, 116), S(171, 177), S(205, 227), S(235, 258), S(252, 278), S(271, 292), S(289, 298), S(308, 301), S(322, 297), S(346, 284), S(410, 238), S(439, 218), S(503, 181), + S(-251, -286), S(-179, -105), S(-106, -44), S(-72, 6), S(-42, 37), S(-25, 57), S(-6, 71), S(12, 77), + S(31, 80), S(45, 76), S(69, 63), S(133, 17), S(162, -3), S(226, -40) }; inline const std::array ROOK_MOBILITY = { - S(157, 217), S(287, 410), S(338, 474), S(370, 484), S(396, 508), S(410, 530), S(428, 542), S(446, 548), S(462, 560), S(480, 569), S(498, 571), S(510, 573), S(531, 575), S(543, 561), S(685, 436), + S(-279, -287), S(-149, -94), S(-98, -30), S(-66, -20), S(-40, 4), S(-26, 26), S(-8, 38), S(10, 44), + S(26, 56), S(44, 65), S(62, 67), S(74, 69), S(95, 71), S(107, 57), S(249, -68) }; inline const std::array QUEEN_MOBILITY = { - S(0, 2), S(721, 361), S(817, 490), S(876, 653), S(888, 837), S(928, 929), S(933, 1032), S(958, 1040), S(965, 1090), S(977, 1113), S(987, 1134), S(993, 1145), S(1013, 1133), S(1025, 1140), S(1034, 1129), S(1050, 1119), S(1060, 1103), S(1065, 1100), S(1097, 1045), S(1127, 996), S(1155, 960), S(1223, 864), S(1222, 865), S(1326, 729), S(1282, 736), S(1243, 733), S(1006, 797), S(905, 780), + S(-996, -893), S(-275, -534), S(-179, -405), S(-120, -242), S(-108, -58), S(-68, 34), S(-63, 137), S(-38, 145), + S(-31, 195), S(-19, 218), S(-9, 239), S(-3, 250), S(17, 238), S(29, 245), S(38, 234), S(54, 224), + S(64, 208), S(69, 205), S(101, 150), S(131, 101), S(159, 65), S(227, -31), S(226, -30), S(330, -166), + S(286, -159), S(247, -162), S(10, -98), S(-91, -115) }; inline const std::array KING_MOBILITY = { - S(335, 145), S(140, -113), S(39, -27), S(25, 9), S(-1, 11), S(-34, 17), S(-14, 18), S(-23, 12), S(-24, -35), + S(286, 141), S(91, -117), S(-10, -31), S(-24, 5), S(-50, 7), S(-83, 13), S(-63, 14), S(-72, 8), + S(-73, -39) }; inline const std::array KNIGHT_KING_RING = { - S(0, 0), S(85, -28), S(154, -73), + S(0, 0), S(85, -28), S(154, -73) }; inline const std::array BISHOP_KING_RING = { - S(0, 0), S(35, -4), S(134, -40), + S(0, 0), S(35, -4), S(134, -40) }; inline const std::array ROOK_KING_RING = { - S(0, 0), S(66, -45), S(48, -58), S(96, -55), S(143, -115), + S(0, 0), S(66, -45), S(48, -58), S(96, -55), S(143, -115) }; inline const std::array QUEEN_KING_RING = { - S(0, 0), S(-52, 64), S(-84, 99), S(-48, 78), S(81, 24), S(211, -61), + S(0, 0), S(-52, 64), S(-84, 99), S(-48, 78), S(81, 24), S(211, -61) }; inline const PParam PAWN_THREAT_KNIGHT = S(234, 57); @@ -93,66 +100,67 @@ inline const PParam BISHOP_THREAT_ROOK = S(237, 55); inline const PParam BISHOP_THREAT_QUEEN = S(192, 35); inline const std::array BISHOP_PAWNS = { - S(1, -6), S(-1, 0), S(0, -10), S(-5, -21), S(-11, -26), S(-16, -32), S(-17, -39), S(-23, -37), S(-33, -42), + S(1, -6), S(-1, 0), S(0, -10), S(-5, -21), S(-11, -26), S(-16, -32), S(-17, -39), S(-23, -37), + S(-33, -42) }; inline const std::array PAWN_PSQT = { - S(234, 316), S(229, 361), S(287, 331), S(350, 215), S(300, 208), S(283, 273), S(182, 295), S(239, 272), // - S(179, 193), S(288, 221), S(266, 165), S(265, 110), S(220, 93), S(167, 140), S(129, 186), S(80, 188), // - S(101, 156), S(120, 160), S(137, 116), S(125, 103), S(108, 98), S(67, 105), S(27, 152), S(1, 173), // - S(76, 109), S(93, 136), S(88, 103), S(70, 107), S(46, 98), S(26, 106), S(-23, 155), S(-41, 143), // - S(74, 79), S(135, 82), S(85, 125), S(55, 127), S(35, 118), S(-2, 117), S(-20, 131), S(-43, 125), // - S(84, 87), S(210, 92), S(165, 127), S(109, 145), S(72, 132), S(38, 128), S(13, 152), S(-22, 139), // + S(115, 160), S(110, 205), S(168, 175), S(231, 59), S(181, 52), S(164, 117), S(63, 139), S(120, 116), + S(60, 37), S(169, 65), S(147, 9), S(146, -46), S(101, -63), S(48, -16), S(10, 30), S(-39, 32), + S(-18, 0), S(1, 4), S(18, -40), S(6, -53), S(-11, -58), S(-52, -51), S(-92, -4), S(-118, 17), + S(-43, -47), S(-26, -20), S(-31, -53), S(-49, -49), S(-73, -58), S(-93, -50), S(-142, -1), S(-160, -13), + S(-45, -77), S(16, -74), S(-34, -31), S(-64, -29), S(-84, -38), S(-121, -39), S(-139, -25), S(-162, -31), + S(-35, -69), S(91, -64), S(46, -29), S(-10, -11), S(-47, -24), S(-81, -28), S(-106, -4), S(-141, -17) }; inline const std::array KNIGHT_PSQT = { - S(-255, -4), S(-195, 195), S(-271, 345), S(4, 213), S(-115, 234), S(-191, 238), S(-400, 215), S(-389, 125), // - S(123, 148), S(188, 159), S(287, 93), S(237, 154), S(238, 162), S(178, 139), S(118, 159), S(99, 116), // - S(180, 121), S(225, 160), S(311, 156), S(264, 179), S(263, 169), S(184, 178), S(173, 152), S(80, 160), // - S(234, 154), S(225, 174), S(254, 181), S(234, 207), S(241, 195), S(209, 189), S(184, 147), S(160, 154), // - S(223, 135), S(247, 132), S(242, 156), S(214, 174), S(206, 183), S(203, 177), S(177, 150), S(163, 95), // - S(137, 126), S(167, 113), S(161, 133), S(172, 177), S(179, 174), S(123, 152), S(129, 113), S(89, 108), // - S(139, 141), S(160, 112), S(144, 118), S(144, 137), S(131, 131), S(104, 111), S(115, 98), S(59, 32), // - S(91, 94), S(129, 134), S(146, 109), S(154, 116), S(147, 124), S(101, 94), S(88, 120), S(40, 71), // + S(-377, -151), S(-317, 48), S(-393, 198), S(-118, 66), S(-237, 87), S(-313, 91), S(-522, 68), S(-511, -22), + S(1, 1), S(66, 12), S(165, -54), S(115, 7), S(116, 15), S(56, -8), S(-4, 12), S(-23, -31), + S(58, -26), S(103, 13), S(189, 9), S(142, 32), S(141, 22), S(62, 31), S(51, 5), S(-42, 13), + S(112, 7), S(103, 27), S(132, 34), S(112, 60), S(119, 48), S(87, 42), S(62, 0), S(38, 7), + S(101, -12), S(125, -15), S(120, 9), S(92, 27), S(84, 36), S(81, 30), S(55, 3), S(41, -52), + S(15, -21), S(45, -34), S(39, -14), S(50, 30), S(57, 27), S(1, 5), S(7, -34), S(-33, -39), + S(17, -6), S(38, -35), S(22, -29), S(22, -10), S(9, -16), S(-18, -36), S(-7, -49), S(-63, -115), + S(-31, -53), S(7, -13), S(24, -38), S(32, -31), S(25, -23), S(-21, -53), S(-34, -27), S(-82, -76) }; inline const std::array BISHOP_PSQT = { - S(-24, 259), S(-45, 241), S(-275, 265), S(-159, 279), S(-112, 282), S(-272, 306), S(-28, 286), S(22, 259), // - S(142, 151), S(126, 228), S(146, 208), S(126, 212), S(102, 227), S(137, 219), S(117, 210), S(81, 210), // - S(169, 207), S(215, 197), S(292, 206), S(222, 205), S(198, 208), S(175, 220), S(231, 192), S(131, 209), // - S(186, 162), S(198, 195), S(234, 198), S(232, 223), S(238, 223), S(178, 221), S(166, 197), S(120, 203), // - S(187, 135), S(194, 173), S(202, 190), S(201, 214), S(193, 231), S(154, 221), S(139, 184), S(134, 141), // - S(201, 144), S(246, 163), S(247, 174), S(190, 218), S(171, 223), S(171, 219), S(202, 177), S(165, 148), // - S(187, 109), S(236, 132), S(206, 145), S(179, 174), S(170, 159), S(171, 148), S(153, 163), S(172, 100), // - S(182, 129), S(169, 173), S(174, 179), S(184, 146), S(193, 135), S(189, 176), S(181, 150), S(164, 148), // + S(-167, 65), S(-188, 47), S(-418, 71), S(-302, 85), S(-255, 88), S(-415, 112), S(-171, 92), S(-121, 65), + S(-1, -43), S(-17, 34), S(3, 14), S(-17, 18), S(-41, 33), S(-6, 25), S(-26, 16), S(-62, 16), + S(26, 13), S(72, 3), S(149, 12), S(79, 11), S(55, 14), S(32, 26), S(88, -2), S(-12, 15), + S(43, -32), S(55, 1), S(91, 4), S(89, 29), S(95, 29), S(35, 27), S(23, 3), S(-23, 9), + S(44, -59), S(51, -21), S(59, -4), S(58, 20), S(50, 37), S(11, 27), S(-4, -10), S(-9, -53), + S(58, -50), S(103, -31), S(104, -20), S(47, 24), S(28, 29), S(28, 25), S(59, -17), S(22, -46), + S(44, -85), S(93, -62), S(63, -49), S(36, -20), S(27, -35), S(28, -46), S(10, -31), S(29, -94), + S(39, -65), S(26, -21), S(31, -15), S(41, -48), S(50, -59), S(46, -18), S(38, -44), S(21, -46) }; inline const std::array ROOK_PSQT = { - S(556, 471), S(619, 470), S(550, 499), S(552, 493), S(561, 481), S(510, 494), S(518, 497), S(526, 502), // - S(473, 524), S(557, 500), S(629, 479), S(559, 519), S(574, 507), S(523, 516), S(464, 534), S(455, 541), // - S(461, 502), S(606, 463), S(636, 456), S(636, 452), S(592, 461), S(520, 502), S(535, 490), S(420, 539), // - S(431, 497), S(505, 490), S(537, 481), S(560, 445), S(528, 468), S(468, 516), S(450, 516), S(381, 523), // - S(368, 448), S(447, 453), S(432, 470), S(416, 470), S(411, 467), S(393, 507), S(365, 504), S(346, 493), // - S(345, 427), S(418, 401), S(412, 430), S(392, 432), S(409, 413), S(360, 467), S(359, 449), S(337, 447), // - S(285, 438), S(383, 378), S(406, 392), S(410, 395), S(403, 400), S(385, 416), S(365, 394), S(335, 410), // - S(318, 441), S(349, 445), S(401, 410), S(425, 394), S(412, 408), S(399, 420), S(384, 412), S(368, 429), // + S(101, 7), S(164, 6), S(95, 35), S(97, 29), S(106, 17), S(55, 30), S(63, 33), S(71, 38), + S(18, 60), S(102, 36), S(174, 15), S(104, 55), S(119, 43), S(68, 52), S(9, 70), S(0, 77), + S(6, 38), S(151, -1), S(181, -8), S(181, -12), S(137, -3), S(65, 38), S(80, 26), S(-35, 75), + S(-24, 33), S(50, 26), S(82, 17), S(105, -19), S(73, 4), S(13, 52), S(-5, 52), S(-74, 59), + S(-87, -16), S(-8, -11), S(-23, 6), S(-39, 6), S(-44, 3), S(-62, 43), S(-90, 40), S(-109, 29), + S(-110, -37), S(-37, -63), S(-43, -34), S(-63, -32), S(-46, -51), S(-95, 3), S(-96, -15), S(-118, -17), + S(-170, -26), S(-72, -86), S(-49, -72), S(-45, -69), S(-52, -64), S(-70, -48), S(-90, -70), S(-120, -54), + S(-137, -23), S(-106, -19), S(-54, -54), S(-30, -70), S(-43, -56), S(-56, -44), S(-71, -52), S(-87, -35) }; inline const std::array QUEEN_PSQT = { - S(831, 784), S(881, 742), S(890, 752), S(787, 868), S(830, 821), S(775, 844), S(820, 779), S(748, 811), // - S(774, 890), S(708, 977), S(719, 1008), S(647, 1024), S(658, 996), S(633, 1014), S(670, 938), S(708, 872), // - S(736, 918), S(820, 921), S(770, 986), S(750, 1001), S(706, 991), S(666, 1007), S(733, 907), S(701, 876), // - S(792, 849), S(793, 922), S(759, 956), S(743, 1031), S(718, 1023), S(710, 956), S(743, 874), S(735, 831), // - S(751, 888), S(782, 856), S(755, 929), S(709, 1002), S(696, 993), S(704, 946), S(718, 863), S(721, 814), // - S(757, 742), S(780, 786), S(777, 862), S(722, 897), S(733, 858), S(735, 865), S(749, 789), S(729, 787), // - S(739, 647), S(769, 554), S(757, 688), S(771, 767), S(745, 786), S(762, 716), S(737, 777), S(724, 767), // - S(685, 730), S(743, 489), S(738, 497), S(763, 597), S(765, 682), S(768, 638), S(757, 672), S(707, 742), // + S(86, -57), S(136, -99), S(145, -89), S(42, 27), S(85, -20), S(30, 3), S(75, -62), S(3, -30), + S(29, 49), S(-37, 136), S(-26, 167), S(-98, 183), S(-87, 155), S(-112, 173), S(-75, 97), S(-37, 31), + S(-9, 77), S(75, 80), S(25, 145), S(5, 160), S(-39, 150), S(-79, 166), S(-12, 66), S(-44, 35), + S(47, 8), S(48, 81), S(14, 115), S(-2, 190), S(-27, 182), S(-35, 115), S(-2, 33), S(-10, -10), + S(6, 47), S(37, 15), S(10, 88), S(-36, 161), S(-49, 152), S(-41, 105), S(-27, 22), S(-24, -27), + S(12, -99), S(35, -55), S(32, 21), S(-23, 56), S(-12, 17), S(-10, 24), S(4, -52), S(-16, -54), + S(-6, -194), S(24, -287), S(12, -153), S(26, -74), S(0, -55), S(17, -125), S(-8, -64), S(-21, -74), + S(-60, -111), S(-2, -352), S(-7, -344), S(18, -244), S(20, -159), S(23, -203), S(12, -169), S(-38, -99) }; inline const std::array KING_PSQT = { - S(-233, -319), S(30, 3), S(-94, 44), S(-169, 61), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // - S(124, -85), S(-15, 152), S(3, 138), S(118, 67), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // - S(-62, 69), S(50, 145), S(88, 114), S(74, 69), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // - S(-275, 92), S(12, 106), S(4, 105), S(-47, 89), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // - S(-246, 48), S(-66, 80), S(-47, 85), S(-128, 120), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // - S(-151, 17), S(35, 23), S(-53, 74), S(-94, 99), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // - S(59, -70), S(110, -28), S(24, 17), S(-54, 59), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // - S(-60, -101), S(66, -87), S(-31, -53), S(-48, -52), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-216, -336), S(47, -14), S(-77, 27), S(-152, 44), S(17, -17), S(17, -17), S(17, -17), S(17, -17), + S(141, -102), S(2, 135), S(20, 121), S(135, 50), S(17, -17), S(17, -17), S(17, -17), S(17, -17), + S(-45, 52), S(67, 128), S(105, 97), S(91, 52), S(17, -17), S(17, -17), S(17, -17), S(17, -17), + S(-258, 75), S(29, 89), S(21, 88), S(-30, 72), S(17, -17), S(17, -17), S(17, -17), S(17, -17), + S(-229, 31), S(-49, 63), S(-30, 68), S(-111, 103), S(17, -17), S(17, -17), S(17, -17), S(17, -17), + S(-134, 0), S(52, 6), S(-36, 57), S(-77, 82), S(17, -17), S(17, -17), S(17, -17), S(17, -17), + S(76, -87), S(127, -45), S(41, 0), S(-37, 42), S(17, -17), S(17, -17), S(17, -17), S(17, -17), + S(-43, -118), S(83, -104), S(-14, -70), S(-31, -69), S(17, -17), S(17, -17), S(17, -17), S(17, -17) }; // Epoch duration: 61.8411s // clang-format on From d409b964e06e9eb70282d907e676beb246006fad Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Fri, 21 Nov 2025 15:36:13 +0100 Subject: [PATCH 13/31] Weird stuff Bench: 12044152 --- src/eval_constants.hpp | 144 +++++++++++++++++++---------------------- 1 file changed, 68 insertions(+), 76 deletions(-) diff --git a/src/eval_constants.hpp b/src/eval_constants.hpp index a7394047..482180df 100644 --- a/src/eval_constants.hpp +++ b/src/eval_constants.hpp @@ -5,11 +5,11 @@ namespace Clockwork { // clang-format off -inline const PParam PAWN_MAT = S(298, 321); -inline const PParam KNIGHT_MAT = S(1089, 895); -inline const PParam BISHOP_MAT = S(1206, 950); -inline const PParam ROOK_MAT = S(1640, 1672); -inline const PParam QUEEN_MAT = S(3372, 2936); +inline const PParam PAWN_MAT = S(179, 165); +inline const PParam KNIGHT_MAT = S(728, 532); +inline const PParam BISHOP_MAT = S(786, 535); +inline const PParam ROOK_MAT = S(749, 704); +inline const PParam QUEEN_MAT = S(1631, 1200); inline const PParam TEMPO_VAL = S(58, 15); inline const PParam BISHOP_PAIR_VAL = S(77, 174); @@ -28,62 +28,55 @@ inline const PParam PAWN_PUSH_THREAT_ROOK = S(33, 33); inline const PParam PAWN_PUSH_THREAT_QUEEN = S(52, -45); inline const std::array PAWN_PHALANX = { - S(20, 19), S(62, 31), S(72, 69), S(185, 139), S(520, 249), S(641, 696) + S(20, 19), S(62, 31), S(72, 69), S(185, 139), S(520, 249), S(641, 696), }; inline const std::array DEFENDED_PAWN = { - S(63, 42), S(59, 31), S(65, 57), S(143, 117), S(607, -50) + S(63, 42), S(59, 31), S(65, 57), S(143, 117), S(607, -50), }; inline const std::array PASSED_PAWN = { - S(-68, -93), S(-57, -75), S(-31, -2), S(22, 75), S(50, 207), S(281, 301) + S(-68, -93), S(-57, -75), S(-31, -2), S(22, 75), S(50, 207), S(281, 301), }; inline const std::array DEFENDED_PASSED_PUSH = { - S(49, -42), S(36, -4), S(20, 29), S(21, 78), S(83, 157), S(163, 282) + S(49, -42), S(36, -4), S(20, 29), S(21, 78), S(83, 157), S(163, 282), }; inline const std::array BLOCKED_PASSED_PAWN = { - S(15, -45), S(4, 2), S(0, -26), S(6, -45), S(0, -90), S(-190, -140) + S(15, -45), S(4, 2), S(0, -26), S(6, -45), S(0, -90), S(-190, -140), }; inline const std::array FRIENDLY_KING_PASSED_PAWN_DISTANCE = { - S(0, 0), S(12, 98), S(-20, 86), S(-13, 36), S(0, 7), S(9, 12), S(39, 9), S(18, -3) + S(0, 0), S(12, 98), S(-20, 86), S(-13, 36), S(0, 7), S(9, 12), S(39, 9), S(18, -3), }; inline const std::array ENEMY_KING_PASSED_PAWN_DISTANCE = { - S(0, 0), S(-183, -53), S(27, -6), S(-12, 40), S(10, 69), S(15, 94), S(35, 93), S(-12, 113) + S(0, 0), S(-183, -53), S(27, -6), S(-12, 40), S(10, 69), S(15, 94), S(35, 93), S(-12, 113), }; inline const std::array KNIGHT_MOBILITY = { - S(-223, -225), S(-120, -64), S(-65, -11), S(-23, 23), S(23, 38), S(48, 76), S(84, 72), S(117, 75), - S(162, 17) + S(16, -9), S(119, 152), S(174, 205), S(216, 239), S(262, 254), S(287, 292), S(323, 288), S(356, 291), S(401, 233), }; inline const std::array BISHOP_MOBILITY = { - S(-251, -286), S(-179, -105), S(-106, -44), S(-72, 6), S(-42, 37), S(-25, 57), S(-6, 71), S(12, 77), - S(31, 80), S(45, 76), S(69, 63), S(133, 17), S(162, -3), S(226, -40) + S(26, -65), S(98, 116), S(171, 177), S(205, 227), S(235, 258), S(252, 278), S(271, 292), S(289, 298), S(308, 301), S(322, 297), S(346, 284), S(410, 238), S(439, 218), S(503, 181), }; inline const std::array ROOK_MOBILITY = { - S(-279, -287), S(-149, -94), S(-98, -30), S(-66, -20), S(-40, 4), S(-26, 26), S(-8, 38), S(10, 44), - S(26, 56), S(44, 65), S(62, 67), S(74, 69), S(95, 71), S(107, 57), S(249, -68) + S(157, 217), S(287, 410), S(338, 474), S(370, 484), S(396, 508), S(410, 530), S(428, 542), S(446, 548), S(462, 560), S(480, 569), S(498, 571), S(510, 573), S(531, 575), S(543, 561), S(685, 436), }; inline const std::array QUEEN_MOBILITY = { - S(-996, -893), S(-275, -534), S(-179, -405), S(-120, -242), S(-108, -58), S(-68, 34), S(-63, 137), S(-38, 145), - S(-31, 195), S(-19, 218), S(-9, 239), S(-3, 250), S(17, 238), S(29, 245), S(38, 234), S(54, 224), - S(64, 208), S(69, 205), S(101, 150), S(131, 101), S(159, 65), S(227, -31), S(226, -30), S(330, -166), - S(286, -159), S(247, -162), S(10, -98), S(-91, -115) + S(0, 2), S(721, 361), S(817, 490), S(876, 653), S(888, 837), S(928, 929), S(933, 1032), S(958, 1040), S(965, 1090), S(977, 1113), S(987, 1134), S(993, 1145), S(1013, 1133), S(1025, 1140), S(1034, 1129), S(1050, 1119), S(1060, 1103), S(1065, 1100), S(1097, 1045), S(1127, 996), S(1155, 960), S(1223, 864), S(1222, 865), S(1326, 729), S(1282, 736), S(1243, 733), S(1006, 797), S(905, 780), }; inline const std::array KING_MOBILITY = { - S(286, 141), S(91, -117), S(-10, -31), S(-24, 5), S(-50, 7), S(-83, 13), S(-63, 14), S(-72, 8), - S(-73, -39) + S(335, 145), S(140, -113), S(39, -27), S(25, 9), S(-1, 11), S(-34, 17), S(-14, 18), S(-23, 12), S(-24, -35), }; inline const std::array KNIGHT_KING_RING = { - S(0, 0), S(85, -28), S(154, -73) + S(0, 0), S(85, -28), S(154, -73), }; inline const std::array BISHOP_KING_RING = { - S(0, 0), S(35, -4), S(134, -40) + S(0, 0), S(35, -4), S(134, -40), }; inline const std::array ROOK_KING_RING = { - S(0, 0), S(66, -45), S(48, -58), S(96, -55), S(143, -115) + S(0, 0), S(66, -45), S(48, -58), S(96, -55), S(143, -115), }; inline const std::array QUEEN_KING_RING = { - S(0, 0), S(-52, 64), S(-84, 99), S(-48, 78), S(81, 24), S(211, -61) + S(0, 0), S(-52, 64), S(-84, 99), S(-48, 78), S(81, 24), S(211, -61), }; inline const PParam PAWN_THREAT_KNIGHT = S(234, 57); @@ -100,67 +93,66 @@ inline const PParam BISHOP_THREAT_ROOK = S(237, 55); inline const PParam BISHOP_THREAT_QUEEN = S(192, 35); inline const std::array BISHOP_PAWNS = { - S(1, -6), S(-1, 0), S(0, -10), S(-5, -21), S(-11, -26), S(-16, -32), S(-17, -39), S(-23, -37), - S(-33, -42) + S(1, -6), S(-1, 0), S(0, -10), S(-5, -21), S(-11, -26), S(-16, -32), S(-17, -39), S(-23, -37), S(-33, -42), }; inline const std::array PAWN_PSQT = { - S(115, 160), S(110, 205), S(168, 175), S(231, 59), S(181, 52), S(164, 117), S(63, 139), S(120, 116), - S(60, 37), S(169, 65), S(147, 9), S(146, -46), S(101, -63), S(48, -16), S(10, 30), S(-39, 32), - S(-18, 0), S(1, 4), S(18, -40), S(6, -53), S(-11, -58), S(-52, -51), S(-92, -4), S(-118, 17), - S(-43, -47), S(-26, -20), S(-31, -53), S(-49, -49), S(-73, -58), S(-93, -50), S(-142, -1), S(-160, -13), - S(-45, -77), S(16, -74), S(-34, -31), S(-64, -29), S(-84, -38), S(-121, -39), S(-139, -25), S(-162, -31), - S(-35, -69), S(91, -64), S(46, -29), S(-10, -11), S(-47, -24), S(-81, -28), S(-106, -4), S(-141, -17) + S(234, 316), S(229, 361), S(287, 331), S(350, 215), S(300, 208), S(283, 273), S(182, 295), S(239, 272), // + S(179, 193), S(288, 221), S(266, 165), S(265, 110), S(220, 93), S(167, 140), S(129, 186), S(80, 188), // + S(101, 156), S(120, 160), S(137, 116), S(125, 103), S(108, 98), S(67, 105), S(27, 152), S(1, 173), // + S(76, 109), S(93, 136), S(88, 103), S(70, 107), S(46, 98), S(26, 106), S(-23, 155), S(-41, 143), // + S(74, 79), S(135, 82), S(85, 125), S(55, 127), S(35, 118), S(-2, 117), S(-20, 131), S(-43, 125), // + S(84, 87), S(210, 92), S(165, 127), S(109, 145), S(72, 132), S(38, 128), S(13, 152), S(-22, 139), // }; inline const std::array KNIGHT_PSQT = { - S(-377, -151), S(-317, 48), S(-393, 198), S(-118, 66), S(-237, 87), S(-313, 91), S(-522, 68), S(-511, -22), - S(1, 1), S(66, 12), S(165, -54), S(115, 7), S(116, 15), S(56, -8), S(-4, 12), S(-23, -31), - S(58, -26), S(103, 13), S(189, 9), S(142, 32), S(141, 22), S(62, 31), S(51, 5), S(-42, 13), - S(112, 7), S(103, 27), S(132, 34), S(112, 60), S(119, 48), S(87, 42), S(62, 0), S(38, 7), - S(101, -12), S(125, -15), S(120, 9), S(92, 27), S(84, 36), S(81, 30), S(55, 3), S(41, -52), - S(15, -21), S(45, -34), S(39, -14), S(50, 30), S(57, 27), S(1, 5), S(7, -34), S(-33, -39), - S(17, -6), S(38, -35), S(22, -29), S(22, -10), S(9, -16), S(-18, -36), S(-7, -49), S(-63, -115), - S(-31, -53), S(7, -13), S(24, -38), S(32, -31), S(25, -23), S(-21, -53), S(-34, -27), S(-82, -76) + S(-255, -4), S(-195, 195), S(-271, 345), S(4, 213), S(-115, 234), S(-191, 238), S(-400, 215), S(-389, 125), // + S(123, 148), S(188, 159), S(287, 93), S(237, 154), S(238, 162), S(178, 139), S(118, 159), S(99, 116), // + S(180, 121), S(225, 160), S(311, 156), S(264, 179), S(263, 169), S(184, 178), S(173, 152), S(80, 160), // + S(234, 154), S(225, 174), S(254, 181), S(234, 207), S(241, 195), S(209, 189), S(184, 147), S(160, 154), // + S(223, 135), S(247, 132), S(242, 156), S(214, 174), S(206, 183), S(203, 177), S(177, 150), S(163, 95), // + S(137, 126), S(167, 113), S(161, 133), S(172, 177), S(179, 174), S(123, 152), S(129, 113), S(89, 108), // + S(139, 141), S(160, 112), S(144, 118), S(144, 137), S(131, 131), S(104, 111), S(115, 98), S(59, 32), // + S(91, 94), S(129, 134), S(146, 109), S(154, 116), S(147, 124), S(101, 94), S(88, 120), S(40, 71), // }; inline const std::array BISHOP_PSQT = { - S(-167, 65), S(-188, 47), S(-418, 71), S(-302, 85), S(-255, 88), S(-415, 112), S(-171, 92), S(-121, 65), - S(-1, -43), S(-17, 34), S(3, 14), S(-17, 18), S(-41, 33), S(-6, 25), S(-26, 16), S(-62, 16), - S(26, 13), S(72, 3), S(149, 12), S(79, 11), S(55, 14), S(32, 26), S(88, -2), S(-12, 15), - S(43, -32), S(55, 1), S(91, 4), S(89, 29), S(95, 29), S(35, 27), S(23, 3), S(-23, 9), - S(44, -59), S(51, -21), S(59, -4), S(58, 20), S(50, 37), S(11, 27), S(-4, -10), S(-9, -53), - S(58, -50), S(103, -31), S(104, -20), S(47, 24), S(28, 29), S(28, 25), S(59, -17), S(22, -46), - S(44, -85), S(93, -62), S(63, -49), S(36, -20), S(27, -35), S(28, -46), S(10, -31), S(29, -94), - S(39, -65), S(26, -21), S(31, -15), S(41, -48), S(50, -59), S(46, -18), S(38, -44), S(21, -46) + S(-24, 259), S(-45, 241), S(-275, 265), S(-159, 279), S(-112, 282), S(-272, 306), S(-28, 286), S(22, 259), // + S(142, 151), S(126, 228), S(146, 208), S(126, 212), S(102, 227), S(137, 219), S(117, 210), S(81, 210), // + S(169, 207), S(215, 197), S(292, 206), S(222, 205), S(198, 208), S(175, 220), S(231, 192), S(131, 209), // + S(186, 162), S(198, 195), S(234, 198), S(232, 223), S(238, 223), S(178, 221), S(166, 197), S(120, 203), // + S(187, 135), S(194, 173), S(202, 190), S(201, 214), S(193, 231), S(154, 221), S(139, 184), S(134, 141), // + S(201, 144), S(246, 163), S(247, 174), S(190, 218), S(171, 223), S(171, 219), S(202, 177), S(165, 148), // + S(187, 109), S(236, 132), S(206, 145), S(179, 174), S(170, 159), S(171, 148), S(153, 163), S(172, 100), // + S(182, 129), S(169, 173), S(174, 179), S(184, 146), S(193, 135), S(189, 176), S(181, 150), S(164, 148), // }; inline const std::array ROOK_PSQT = { - S(101, 7), S(164, 6), S(95, 35), S(97, 29), S(106, 17), S(55, 30), S(63, 33), S(71, 38), - S(18, 60), S(102, 36), S(174, 15), S(104, 55), S(119, 43), S(68, 52), S(9, 70), S(0, 77), - S(6, 38), S(151, -1), S(181, -8), S(181, -12), S(137, -3), S(65, 38), S(80, 26), S(-35, 75), - S(-24, 33), S(50, 26), S(82, 17), S(105, -19), S(73, 4), S(13, 52), S(-5, 52), S(-74, 59), - S(-87, -16), S(-8, -11), S(-23, 6), S(-39, 6), S(-44, 3), S(-62, 43), S(-90, 40), S(-109, 29), - S(-110, -37), S(-37, -63), S(-43, -34), S(-63, -32), S(-46, -51), S(-95, 3), S(-96, -15), S(-118, -17), - S(-170, -26), S(-72, -86), S(-49, -72), S(-45, -69), S(-52, -64), S(-70, -48), S(-90, -70), S(-120, -54), - S(-137, -23), S(-106, -19), S(-54, -54), S(-30, -70), S(-43, -56), S(-56, -44), S(-71, -52), S(-87, -35) + S(556, 471), S(619, 470), S(550, 499), S(552, 493), S(561, 481), S(510, 494), S(518, 497), S(526, 502), // + S(473, 524), S(557, 500), S(629, 479), S(559, 519), S(574, 507), S(523, 516), S(464, 534), S(455, 541), // + S(461, 502), S(606, 463), S(636, 456), S(636, 452), S(592, 461), S(520, 502), S(535, 490), S(420, 539), // + S(431, 497), S(505, 490), S(537, 481), S(560, 445), S(528, 468), S(468, 516), S(450, 516), S(381, 523), // + S(368, 448), S(447, 453), S(432, 470), S(416, 470), S(411, 467), S(393, 507), S(365, 504), S(346, 493), // + S(345, 427), S(418, 401), S(412, 430), S(392, 432), S(409, 413), S(360, 467), S(359, 449), S(337, 447), // + S(285, 438), S(383, 378), S(406, 392), S(410, 395), S(403, 400), S(385, 416), S(365, 394), S(335, 410), // + S(318, 441), S(349, 445), S(401, 410), S(425, 394), S(412, 408), S(399, 420), S(384, 412), S(368, 429), // }; inline const std::array QUEEN_PSQT = { - S(86, -57), S(136, -99), S(145, -89), S(42, 27), S(85, -20), S(30, 3), S(75, -62), S(3, -30), - S(29, 49), S(-37, 136), S(-26, 167), S(-98, 183), S(-87, 155), S(-112, 173), S(-75, 97), S(-37, 31), - S(-9, 77), S(75, 80), S(25, 145), S(5, 160), S(-39, 150), S(-79, 166), S(-12, 66), S(-44, 35), - S(47, 8), S(48, 81), S(14, 115), S(-2, 190), S(-27, 182), S(-35, 115), S(-2, 33), S(-10, -10), - S(6, 47), S(37, 15), S(10, 88), S(-36, 161), S(-49, 152), S(-41, 105), S(-27, 22), S(-24, -27), - S(12, -99), S(35, -55), S(32, 21), S(-23, 56), S(-12, 17), S(-10, 24), S(4, -52), S(-16, -54), - S(-6, -194), S(24, -287), S(12, -153), S(26, -74), S(0, -55), S(17, -125), S(-8, -64), S(-21, -74), - S(-60, -111), S(-2, -352), S(-7, -344), S(18, -244), S(20, -159), S(23, -203), S(12, -169), S(-38, -99) + S(831, 784), S(881, 742), S(890, 752), S(787, 868), S(830, 821), S(775, 844), S(820, 779), S(748, 811), // + S(774, 890), S(708, 977), S(719, 1008), S(647, 1024), S(658, 996), S(633, 1014), S(670, 938), S(708, 872), // + S(736, 918), S(820, 921), S(770, 986), S(750, 1001), S(706, 991), S(666, 1007), S(733, 907), S(701, 876), // + S(792, 849), S(793, 922), S(759, 956), S(743, 1031), S(718, 1023), S(710, 956), S(743, 874), S(735, 831), // + S(751, 888), S(782, 856), S(755, 929), S(709, 1002), S(696, 993), S(704, 946), S(718, 863), S(721, 814), // + S(757, 742), S(780, 786), S(777, 862), S(722, 897), S(733, 858), S(735, 865), S(749, 789), S(729, 787), // + S(739, 647), S(769, 554), S(757, 688), S(771, 767), S(745, 786), S(762, 716), S(737, 777), S(724, 767), // + S(685, 730), S(743, 489), S(738, 497), S(763, 597), S(765, 682), S(768, 638), S(757, 672), S(707, 742), // }; inline const std::array KING_PSQT = { - S(-216, -336), S(47, -14), S(-77, 27), S(-152, 44), S(17, -17), S(17, -17), S(17, -17), S(17, -17), - S(141, -102), S(2, 135), S(20, 121), S(135, 50), S(17, -17), S(17, -17), S(17, -17), S(17, -17), - S(-45, 52), S(67, 128), S(105, 97), S(91, 52), S(17, -17), S(17, -17), S(17, -17), S(17, -17), - S(-258, 75), S(29, 89), S(21, 88), S(-30, 72), S(17, -17), S(17, -17), S(17, -17), S(17, -17), - S(-229, 31), S(-49, 63), S(-30, 68), S(-111, 103), S(17, -17), S(17, -17), S(17, -17), S(17, -17), - S(-134, 0), S(52, 6), S(-36, 57), S(-77, 82), S(17, -17), S(17, -17), S(17, -17), S(17, -17), - S(76, -87), S(127, -45), S(41, 0), S(-37, 42), S(17, -17), S(17, -17), S(17, -17), S(17, -17), - S(-43, -118), S(83, -104), S(-14, -70), S(-31, -69), S(17, -17), S(17, -17), S(17, -17), S(17, -17) + S(-233, -319), S(30, 3), S(-94, 44), S(-169, 61), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(124, -85), S(-15, 152), S(3, 138), S(118, 67), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-62, 69), S(50, 145), S(88, 114), S(74, 69), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-275, 92), S(12, 106), S(4, 105), S(-47, 89), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-246, 48), S(-66, 80), S(-47, 85), S(-128, 120), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-151, 17), S(35, 23), S(-53, 74), S(-94, 99), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(59, -70), S(110, -28), S(24, 17), S(-54, 59), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // + S(-60, -101), S(66, -87), S(-31, -53), S(-48, -52), S(0, 0), S(0, 0), S(0, 0), S(0, 0), // }; // Epoch duration: 61.8411s // clang-format on From 144fe3e0c7ad5410342b43ad47fabc41e338a182 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 12:36:32 +0100 Subject: [PATCH 14/31] SoA + pointer for backward loop --- src/tuning/arena.hpp | 199 ++++++++++++++++++++---- src/tuning/graph.cpp | 350 +++++++++++++++++++++++++------------------ src/tuning/graph.hpp | 71 +++++---- src/tuning/value.cpp | 20 ++- 4 files changed, 422 insertions(+), 218 deletions(-) diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index f2f5266c..89d5ab5a 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -1,4 +1,5 @@ #pragma once +#include "util/vec/sse2.hpp" #include "util/types.hpp" #include @@ -6,55 +7,199 @@ namespace Clockwork::Autograd { -/// ARENA IMPLEMENTATION \\\ -// Simple vector-based arena for storing values and pairs. Surely can be done better. Kek. -template -class Arena { +class ValueArena { +public: + ValueArena() = default; + + void reserve(usize n) { + values.reserve(n); + gradients.reserve(n); + } + + u32 alloc(f64 value, f64 grad = 0.0) { + u32 idx = static_cast(values.size()); + values.push_back(value); + gradients.push_back(grad); + return idx; + } + + u32 alloc_uninitialized() { + u32 idx = static_cast(values.size()); + values.push_back(0.0); + gradients.push_back(0.0); + return idx; + } + + // Mutating accessors + inline f64& val(u32 i) { + assert(i < values.size()); + return values[i]; + } + inline f64& grad(u32 i) { + assert(i < gradients.size()); + return gradients[i]; + } + + // Const accessors + inline const f64& val(u32 i) const { + assert(i < values.size()); + return values[i]; + } + inline const f64& grad(u32 i) const { + assert(i < gradients.size()); + return gradients[i]; + } + + inline usize size() const { + return values.size(); + } + + void clear() { + values.clear(); + gradients.clear(); + } + + void reset_to(usize n) { + if (n < values.size()) { + values.resize(n); + gradients.resize(n); + } + } + + inline f64* values_data() { + return values.data(); + } + inline f64* gradients_data() { + return gradients.data(); + } + inline const f64* values_data() const { + return values.data(); + } + inline const f64* gradients_data() const { + return gradients.data(); + } + private: - std::vector m_data; + std::vector values; + std::vector gradients; +}; + +class PairArena { public: - // Allocates a new slot and returns its index - u32 alloc(const T& initial_value) { - u32 idx = static_cast(m_data.size()); - m_data.push_back(initial_value); + PairArena() = default; + + void reserve(usize n) { + p0.reserve(n); + p1.reserve(n); + g0.reserve(n); + g1.reserve(n); + } + + u32 alloc(f64x2 v, f64x2 g = f64x2::zero()) { + u32 idx = static_cast(p0.size()); + p0.push_back(v.first()); + p1.push_back(v.second()); + g0.push_back(g.first()); + g1.push_back(g.second()); return idx; } - // Emplace version we might want later for ops that return many values? - // Might be seeing things. - template - u32 emplace(Args&&... args) { - u32 idx = static_cast(m_data.size()); - m_data.emplace_back(std::forward(args)...); + u32 alloc_uninitialized() { + u32 idx = static_cast(p0.size()); + p0.push_back(0.0); + p1.push_back(0.0); + g0.push_back(0.0); + g1.push_back(0.0); return idx; } - // Accessors - inline T& operator[](u32 index) { - assert(index < m_data.size()); - return m_data[index]; + // Mutating accessors + inline f64& p0_mut(u32 i) { + assert(i < p0.size()); + return p0[i]; + } + inline f64& p1_mut(u32 i) { + assert(i < p1.size()); + return p1[i]; + } + inline f64& g0_mut(u32 i) { + assert(i < g0.size()); + return g0[i]; + } + inline f64& g1_mut(u32 i) { + assert(i < g1.size()); + return g1[i]; } - inline const T& operator[](u32 index) const { - assert(index < m_data.size()); - return m_data[index]; + // Const accessors + inline const f64& p0_ref(u32 i) const { + assert(i < p0.size()); + return p0[i]; + } + inline const f64& p1_ref(u32 i) const { + assert(i < p1.size()); + return p1[i]; + } + inline const f64& g0_ref(u32 i) const { + assert(i < g0.size()); + return g0[i]; + } + inline const f64& g1_ref(u32 i) const { + assert(i < g1.size()); + return g1[i]; } inline usize size() const { - return m_data.size(); + return p0.size(); } - // Common std::vector W void clear() { - m_data.clear(); + p0.clear(); + p1.clear(); + g0.clear(); + g1.clear(); } void reset_to(usize n) { - if (n < m_data.size()) { - m_data.resize(n); + if (n < p0.size()) { + p0.resize(n); + p1.resize(n); + g0.resize(n); + g1.resize(n); } } + + inline f64* p0_data() { + return p0.data(); + } + inline f64* p1_data() { + return p1.data(); + } + inline f64* g0_data() { + return g0.data(); + } + inline f64* g1_data() { + return g1.data(); + } + inline const f64* p0_data() const { + return p0.data(); + } + inline const f64* p1_data() const { + return p1.data(); + } + inline const f64* g0_data() const { + return g0.data(); + } + inline const f64* g1_data() const { + return g1.data(); + } + +private: + std::vector p0; + std::vector p1; + std::vector g0; + std::vector g1; }; } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 212ee38c..8b95cbb9 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -12,11 +12,16 @@ Graph::Graph() { m_global_param_count = params.size(); m_global_pair_count = pair_params.size(); + // reserve some headroom + m_values.reserve(m_global_param_count + 1024); + m_pairs.reserve(m_global_pair_count + 1024); + m_tape.reserve(16384 * 2 * 16); + for (auto* p : params) { - m_values.alloc({p->default_value(), 0.0}); + m_values.alloc(p->default_value(), 0.0); } for (auto* p : pair_params) { - m_pairs.alloc({p->default_value(), f64x2::zero()}); + m_pairs.alloc(p->default_value(), f64x2::zero()); } } @@ -26,19 +31,19 @@ Graph& Graph::get() { } ValueHandle Graph::create_value(f64 data) { - return ValueHandle(m_values.alloc({data, 0.0})); + return ValueHandle(m_values.alloc(data, 0.0)); } PairHandle Graph::create_pair(f64x2 data) { - return PairHandle(m_pairs.alloc({data, f64x2::zero()})); + return PairHandle(m_pairs.alloc(data, f64x2::zero())); } -// ------------------ Recording ------------------ +// Recording ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { - u32 out = m_values.alloc({0.0, 0.0}); - f64 l = m_values[lhs.index].value; - f64 r = m_values[rhs.index].value; + u32 out = m_values.alloc_uninitialized(); + f64 l = m_values.val(lhs.index); + f64 r = m_values.val(rhs.index); f64 res = 0.0; switch (op) { @@ -60,14 +65,14 @@ ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { default: break; } - m_values[out].value = res; + m_values.val(out) = res; m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); return ValueHandle(out); } ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { - u32 out = m_values.alloc({0.0, 0.0}); - f64 l = m_values[input.index].value; + u32 out = m_values.alloc_uninitialized(); + f64 l = m_values.val(input.index); f64 res = 0.0; switch (op) { @@ -107,89 +112,99 @@ ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { default: break; } - m_values[out].value = res; + m_values.val(out) = res; m_tape.push_back({op, out, input.index, 0, scalar}); return ValueHandle(out); } PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { - u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); - f64x2 l = m_pairs[lhs.index].values; - f64x2 r = m_pairs[rhs.index].values; + u32 out = m_pairs.alloc_uninitialized(); + f64 l0 = m_pairs.p0_ref(lhs.index); + f64 l1 = m_pairs.p1_ref(lhs.index); + f64 r0 = m_pairs.p0_ref(rhs.index); + f64 r1 = m_pairs.p1_ref(rhs.index); f64x2 res = f64x2::zero(); switch (op) { case OpType::PairAdd: - res = f64x2::add(l, r); + res = f64x2::add(f64x2::make(l0, l1), f64x2::make(r0, r1)); break; case OpType::PairSub: - res = f64x2::sub(l, r); + res = f64x2::sub(f64x2::make(l0, l1), f64x2::make(r0, r1)); break; default: break; } - m_pairs[out].values = res; + + m_pairs.p0_mut(out) = res.first(); + m_pairs.p1_mut(out) = res.second(); m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); return PairHandle(out); } PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { - u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); - f64x2 l = m_pairs[input.index].values; + u32 out = m_pairs.alloc_uninitialized(); + f64 l0 = m_pairs.p0_ref(input.index); + f64 l1 = m_pairs.p1_ref(input.index); f64x2 res = f64x2::zero(); switch (op) { case OpType::PairNeg: - res = f64x2::neg(l); + res = f64x2::neg(f64x2::make(l0, l1)); break; case OpType::PairMulScalar: - res = f64x2::mul_scalar(l, scalar); + res = f64x2::mul_scalar(f64x2::make(l0, l1), scalar); break; case OpType::PairDivScalar: - res = f64x2::div_scalar(l, scalar); + res = f64x2::div_scalar(f64x2::make(l0, l1), scalar); break; case OpType::ScalarDivPair: - res = f64x2::scalar_div(scalar, l); + res = f64x2::scalar_div(scalar, f64x2::make(l0, l1)); break; default: break; } - m_pairs[out].values = res; + + m_pairs.p0_mut(out) = res.first(); + m_pairs.p1_mut(out) = res.second(); m_tape.push_back({op, out, input.index, 0, scalar}); return PairHandle(out); } PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) { - u32 out = m_pairs.alloc({f64x2::zero(), f64x2::zero()}); - f64x2 p = m_pairs[pair.index].values; - f64 v = m_values[val.index].value; + u32 out = m_pairs.alloc_uninitialized(); + f64 p0 = m_pairs.p0_ref(pair.index); + f64 p1 = m_pairs.p1_ref(pair.index); + f64 v = m_values.val(val.index); f64x2 res = f64x2::zero(); switch (op) { case OpType::PairMulValue: case OpType::ValueMulPair: - res = f64x2::mul_scalar(p, v); + res = f64x2::mul_scalar(f64x2::make(p0, p1), v); break; case OpType::PairDivValue: - res = f64x2::div_scalar(p, v); + res = f64x2::div_scalar(f64x2::make(p0, p1), v); break; case OpType::ValueDivPair: - res = f64x2::scalar_div(v, p); + res = f64x2::scalar_div(v, f64x2::make(p0, p1)); break; default: break; } - m_pairs[out].values = res; + m_pairs.p0_mut(out) = res.first(); + m_pairs.p1_mut(out) = res.second(); m_tape.push_back({op, out, pair.index, val.index, 0.0}); return PairHandle(out); } ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { - u32 out = m_values.alloc({0.0, 0.0}); - f64x2 p = m_pairs[input.index].values; + u32 out = m_values.alloc_uninitialized(); + f64 p0 = m_pairs.p0_ref(input.index); + f64 p1 = m_pairs.p1_ref(input.index); - f64 val = alpha * p.first() + (1.0 - alpha) * p.second(); - m_values[out].value = val; + f64 val = alpha * p0 + (1.0 - alpha) * p1; + m_values.val(out) = val; m_tape.push_back({OpType::Phase, out, input.index, 0, alpha}); return ValueHandle(out); } @@ -199,204 +214,216 @@ void Graph::backward() { return; } - // Our backward model assumes the last operation produces a scalar loss value. - const auto& last_node = m_tape.back(); + const auto& last_node = m_tape.back(); + m_values.grad(last_node.output_idx) = 1.0; - // Initialize gradient of final output to 1.0 to start backprop - m_values[last_node.output_idx].gradient = 1.0; + // Raw pointers for hot loops + f64* vals = m_values.values_data(); + f64* grads = m_values.gradients_data(); + + f64* p0 = m_pairs.p0_data(); + f64* p1 = m_pairs.p1_data(); + f64* g0 = m_pairs.g0_data(); + f64* g1 = m_pairs.g1_data(); - // Rev iterate for (auto it = m_tape.rbegin(); it != m_tape.rend(); ++it) { const Node& node = *it; switch (node.type) { // Value Binary case OpType::Add: { - f64 grad = m_values[node.output_idx].gradient; - m_values[node.lhs_idx].gradient += grad; - m_values[node.rhs_idx].gradient += grad; + f64 grad = grads[node.output_idx]; + grads[node.lhs_idx] += grad; + grads[node.rhs_idx] += grad; break; } case OpType::Sub: { - f64 grad = m_values[node.output_idx].gradient; - m_values[node.lhs_idx].gradient += grad; - m_values[node.rhs_idx].gradient -= grad; + f64 grad = grads[node.output_idx]; + grads[node.lhs_idx] += grad; + grads[node.rhs_idx] -= grad; break; } case OpType::Mul: { - f64 grad = m_values[node.output_idx].gradient; - f64 l = m_values[node.lhs_idx].value; - f64 r = m_values[node.rhs_idx].value; - m_values[node.lhs_idx].gradient += r * grad; - m_values[node.rhs_idx].gradient += l * grad; + f64 grad = grads[node.output_idx]; + f64 l = vals[node.lhs_idx]; + f64 r = vals[node.rhs_idx]; + grads[node.lhs_idx] += r * grad; + grads[node.rhs_idx] += l * grad; break; } case OpType::Div: { - f64 grad = m_values[node.output_idx].gradient; - f64 l = m_values[node.lhs_idx].value; - f64 r = m_values[node.rhs_idx].value; - m_values[node.lhs_idx].gradient += (1.0 / r) * grad; - m_values[node.rhs_idx].gradient += (-l / (r * r)) * grad; + f64 grad = grads[node.output_idx]; + f64 l = vals[node.lhs_idx]; + f64 r = vals[node.rhs_idx]; + grads[node.lhs_idx] += (1.0 / r) * grad; + grads[node.rhs_idx] += (-l / (r * r)) * grad; break; } case OpType::Pow: { - f64 grad = m_values[node.output_idx].gradient; - f64 base = m_values[node.lhs_idx].value; - f64 exp = m_values[node.rhs_idx].value; - m_values[node.lhs_idx].gradient += exp * std::pow(base, exp - 1) * grad; - m_values[node.rhs_idx].gradient += std::pow(base, exp) * std::log(base) * grad; + f64 grad = grads[node.output_idx]; + f64 base = vals[node.lhs_idx]; + f64 exp = vals[node.rhs_idx]; + grads[node.lhs_idx] += exp * std::pow(base, exp - 1) * grad; + grads[node.rhs_idx] += std::pow(base, exp) * std::log(base) * grad; break; } // Value Unary case OpType::Exp: { - f64 grad = m_values[node.output_idx].gradient; - f64 val = m_values[node.output_idx].value; - m_values[node.lhs_idx].gradient += val * grad; + f64 grad = grads[node.output_idx]; + f64 val = vals[node.output_idx]; + grads[node.lhs_idx] += val * grad; break; } case OpType::Log: { - f64 grad = m_values[node.output_idx].gradient; - f64 l = m_values[node.lhs_idx].value; - m_values[node.lhs_idx].gradient += (1.0 / l) * grad; + f64 grad = grads[node.output_idx]; + f64 l = vals[node.lhs_idx]; + grads[node.lhs_idx] += (1.0 / l) * grad; break; } case OpType::Sigmoid: { - f64 grad = m_values[node.output_idx].gradient; - f64 s = m_values[node.output_idx].value; - m_values[node.lhs_idx].gradient += (s * (1.0 - s)) * grad; + f64 grad = grads[node.output_idx]; + f64 s = vals[node.output_idx]; + grads[node.lhs_idx] += (s * (1.0 - s)) * grad; break; } case OpType::Neg: { - m_values[node.lhs_idx].gradient -= m_values[node.output_idx].gradient; + grads[node.lhs_idx] -= grads[node.output_idx]; break; } case OpType::PowConst: { - f64 grad = m_values[node.output_idx].gradient; - f64 l = m_values[node.lhs_idx].value; + f64 grad = grads[node.output_idx]; + f64 l = vals[node.lhs_idx]; f64 exp = node.scalar_data; - m_values[node.lhs_idx].gradient += exp * std::pow(l, exp - 1.0) * grad; + grads[node.lhs_idx] += exp * std::pow(l, exp - 1.0) * grad; break; } case OpType::AddScalar: - case OpType::SubScalarVal: { - m_values[node.lhs_idx].gradient += m_values[node.output_idx].gradient; - break; - } + case OpType::SubScalarVal: case OpType::ValSubScalar: { - m_values[node.lhs_idx].gradient += m_values[node.output_idx].gradient; + grads[node.lhs_idx] += grads[node.output_idx]; break; } case OpType::MulScalar: { - m_values[node.lhs_idx].gradient += - node.scalar_data * m_values[node.output_idx].gradient; + grads[node.lhs_idx] += node.scalar_data * grads[node.output_idx]; break; } case OpType::ValDivScalar: { - m_values[node.lhs_idx].gradient += - (1.0 / node.scalar_data) * m_values[node.output_idx].gradient; + grads[node.lhs_idx] += (1.0 / node.scalar_data) * grads[node.output_idx]; break; } case OpType::DivScalarVal: { - f64 grad = m_values[node.output_idx].gradient; - f64 l = m_values[node.lhs_idx].value; - m_values[node.lhs_idx].gradient += (-node.scalar_data / (l * l)) * grad; + f64 grad = grads[node.output_idx]; + f64 l = vals[node.lhs_idx]; + grads[node.lhs_idx] += (-node.scalar_data / (l * l)) * grad; break; } // Pair Binary case OpType::PairAdd: { - f64x2 grad = m_pairs[node.output_idx].gradients; - m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad); - m_pairs[node.rhs_idx].gradients = f64x2::add(m_pairs[node.rhs_idx].gradients, grad); + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + g0[node.lhs_idx] += g0_out; + g1[node.lhs_idx] += g1_out; + g0[node.rhs_idx] += g0_out; + g1[node.rhs_idx] += g1_out; break; } case OpType::PairSub: { - f64x2 grad = m_pairs[node.output_idx].gradients; - m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad); - m_pairs[node.rhs_idx].gradients = f64x2::sub(m_pairs[node.rhs_idx].gradients, grad); + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + g0[node.lhs_idx] += g0_out; + g1[node.lhs_idx] += g1_out; + g0[node.rhs_idx] -= g0_out; + g1[node.rhs_idx] -= g1_out; break; } // Pair Scalar case OpType::PairNeg: { - f64x2 grad = m_pairs[node.output_idx].gradients; - m_pairs[node.lhs_idx].gradients = f64x2::sub(m_pairs[node.lhs_idx].gradients, grad); + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + g0[node.lhs_idx] -= g0_out; + g1[node.lhs_idx] -= g1_out; break; } case OpType::PairMulScalar: { - f64x2 grad = m_pairs[node.output_idx].gradients; - f64x2 scaled_grad = f64x2::mul_scalar(grad, node.scalar_data); - m_pairs[node.lhs_idx].gradients = - f64x2::add(m_pairs[node.lhs_idx].gradients, scaled_grad); + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + g0[node.lhs_idx] += g0_out * node.scalar_data; + g1[node.lhs_idx] += g1_out * node.scalar_data; break; } case OpType::PairDivScalar: { - f64x2 grad = m_pairs[node.output_idx].gradients; - f64x2 scaled_grad = f64x2::div_scalar(grad, node.scalar_data); - m_pairs[node.lhs_idx].gradients = - f64x2::add(m_pairs[node.lhs_idx].gradients, scaled_grad); + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + g0[node.lhs_idx] += g0_out / node.scalar_data; + g1[node.lhs_idx] += g1_out / node.scalar_data; break; } case OpType::ScalarDivPair: { - f64x2 grad = m_pairs[node.output_idx].gradients; - f64x2 l = m_pairs[node.lhs_idx].values; - f64x2 l_sq = f64x2::mul(l, l); - f64x2 neg_s_over_sq = f64x2::neg(f64x2::scalar_div(node.scalar_data, l_sq)); - f64x2 update = f64x2::mul(neg_s_over_sq, grad); - m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, update); + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + f64 l0 = p0[node.lhs_idx]; + f64 l1 = p1[node.lhs_idx]; + f64 u0 = -node.scalar_data / (l0 * l0); + f64 u1 = -node.scalar_data / (l1 * l1); + g0[node.lhs_idx] += u0 * g0_out; + g1[node.lhs_idx] += u1 * g1_out; break; } // Pair-Value case OpType::PairMulValue: case OpType::ValueMulPair: { - f64x2 grad_out = m_pairs[node.output_idx].gradients; - f64x2 p = m_pairs[node.lhs_idx].values; - f64 v = m_values[node.rhs_idx].value; + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + f64 p0v = p0[node.lhs_idx]; + f64 p1v = p1[node.lhs_idx]; + f64 v = vals[node.rhs_idx]; - f64x2 grad_p = f64x2::mul_scalar(grad_out, v); - m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); + g0[node.lhs_idx] += g0_out * v; + g1[node.lhs_idx] += g1_out * v; - f64x2 contrib = f64x2::mul(p, grad_out); - m_values[node.rhs_idx].gradient += contrib.first() + contrib.second(); + f64 contrib = p0v * g0_out + p1v * g1_out; + grads[node.rhs_idx] += contrib; break; } case OpType::PairDivValue: { - f64x2 grad_out = m_pairs[node.output_idx].gradients; - f64x2 p = m_pairs[node.lhs_idx].values; - f64 v = m_values[node.rhs_idx].value; + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + f64 p0v = p0[node.lhs_idx]; + f64 p1v = p1[node.lhs_idx]; + f64 v = vals[node.rhs_idx]; - f64x2 grad_p = f64x2::div_scalar(grad_out, v); - m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); + g0[node.lhs_idx] += g0_out / v; + g1[node.lhs_idx] += g1_out / v; - f64x2 num = f64x2::mul(p, grad_out); - f64 sum_contr = num.first() + num.second(); - m_values[node.rhs_idx].gradient += -sum_contr / (v * v); + f64 num = p0v * g0_out + p1v * g1_out; + grads[node.rhs_idx] += -num / (v * v); break; } case OpType::ValueDivPair: { - f64x2 grad_out = m_pairs[node.output_idx].gradients; - f64x2 p = m_pairs[node.lhs_idx].values; - f64 v = m_values[node.rhs_idx].value; - - f64x2 p_sq = f64x2::mul(p, p); - f64x2 neg_v_sq = f64x2::neg(f64x2::scalar_div(v, p_sq)); - f64x2 grad_p = f64x2::mul(neg_v_sq, grad_out); - m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_p); - - f64x2 v_contr = f64x2::div(grad_out, p); - m_values[node.rhs_idx].gradient += v_contr.first() + v_contr.second(); + f64 g0_out = g0[node.output_idx]; + f64 g1_out = g1[node.output_idx]; + f64 p0v = p0[node.lhs_idx]; + f64 p1v = p1[node.lhs_idx]; + f64 v = vals[node.rhs_idx]; + + f64 u0 = -v / (p0v * p0v); + f64 u1 = -v / (p1v * p1v); + g0[node.lhs_idx] += u0 * g0_out; + g1[node.lhs_idx] += u1 * g1_out; + + grads[node.rhs_idx] += g0_out / p0v + g1_out / p1v; break; } - // Special Case phase case OpType::Phase: { - f64 grad = m_values[node.output_idx].gradient; + f64 grad = grads[node.output_idx]; f64 alpha = node.scalar_data; - - f64x2 grad_upd = f64x2::make(alpha * grad, (1.0 - alpha) * grad); - m_pairs[node.lhs_idx].gradients = f64x2::add(m_pairs[node.lhs_idx].gradients, grad_upd); + g0[node.lhs_idx] += alpha * grad; + g1[node.lhs_idx] += (1.0 - alpha) * grad; break; } default: @@ -406,7 +433,6 @@ void Graph::backward() { } void Graph::cleanup() { - // Keep parameters, clear the rest m_values.reset_to(m_global_param_count); m_pairs.reset_to(m_global_pair_count); m_tape.clear(); @@ -414,10 +440,11 @@ void Graph::cleanup() { void Graph::zero_grad() { for (usize i = 0; i < m_global_param_count; ++i) { - m_values[i].gradient = 0.0; + m_values.grad(i) = 0.0; } for (usize i = 0; i < m_global_pair_count; ++i) { - m_pairs[i].gradients = f64x2::zero(); + m_pairs.g0_mut(i) = 0.0; + m_pairs.g1_mut(i) = 0.0; } } @@ -428,10 +455,11 @@ void Graph::copy_parameter_values(const Parameters& source) { std::terminate(); } for (usize i = 0; i < m_global_param_count; ++i) { - m_values[i].value = source.parameters[i]; + m_values.val(i) = source.parameters[i]; } for (usize i = 0; i < m_global_pair_count; ++i) { - m_pairs[i].values = source.pair_parameters[i]; + m_pairs.p0_mut(i) = source.pair_parameters[i].first(); + m_pairs.p1_mut(i) = source.pair_parameters[i].second(); } } @@ -440,10 +468,10 @@ Parameters Graph::get_all_parameter_values() const { p.parameters.reserve(m_global_param_count); p.pair_parameters.reserve(m_global_pair_count); for (usize i = 0; i < m_global_param_count; ++i) { - p.parameters.push_back(m_values[i].value); + p.parameters.push_back(m_values.val(i)); } for (usize i = 0; i < m_global_pair_count; ++i) { - p.pair_parameters.push_back(m_pairs[i].values); + p.pair_parameters.push_back(f64x2::make(m_pairs.p0_ref(i), m_pairs.p1_ref(i))); } return p; } @@ -453,12 +481,36 @@ Parameters Graph::get_all_parameter_gradients() const { p.parameters.reserve(m_global_param_count); p.pair_parameters.reserve(m_global_pair_count); for (usize i = 0; i < m_global_param_count; ++i) { - p.parameters.push_back(m_values[i].gradient); + p.parameters.push_back(m_values.grad(i)); } for (usize i = 0; i < m_global_pair_count; ++i) { - p.pair_parameters.push_back(m_pairs[i].gradients); + p.pair_parameters.push_back(f64x2::make(m_pairs.g0_ref(i), m_pairs.g1_ref(i))); } return p; } +// Mutation Helpers + +void Graph::add_value_gradient(u32 idx, f64 delta) { + m_values.grad(idx) += delta; +} + +void Graph::set_value(u32 idx, f64 v) { + m_values.val(idx) = v; +} + +void Graph::zero_value_grad(u32 idx) { + m_values.grad(idx) = 0.0; +} + +void Graph::set_pair_values(u32 idx, const f64x2& v) { + m_pairs.p0_mut(idx) = v.first(); + m_pairs.p1_mut(idx) = v.second(); +} + +void Graph::zero_pair_grad(u32 idx) { + m_pairs.g0_mut(idx) = 0.0; + m_pairs.g1_mut(idx) = 0.0; +} + } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 28eef7a2..20876bf4 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -11,26 +11,15 @@ namespace Clockwork::Autograd { -struct ValueData { - f64 value; - f64 gradient; -}; - -struct PairData { - f64x2 values; - f64x2 gradients; -}; - class Graph { private: - // Storage - Arena m_values; - Arena m_pairs; + ValueArena m_values; + PairArena m_pairs; // Tape (Linear record of operations) std::vector m_tape; - // Counts of global parameters (they sit at the start of the arenas) + // Counts of global parameters usize m_global_param_count = 0; usize m_global_pair_count = 0; @@ -43,20 +32,13 @@ class Graph { ValueHandle create_value(f64 data); PairHandle create_pair(f64x2 data); - // Operation recording stuff - - // Value-Value Binary + // Operation recording ValueHandle record_op(OpType op, ValueHandle lhs, ValueHandle rhs); - // Value Unary / Scalar ValueHandle record_op(OpType op, ValueHandle input, f64 scalar = 0.0); - // Pair-Pair Binary PairHandle record_pair_op(OpType op, PairHandle lhs, PairHandle rhs); - // Pair-Scalar PairHandle record_pair_scalar(OpType op, PairHandle input, f64 scalar); - // Pair-Value PairHandle record_pair_value(OpType op, PairHandle pair, ValueHandle val); - // Handling phasing separately due to its unique nature, probably can be done better ValueHandle record_phase(PairHandle input, f64 alpha); void backward(); @@ -67,19 +49,46 @@ class Graph { Parameters get_all_parameter_values() const; Parameters get_all_parameter_gradients() const; - // Accessors for Handles - ValueData& get_value_data(ValueHandle h) { - return m_values[h.index]; + void add_value_gradient(u32 idx, f64 delta); + void set_value(u32 idx, f64 v); + void zero_value_grad(u32 idx); + + void set_pair_values(u32 idx, const f64x2& v); + void zero_pair_grad(u32 idx); + + // Direct SoA accessors + f64 get_value(u32 idx) const { + return m_values.val(idx); } - const ValueData& get_value_data(ValueHandle h) const { - return m_values[h.index]; + f64 get_gradient(u32 idx) const { + return m_values.grad(idx); } - PairData& get_pair_data(PairHandle h) { - return m_pairs[h.index]; + f64x2 get_pair_values(u32 idx) const { + return f64x2::make(m_pairs.p0_ref(idx), m_pairs.p1_ref(idx)); + } + f64x2 get_pair_gradients(u32 idx) const { + return f64x2::make(m_pairs.g0_ref(idx), m_pairs.g1_ref(idx)); + } + + // Pointer accessors + f64* values_data() { + return m_values.values_data(); + } + f64* gradients_data() { + return m_values.gradients_data(); + } + f64* p0_data() { + return m_pairs.p0_data(); + } + f64* p1_data() { + return m_pairs.p1_data(); + } + f64* g0_data() { + return m_pairs.g0_data(); } - const PairData& get_pair_data(PairHandle h) const { - return m_pairs[h.index]; + f64* g1_data() { + return m_pairs.g1_data(); } ValueHandle get_parameter(usize global_index) const { diff --git a/src/tuning/value.cpp b/src/tuning/value.cpp index 78a914ab..8968b153 100644 --- a/src/tuning/value.cpp +++ b/src/tuning/value.cpp @@ -16,7 +16,6 @@ ValueHandle ValueHandle::sum(const std::vector& inputs) { if (inputs.empty()) { return ValueHandle::create(0.0); } - // Simple linear accumulation on the tape. I dropped the old optimized version for the arena rewrite, but its the top priority for future optimization. ValueHandle total = inputs[0]; for (size_t i = 1; i < inputs.size(); ++i) { total = total + inputs[i]; @@ -42,27 +41,27 @@ ValueHandle ValueHandle::pow(f64 exponent) const { void ValueHandle::add_gradient(f64 rhs) const { if (is_valid()) { - Graph::get().get_value_data(*this).gradient += rhs; + Graph::get().add_value_gradient(index, rhs); } } f64 ValueHandle::get_value() const { - return is_valid() ? Graph::get().get_value_data(*this).value : 0.0; + return is_valid() ? Graph::get().get_value(index) : 0.0; } f64 ValueHandle::get_gradient() const { - return is_valid() ? Graph::get().get_value_data(*this).gradient : 0.0; + return is_valid() ? Graph::get().get_gradient(index) : 0.0; } void ValueHandle::zero_grad() const { if (is_valid()) { - Graph::get().get_value_data(*this).gradient = 0.0; + Graph::get().zero_value_grad(index); } } void ValueHandle::set_value(f64 v) const { if (is_valid()) { - Graph::get().get_value_data(*this).value = v; + Graph::get().set_value(index, v); } } @@ -77,10 +76,10 @@ PairHandle PairHandle::create(const f64x2& values) { } f64x2 PairHandle::get_values() const { - return Graph::get().get_pair_data(*this).values; + return Graph::get().get_pair_values(index); } f64x2 PairHandle::get_gradients() const { - return Graph::get().get_pair_data(*this).gradients; + return Graph::get().get_pair_gradients(index); } f64 PairHandle::first() const { return get_values().first(); @@ -90,17 +89,16 @@ f64 PairHandle::second() const { } void PairHandle::set_values(const f64x2& v) const { - Graph::get().get_pair_data(*this).values = v; + Graph::get().set_pair_values(index, v); } void PairHandle::set_values(f64 f, f64 s) const { set_values(f64x2::make(f, s)); } void PairHandle::zero_grad() const { - Graph::get().get_pair_data(*this).gradients = f64x2::zero(); + Graph::get().zero_pair_grad(index); } - // Special phasing case ValueHandle PairHandle::phase_impl(f64 scaled_alpha) const { return Graph::get().record_phase(*this, scaled_alpha); From b21e922e5e7210012eee0a1c1cad7be50a3813b4 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 13:06:49 +0100 Subject: [PATCH 15/31] 8 epochs for testing --- src/evaltune_main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index 37f48068..f491d69e 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -97,7 +97,7 @@ int main() { // The optimizer will now start with all-zero parameters AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0); - const i32 epochs = 1000; + const i32 epochs = 8; const f64 K = 1.0 / 400; const size_t batch_size = 16 * 16384; From 20b4ffcdfa2fc06f0a6077c0db71f0be4053b450 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 14:26:08 +0100 Subject: [PATCH 16/31] inline Graph::get --- CMakeLists.txt | 4 ++++ src/evaltune_main.cpp | 5 +++-- src/tuning/graph.cpp | 5 ----- src/tuning/graph.hpp | 5 ++++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 25b1662a..400da671 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,11 +72,15 @@ function(target_add_flags target) # Libraries target_link_libraries(${target} PUBLIC git_hash lps) + # LTO + message(STATUS "LTO is set to: ${lto}") # LTO if (lto) set_target_properties(${target} PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) endif() + + endfunction() # Sorted list of source files diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index f491d69e..1c0fd242 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -200,7 +200,8 @@ int main() { Graph::get().cleanup(); Graph::get().zero_grad(); - +#define PROFILE_RUN +#ifndef PROFILE_RUN std::cout << "inline const PParam PAWN_MAT = " << PAWN_MAT << ";" << std::endl; std::cout << "inline const PParam KNIGHT_MAT = " << KNIGHT_MAT << ";" << std::endl; std::cout << "inline const PParam BISHOP_MAT = " << BISHOP_MAT << ";" << std::endl; @@ -327,7 +328,7 @@ int main() { printPsqtArray("QUEEN_PSQT", QUEEN_PSQT); printPsqtArray("KING_PSQT", KING_PSQT); std::cout << std::endl; - +#endif const auto end = time::Clock::now(); std::cout << "// Epoch duration: " << time::cast(end - start).count() << "s\n"; diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 8b95cbb9..84c409c1 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -25,11 +25,6 @@ Graph::Graph() { } } -Graph& Graph::get() { - thread_local Graph instance; - return instance; -} - ValueHandle Graph::create_value(f64 data) { return ValueHandle(m_values.alloc(data, 0.0)); } diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 20876bf4..240147ec 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -26,7 +26,10 @@ class Graph { Graph(); public: - static Graph& get(); + inline static Graph& get() { + thread_local Graph instance; + return instance; + } // Creation ValueHandle create_value(f64 data); From 71bd2ab75a0d3df0ec5f08b59deefdf2ce1cbf16 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 14:33:43 +0100 Subject: [PATCH 17/31] Bench: 12044152 From 4c68180e4e65df6f094a395522a50c15559633d2 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 15:02:25 +0100 Subject: [PATCH 18/31] Lazy node addition --- src/tuning/arena.hpp | 8 ++++++++ src/tuning/graph.cpp | 33 +++++++++++++++++---------------- src/tuning/graph.hpp | 6 +++--- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index 89d5ab5a..f6139abf 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -30,6 +30,10 @@ class ValueArena { return idx; } + inline u32 next_index() const { + return static_cast(values.size()); + } + // Mutating accessors inline f64& val(u32 i) { assert(i < values.size()); @@ -114,6 +118,10 @@ class PairArena { return idx; } + inline u32 next_index() const { + return static_cast(p0.size()); + } + // Mutating accessors inline f64& p0_mut(u32 i) { assert(i < p0.size()); diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 84c409c1..213fb6e9 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -36,7 +36,7 @@ PairHandle Graph::create_pair(f64x2 data) { // Recording ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { - u32 out = m_values.alloc_uninitialized(); + u32 out = m_values.next_index(); f64 l = m_values.val(lhs.index); f64 r = m_values.val(rhs.index); f64 res = 0.0; @@ -60,13 +60,14 @@ ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { default: break; } - m_values.val(out) = res; + + m_values.alloc(res, 0.0); m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); return ValueHandle(out); } ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { - u32 out = m_values.alloc_uninitialized(); + u32 out = m_values.next_index(); f64 l = m_values.val(input.index); f64 res = 0.0; @@ -107,13 +108,14 @@ ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { default: break; } - m_values.val(out) = res; + + m_values.alloc(res, 0.0); m_tape.push_back({op, out, input.index, 0, scalar}); return ValueHandle(out); } PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { - u32 out = m_pairs.alloc_uninitialized(); + u32 out = m_pairs.next_index(); f64 l0 = m_pairs.p0_ref(lhs.index); f64 l1 = m_pairs.p1_ref(lhs.index); f64 r0 = m_pairs.p0_ref(rhs.index); @@ -131,14 +133,13 @@ PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { break; } - m_pairs.p0_mut(out) = res.first(); - m_pairs.p1_mut(out) = res.second(); + m_pairs.alloc(res, f64x2::zero()); m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); return PairHandle(out); } PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { - u32 out = m_pairs.alloc_uninitialized(); + u32 out = m_pairs.next_index(); f64 l0 = m_pairs.p0_ref(input.index); f64 l1 = m_pairs.p1_ref(input.index); f64x2 res = f64x2::zero(); @@ -160,14 +161,13 @@ PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { break; } - m_pairs.p0_mut(out) = res.first(); - m_pairs.p1_mut(out) = res.second(); + m_pairs.alloc(res, f64x2::zero()); m_tape.push_back({op, out, input.index, 0, scalar}); return PairHandle(out); } PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) { - u32 out = m_pairs.alloc_uninitialized(); + u32 out = m_pairs.next_index(); f64 p0 = m_pairs.p0_ref(pair.index); f64 p1 = m_pairs.p1_ref(pair.index); f64 v = m_values.val(val.index); @@ -187,19 +187,20 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) default: break; } - m_pairs.p0_mut(out) = res.first(); - m_pairs.p1_mut(out) = res.second(); + + m_pairs.alloc(res, f64x2::zero()); m_tape.push_back({op, out, pair.index, val.index, 0.0}); return PairHandle(out); } ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { - u32 out = m_values.alloc_uninitialized(); + u32 out = m_values.next_index(); f64 p0 = m_pairs.p0_ref(input.index); f64 p1 = m_pairs.p1_ref(input.index); - f64 val = alpha * p0 + (1.0 - alpha) * p1; - m_values.val(out) = val; + f64 val = alpha * p0 + (1.0 - alpha) * p1; + + m_values.alloc(val, 0.0); m_tape.push_back({OpType::Phase, out, input.index, 0, alpha}); return ValueHandle(out); } diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 240147ec..00d422ec 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -38,9 +38,9 @@ class Graph { // Operation recording ValueHandle record_op(OpType op, ValueHandle lhs, ValueHandle rhs); ValueHandle record_op(OpType op, ValueHandle input, f64 scalar = 0.0); - PairHandle record_pair_op(OpType op, PairHandle lhs, PairHandle rhs); - PairHandle record_pair_scalar(OpType op, PairHandle input, f64 scalar); - PairHandle record_pair_value(OpType op, PairHandle pair, ValueHandle val); + PairHandle record_pair_op(OpType op, PairHandle lhs, PairHandle rhs); + PairHandle record_pair_scalar(OpType op, PairHandle input, f64 scalar); + PairHandle record_pair_value(OpType op, PairHandle pair, ValueHandle val); ValueHandle record_phase(PairHandle input, f64 alpha); From 69a78fd2142273aa19d63160f8fb872f668c1d06 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 15:26:26 +0100 Subject: [PATCH 19/31] f64x2 are back on the menu --- src/tuning/arena.hpp | 137 +++++++++++++++----------------- src/tuning/graph.cpp | 181 ++++++++++++++++++------------------------- src/tuning/graph.hpp | 12 --- 3 files changed, 137 insertions(+), 193 deletions(-) diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index f6139abf..55e3c842 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -88,126 +88,111 @@ class ValueArena { std::vector gradients; }; - class PairArena { public: PairArena() = default; void reserve(usize n) { - p0.reserve(n); - p1.reserve(n); - g0.reserve(n); - g1.reserve(n); + values.reserve(n); + gradients.reserve(n); } u32 alloc(f64x2 v, f64x2 g = f64x2::zero()) { - u32 idx = static_cast(p0.size()); - p0.push_back(v.first()); - p1.push_back(v.second()); - g0.push_back(g.first()); - g1.push_back(g.second()); + u32 idx = static_cast(values.size()); + values.push_back(v); + gradients.push_back(g); return idx; } u32 alloc_uninitialized() { - u32 idx = static_cast(p0.size()); - p0.push_back(0.0); - p1.push_back(0.0); - g0.push_back(0.0); - g1.push_back(0.0); + u32 idx = static_cast(values.size()); + values.push_back(f64x2::zero()); + gradients.push_back(f64x2::zero()); return idx; } inline u32 next_index() const { - return static_cast(p0.size()); + return static_cast(values.size()); } - // Mutating accessors - inline f64& p0_mut(u32 i) { - assert(i < p0.size()); - return p0[i]; + // Accessors + inline f64x2& val(u32 i) { + assert(i < values.size()); + return values[i]; } - inline f64& p1_mut(u32 i) { - assert(i < p1.size()); - return p1[i]; + + inline const f64x2& val(u32 i) const { + assert(i < values.size()); + return values[i]; } - inline f64& g0_mut(u32 i) { - assert(i < g0.size()); - return g0[i]; + + inline f64x2& grad(u32 i) { + assert(i < gradients.size()); + return gradients[i]; } - inline f64& g1_mut(u32 i) { - assert(i < g1.size()); - return g1[i]; + + inline const f64x2& grad(u32 i) const { + assert(i < gradients.size()); + return gradients[i]; } - // Const accessors - inline const f64& p0_ref(u32 i) const { - assert(i < p0.size()); - return p0[i]; + // Legacy component accessors + inline f64 p0_ref(u32 i) const { + assert(i < values.size()); + return values[i].first(); } - inline const f64& p1_ref(u32 i) const { - assert(i < p1.size()); - return p1[i]; + + inline f64 p1_ref(u32 i) const { + assert(i < values.size()); + return values[i].second(); } - inline const f64& g0_ref(u32 i) const { - assert(i < g0.size()); - return g0[i]; + + inline f64 g0_ref(u32 i) const { + assert(i < gradients.size()); + return gradients[i].first(); } - inline const f64& g1_ref(u32 i) const { - assert(i < g1.size()); - return g1[i]; + + inline f64 g1_ref(u32 i) const { + assert(i < gradients.size()); + return gradients[i].second(); } inline usize size() const { - return p0.size(); + return values.size(); } void clear() { - p0.clear(); - p1.clear(); - g0.clear(); - g1.clear(); + values.clear(); + gradients.clear(); } void reset_to(usize n) { - if (n < p0.size()) { - p0.resize(n); - p1.resize(n); - g0.resize(n); - g1.resize(n); + if (n < values.size()) { + values.resize(n); + gradients.resize(n); } } - inline f64* p0_data() { - return p0.data(); - } - inline f64* p1_data() { - return p1.data(); - } - inline f64* g0_data() { - return g0.data(); - } - inline f64* g1_data() { - return g1.data(); - } - inline const f64* p0_data() const { - return p0.data(); + // Pointer accessors for hot loops + inline f64x2* values_data() { + return values.data(); } - inline const f64* p1_data() const { - return p1.data(); + + inline f64x2* gradients_data() { + return gradients.data(); } - inline const f64* g0_data() const { - return g0.data(); + + inline const f64x2* values_data() const { + return values.data(); } - inline const f64* g1_data() const { - return g1.data(); + + inline const f64x2* gradients_data() const { + return gradients.data(); } private: - std::vector p0; - std::vector p1; - std::vector g0; - std::vector g1; + std::vector values; + std::vector gradients; }; } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 213fb6e9..f7b9ad3c 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -116,18 +116,16 @@ ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { u32 out = m_pairs.next_index(); - f64 l0 = m_pairs.p0_ref(lhs.index); - f64 l1 = m_pairs.p1_ref(lhs.index); - f64 r0 = m_pairs.p0_ref(rhs.index); - f64 r1 = m_pairs.p1_ref(rhs.index); + f64x2 l = m_pairs.val(lhs.index); + f64x2 r = m_pairs.val(rhs.index); f64x2 res = f64x2::zero(); switch (op) { case OpType::PairAdd: - res = f64x2::add(f64x2::make(l0, l1), f64x2::make(r0, r1)); + res = f64x2::add(l, r); break; case OpType::PairSub: - res = f64x2::sub(f64x2::make(l0, l1), f64x2::make(r0, r1)); + res = f64x2::sub(l, r); break; default: break; @@ -140,22 +138,21 @@ PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { u32 out = m_pairs.next_index(); - f64 l0 = m_pairs.p0_ref(input.index); - f64 l1 = m_pairs.p1_ref(input.index); + f64x2 l = m_pairs.val(input.index); f64x2 res = f64x2::zero(); switch (op) { case OpType::PairNeg: - res = f64x2::neg(f64x2::make(l0, l1)); + res = f64x2::neg(l); break; case OpType::PairMulScalar: - res = f64x2::mul_scalar(f64x2::make(l0, l1), scalar); + res = f64x2::mul_scalar(l, scalar); break; case OpType::PairDivScalar: - res = f64x2::div_scalar(f64x2::make(l0, l1), scalar); + res = f64x2::div_scalar(l, scalar); break; case OpType::ScalarDivPair: - res = f64x2::scalar_div(scalar, f64x2::make(l0, l1)); + res = f64x2::scalar_div(scalar, l); break; default: break; @@ -167,22 +164,21 @@ PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { } PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) { - u32 out = m_pairs.next_index(); - f64 p0 = m_pairs.p0_ref(pair.index); - f64 p1 = m_pairs.p1_ref(pair.index); - f64 v = m_values.val(val.index); - f64x2 res = f64x2::zero(); + u32 out = m_pairs.next_index(); + f64x2 pair_val = m_pairs.val(pair.index); + f64 v = m_values.val(val.index); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairMulValue: case OpType::ValueMulPair: - res = f64x2::mul_scalar(f64x2::make(p0, p1), v); + res = f64x2::mul_scalar(pair_val, v); break; case OpType::PairDivValue: - res = f64x2::div_scalar(f64x2::make(p0, p1), v); + res = f64x2::div_scalar(pair_val, v); break; case OpType::ValueDivPair: - res = f64x2::scalar_div(v, f64x2::make(p0, p1)); + res = f64x2::scalar_div(v, pair_val); break; default: break; @@ -194,11 +190,10 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) } ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { - u32 out = m_values.next_index(); - f64 p0 = m_pairs.p0_ref(input.index); - f64 p1 = m_pairs.p1_ref(input.index); + u32 out = m_values.next_index(); + f64x2 pair_val = m_pairs.val(input.index); - f64 val = alpha * p0 + (1.0 - alpha) * p1; + f64 val = alpha * pair_val.first() + (1.0 - alpha) * pair_val.second(); m_values.alloc(val, 0.0); m_tape.push_back({OpType::Phase, out, input.index, 0, alpha}); @@ -217,10 +212,8 @@ void Graph::backward() { f64* vals = m_values.values_data(); f64* grads = m_values.gradients_data(); - f64* p0 = m_pairs.p0_data(); - f64* p1 = m_pairs.p1_data(); - f64* g0 = m_pairs.g0_data(); - f64* g1 = m_pairs.g1_data(); + f64x2* pair_vals = m_pairs.values_data(); + f64x2* pair_grads = m_pairs.gradients_data(); for (auto it = m_tape.rbegin(); it != m_tape.rend(); ++it) { const Node& node = *it; @@ -315,111 +308,93 @@ void Graph::backward() { break; } - // Pair Binary case OpType::PairAdd: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - g0[node.lhs_idx] += g0_out; - g1[node.lhs_idx] += g1_out; - g0[node.rhs_idx] += g0_out; - g1[node.rhs_idx] += g1_out; + f64x2 grad_out = pair_grads[node.output_idx]; + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], grad_out); + pair_grads[node.rhs_idx] = f64x2::add(pair_grads[node.rhs_idx], grad_out); break; } case OpType::PairSub: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - g0[node.lhs_idx] += g0_out; - g1[node.lhs_idx] += g1_out; - g0[node.rhs_idx] -= g0_out; - g1[node.rhs_idx] -= g1_out; + f64x2 grad_out = pair_grads[node.output_idx]; + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], grad_out); + pair_grads[node.rhs_idx] = f64x2::sub(pair_grads[node.rhs_idx], grad_out); break; } - // Pair Scalar case OpType::PairNeg: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - g0[node.lhs_idx] -= g0_out; - g1[node.lhs_idx] -= g1_out; + f64x2 grad_out = pair_grads[node.output_idx]; + pair_grads[node.lhs_idx] = f64x2::sub(pair_grads[node.lhs_idx], grad_out); break; } case OpType::PairMulScalar: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - g0[node.lhs_idx] += g0_out * node.scalar_data; - g1[node.lhs_idx] += g1_out * node.scalar_data; + f64x2 grad_out = pair_grads[node.output_idx]; + f64x2 scaled = f64x2::mul_scalar(grad_out, node.scalar_data); + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], scaled); break; } case OpType::PairDivScalar: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - g0[node.lhs_idx] += g0_out / node.scalar_data; - g1[node.lhs_idx] += g1_out / node.scalar_data; + f64x2 grad_out = pair_grads[node.output_idx]; + f64x2 scaled = f64x2::div_scalar(grad_out, node.scalar_data); + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], scaled); break; } case OpType::ScalarDivPair: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - f64 l0 = p0[node.lhs_idx]; - f64 l1 = p1[node.lhs_idx]; - f64 u0 = -node.scalar_data / (l0 * l0); - f64 u1 = -node.scalar_data / (l1 * l1); - g0[node.lhs_idx] += u0 * g0_out; - g1[node.lhs_idx] += u1 * g1_out; + f64x2 grad_out = pair_grads[node.output_idx]; + f64x2 val = pair_vals[node.lhs_idx]; + f64x2 val_sq = f64x2::mul(val, val); + f64x2 deriv = f64x2::scalar_div(-node.scalar_data, val_sq); + f64x2 update = f64x2::mul(deriv, grad_out); + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], update); break; } - // Pair-Value case OpType::PairMulValue: case OpType::ValueMulPair: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - f64 p0v = p0[node.lhs_idx]; - f64 p1v = p1[node.lhs_idx]; - f64 v = vals[node.rhs_idx]; + f64x2 grad_out = pair_grads[node.output_idx]; + f64x2 pair_val = pair_vals[node.lhs_idx]; + f64 v = vals[node.rhs_idx]; - g0[node.lhs_idx] += g0_out * v; - g1[node.lhs_idx] += g1_out * v; + f64x2 pair_update = f64x2::mul_scalar(grad_out, v); + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], pair_update); - f64 contrib = p0v * g0_out + p1v * g1_out; + f64 contrib = + pair_val.first() * grad_out.first() + pair_val.second() * grad_out.second(); grads[node.rhs_idx] += contrib; break; } case OpType::PairDivValue: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - f64 p0v = p0[node.lhs_idx]; - f64 p1v = p1[node.lhs_idx]; - f64 v = vals[node.rhs_idx]; + f64x2 grad_out = pair_grads[node.output_idx]; + f64x2 pair_val = pair_vals[node.lhs_idx]; + f64 v = vals[node.rhs_idx]; - g0[node.lhs_idx] += g0_out / v; - g1[node.lhs_idx] += g1_out / v; + f64x2 pair_update = f64x2::div_scalar(grad_out, v); + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], pair_update); - f64 num = p0v * g0_out + p1v * g1_out; + f64 num = pair_val.first() * grad_out.first() + pair_val.second() * grad_out.second(); grads[node.rhs_idx] += -num / (v * v); break; } case OpType::ValueDivPair: { - f64 g0_out = g0[node.output_idx]; - f64 g1_out = g1[node.output_idx]; - f64 p0v = p0[node.lhs_idx]; - f64 p1v = p1[node.lhs_idx]; - f64 v = vals[node.rhs_idx]; - - f64 u0 = -v / (p0v * p0v); - f64 u1 = -v / (p1v * p1v); - g0[node.lhs_idx] += u0 * g0_out; - g1[node.lhs_idx] += u1 * g1_out; - - grads[node.rhs_idx] += g0_out / p0v + g1_out / p1v; + f64x2 grad_out = pair_grads[node.output_idx]; + f64x2 pair_val = pair_vals[node.lhs_idx]; + f64 v = vals[node.rhs_idx]; + + f64x2 pair_sq = f64x2::mul(pair_val, pair_val); + f64x2 deriv = f64x2::scalar_div(-v, pair_sq); + f64x2 update = f64x2::mul(deriv, grad_out); + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], update); + + f64x2 recip = f64x2::scalar_div(1.0, pair_val); + grads[node.rhs_idx] += + grad_out.first() * recip.first() + grad_out.second() * recip.second(); break; } case OpType::Phase: { - f64 grad = grads[node.output_idx]; - f64 alpha = node.scalar_data; - g0[node.lhs_idx] += alpha * grad; - g1[node.lhs_idx] += (1.0 - alpha) * grad; + f64 grad = grads[node.output_idx]; + f64 alpha = node.scalar_data; + f64x2 update = f64x2::make(alpha * grad, (1.0 - alpha) * grad); + pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], update); break; } default: @@ -439,8 +414,7 @@ void Graph::zero_grad() { m_values.grad(i) = 0.0; } for (usize i = 0; i < m_global_pair_count; ++i) { - m_pairs.g0_mut(i) = 0.0; - m_pairs.g1_mut(i) = 0.0; + m_pairs.grad(i) = f64x2::zero(); } } @@ -454,8 +428,7 @@ void Graph::copy_parameter_values(const Parameters& source) { m_values.val(i) = source.parameters[i]; } for (usize i = 0; i < m_global_pair_count; ++i) { - m_pairs.p0_mut(i) = source.pair_parameters[i].first(); - m_pairs.p1_mut(i) = source.pair_parameters[i].second(); + m_pairs.val(i) = source.pair_parameters[i]; } } @@ -467,7 +440,7 @@ Parameters Graph::get_all_parameter_values() const { p.parameters.push_back(m_values.val(i)); } for (usize i = 0; i < m_global_pair_count; ++i) { - p.pair_parameters.push_back(f64x2::make(m_pairs.p0_ref(i), m_pairs.p1_ref(i))); + p.pair_parameters.push_back(m_pairs.val(i)); } return p; } @@ -480,7 +453,7 @@ Parameters Graph::get_all_parameter_gradients() const { p.parameters.push_back(m_values.grad(i)); } for (usize i = 0; i < m_global_pair_count; ++i) { - p.pair_parameters.push_back(f64x2::make(m_pairs.g0_ref(i), m_pairs.g1_ref(i))); + p.pair_parameters.push_back(m_pairs.grad(i)); } return p; } @@ -500,13 +473,11 @@ void Graph::zero_value_grad(u32 idx) { } void Graph::set_pair_values(u32 idx, const f64x2& v) { - m_pairs.p0_mut(idx) = v.first(); - m_pairs.p1_mut(idx) = v.second(); + m_pairs.val(idx) = v; } void Graph::zero_pair_grad(u32 idx) { - m_pairs.g0_mut(idx) = 0.0; - m_pairs.g1_mut(idx) = 0.0; + m_pairs.grad(idx) = f64x2::zero(); } } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 00d422ec..7053f47a 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -81,18 +81,6 @@ class Graph { f64* gradients_data() { return m_values.gradients_data(); } - f64* p0_data() { - return m_pairs.p0_data(); - } - f64* p1_data() { - return m_pairs.p1_data(); - } - f64* g0_data() { - return m_pairs.g0_data(); - } - f64* g1_data() { - return m_pairs.g1_data(); - } ValueHandle get_parameter(usize global_index) const { return ValueHandle(static_cast(global_index)); From df02d72c7a4c42bdfd9bfeddc899caeb572480e8 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 15:48:14 +0100 Subject: [PATCH 20/31] inline allocs --- src/tuning/arena.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index 55e3c842..8815048a 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -16,14 +16,14 @@ class ValueArena { gradients.reserve(n); } - u32 alloc(f64 value, f64 grad = 0.0) { + inline u32 alloc(f64 value, f64 grad = 0.0) { u32 idx = static_cast(values.size()); values.push_back(value); gradients.push_back(grad); return idx; } - u32 alloc_uninitialized() { + inline u32 alloc_uninitialized() { u32 idx = static_cast(values.size()); values.push_back(0.0); gradients.push_back(0.0); @@ -97,14 +97,14 @@ class PairArena { gradients.reserve(n); } - u32 alloc(f64x2 v, f64x2 g = f64x2::zero()) { + inline u32 alloc(f64x2 v, f64x2 g = f64x2::zero()) { u32 idx = static_cast(values.size()); values.push_back(v); gradients.push_back(g); return idx; } - u32 alloc_uninitialized() { + inline u32 alloc_uninitialized() { u32 idx = static_cast(values.size()); values.push_back(f64x2::zero()); gradients.push_back(f64x2::zero()); From 6691b768b7cfdd3b561a089a0836058e8eeacfcf Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 18:00:26 +0100 Subject: [PATCH 21/31] tentative make op node 16 bytes --- src/tuning/arena.hpp | 10 ++ src/tuning/graph.cpp | 313 +++++++++++++++++++------------------- src/tuning/operations.hpp | 83 +++++++++- 3 files changed, 241 insertions(+), 165 deletions(-) diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index 8815048a..43d6a8ca 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -1,5 +1,7 @@ #pragma once #include "util/vec/sse2.hpp" +#include "value.hpp" +#include "value.hpp" #include "util/types.hpp" #include @@ -34,6 +36,10 @@ class ValueArena { return static_cast(values.size()); } + inline ValueHandle next_handle() const { + return ValueHandle(next_index()); + } + // Mutating accessors inline f64& val(u32 i) { assert(i < values.size()); @@ -115,6 +121,10 @@ class PairArena { return static_cast(values.size()); } + inline PairHandle next_handle() const { + return PairHandle(next_index()); + } + // Accessors inline f64x2& val(u32 i) { assert(i < values.size()); diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index f7b9ad3c..9b16f32e 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -36,10 +36,10 @@ PairHandle Graph::create_pair(f64x2 data) { // Recording ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { - u32 out = m_values.next_index(); - f64 l = m_values.val(lhs.index); - f64 r = m_values.val(rhs.index); - f64 res = 0.0; + ValueHandle out = m_values.next_handle(); + f64 l = m_values.val(lhs.index); + f64 r = m_values.val(rhs.index); + f64 res = 0.0; switch (op) { case OpType::Add: @@ -62,14 +62,16 @@ ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { } m_values.alloc(res, 0.0); - m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); - return ValueHandle(out); + + m_tape.push_back(Node::make_binary(op, out.index, lhs.index, rhs.index)); + + return out; } -ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { - u32 out = m_values.next_index(); - f64 l = m_values.val(input.index); - f64 res = 0.0; +ValueHandle Graph::record_op(OpType op, ValueHandle lhs, f64 scalar) { + ValueHandle out = m_values.next_handle(); + f64 l = m_values.val(lhs.index); + f64 res = 0.0; switch (op) { case OpType::Exp: @@ -110,15 +112,17 @@ ValueHandle Graph::record_op(OpType op, ValueHandle input, f64 scalar) { } m_values.alloc(res, 0.0); - m_tape.push_back({op, out, input.index, 0, scalar}); - return ValueHandle(out); + + m_tape.push_back(Node::make_scalar(op, out.index, lhs.index, scalar)); + + return out; } PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { - u32 out = m_pairs.next_index(); - f64x2 l = m_pairs.val(lhs.index); - f64x2 r = m_pairs.val(rhs.index); - f64x2 res = f64x2::zero(); + PairHandle out = m_pairs.next_handle(); + f64x2 l = m_pairs.val(lhs.index); + f64x2 r = m_pairs.val(rhs.index); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairAdd: @@ -132,14 +136,16 @@ PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { } m_pairs.alloc(res, f64x2::zero()); - m_tape.push_back({op, out, lhs.index, rhs.index, 0.0}); - return PairHandle(out); + + m_tape.push_back(Node::make_binary(op, out.index, lhs.index, rhs.index)); + + return out; } -PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { - u32 out = m_pairs.next_index(); - f64x2 l = m_pairs.val(input.index); - f64x2 res = f64x2::zero(); +PairHandle Graph::record_pair_scalar(OpType op, PairHandle lhs, f64 scalar) { + PairHandle out = m_pairs.next_handle(); + f64x2 l = m_pairs.val(lhs.index); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairNeg: @@ -159,15 +165,17 @@ PairHandle Graph::record_pair_scalar(OpType op, PairHandle input, f64 scalar) { } m_pairs.alloc(res, f64x2::zero()); - m_tape.push_back({op, out, input.index, 0, scalar}); - return PairHandle(out); + + m_tape.push_back(Node::make_scalar(op, out.index, lhs.index, scalar)); + + return out; } -PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) { - u32 out = m_pairs.next_index(); - f64x2 pair_val = m_pairs.val(pair.index); - f64 v = m_values.val(val.index); - f64x2 res = f64x2::zero(); +PairHandle Graph::record_pair_value(OpType op, PairHandle lhs, ValueHandle rhs) { + PairHandle out = m_pairs.next_handle(); + f64x2 pair_val = m_pairs.val(lhs.index); + f64 v = m_values.val(rhs.index); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairMulValue: @@ -185,218 +193,209 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle pair, ValueHandle val) } m_pairs.alloc(res, f64x2::zero()); - m_tape.push_back({op, out, pair.index, val.index, 0.0}); - return PairHandle(out); -} -ValueHandle Graph::record_phase(PairHandle input, f64 alpha) { - u32 out = m_values.next_index(); - f64x2 pair_val = m_pairs.val(input.index); + m_tape.push_back(Node::make_binary(op, out.index, lhs.index, rhs.index)); - f64 val = alpha * pair_val.first() + (1.0 - alpha) * pair_val.second(); + return out; +} + +ValueHandle Graph::record_phase(PairHandle lhs, f64 alpha) { + ValueHandle out = m_values.next_handle(); + f64x2 pair_val = m_pairs.val(lhs.index); + f64 val = alpha * pair_val.first() + (1.0 - alpha) * pair_val.second(); m_values.alloc(val, 0.0); - m_tape.push_back({OpType::Phase, out, input.index, 0, alpha}); - return ValueHandle(out); + + m_tape.push_back(Node::make_scalar(OpType::Phase, out.index, lhs.index, alpha)); + + return out; } void Graph::backward() { if (m_tape.empty()) { return; } + - const auto& last_node = m_tape.back(); - m_values.grad(last_node.output_idx) = 1.0; - - // Raw pointers for hot loops - f64* vals = m_values.values_data(); - f64* grads = m_values.gradients_data(); + // Initialize gradient of last output to 1 + m_values.grad(m_tape.back().out()) = 1.0; + f64* vals = m_values.values_data(); + f64* grads = m_values.gradients_data(); f64x2* pair_vals = m_pairs.values_data(); f64x2* pair_grads = m_pairs.gradients_data(); for (auto it = m_tape.rbegin(); it != m_tape.rend(); ++it) { const Node& node = *it; + const u32 out_idx = node.out(); + const f64 grad_out = grads[out_idx]; + switch (node.type) { - // Value Binary - case OpType::Add: { - f64 grad = grads[node.output_idx]; - grads[node.lhs_idx] += grad; - grads[node.rhs_idx] += grad; + // Value-Binary + + case OpType::Add: + grads[node.lhs()] += grad_out; + grads[node.rhs()] += grad_out; break; - } - case OpType::Sub: { - f64 grad = grads[node.output_idx]; - grads[node.lhs_idx] += grad; - grads[node.rhs_idx] -= grad; + + case OpType::Sub: + grads[node.lhs()] += grad_out; + grads[node.rhs()] -= grad_out; break; - } + case OpType::Mul: { - f64 grad = grads[node.output_idx]; - f64 l = vals[node.lhs_idx]; - f64 r = vals[node.rhs_idx]; - grads[node.lhs_idx] += r * grad; - grads[node.rhs_idx] += l * grad; + f64 l = vals[node.lhs()]; + f64 r = vals[node.rhs()]; + grads[node.lhs()] += r * grad_out; + grads[node.rhs()] += l * grad_out; break; } + case OpType::Div: { - f64 grad = grads[node.output_idx]; - f64 l = vals[node.lhs_idx]; - f64 r = vals[node.rhs_idx]; - grads[node.lhs_idx] += (1.0 / r) * grad; - grads[node.rhs_idx] += (-l / (r * r)) * grad; + f64 l = vals[node.lhs()]; + f64 r = vals[node.rhs()]; + grads[node.lhs()] += grad_out / r; + grads[node.rhs()] += -l * grad_out / (r * r); break; } + case OpType::Pow: { - f64 grad = grads[node.output_idx]; - f64 base = vals[node.lhs_idx]; - f64 exp = vals[node.rhs_idx]; - grads[node.lhs_idx] += exp * std::pow(base, exp - 1) * grad; - grads[node.rhs_idx] += std::pow(base, exp) * std::log(base) * grad; + f64 base = vals[node.lhs()]; + f64 exp = vals[node.rhs()]; + grads[node.lhs()] += exp * std::pow(base, exp - 1) * grad_out; + grads[node.rhs()] += std::pow(base, exp) * std::log(base) * grad_out; break; } - // Value Unary - case OpType::Exp: { - f64 grad = grads[node.output_idx]; - f64 val = vals[node.output_idx]; - grads[node.lhs_idx] += val * grad; + // Value-Scalar + case OpType::Exp: + grads[node.lhs()] += vals[out_idx] * grad_out; break; - } - case OpType::Log: { - f64 grad = grads[node.output_idx]; - f64 l = vals[node.lhs_idx]; - grads[node.lhs_idx] += (1.0 / l) * grad; + + case OpType::Log: + grads[node.lhs()] += grad_out / vals[node.lhs()]; break; - } + case OpType::Sigmoid: { - f64 grad = grads[node.output_idx]; - f64 s = vals[node.output_idx]; - grads[node.lhs_idx] += (s * (1.0 - s)) * grad; + f64 s = vals[out_idx]; + grads[node.lhs()] += s * (1.0 - s) * grad_out; break; } - case OpType::Neg: { - grads[node.lhs_idx] -= grads[node.output_idx]; + + case OpType::Neg: + grads[node.lhs()] -= grad_out; break; - } + case OpType::PowConst: { - f64 grad = grads[node.output_idx]; - f64 l = vals[node.lhs_idx]; - f64 exp = node.scalar_data; - grads[node.lhs_idx] += exp * std::pow(l, exp - 1.0) * grad; + f64 l = vals[node.lhs()]; + f64 exp = node.scalar(); + grads[node.lhs()] += exp * std::pow(l, exp - 1.0) * grad_out; break; } + case OpType::AddScalar: case OpType::SubScalarVal: - case OpType::ValSubScalar: { - grads[node.lhs_idx] += grads[node.output_idx]; + case OpType::ValSubScalar: + grads[node.lhs()] += grad_out; break; - } - case OpType::MulScalar: { - grads[node.lhs_idx] += node.scalar_data * grads[node.output_idx]; + + case OpType::MulScalar: + grads[node.lhs()] += node.scalar() * grad_out; break; - } - case OpType::ValDivScalar: { - grads[node.lhs_idx] += (1.0 / node.scalar_data) * grads[node.output_idx]; + + case OpType::ValDivScalar: + grads[node.lhs()] += grad_out / node.scalar(); break; - } + case OpType::DivScalarVal: { - f64 grad = grads[node.output_idx]; - f64 l = vals[node.lhs_idx]; - grads[node.lhs_idx] += (-node.scalar_data / (l * l)) * grad; + f64 l = vals[node.lhs()]; + grads[node.lhs()] += -node.scalar() * grad_out / (l * l); break; } + // Pair-Binary case OpType::PairAdd: { - f64x2 grad_out = pair_grads[node.output_idx]; - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], grad_out); - pair_grads[node.rhs_idx] = f64x2::add(pair_grads[node.rhs_idx], grad_out); + f64x2 grad_pair = f64x2::make(grad_out, grad_out); // same grad applied to both + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); + pair_grads[node.rhs()] = f64x2::add(pair_grads[node.rhs()], grad_pair); break; } + case OpType::PairSub: { - f64x2 grad_out = pair_grads[node.output_idx]; - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], grad_out); - pair_grads[node.rhs_idx] = f64x2::sub(pair_grads[node.rhs_idx], grad_out); + f64x2 grad_pair = f64x2::make(grad_out, grad_out); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); + pair_grads[node.rhs()] = f64x2::sub(pair_grads[node.rhs()], grad_pair); break; } - case OpType::PairNeg: { - f64x2 grad_out = pair_grads[node.output_idx]; - pair_grads[node.lhs_idx] = f64x2::sub(pair_grads[node.lhs_idx], grad_out); + case OpType::PairNeg: + pair_grads[node.lhs()] = + f64x2::sub(pair_grads[node.lhs()], f64x2::make(grad_out, grad_out)); break; - } + + // Pair-Scalar case OpType::PairMulScalar: { - f64x2 grad_out = pair_grads[node.output_idx]; - f64x2 scaled = f64x2::mul_scalar(grad_out, node.scalar_data); - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], scaled); + f64x2 scaled = f64x2::mul_scalar(f64x2::make(grad_out, grad_out), node.scalar()); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], scaled); break; } + case OpType::PairDivScalar: { - f64x2 grad_out = pair_grads[node.output_idx]; - f64x2 scaled = f64x2::div_scalar(grad_out, node.scalar_data); - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], scaled); + f64x2 scaled = f64x2::div_scalar(f64x2::make(grad_out, grad_out), node.scalar()); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], scaled); break; } + case OpType::ScalarDivPair: { - f64x2 grad_out = pair_grads[node.output_idx]; - f64x2 val = pair_vals[node.lhs_idx]; - f64x2 val_sq = f64x2::mul(val, val); - f64x2 deriv = f64x2::scalar_div(-node.scalar_data, val_sq); - f64x2 update = f64x2::mul(deriv, grad_out); - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], update); + f64x2 val = pair_vals[node.lhs()]; + f64x2 grad = f64x2::scalar_div(-node.scalar(), f64x2::mul(val, val)); + f64x2 update = f64x2::mul(grad, f64x2::make(grad_out, grad_out)); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); break; } + // Pair-Value case OpType::PairMulValue: case OpType::ValueMulPair: { - f64x2 grad_out = pair_grads[node.output_idx]; - f64x2 pair_val = pair_vals[node.lhs_idx]; - f64 v = vals[node.rhs_idx]; - - f64x2 pair_update = f64x2::mul_scalar(grad_out, v); - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], pair_update); + f64x2 grad_pair = f64x2::mul_scalar(f64x2::make(grad_out, grad_out), vals[node.rhs()]); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); f64 contrib = - pair_val.first() * grad_out.first() + pair_val.second() * grad_out.second(); - grads[node.rhs_idx] += contrib; + pair_vals[node.lhs()].first() * grad_out + pair_vals[node.lhs()].second() * grad_out; + grads[node.rhs()] += contrib; break; } - case OpType::PairDivValue: { - f64x2 grad_out = pair_grads[node.output_idx]; - f64x2 pair_val = pair_vals[node.lhs_idx]; - f64 v = vals[node.rhs_idx]; - f64x2 pair_update = f64x2::div_scalar(grad_out, v); - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], pair_update); + case OpType::PairDivValue: { + f64x2 grad_pair = f64x2::div_scalar(f64x2::make(grad_out, grad_out), vals[node.rhs()]); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); - f64 num = pair_val.first() * grad_out.first() + pair_val.second() * grad_out.second(); - grads[node.rhs_idx] += -num / (v * v); + f64 num = + pair_vals[node.lhs()].first() * grad_out + pair_vals[node.lhs()].second() * grad_out; + grads[node.rhs()] += -num / (vals[node.rhs()] * vals[node.rhs()]); break; } + case OpType::ValueDivPair: { - f64x2 grad_out = pair_grads[node.output_idx]; - f64x2 pair_val = pair_vals[node.lhs_idx]; - f64 v = vals[node.rhs_idx]; - - f64x2 pair_sq = f64x2::mul(pair_val, pair_val); - f64x2 deriv = f64x2::scalar_div(-v, pair_sq); - f64x2 update = f64x2::mul(deriv, grad_out); - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], update); - - f64x2 recip = f64x2::scalar_div(1.0, pair_val); - grads[node.rhs_idx] += - grad_out.first() * recip.first() + grad_out.second() * recip.second(); + f64x2 val = pair_vals[node.lhs()]; + f64x2 grad_pair = f64x2::scalar_div(-vals[node.rhs()], f64x2::mul(val, val)); + f64x2 update = f64x2::mul(grad_pair, f64x2::make(grad_out, grad_out)); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); + + f64x2 recip = f64x2::scalar_div(1.0, val); + grads[node.rhs()] += grad_out * (recip.first() + recip.second()); break; } + // Phase ops case OpType::Phase: { - f64 grad = grads[node.output_idx]; - f64 alpha = node.scalar_data; - f64x2 update = f64x2::make(alpha * grad, (1.0 - alpha) * grad); - pair_grads[node.lhs_idx] = f64x2::add(pair_grads[node.lhs_idx], update); + f64x2 update = + f64x2::make(node.scalar() * grad_out, (1.0 - node.scalar()) * grad_out); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); break; } + default: break; } diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index 09404d32..ebca103d 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -1,6 +1,9 @@ #pragma once #include "util/types.hpp" +#include "value.hpp" +#include +#include namespace Clockwork::Autograd { @@ -54,16 +57,80 @@ enum class OpType : u8 { Sum // Sum of a vector of values }; -// A single node in the compute tape. Probably can be rewritten more compactly. -struct Node { - OpType type; // This tells us which arenas to look at and how to interpret lhs/rhs - u32 output_idx; // Index in the respective arena (Value or Pair) - u32 lhs_idx; // Index of first operand - u32 rhs_idx; // Index of second operand (if applicable) +// Node layout (16 bytes): +// [0..3] : type (1), pad (1), lhs_offset (2) -> 4 bytes +// [4..7] : output_idx -> 4 bytes +// [8..15] : union { struct { u16 rhs_offset; u16 pad2; u32 pad3; } ; double scalar; } -> 8 bytes +struct alignas(8) Node { + using u8 = uint8_t; + using u16 = uint16_t; + using u32 = uint32_t; + using f64 = double; - // Auxiliary data for scalar ops, constants, or specific parameters. - f64 scalar_data; + OpType type; // u8 + u8 pad; // required to align lhs_offset + u16 lhs_offset; + u32 output_idx; + + union U { + u16 rhs_offset; // for unary/binary ops + f64 scalar_data; // for scalar ops + + constexpr U() : + rhs_offset(0) { + } + constexpr U(u16 rhs) : + rhs_offset(rhs) { + } + constexpr U(f64 scalar) : + scalar_data(scalar) { + } + } u; + + static constexpr Node make_binary(OpType t, u32 output_idx, u32 lhs_idx, u32 rhs_idx) { + Node n{}; + n.type = t; + n.pad = 0; + + // lhs & rhs indices are guaranteed to be <= out + n.lhs_offset = static_cast(output_idx - lhs_idx); + n.output_idx = output_idx; + n.u.rhs_offset = static_cast(output_idx - rhs_idx); + + return n; + } + + static constexpr Node make_scalar(OpType t, u32 output_idx, u32 lhs_idx, f64 scalar) { + Node n{}; + n.type = t; + n.pad = 0; + n.lhs_offset = static_cast(output_idx - lhs_idx); + n.output_idx = output_idx; + n.u.scalar_data = scalar; + return n; + } + + constexpr u32 lhs() const noexcept { + return output_idx - lhs_offset; + } + + constexpr u32 rhs() const noexcept { + return output_idx - u.rhs_offset; + } + + constexpr u32 out() const noexcept { + return output_idx; + } + + constexpr f64 scalar() const noexcept { + return u.scalar_data; + } }; +static_assert(sizeof(Node) == 16, "Node must be exactly 16 bytes"); +static_assert(alignof(Node) == alignof(double), + "Node alignment must match double alignment (8 bytes)"); +static_assert(offsetof(Node, u) == 8, "Union must begin at offset 8 to keep double aligned."); + } // namespace Clockwork::Autograd From 66970622bee3d95b04d647120af51810d443cf17 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 19:44:14 +0100 Subject: [PATCH 22/31] alignas(16) --- src/tuning/operations.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index ebca103d..ab7de266 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -62,7 +62,7 @@ enum class OpType : u8 { // [0..3] : type (1), pad (1), lhs_offset (2) -> 4 bytes // [4..7] : output_idx -> 4 bytes // [8..15] : union { struct { u16 rhs_offset; u16 pad2; u32 pad3; } ; double scalar; } -> 8 bytes -struct alignas(8) Node { +struct alignas(16) Node { using u8 = uint8_t; using u16 = uint16_t; using u32 = uint32_t; @@ -129,7 +129,7 @@ struct alignas(8) Node { }; static_assert(sizeof(Node) == 16, "Node must be exactly 16 bytes"); -static_assert(alignof(Node) == alignof(double), +static_assert(alignof(Node) == 16, "Node alignment must match double alignment (8 bytes)"); static_assert(offsetof(Node, u) == 8, "Union must begin at offset 8 to keep double aligned."); From 7462470e497c432abb7b4de2e19164415cb9918e Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sat, 22 Nov 2025 20:56:09 +0100 Subject: [PATCH 23/31] 16 bytes aligned operation node --- src/tuning/operations.hpp | 40 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index ab7de266..07695b69 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -7,7 +7,7 @@ namespace Clockwork::Autograd { -enum class OpType : u8 { +enum class OpType : u32 { // Leaf nodes None, Parameter, // Created from a global parameter @@ -63,40 +63,34 @@ enum class OpType : u8 { // [4..7] : output_idx -> 4 bytes // [8..15] : union { struct { u16 rhs_offset; u16 pad2; u32 pad3; } ; double scalar; } -> 8 bytes struct alignas(16) Node { - using u8 = uint8_t; - using u16 = uint16_t; - using u32 = uint32_t; - using f64 = double; - OpType type; // u8 - u8 pad; // required to align lhs_offset - u16 lhs_offset; - u32 output_idx; + OpType type; // u32 + u32 lhs_idx; + u32 output_idx; union U { - u16 rhs_offset; // for unary/binary ops - f64 scalar_data; // for scalar ops + u32 rhs_idx; // for unary/binary ops + f32 scalar_data; // for scalar ops constexpr U() : - rhs_offset(0) { + rhs_idx(0) { } - constexpr U(u16 rhs) : - rhs_offset(rhs) { + constexpr U(u16 rhs_idx) : + rhs_idx(rhs_idx) { } constexpr U(f64 scalar) : - scalar_data(scalar) { + scalar_data(static_cast(scalar)) { } } u; static constexpr Node make_binary(OpType t, u32 output_idx, u32 lhs_idx, u32 rhs_idx) { Node n{}; n.type = t; - n.pad = 0; // lhs & rhs indices are guaranteed to be <= out - n.lhs_offset = static_cast(output_idx - lhs_idx); + n.lhs_idx = lhs_idx; n.output_idx = output_idx; - n.u.rhs_offset = static_cast(output_idx - rhs_idx); + n.u.rhs_idx = rhs_idx; return n; } @@ -104,19 +98,18 @@ struct alignas(16) Node { static constexpr Node make_scalar(OpType t, u32 output_idx, u32 lhs_idx, f64 scalar) { Node n{}; n.type = t; - n.pad = 0; - n.lhs_offset = static_cast(output_idx - lhs_idx); + n.lhs_idx = lhs_idx; n.output_idx = output_idx; - n.u.scalar_data = scalar; + n.u.scalar_data = static_cast(scalar); return n; } constexpr u32 lhs() const noexcept { - return output_idx - lhs_offset; + return lhs_idx; } constexpr u32 rhs() const noexcept { - return output_idx - u.rhs_offset; + return u.rhs_idx; } constexpr u32 out() const noexcept { @@ -131,6 +124,5 @@ struct alignas(16) Node { static_assert(sizeof(Node) == 16, "Node must be exactly 16 bytes"); static_assert(alignof(Node) == 16, "Node alignment must match double alignment (8 bytes)"); -static_assert(offsetof(Node, u) == 8, "Union must begin at offset 8 to keep double aligned."); } // namespace Clockwork::Autograd From fe23d85e5c2110d8ac195132096baa30463c0f5a Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sun, 23 Nov 2025 00:59:59 +0100 Subject: [PATCH 24/31] Bugfix + unreachable --- src/tuning/graph.cpp | 175 ++++++++++++++++++++++++------------------- 1 file changed, 96 insertions(+), 79 deletions(-) diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 9b16f32e..8a72bbc9 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -1,5 +1,6 @@ #include "tuning/graph.hpp" #include "tuning/globals.hpp" +#include "util/types.hpp" #include namespace Clockwork::Autograd { @@ -215,9 +216,9 @@ void Graph::backward() { if (m_tape.empty()) { return; } - // Initialize gradient of last output to 1 + // (This assumes the final output is always a ValueHandle) m_values.grad(m_tape.back().out()) = 1.0; f64* vals = m_values.values_data(); @@ -228,175 +229,191 @@ void Graph::backward() { for (auto it = m_tape.rbegin(); it != m_tape.rend(); ++it) { const Node& node = *it; - const u32 out_idx = node.out(); - const f64 grad_out = grads[out_idx]; + const u32 out_idx = node.out(); switch (node.type) { // Value-Binary - case OpType::Add: + case OpType::Add: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] += grad_out; grads[node.rhs()] += grad_out; break; - - case OpType::Sub: + } + case OpType::Sub: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] += grad_out; grads[node.rhs()] -= grad_out; break; - + } case OpType::Mul: { - f64 l = vals[node.lhs()]; - f64 r = vals[node.rhs()]; + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; + f64 r = vals[node.rhs()]; grads[node.lhs()] += r * grad_out; grads[node.rhs()] += l * grad_out; break; } - case OpType::Div: { - f64 l = vals[node.lhs()]; - f64 r = vals[node.rhs()]; + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; + f64 r = vals[node.rhs()]; grads[node.lhs()] += grad_out / r; grads[node.rhs()] += -l * grad_out / (r * r); break; } - case OpType::Pow: { - f64 base = vals[node.lhs()]; - f64 exp = vals[node.rhs()]; + const f64 grad_out = grads[out_idx]; + f64 base = vals[node.lhs()]; + f64 exp = vals[node.rhs()]; grads[node.lhs()] += exp * std::pow(base, exp - 1) * grad_out; grads[node.rhs()] += std::pow(base, exp) * std::log(base) * grad_out; break; } // Value-Scalar - case OpType::Exp: + case OpType::Exp: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] += vals[out_idx] * grad_out; break; - - case OpType::Log: + } + case OpType::Log: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] += grad_out / vals[node.lhs()]; break; - + } case OpType::Sigmoid: { - f64 s = vals[out_idx]; + const f64 grad_out = grads[out_idx]; + f64 s = vals[out_idx]; grads[node.lhs()] += s * (1.0 - s) * grad_out; break; } - - case OpType::Neg: + case OpType::Neg: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] -= grad_out; break; - + } case OpType::PowConst: { - f64 l = vals[node.lhs()]; - f64 exp = node.scalar(); + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; + f64 exp = node.scalar(); grads[node.lhs()] += exp * std::pow(l, exp - 1.0) * grad_out; break; } - case OpType::AddScalar: - case OpType::SubScalarVal: - case OpType::ValSubScalar: + case OpType::ValSubScalar: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] += grad_out; break; - - case OpType::MulScalar: + } + + case OpType::SubScalarVal: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] -= grad_out; + break; + } + case OpType::MulScalar: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] += node.scalar() * grad_out; break; - - case OpType::ValDivScalar: + } + case OpType::ValDivScalar: { + const f64 grad_out = grads[out_idx]; grads[node.lhs()] += grad_out / node.scalar(); break; - + } case OpType::DivScalarVal: { - f64 l = vals[node.lhs()]; + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; grads[node.lhs()] += -node.scalar() * grad_out / (l * l); break; } + case OpType::Phase: { + const f64 grad_out = grads[out_idx]; + f64x2 update = f64x2::make(node.scalar() * grad_out, (1.0 - node.scalar()) * grad_out); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); + break; + } - // Pair-Binary case OpType::PairAdd: { - f64x2 grad_pair = f64x2::make(grad_out, grad_out); // same grad applied to both - pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); - pair_grads[node.rhs()] = f64x2::add(pair_grads[node.rhs()], grad_pair); + const f64x2 grad_out = pair_grads[out_idx]; + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_out); + pair_grads[node.rhs()] = f64x2::add(pair_grads[node.rhs()], grad_out); break; } - case OpType::PairSub: { - f64x2 grad_pair = f64x2::make(grad_out, grad_out); - pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); - pair_grads[node.rhs()] = f64x2::sub(pair_grads[node.rhs()], grad_pair); + const f64x2 grad_out = pair_grads[out_idx]; + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_out); + pair_grads[node.rhs()] = f64x2::sub(pair_grads[node.rhs()], grad_out); break; } - - case OpType::PairNeg: - pair_grads[node.lhs()] = - f64x2::sub(pair_grads[node.lhs()], f64x2::make(grad_out, grad_out)); + case OpType::PairNeg: { + const f64x2 grad_out = pair_grads[out_idx]; + pair_grads[node.lhs()] = f64x2::sub(pair_grads[node.lhs()], grad_out); break; - - // Pair-Scalar + } case OpType::PairMulScalar: { - f64x2 scaled = f64x2::mul_scalar(f64x2::make(grad_out, grad_out), node.scalar()); + const f64x2 grad_out = pair_grads[out_idx]; + f64x2 scaled = f64x2::mul_scalar(grad_out, node.scalar()); pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], scaled); break; } - case OpType::PairDivScalar: { - f64x2 scaled = f64x2::div_scalar(f64x2::make(grad_out, grad_out), node.scalar()); + const f64x2 grad_out = pair_grads[out_idx]; + f64x2 scaled = f64x2::div_scalar(grad_out, node.scalar()); pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], scaled); break; } - case OpType::ScalarDivPair: { - f64x2 val = pair_vals[node.lhs()]; - f64x2 grad = f64x2::scalar_div(-node.scalar(), f64x2::mul(val, val)); - f64x2 update = f64x2::mul(grad, f64x2::make(grad_out, grad_out)); + const f64x2 grad_out = pair_grads[out_idx]; + f64x2 val = pair_vals[node.lhs()]; + f64x2 grad = f64x2::scalar_div(-node.scalar(), f64x2::mul(val, val)); + f64x2 update = f64x2::mul(grad, grad_out); pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); break; } - - // Pair-Value case OpType::PairMulValue: case OpType::ValueMulPair: { - f64x2 grad_pair = f64x2::mul_scalar(f64x2::make(grad_out, grad_out), vals[node.rhs()]); + const f64x2 grad_out = pair_grads[out_idx]; + f64 val_rhs = vals[node.rhs()]; + f64x2 val_lhs = pair_vals[node.lhs()]; + + f64x2 grad_pair = f64x2::mul_scalar(grad_out, val_rhs); pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); - f64 contrib = - pair_vals[node.lhs()].first() * grad_out + pair_vals[node.lhs()].second() * grad_out; + f64 contrib = grad_out.first() * val_lhs.first() + grad_out.second() * val_lhs.second(); grads[node.rhs()] += contrib; break; } - case OpType::PairDivValue: { - f64x2 grad_pair = f64x2::div_scalar(f64x2::make(grad_out, grad_out), vals[node.rhs()]); + const f64x2 grad_out = pair_grads[out_idx]; + f64 val_rhs = vals[node.rhs()]; + f64x2 val_lhs = pair_vals[node.lhs()]; + + f64x2 grad_pair = f64x2::div_scalar(grad_out, val_rhs); pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); - f64 num = - pair_vals[node.lhs()].first() * grad_out + pair_vals[node.lhs()].second() * grad_out; - grads[node.rhs()] += -num / (vals[node.rhs()] * vals[node.rhs()]); + f64 num = grad_out.first() * val_lhs.first() + grad_out.second() * val_lhs.second(); + grads[node.rhs()] += -num / (val_rhs * val_rhs); break; } - case OpType::ValueDivPair: { - f64x2 val = pair_vals[node.lhs()]; - f64x2 grad_pair = f64x2::scalar_div(-vals[node.rhs()], f64x2::mul(val, val)); - f64x2 update = f64x2::mul(grad_pair, f64x2::make(grad_out, grad_out)); - pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); + const f64x2 grad_out = pair_grads[out_idx]; + f64 val_rhs = vals[node.rhs()]; + f64x2 val_lhs = pair_vals[node.lhs()]; - f64x2 recip = f64x2::scalar_div(1.0, val); - grads[node.rhs()] += grad_out * (recip.first() + recip.second()); - break; - } - - // Phase ops - case OpType::Phase: { - f64x2 update = - f64x2::make(node.scalar() * grad_out, (1.0 - node.scalar()) * grad_out); + f64x2 grad_pair = f64x2::scalar_div(-val_rhs, f64x2::mul(val_lhs, val_lhs)); + f64x2 update = f64x2::mul(grad_pair, grad_out); pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); + + f64x2 recip = f64x2::scalar_div(1.0, val_lhs); + grads[node.rhs()] += + grad_out.first() * recip.first() + grad_out.second() * recip.second(); break; } default: + unreachable(); break; } } From e774189b70849a068e721d9ec72ad07a25257871 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Sun, 23 Nov 2025 18:40:08 +0100 Subject: [PATCH 25/31] Format and prep for pgo tests --- src/evaltune_main.cpp | 8 +++++--- src/tuning/arena.hpp | 3 +-- src/tuning/graph.cpp | 36 ++++++++++++++++++------------------ src/tuning/operations.hpp | 19 +++++++++---------- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index 1c0fd242..44a5ec3b 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -96,8 +96,11 @@ int main() { // The optimizer will now start with all-zero parameters AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0); - - const i32 epochs = 8; +#ifdef PROFILE_RUN + const i32 epochs = 8; +#else + const i32 epochs = 1000; +#endif const f64 K = 1.0 / 400; const size_t batch_size = 16 * 16384; @@ -200,7 +203,6 @@ int main() { Graph::get().cleanup(); Graph::get().zero_grad(); -#define PROFILE_RUN #ifndef PROFILE_RUN std::cout << "inline const PParam PAWN_MAT = " << PAWN_MAT << ";" << std::endl; std::cout << "inline const PParam KNIGHT_MAT = " << KNIGHT_MAT << ";" << std::endl; diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index 43d6a8ca..bac8cc4f 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -1,7 +1,6 @@ #pragma once #include "util/vec/sse2.hpp" #include "value.hpp" -#include "value.hpp" #include "util/types.hpp" #include @@ -33,7 +32,7 @@ class ValueArena { } inline u32 next_index() const { - return static_cast(values.size()); + return static_cast(values.size()); } inline ValueHandle next_handle() const { diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 8a72bbc9..91ef6163 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -38,9 +38,9 @@ PairHandle Graph::create_pair(f64x2 data) { ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { ValueHandle out = m_values.next_handle(); - f64 l = m_values.val(lhs.index); - f64 r = m_values.val(rhs.index); - f64 res = 0.0; + f64 l = m_values.val(lhs.index); + f64 r = m_values.val(rhs.index); + f64 res = 0.0; switch (op) { case OpType::Add: @@ -71,8 +71,8 @@ ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { ValueHandle Graph::record_op(OpType op, ValueHandle lhs, f64 scalar) { ValueHandle out = m_values.next_handle(); - f64 l = m_values.val(lhs.index); - f64 res = 0.0; + f64 l = m_values.val(lhs.index); + f64 res = 0.0; switch (op) { case OpType::Exp: @@ -121,9 +121,9 @@ ValueHandle Graph::record_op(OpType op, ValueHandle lhs, f64 scalar) { PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { PairHandle out = m_pairs.next_handle(); - f64x2 l = m_pairs.val(lhs.index); - f64x2 r = m_pairs.val(rhs.index); - f64x2 res = f64x2::zero(); + f64x2 l = m_pairs.val(lhs.index); + f64x2 r = m_pairs.val(rhs.index); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairAdd: @@ -145,8 +145,8 @@ PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { PairHandle Graph::record_pair_scalar(OpType op, PairHandle lhs, f64 scalar) { PairHandle out = m_pairs.next_handle(); - f64x2 l = m_pairs.val(lhs.index); - f64x2 res = f64x2::zero(); + f64x2 l = m_pairs.val(lhs.index); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairNeg: @@ -174,9 +174,9 @@ PairHandle Graph::record_pair_scalar(OpType op, PairHandle lhs, f64 scalar) { PairHandle Graph::record_pair_value(OpType op, PairHandle lhs, ValueHandle rhs) { PairHandle out = m_pairs.next_handle(); - f64x2 pair_val = m_pairs.val(lhs.index); - f64 v = m_values.val(rhs.index); - f64x2 res = f64x2::zero(); + f64x2 pair_val = m_pairs.val(lhs.index); + f64 v = m_values.val(rhs.index); + f64x2 res = f64x2::zero(); switch (op) { case OpType::PairMulValue: @@ -202,8 +202,8 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle lhs, ValueHandle rhs) ValueHandle Graph::record_phase(PairHandle lhs, f64 alpha) { ValueHandle out = m_values.next_handle(); - f64x2 pair_val = m_pairs.val(lhs.index); - f64 val = alpha * pair_val.first() + (1.0 - alpha) * pair_val.second(); + f64x2 pair_val = m_pairs.val(lhs.index); + f64 val = alpha * pair_val.first() + (1.0 - alpha) * pair_val.second(); m_values.alloc(val, 0.0); @@ -232,7 +232,7 @@ void Graph::backward() { const u32 out_idx = node.out(); switch (node.type) { - // Value-Binary + // Value-Binary case OpType::Add: { const f64 grad_out = grads[out_idx]; @@ -306,7 +306,7 @@ void Graph::backward() { grads[node.lhs()] += grad_out; break; } - + case OpType::SubScalarVal: { const f64 grad_out = grads[out_idx]; grads[node.lhs()] -= grad_out; @@ -413,7 +413,7 @@ void Graph::backward() { } default: - unreachable(); + unreachable(); break; } } diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp index 07695b69..d01dfd7d 100644 --- a/src/tuning/operations.hpp +++ b/src/tuning/operations.hpp @@ -2,8 +2,8 @@ #include "util/types.hpp" #include "value.hpp" -#include #include +#include namespace Clockwork::Autograd { @@ -65,11 +65,11 @@ enum class OpType : u32 { struct alignas(16) Node { OpType type; // u32 - u32 lhs_idx; - u32 output_idx; + u32 lhs_idx; + u32 output_idx; union U { - u32 rhs_idx; // for unary/binary ops + u32 rhs_idx; // for unary/binary ops f32 scalar_data; // for scalar ops constexpr U() : @@ -88,9 +88,9 @@ struct alignas(16) Node { n.type = t; // lhs & rhs indices are guaranteed to be <= out - n.lhs_idx = lhs_idx; - n.output_idx = output_idx; - n.u.rhs_idx = rhs_idx; + n.lhs_idx = lhs_idx; + n.output_idx = output_idx; + n.u.rhs_idx = rhs_idx; return n; } @@ -98,7 +98,7 @@ struct alignas(16) Node { static constexpr Node make_scalar(OpType t, u32 output_idx, u32 lhs_idx, f64 scalar) { Node n{}; n.type = t; - n.lhs_idx = lhs_idx; + n.lhs_idx = lhs_idx; n.output_idx = output_idx; n.u.scalar_data = static_cast(scalar); return n; @@ -122,7 +122,6 @@ struct alignas(16) Node { }; static_assert(sizeof(Node) == 16, "Node must be exactly 16 bytes"); -static_assert(alignof(Node) == 16, - "Node alignment must match double alignment (8 bytes)"); +static_assert(alignof(Node) == 16, "Node alignment must match double alignment (8 bytes)"); } // namespace Clockwork::Autograd From 32317c5114c19df86c6af5b0b12bc60e8458e608 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Tue, 2 Dec 2025 02:49:38 +0100 Subject: [PATCH 26/31] Fix merge fuckery and format Bench: 11233646 --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e5d03e2..a2b6261f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,8 @@ set(srcs src/board.cpp src/board.hpp src/common.hpp + src/cuckoo.cpp + src/cuckoo.hpp src/dbg_tools.cpp src/dbg_tools.hpp src/eval_constants.hpp From b6dc03b448ecaf9b09b1e0b379f3dbea205aa6e2 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Wed, 3 Dec 2025 02:11:31 +0100 Subject: [PATCH 27/31] Bench: 11856625 --- src/eval_types.hpp | 80 +++++++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/src/eval_types.hpp b/src/eval_types.hpp index 50890fa9..6391bb8c 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -21,9 +21,8 @@ using Score = i16; class PScore { private: i32 m_score; - explicit constexpr PScore(i32 score) : - m_score(score) { + m_score{score} { } public: @@ -32,69 +31,78 @@ class PScore { } constexpr PScore(Score midgame, Score endgame) : - m_score(static_cast((u32(endgame) << 16) + u16(midgame))) { + m_score{static_cast(static_cast(endgame) << 16) + midgame} { assert(std::numeric_limits::min() <= midgame - && midgame <= std::numeric_limits::max()); + && std::numeric_limits::max() >= midgame); assert(std::numeric_limits::min() <= endgame - && endgame <= std::numeric_limits::max()); + && std::numeric_limits::max() >= endgame); } - [[nodiscard]] inline Score mg() const { - u16 mg = u16(m_score); - i16 v; + [[nodiscard]] inline auto mg() const { + const auto mg = static_cast(m_score); + + i16 v{}; std::memcpy(&v, &mg, sizeof(mg)); - return v; + + return static_cast(v); } - [[nodiscard]] inline Score eg() const { - u16 eg = u16(u32(m_score + 0x8000) >> 16); - i16 v; + [[nodiscard]] inline auto eg() const { + const auto eg = static_cast(static_cast(m_score + 0x8000) >> 16); + + i16 v{}; std::memcpy(&v, &eg, sizeof(eg)); - return v; + + return static_cast(v); } - // Operators identical to original version - constexpr PScore operator+(const PScore& o) const { - return PScore(m_score + o.m_score); + [[nodiscard]] constexpr auto operator+(const PScore& other) const { + return PScore{m_score + other.m_score}; } - constexpr PScore operator-(const PScore& o) const { - return PScore(m_score - o.m_score); + + constexpr auto operator+=(const PScore& other) -> auto& { + m_score += other.m_score; + return *this; } - constexpr PScore operator*(i32 v) const { - return PScore(m_score * v); + + [[nodiscard]] constexpr auto operator-(const PScore& other) const { + return PScore{m_score - other.m_score}; } - constexpr PScore& operator+=(const PScore& o) { - m_score += o.m_score; + + constexpr auto operator-=(const PScore& other) -> auto& { + m_score -= other.m_score; return *this; } - constexpr PScore& operator-=(const PScore& o) { - m_score -= o.m_score; - return *this; + + [[nodiscard]] constexpr auto operator*(i32 v) const { + return PScore{m_score * v}; } - constexpr PScore& operator*=(i32 v) { + + constexpr auto operator*=(i32 v) -> auto& { m_score *= v; return *this; } - constexpr PScore operator-() const { - return PScore(-m_score); + + [[nodiscard]] constexpr auto operator-() const { + return PScore{-m_score}; } - constexpr bool operator==(const PScore&) const = default; + [[nodiscard]] constexpr bool operator==(const PScore& other) const = default; - constexpr const PScore* operator->() const { + [[nodiscard]] constexpr const PScore* operator->() const { return this; } - // Phase function (non-tuning: returns int) + // Phasing between two scores template - [[nodiscard]] inline Value phase(i32 alpha) const { + Value phase(i32 alpha) const { assert(0 <= alpha && alpha <= max); - return Value((mg() * alpha + eg() * (max - alpha)) / max); + return static_cast((mg() * alpha + eg() * (max - alpha)) / max); } - friend std::ostream& operator<<(std::ostream& os, const PScore& s) { - os << "(" << s.mg() << "\t" << s.eg() << ")"; - return os; + friend std::ostream& operator<<(std::ostream& stream, const PScore& score) { + stream << "(" << score.mg() << "\t" << score.eg() << ")"; + return stream; } }; From f06a4ffb79017bc548e776ec4e525011c76b0911 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Wed, 3 Dec 2025 02:28:35 +0100 Subject: [PATCH 28/31] Bench: 11856625 --- src/eval_types.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eval_types.hpp b/src/eval_types.hpp index 6391bb8c..fd124109 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -14,9 +14,9 @@ #endif namespace Clockwork { +using Score = i16; #ifndef EVAL_TUNING -using Score = i16; class PScore { private: From 52d8006fbf959d8bb8f8950ec91b747660a6f3cb Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Wed, 3 Dec 2025 02:43:53 +0100 Subject: [PATCH 29/31] ahahhhahahahahHAhAhahhahahAHHaHaHAHAHAHAH Bench: 11856625 --- src/eval_types.hpp | 2 +- src/tuning/arena.hpp | 20 +++----------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/eval_types.hpp b/src/eval_types.hpp index fd124109..6391bb8c 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -14,9 +14,9 @@ #endif namespace Clockwork { -using Score = i16; #ifndef EVAL_TUNING +using Score = i16; class PScore { private: diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index bac8cc4f..63f24188 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -17,20 +17,13 @@ class ValueArena { gradients.reserve(n); } - inline u32 alloc(f64 value, f64 grad = 0.0) { + inline u32 alloc(f64 value = 0.0, f64 grad = 0.0) { u32 idx = static_cast(values.size()); values.push_back(value); gradients.push_back(grad); return idx; } - inline u32 alloc_uninitialized() { - u32 idx = static_cast(values.size()); - values.push_back(0.0); - gradients.push_back(0.0); - return idx; - } - inline u32 next_index() const { return static_cast(values.size()); } @@ -102,20 +95,13 @@ class PairArena { gradients.reserve(n); } - inline u32 alloc(f64x2 v, f64x2 g = f64x2::zero()) { + inline u32 alloc(f64x2 v = f64x2::zero(), f64x2 g = f64x2::zero()) { u32 idx = static_cast(values.size()); values.push_back(v); gradients.push_back(g); return idx; } - - inline u32 alloc_uninitialized() { - u32 idx = static_cast(values.size()); - values.push_back(f64x2::zero()); - gradients.push_back(f64x2::zero()); - return idx; - } - + inline u32 next_index() const { return static_cast(values.size()); } From a36f77cd62690cc9176a03f04d63095f48e9e8e0 Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Wed, 3 Dec 2025 03:04:54 +0100 Subject: [PATCH 30/31] Bench: 11856625 --- src/util/vec/sse2.hpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/util/vec/sse2.hpp b/src/util/vec/sse2.hpp index e8f302dc..86852ae2 100644 --- a/src/util/vec/sse2.hpp +++ b/src/util/vec/sse2.hpp @@ -6,13 +6,13 @@ #if defined(__SSE2__) #include - #define F128_USE_SSE2 1 + #define F64X2_USE_SSE2 1 #else - #define F128_USE_SSE2 0 + #define F64X2_USE_SSE2 0 #endif struct f64x2 { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 __m128d v = _mm_setzero_pd(); #else double lo = 0.0; @@ -21,7 +21,7 @@ struct f64x2 { // ---- Constructors ---- static inline f64x2 make(double a, double b) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_set_pd(b, a); return r; @@ -31,7 +31,7 @@ struct f64x2 { } static inline f64x2 broadcast(double x) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_set1_pd(x); return r; @@ -41,7 +41,7 @@ struct f64x2 { } static inline f64x2 zero() { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_setzero_pd(); return r; @@ -52,7 +52,7 @@ struct f64x2 { // ---- Extract ---- inline double first() const { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 alignas(16) double buf[2]; _mm_store_pd(buf, v); return buf[0]; @@ -62,7 +62,7 @@ struct f64x2 { } inline double second() const { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 alignas(16) double buf[2]; _mm_store_pd(buf, v); return buf[1]; @@ -73,7 +73,7 @@ struct f64x2 { // ---- Arithmetic ---- static inline f64x2 add(const f64x2& a, const f64x2& b) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_add_pd(a.v, b.v); return r; @@ -83,7 +83,7 @@ struct f64x2 { } static inline f64x2 sub(const f64x2& a, const f64x2& b) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_sub_pd(a.v, b.v); return r; @@ -93,7 +93,7 @@ struct f64x2 { } static inline f64x2 mul(const f64x2& a, const f64x2& b) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_mul_pd(a.v, b.v); return r; @@ -103,7 +103,7 @@ struct f64x2 { } static inline f64x2 div(const f64x2& a, const f64x2& b) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_div_pd(a.v, b.v); return r; @@ -113,7 +113,7 @@ struct f64x2 { } static inline f64x2 neg(const f64x2& a) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 __m128d zero = _mm_setzero_pd(); f64x2 r; r.v = _mm_sub_pd(zero, a.v); @@ -141,7 +141,7 @@ struct f64x2 { } static inline f64x2 scalar_div(double s, const f64x2& a) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 __m128d num = _mm_set1_pd(s); f64x2 r; r.v = _mm_div_pd(num, a.v); @@ -153,7 +153,7 @@ struct f64x2 { // ---- Math functions ---- static inline f64x2 sqrt(const f64x2& a) { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_sqrt_pd(a.v); return r; @@ -165,7 +165,7 @@ struct f64x2 { // ---- FMA-style (useful for gradient updates) ---- static inline f64x2 madd(const f64x2& a, const f64x2& b, const f64x2& c) { // a + b*c -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 f64x2 r; r.v = _mm_add_pd(a.v, _mm_mul_pd(b.v, c.v)); return r; From cd7b6b35a4f871f1a566922666696b2ea9653c3a Mon Sep 17 00:00:00 2001 From: TheRealGioviok <425gioviok@gmail.com> Date: Wed, 3 Dec 2025 03:11:13 +0100 Subject: [PATCH 31/31] format yay Bench: 11856625 --- src/tuning/arena.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp index 63f24188..0f510e4f 100644 --- a/src/tuning/arena.hpp +++ b/src/tuning/arena.hpp @@ -101,7 +101,7 @@ class PairArena { gradients.push_back(g); return idx; } - + inline u32 next_index() const { return static_cast(values.size()); }