diff --git a/CMakeLists.txt b/CMakeLists.txt index 50c768a..bcfd3bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,7 +57,7 @@ include(ExternalProject) ExternalProject_Add( pysa-branching GIT_REPOSITORY https://github.com/nasa/pysa.git - GIT_TAG pysa-branching + GIT_TAG pysa-branching-earlystop GIT_SHALLOW 1 GIT_PROGRESS 1 CONFIGURE_COMMAND "" diff --git a/examples/simple.cpp b/examples/simple.cpp index b085054..9b8591e 100644 --- a/examples/simple.cpp +++ b/examples/simple.cpp @@ -59,7 +59,7 @@ int main() { #ifdef USE_MPI pysa::dpll::sat::mpi::optimize(mpi_comm_world, formula_, 2, true); #else - pysa::dpll::sat::optimize(formula_, 2, true); + pysa::dpll::sat::optimize(formula_, 2, true, {}, false); #endif // Output configurations diff --git a/include/pysa/dpll/dpll.hpp b/include/pysa/dpll/dpll.hpp index 6028f3d..2604ca3 100644 --- a/include/pysa/dpll/dpll.hpp +++ b/include/pysa/dpll/dpll.hpp @@ -32,7 +32,7 @@ namespace pysa::dpll { using namespace std::chrono_literals; -template auto DPLL(Init &&init, Get &&get, bool verbose = false, @@ -55,11 +55,14 @@ auto DPLL(Init &&init, Get &&get, bool verbose = false, if (!branch.partial()) { const std::scoped_lock lock_(mutex_); collected_.push_back(get(branch)); + return true; + } else { + return false; } }; // Get brancher - auto brancher_ = pysa::branching::DPLL( + auto brancher_ = pysa::branching::DPLL( branches_type{{root_}}, collect_, n_threads, sleep_time); // Get initial time diff --git a/include/pysa/python/main.cpp b/include/pysa/python/main.cpp index 6b7b9b7..af5c144 100644 --- a/include/pysa/python/main.cpp +++ b/include/pysa/python/main.cpp @@ -191,6 +191,9 @@ max_n_unsat: int, optional Number of maximum allowed unsatisfied clauses. n_threads: int, optional Number of threads to use (by default, all cores will be used) +stop_on_first: bool, optional + Stop when the first solution is found. By default, branch until all solutions + are found or walltime is reached. walltime: float, optional Maximum number of seconds to run the optimization. verbose: bool, optional @@ -201,7 +204,9 @@ verbose: bool, optional [](const std::vector> cnf, const std::size_t max_n_unsat, const std::optional n_threads, - const std::optional walltime, const bool verbose) { + const bool stop_on_first, + const std::optional walltime, + const bool verbose) { #ifdef USE_MPI if (walltime) throw std::logic_error( @@ -213,16 +218,17 @@ verbose: bool, optional #else if (walltime) return sat::optimize( - cnf, max_n_unsat, verbose, n_threads, + cnf, max_n_unsat, verbose, n_threads, stop_on_first, std::chrono::milliseconds( static_cast(walltime.value() * 1e3))); else - return sat::optimize(cnf, max_n_unsat, verbose, n_threads, nullptr); + return sat::optimize(cnf, max_n_unsat, verbose, n_threads, stop_on_first, nullptr); #endif }, py::arg("cnf"), py::pos_only(), py::arg("max_n_unsat") = 0, py::kw_only(), py::arg("n_threads") = py::none(), - py::arg("walltime") = std::numeric_limits::infinity(), + py::arg("stop_on_first") = false, + py::arg("walltime") = py::none(), py::arg("verbose") = false, __doc__); } diff --git a/include/pysa/sat/sat.hpp b/include/pysa/sat/sat.hpp index 8794255..22777e6 100644 --- a/include/pysa/sat/sat.hpp +++ b/include/pysa/sat/sat.hpp @@ -47,7 +47,22 @@ template n_threads = std::nullopt, + bool stop_on_first = false, WallTime &&walltime = nullptr, SleepTime &&sleep_time = 1ms) { +#ifndef NDEBUG + std::cout << "Entering pysa::dpll:sat::optimized\n" + <<"max_n_unsat = " << max_n_unsat << "\n" + << "verbose = " << verbose << "\n" + << "n_threads = " << n_threads.value_or(0) << "\n" + << "stop_on_first = " << stop_on_first << "\n"; + if constexpr (std::is_null_pointer_v){ + std::cout << "wall_time = none\n"; + } else { + std::cout << "wall_time = " + << std::chrono::duration_cast(walltime).count() << " ms\n"; + } + std::cout << "sleep_time = " << std::chrono::duration_cast(sleep_time).count() << " ms\n"; +#endif // Get root initializer const auto init_ = [&formula, max_n_unsat]() { return Branch<>(formula, max_n_unsat); @@ -59,9 +74,15 @@ auto optimize(Formula &&formula, std::size_t max_n_unsat = 0, }; // Get configurations from dpll - return DPLL(init_, get_, verbose, + if(stop_on_first){ + return DPLL(init_, get_, verbose, n_threads.value_or(std::thread::hardware_concurrency()), walltime, sleep_time); + } else { + return DPLL(init_, get_, verbose, + n_threads.value_or(std::thread::hardware_concurrency()), walltime, + sleep_time); + } } #ifdef USE_MPI diff --git a/pysa_dpll/_app/_main.py b/pysa_dpll/_app/_main.py index 8918b94..760ebf9 100644 --- a/pysa_dpll/_app/_main.py +++ b/pysa_dpll/_app/_main.py @@ -45,6 +45,10 @@ def main( help= "Number of threads to use (by default, all available cores are used)." )] = None, + stop_on_first: Annotated[ + bool, + Option("--stop-on-first", + help="Stop search when the first solution is found.")] = False, verbose: Annotated[ bool, Option("--verbose", "-v", help="Verbose output.")] = False): # Update parameters diff --git a/pysa_dpll/_app/_sat.py b/pysa_dpll/_app/_sat.py index 5b8615a..9b81e8f 100644 --- a/pysa_dpll/_app/_sat.py +++ b/pysa_dpll/_app/_sat.py @@ -66,6 +66,7 @@ def sat(max_n_unsat: Annotated[ collected_, branches_ = optimize(cnf_, max_n_unsat=__params['max_n_unsat'], n_threads=__params['n_threads'], + stop_on_first=__params['stop_on_first'], walltime=__params['walltime'], verbose=__params['verbose']) diff --git a/src/dpll-sat.cpp b/src/dpll-sat.cpp index 15693bd..cab2736 100644 --- a/src/dpll-sat.cpp +++ b/src/dpll-sat.cpp @@ -52,7 +52,7 @@ int main(int argc, char* argv[]) { #endif // Print the required arguments - if (argc < 2 || argc > 5) { + if (argc < 2 || argc > 6) { std::cerr << "Usage: " << std::filesystem::path(argv[0]).filename().string() << " cnf_file [max_unsat = 0] [n_threads = 0] [verbose = 0]" << std::endl; @@ -64,6 +64,9 @@ int main(int argc, char* argv[]) { std::cerr << " n_threads Number of threads to use (default = 0, " "that is suggested by the implementation)" << std::endl; + std::cerr << " stop_on_first Stop search when the first solution is found. " + "(default = 0, find all possible solutions)" + << std::endl; std::cerr << " verbose Level of verbosity (default = 0)" << std::endl; return EXIT_FAILURE; @@ -78,13 +81,17 @@ int main(int argc, char* argv[]) { // Set default value for number of threads (0 = implementation specific) std::size_t n_threads = 0; + // Default + bool stop_on_first = false; // Set default value for verbosity std::size_t verbose = false; // Assign provided values switch (argc) { + case 6: + verbose = std::stoull(argv[5]); case 5: - verbose = std::stoull(argv[4]); + stop_on_first = std::stoull(argv[4]); case 4: n_threads = std::stoull(argv[3]); case 3: @@ -105,7 +112,7 @@ int main(int argc, char* argv[]) { pysa::dpll::sat::mpi::optimize(mpi_comm_world, formula, max_unsat, verbose, n_threads); #else - pysa::dpll::sat::optimize(formula, max_unsat, verbose, n_threads); + pysa::dpll::sat::optimize(formula, max_unsat, verbose, n_threads, stop_on_first); #endif #ifdef USE_MPI diff --git a/tests/test_dpll_sat.cpp b/tests/test_dpll_sat.cpp index 3ca8307..f780845 100644 --- a/tests/test_dpll_sat.cpp +++ b/tests/test_dpll_sat.cpp @@ -51,6 +51,8 @@ int main() { TestDPLLSAT(3, 21, 60, std::size_t{1} << 20, true); TestDPLLSAT(2, 21, 60, 10, true); TestDPLLSAT(3, 21, 60, 0, true); + TestDPLLSAT(3, 21, 60, 0, true, true); + TestDPLLSAT(3, 21, 60, 1, true, true); } #ifdef USE_MPI diff --git a/tests/test_dpll_sat.hpp b/tests/test_dpll_sat.hpp index f65fe83..320df6e 100644 --- a/tests/test_dpll_sat.hpp +++ b/tests/test_dpll_sat.hpp @@ -190,7 +190,7 @@ auto TestBranch() { } auto TestDPLLSAT(std::size_t k, std::size_t n, std::size_t m, - std::size_t max_n_unsat, bool verbose = false) { + std::size_t max_n_unsat, bool verbose = false, bool stop_on_first = false) { // Get random SAT problem const auto formula_ = sat::GetRandomInstance(k, n, m); @@ -200,10 +200,13 @@ auto TestDPLLSAT(std::size_t k, std::size_t n, std::size_t m, if (verbose) std::cerr << "Done!" << std::endl; // Get configurations from dpll - const auto [dpll_, branches_] = sat::optimize(formula_, max_n_unsat, verbose); + const auto [dpll_, branches_] = sat::optimize(formula_, max_n_unsat, verbose, {}, stop_on_first); + if(stop_on_first) + std::cerr << "Stopped with " << std::size(dpll_) << " solutions and " + << std::size(branches_) << " remaining branches." << std::endl; // Branches should be empty - if (std::size(branches_)) + if (!stop_on_first && std::size(branches_)) throw std::runtime_error("Some branches are not used"); // Check number of unsat @@ -212,7 +215,7 @@ auto TestDPLLSAT(std::size_t k, std::size_t n, std::size_t m, throw std::runtime_error("Wrong number of unsat clauses"); // Check size - if (std::size(dpll_) != std::size(all_)) + if (!stop_on_first && (std::size(dpll_) != std::size(all_))) throw std::runtime_error("Collected wrong number of states"); // Check if everything has been properly collected