diff --git a/src/extra/nnue_data_binpack_format.h b/src/extra/nnue_data_binpack_format.h index b9e45c3e..ceb5c415 100644 --- a/src/extra/nnue_data_binpack_format.h +++ b/src/extra/nnue_data_binpack_format.h @@ -3132,7 +3132,11 @@ namespace chess }; } - static const EnumArray2 pseudoAttacks = generatePseudoAttacks(); + static const EnumArray2& pseudoAttacks() + { + static const EnumArray2 s_pseudoAttacks = generatePseudoAttacks(); + return s_pseudoAttacks; + } [[nodiscard]] static Bitboard generatePositiveRayAttacks(Direction dir, Square fromSq) { @@ -3187,24 +3191,29 @@ namespace chess return bbs; } - static const std::array, 8> positiveRayAttacks = generatePositiveRayAttacks(); + + static const std::array, 8>& positiveRayAttacks() + { + static const std::array, 8> s_positiveRayAttacks = generatePositiveRayAttacks(); + return s_positiveRayAttacks; + } template [[nodiscard]] static Bitboard slidingAttacks(Square sq, Bitboard occupied) { assert(sq.isOk()); - Bitboard attacks = positiveRayAttacks[DirV][sq]; + Bitboard attacks = positiveRayAttacks()[DirV][sq]; if constexpr (DirV == NorthWest || DirV == North || DirV == NorthEast || DirV == East) { Bitboard blocker = (attacks & occupied) | h8; // set highest bit (H8) so msb never fails - return attacks ^ positiveRayAttacks[DirV][blocker.first()]; + return attacks ^ positiveRayAttacks()[DirV][blocker.first()]; } else { Bitboard blocker = (attacks & occupied) | a1; - return attacks ^ positiveRayAttacks[DirV][blocker.last()]; + return attacks ^ positiveRayAttacks()[DirV][blocker.last()]; } } @@ -3290,10 +3299,10 @@ namespace chess { for (PieceType pt : { PieceType::Bishop, PieceType::Rook }) { - const Bitboard s1Attacks = pseudoAttacks[pt][s1]; + const Bitboard s1Attacks = pseudoAttacks()[pt][s1]; if (s1Attacks.isSet(s2)) { - const Bitboard s2Attacks = pseudoAttacks[pt][s2]; + const Bitboard s2Attacks = pseudoAttacks()[pt][s2]; return (s1Attacks & s2Attacks) | s1 | s2; } } @@ -3420,14 +3429,14 @@ namespace chess assert(sq.isOk()); - return detail::pseudoAttacks[PieceTypeV][sq]; + return detail::pseudoAttacks()[PieceTypeV][sq]; } [[nodiscard]] inline Bitboard pseudoAttacks(PieceType pt, Square sq) { assert(sq.isOk()); - return detail::pseudoAttacks[pt][sq]; + return detail::pseudoAttacks()[pt][sq]; } [[nodiscard]] inline Bitboard pawnAttacks(Bitboard pawns, Color color) @@ -4373,6 +4382,22 @@ namespace chess std::uint64_t low; }; + struct Position; + + struct MoveLegalityChecker + { + MoveLegalityChecker(const Position& position); + + [[nodiscard]] bool isPseudoLegalMoveLegal(const Move& move) const; + + private: + const Position* m_position; + Bitboard m_checkers; + Bitboard m_ourBlockersForKing; + Bitboard m_potentialCheckRemovals; + Square m_ksq; + }; + struct Position : public Board { using BaseType = Board; @@ -4412,6 +4437,11 @@ namespace chess [[nodiscard]] inline std::string fen() const; + [[nodiscard]] MoveLegalityChecker moveLegalityChecker() const + { + return { *this }; + } + constexpr void setEpSquareUnchecked(Square sq) { m_epSquare = sq; @@ -4498,6 +4528,8 @@ namespace chess [[nodiscard]] inline bool isCheckAfterMove(Move move) const; + [[nodiscard]] inline bool isMoveLegal(Move move) const; + [[nodiscard]] inline bool isPseudoLegalMoveLegal(Move move) const; [[nodiscard]] inline bool isMovePseudoLegal(Move move) const; @@ -4665,6 +4697,592 @@ namespace chess std::uint8_t m_packedState[16]; }; + namespace movegen + { + // For a pseudo-legal move the following are true: + // - the moving piece has the pos.sideToMove() color + // - the destination square is either empty or has a piece of the opposite color + // - if it is a pawn move it is valid (but may be illegal due to discovered checks) + // - if it is not a pawn move then the destination square is contained in attacks() + // - if it is a castling it is legal + // - a move other than castling may create a discovered attack on the king + // - a king may walk into a check + + template + inline void forEachPseudoLegalPawnMove(const Position& pos, Square from, FuncT&& f) + { + const Color sideToMove = pos.sideToMove(); + const Square epSquare = pos.epSquare(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + + Bitboard attackTargets = theirPieces; + if (epSquare != Square::none()) + { + attackTargets |= epSquare; + } + + const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), sideToMove) & attackTargets; + + const Rank secondToLastRank = sideToMove == Color::White ? rank7 : rank2; + const auto forward = sideToMove == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); + + // promotions + if (from.rank() == secondToLastRank) + { + // capture promotions + for (Square toSq : attacks) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, toSq, MoveType::Promotion, Piece(pt, sideToMove) }; + f(move); + } + } + + // push promotions + const Square toSq = from + forward; + if (!occupied.isSet(toSq)) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, toSq, MoveType::Promotion, Piece(pt, sideToMove) }; + f(move); + } + } + } + else + { + // captures + for (Square toSq : attacks) + { + Move move{ from, toSq, (toSq == epSquare) ? MoveType::EnPassant : MoveType::Normal }; + f(move); + } + + const Square toSq = from + forward; + + // single push + if (!occupied.isSet(toSq)) + { + const Rank startRank = sideToMove == Color::White ? rank2 : rank7; + if (from.rank() == startRank) + { + // double push + const Square toSq2 = toSq + forward; + if (!occupied.isSet(toSq2)) + { + Move move{ from, toSq2 }; + f(move); + } + } + + Move move{ from, toSq }; + f(move); + } + } + } + + template + inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) + { + const Square epSquare = pos.epSquare(); + const Bitboard ourPieces = pos.piecesBB(SideToMoveV); + const Bitboard theirPieces = pos.piecesBB(!SideToMoveV); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard pawns = pos.piecesBB(Piece(PieceType::Pawn, SideToMoveV)); + + const Bitboard secondToLastRank = SideToMoveV == Color::White ? bb::rank7 : bb::rank2; + const Bitboard secondRank = SideToMoveV == Color::White ? bb::rank2 : bb::rank7; + + const auto singlePawnMoveDestinationOffset = SideToMoveV == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); + const auto doublePawnMoveDestinationOffset = SideToMoveV == Color::White ? FlatSquareOffset(0, 2) : FlatSquareOffset(0, -2); + + { + const int backward = SideToMoveV == Color::White ? -1 : 1; + const int backward2 = backward * 2; + + const Bitboard doublePawnMoveStarts = + pawns + & secondRank + & ~(occupied.shiftedVertically(backward) | occupied.shiftedVertically(backward2)); + + const Bitboard singlePawnMoveStarts = + pawns + & ~secondToLastRank + & ~occupied.shiftedVertically(backward); + + for (Square from : doublePawnMoveStarts) + { + const Square to = from + doublePawnMoveDestinationOffset; + f(Move::normal(from, to)); + } + + for (Square from : singlePawnMoveStarts) + { + const Square to = from + singlePawnMoveDestinationOffset; + f(Move::normal(from, to)); + } + } + + { + const Bitboard lastRank = SideToMoveV == Color::White ? bb::rank8 : bb::rank1; + const FlatSquareOffset westCaptureOffset = SideToMoveV == Color::White ? FlatSquareOffset(-1, 1) : FlatSquareOffset(-1, -1); + const FlatSquareOffset eastCaptureOffset = SideToMoveV == Color::White ? FlatSquareOffset(1, 1) : FlatSquareOffset(1, -1); + + const Bitboard pawnsWithWestCapture = bb::eastPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; + const Bitboard pawnsWithEastCapture = bb::westPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; + + for (Square from : pawnsWithWestCapture) + { + f(Move::normal(from, from + westCaptureOffset)); + } + + for (Square from : pawnsWithEastCapture) + { + f(Move::normal(from, from + eastCaptureOffset)); + } + } + + if (epSquare != Square::none()) + { + const Bitboard pawnsThatCanCapture = bb::pawnAttacks(Bitboard::square(epSquare), !SideToMoveV) & pawns; + for (Square from : pawnsThatCanCapture) + { + f(Move::enPassant(from, epSquare)); + } + } + + for (Square from : pawns & secondToLastRank) + { + const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), SideToMoveV) & theirPieces; + + // capture promotions + for (Square to : attacks) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, to, MoveType::Promotion, Piece(pt, SideToMoveV) }; + f(move); + } + } + + // push promotions + const Square to = from + singlePawnMoveDestinationOffset; + if (!occupied.isSet(to)) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, to, MoveType::Promotion, Piece(pt, SideToMoveV) }; + f(move); + } + } + } + } + + template + inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) + { + if (pos.sideToMove() == Color::White) + { + forEachPseudoLegalPawnMove(pos, std::forward(f)); + } + else + { + forEachPseudoLegalPawnMove(pos, std::forward(f)); + } + } + + template + inline void forEachPseudoLegalPieceMove(const Position& pos, Square from, FuncT&& f) + { + static_assert(PieceTypeV != PieceType::None); + + if constexpr (PieceTypeV == PieceType::Pawn) + { + forEachPseudoLegalPawnMove(pos, from, f); + } + else + { + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard attacks = bb::attacks(from, occupied) & ~ourPieces; + + for (Square toSq : attacks) + { + Move move{ from, toSq }; + f(move); + } + } + } + + template + inline void forEachPseudoLegalPieceMove(const Position& pos, FuncT&& f) + { + static_assert(PieceTypeV != PieceType::None); + + if constexpr (PieceTypeV == PieceType::Pawn) + { + forEachPseudoLegalPawnMove(pos, f); + } + else + { + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard pieces = pos.piecesBB(Piece(PieceTypeV, sideToMove)); + for (Square fromSq : pieces) + { + const Bitboard attacks = bb::attacks(fromSq, occupied) & ~ourPieces; + for (Square toSq : attacks) + { + Move move{ fromSq, toSq }; + f(move); + } + } + } + } + + template + inline void forEachCastlingMove(const Position& pos, FuncT&& f) + { + CastlingRights rights = pos.castlingRights(); + if (rights == CastlingRights::None) + { + return; + } + + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + + // we first reduce the set of legal castlings by checking the paths for pieces + if (sideToMove == Color::White) + { + if ((CastlingTraits::castlingPath[Color::White][CastleType::Short] & occupied).any()) rights &= ~CastlingRights::WhiteKingSide; + if ((CastlingTraits::castlingPath[Color::White][CastleType::Long] & occupied).any()) rights &= ~CastlingRights::WhiteQueenSide; + rights &= ~CastlingRights::Black; + } + else + { + if ((CastlingTraits::castlingPath[Color::Black][CastleType::Short] & occupied).any()) rights &= ~CastlingRights::BlackKingSide; + if ((CastlingTraits::castlingPath[Color::Black][CastleType::Long] & occupied).any()) rights &= ~CastlingRights::BlackQueenSide; + rights &= ~CastlingRights::White; + } + + if (rights == CastlingRights::None) + { + return; + } + + // King must not be in check. Done here because it is quite expensive. + const Square ksq = pos.kingSquare(sideToMove); + if (pos.isSquareAttacked(ksq, !sideToMove)) + { + return; + } + + // Loop through all possible castlings. + for (CastleType castlingType : values()) + { + const CastlingRights right = CastlingTraits::castlingRights[sideToMove][castlingType]; + + if (!contains(rights, right)) + { + continue; + } + + // If we have this castling right + // we check whether the king passes an attacked square. + const Square passedSquare = CastlingTraits::squarePassedByKing[sideToMove][castlingType]; + if (pos.isSquareAttacked(passedSquare, !sideToMove)) + { + continue; + } + + // If it's a castling move then the change in square occupation + // cannot have an effect because otherwise there would be + // a slider attacker attacking the castling king. + if (pos.isSquareAttacked(CastlingTraits::kingDestination[sideToMove][castlingType], !sideToMove)) + { + continue; + } + + // If not we can castle. + Move move = Move::castle(castlingType, sideToMove); + f(move); + } + } + + // Calls a given function for all pseudo legal moves for the position. + // `pos` must be a legal chess position + template + inline void forEachPseudoLegalMove(const Position& pos, FuncT&& func) + { + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachCastlingMove(pos, func); + } + + // Calls a given function for all legal moves for the position. + // `pos` must be a legal chess position + template + inline void forEachLegalMove(const Position& pos, FuncT&& func) + { + auto funcIfLegal = [&func, checker = pos.moveLegalityChecker()](Move move) { + if (checker.isPseudoLegalMoveLegal(move)) + { + func(move); + } + }; + + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachCastlingMove(pos, func); + } + + // Generates all pseudo legal moves for the position. + // `pos` must be a legal chess position + [[nodiscard]] std::vector generatePseudoLegalMoves(const Position& pos); + + // Generates all legal moves for the position. + // `pos` must be a legal chess position + [[nodiscard]] std::vector generateLegalMoves(const Position& pos); + } + + [[nodiscard]] inline bool Position::isCheck() const + { + return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove); + } + + [[nodiscard]] inline Bitboard Position::checkers() const + { + return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove); + } + + [[nodiscard]] inline bool Position::isCheckAfterMove(Move move) const + { + return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove); + } + + [[nodiscard]] inline bool Position::isMoveLegal(Move move) const + { + return + isMovePseudoLegal(move) + && isPseudoLegalMoveLegal(move); + } + + [[nodiscard]] inline bool Position::isPseudoLegalMoveLegal(Move move) const + { + return + (move.type == MoveType::Castle) + || !isOwnKingAttackedAfterMove(move); + } + + [[nodiscard]] inline bool Position::isMovePseudoLegal(Move move) const + { + if (!move.from.isOk() || !move.to.isOk()) + { + return false; + } + + if (move.from == move.to) + { + return false; + } + + if (move.type != MoveType::Promotion && move.promotedPiece != Piece::none()) + { + return false; + } + + const Piece movedPiece = pieceAt(move.from); + if (movedPiece == Piece::none()) + { + return false; + } + + if (movedPiece.color() != m_sideToMove) + { + return false; + } + + const Bitboard occupied = piecesBB(); + const Bitboard ourPieces = piecesBB(m_sideToMove); + const bool isNormal = move.type == MoveType::Normal; + + switch (movedPiece.type()) + { + case PieceType::Pawn: + { + bool isValid = false; + // TODO: use iterators so we don't loop over all moves + // when we can avoid it. + movegen::forEachPseudoLegalPawnMove(*this, move.from, [&isValid, &move](const Move& genMove) { + if (move == genMove) + { + isValid = true; + } + }); + return isValid; + } + + case PieceType::Bishop: + return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::Knight: + return isNormal && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); + + case PieceType::Rook: + return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::Queen: + return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::King: + { + if (move.type == MoveType::Castle) + { + bool isValid = false; + movegen::forEachCastlingMove(*this, [&isValid, &move](const Move& genMove) { + if (move == genMove) + { + isValid = true; + } + }); + return isValid; + } + else + { + return isNormal && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); + } + } + + default: + return false; + } + } + + [[nodiscard]] inline Bitboard Position::blockersForKing(Color color) const + { + const Color attackerColor = !color; + + const Bitboard occupied = piecesBB(); + + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + + const Square ksq = kingSquare(color); + + const Bitboard opponentBishopLikePieces = (bishops | queens); + const Bitboard bishopPseudoAttacks = bb::pseudoAttacks(ksq); + + const Bitboard opponentRookLikePieces = (rooks | queens); + const Bitboard rookPseudoAttacks = bb::pseudoAttacks(ksq); + + const Bitboard xrayers = + (bishopPseudoAttacks & opponentBishopLikePieces) + | (rookPseudoAttacks & opponentRookLikePieces); + + Bitboard allBlockers = Bitboard::none(); + + for (Square xrayer : xrayers) + { + const Bitboard blockers = bb::between(xrayer, ksq) & occupied; + if (blockers.exactlyOne()) + { + allBlockers |= blockers; + } + } + + return allBlockers; + } + + inline MoveLegalityChecker::MoveLegalityChecker(const Position& position) : + m_position(&position), + m_checkers(position.checkers()), + m_ourBlockersForKing( + position.blockersForKing(position.sideToMove()) + & position.piecesBB(position.sideToMove()) + ), + m_ksq(position.kingSquare(position.sideToMove())) + { + if (m_checkers.exactlyOne()) + { + const Bitboard knightCheckers = m_checkers & bb::pseudoAttacks(m_ksq); + if (knightCheckers.any()) + { + // We're checked by a knight, we have to remove it or move the king. + m_potentialCheckRemovals = knightCheckers; + } + else + { + // If we're not checked by a knight we can block it. + m_potentialCheckRemovals = bb::between(m_ksq, m_checkers.first()) | m_checkers; + } + } + else + { + // Double check, king has to move. + m_potentialCheckRemovals = Bitboard::none(); + } + } + + [[nodiscard]] inline bool MoveLegalityChecker::isPseudoLegalMoveLegal(const Move& move) const + { + if (m_checkers.any()) + { + if (move.from == m_ksq || move.type == MoveType::EnPassant) + { + return m_position->isPseudoLegalMoveLegal(move); + } + else + { + // This means there's only one check and we either + // blocked it or removed the piece that attacked + // our king. So the only threat is if it's a discovered check. + return + m_potentialCheckRemovals.isSet(move.to) + && !m_ourBlockersForKing.isSet(move.from); + } + } + else + { + if (move.from == m_ksq) + { + return m_position->isPseudoLegalMoveLegal(move); + } + else if (move.type == MoveType::EnPassant) + { + return !m_position->createsDiscoveredAttackOnOwnKing(move); + } + else if (m_ourBlockersForKing.isSet(move.from)) + { + // If it was a blocker it may have only moved in line with our king. + // Otherwise it's a discovered check. + return bb::line(m_ksq, move.from).isSet(move.to); + } + else + { + return true; + } + } + } + static_assert(sizeof(CompressedPosition) == 24); static_assert(std::is_trivially_copyable_v); @@ -5483,57 +6101,6 @@ namespace chess return { move, captured, oldEpSquare, oldCastlingRights }; } - [[nodiscard]] inline bool Position::isCheck() const - { - return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove); - } - - [[nodiscard]] inline Bitboard Position::checkers() const - { - return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove); - } - - [[nodiscard]] bool Position::isCheckAfterMove(Move move) const - { - return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove); - } - - [[nodiscard]] inline Bitboard Position::blockersForKing(Color color) const - { - const Color attackerColor = !color; - - const Bitboard occupied = piecesBB(); - - const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); - const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); - const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); - - const Square ksq = kingSquare(color); - - const Bitboard opponentBishopLikePieces = (bishops | queens); - const Bitboard bishopPseudoAttacks = bb::pseudoAttacks(ksq); - - const Bitboard opponentRookLikePieces = (rooks | queens); - const Bitboard rookPseudoAttacks = bb::pseudoAttacks(ksq); - - const Bitboard xrayers = - (bishopPseudoAttacks & opponentBishopLikePieces) - | (rookPseudoAttacks & opponentRookLikePieces); - - Bitboard allBlockers = Bitboard::none(); - - for (Square xrayer : xrayers) - { - const Bitboard blockers = bb::between(xrayer, ksq) & occupied; - if (blockers.exactlyOne()) - { - allBlockers |= blockers; - } - } - - return allBlockers; - } - [[nodiscard]] inline Position Position::afterMove(Move move) const { Position cpy(*this); @@ -5756,6 +6323,25 @@ namespace binpack return chess::Move{from, to, type}; } + [[nodiscard]] std::string toString() const + { + const chess::Square to = static_cast((m_raw & (0b111111 << 0) >> 0)); + const chess::Square from = static_cast((m_raw & (0b111111 << 6)) >> 6); + + const unsigned promotionIndex = (m_raw & (0b11 << 12)) >> 12; + const chess::PieceType promotionType = static_cast(static_cast(chess::PieceType::Knight) + promotionIndex); + + std::string r; + chess::parser_bits::appendSquareToString(from, r); + chess::parser_bits::appendSquareToString(to, r); + if (promotionType != chess::PieceType::None) + { + r += chess::EnumTraits::toChar(promotionType, chess::Color::Black); + } + + return r; + } + private: std::uint16_t m_raw; }; @@ -6233,6 +6819,11 @@ namespace binpack std::int16_t score; std::uint16_t ply; std::int16_t result; + + [[nodiscard]] bool isValid() const + { + return pos.isMoveLegal(move); + } }; [[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv) @@ -6921,7 +7512,7 @@ namespace binpack buffer.insert(buffer.end(), data, data+sizeof(psv)); } - inline void convertPlainToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om) + inline void convertPlainToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) { constexpr std::size_t reportEveryNPositions = 100'000; @@ -6949,6 +7540,11 @@ namespace binpack if (key == "e"sv) { e.move = chess::uci::uciToMove(e.pos, move); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } writer.addTrainingDataEntry(e); @@ -6975,7 +7571,7 @@ namespace binpack std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } - inline void convertBinpackToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om) + inline void convertBinpackToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) { constexpr std::size_t bufferSize = MiB; @@ -6990,7 +7586,14 @@ namespace binpack while(reader.hasNext()) { - emitPlainEntry(buffer, reader.next()); + auto e = reader.next(); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + emitPlainEntry(buffer, e); ++numProcessedPositions; @@ -7016,7 +7619,7 @@ namespace binpack } - inline void convertBinToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om) + inline void convertBinToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) { constexpr std::size_t reportEveryNPositions = 100'000; @@ -7037,7 +7640,15 @@ namespace binpack break; } - writer.addTrainingDataEntry(packedSfenValueToTrainingDataEntry(psv)); + auto e = packedSfenValueToTrainingDataEntry(psv); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + std::cerr << static_cast(e.move.type) << '\n'; + return; + } + + writer.addTrainingDataEntry(e); ++numProcessedPositions; const auto cur = inputFile.tellg(); @@ -7050,7 +7661,7 @@ namespace binpack std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } - inline void convertBinpackToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om) + inline void convertBinpackToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) { constexpr std::size_t bufferSize = MiB; @@ -7065,7 +7676,14 @@ namespace binpack while(reader.hasNext()) { - emitBinEntry(buffer, reader.next()); + auto e = reader.next(); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + emitBinEntry(buffer, e); ++numProcessedPositions; @@ -7090,7 +7708,7 @@ namespace binpack std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } - inline void convertBinToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om) + inline void convertBinToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) { constexpr std::size_t bufferSize = MiB; @@ -7113,7 +7731,14 @@ namespace binpack break; } - emitPlainEntry(buffer, packedSfenValueToTrainingDataEntry(psv)); + auto e = packedSfenValueToTrainingDataEntry(psv); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + emitPlainEntry(buffer, e); ++numProcessedPositions; @@ -7138,7 +7763,7 @@ namespace binpack std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } - inline void convertPlainToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om) + inline void convertPlainToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) { constexpr std::size_t bufferSize = MiB; @@ -7169,6 +7794,11 @@ namespace binpack if (key == "e"sv) { e.move = chess::uci::uciToMove(e.pos, move); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } emitBinEntry(buffer, e); diff --git a/src/learn/convert.cpp b/src/learn/convert.cpp index dfd30509..5fe7ea1d 100644 --- a/src/learn/convert.cpp +++ b/src/learn/convert.cpp @@ -525,7 +525,7 @@ namespace Learner && ends_with(output_path, expected_output_extension); } - using ConvertFunctionType = void(std::string inputPath, std::string outputPath, std::ios_base::openmode om); + using ConvertFunctionType = void(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate); static ConvertFunctionType* get_convert_function(const std::string& input_path, const std::string& output_path) { @@ -547,7 +547,7 @@ namespace Learner return nullptr; } - static void convert(const std::string& input_path, const std::string& output_path, std::ios_base::openmode om) + static void convert(const std::string& input_path, const std::string& output_path, std::ios_base::openmode om, bool validate) { if(!file_exists(input_path)) { @@ -558,7 +558,7 @@ namespace Learner auto func = get_convert_function(input_path, output_path); if (func != nullptr) { - func(input_path, output_path, om); + func(input_path, output_path, om, validate); } else { @@ -568,20 +568,22 @@ namespace Learner static void convert(const std::vector& args) { - if (args.size() < 2 || args.size() > 3) + if (args.size() < 2 || args.size() > 4) { std::cerr << "Invalid arguments.\n"; - std::cerr << "Usage: convert from_path to_path [append]\n"; + std::cerr << "Usage: convert from_path to_path [append] [validate]\n"; return; } - const bool append = (args.size() == 3) && (args[2] == "append"); + const bool append = std::find(args.begin() + 2, args.end(), "append") != args.end(); + const bool validate = std::find(args.begin() + 2, args.end(), "validate") != args.end(); + const std::ios_base::openmode openmode = append ? std::ios_base::app : std::ios_base::trunc; - convert(args[0], args[1], openmode); + convert(args[0], args[1], openmode, validate); } void convert(istringstream& is)