Skip to content
17 changes: 16 additions & 1 deletion src/sc2_datasets/transforms/mmr_vs_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,22 @@ def mmr_vs_result(sc2_replay: SC2ReplayData) -> Tuple[torch.Tensor, torch.Tensor
dtype=torch.float,
)

result_dict = {"Loss": 0, "Win": 1, "Victory": 1, "Defeat": 0}
result_dict = {
"Loss": 0,
"Win": 1,
"Victory": 1,
"Defeat": 0,
"Undecided": -1,
"Draw": -1,
"Tie": -1,
}

# Check if result is "Undecided", "Draw", or "Tie" and return None to skip this replay
skip_results = ["Undecided", "Draw", "Tie"]
if sc2_replay.toonPlayerDescMap[0].toon_player_info.result in skip_results:
return None, None

# Map result to label tensor
label_tensor = torch.tensor(
result_dict[sc2_replay.toonPlayerDescMap[0].toon_player_info.result],
dtype=torch.int8,
Expand Down
19 changes: 18 additions & 1 deletion src/sc2_datasets/transforms/pytorch/economy_vs_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,24 @@ def economy_average_vs_outcome(
# Creating feature tensor:
feature_tensor = torch.tensor(feature_list, dtype=torch.float32)

result_dict = {"Loss": 0, "Win": 1, "Victory": 1, "Defeat": 0}
result_dict = {
"Loss": 0,
"Win": 1,
"Victory": 1,
"Defeat": 0,
"Undecided": -1,
"Draw": -1,
"Tie": -1,
}

# Check if any player's result is "Undecided", "Draw", or "Tie" and return None to skip this replay
skip_results = ["Undecided", "Draw", "Tie"]
if any(
player_desc.toon_player_info.result in skip_results
for player_desc in sc2_replay.toonPlayerDescMap
):
return None, None

target = result_dict[sc2_replay.toonPlayerDescMap[0].toon_player_info.result]

return feature_tensor, target
17 changes: 16 additions & 1 deletion src/sc2_datasets/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,22 @@ def select_outcome_1v1(sc2_replay: SC2ReplayData) -> Dict[str, int]:

player_outcome = {"1": 0, "2": 0}

result_dict = {"Loss": 0, "Win": 1, "Victory": 1, "Defeat": 0}
result_dict = {
"Loss": 0,
"Win": 1,
"Victory": 1,
"Defeat": 0,
"Undecided": -1,
"Draw": -1,
"Tie": -1,
}

# Check if any player has an "Undecided", "Draw", or "Tie" result and return None to indicate skipping
skip_results = ["Undecided", "Draw", "Tie"]
for toon_desc_map in sc2_replay.toonPlayerDescMap:
if toon_desc_map.toon_player_info.result in skip_results:
return None

for toon_desc_map in sc2_replay.toonPlayerDescMap:
result = result_dict[toon_desc_map.toon_player_info.result]
player_outcome[toon_desc_map.toon_player_info.playerID] = result
Expand Down