Skip to content

Commit 82ac83d

Browse files
committed
(add): stats monitoring schema
1 parent 0c8a517 commit 82ac83d

File tree

5 files changed

+52
-40
lines changed

5 files changed

+52
-40
lines changed

stats-monitoring/monitoring/generate_animation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,23 @@
1111
from monitoring.frame import load_one_hour_density, set_frame
1212
from monitoring.logger import logger
1313
from monitoring.position_distribution import load_current_coordinate, set_histogram
14+
from monitoring.schema import StatsData
1415
from monitoring.stats_metrics import load_statistics, set_stats_metrics
1516

1617

1718
def update(
1819
frame_num: int,
1920
cfg: MonitoringConfig,
2021
axs: List,
21-
stats_dict: Dict,
22+
stats_data: StatsData,
2223
) -> None:
2324
"""Update function for movie frame
2425
2526
Args:
2627
frame_num (int): current frame number
2728
cfg (MonitoringConfig): config for monitoring environment
2829
axs (List): list of matplotlib axis
29-
stats_dict (Dict): dictionary containing each stats array
30+
stats_data (StatsData): instance of each statistics data
3031
"""
3132
# clear all axis
3233
for ax in axs:
@@ -50,7 +51,7 @@ def update(
5051
set_frame(cfg, frame_num, detected_coordinate_df, one_hour_density_df, axs[0])
5152

5253
# set statistics metrics
53-
set_stats_metrics(cfg, frame_num, stats_dict, axs[3], axs[4])
54+
set_stats_metrics(cfg, frame_num, stats_data, axs[3], axs[4])
5455

5556

5657
def generate_animation(cfg: MonitoringConfig, fig: Figure, axs: List) -> None:
@@ -62,15 +63,15 @@ def generate_animation(cfg: MonitoringConfig, fig: Figure, axs: List) -> None:
6263
axs (List): list of matplotlib axis
6364
"""
6465
logger.info("[START] Load Statistics Data ...")
65-
stats_dict = load_statistics(cfg)
66+
stats_data = load_statistics(cfg)
6667
logger.info("------------ [DONE] ------------")
6768

6869
anim = FuncAnimation(
6970
fig,
7071
update,
7172
frames=trange(cfg.animation.frame_number),
7273
interval=cfg.animation.interval,
73-
fargs=(cfg, axs, stats_dict),
74+
fargs=(cfg, axs, stats_data),
7475
)
7576
save_movie_path = str(DATA_DIR / cfg.path.save_movie_path)
7677
anim.save(save_movie_path, writer=cfg.animation.format)

stats-monitoring/monitoring/schema.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Annotated, Any
2+
3+
import numpy as np
4+
from pydantic import BaseModel, PlainSerializer, PlainValidator
5+
6+
7+
def validate(v: Any) -> np.ndarray:
8+
if isinstance(v, np.ndarray):
9+
return v
10+
else:
11+
raise TypeError(f"Expected numpy array, got {type(v)}")
12+
13+
14+
def serialize(v: np.ndarray) -> list[list[float]]:
15+
return v.tolist()
16+
17+
18+
DataArray = Annotated[
19+
np.ndarray,
20+
PlainValidator(validate),
21+
PlainSerializer(serialize),
22+
]
23+
24+
25+
class StatsData(BaseModel):
26+
mean: DataArray
27+
past_mean: DataArray
28+
acceleration: DataArray
29+
past_acceleration: DataArray

stats-monitoring/monitoring/stats_metrics.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
from pathlib import Path
3-
from typing import Dict
43

54
import numpy as np
65
from matplotlib.axes import Axes
@@ -13,18 +12,19 @@
1312
LABEL_IDX_MAX,
1413
TIME_LABELS,
1514
)
16-
from monitoring.exceptions import PathNotExistError, StatsKeyNotExistError
15+
from monitoring.exceptions import PathNotExistError
1716
from monitoring.logger import logger
17+
from monitoring.schema import StatsData
1818

1919

20-
def load_statistics(cfg: MonitoringConfig) -> Dict[str, np.ndarray]:
20+
def load_statistics(cfg: MonitoringConfig) -> StatsData:
2121
"""Load the statistics used in the monitoring environment
2222
2323
Args:
2424
cfg (MonitoringConfig): config for monitoring environment
2525
2626
Returns:
27-
dict: dictionary containing each stats array
27+
StatsData: instance of each statistics data
2828
"""
2929

3030
def load_value(path: Path) -> np.ndarray:
@@ -43,38 +43,20 @@ def load_value(path: Path) -> np.ndarray:
4343
logger.error(message)
4444
raise PathNotExistError(message)
4545

46-
stats_dict = {
47-
"mean": load_value(DATA_DIR / cfg.path.mean_speed_path),
48-
"past_mean": load_value(DATA_DIR / cfg.path.past_mean_speed_path),
49-
"acceleration": load_value(DATA_DIR / cfg.path.acceleration_count_path),
50-
"past_acceleration": load_value(DATA_DIR / cfg.path.past_acceleration_count_path),
51-
}
46+
stats_data = StatsData(
47+
mean=load_value(DATA_DIR / cfg.path.mean_speed_path),
48+
past_mean=load_value(DATA_DIR / cfg.path.past_mean_speed_path),
49+
acceleration=load_value(DATA_DIR / cfg.path.acceleration_count_path),
50+
past_acceleration=load_value(DATA_DIR / cfg.path.past_acceleration_count_path),
51+
)
5252

53-
return stats_dict
54-
55-
56-
def load_array(stats_dict: Dict[str, np.ndarray], key: str) -> np.ndarray:
57-
"""Load array from statistics dictionary
58-
59-
Args:
60-
stats_dict (Dict): dictionary containing each stats array
61-
key (str): dictionary key
62-
63-
Returns:
64-
np.ndarray: target stats array
65-
"""
66-
if key in stats_dict.keys():
67-
return stats_dict[key]
68-
else:
69-
message = f'key="{key}" is not exist in stats_dict.'
70-
logger.error(message)
71-
raise StatsKeyNotExistError(message)
53+
return stats_data
7254

7355

7456
def set_stats_metrics(
7557
cfg: MonitoringConfig,
7658
frame_num: int,
77-
stats_dict: Dict[str, np.ndarray],
59+
stats_data: StatsData,
7860
mean_ax: Axes,
7961
acc_ax: Axes,
8062
) -> None:
@@ -83,11 +65,11 @@ def set_stats_metrics(
8365
Args:
8466
cfg (MonitoringConfig): config for monitoring environment
8567
frame_num (int): current frame number
86-
stats_dict (Dict): dictionary containing each stats array
68+
stats_data (StatsData): instance of each statistics data
8769
mean_ax (Axes): matplotlib figure axis of mean speed
8870
acc_ax (Axes): matplotlib figure axis of acceleration count
8971
"""
90-
mean_arr = load_array(stats_dict, "mean")
72+
mean_arr = stats_data.mean
9173
x = [i for i in range(len(mean_arr))]
9274

9375
# plot mean speed
@@ -99,7 +81,7 @@ def set_stats_metrics(
9981
mean_ax.axvline(frame_num, 0, 100, color="black", linestyle="dashed")
10082

10183
# plot cumulate acceleration count
102-
acc_arr = load_array(stats_dict, "acceleration")
84+
acc_arr = stats_data.acceleration
10385
acc_ax.plot(x[: frame_num + 1], acc_arr[: frame_num + 1])
10486
acc_ax.set_xlim(0, LABEL_IDX_MAX)
10587
acc_ax.set_ylim(0, cfg.statistics.acceleration_max)

stats-monitoring/mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[mypy]
2-
python_version = 3.8
2+
python_version = 3.9.16
33
follow_imports = skip
44
ignore_missing_imports = True
55
disallow_untyped_defs = True

stats-monitoring/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ tqdm = "^4.62.3"
1717
seaborn = "^0.11.2"
1818
pandas = "^1.3.4"
1919
matplotlib = "^3.5.0"
20-
pydantic = "^2.3.0"
20+
pydantic = "2.8.2"
2121
pathlib = "^1.0.1"
2222
pyyaml = "^6.0.1"
2323
types-pyyaml = "^6.0.12.11"

0 commit comments

Comments
 (0)