From 59fe1fd4caae7b8e46feb42893e4e337c21fcb54 Mon Sep 17 00:00:00 2001 From: 87 Date: Thu, 28 Aug 2025 21:10:02 +0100 Subject: [PATCH] Implement C-AdamW for eval tuning Bench: 8584309 --- src/evaltune_main.cpp | 2 +- src/tuning/optim.hpp | 161 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 1 deletion(-) diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index 4410242d..56819415 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -75,7 +75,7 @@ int main(int argc, char* argv[]) { std::cout << "Loaded " << positions.size() << " FENs from " << fenFiles.size() << " files." << std::endl; - Clockwork::Autograd::AdamW optim(10, 0.9, 0.999, 1e-8, 0.0); + Clockwork::Autograd::CAdamW optim(10, 0.9, 0.999, 1e-8, 0.0); i32 epochs = 1000; const f64 K = 1.0 / 400; diff --git a/src/tuning/optim.hpp b/src/tuning/optim.hpp index 61434bed..8105416c 100644 --- a/src/tuning/optim.hpp +++ b/src/tuning/optim.hpp @@ -177,5 +177,166 @@ class AdamW { } }; +class CAdamW { +private: + std::vector m_value_params; + std::vector m_pair_params; + + f64 m_lr; + f64 m_beta1; + f64 m_beta2; + f64 m_eps; + f64 m_weight_decay; + long long m_t; + + std::vector m_m; + std::vector m_v; + std::vector m_pair_m; + std::vector m_pair_v; + +public: + explicit CAdamW( + f64 lr = 1e-3, f64 beta1 = 0.9, f64 beta2 = 0.999, f64 eps = 1e-8, f64 weight_decay = 0.01) : + m_lr(lr), + m_beta1(beta1), + m_beta2(beta2), + m_eps(eps), + m_weight_decay(weight_decay), + m_t(0) { + auto graph = Graph::get(); + m_value_params = graph->get_parameters(); + m_pair_params = graph->get_pair_parameters(); + + m_m.resize(m_value_params.size(), 0.0); + m_v.resize(m_value_params.size(), 0.0); + + m_pair_m.resize(m_pair_params.size(), f128::zero()); + m_pair_v.resize(m_pair_params.size(), f128::zero()); + } + + void step() { + m_t += 1; + + const f64 b1t = std::pow(m_beta1, static_cast(m_t)); + const f64 b2t = std::pow(m_beta2, static_cast(m_t)); + const f64 inv1mb1t = 1.0 / (1.0 - b1t); + const f64 inv1mb2t = 1.0 / (1.0 - b2t); + + // ---------------- Value parameters ---------------- + { + std::vector adam_updates; + std::vector alignment_mask; + i64 alignment_mask_nnz = 0; + + const size_t N = m_value_params.size(); + + adam_updates.resize(N); + alignment_mask.resize(N); + + for (size_t i = 0; i < N; ++i) { + auto& p = m_value_params[i]; + const f64 g = p->get_gradient(); + + m_m[i] = m_beta1 * m_m[i] + (1.0 - m_beta1) * g; + + m_v[i] = m_beta2 * m_v[i] + (1.0 - m_beta2) * g * g; + + const f64 m_hat = m_m[i] * inv1mb1t; + const f64 v_hat = m_v[i] * inv1mb2t; + + adam_updates[i] = m_lr * m_hat / (std::sqrt(v_hat) + m_eps); + + bool alignment = adam_updates[i] * g > 0; + alignment_mask[i] = alignment; + alignment_mask_nnz += alignment; + } + + f64 cautiousness_rescale = static_cast(N) / (alignment_mask_nnz + 1); + + for (size_t i = 0; i < N; ++i) { + auto& p = m_value_params[i]; + + const f64 adam_update = alignment_mask[i] * cautiousness_rescale * adam_updates[i]; + const f64 weight_decay_update = m_lr * m_weight_decay * p->get_value(); + + const f64 total_update = -(adam_update + weight_decay_update); + + p->change_value(total_update); + } + } + + // ---------------- Pair parameters ---------------- + { + std::vector> adam_updates; + std::vector> alignment_mask; + i64 alignment_mask_nnz = 0; + + const size_t N = m_pair_params.size(); + + adam_updates.resize(N); + alignment_mask.resize(N); + + for (size_t i = 0; i < N; ++i) { + auto& p = m_pair_params[i]; + auto& m = m_pair_m[i]; + auto& v = m_pair_v[i]; + + const f128 g_vec = f128::make(p->grad_first(), p->grad_second()); + const f128 g2_vec = f128::make(p->grad_first() * p->grad_first(), + p->grad_second() * p->grad_second()); + + const f128 m_scaled = f128::mul_scalar(m, m_beta1); + const f128 g_scaled = f128::mul_scalar(g_vec, (1.0 - m_beta1)); + m = f128::add(m_scaled, g_scaled); + + const f128 v_scaled = f128::mul_scalar(v, m_beta2); + const f128 g2_scaled = f128::mul_scalar(g2_vec, (1.0 - m_beta2)); + v = f128::add(v_scaled, g2_scaled); + + const f128 m_hat = f128::mul_scalar(m, inv1mb1t); + const f128 v_hat = f128::mul_scalar(v, inv1mb2t); + + const f64 adam_upd_f = m_lr * m_hat.first() / (std::sqrt(v_hat.first()) + m_eps); + const f64 adam_upd_s = m_lr * m_hat.second() / (std::sqrt(v_hat.second()) + m_eps); + + adam_updates[i] = {adam_upd_f, adam_upd_s}; + + bool align_f = adam_upd_f * p->grad_first() > 0; + bool align_s = adam_upd_s * p->grad_second() > 0; + + alignment_mask[i] = {align_f, align_s}; + alignment_mask_nnz += align_f + align_s; + } + + f64 cautiousness_rescale = static_cast(N * 2) / (alignment_mask_nnz + 1); + + for (size_t i = 0; i < N; ++i) { + auto& p = m_pair_params[i]; + + bool align_f = alignment_mask[i][0]; + bool align_s = alignment_mask[i][1]; + + const f64 adam_upd_f = align_f * cautiousness_rescale * adam_updates[i][0]; + const f64 adam_upd_s = align_s * cautiousness_rescale * adam_updates[i][1]; + + const f64 decay_upd_f = m_lr * m_weight_decay * p->first(); + const f64 decay_upd_s = m_lr * m_weight_decay * p->second(); + + const f64 total_upd_f = -(adam_upd_f + decay_upd_f); + const f64 total_upd_s = -(adam_upd_s + decay_upd_s); + + p->m_values = f128::add(p->m_values, f128::make(total_upd_f, total_upd_s)); + } + } + } + + void set_lr(f64 lr) { + m_lr = lr; + } + f64 get_lr() const { + return m_lr; + } +}; + } // namespace Autograd } // namespace Clockwork