Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ venv*
.venv

test_output/*

.ruff_cache/
172 changes: 89 additions & 83 deletions poetry.lock

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ version_files = ["pyproject.toml:version"]
[tool.poetry.dependencies]
python = "^3.11"
pytorch-lightning = "^2.5.0"
pandas = "^2.1.4"
tqdm = "^4.66.1"
pandas = "^2.2.3"
tqdm = "^4.67.1"
requests = "^2.32.3"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.5"
pytest-cov = "^6.0.0"
myst-nb = "^1.0.0"
sphinx-autoapi = "^3.0.0"
pytest-cov = "^6.1.1"
myst-nb = "^1.2.0"
sphinx-autoapi = "^3.6.0"
furo = "^2024.8.6"
pre-commit = "^4.1.0"
ruff = "^0.9.10"
commitizen = "^4.4.1"
pre-commit = "^4.2.0"
ruff = "^0.11.7"
commitizen = "^4.6.0"

[tool.ruff]
target-version = "py311"
Expand Down
129 changes: 42 additions & 87 deletions src/sc2_datasets/replay_data/sc2_replay_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict

from sc2_datasets.replay_parser.details.details import Details
from sc2_datasets.replay_parser.game_events.game_events_parser import GameEventsParser
Expand All @@ -19,16 +19,20 @@
)


@dataclass
class SC2ReplayData:
"""
Specifies a data type that holds information parsed from json representation of a replay.

Parameters
----------
loaded_replay_object : Any
Specifies a parsed Python deserialized json object\
loaded into memory
"""
filepath: Path
header: Header
initData: InitData
details: Details
metadata: Metadata
messageEvents: list = field(default_factory=list)
gameEvents: list = field(default_factory=list)
trackerEvents: list = field(default_factory=list)
toonPlayerDescMap: list = field(default_factory=list)
gameEventsErr: bool = False
messageEventsErr: bool = False
trackerEventsErr: bool = False

@staticmethod
def from_file(replay_filepath: str) -> "SC2ReplayData":
Expand Down Expand Up @@ -61,47 +65,35 @@ def from_file(replay_filepath: str) -> "SC2ReplayData":
logging.info(f"Attempting to parse: {str(replay_path)}")
with replay_path.open(mode="r", encoding="utf-8") as replay_file:
loaded_data = json.load(replay_file)
return SC2ReplayData(filepath=replay_path, loaded_replay_object=loaded_data)

def __init__(self, filepath: Path, loaded_replay_object: Any) -> None:
# Replay data must contain the path to the json it comes from
# to allow for debugging:
self._filepath = filepath

self._header = Header.from_dict(d=loaded_replay_object["header"])
self._initData = InitData.from_dict(d=loaded_replay_object["initData"])
self._details = Details.from_dict(d=loaded_replay_object["details"])
self._metadata = Metadata.from_dict(d=loaded_replay_object["metadata"])
# TODO: We might want this to be a IterableDataset using PyTorch class:
self._messageEvents = []
if loaded_replay_object["messageEvents"]:
for event_dict in loaded_replay_object["messageEvents"]:
self._messageEvents.append(MessageEventsParser.from_dict(d=event_dict))
# TODO: We might want this to be a IterableDataset using PyTorch class:
self._gameEvents = []
if loaded_replay_object["gameEvents"]:
for event_dict in loaded_replay_object["gameEvents"]:
self._gameEvents.append(GameEventsParser.from_dict(d=event_dict))
# TODO: We might want this to be a IterableDataset using PyTorch class:
self._trackerEvents = []
if loaded_replay_object["trackerEvents"]:
for event_dict in loaded_replay_object["trackerEvents"]:
self._trackerEvents.append(TrackerEventsParser.from_dict(d=event_dict))
# TODO: We might want this to be a IterableDataset using PyTorch class:
toon_player_desc_dict: Dict[str, Dict[str, Any]] = loaded_replay_object[
"ToonPlayerDescMap"
]

self._toonPlayerDescMap = [
ToonPlayerDesc.from_dict(toon=toon, d=player_dict)
for toon, player_dict in toon_player_desc_dict.items()
]

self._gameEventsErr: bool = loaded_replay_object["gameEventsErr"]
self._messageEventsErr: bool = loaded_replay_object["messageEventsErr"]
self._trackerEventsErr: bool = loaded_replay_object["trackerEvtsErr"]
return SC2ReplayData(
filepath=replay_path,
header=Header.from_dict(d=loaded_data["header"]),
initData=InitData.from_dict(d=loaded_data["initData"]),
details=Details.from_dict(d=loaded_data["details"]),
metadata=Metadata.from_dict(d=loaded_data["metadata"]),
messageEvents=[
MessageEventsParser.from_dict(d=event_dict)
for event_dict in loaded_data.get("messageEvents", [])
],
gameEvents=[
GameEventsParser.from_dict(d=event_dict)
for event_dict in loaded_data.get("gameEvents", [])
],
trackerEvents=[
TrackerEventsParser.from_dict(d=event_dict)
for event_dict in loaded_data.get("trackerEvents", [])
],
toonPlayerDescMap=[
ToonPlayerDesc.from_dict(toon=toon, d=player_dict)
for toon, player_dict in loaded_data.get(
"ToonPlayerDescMap", {}
).items()
],
gameEventsErr=loaded_data.get("gameEventsErr", False),
messageEventsErr=loaded_data.get("messageEventsErr", False),
trackerEventsErr=loaded_data.get("trackerEvtsErr", False),
)

# REVIEW: Should the __hash__ be tested?
def __hash__(self) -> int:
"""
Custom hashing function based on the fields that were read from replay.
Expand Down Expand Up @@ -133,40 +125,3 @@ def __hash__(self) -> int:
player_tuple_toon,
)
)

# REVIEW: Should the properties be documented?
@property
def filepath(self):
return self._filepath

@property
def initData(self):
return self._initData

@property
def header(self):
return self._header

@property
def details(self):
return self._details

@property
def metadata(self):
return self._metadata

@property
def messageEvents(self):
return self._messageEvents

@property
def gameEvents(self):
return self._gameEvents

@property
def trackerEvents(self):
return self._trackerEvents

@property
def toonPlayerDescMap(self):
return self._toonPlayerDescMap
17 changes: 6 additions & 11 deletions src/sc2_datasets/replay_parser/details/details.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Any, Dict


@dataclass
class Details:
"""
Data type containing details about a StarCraft II game.
Expand All @@ -18,7 +20,10 @@ class Details:
Denotes the time at which the game was started in Coordinated Universal Time.
"""

# REVIEW: Doctests for this:
gameSpeed: str
isBlizzardMap: bool
timeUTC: str

@staticmethod
def from_dict(d: Dict[str, Any]) -> "Details":
"""
Expand Down Expand Up @@ -62,13 +67,3 @@ def from_dict(d: Dict[str, Any]) -> "Details":
isBlizzardMap=d["isBlizzardMap"],
timeUTC=d["timeUTC"],
)

def __init__(
self,
gameSpeed: str,
isBlizzardMap: bool,
timeUTC: str,
) -> None:
self.gameSpeed = gameSpeed
self.isBlizzardMap = isBlizzardMap
self.timeUTC = timeUTC
25 changes: 9 additions & 16 deletions src/sc2_datasets/replay_parser/game_events/events/camera_save.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass
from typing import Dict

from sc2_datasets.replay_parser.game_events.events.nested.target_2d import Target2D
from sc2_datasets.replay_parser.game_events.game_event import GameEvent


@dataclass
class CameraSave(GameEvent):
"""
CameraSave represents replay information regarding a saved camera location within the game.
Expand All @@ -14,15 +16,20 @@ class CameraSave(GameEvent):
Identifier for the CameraSave object. Multiple elements may share the same ID.
loop : int
Game loop number (game-engine tick) when the event occurred.
target : Target
target : Target2D
Target class object containing x and y coordinates where the camera location was set in the game.
userid : int
ID of the player who saved the camera location.
which : int
Hotkey [0-9] to which the camera location was set.
"""

# REVIEW: Doctests here:
id: int
loop: int
target: Target2D
userid: int
which: int

@staticmethod
def from_dict(d: Dict) -> "CameraSave":
"""
Expand Down Expand Up @@ -79,17 +86,3 @@ def from_dict(d: Dict) -> "CameraSave":
userid=d["userid"]["userId"],
which=d["which"],
)

def __init__(
self,
id: int,
loop: int,
target: Target2D,
userid: int,
which: int,
) -> None:
self.id = id
self.loop = loop
self.target = target
self.userid = userid
self.which = which
39 changes: 14 additions & 25 deletions src/sc2_datasets/replay_parser/game_events/events/camera_update.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass
from types import NoneType
from typing import Dict

from sc2_datasets.replay_parser.game_events.events.nested.target_2d import Target2D
from sc2_datasets.replay_parser.game_events.game_event import GameEvent


@dataclass
class CameraUpdate(GameEvent):
"""
CameraUpdate represents replay data regarding updated camera locations in the game.
Expand All @@ -23,17 +25,26 @@ class CameraUpdate(GameEvent):
Angle in the vertical plane, representing the vertical elevation of the camera.
reason : None, str
No valuable information about this parameter.
target : Target
target : Target2D | None
Target class object containing x and y coordinates where the camera location was set.
userid : int
ID of the player who saved the camera location.
yaw : None, float, int
Angle in the horizontal plane of the camera.
"""

# REVIEW: Doctests here:
distance: NoneType | float | int
follow: bool
id: int
loop: int
pitch: NoneType | float | int
reason: NoneType | str
target: Target2D | None
userid: int
yaw: NoneType | float | int

@staticmethod
def from_dict(d: Dict):
def from_dict(d: Dict) -> "CameraUpdate":
"""
Static method returning initialized CameraUpdate class from a dictionary.
This aids in parsing the original JSON file extracted from a processed .SC2Replay file.
Expand Down Expand Up @@ -107,25 +118,3 @@ def from_dict(d: Dict):
userid=d["userid"]["userId"],
yaw=d["yaw"],
)

def __init__(
self,
distance: NoneType | float | int,
follow: bool,
id: int,
loop: int,
pitch: NoneType | float | int,
reason: NoneType | str,
target: Target2D | None,
userid: int,
yaw: NoneType | float | int,
) -> None:
self.distance = distance
self.follow = follow
self.id = id
self.loop = loop
self.pitch = pitch
self.reason = reason
self.target = target
self.userid = userid
self.yaw = yaw
25 changes: 9 additions & 16 deletions src/sc2_datasets/replay_parser/game_events/events/cmd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from types import NoneType
from typing import Dict

Expand All @@ -7,6 +8,7 @@
# Should this be encoded somehow if there is a NoneType detected?


@dataclass
class Cmd(GameEvent):
"""
Cmd contains specific details about command interface events.
Expand All @@ -31,6 +33,13 @@ class Cmd(GameEvent):

"""

id: int
loop: int
otherUnit: NoneType
sequence: int
unitGroup: NoneType | int
userid: int

@staticmethod
def from_dict(d: Dict) -> "Cmd":
"""
Expand All @@ -57,19 +66,3 @@ def from_dict(d: Dict) -> "Cmd":
unitGroup=d["unitGroup"],
userid=d["userid"]["userId"],
)

def __init__(
self,
id: int,
loop: int,
otherUnit: NoneType,
sequence: int,
unitGroup: NoneType | int,
userid: int,
) -> None:
self.id = id
self.loop = loop
self.otherUnit = otherUnit
self.sequence = sequence
self.unitGroup = unitGroup
self.userid = userid
Loading