diff --git a/CMakeLists.txt b/CMakeLists.txt index cc1a3f1c..0bca6cd9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,6 +138,7 @@ else() add_test(${name} ${name}) endfunction() + do_test(test_move) do_test(test_piece_count) do_test(test_see) do_test(test_see_prototype) diff --git a/src/common.hpp b/src/common.hpp index 2e444d90..92137f5a 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -2,6 +2,7 @@ #include "util/types.hpp" #include +#include namespace Clockwork { @@ -78,6 +79,32 @@ constexpr char piece_char(PieceType piece) { unreachable(); } +constexpr std::optional parse_piece_char(char ch) { + using enum PieceType; + switch (ch) { + case 'P': + case 'p': + return Pawn; + case 'N': + case 'n': + return Knight; + case 'B': + case 'b': + return Bishop; + case 'R': + case 'r': + return Rook; + case 'Q': + case 'q': + return Queen; + case 'K': + case 'k': + return King; + default: + return std::nullopt; + } +} + constexpr bool is_slider(PieceType ptype) { return ptype >= PieceType::Bishop && ptype <= PieceType::Queen; } diff --git a/src/move.cpp b/src/move.cpp index 27bbee1e..2d0ec65a 100644 --- a/src/move.cpp +++ b/src/move.cpp @@ -1,4 +1,5 @@ #include "move.hpp" +#include "movegen.hpp" #include "position.hpp" @@ -76,4 +77,128 @@ std::optional Move::parse(std::string_view str, const Position& ctx) { return Move(*from, *to, *mf); } +std::optional Move::parseSan(std::string_view san, const Position& ctx) { + Color stm = ctx.active_color(); + + if (san.size() < 2) { + return std::nullopt; + } + + if (san.ends_with('#') || san.ends_with('+')) { + san.remove_suffix(1); + } + + if (san == "O-O") { + Square rook_hside = ctx.rook_info(stm).hside; + if (!rook_hside.is_valid()) { + return std::nullopt; + } + return Move(ctx.king_sq(stm), rook_hside, MoveFlags::Castle); + } + + if (san == "O-O-O") { + Square rook_aside = ctx.rook_info(stm).aside; + if (!rook_aside.is_valid()) { + return std::nullopt; + } + return Move(ctx.king_sq(stm), rook_aside, MoveFlags::Castle); + } + + PieceType promo = PieceType::None; + if (san.size() >= 4 && san[san.size() - 2] == '=') { + if (auto p = parse_piece_char(san.back())) { + promo = *p; + } else { + return std::nullopt; + } + san.remove_suffix(2); + } + + if (san.size() < 2) { + return std::nullopt; + } + + bool is_capture = san.size() > 3 && san[san.size() - 3] == 'x'; + auto to = Square::parse(san.substr(san.size() - 2)); + if (!to) { + return std::nullopt; + } + san.remove_suffix(2 + is_capture); + + PieceType src_ptype = PieceType::None; + switch (san.size()) { + case 0: + // e.g. e4 + { + i32 delta = stm == Color::White ? -8 : 8; + Square push_src{static_cast(to->raw + delta)}; + Square double_src{static_cast(to->raw + delta * 2)}; + + if (to->relative_sq(stm).rank() < 7 && ctx.board()[push_src].ptype() == PieceType::Pawn + && ctx.board()[push_src].color() == stm) { + if (auto mf = build_move_flags(false, false, false, promo)) { + return Move(push_src, *to, *mf); + } + } else if (to->relative_sq(stm).rank() == 3 + && ctx.board()[double_src].ptype() == PieceType::Pawn + && ctx.board()[double_src].color() == stm) { + if (auto mf = build_move_flags(false, false, false, promo)) { + return Move(double_src, *to, *mf); + } + } + + return std::nullopt; + } + + case 1: + if (san[0] >= 'a' && san[0] <= 'h') { + // e.g. axb3 + if (!is_capture) { + return std::nullopt; + } + src_ptype = PieceType::Pawn; + } else if (auto p = parse_piece_char(san[0])) { + // e.g. Bb3, Bxb3 + src_ptype = *p; + san.remove_prefix(1); + } else { + return std::nullopt; + } + break; + case 2: // e.g. Qhxa3, Q3xb7, Qba4, Q6a3 + case 3: // e.g. Qa1b2, Qa1xb2 + if (auto p = parse_piece_char(san[0])) { + src_ptype = *p; + } else { + return std::nullopt; + } + san.remove_prefix(1); + break; + default: + return std::nullopt; + } + + bool is_en_passant = src_ptype == PieceType::Pawn && is_capture && *to == ctx.en_passant(); + + MoveGen movegen{ctx}; + std::vector candidates; + PieceMask piece_mask = ctx.attack_table(stm).read(*to) & ctx.piece_list(stm).mask_eq(src_ptype); + for (PieceId id : piece_mask) { + Square from = ctx.piece_list_sq(stm)[id]; + if (from.to_string().find(san) != std::string::npos) { + if (auto mf = build_move_flags(false, is_en_passant, is_capture, promo)) { + Move move = Move(from, *to, *mf); + if (movegen.is_legal(move)) { + candidates.push_back(move); + } + } + } + } + + if (candidates.size() != 1) { + return std::nullopt; + } + return candidates[0]; +} + } diff --git a/src/move.hpp b/src/move.hpp index 4f3d60e7..5177cff1 100644 --- a/src/move.hpp +++ b/src/move.hpp @@ -4,7 +4,6 @@ #include "square.hpp" #include "util/types.hpp" - #include #include @@ -31,6 +30,29 @@ enum class MoveFlags : u16 { PromoQueenCapture = (0b1100 | (static_cast(PieceType::Queen) - 2)) << 12, }; +inline std::optional +build_move_flags(bool castle, bool en_passant, bool capture, PieceType promo) { + using enum MoveFlags; + if (castle) { + return Castle; + } + if (en_passant) { + return EnPassant; + } + + u16 flags = 0; + if (capture) { + flags |= static_cast(CaptureBit); + } + if (promo != PieceType::None) { + if (promo < PieceType::Knight || promo > PieceType::Queen) { + return std::nullopt; + } + flags |= (static_cast(promo) - 2) << 12; + } + return static_cast(flags); +} + struct Move { u16 raw = 0; constexpr Move() = default; @@ -84,8 +106,14 @@ struct Move { return static_cast(((raw >> 12) & 0b0011) + 2); } + // Parse UCI move notation + // All legal moves will be parsed, but a successfully parsed move is not guaranteed to be legal. static std::optional parse(std::string_view str, const Position& context); + // Parse Standard Algebraic notation (SAN) moves + // All legal moves will be parsed, but a successfully parsed move is not guaranteed to be legal. + static std::optional parseSan(std::string_view san, const Position& context); + [[nodiscard]] constexpr bool operator==(const Move& other) const = default; [[nodiscard]] constexpr bool operator!=(const Move& other) const = default; diff --git a/src/position.cpp b/src/position.cpp index c6f8718d..6175d56c 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -967,6 +967,12 @@ std::ostream& operator<<(std::ostream& os, const Position& position) { return os; } +std::string Position::to_string() const { + std::stringstream ss; + ss << *this; + return ss.str(); +} + bool Position::is_reversible(Move move) { return !(move.is_capture() || move.is_promotion() || move.is_castle() || (m_board[move.from()].ptype() == PieceType::Pawn)); diff --git a/src/position.hpp b/src/position.hpp index f2e291a5..2f3456a6 100644 --- a/src/position.hpp +++ b/src/position.hpp @@ -282,6 +282,8 @@ struct Position { bool operator==(const Position&) const = default; friend std::ostream& operator<<(std::ostream& os, const Position& position); + std::string to_string() const; + private: std::array m_attack_table{}; std::array, 2> m_piece_list_sq{}; diff --git a/src/square.hpp b/src/square.hpp index 51a1cf70..983828df 100644 --- a/src/square.hpp +++ b/src/square.hpp @@ -67,6 +67,13 @@ struct Square { return Square{static_cast(raw ^ 56)}; } + constexpr std::string to_string() { + std::string result; + result += static_cast('a' + file()); + result += static_cast('1' + rank()); + return result; + } + friend std::ostream& operator<<(std::ostream& os, Square sq) { char file = static_cast('a' + sq.file()); return os << file << sq.rank() + 1; diff --git a/tests/test_move.cpp b/tests/test_move.cpp new file mode 100644 index 00000000..d291f0be --- /dev/null +++ b/tests/test_move.cpp @@ -0,0 +1,105 @@ +#include "move.hpp" +#include "movegen.hpp" +#include "position.hpp" +#include "test.hpp" +#include +#include +#include +#include + +using namespace Clockwork; + +void game1() { + std::vector game1_record{{ + "e4", "c5", // + "Nf3", "e6", // + "d4", "cxd4", // + "Nxd4", "Nc6", // + "Nb5", "d6", // + "c4", "Nf6", // + "N1c3", "a6", // + "Na3", "d5", // + "cxd5", "exd5", // + "exd5", "Nb4", // + "Be2", "Bc5", // + "O-O", "O-O", // + "Bf3", "Bf5", // + "Bg5", "Re8", // + "Qd2", "b5", // + "Rad1", "Nd3", // + "Nab1", "h6", // + "Bh4", "b4", // + "Na4", "Bd6", // + "Bg3", "Rc8", // + "b3", "g5", // + "Bxd6", "Qxd6", // + "g3", "Nd7", // + "Bg2", "Qf6", // + "a3", "a5", // + "axb4", "axb4", // + "Qa2", "Bg6", // + "d6", "g4", // + "Qd2", "Kg7", // + "f3", "Qxd6", // + "fxg4", "Qd4+", // + "Kh1", "Nf6", // + "Rf4", "Ne4", // + "Qxd3", "Nf2+", // + "Rxf2", "Bxd3", // + "Rfd2", "Qe3", // + "Rxd3", "Rc1", // + "Nb2", "Qf2", // + "Nd2", "Rxd1+", // + "Nxd1", "Re1+", // + }}; + + Position pos = + Position::parse("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1").value(); + for (std::string_view move_str : game1_record) { + std::cout << move_str << " : "; + + auto move = Move::parseSan(move_str, pos); + REQUIRE(move); + + std::cout << *move << " : "; + + MoveGen movegen{pos}; + REQUIRE(movegen.is_legal(*move)); + + pos = pos.move(*move); + + std::cout << pos << std::endl; + } + + REQUIRE(pos.to_string() == "8/5pk1/7p/8/1p4P1/1P1R2P1/3N1qBP/3Nr2K w - - 1 41"); +} + +void cases() { + std::vector> cases{{ + {"7r/3r1p1p/6p1/1p6/2B5/5PP1/1Q5P/1K1k4 b - - 0 38", "bxc4", + "7r/3r1p1p/6p1/8/2p5/5PP1/1Q5P/1K1k4 w - - 0 39"}, + {"2n1r1n1/1p1k1p2/6pp/R2pP3/3P4/8/5PPP/2R3K1 b - - 0 30", "Nge7", + "2n1r3/1p1knp2/6pp/R2pP3/3P4/8/5PPP/2R3K1 w - - 1 31"}, + {"8/5p2/1kn1r1n1/1p1pP3/6K1/8/4R3/5R2 b - - 9 60", "Ngxe5+", + "8/5p2/1kn1r3/1p1pn3/6K1/8/4R3/5R2 w - - 0 61"}, + {"r3k2r/pp1bnpbp/1q3np1/3p4/3N1P2/1PP1Q2P/P1B3P1/RNB1K2R b KQkq - 5 15", "Ng8", + "r3k1nr/pp1bnpbp/1q4p1/3p4/3N1P2/1PP1Q2P/P1B3P1/RNB1K2R w KQkq - 6 16"}, + }}; + for (auto [before, san, after] : cases) { + std::cout << before << " + " << san << ":"; + auto pos1 = Position::parse(before); + REQUIRE(pos1); + auto move = Move::parseSan(san, *pos1); + REQUIRE(move); + std::cout << *move << " = "; + Position pos2 = pos1->move(*move); + std::cout << pos2 << std::endl; + REQUIRE(pos2.to_string() == after); + } +} + +int main() { + game1(); + cases(); + return 0; +}