Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 94 additions & 6 deletions src/evaltune_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,111 @@
#include <numeric>
#include <random>
#include <sstream>
#include <string_view>
#include <thread>
#include <tuple>

using namespace Clockwork;

int main() {
struct AdamWParams {
f64 learning_rate{10.0};
f64 beta1{0.9};
f64 beta2{0.999};
f64 weight_decay{0.0};
};

void print_help(char** argv) {
std::cout << "Usage: " << argv[0] << " [options]\n\n";
std::cout << "Options:\n";
std::cout << " -h, --help Show this help message and exit.\n";
std::cout << " -t, --threads <number> Number of threads to use (type: uint32_t, default: "
<< std::thread::hardware_concurrency() / 2 << ".\n";
std::cout
<< " -e, --epochs <number> Number of training epochs (type: int32_t, default: 1000).\n";
std::cout << " -b, --batch <number> Batch size for training (type: size_t, default: "
<< static_cast<size_t>(16 * 16384) << ").\n";
std::cout
<< " -d, --decay <value> Learning rate decay factor per epoch (type: double, default: 0.91).\n";
std::cout << "\nAdamW Optimizer Parameters:\n";
std::cout << " --lr <value> Learning rate (type: double, default: 10.0).\n";
std::cout << " --beta1 <value> Beta1 parameter (type: double, default: 0.9).\n";
std::cout << " --beta2 <value> Beta2 parameter (type: double, default: 0.999).\n";
std::cout << " --weight_decay <value> Weight decay (type: double, default: 0.0).\n";
}

int main(int argc, char** argv) {

//Default params

uint32_t thread_count_p = std::thread::hardware_concurrency() / 2;
int32_t epochs_p = 1000;
size_t batch_size_p = 16 * 16384;
f64 decay_p = 0.91;

AdamWParams adam_p;

//Args parsing
for (int i = 1; i < argc; ++i) {
std::string_view arg = argv[i];

if (arg == "--help" || arg == "-h") {
print_help(argv);
return 0;
}

// Thread Count check
if ((arg == "--threads" || arg == "-t") && i + 1 < argc) {
thread_count_p = static_cast<uint32_t>(std::stoul(argv[++i]));
}
// Epochs check
else if ((arg == "--epochs" || arg == "-e") && i + 1 < argc) {
epochs_p = static_cast<int32_t>(std::stoi(argv[++i]));
}
// Batch Size check
else if ((arg == "--batch" || arg == "-b") && i + 1 < argc) {
batch_size_p = static_cast<size_t>(std::stoull(argv[++i]));
}
// AdamW Params check
else if (arg == "--lr" && i + 1 < argc) {
adam_p.learning_rate = std::stod(argv[++i]);
} else if (arg == "--beta1" && i + 1 < argc) {
adam_p.beta1 = std::stod(argv[++i]);
} else if (arg == "--beta2" && i + 1 < argc) {
adam_p.beta2 = std::stod(argv[++i]);
} else if (arg == "--weight_decay" && i + 1 < argc) {
adam_p.weight_decay = std::stod(argv[++i]);
}

//Decay check
else if ((arg == "--decay" || arg == "-d") && i + 1 < argc) {
decay_p = std::stod(argv[++i]);
} else {
// Check if it's a flag without a value or an unknown flag
if (arg.rfind("--", 0) == 0 || arg.rfind("-", 0) == 0) {
if (i + 1 >= argc || (argv[i + 1][0] == '-' && !std::isdigit(argv[i + 1][1]))) {
std::cout << "Warning! Argument '" << argv[i] << "' has a missing value.\n Run "
<< argv[0] << " --help to list all arguments.";
exit(-1);
} else {
std::cout << "Warning! Arg not recognized: '" << argv[i] << "'\n Run "
<< argv[0] << " --help to list all arguments.\n";
exit(-1);
}
}
}
}

// Load fens from multiple files.
std::vector<Position> positions;
std::vector<f64> results;

// List of files to load
const std::vector<std::string> fenFiles = {

Check warning on line 120 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:120:36 [readability-identifier-naming]

invalid case style for constant 'fenFiles'
"data/dfrc-1m.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3/v3.txt",
};

// Number of threads to use, default to half available
const u32 thread_count = std::max<u32>(1, std::thread::hardware_concurrency() / 2);
const u32 thread_count = std::max<u32>(1, thread_count_p);

Check warning on line 125 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:125:15 [readability-identifier-naming]

invalid case style for constant 'thread_count'

std::cout << "Running on " << thread_count << " threads" << std::endl;

Expand Down Expand Up @@ -92,18 +179,19 @@

using namespace Clockwork::Autograd;

const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts();

Check warning on line 182 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:182:30 [readability-identifier-naming]

invalid case style for constant 'parameter_count'
Parameters current_parameter_values = Graph::get().get_all_parameter_values();

AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0);
AdamW optim(parameter_count, adam_p.learning_rate, adam_p.beta1, adam_p.beta2, 1e-8,
adam_p.weight_decay);

const i32 epochs = 1000;
const i32 epochs = epochs_p;

Check warning on line 188 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:188:18 [readability-identifier-naming]

invalid case style for constant 'epochs'
const f64 K = 1.0 / 400;
const size_t batch_size = 16 * 16384; // Set batch size here
const size_t batch_size = batch_size_p; // Set batch size here

Check warning on line 190 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:190:18 [readability-identifier-naming]

invalid case style for constant 'batch_size'

std::mt19937 rng(std::random_device{}()); // Random number generator for shuffling

const size_t total_batches = (positions.size() + batch_size - 1) / batch_size;

Check warning on line 194 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:194:25 [readability-identifier-naming]

invalid case style for constant 'total_batches'
std::vector<size_t> indices(positions.size());

Parameters batch_gradients = Parameters::zeros(parameter_count);
Expand Down Expand Up @@ -188,7 +276,7 @@
// Print epoch header
std::cout << "Epoch " << (epoch + 1) << "/" << epochs << std::endl;

const auto epoch_start_time = time::Clock::now();

Check warning on line 279 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:279:20 [readability-identifier-naming]

invalid case style for constant 'epoch_start_time'

std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), rng);
Expand All @@ -204,7 +292,7 @@
print_progress(batch_idx + 1, total_batches);
}

const auto epoch_end_time = time::Clock::now();

Check warning on line 295 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:295:20 [readability-identifier-naming]

invalid case style for constant 'epoch_end_time'

std::cout << std::endl; // Finish progress bar line

Expand Down Expand Up @@ -313,7 +401,7 @@
print_table("BISHOP_PAWNS", BISHOP_PAWNS);
std::cout << std::endl;

auto printPsqtArray = [](const std::string& name, const auto& arr) {

Check warning on line 404 in src/evaltune_main.cpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/evaltune_main.cpp:404:14 [readability-identifier-naming]

invalid case style for variable 'printPsqtArray'
std::cout << "inline const std::array<PParam, " << arr.size() << "> " << name << " = {"
<< std::endl;
for (std::size_t i = 0; i < arr.size(); ++i) {
Expand Down Expand Up @@ -342,7 +430,7 @@
<< "s" << std::endl;

if (epoch > 5) {
optim.set_lr(optim.get_lr() * 0.91);
optim.set_lr(optim.get_lr() * decay_p);
}
}

Expand Down
Loading