diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index 0568a660..0e99dfbe 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -18,12 +18,99 @@ #include #include #include +#include #include #include using namespace Clockwork; -int main() { +struct AdamWParams { + f64 learning_rate{10.0}; + f64 beta1{0.9}; + f64 beta2{0.999}; + f64 weight_decay{0.0}; +}; + +void print_help(char** argv) { + std::cout << "Usage: " << argv[0] << " [options]\n\n"; + std::cout << "Options:\n"; + std::cout << " -h, --help Show this help message and exit.\n"; + std::cout << " -t, --threads Number of threads to use (type: uint32_t, default: " + << std::thread::hardware_concurrency() / 2 << ".\n"; + std::cout + << " -e, --epochs Number of training epochs (type: int32_t, default: 1000).\n"; + std::cout << " -b, --batch Batch size for training (type: size_t, default: " + << static_cast(16 * 16384) << ").\n"; + std::cout + << " -d, --decay Learning rate decay factor per epoch (type: double, default: 0.91).\n"; + std::cout << "\nAdamW Optimizer Parameters:\n"; + std::cout << " --lr Learning rate (type: double, default: 10.0).\n"; + std::cout << " --beta1 Beta1 parameter (type: double, default: 0.9).\n"; + std::cout << " --beta2 Beta2 parameter (type: double, default: 0.999).\n"; + std::cout << " --weight_decay Weight decay (type: double, default: 0.0).\n"; +} + +int main(int argc, char** argv) { + + //Default params + + uint32_t thread_count_p = std::thread::hardware_concurrency() / 2; + int32_t epochs_p = 1000; + size_t batch_size_p = 16 * 16384; + f64 decay_p = 0.91; + + AdamWParams adam_p; + + //Args parsing + for (int i = 1; i < argc; ++i) { + std::string_view arg = argv[i]; + + if (arg == "--help" || arg == "-h") { + print_help(argv); + return 0; + } + + // Thread Count check + if ((arg == "--threads" || arg == "-t") && i + 1 < argc) { + thread_count_p = static_cast(std::stoul(argv[++i])); + } + // Epochs check + else if ((arg == "--epochs" || arg == "-e") && i + 1 < argc) { + epochs_p = static_cast(std::stoi(argv[++i])); + } + // Batch Size check + else if ((arg == "--batch" || arg == "-b") && i + 1 < argc) { + batch_size_p = static_cast(std::stoull(argv[++i])); + } + // AdamW Params check + else if (arg == "--lr" && i + 1 < argc) { + adam_p.learning_rate = std::stod(argv[++i]); + } else if (arg == "--beta1" && i + 1 < argc) { + adam_p.beta1 = std::stod(argv[++i]); + } else if (arg == "--beta2" && i + 1 < argc) { + adam_p.beta2 = std::stod(argv[++i]); + } else if (arg == "--weight_decay" && i + 1 < argc) { + adam_p.weight_decay = std::stod(argv[++i]); + } + + //Decay check + else if ((arg == "--decay" || arg == "-d") && i + 1 < argc) { + decay_p = std::stod(argv[++i]); + } else { + // Check if it's a flag without a value or an unknown flag + if (arg.rfind("--", 0) == 0 || arg.rfind("-", 0) == 0) { + if (i + 1 >= argc || (argv[i + 1][0] == '-' && !std::isdigit(argv[i + 1][1]))) { + std::cout << "Warning! Argument '" << argv[i] << "' has a missing value.\n Run " + << argv[0] << " --help to list all arguments."; + exit(-1); + } else { + std::cout << "Warning! Arg not recognized: '" << argv[i] << "'\n Run " + << argv[0] << " --help to list all arguments.\n"; + exit(-1); + } + } + } + } // Load fens from multiple files. std::vector positions; @@ -35,7 +122,7 @@ int main() { }; // Number of threads to use, default to half available - const u32 thread_count = std::max(1, std::thread::hardware_concurrency() / 2); + const u32 thread_count = std::max(1, thread_count_p); std::cout << "Running on " << thread_count << " threads" << std::endl; @@ -95,11 +182,12 @@ int main() { const ParameterCountInfo parameter_count = Globals::get().get_parameter_counts(); Parameters current_parameter_values = Graph::get().get_all_parameter_values(); - AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0); + AdamW optim(parameter_count, adam_p.learning_rate, adam_p.beta1, adam_p.beta2, 1e-8, + adam_p.weight_decay); - const i32 epochs = 1000; + const i32 epochs = epochs_p; const f64 K = 1.0 / 400; - const size_t batch_size = 16 * 16384; // Set batch size here + const size_t batch_size = batch_size_p; // Set batch size here std::mt19937 rng(std::random_device{}()); // Random number generator for shuffling @@ -342,7 +430,7 @@ int main() { << "s" << std::endl; if (epoch > 5) { - optim.set_lr(optim.get_lr() * 0.91); + optim.set_lr(optim.get_lr() * decay_p); } }