Skip to content

Commit 53f8cbd

Browse files
authored
[BUG] Fix TimeSeriesDataSet wrong inferred tensor dtype when time_idx is included in features (#1950)
#### Reference Issues/PRs Fixes #1930. When `time_idx` or some other integer feature is included in the unknow reals, the dataset/dataloader elements are converted into the wrong dtype. #### What does this implement/fix? Explain your changes. Modify the function the method `TimeSeriesDataSet._data_to_tensors._to_tensor()`
1 parent 6074011 commit 53f8cbd

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

pytorch_forecasting/data/timeseries/_timeseries.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,25 +1430,34 @@ def _data_to_tensors(self, data: pd.DataFrame) -> dict[str, torch.Tensor]:
14301430
time index
14311431
"""
14321432

1433-
def _to_tensor(cols, long=True) -> torch.Tensor:
1433+
def _to_tensor(cols, long=True, real=False) -> torch.Tensor:
14341434
"""Convert data[cols] to torch tensor.
14351435
14361436
Converts sub-frames to numpy and then to torch tensor.
14371437
Makes the following choices for types:
14381438
1439-
* float columns are converted to torch.float
1440-
* integer columns are converted to torch.int64 or torch.long,
1441-
depending on the long argument
1439+
- real is True:
1440+
* the sub-frame is converted to a torch.float32 tensor
1441+
- long is True (and real is False):
1442+
* the sub-frame is converted to a torch.long tensor
1443+
- real is False and long is False:
1444+
* if all columns are integer or boolean, the sub-frame is
1445+
converted to a torch.int64 tensor
1446+
* if one column is a float, the sub-frame is converted to
1447+
a torch.float32 tensor
14421448
"""
14431449
if not isinstance(cols, list) and cols not in data.columns:
14441450
return None
14451451
if isinstance(cols, list) and len(cols) == 0:
14461452
dtypekind = "f"
14471453
elif isinstance(cols, list): # and len(cols) > 0
1448-
dtypekind = data.dtypes[cols[0]].kind
1454+
# dtypekind = data.dtypes[cols[0]].kind
1455+
dtypekind = np.result_type(*data[cols].dtypes.to_list()).kind
14491456
else:
14501457
dtypekind = data.dtypes[cols].kind
1451-
if not long:
1458+
if real:
1459+
return torch.tensor(data[cols].to_numpy(np.float64), dtype=torch.float)
1460+
elif not long:
14521461
return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.int64)
14531462
elif dtypekind in "bi":
14541463
return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.long)

tests/test_data/test_timeseries.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,3 +678,41 @@ def distance_to_weights(dist):
678678
if idx > 100:
679679
break
680680
print(a)
681+
682+
683+
def test_correct_dtype_inference():
684+
# Create a small dataset
685+
data = pd.DataFrame(
686+
{
687+
"time_idx": np.arange(30),
688+
"value": np.sin(np.arange(30) / 5) + np.random.normal(scale=1, size=30),
689+
"group": ["A"] * 30,
690+
}
691+
)
692+
693+
# Define the dataset
694+
dataset = TimeSeriesDataSet(
695+
data.copy(),
696+
time_idx="time_idx",
697+
target="value",
698+
group_ids=["group"],
699+
static_categoricals=["group"],
700+
max_encoder_length=4,
701+
max_prediction_length=2,
702+
time_varying_unknown_reals=["value"],
703+
target_normalizer=None,
704+
# WATCH THIS
705+
time_varying_known_reals=["time_idx"],
706+
scalers=dict(time_idx=None),
707+
)
708+
709+
# and the dataloader
710+
dataloader = dataset.to_dataloader(batch_size=8)
711+
712+
x, y = next(iter(dataset))
713+
# real features must be real
714+
assert x["x_cont"].dtype is torch.float
715+
716+
x, y = next(iter(dataloader))
717+
# real features must be real
718+
assert x["encoder_cont"].dtype is torch.float

0 commit comments

Comments
 (0)