Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 104 additions & 9 deletions darts/tests/utils/test_timeseries_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import numpy as np
import pandas as pd
import pytest
from pandas.tseries.frequencies import to_offset

from darts import TimeSeries
from darts.utils.timeseries_generation import (
DATETIME_ATT_WITH_VARIABLE_MAX,
ONE_INDEXED_FREQS,
_build_forecast_series_from_schema,
autoregressive_timeseries,
Expand All @@ -18,6 +20,7 @@
linear_timeseries,
random_walk_timeseries,
sine_timeseries,
unique_datetime_value_freq_aware,
)
from darts.utils.utils import freqs

Expand Down Expand Up @@ -385,6 +388,94 @@ def test_datetime_attribute_timeseries_wrong_args(self):
)
assert "`time_index` must be time zone naive." == str(err.value)

@pytest.mark.parametrize(
"attribute,freq,start,expected",
[
pytest.param(
"minute",
to_offset("1min"),
pd.Timestamp(year=2000, month=1, day=1),
np.arange(60),
id="minute_minutely",
),
pytest.param(
"minute",
to_offset("1min"),
pd.Timestamp(year=2000, month=1, day=1, minute=1),
np.arange(60),
id="minute_minutely_one_minute_shifted",
),
pytest.param(
"minute",
to_offset("1h"),
pd.Timestamp(year=2000, month=1, day=1),
np.arange(1),
id="minute_hourly",
),
pytest.param(
"minute",
to_offset("15min"),
pd.Timestamp(year=2000, month=1, day=1),
np.array([0, 15, 30, 45]),
id="minute_quarter_hourly",
),
pytest.param(
"day",
to_offset("1D"),
pd.Timestamp(year=2025, month=1, day=1),
np.arange(31),
id="day_daily_january",
),
pytest.param(
"day",
to_offset("1D"),
pd.Timestamp(year=2025, month=2, day=1),
np.arange(31),
id="day_daily_february",
),
pytest.param(
"day_of_week",
to_offset("YS"),
pd.Timestamp(year=2025, month=1, day=1),
np.arange(7),
id="dayofweek_yearly",
),
pytest.param(
"day",
to_offset("YS"),
pd.Timestamp(year=2025, month=1, day=1),
np.arange(1),
id="day_yearly",
),
pytest.param(
"day",
to_offset("B"),
pd.Timestamp(year=2025, month=1, day=1),
np.arange(31),
id="day_business_daily",
),
pytest.param(
"nanosecond",
to_offset("999999ns"),
pd.Timestamp(year=2000, month=1, day=1),
np.arange(1000),
),
],
)
def test_unique_datetime_value_freq_aware(
self,
attribute: str,
freq: pd.DateOffset,
start: pd.Timestamp,
expected: np.ndarray[int] | type[Exception],
):
if isinstance(expected, type) and issubclass(expected, Exception):
with pytest.raises(expected):
unique_datetime_value_freq_aware(attribute, freq, start)
else:
unique_values = unique_datetime_value_freq_aware(attribute, freq, start)
np.testing.assert_array_equal(unique_values, expected)

def test_datetime_attribute_timeseries(self):
idx = generate_index(
start=pd.Timestamp("2000-01-01"), length=48, freq=freqs["h"]
Expand Down Expand Up @@ -424,8 +515,8 @@ def test_datetime_attribute_timeseries(self):
(freqs["h"], "hour", 24),
("D", "weekday", 7),
(freqs["s"], "second", 60),
("W", "weekofyear", 52),
("D", "dayofyear", 365),
("W", "weekofyear", 53),
("D", "dayofyear", 366),
(freqs["QE"], "quarter", 4),
],
)
Expand Down Expand Up @@ -479,8 +570,10 @@ def test_datetime_attribute_timeseries_one_hot(self, config):
# first quarter/year, month/year, week/year, day/year, day/week, hour/day, second/hour
simple_start = pd.Timestamp("2001-01-01 00:00:00")
idx = generate_index(start=simple_start, length=period, freq=base_freq)
vals = np.eye(period)

expected_dim = period
if attribute_freq in DATETIME_ATT_WITH_VARIABLE_MAX:
expected_dim += 1
vals = np.eye(period, expected_dim)
# simple start
self.helper_routine(idx, attribute_freq, vals_exp=vals, one_hot=True)
# with time-zone
Expand All @@ -492,7 +585,7 @@ def test_datetime_attribute_timeseries_one_hot(self, config):
# missing values
cut_period = period // 3
idx = generate_index(start=simple_start, length=cut_period, freq=base_freq)
vals = np.eye(period)
vals = np.eye(period, expected_dim)
# removing missing rows
vals = vals[:cut_period]
# mask missing attribute values
Expand All @@ -519,7 +612,7 @@ def test_datetime_attribute_timeseries_one_hot(self, config):
shift -= 1

idx = generate_index(start=shifted_start, length=period, freq=base_freq)
vals = np.eye(period)
vals = np.eye(period, expected_dim)
# shift values
vals = np.roll(vals, shift=-shift, axis=0)

Expand Down Expand Up @@ -617,9 +710,11 @@ def test_datetime_attribute_timeseries_special_years(self, year):
# the 53th week is omitted from index when created with freq="W"
index_weeks = pd.date_range(start=start_date, end=end_date, freq="W")
assert len(index_weeks) == weeks_special_year - 1
# and 53th week properly excluded from the encoding
vals_exp = np.eye(weeks_special_year - 1)[: len(index_weeks)]
assert vals_exp.shape[1] == weeks_special_year - 1
# and 53th week should still be part of the encoding
vals_exp = np.eye(weeks_special_year - 1, weeks_special_year)[
: len(index_weeks)
]
assert vals_exp.shape[1] == weeks_special_year
self.helper_routine(
index_weeks, "week_of_year", vals_exp=vals_exp, one_hot=True
)
Expand Down
Loading