diff --git a/src/learn/learner.cpp b/src/learn/learner.cpp index 1e51eeb5..a9e742a0 100644 --- a/src/learn/learner.cpp +++ b/src/learn/learner.cpp @@ -2537,8 +2537,26 @@ void shuffle_files_on_memory(const vector& filenames,const string output std::cout << "..shuffle_on_memory done." << std::endl; } -void convert_bin(const vector& filenames, const string& output_file_name, const int ply_minimum, const int ply_maximum, const int interpolate_eval, const bool check_illegal_move) +bool fen_is_ok(Position& pos, std::string input_fen) { + std::string pos_fen = pos.fen(); + std::istringstream ss_input(input_fen); + std::istringstream ss_pos(pos_fen); + + // example : "2r4r/4kpp1/nb1np3/p2p3p/B2P1BP1/PP6/4NPKP/2R1R3 w - h6 0 24" + // --> "2r4r/4kpp1/nb1np3/p2p3p/B2P1BP1/PP6/4NPKP/2R1R3" + std::string str_input, str_pos; + ss_input >> str_input; + ss_pos >> str_pos; + + // Only compare "Piece placement field" between input_fen and pos.fen(). + return str_input == str_pos; +} + +void convert_bin(const vector& filenames, const string& output_file_name, const int ply_minimum, const int ply_maximum, const int interpolate_eval, const bool check_invalid_fen, const bool check_illegal_move) { + std::cout << "check_invalid_fen=" << check_invalid_fen << std::endl; + std::cout << "check_illegal_move=" << check_illegal_move << std::endl; + std::fstream fs; uint64_t data_size=0; uint64_t filtered_size = 0; @@ -2571,16 +2589,16 @@ void convert_bin(const vector& filenames, const string& output_file_name std::string value; ss >> token; if (token == "fen") { - states = StateListPtr(new std::deque(1)); // Drop old and create a new one + states = StateListPtr(new std::deque(1)); // Drop old and create a new one std::string input_fen = line.substr(4); tpos.set(input_fen, false, &states->back(), Threads.main()); - if (!tpos.pos_is_ok() || tpos.fen() != input_fen) { + if (check_invalid_fen && !fen_is_ok(tpos, input_fen)) { ignore_flag_fen = true; filtered_size_fen++; } else { - tpos.sfen_pack(p.sfen); - } + tpos.sfen_pack(p.sfen); + } } else if (token == "move") { ss >> value; @@ -2607,7 +2625,7 @@ void convert_bin(const vector& filenames, const string& output_file_name } p.gamePly = uint16_t(temp); // No cast here? if (interpolate_eval != 0){ - p.score = min(3000, interpolate_eval * temp); + p.score = min(3000, interpolate_eval * temp); } } else if (token == "result") { @@ -2615,17 +2633,17 @@ void convert_bin(const vector& filenames, const string& output_file_name ss >> temp; p.game_result = int8_t(temp); // Do you need a cast here? if (interpolate_eval){ - p.score = p.score * p.game_result; + p.score = p.score * p.game_result; } } else if (token == "e") { - if(!(ignore_flag_fen || ignore_flag_move || ignore_flag_ply)){ - fs.write((char*)&p, sizeof(PackedSfenValue)); - data_size+=1; - // debug - // std::cout<& filenames, const string& output_file_name } } std::cout << "done " << data_size << " parsed " << filtered_size << " is filtered" - << " (illegal fen:" << filtered_size_fen << ", illegal move:" << filtered_size_move << ", illegal ply:" << filtered_size_ply << ")" << std::endl; + << " (invalid fen:" << filtered_size_fen << ", illegal move:" << filtered_size_move << ", invalid ply:" << filtered_size_ply << ")" << std::endl; ifs.close(); } std::cout << "all done" << std::endl; @@ -2983,6 +3001,7 @@ void learn(Position&, istringstream& is) int ply_minimum = 0; int ply_maximum = 114514; bool interpolate_eval = 0; + bool check_invalid_fen = false; bool check_illegal_move = false; // convert teacher in pgn-extract format to Yaneura King's bin bool use_convert_bin_from_pgn_extract = false; @@ -3123,6 +3142,7 @@ void learn(Position&, istringstream& is) else if (option == "convert_plain") use_convert_plain = true; else if (option == "convert_bin") use_convert_bin = true; else if (option == "interpolate_eval") is >> interpolate_eval; + else if (option == "check_invalid_fen") is >> check_invalid_fen; else if (option == "check_illegal_move") is >> check_illegal_move; else if (option == "convert_bin_from_pgn-extract") use_convert_bin_from_pgn_extract = true; else if (option == "pgn_eval_side_to_move") is >> pgn_eval_side_to_move; @@ -3235,7 +3255,7 @@ void learn(Position&, istringstream& is) { Eval::init_NNUE(); cout << "convert_bin.." << endl; - convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval, check_illegal_move); + convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval, check_invalid_fen, check_illegal_move); return; }