diff --git a/CMakeLists.txt b/CMakeLists.txt index 62abe7b2..a2b6261f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,73 +72,80 @@ function(target_add_flags target) # Libraries target_link_libraries(${target} PUBLIC git_hash lps) + # LTO + message(STATUS "LTO is set to: ${lto}") # LTO if (lto) set_target_properties(${target} PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) endif() + + endfunction() # Sorted list of source files set(srcs - src/bench.cpp - src/bench.hpp - src/board.cpp - src/board.hpp - src/cuckoo.cpp - src/cuckoo.hpp - src/common.hpp - src/dbg_tools.cpp - src/dbg_tools.hpp - src/eval_constants.hpp - src/eval_types.hpp - src/evaluation.cpp - src/evaluation.hpp - src/geometry.cpp - src/geometry.hpp - src/history.cpp - src/history.hpp - src/move.cpp - src/move.hpp - src/movegen.cpp - src/movegen.hpp - src/movepick.cpp - src/movepick.hpp - src/perft.cpp - src/perft.hpp - src/position.cpp - src/position.hpp - src/psqt_state.hpp - src/repetition_info.cpp - src/repetition_info.hpp - src/search.cpp - src/search.hpp - src/speedtest.cpp - src/speedtest.hpp - src/square.hpp - src/tm.cpp - src/tm.hpp - src/tt.cpp - src/tt.hpp - src/tuned.cpp - src/tuned.hpp - src/tuning/globals.hpp - src/tuning/graph.cpp - src/tuning/graph.hpp - src/tuning/loss.hpp - src/tuning/optim.hpp - src/tuning/value.cpp - src/tuning/value.hpp - src/uci.cpp - src/uci.hpp - src/util/bit.hpp - src/util/parse.hpp - src/util/pretty.hpp - src/util/static_vector.hpp - src/util/types.hpp - src/util/vec/sse2.hpp - src/zobrist.cpp - src/zobrist.hpp + src/bench.cpp + src/bench.hpp + src/board.cpp + src/board.hpp + src/common.hpp + src/cuckoo.cpp + src/cuckoo.hpp + src/dbg_tools.cpp + src/dbg_tools.hpp + src/eval_constants.hpp + src/eval_types.hpp + src/evaluation.cpp + src/evaluation.hpp + src/geometry.cpp + src/geometry.hpp + src/history.cpp + src/history.hpp + src/move.cpp + src/move.hpp + src/movegen.cpp + src/movegen.hpp + src/movepick.cpp + src/movepick.hpp + src/perft.cpp + src/perft.hpp + src/position.cpp + src/position.hpp + src/psqt_state.hpp + src/repetition_info.cpp + src/repetition_info.hpp + src/search.cpp + src/search.hpp + src/speedtest.cpp + src/speedtest.hpp + src/square.hpp + src/tm.cpp + src/tm.hpp + src/tt.cpp + src/tt.hpp + src/tuned.cpp + src/tuned.hpp + src/tuning/arena.hpp + src/tuning/globals.hpp + src/tuning/graph.cpp + src/tuning/graph.hpp + src/tuning/info.hpp + src/tuning/loss.hpp + src/tuning/operations.hpp + src/tuning/optim.hpp + src/tuning/value.cpp + src/tuning/value.hpp + src/uci.cpp + src/uci.hpp + src/util/bit.hpp + src/util/parse.hpp + src/util/pretty.hpp + src/util/static_vector.hpp + src/util/types.hpp + src/util/vec/sse2.hpp + src/zobrist.cpp + src/zobrist.hpp ) if(CLOCKWORK_ENABLE_EVALTUNE) diff --git a/src/eval_types.hpp b/src/eval_types.hpp index 8262a951..6391bb8c 100644 --- a/src/eval_types.hpp +++ b/src/eval_types.hpp @@ -14,9 +14,10 @@ #endif namespace Clockwork { -#ifndef EVAL_TUNING +#ifndef EVAL_TUNING using Score = i16; + class PScore { private: i32 m_score; @@ -109,20 +110,32 @@ using PParam = PScore; #else -using Score = Autograd::ValuePtr; -using PScore = Autograd::PairPtr; -using PParam = Autograd::PairPlaceholder; +using Score = Autograd::ValueHandle; +using PScore = Autograd::PairHandle; +using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter #endif + #ifdef EVAL_TUNING - #define S(a, b) Autograd::PairPlaceholder::create_tunable((a), (b)) // Defines a tunable pscore + // Tunable scalar pair (mg, eg) + #define S(a, b) Autograd::PairPlaceholder::create_tunable((a), (b)) + + // Constant scalar pair (mg, eg) #define CS(a, b) Autograd::PairPlaceholder::create((a), (b)) - #define PSCORE_ZERO Autograd::Pair::create(0, 0) + + // Zero pair FOR PARAMETERS (e.g., in an array) + #define PPARAM_ZERO Autograd::PairPlaceholder::create(0, 0) + + // Zero pair FOR INTERMEDIATES (e.g., scores) + #define PSCORE_ZERO Autograd::PairHandle::create(0, 0) + #else - #define S(a, b) PScore((a), (b)) // Defines a constant pscore when not tuning - #define CS(a, b) S((a), (b)) - #define PSCORE_ZERO CS(0, 0) + // ... (non-tuning definitions) ... + #define S(a, b) PScore((a), (b)) + #define CS(a, b) PScore((a), (b)) + #define PPARAM_ZERO PScore(0, 0) + #define PSCORE_ZERO PScore(0, 0) #endif } // namespace Clockwork diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index 0568a660..44a5ec3b 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -2,12 +2,15 @@ #include "eval_types.hpp" #include "evaluation.hpp" #include "position.hpp" + #include "tuning/graph.hpp" #include "tuning/loss.hpp" #include "tuning/optim.hpp" #include "tuning/value.hpp" + #include "util/pretty.hpp" #include "util/types.hpp" + #include #include #include @@ -22,195 +25,185 @@ #include using namespace Clockwork; +using namespace Clockwork::Autograd; int main() { - // Load fens from multiple files. std::vector positions; std::vector results; - // List of files to load const std::vector fenFiles = { - "data/dfrc-1m.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3/v3.txt", + "data/dfrcv1.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3.txt", }; - // Number of threads to use, default to half available const u32 thread_count = std::max(1, std::thread::hardware_concurrency() / 2); - std::cout << "Running on " << thread_count << " threads" << std::endl; + std::cout << "Running on " << thread_count << " threads\n"; for (const auto& filename : fenFiles) { std::ifstream fenFile(filename); if (!fenFile) { - std::cerr << "Error opening " << filename << std::endl; + std::cerr << "Error opening " << filename << "\n"; return 1; } std::string line; while (std::getline(fenFile, line)) { size_t pos = line.find(';'); - if (pos != std::string::npos) { - std::string fen = line.substr(0, pos); - auto parsed = Position::parse(fen); - if (parsed) { - positions.push_back(*parsed); - } else { - std::cerr << "Failed to parse FEN in file " << filename << ": " << fen - << std::endl; - continue; - } + if (pos == std::string::npos) { + std::cerr << "Bad line in " << filename << ": " << line << "\n"; + continue; + } - std::string result = line.substr(pos + 1); - result.erase(std::remove_if(result.begin(), result.end(), ::isspace), result.end()); - - if (result == "w") { - results.push_back(1.0); - } else if (result == "d") { - results.push_back(0.5); - } else if (result == "b") { - results.push_back(0.0); - } else { - std::cerr << "Invalid result in file " << filename << " line: " << line - << " (result is '" << result << "')" << std::endl; - } + std::string fen = line.substr(0, pos); + auto parsed = Position::parse(fen); + + if (!parsed) { + std::cerr << "Failed to parse FEN in " << filename << ": " << fen << "\n"; + continue; + } + + positions.push_back(*parsed); + + std::string result = line.substr(pos + 1); + result.erase(std::remove_if(result.begin(), result.end(), ::isspace), result.end()); + + if (result == "w") { + results.push_back(1.0); + } else if (result == "d") { + results.push_back(0.5); + } else if (result == "b") { + results.push_back(0.0); } else { - std::cerr << "Invalid line format in " << filename << ": " << line << std::endl; + std::cerr << "Invalid result in " << filename << ": " << line << "\n"; } } - - fenFile.close(); } - // Print the number of positions loaded - std::cout << "Loaded " << positions.size() << " FENs from " << fenFiles.size() << " files." - << std::endl; - - if (positions.size() == 0) { - std::cerr << "No positions loaded!" << std::endl; + std::cout << "Loaded " << positions.size() << " FENs.\n"; + if (positions.empty()) { return 1; } - using namespace Clockwork::Autograd; + // Setup tuning + const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts(); - const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts(); - Parameters current_parameter_values = Graph::get().get_all_parameter_values(); + // This line loads the defaults from your S() macros + Parameters current_parameter_values = Graph::get().get_all_parameter_values(); - AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0); + // Uncomment for zero tune: Overwrite them all with zeros. + current_parameter_values = Parameters::zeros(parameter_count); - const i32 epochs = 1000; + // The optimizer will now start with all-zero parameters + AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0); +#ifdef PROFILE_RUN + const i32 epochs = 8; +#else + const i32 epochs = 1000; +#endif const f64 K = 1.0 / 400; - const size_t batch_size = 16 * 16384; // Set batch size here - - std::mt19937 rng(std::random_device{}()); // Random number generator for shuffling + const size_t batch_size = 16 * 16384; - const size_t total_batches = (positions.size() + batch_size - 1) / batch_size; + std::mt19937 rng(std::random_device{}()); std::vector indices(positions.size()); + // Initialize indices 1..N + std::iota(indices.begin(), indices.end(), 0); + const size_t total_batches = (positions.size() + batch_size - 1) / batch_size; + // Shared gradient accumulator Parameters batch_gradients = Parameters::zeros(parameter_count); - std::mutex mutex; + std::mutex mutex; + std::barrier epoch_barrier{thread_count + 1}; std::barrier batch_barrier{thread_count + 1, [&]() noexcept { - std::lock_guard guard{mutex}; + // Single-thread optimizer update optim.step(current_parameter_values, batch_gradients); batch_gradients = Parameters::zeros(parameter_count); }}; - for (u32 thread_idx = 0; thread_idx < thread_count; thread_idx++) { - std::thread([&, thread_idx] { - Graph::get().cleanup(); - - std::vector subbatch_outputs; - std::vector subbatch_targets; - - for (i32 epoch = 0; epoch < epochs; epoch++) { + // Spawn worker threads + for (u32 t = 0; t < thread_count; ++t) { + std::thread([&, t]() { + // Each thread uses its own Graph arena + for (int epoch = 0; epoch < epochs; ++epoch) { epoch_barrier.arrive_and_wait(); for (size_t batch_start = 0; batch_start < positions.size(); batch_start += batch_size) { + size_t batch_end = std::min(batch_start + batch_size, positions.size()); + size_t this_batch_size = batch_end - batch_start; - size_t batch_end = std::min(batch_start + batch_size, positions.size()); - size_t current_batch_size = batch_end - batch_start; - size_t subbatch_size = (current_batch_size + thread_count - 1) / thread_count; - - size_t subbatch_start = batch_start + subbatch_size * thread_idx; - size_t subbatch_end = std::min(subbatch_start + subbatch_size, batch_end); - size_t current_subbatch_size = subbatch_end - subbatch_start; + size_t sub_size = (this_batch_size + thread_count - 1) / thread_count; - subbatch_outputs.clear(); - subbatch_targets.clear(); - subbatch_outputs.reserve(current_subbatch_size); - subbatch_targets.reserve(current_subbatch_size); + size_t sub_start = batch_start + sub_size * t; + size_t sub_end = std::min(sub_start + sub_size, batch_end); Graph::get().copy_parameter_values(current_parameter_values); - uint32_t i = 0; - for (size_t j = subbatch_start; j < subbatch_end; ++j) { - size_t idx = indices[j]; - f64 y = results[idx]; - Position pos = positions[idx]; - auto result = (evaluate_white_pov(pos) * K)->sigmoid(); - subbatch_outputs.push_back(result); - subbatch_targets.push_back(y); - if (++i == 1024) { - i = 0; - auto subbatch_loss = - mse(subbatch_outputs, subbatch_targets) - * Autograd::Value::create(1.0 / static_cast(current_batch_size)); - Graph::get().backward(); - Graph::get().clear_backwardables(); - subbatch_outputs.clear(); - subbatch_targets.clear(); - } + std::vector outputs; + std::vector targets; + outputs.reserve(sub_end - sub_start); + targets.reserve(sub_end - sub_start); + + // Forward + for (size_t j = sub_start; j < sub_end; ++j) { + size_t idx = indices[j]; + + auto y = results[idx]; + ValueHandle v = (evaluate_white_pov(positions[idx]) * K).sigmoid(); + outputs.push_back(v); + targets.push_back(y); } - auto subbatch_loss = - mse(subbatch_outputs, subbatch_targets) - * Autograd::Value::create(1.0 / static_cast(current_batch_size)); + // Backward + ValueHandle loss = mse(outputs, targets) + * ValueHandle::create(1.0 / double(this_batch_size)); + Graph::get().backward(); - Parameters subbatch_gradients = Graph::get().get_all_parameter_gradients(); + Parameters grads = Graph::get().get_all_parameter_gradients(); + // Accumulate { - std::lock_guard guard{mutex}; - batch_gradients.accumulate(subbatch_gradients); + std::lock_guard guard(mutex); + batch_gradients.accumulate(grads); } - batch_barrier.arrive_and_wait(); Graph::get().cleanup(); + Graph::get().zero_grad(); + + batch_barrier.arrive_and_wait(); } } }).detach(); } - for (i32 epoch = 0; epoch < epochs; epoch++) { - // Print epoch header - std::cout << "Epoch " << (epoch + 1) << "/" << epochs << std::endl; + // Epoch loop + for (int epoch = 0; epoch < epochs; ++epoch) { - const auto epoch_start_time = time::Clock::now(); + std::cout << "Epoch " << epoch + 1 << "/" << epochs << "\n"; + + const auto start = time::Clock::now(); - std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), rng); epoch_barrier.arrive_and_wait(); - for (size_t batch_idx = 0, batch_start = 0; batch_start < positions.size(); - batch_start += batch_size, ++batch_idx) { - + for (size_t bi = 0, bstart = 0; bstart < positions.size(); bstart += batch_size, ++bi) { batch_barrier.arrive_and_wait(); - - // Print batch progress bar - print_progress(batch_idx + 1, total_batches); + print_progress(bi + 1, total_batches); } - const auto epoch_end_time = time::Clock::now(); - - std::cout << std::endl; // Finish progress bar line + std::cout << "\n"; - // Print current values + // Dump current parameter values Graph::get().copy_parameter_values(current_parameter_values); + Graph::get().cleanup(); + Graph::get().zero_grad(); +#ifndef PROFILE_RUN std::cout << "inline const PParam PAWN_MAT = " << PAWN_MAT << ";" << std::endl; std::cout << "inline const PParam KNIGHT_MAT = " << KNIGHT_MAT << ";" << std::endl; std::cout << "inline const PParam BISHOP_MAT = " << BISHOP_MAT << ";" << std::endl; @@ -336,10 +329,11 @@ int main() { printPsqtArray("ROOK_PSQT", ROOK_PSQT); printPsqtArray("QUEEN_PSQT", QUEEN_PSQT); printPsqtArray("KING_PSQT", KING_PSQT); - - std::cout << "// Epoch duration: " - << time::cast(epoch_end_time - epoch_start_time).count() - << "s" << std::endl; + std::cout << std::endl; +#endif + const auto end = time::Clock::now(); + std::cout << "// Epoch duration: " << time::cast(end - start).count() + << "s\n"; if (epoch > 5) { optim.set_lr(optim.get_lr() * 0.91); diff --git a/src/evaluation.cpp b/src/evaluation.cpp index 634a1665..835bbff0 100644 --- a/src/evaluation.cpp +++ b/src/evaluation.cpp @@ -337,7 +337,7 @@ Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) { eval += evaluate_space(pos) - evaluate_space(pos); eval += evaluate_outposts(pos) - evaluate_outposts(pos); eval += (us == Color::White) ? TEMPO_VAL : -TEMPO_VAL; - return static_cast(eval->phase<24>(static_cast(phase))); + return static_cast(eval.phase<24>(static_cast(phase))); }; Score evaluate_stm_pov(const Position& pos, const PsqtState& psqt_state) { diff --git a/src/tuning/arena.hpp b/src/tuning/arena.hpp new file mode 100644 index 00000000..0f510e4f --- /dev/null +++ b/src/tuning/arena.hpp @@ -0,0 +1,193 @@ +#pragma once +#include "util/vec/sse2.hpp" +#include "value.hpp" + +#include "util/types.hpp" +#include +#include + +namespace Clockwork::Autograd { + +class ValueArena { +public: + ValueArena() = default; + + void reserve(usize n) { + values.reserve(n); + gradients.reserve(n); + } + + inline u32 alloc(f64 value = 0.0, f64 grad = 0.0) { + u32 idx = static_cast(values.size()); + values.push_back(value); + gradients.push_back(grad); + return idx; + } + + inline u32 next_index() const { + return static_cast(values.size()); + } + + inline ValueHandle next_handle() const { + return ValueHandle(next_index()); + } + + // Mutating accessors + inline f64& val(u32 i) { + assert(i < values.size()); + return values[i]; + } + inline f64& grad(u32 i) { + assert(i < gradients.size()); + return gradients[i]; + } + + // Const accessors + inline const f64& val(u32 i) const { + assert(i < values.size()); + return values[i]; + } + inline const f64& grad(u32 i) const { + assert(i < gradients.size()); + return gradients[i]; + } + + inline usize size() const { + return values.size(); + } + + void clear() { + values.clear(); + gradients.clear(); + } + + void reset_to(usize n) { + if (n < values.size()) { + values.resize(n); + gradients.resize(n); + } + } + + inline f64* values_data() { + return values.data(); + } + inline f64* gradients_data() { + return gradients.data(); + } + inline const f64* values_data() const { + return values.data(); + } + inline const f64* gradients_data() const { + return gradients.data(); + } + +private: + std::vector values; + std::vector gradients; +}; + +class PairArena { +public: + PairArena() = default; + + void reserve(usize n) { + values.reserve(n); + gradients.reserve(n); + } + + inline u32 alloc(f64x2 v = f64x2::zero(), f64x2 g = f64x2::zero()) { + u32 idx = static_cast(values.size()); + values.push_back(v); + gradients.push_back(g); + return idx; + } + + inline u32 next_index() const { + return static_cast(values.size()); + } + + inline PairHandle next_handle() const { + return PairHandle(next_index()); + } + + // Accessors + inline f64x2& val(u32 i) { + assert(i < values.size()); + return values[i]; + } + + inline const f64x2& val(u32 i) const { + assert(i < values.size()); + return values[i]; + } + + inline f64x2& grad(u32 i) { + assert(i < gradients.size()); + return gradients[i]; + } + + inline const f64x2& grad(u32 i) const { + assert(i < gradients.size()); + return gradients[i]; + } + + // Legacy component accessors + inline f64 p0_ref(u32 i) const { + assert(i < values.size()); + return values[i].first(); + } + + inline f64 p1_ref(u32 i) const { + assert(i < values.size()); + return values[i].second(); + } + + inline f64 g0_ref(u32 i) const { + assert(i < gradients.size()); + return gradients[i].first(); + } + + inline f64 g1_ref(u32 i) const { + assert(i < gradients.size()); + return gradients[i].second(); + } + + inline usize size() const { + return values.size(); + } + + void clear() { + values.clear(); + gradients.clear(); + } + + void reset_to(usize n) { + if (n < values.size()) { + values.resize(n); + gradients.resize(n); + } + } + + // Pointer accessors for hot loops + inline f64x2* values_data() { + return values.data(); + } + + inline f64x2* gradients_data() { + return gradients.data(); + } + + inline const f64x2* values_data() const { + return values.data(); + } + + inline const f64x2* gradients_data() const { + return gradients.data(); + } + +private: + std::vector values; + std::vector gradients; +}; + +} // namespace Clockwork::Autograd diff --git a/src/tuning/globals.hpp b/src/tuning/globals.hpp index e573930f..f9236e4b 100644 --- a/src/tuning/globals.hpp +++ b/src/tuning/globals.hpp @@ -1,6 +1,6 @@ #pragma once -#include "tuning/graph.hpp" +#include "tuning/graph.hpp" // Required for Graph::get() #include "tuning/info.hpp" #include "tuning/value.hpp" #include "util/types.hpp" @@ -8,7 +8,7 @@ #include #include #include -#include +#include namespace Clockwork::Autograd { @@ -60,7 +60,6 @@ class Globals { } bool is_parameter_constant(usize i) const; - bool is_pair_parameter_constant(usize i) const; private: @@ -89,18 +88,17 @@ class ValuePlaceholder { return ValuePlaceholder(a, true); } - operator ValuePtr() const { + // Conversion to Handle: Delegates to the Graph + operator ValueHandle() const { return Graph::get().get_parameter(m_index); } usize index() const { return m_index; } - f64 default_value() const { return m_default_value; } - bool constant() const { return m_constant; } @@ -111,100 +109,50 @@ class ValuePlaceholder { bool m_constant; }; -inline bool Globals::is_parameter_constant(usize i) const { - return m_parameters[i]->constant(); -} - -inline std::ostream& operator<<(std::ostream& os, ValuePlaceholder a) { - os << static_cast(a); - return os; -} - -inline ValuePtr operator-(ValuePlaceholder a) { - return -static_cast(a); -} - -inline ValuePtr operator+(ValuePlaceholder a, ValuePlaceholder b) { - return static_cast(a) + static_cast(b); -} - -inline ValuePtr operator-(ValuePlaceholder a, ValuePlaceholder b) { - return static_cast(a) - static_cast(b); -} - -inline ValuePtr operator*(ValuePlaceholder a, i32 b) { - return static_cast(a) * b; -} - -inline ValuePtr operator/(ValuePlaceholder a, i32 b) { - return static_cast(a) / b; -} - class PairPlaceholder { public: - explicit PairPlaceholder(f128 default_value, bool constant) : + explicit PairPlaceholder(f64x2 default_value, bool constant) : m_index(Globals::get().register_param(this)), m_default_value(default_value), m_constant(constant) { } static PairPlaceholder create_tunable(f64 a, f64 b) { - return PairPlaceholder(f128::make(a, b), false); + return PairPlaceholder(f64x2::make(a, b), false); } static PairPlaceholder create(f64 a, f64 b) { - return PairPlaceholder(f128::make(a, b), true); + return PairPlaceholder(f64x2::make(a, b), true); } - operator PairPtr() const { + // Conversion to Handle: Delegates to the Graph + operator PairHandle() const { return Graph::get().get_pair_parameter(m_index); } usize index() const { return m_index; } - - f128 default_value() const { + f64x2 default_value() const { return m_default_value; } - bool constant() const { return m_constant; } private: usize m_index; - f128 m_default_value; + f64x2 m_default_value; bool m_constant; }; -inline bool Globals::is_pair_parameter_constant(usize i) const { - return m_pair_parameters[i]->constant(); -} - -inline std::ostream& operator<<(std::ostream& os, PairPlaceholder a) { - os << static_cast(a); - return os; -} - -inline PairPtr operator-(PairPlaceholder a) { - return -static_cast(a); -} - -inline PairPtr operator+(PairPlaceholder a, PairPlaceholder b) { - return static_cast(a) + static_cast(b); -} - -inline PairPtr operator-(PairPlaceholder a, PairPlaceholder b) { - return static_cast(a) - static_cast(b); +inline bool Globals::is_parameter_constant(usize i) const { + return m_parameters[i]->constant(); } -inline PairPtr operator*(PairPlaceholder a, i32 b) { - return static_cast(a) * b; +inline bool Globals::is_pair_parameter_constant(usize i) const { + return m_pair_parameters[i]->constant(); } -inline PairPtr operator/(PairPlaceholder a, i32 b) { - return static_cast(a) / b; -} } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.cpp b/src/tuning/graph.cpp index 706ee5c9..91ef6163 100644 --- a/src/tuning/graph.cpp +++ b/src/tuning/graph.cpp @@ -1,16 +1,499 @@ #include "tuning/graph.hpp" #include "tuning/globals.hpp" +#include "util/types.hpp" +#include namespace Clockwork::Autograd { Graph::Graph() { - for (ValuePlaceholder* placeholder : Globals::get().get_parameters()) { - register_param(std::make_shared(placeholder->default_value())); + // Initialize arenas with global parameters + auto params = Globals::get().get_parameters(); + auto pair_params = Globals::get().get_pair_parameters(); + + m_global_param_count = params.size(); + m_global_pair_count = pair_params.size(); + + // reserve some headroom + m_values.reserve(m_global_param_count + 1024); + m_pairs.reserve(m_global_pair_count + 1024); + m_tape.reserve(16384 * 2 * 16); + + for (auto* p : params) { + m_values.alloc(p->default_value(), 0.0); + } + for (auto* p : pair_params) { + m_pairs.alloc(p->default_value(), f64x2::zero()); + } +} + +ValueHandle Graph::create_value(f64 data) { + return ValueHandle(m_values.alloc(data, 0.0)); +} + +PairHandle Graph::create_pair(f64x2 data) { + return PairHandle(m_pairs.alloc(data, f64x2::zero())); +} + +// Recording + +ValueHandle Graph::record_op(OpType op, ValueHandle lhs, ValueHandle rhs) { + ValueHandle out = m_values.next_handle(); + f64 l = m_values.val(lhs.index); + f64 r = m_values.val(rhs.index); + f64 res = 0.0; + + switch (op) { + case OpType::Add: + res = l + r; + break; + case OpType::Sub: + res = l - r; + break; + case OpType::Mul: + res = l * r; + break; + case OpType::Div: + res = l / r; + break; + case OpType::Pow: + res = std::pow(l, r); + break; + default: + break; + } + + m_values.alloc(res, 0.0); + + m_tape.push_back(Node::make_binary(op, out.index, lhs.index, rhs.index)); + + return out; +} + +ValueHandle Graph::record_op(OpType op, ValueHandle lhs, f64 scalar) { + ValueHandle out = m_values.next_handle(); + f64 l = m_values.val(lhs.index); + f64 res = 0.0; + + switch (op) { + case OpType::Exp: + res = std::exp(l); + break; + case OpType::Log: + res = std::log(l); + break; + case OpType::Sigmoid: + res = 1.0 / (1.0 + std::exp(-l)); + break; + case OpType::Neg: + res = -l; + break; + case OpType::PowConst: + res = std::pow(l, scalar); + break; + case OpType::AddScalar: + res = l + scalar; + break; + case OpType::SubScalarVal: + res = scalar - l; + break; + case OpType::ValSubScalar: + res = l - scalar; + break; + case OpType::MulScalar: + res = l * scalar; + break; + case OpType::DivScalarVal: + res = scalar / l; + break; + case OpType::ValDivScalar: + res = l / scalar; + break; + default: + break; + } + + m_values.alloc(res, 0.0); + + m_tape.push_back(Node::make_scalar(op, out.index, lhs.index, scalar)); + + return out; +} + +PairHandle Graph::record_pair_op(OpType op, PairHandle lhs, PairHandle rhs) { + PairHandle out = m_pairs.next_handle(); + f64x2 l = m_pairs.val(lhs.index); + f64x2 r = m_pairs.val(rhs.index); + f64x2 res = f64x2::zero(); + + switch (op) { + case OpType::PairAdd: + res = f64x2::add(l, r); + break; + case OpType::PairSub: + res = f64x2::sub(l, r); + break; + default: + break; + } + + m_pairs.alloc(res, f64x2::zero()); + + m_tape.push_back(Node::make_binary(op, out.index, lhs.index, rhs.index)); + + return out; +} + +PairHandle Graph::record_pair_scalar(OpType op, PairHandle lhs, f64 scalar) { + PairHandle out = m_pairs.next_handle(); + f64x2 l = m_pairs.val(lhs.index); + f64x2 res = f64x2::zero(); + + switch (op) { + case OpType::PairNeg: + res = f64x2::neg(l); + break; + case OpType::PairMulScalar: + res = f64x2::mul_scalar(l, scalar); + break; + case OpType::PairDivScalar: + res = f64x2::div_scalar(l, scalar); + break; + case OpType::ScalarDivPair: + res = f64x2::scalar_div(scalar, l); + break; + default: + break; + } + + m_pairs.alloc(res, f64x2::zero()); + + m_tape.push_back(Node::make_scalar(op, out.index, lhs.index, scalar)); + + return out; +} + +PairHandle Graph::record_pair_value(OpType op, PairHandle lhs, ValueHandle rhs) { + PairHandle out = m_pairs.next_handle(); + f64x2 pair_val = m_pairs.val(lhs.index); + f64 v = m_values.val(rhs.index); + f64x2 res = f64x2::zero(); + + switch (op) { + case OpType::PairMulValue: + case OpType::ValueMulPair: + res = f64x2::mul_scalar(pair_val, v); + break; + case OpType::PairDivValue: + res = f64x2::div_scalar(pair_val, v); + break; + case OpType::ValueDivPair: + res = f64x2::scalar_div(v, pair_val); + break; + default: + break; + } + + m_pairs.alloc(res, f64x2::zero()); + + m_tape.push_back(Node::make_binary(op, out.index, lhs.index, rhs.index)); + + return out; +} + +ValueHandle Graph::record_phase(PairHandle lhs, f64 alpha) { + ValueHandle out = m_values.next_handle(); + f64x2 pair_val = m_pairs.val(lhs.index); + f64 val = alpha * pair_val.first() + (1.0 - alpha) * pair_val.second(); + + m_values.alloc(val, 0.0); + + m_tape.push_back(Node::make_scalar(OpType::Phase, out.index, lhs.index, alpha)); + + return out; +} + +void Graph::backward() { + if (m_tape.empty()) { + return; + } + + // Initialize gradient of last output to 1 + // (This assumes the final output is always a ValueHandle) + m_values.grad(m_tape.back().out()) = 1.0; + + f64* vals = m_values.values_data(); + f64* grads = m_values.gradients_data(); + f64x2* pair_vals = m_pairs.values_data(); + f64x2* pair_grads = m_pairs.gradients_data(); + + for (auto it = m_tape.rbegin(); it != m_tape.rend(); ++it) { + const Node& node = *it; + + const u32 out_idx = node.out(); + + switch (node.type) { + // Value-Binary + + case OpType::Add: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] += grad_out; + grads[node.rhs()] += grad_out; + break; + } + case OpType::Sub: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] += grad_out; + grads[node.rhs()] -= grad_out; + break; + } + case OpType::Mul: { + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; + f64 r = vals[node.rhs()]; + grads[node.lhs()] += r * grad_out; + grads[node.rhs()] += l * grad_out; + break; + } + case OpType::Div: { + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; + f64 r = vals[node.rhs()]; + grads[node.lhs()] += grad_out / r; + grads[node.rhs()] += -l * grad_out / (r * r); + break; + } + case OpType::Pow: { + const f64 grad_out = grads[out_idx]; + f64 base = vals[node.lhs()]; + f64 exp = vals[node.rhs()]; + grads[node.lhs()] += exp * std::pow(base, exp - 1) * grad_out; + grads[node.rhs()] += std::pow(base, exp) * std::log(base) * grad_out; + break; + } + + // Value-Scalar + case OpType::Exp: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] += vals[out_idx] * grad_out; + break; + } + case OpType::Log: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] += grad_out / vals[node.lhs()]; + break; + } + case OpType::Sigmoid: { + const f64 grad_out = grads[out_idx]; + f64 s = vals[out_idx]; + grads[node.lhs()] += s * (1.0 - s) * grad_out; + break; + } + case OpType::Neg: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] -= grad_out; + break; + } + case OpType::PowConst: { + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; + f64 exp = node.scalar(); + grads[node.lhs()] += exp * std::pow(l, exp - 1.0) * grad_out; + break; + } + case OpType::AddScalar: + case OpType::ValSubScalar: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] += grad_out; + break; + } + + case OpType::SubScalarVal: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] -= grad_out; + break; + } + case OpType::MulScalar: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] += node.scalar() * grad_out; + break; + } + case OpType::ValDivScalar: { + const f64 grad_out = grads[out_idx]; + grads[node.lhs()] += grad_out / node.scalar(); + break; + } + case OpType::DivScalarVal: { + const f64 grad_out = grads[out_idx]; + f64 l = vals[node.lhs()]; + grads[node.lhs()] += -node.scalar() * grad_out / (l * l); + break; + } + case OpType::Phase: { + const f64 grad_out = grads[out_idx]; + f64x2 update = f64x2::make(node.scalar() * grad_out, (1.0 - node.scalar()) * grad_out); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); + break; + } + + case OpType::PairAdd: { + const f64x2 grad_out = pair_grads[out_idx]; + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_out); + pair_grads[node.rhs()] = f64x2::add(pair_grads[node.rhs()], grad_out); + break; + } + case OpType::PairSub: { + const f64x2 grad_out = pair_grads[out_idx]; + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_out); + pair_grads[node.rhs()] = f64x2::sub(pair_grads[node.rhs()], grad_out); + break; + } + case OpType::PairNeg: { + const f64x2 grad_out = pair_grads[out_idx]; + pair_grads[node.lhs()] = f64x2::sub(pair_grads[node.lhs()], grad_out); + break; + } + case OpType::PairMulScalar: { + const f64x2 grad_out = pair_grads[out_idx]; + f64x2 scaled = f64x2::mul_scalar(grad_out, node.scalar()); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], scaled); + break; + } + case OpType::PairDivScalar: { + const f64x2 grad_out = pair_grads[out_idx]; + f64x2 scaled = f64x2::div_scalar(grad_out, node.scalar()); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], scaled); + break; + } + case OpType::ScalarDivPair: { + const f64x2 grad_out = pair_grads[out_idx]; + f64x2 val = pair_vals[node.lhs()]; + f64x2 grad = f64x2::scalar_div(-node.scalar(), f64x2::mul(val, val)); + f64x2 update = f64x2::mul(grad, grad_out); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); + break; + } + case OpType::PairMulValue: + case OpType::ValueMulPair: { + const f64x2 grad_out = pair_grads[out_idx]; + f64 val_rhs = vals[node.rhs()]; + f64x2 val_lhs = pair_vals[node.lhs()]; + + f64x2 grad_pair = f64x2::mul_scalar(grad_out, val_rhs); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); + + f64 contrib = grad_out.first() * val_lhs.first() + grad_out.second() * val_lhs.second(); + grads[node.rhs()] += contrib; + break; + } + case OpType::PairDivValue: { + const f64x2 grad_out = pair_grads[out_idx]; + f64 val_rhs = vals[node.rhs()]; + f64x2 val_lhs = pair_vals[node.lhs()]; + + f64x2 grad_pair = f64x2::div_scalar(grad_out, val_rhs); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair); + + f64 num = grad_out.first() * val_lhs.first() + grad_out.second() * val_lhs.second(); + grads[node.rhs()] += -num / (val_rhs * val_rhs); + break; + } + case OpType::ValueDivPair: { + const f64x2 grad_out = pair_grads[out_idx]; + f64 val_rhs = vals[node.rhs()]; + f64x2 val_lhs = pair_vals[node.lhs()]; + + f64x2 grad_pair = f64x2::scalar_div(-val_rhs, f64x2::mul(val_lhs, val_lhs)); + f64x2 update = f64x2::mul(grad_pair, grad_out); + pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update); + + f64x2 recip = f64x2::scalar_div(1.0, val_lhs); + grads[node.rhs()] += + grad_out.first() * recip.first() + grad_out.second() * recip.second(); + break; + } + + default: + unreachable(); + break; + } + } +} + +void Graph::cleanup() { + m_values.reset_to(m_global_param_count); + m_pairs.reset_to(m_global_pair_count); + m_tape.clear(); +} + +void Graph::zero_grad() { + for (usize i = 0; i < m_global_param_count; ++i) { + m_values.grad(i) = 0.0; } - for (PairPlaceholder* placeholder : Globals::get().get_pair_parameters()) { - register_param( - std::make_shared(placeholder->default_value(), placeholder->constant())); + for (usize i = 0; i < m_global_pair_count; ++i) { + m_pairs.grad(i) = f64x2::zero(); } } +void Graph::copy_parameter_values(const Parameters& source) { + if (source.parameters.size() != m_global_param_count + || source.pair_parameters.size() != m_global_pair_count) { + std::cerr << "Graph parameter count desync!" << std::endl; + std::terminate(); + } + for (usize i = 0; i < m_global_param_count; ++i) { + m_values.val(i) = source.parameters[i]; + } + for (usize i = 0; i < m_global_pair_count; ++i) { + m_pairs.val(i) = source.pair_parameters[i]; + } +} + +Parameters Graph::get_all_parameter_values() const { + Parameters p; + p.parameters.reserve(m_global_param_count); + p.pair_parameters.reserve(m_global_pair_count); + for (usize i = 0; i < m_global_param_count; ++i) { + p.parameters.push_back(m_values.val(i)); + } + for (usize i = 0; i < m_global_pair_count; ++i) { + p.pair_parameters.push_back(m_pairs.val(i)); + } + return p; +} + +Parameters Graph::get_all_parameter_gradients() const { + Parameters p; + p.parameters.reserve(m_global_param_count); + p.pair_parameters.reserve(m_global_pair_count); + for (usize i = 0; i < m_global_param_count; ++i) { + p.parameters.push_back(m_values.grad(i)); + } + for (usize i = 0; i < m_global_pair_count; ++i) { + p.pair_parameters.push_back(m_pairs.grad(i)); + } + return p; +} + +// Mutation Helpers + +void Graph::add_value_gradient(u32 idx, f64 delta) { + m_values.grad(idx) += delta; +} + +void Graph::set_value(u32 idx, f64 v) { + m_values.val(idx) = v; +} + +void Graph::zero_value_grad(u32 idx) { + m_values.grad(idx) = 0.0; +} + +void Graph::set_pair_values(u32 idx, const f64x2& v) { + m_pairs.val(idx) = v; +} + +void Graph::zero_pair_grad(u32 idx) { + m_pairs.grad(idx) = f64x2::zero(); +} + } // namespace Clockwork::Autograd diff --git a/src/tuning/graph.hpp b/src/tuning/graph.hpp index 5f029677..7053f47a 100644 --- a/src/tuning/graph.hpp +++ b/src/tuning/graph.hpp @@ -1,163 +1,93 @@ #pragma once -#include "tuning/info.hpp" -#include "tuning/value.hpp" -#include +#include "arena.hpp" +#include "info.hpp" +#include "operations.hpp" +#include "util/vec/sse2.hpp" +#include "value.hpp" #include #include #include namespace Clockwork::Autograd { -class Backwardable; -using BackwardablePtr = std::shared_ptr; - -template -class SmartBackwardable; -template -using SmartBackwardablePtr = std::shared_ptr>; - -class Value; -using ValuePtr = std::shared_ptr; - -class Pair; -using PairPtr = std::shared_ptr; - class Graph { private: - // Tunable parameters - std::vector m_parameters; - std::vector m_pair_parameters; + ValueArena m_values; + PairArena m_pairs; - // All backwardable nodes in insertion order (intermediates + outputs + parameters) - std::vector m_backwardables; + // Tape (Linear record of operations) + std::vector m_tape; - Graph(); + // Counts of global parameters + usize m_global_param_count = 0; + usize m_global_pair_count = 0; - void register_param(const ValuePtr& param) { - m_parameters.push_back(param); - } - - void register_param(const PairPtr& param) { - m_pair_parameters.push_back(param); - } + Graph(); public: - static Graph& get() { + inline static Graph& get() { thread_local Graph instance; return instance; } - // ------------------ Registration ------------------ - void register_value(const BackwardablePtr& node) { - m_backwardables.push_back(node); - } + // Creation + ValueHandle create_value(f64 data); + PairHandle create_pair(f64x2 data); - void register_value(const ValuePtr& node) { - m_backwardables.push_back(std::static_pointer_cast(node)); - } + // Operation recording + ValueHandle record_op(OpType op, ValueHandle lhs, ValueHandle rhs); + ValueHandle record_op(OpType op, ValueHandle input, f64 scalar = 0.0); + PairHandle record_pair_op(OpType op, PairHandle lhs, PairHandle rhs); + PairHandle record_pair_scalar(OpType op, PairHandle input, f64 scalar); + PairHandle record_pair_value(OpType op, PairHandle pair, ValueHandle val); - void register_value(const PairPtr& node) { - m_backwardables.push_back(std::static_pointer_cast(node)); - } + ValueHandle record_phase(PairHandle input, f64 alpha); - // ------------------- Copy Values ------------------- - void copy_parameter_values(const Parameters& source) { - if (source.parameters.size() != m_parameters.size() - || source.pair_parameters.size() != m_pair_parameters.size()) { - std::cerr << "Graph parameters count have desynced" << std::endl; - std::terminate(); - } - - for (usize i = 0; i < m_parameters.size(); i++) { - m_parameters[i]->set_value(source.parameters[i]); - } - for (usize i = 0; i < m_pair_parameters.size(); i++) { - m_pair_parameters[i]->set_values(source.pair_parameters[i]); - } - } + void backward(); - Parameters get_all_parameter_values() const { - Parameters result; - result.parameters.resize(m_parameters.size()); - result.pair_parameters.resize(m_pair_parameters.size()); - for (usize i = 0; i < m_parameters.size(); i++) { - result.parameters[i] = m_parameters[i]->get_value(); - } - for (usize i = 0; i < m_pair_parameters.size(); i++) { - result.pair_parameters[i] = m_pair_parameters[i]->get_values(); - } - return result; - } - - Parameters get_all_parameter_gradients() const { - Parameters result; - result.parameters.resize(m_parameters.size()); - result.pair_parameters.resize(m_pair_parameters.size()); - for (usize i = 0; i < m_parameters.size(); i++) { - result.parameters[i] = m_parameters[i]->get_gradient(); - } - for (usize i = 0; i < m_pair_parameters.size(); i++) { - result.pair_parameters[i] = m_pair_parameters[i]->get_graidents(); - } - return result; - } + void cleanup(); + void zero_grad(); + void copy_parameter_values(const Parameters& source); + Parameters get_all_parameter_values() const; + Parameters get_all_parameter_gradients() const; - // ------------------ Backward Pass ------------------ - void backward() { - if (m_backwardables.empty()) { - return; - } + void add_value_gradient(u32 idx, f64 delta); + void set_value(u32 idx, f64 v); + void zero_value_grad(u32 idx); - // Initialize gradient on last node (loss node) - auto last = std::static_pointer_cast(m_backwardables.back()); - last->m_gradient = static_cast(1); + void set_pair_values(u32 idx, const f64x2& v); + void zero_pair_grad(u32 idx); - // Reverse pass - for (auto it = m_backwardables.rbegin(); it != m_backwardables.rend(); ++it) { - (*it)->backward(); - } + // Direct SoA accessors + f64 get_value(u32 idx) const { + return m_values.val(idx); } - - void clear_backwardables() { - m_backwardables.clear(); + f64 get_gradient(u32 idx) const { + return m_values.grad(idx); } - // ------------------ Cleanup ------------------ - void cleanup() { - for (auto& param : m_parameters) { - param->zero_grad(); - } - for (auto& param : m_pair_parameters) { - param->zero_grad(); - } - - m_backwardables.clear(); + f64x2 get_pair_values(u32 idx) const { + return f64x2::make(m_pairs.p0_ref(idx), m_pairs.p1_ref(idx)); } - - // ------------------ Reset ------------------ - void init_zeros() { - for (auto& param : m_parameters) { - param->set_value(0.0); - } - for (auto& param : m_pair_parameters) { - param->set_values(0.0, 0.0); - } - cleanup(); + f64x2 get_pair_gradients(u32 idx) const { + return f64x2::make(m_pairs.g0_ref(idx), m_pairs.g1_ref(idx)); } - // ------------------ Accessors ------------------ - const std::vector& get_parameters() const { - return m_parameters; + // Pointer accessors + f64* values_data() { + return m_values.values_data(); } - const std::vector& get_pair_parameters() const { - return m_pair_parameters; + f64* gradients_data() { + return m_values.gradients_data(); } - ValuePtr get_parameter(usize index) const { - return m_parameters[index]; + + ValueHandle get_parameter(usize global_index) const { + return ValueHandle(static_cast(global_index)); } - PairPtr get_pair_parameter(usize index) const { - return m_pair_parameters[index]; + + PairHandle get_pair_parameter(usize global_index) const { + return PairHandle(static_cast(global_index)); } }; diff --git a/src/tuning/info.hpp b/src/tuning/info.hpp index ce6d01e4..d3874d1c 100644 --- a/src/tuning/info.hpp +++ b/src/tuning/info.hpp @@ -13,13 +13,13 @@ struct ParameterCountInfo { }; struct Parameters { - std::vector parameters; - std::vector pair_parameters; + std::vector parameters; + std::vector pair_parameters; static Parameters zeros(ParameterCountInfo counts) { Parameters result; result.parameters.resize(counts.parameter_count, 0.0); - result.pair_parameters.resize(counts.pair_parameter_count, f128::zero()); + result.pair_parameters.resize(counts.pair_parameter_count, f64x2::zero()); return result; } @@ -30,7 +30,7 @@ struct Parameters { parameters[i] += b.parameters[i]; } for (usize i = 0; i < pair_parameters.size(); i++) { - pair_parameters[i] = f128::add(pair_parameters[i], b.pair_parameters[i]); + pair_parameters[i] = f64x2::add(pair_parameters[i], b.pair_parameters[i]); } } @@ -42,7 +42,7 @@ struct Parameters { } for (usize i = 0; i < pair_parameters.size(); i++) { pair_parameters[i] = - f128::madd(pair_parameters[i], f128::broadcast(weight), b.pair_parameters[i]); + f64x2::madd(pair_parameters[i], f64x2::broadcast(weight), b.pair_parameters[i]); } } }; diff --git a/src/tuning/loss.hpp b/src/tuning/loss.hpp index 79d620ce..41eb30bb 100644 --- a/src/tuning/loss.hpp +++ b/src/tuning/loss.hpp @@ -1,7 +1,7 @@ #pragma once #include "value.hpp" -#include +#include #include namespace Clockwork::Autograd { @@ -13,33 +13,31 @@ enum class Reduction { }; template -auto mse(const std::vector& predictions, const std::vector& targets) { +auto mse(const std::vector& predictions, const std::vector& targets) { if (predictions.size() != targets.size()) { throw std::invalid_argument("MSE: predictions and targets must have the same size."); } if constexpr (R == Reduction::None) { - // Return vector of squared errors (no reduction) - std::vector losses; + std::vector losses; losses.reserve(predictions.size()); for (size_t i = 0; i < predictions.size(); ++i) { - ValuePtr diff = predictions[i] - targets[i]; + ValueHandle diff = predictions[i] - targets[i]; losses.push_back(diff * diff); } return losses; } else { - // Compute sum of squared errors in one accumulated node - std::vector losses; + std::vector losses; losses.reserve(predictions.size()); for (size_t i = 0; i < predictions.size(); ++i) { - ValuePtr diff = predictions[i] - targets[i]; + ValueHandle diff = predictions[i] - targets[i]; losses.push_back(diff * diff); } - ValuePtr total_loss = Value::sum(losses); + ValueHandle total_loss = ValueHandle::sum(losses); if constexpr (R == Reduction::Mean) { f64 n = static_cast(predictions.size()); - return total_loss * Value::create(static_cast(1) / n); + return total_loss * (1.0 / n); } else { return total_loss; } diff --git a/src/tuning/operations.hpp b/src/tuning/operations.hpp new file mode 100644 index 00000000..d01dfd7d --- /dev/null +++ b/src/tuning/operations.hpp @@ -0,0 +1,127 @@ +#pragma once + +#include "util/types.hpp" +#include "value.hpp" +#include +#include + +namespace Clockwork::Autograd { + +enum class OpType : u32 { + // Leaf nodes + None, + Parameter, // Created from a global parameter + Input, // Created manually (e.g. from data) + + // Binary Ops + Add, + Sub, + Mul, + Div, + Pow, + + // Unary Ops + Exp, + Log, + Sigmoid, + Neg, + PowConst, // x ^ scalar + AddScalar, // x + scalar + SubScalarVal, // scalar - x + ValSubScalar, // x - scalar + MulScalar, // x * scalar + DivScalarVal, // scalar / x + ValDivScalar, // x / scalar + + // Pair Ops + PairCreate, + PairAdd, + PairSub, + PairNeg, + + // Pair-Scalar Ops + PairMulScalar, + PairDivScalar, + ScalarDivPair, + + // Pair-Value Ops + PairMulValue, + ValueMulPair, + PairDivValue, + ValueDivPair, + + // Phasing + Phase, // Pair -> Value via alpha + + // Reduction (TODO: optimize this) + Sum // Sum of a vector of values +}; + + +// Node layout (16 bytes): +// [0..3] : type (1), pad (1), lhs_offset (2) -> 4 bytes +// [4..7] : output_idx -> 4 bytes +// [8..15] : union { struct { u16 rhs_offset; u16 pad2; u32 pad3; } ; double scalar; } -> 8 bytes +struct alignas(16) Node { + + OpType type; // u32 + u32 lhs_idx; + u32 output_idx; + + union U { + u32 rhs_idx; // for unary/binary ops + f32 scalar_data; // for scalar ops + + constexpr U() : + rhs_idx(0) { + } + constexpr U(u16 rhs_idx) : + rhs_idx(rhs_idx) { + } + constexpr U(f64 scalar) : + scalar_data(static_cast(scalar)) { + } + } u; + + static constexpr Node make_binary(OpType t, u32 output_idx, u32 lhs_idx, u32 rhs_idx) { + Node n{}; + n.type = t; + + // lhs & rhs indices are guaranteed to be <= out + n.lhs_idx = lhs_idx; + n.output_idx = output_idx; + n.u.rhs_idx = rhs_idx; + + return n; + } + + static constexpr Node make_scalar(OpType t, u32 output_idx, u32 lhs_idx, f64 scalar) { + Node n{}; + n.type = t; + n.lhs_idx = lhs_idx; + n.output_idx = output_idx; + n.u.scalar_data = static_cast(scalar); + return n; + } + + constexpr u32 lhs() const noexcept { + return lhs_idx; + } + + constexpr u32 rhs() const noexcept { + return u.rhs_idx; + } + + constexpr u32 out() const noexcept { + return output_idx; + } + + constexpr f64 scalar() const noexcept { + return u.scalar_data; + } +}; + +static_assert(sizeof(Node) == 16, "Node must be exactly 16 bytes"); +static_assert(alignof(Node) == 16, "Node alignment must match double alignment (8 bytes)"); + +} // namespace Clockwork::Autograd diff --git a/src/tuning/optim.hpp b/src/tuning/optim.hpp index 037e0598..ccad9614 100644 --- a/src/tuning/optim.hpp +++ b/src/tuning/optim.hpp @@ -6,7 +6,6 @@ #include "util/vec/sse2.hpp" #include -#include #include namespace Clockwork::Autograd { @@ -14,12 +13,11 @@ namespace Clockwork::Autograd { class SGD { private: ParameterCountInfo m_counts; + f64 m_lr; + f64 m_momentum; - f64 m_lr; - f64 m_momentum; - - std::vector m_value_velocity; - std::vector m_pair_velocity; + std::vector m_value_velocity; + std::vector m_pair_velocity; public: explicit SGD(ParameterCountInfo counts, f64 lr, f64 momentum = 0.9) : @@ -27,13 +25,15 @@ class SGD { m_lr(lr), m_momentum(momentum) { m_value_velocity.resize(m_counts.parameter_count, 0.0); - m_pair_velocity.resize(m_counts.pair_parameter_count, f128::zero()); + m_pair_velocity.resize(m_counts.pair_parameter_count, f64x2::zero()); } void step(Parameters& values, const Parameters& gradients) { + const auto& globals = Globals::get(); + // ---- Value parameters ---- for (size_t i = 0; i < m_counts.parameter_count; ++i) { - if (Globals::get().is_parameter_constant(i)) { + if (globals.is_parameter_constant(i)) { continue; } @@ -42,13 +42,12 @@ class SGD { auto& v = m_value_velocity[i]; v = m_momentum * v - m_lr * p_grad; - p_value += v; } // ---- Pair parameters ---- for (size_t i = 0; i < m_counts.pair_parameter_count; ++i) { - if (Globals::get().is_pair_parameter_constant(i)) { + if (globals.is_pair_parameter_constant(i)) { continue; } @@ -56,13 +55,11 @@ class SGD { auto& p_grad = gradients.pair_parameters[i]; auto& v = m_pair_velocity[i]; - const f128 lr_grad = f128::mul_scalar(p_grad, m_lr); - - const f128 mom_v = f128::mul_scalar(v, m_momentum); - const f128 neg_lr_grad = f128::neg(lr_grad); - v = f128::add(mom_v, neg_lr_grad); - - p_value = f128::add(p_value, v); + const f64x2 lr_grad = f64x2::mul_scalar(p_grad, m_lr); + const f64x2 mom_v = f64x2::mul_scalar(v, m_momentum); + const f64x2 neg_lr_grad = f64x2::neg(lr_grad); + v = f64x2::add(mom_v, neg_lr_grad); + p_value = f64x2::add(p_value, v); } } @@ -78,18 +75,17 @@ class SGD { class AdamW { private: ParameterCountInfo m_counts; - - f64 m_lr; - f64 m_beta1; - f64 m_beta2; - f64 m_eps; - f64 m_weight_decay; - long long m_t; - - std::vector m_m; - std::vector m_v; - std::vector m_pair_m; - std::vector m_pair_v; + f64 m_lr; + f64 m_beta1; + f64 m_beta2; + f64 m_eps; + f64 m_weight_decay; + long long m_t; + + std::vector m_m; + std::vector m_v; + std::vector m_pair_m; + std::vector m_pair_v; public: explicit AdamW(ParameterCountInfo counts, @@ -107,13 +103,13 @@ class AdamW { m_t(0) { m_m.resize(m_counts.parameter_count, 0.0); m_v.resize(m_counts.parameter_count, 0.0); - - m_pair_m.resize(m_counts.pair_parameter_count, f128::zero()); - m_pair_v.resize(m_counts.pair_parameter_count, f128::zero()); + m_pair_m.resize(m_counts.pair_parameter_count, f64x2::zero()); + m_pair_v.resize(m_counts.pair_parameter_count, f64x2::zero()); } void step(Parameters& values, const Parameters& gradients) { m_t += 1; + const auto& globals = Globals::get(); const f64 b1t = std::pow(m_beta1, static_cast(m_t)); const f64 b2t = std::pow(m_beta2, static_cast(m_t)); @@ -122,7 +118,7 @@ class AdamW { // ---------------- Value parameters ---------------- for (size_t i = 0; i < m_counts.parameter_count; ++i) { - if (Globals::get().is_parameter_constant(i)) { + if (globals.is_parameter_constant(i)) { continue; } @@ -130,24 +126,19 @@ class AdamW { auto& g = gradients.parameters[i]; m_m[i] = m_beta1 * m_m[i] + (1.0 - m_beta1) * g; - m_v[i] = m_beta2 * m_v[i] + (1.0 - m_beta2) * g * g; - const f64 m_hat = m_m[i] * inv1mb1t; - const f64 v_hat = m_v[i] * inv1mb2t; - - const f64 adam_update = m_lr * m_hat / (std::sqrt(v_hat) + m_eps); - + const f64 m_hat = m_m[i] * inv1mb1t; + const f64 v_hat = m_v[i] * inv1mb2t; + const f64 adam_update = m_lr * m_hat / (std::sqrt(v_hat) + m_eps); const f64 weight_decay_update = m_lr * m_weight_decay * p; - const f64 total_update = -(adam_update + weight_decay_update); - - p += total_update; + p += -(adam_update + weight_decay_update); } // ---------------- Pair parameters ---------------- for (size_t i = 0; i < m_counts.pair_parameter_count; ++i) { - if (Globals::get().is_pair_parameter_constant(i)) { + if (globals.is_pair_parameter_constant(i)) { continue; } @@ -156,18 +147,18 @@ class AdamW { auto& m = m_pair_m[i]; auto& v = m_pair_v[i]; - const f128 g2 = f128::mul(g, g); + const f64x2 g2 = f64x2::mul(g, g); - const f128 m_scaled = f128::mul_scalar(m, m_beta1); - const f128 g_scaled = f128::mul_scalar(g, (1.0 - m_beta1)); - m = f128::add(m_scaled, g_scaled); + const f64x2 m_scaled = f64x2::mul_scalar(m, m_beta1); + const f64x2 g_scaled = f64x2::mul_scalar(g, (1.0 - m_beta1)); + m = f64x2::add(m_scaled, g_scaled); - const f128 v_scaled = f128::mul_scalar(v, m_beta2); - const f128 g2_scaled = f128::mul_scalar(g2, (1.0 - m_beta2)); - v = f128::add(v_scaled, g2_scaled); + const f64x2 v_scaled = f64x2::mul_scalar(v, m_beta2); + const f64x2 g2_scaled = f64x2::mul_scalar(g2, (1.0 - m_beta2)); + v = f64x2::add(v_scaled, g2_scaled); - const f128 m_hat = f128::mul_scalar(m, inv1mb1t); - const f128 v_hat = f128::mul_scalar(v, inv1mb2t); + const f64x2 m_hat = f64x2::mul_scalar(m, inv1mb1t); + const f64x2 v_hat = f64x2::mul_scalar(v, inv1mb2t); const f64 adam_upd_f = m_lr * m_hat.first() / (std::sqrt(v_hat.first()) + m_eps); const f64 adam_upd_s = m_lr * m_hat.second() / (std::sqrt(v_hat.second()) + m_eps); @@ -178,7 +169,7 @@ class AdamW { const f64 total_upd_f = -(adam_upd_f + decay_upd_f); const f64 total_upd_s = -(adam_upd_s + decay_upd_s); - p = f128::add(p, f128::make(total_upd_f, total_upd_s)); + p = f64x2::add(p, f64x2::make(total_upd_f, total_upd_s)); } } diff --git a/src/tuning/value.cpp b/src/tuning/value.cpp index 422556d8..8968b153 100644 --- a/src/tuning/value.cpp +++ b/src/tuning/value.cpp @@ -1,27 +1,261 @@ -#include "value.hpp" -#include "../util/types.hpp" -#include "graph.hpp" -#include "util/vec/sse2.hpp" +#include "tuning/value.hpp" +#include "operations.hpp" +#include "tuning/graph.hpp" +#include "util/types.hpp" #include namespace Clockwork::Autograd { -ValuePtr Value::create(f64 data) { - ValuePtr res = std::make_shared(data); - Graph::get().register_value(res); - return res; +// ------------------ ValueHandle ------------------ + +ValueHandle ValueHandle::create(f64 data) { + return Graph::get().create_value(data); +} + +ValueHandle ValueHandle::sum(const std::vector& inputs) { + if (inputs.empty()) { + return ValueHandle::create(0.0); + } + ValueHandle total = inputs[0]; + for (size_t i = 1; i < inputs.size(); ++i) { + total = total + inputs[i]; + } + return total; +} + +ValueHandle ValueHandle::exp() const { + return Graph::get().record_op(OpType::Exp, *this); +} +ValueHandle ValueHandle::log() const { + return Graph::get().record_op(OpType::Log, *this); +} +ValueHandle ValueHandle::sigmoid() const { + return Graph::get().record_op(OpType::Sigmoid, *this); +} +ValueHandle ValueHandle::pow(ValueHandle exponent) const { + return Graph::get().record_op(OpType::Pow, *this, exponent); +} +ValueHandle ValueHandle::pow(f64 exponent) const { + return Graph::get().record_op(OpType::PowConst, *this, exponent); +} + +void ValueHandle::add_gradient(f64 rhs) const { + if (is_valid()) { + Graph::get().add_value_gradient(index, rhs); + } +} + +f64 ValueHandle::get_value() const { + return is_valid() ? Graph::get().get_value(index) : 0.0; +} + +f64 ValueHandle::get_gradient() const { + return is_valid() ? Graph::get().get_gradient(index) : 0.0; +} + +void ValueHandle::zero_grad() const { + if (is_valid()) { + Graph::get().zero_value_grad(index); + } +} + +void ValueHandle::set_value(f64 v) const { + if (is_valid()) { + Graph::get().set_value(index, v); + } +} + +// PairHandle implementations + +PairHandle PairHandle::create(f64 first, f64 second) { + return Graph::get().create_pair(f64x2::make(first, second)); +} + +PairHandle PairHandle::create(const f64x2& values) { + return Graph::get().create_pair(values); +} + +f64x2 PairHandle::get_values() const { + return Graph::get().get_pair_values(index); +} +f64x2 PairHandle::get_gradients() const { + return Graph::get().get_pair_gradients(index); +} +f64 PairHandle::first() const { + return get_values().first(); +} +f64 PairHandle::second() const { + return get_values().second(); +} + +void PairHandle::set_values(const f64x2& v) const { + Graph::get().set_pair_values(index, v); +} +void PairHandle::set_values(f64 f, f64 s) const { + set_values(f64x2::make(f, s)); +} + +void PairHandle::zero_grad() const { + Graph::get().zero_pair_grad(index); +} + +// Special phasing case +ValueHandle PairHandle::phase_impl(f64 scaled_alpha) const { + return Graph::get().record_phase(*this, scaled_alpha); +} + +// ValueHandle Operators +ValueHandle operator-(ValueHandle a) { + return Graph::get().record_op(OpType::Neg, a); +} +ValueHandle operator+(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Add, a, b); +} +ValueHandle operator-(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Sub, a, b); +} +ValueHandle operator*(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Mul, a, b); +} +ValueHandle operator/(ValueHandle a, ValueHandle b) { + return Graph::get().record_op(OpType::Div, a, b); } -PairPtr Pair::create(f64 first, f64 second) { - PairPtr res = std::make_shared(first, second, true); - Graph::get().register_value(res); - return res; +ValueHandle operator+(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::AddScalar, a, b); +} +ValueHandle operator-(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::ValSubScalar, a, b); +} +ValueHandle operator*(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::MulScalar, a, b); +} +ValueHandle operator/(ValueHandle a, f64 b) { + return Graph::get().record_op(OpType::ValDivScalar, a, b); +} + +ValueHandle operator+(f64 a, ValueHandle b) { + return b + a; +} +ValueHandle operator-(f64 a, ValueHandle b) { + return Graph::get().record_op(OpType::SubScalarVal, b, a); +} +ValueHandle operator*(f64 a, ValueHandle b) { + return b * a; +} +ValueHandle operator/(f64 a, ValueHandle b) { + return Graph::get().record_op(OpType::DivScalarVal, b, a); +} + +bool operator<(ValueHandle a, ValueHandle b) { + return a.get_value() < b.get_value(); +} +bool operator>(ValueHandle a, ValueHandle b) { + return a.get_value() > b.get_value(); +} + +// PairHandle Operators +PairHandle operator+(PairHandle a, PairHandle b) { + return Graph::get().record_pair_op(OpType::PairAdd, a, b); +} +PairHandle operator-(PairHandle a, PairHandle b) { + return Graph::get().record_pair_op(OpType::PairSub, a, b); +} +PairHandle operator-(PairHandle a) { + return Graph::get().record_pair_scalar(OpType::PairNeg, a, 0.0); +} + +PairHandle operator*(PairHandle a, f64 scalar) { + return Graph::get().record_pair_scalar(OpType::PairMulScalar, a, scalar); +} +PairHandle operator*(f64 scalar, PairHandle a) { + return a * scalar; +} +PairHandle operator/(PairHandle a, f64 scalar) { + return Graph::get().record_pair_scalar(OpType::PairDivScalar, a, scalar); +} +PairHandle operator/(f64 scalar, PairHandle a) { + return Graph::get().record_pair_scalar(OpType::ScalarDivPair, a, scalar); +} + +PairHandle operator*(PairHandle a, ValueHandle v) { + return Graph::get().record_pair_value(OpType::PairMulValue, a, v); +} +PairHandle operator*(ValueHandle v, PairHandle a) { + return Graph::get().record_pair_value(OpType::ValueMulPair, a, v); +} +PairHandle operator/(PairHandle a, ValueHandle v) { + return Graph::get().record_pair_value(OpType::PairDivValue, a, v); +} +PairHandle operator/(ValueHandle v, PairHandle a) { + return Graph::get().record_pair_value(OpType::ValueDivPair, a, v); +} + +// Printing overloads for debugging +std::ostream& operator<<(std::ostream& os, const PairHandle& p) { + os << "S(" << std::round(p.first()) << ", " << std::round(p.second()) << ")"; + return os; +} + +// Value Inplaces +ValueHandle& operator+=(ValueHandle& a, ValueHandle b) { + a = a + b; + return a; +} +ValueHandle& operator-=(ValueHandle& a, ValueHandle b) { + a = a - b; + return a; +} +ValueHandle& operator*=(ValueHandle& a, ValueHandle b) { + a = a * b; + return a; +} +ValueHandle& operator/=(ValueHandle& a, ValueHandle b) { + a = a / b; + return a; } -PairPtr Pair::create(const f128& values) { - PairPtr res = std::make_shared(values, true); - Graph::get().register_value(res); - return res; +ValueHandle& operator+=(ValueHandle& a, f64 b) { + a = a + b; + return a; +} +ValueHandle& operator-=(ValueHandle& a, f64 b) { + a = a - b; + return a; +} +ValueHandle& operator*=(ValueHandle& a, f64 b) { + a = a * b; + return a; +} +ValueHandle& operator/=(ValueHandle& a, f64 b) { + a = a / b; + return a; +} + +// Pair Inplaces +PairHandle& operator+=(PairHandle& a, PairHandle b) { + a = a + b; + return a; +} +PairHandle& operator-=(PairHandle& a, PairHandle b) { + a = a - b; + return a; +} +PairHandle& operator*=(PairHandle& a, f64 scalar) { + a = a * scalar; + return a; +} +PairHandle& operator*=(PairHandle& a, ValueHandle v) { + a = a * v; + return a; +} +PairHandle& operator/=(PairHandle& a, f64 scalar) { + a = a / scalar; + return a; +} +PairHandle& operator/=(PairHandle& a, ValueHandle v) { + a = a / v; + return a; } } // namespace Clockwork::Autograd diff --git a/src/tuning/value.hpp b/src/tuning/value.hpp index 67a92378..c8d60241 100644 --- a/src/tuning/value.hpp +++ b/src/tuning/value.hpp @@ -1,628 +1,124 @@ #pragma once -#include "../util/types.hpp" +#include "util/types.hpp" #include "util/vec/sse2.hpp" -#include - #include -#include -#include +#include #include namespace Clockwork::Autograd { -// Forward declarations - -class Backwardable; -template -class SmartBackwardable; -class Value; -class Pair; class Graph; - -using ValuePtr = std::shared_ptr; - -using PairPtr = std::shared_ptr; - -using BackwardablePtr = std::shared_ptr; - - -class Backwardable { -public: - friend class Graph; - std::function m_backward_func; - virtual ~Backwardable() = default; - virtual void backward() = 0; -}; - -template -class SmartBackwardable : public Backwardable, public std::enable_shared_from_this { -public: - virtual ~SmartBackwardable() = default; -}; - - -class Value : public SmartBackwardable { -private: - f64 m_value = 0; - f64 m_gradient = 0; - std::function m_backward_func; - -public: - friend class Graph; - friend class Pair; - - explicit Value(f64 data) : - m_value(data) {}; - - inline f64 get_value() const { - return m_value; - } - inline void change_value(f64 amount) { - m_value += amount; - } - inline void set_value(f64 value) { - m_value = value; - } - inline f64 get_gradient() const { - return m_gradient; - } - - inline void zero_grad() { - m_gradient = 0.0; - } - static ValuePtr create(f64 data); - - ValuePtr exp() { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::exp(this->m_value)); - result->m_backward_func = [this_value](ValuePtr out) { - f64 grad = out->m_value; // Avoid recomputing exp val - this_value->m_gradient += grad * out->m_gradient; - }; - return result; - } - - ValuePtr log() { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::log(this->m_value)); - result->m_backward_func = [this_value](ValuePtr out) { - f64 grad = (1 / this_value->m_value); - this_value->m_gradient += grad * out->m_gradient; - }; - return result; - } - - ValuePtr sigmoid() { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(1 / (1 + std::exp(-this_value->m_value))); - result->m_backward_func = [this_value](ValuePtr out) { - f64 grad = out->m_value - * (1 - out->m_value); // Same trick as before, avoid recomputing sigmoid(x) - this_value->m_gradient += grad * out->m_gradient; - }; - return result; - } - - ValuePtr pow(ValuePtr exponent) { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::pow(this_value->m_value, exponent->m_value)); - result->m_backward_func = [this_value, exponent](ValuePtr out) { - this_value->m_gradient += exponent->m_value - * std::pow(this_value->m_value, exponent->m_value - 1) - * out->m_gradient; - exponent->m_gradient += out->m_value * std::log(this_value->m_value) * out->m_gradient; - }; - return result; - } - - ValuePtr pow(f64 exponent) { - auto this_value = this->shared_from_this(); - ValuePtr result = Value::create(std::pow(this_value->m_value, exponent)); - result->m_backward_func = [this_value, exponent](ValuePtr out) { - this_value->m_gradient += - exponent * std::pow(this_value->m_value, exponent - 1) * out->m_gradient; - }; - return result; - } - - void add_gradient(f64 rhs) { - m_gradient += rhs; - } - - friend ValuePtr operator-(ValuePtr a) { - ValuePtr result = Value::create(-a->m_value); - result->m_backward_func = [a](ValuePtr out) { - f64 grad = -out->m_gradient; - a->m_gradient += grad; - }; - return result; - } - - friend ValuePtr operator+(ValuePtr a, ValuePtr b) { - ValuePtr result = Value::create(a->m_value + b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += out->m_gradient; - b->m_gradient += out->m_gradient; - }; - return result; +struct ValueHandle { + u32 index; + ValueHandle() : + index(0xFFFFFFFF) { } - - friend ValuePtr operator-(ValuePtr a, - ValuePtr b) { // We are NOT cheaping out with a + (-b) - ValuePtr result = Value::create(a->m_value - b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += out->m_gradient; - b->m_gradient -= out->m_gradient; - }; - return result; - } - - friend ValuePtr operator*(ValuePtr a, ValuePtr b) { - ValuePtr result = Value::create(a->m_value * b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += b->m_value * out->m_gradient; - b->m_gradient += a->m_value * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator/(ValuePtr a, - ValuePtr b) { // We are NOT cheaping out with a * (std::pow(b,-1)) - ValuePtr result = Value::create(a->m_value / b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += 1.0 / b->m_value * out->m_gradient; - b->m_gradient += -a->m_value / (b->m_value * b->m_value) * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator+(ValuePtr a, f64 b) { - ValuePtr result = Value::create(a->m_value + b); - result->m_backward_func = [a](ValuePtr out) { - a->m_gradient += out->m_gradient; - }; - return result; - } - - friend ValuePtr operator-(ValuePtr a, f64 b) { - ValuePtr result = Value::create(a->m_value - b); - result->m_backward_func = [a](ValuePtr out) { - a->m_gradient += out->m_gradient; - }; - return result; - } - - friend ValuePtr operator*(ValuePtr a, f64 b) { - ValuePtr result = Value::create(a->m_value * b); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += b * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator/(ValuePtr a, - f64 b) { // We are NOT cheaping out with a * (std::pow(b,-1)) - ValuePtr result = Value::create(a->m_value / b); - result->m_backward_func = [a, b](ValuePtr out) { - a->m_gradient += 1.0 / b * out->m_gradient; - }; - return result; - } - - friend ValuePtr operator+(f64 a, ValuePtr b) { - return b + a; - } - - friend ValuePtr operator-(f64 a, ValuePtr b) { - ValuePtr result = Value::create(a - b->m_value); - result->m_backward_func = [b](ValuePtr out) { - b->m_gradient -= out->m_gradient; - }; - return result; - } - - friend ValuePtr operator*(f64 a, ValuePtr b) { - return b * a; + explicit ValueHandle(u32 idx) : + index(idx) { } - - friend ValuePtr operator/(f64 a, ValuePtr b) { - ValuePtr result = Value::create(a / b->m_value); - result->m_backward_func = [a, b](ValuePtr out) { - b->m_gradient += -a / (b->m_value * b->m_value) * out->m_gradient; - }; - return result; - } - - static ValuePtr sum(const std::vector& inputs) { - if (inputs.empty()) { - return Value::create(0.0); - } - - f64 sum = 0.0; - f64 c = 0.0; - for (auto& v : inputs) { - f64 y = v->m_value - c; - f64 t = sum + y; - c = (t - sum) - y; - sum = t; - } - - ValuePtr result = Value::create(sum); - - result->m_backward_func = [inputs](ValuePtr out) { - f64 grad = out->m_gradient; - for (auto& v : inputs) { - v->m_gradient += grad; - } - }; - - return result; - } - - - friend bool operator==(ValuePtr a, ValuePtr b) { - return a->m_value == b->m_value; + bool is_valid() const { + return index != 0xFFFFFFFF; } - friend bool operator!=(ValuePtr a, ValuePtr b) { - return a->m_value != b->m_value; - } - - friend bool operator<(ValuePtr a, ValuePtr b) { - return a->m_value < b->m_value; - } - - friend bool operator<=(ValuePtr a, ValuePtr b) { - return a->m_value <= b->m_value; - } - - friend bool operator>(ValuePtr a, ValuePtr b) { - return a->m_value > b->m_value; - } - - friend bool operator>=(ValuePtr a, ValuePtr b) { - return a->m_value >= b->m_value; - } + static ValueHandle create(f64 data); + static ValueHandle sum(const std::vector& inputs); - friend std::ostream& operator<<(std::ostream& os, const ValuePtr& value) { - os << "Value(data=" << value->get_value() << ", grad=" << value->get_gradient() << ")"; - return os; - } + ValueHandle exp() const; + ValueHandle log() const; + ValueHandle sigmoid() const; + ValueHandle pow(ValueHandle exponent) const; + ValueHandle pow(f64 exponent) const; - void backward() override { - auto this_value = this->shared_from_this(); - if (this_value->m_backward_func) { - this_value->m_backward_func(this_value); - } - } + void add_gradient(f64 rhs) const; + f64 get_value() const; + f64 get_gradient() const; + void zero_grad() const; + void set_value(f64 v) const; }; -class Pair : public SmartBackwardable { -private: - std::function m_backward_func; - bool m_constant; - -public: - friend class Graph; - friend class Value; - - f128 m_values; // stores (first, second) - f128 m_gradients; // stores (grad_first, grad_second) - - explicit Pair(f64 first, f64 second, bool constant = false) : - m_constant(constant), - m_values(f128::make(first, second)), - m_gradients(f128::zero()) { - } - - explicit Pair(const f128& values, bool constant = false) : - m_constant(constant), - m_values(values), - m_gradients(f128::zero()) { - } - - static PairPtr create(f64 first, f64 second); - - static PairPtr create(const f128& values); - - inline f64 first() const { - return m_values.first(); - } - - inline f64 second() const { - return m_values.second(); - } - - inline f128 get_values() const { - return m_values; - } - - inline f64 grad_first() const { - return m_gradients.first(); - } - - inline f64 grad_second() const { - return m_gradients.second(); - } - - inline f128 get_graidents() const { - return m_gradients; - } - - inline void zero_grad() { - m_gradients = f128::zero(); - } - - inline void set_values(const f128& values) { - m_values = values; - } - inline void set_values(f64 first, f64 second) { - m_values = f128::make(first, second); - } - - // ------------------- Backward ------------------- - void backward() override { - auto self = this->shared_from_this(); - if (m_backward_func) { - m_backward_func(self); - } - } - - // ------------------- Arithmetic ------------------- - friend PairPtr operator+(const PairPtr& a, const PairPtr& b) { - f128 result_values = f128::add(a->m_values, b->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, b](PairPtr out) { - a->m_gradients = f128::add(a->m_gradients, out->m_gradients); - b->m_gradients = f128::add(b->m_gradients, out->m_gradients); - }; - return result; - } - - friend PairPtr operator-(const PairPtr& a, const PairPtr& b) { - f128 result_values = f128::sub(a->m_values, b->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, b](PairPtr out) { - a->m_gradients = f128::add(a->m_gradients, out->m_gradients); - b->m_gradients = f128::sub(b->m_gradients, out->m_gradients); - }; - return result; - } - - // ------------------- Scalar multiplication ------------------- - friend PairPtr operator*(const PairPtr& a, f64 scalar) { - f128 result_values = f128::mul_scalar(a->m_values, scalar); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, scalar](PairPtr out) { - f128 scaled_grad = f128::mul_scalar(out->m_gradients, scalar); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); - }; - return result; - } - - friend PairPtr operator*(f64 scalar, const PairPtr& a) { - return a * scalar; - } - - // ------------------- Pair * Value ------------------- - friend PairPtr operator*(const PairPtr& a, const ValuePtr& v) { - f64 v_val = v->get_value(); - f128 result_values = f128::mul_scalar(a->m_values, v_val); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, v](PairPtr out) { - f64 v_val = v->get_value(); - f128 scaled_grad = f128::mul_scalar(out->m_gradients, v_val); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); - - f128 contribution = f128::mul(a->m_values, out->m_gradients); - v->add_gradient(contribution.first() + contribution.second()); - }; - return result; +struct PairHandle { + u32 index; + PairHandle() : + index(0xFFFFFFFF) { } - - friend PairPtr operator*(const ValuePtr& v, const PairPtr& a) { - return a * v; + explicit PairHandle(u32 idx) : + index(idx) { } - - // ------------------- Pair / scalar ------------------- - friend PairPtr operator/(const PairPtr& a, f64 scalar) { - f128 result_values = f128::div_scalar(a->m_values, scalar); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, scalar](PairPtr out) { - f128 scaled_grad = f128::div_scalar(out->m_gradients, scalar); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); - }; - return result; + bool is_valid() const { + return index != 0xFFFFFFFF; } - // ------------------- Scalar / Pair ------------------- - friend PairPtr operator/(f64 scalar, const PairPtr& a) { - f128 result_values = f128::scalar_div(scalar, a->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, scalar](PairPtr out) { - f128 a_squared = f128::mul(a->m_values, a->m_values); - f128 neg_scalar_over_a_sq = f128::neg(f128::scalar_div(scalar, a_squared)); - f128 grad_update = f128::mul(neg_scalar_over_a_sq, out->m_gradients); - a->m_gradients = f128::add(a->m_gradients, grad_update); - }; - return result; + static PairHandle create(f64 first, f64 second); + static PairHandle create(const f64x2& values); + static PairHandle create_tunable(f64 a, f64 b) { + return create(a, b); } - // ------------------- Pair / Value ------------------- - friend PairPtr operator/(const PairPtr& a, const ValuePtr& v) { - f64 v_val = v->get_value(); - f128 result_values = f128::div_scalar(a->m_values, v_val); - PairPtr result = Pair::create(result_values); + f64x2 get_values() const; + f64x2 get_gradients() const; + f64 first() const; + f64 second() const; + void set_values(const f64x2& v) const; + void set_values(f64 f, f64 s) const; + void zero_grad() const; - result->m_backward_func = [a, v](PairPtr out) { - f64 v_val = v->get_value(); - f128 scaled_grad = f128::div_scalar(out->m_gradients, v_val); - a->m_gradients = f128::add(a->m_gradients, scaled_grad); + // Internal helper to avoid including Graph in header + ValueHandle phase_impl(f64 scaled_alpha) const; - f128 numerator = f128::mul(a->m_values, out->m_gradients); - f64 sum_contribution = numerator.first() + numerator.second(); - v->add_gradient(-sum_contribution / (v_val * v_val)); - }; - return result; - } - - // ------------------- Value / Pair ------------------- - friend PairPtr operator/(const ValuePtr& v, const PairPtr& a) { - f64 v_val = v->get_value(); - f128 result_values = f128::scalar_div(v_val, a->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a, v](PairPtr out) { - f64 v_val = v->get_value(); - f128 a_squared = f128::mul(a->m_values, a->m_values); - f128 neg_v_over_a_sq = f128::neg(f128::scalar_div(v_val, a_squared)); - f128 grad_update = f128::mul(neg_v_over_a_sq, out->m_gradients); - a->m_gradients = f128::add(a->m_gradients, grad_update); - - f128 v_grad_contrib = f128::div(out->m_gradients, a->m_values); - v->add_gradient(v_grad_contrib.first() + v_grad_contrib.second()); - }; - return result; - } - - // ------------------- Unary negation ------------------- - friend PairPtr operator-(PairPtr a) { - f128 result_values = f128::neg(a->m_values); - PairPtr result = Pair::create(result_values); - - result->m_backward_func = [a](PairPtr out) { - a->m_gradients = f128::sub(a->m_gradients, out->m_gradients); - }; - return result; - } - - // ------------------- Phasing ------------------- template - ValuePtr phase(f64 alpha) { - alpha /= max; - auto self = this->shared_from_this(); - - // Linear interpolation: alpha * first + (1-alpha) * second - f64 result_val = alpha * first() + (1.0 - alpha) * second(); - ValuePtr result = Value::create(result_val); - - result->m_backward_func = [self, alpha](ValuePtr out) { - // Gradient of output w.r.t first and second - f64 grad_first = alpha * out->m_gradient; - f64 grad_second = (1.0 - alpha) * out->m_gradient; - f128 grad_update = f128::make(grad_first, grad_second); - self->m_gradients = f128::add(self->m_gradients, grad_update); - }; - - return result; - } - - // ------------------- Print ------------------- - friend std::ostream& operator<<(std::ostream& os, const PairPtr& p) { -#ifdef VERBOSE_AUTOGRAD - os << "Pair(first=" << p->first() << ", second=" << p->second() - << ", grad_first=" << p->grad_first() << ", grad_second=" << p->grad_second() << ")"; -#else - os << (p->m_constant ? "CS" : "S"); - os << "(" << std::round(p->first()) << "," << std::round(p->second()) << ")"; -#endif - return os; + ValueHandle phase(f64 alpha) const { + return phase_impl(alpha / max); } }; -// Inplace ops (we can't do them as member functions because we use shared pointers) - -// ValuePtr += ValuePtr -inline ValuePtr& operator+=(ValuePtr& a, const ValuePtr& b) { - a = a + b; - return a; -} - -// ValuePtr -= ValuePtr -inline ValuePtr& operator-=(ValuePtr& a, const ValuePtr& b) { - a = a - b; - return a; -} - -// ValuePtr *= ValuePtr -inline ValuePtr& operator*=(ValuePtr& a, const ValuePtr& b) { - a = a * b; - return a; -} - -// ValuePtr /= ValuePtr -inline ValuePtr& operator/=(ValuePtr& a, const ValuePtr& b) { - a = a / b; - return a; -} - - -// ValuePtr += scalar -inline ValuePtr& operator+=(ValuePtr& a, f64 b) { - a = a + b; - return a; -} - -// ValuePtr -= scalar -inline ValuePtr& operator-=(ValuePtr& a, f64 b) { - a = a - b; - return a; -} - -// ValuePtr *= scalar -inline ValuePtr& operator*=(ValuePtr& a, f64 b) { - a = a * b; - return a; -} - -// ValuePtr /= scalar -inline ValuePtr& operator/=(ValuePtr& a, f64 b) { - a = a / b; - return a; -} - -// PairPtr += PairPtr -inline PairPtr& operator+=(PairPtr& a, const PairPtr& b) { - a = a + b; - return a; -} - -// PairPtr -= PairPtr -inline PairPtr& operator-=(PairPtr& a, const PairPtr& b) { - a = a - b; - return a; -} - -// PairPtr *= scalar -inline PairPtr& operator*=(PairPtr& a, f64 scalar) { - a = a * scalar; - return a; -} - -// PairPtr *= ValuePtr -inline PairPtr& operator*=(PairPtr& a, const ValuePtr& v) { - a = a * v; - return a; -} - -// PairPtr /= scalar -inline PairPtr& operator/=(PairPtr& a, f64 scalar) { - a = a / scalar; - return a; -} - -// PairPtr /= ValuePtr -inline PairPtr& operator/=(PairPtr& a, const ValuePtr& v) { - a = a / v; - return a; -} +// Operation decls +ValueHandle operator-(ValueHandle a); +ValueHandle operator+(ValueHandle a, ValueHandle b); +ValueHandle operator-(ValueHandle a, ValueHandle b); +ValueHandle operator*(ValueHandle a, ValueHandle b); +ValueHandle operator/(ValueHandle a, ValueHandle b); +ValueHandle operator+(ValueHandle a, f64 b); +ValueHandle operator-(ValueHandle a, f64 b); +ValueHandle operator*(ValueHandle a, f64 b); +ValueHandle operator/(ValueHandle a, f64 b); +ValueHandle operator+(f64 a, ValueHandle b); +ValueHandle operator-(f64 a, ValueHandle b); +ValueHandle operator*(f64 a, ValueHandle b); +ValueHandle operator/(f64 a, ValueHandle b); +bool operator<(ValueHandle a, ValueHandle b); +bool operator>(ValueHandle a, ValueHandle b); + +PairHandle operator+(PairHandle a, PairHandle b); +PairHandle operator-(PairHandle a, PairHandle b); +PairHandle operator-(PairHandle a); +PairHandle operator*(PairHandle a, f64 scalar); +PairHandle operator*(f64 scalar, PairHandle a); +PairHandle operator/(PairHandle a, f64 scalar); +PairHandle operator/(f64 scalar, PairHandle a); +PairHandle operator*(PairHandle a, ValueHandle v); +PairHandle operator*(ValueHandle v, PairHandle a); +PairHandle operator/(PairHandle a, ValueHandle v); +PairHandle operator/(ValueHandle v, PairHandle a); +std::ostream& operator<<(std::ostream& os, const PairHandle& p); + +// Value Inplaces +ValueHandle& operator+=(ValueHandle& a, ValueHandle b); +ValueHandle& operator-=(ValueHandle& a, ValueHandle b); +ValueHandle& operator*=(ValueHandle& a, ValueHandle b); +ValueHandle& operator/=(ValueHandle& a, ValueHandle b); +ValueHandle& operator+=(ValueHandle& a, f64 b); +ValueHandle& operator-=(ValueHandle& a, f64 b); +ValueHandle& operator*=(ValueHandle& a, f64 b); +ValueHandle& operator/=(ValueHandle& a, f64 b); + +// Pair Inplaces +PairHandle& operator+=(PairHandle& a, PairHandle b); +PairHandle& operator-=(PairHandle& a, PairHandle b); +PairHandle& operator*=(PairHandle& a, f64 scalar); +PairHandle& operator*=(PairHandle& a, ValueHandle v); +PairHandle& operator/=(PairHandle& a, f64 scalar); +PairHandle& operator/=(PairHandle& a, ValueHandle v); } // namespace Clockwork::Autograd diff --git a/src/util/vec/sse2.hpp b/src/util/vec/sse2.hpp index 0b212b68..86852ae2 100644 --- a/src/util/vec/sse2.hpp +++ b/src/util/vec/sse2.hpp @@ -6,13 +6,13 @@ #if defined(__SSE2__) #include - #define F128_USE_SSE2 1 + #define F64X2_USE_SSE2 1 #else - #define F128_USE_SSE2 0 + #define F64X2_USE_SSE2 0 #endif -struct f128 { -#if F128_USE_SSE2 +struct f64x2 { +#if F64X2_USE_SSE2 __m128d v = _mm_setzero_pd(); #else double lo = 0.0; @@ -20,9 +20,9 @@ struct f128 { #endif // ---- Constructors ---- - static inline f128 make(double a, double b) { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 make(double a, double b) { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_set_pd(b, a); return r; #else @@ -30,9 +30,9 @@ struct f128 { #endif } - static inline f128 broadcast(double x) { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 broadcast(double x) { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_set1_pd(x); return r; #else @@ -40,9 +40,9 @@ struct f128 { #endif } - static inline f128 zero() { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 zero() { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_setzero_pd(); return r; #else @@ -52,7 +52,7 @@ struct f128 { // ---- Extract ---- inline double first() const { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 alignas(16) double buf[2]; _mm_store_pd(buf, v); return buf[0]; @@ -62,7 +62,7 @@ struct f128 { } inline double second() const { -#if F128_USE_SSE2 +#if F64X2_USE_SSE2 alignas(16) double buf[2]; _mm_store_pd(buf, v); return buf[1]; @@ -72,9 +72,9 @@ struct f128 { } // ---- Arithmetic ---- - static inline f128 add(const f128& a, const f128& b) { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 add(const f64x2& a, const f64x2& b) { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_add_pd(a.v, b.v); return r; #else @@ -82,9 +82,9 @@ struct f128 { #endif } - static inline f128 sub(const f128& a, const f128& b) { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 sub(const f64x2& a, const f64x2& b) { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_sub_pd(a.v, b.v); return r; #else @@ -92,9 +92,9 @@ struct f128 { #endif } - static inline f128 mul(const f128& a, const f128& b) { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 mul(const f64x2& a, const f64x2& b) { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_mul_pd(a.v, b.v); return r; #else @@ -102,9 +102,9 @@ struct f128 { #endif } - static inline f128 div(const f128& a, const f128& b) { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 div(const f64x2& a, const f64x2& b) { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_div_pd(a.v, b.v); return r; #else @@ -112,10 +112,10 @@ struct f128 { #endif } - static inline f128 neg(const f128& a) { -#if F128_USE_SSE2 + static inline f64x2 neg(const f64x2& a) { +#if F64X2_USE_SSE2 __m128d zero = _mm_setzero_pd(); - f128 r; + f64x2 r; r.v = _mm_sub_pd(zero, a.v); return r; #else @@ -124,26 +124,26 @@ struct f128 { } // ---- Scalar ops ---- - static inline f128 add_scalar(const f128& a, double s) { + static inline f64x2 add_scalar(const f64x2& a, double s) { return add(a, broadcast(s)); } - static inline f128 sub_scalar(const f128& a, double s) { + static inline f64x2 sub_scalar(const f64x2& a, double s) { return sub(a, broadcast(s)); } - static inline f128 mul_scalar(const f128& a, double s) { + static inline f64x2 mul_scalar(const f64x2& a, double s) { return mul(a, broadcast(s)); } - static inline f128 div_scalar(const f128& a, double s) { + static inline f64x2 div_scalar(const f64x2& a, double s) { return div(a, broadcast(s)); } - static inline f128 scalar_div(double s, const f128& a) { -#if F128_USE_SSE2 + static inline f64x2 scalar_div(double s, const f64x2& a) { +#if F64X2_USE_SSE2 __m128d num = _mm_set1_pd(s); - f128 r; + f64x2 r; r.v = _mm_div_pd(num, a.v); return r; #else @@ -152,9 +152,9 @@ struct f128 { } // ---- Math functions ---- - static inline f128 sqrt(const f128& a) { -#if F128_USE_SSE2 - f128 r; + static inline f64x2 sqrt(const f64x2& a) { +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_sqrt_pd(a.v); return r; #else @@ -163,10 +163,10 @@ struct f128 { } // ---- FMA-style (useful for gradient updates) ---- - static inline f128 madd(const f128& a, const f128& b, const f128& c) { + static inline f64x2 madd(const f64x2& a, const f64x2& b, const f64x2& c) { // a + b*c -#if F128_USE_SSE2 - f128 r; +#if F64X2_USE_SSE2 + f64x2 r; r.v = _mm_add_pd(a.v, _mm_mul_pd(b.v, c.v)); return r; #else @@ -175,7 +175,7 @@ struct f128 { } // ---- Printing ---- - friend std::ostream& operator<<(std::ostream& os, const f128& f) { + friend std::ostream& operator<<(std::ostream& os, const f64x2& f) { os << "(" << f.first() << ", " << f.second() << ")"; return os; }