diff --git a/src/sc2_datasets/transforms/mmr_vs_result.py b/src/sc2_datasets/transforms/mmr_vs_result.py index 1f52879..b594a25 100644 --- a/src/sc2_datasets/transforms/mmr_vs_result.py +++ b/src/sc2_datasets/transforms/mmr_vs_result.py @@ -29,7 +29,22 @@ def mmr_vs_result(sc2_replay: SC2ReplayData) -> Tuple[torch.Tensor, torch.Tensor dtype=torch.float, ) - result_dict = {"Loss": 0, "Win": 1, "Victory": 1, "Defeat": 0} + result_dict = { + "Loss": 0, + "Win": 1, + "Victory": 1, + "Defeat": 0, + "Undecided": -1, + "Draw": -1, + "Tie": -1, + } + + # Check if result is "Undecided", "Draw", or "Tie" and return None to skip this replay + skip_results = ["Undecided", "Draw", "Tie"] + if sc2_replay.toonPlayerDescMap[0].toon_player_info.result in skip_results: + return None, None + + # Map result to label tensor label_tensor = torch.tensor( result_dict[sc2_replay.toonPlayerDescMap[0].toon_player_info.result], dtype=torch.int8, diff --git a/src/sc2_datasets/transforms/pytorch/economy_vs_outcome.py b/src/sc2_datasets/transforms/pytorch/economy_vs_outcome.py index 647b62e..f234088 100644 --- a/src/sc2_datasets/transforms/pytorch/economy_vs_outcome.py +++ b/src/sc2_datasets/transforms/pytorch/economy_vs_outcome.py @@ -59,7 +59,24 @@ def economy_average_vs_outcome( # Creating feature tensor: feature_tensor = torch.tensor(feature_list, dtype=torch.float32) - result_dict = {"Loss": 0, "Win": 1, "Victory": 1, "Defeat": 0} + result_dict = { + "Loss": 0, + "Win": 1, + "Victory": 1, + "Defeat": 0, + "Undecided": -1, + "Draw": -1, + "Tie": -1, + } + + # Check if any player's result is "Undecided", "Draw", or "Tie" and return None to skip this replay + skip_results = ["Undecided", "Draw", "Tie"] + if any( + player_desc.toon_player_info.result in skip_results + for player_desc in sc2_replay.toonPlayerDescMap + ): + return None, None + target = result_dict[sc2_replay.toonPlayerDescMap[0].toon_player_info.result] return feature_tensor, target diff --git a/src/sc2_datasets/transforms/utils.py b/src/sc2_datasets/transforms/utils.py index c3cb062..e74313a 100644 --- a/src/sc2_datasets/transforms/utils.py +++ b/src/sc2_datasets/transforms/utils.py @@ -235,7 +235,22 @@ def select_outcome_1v1(sc2_replay: SC2ReplayData) -> Dict[str, int]: player_outcome = {"1": 0, "2": 0} - result_dict = {"Loss": 0, "Win": 1, "Victory": 1, "Defeat": 0} + result_dict = { + "Loss": 0, + "Win": 1, + "Victory": 1, + "Defeat": 0, + "Undecided": -1, + "Draw": -1, + "Tie": -1, + } + + # Check if any player has an "Undecided", "Draw", or "Tie" result and return None to indicate skipping + skip_results = ["Undecided", "Draw", "Tie"] + for toon_desc_map in sc2_replay.toonPlayerDescMap: + if toon_desc_map.toon_player_info.result in skip_results: + return None + for toon_desc_map in sc2_replay.toonPlayerDescMap: result = result_dict[toon_desc_map.toon_player_info.result] player_outcome[toon_desc_map.toon_player_info.playerID] = result