Skip to content

Commit 4a6eeed

Browse files
authored
Remove implicit clone (#286)
Fixes #254 Summary: 1. Only clone when `na_strategy` is specified in stype encoder. Only clone the `values` of a `MultiTensor`.
1 parent 8fca093 commit 4a6eeed

File tree

4 files changed

+93
-66
lines changed

4 files changed

+93
-66
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Added
1010

1111
### Changed
12+
- Removed implicit clones in `StypeEncoder` ([#286](https://github.com/pyg-team/pytorch-frame/pull/286))
1213

1314
### Deprecated
1415

test/nn/encoder/test_stype_encoder.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import copy
2+
13
import pytest
24
import torch
35
from torch.nn import ReLU
46

57
import torch_frame
68
from torch_frame import NAStrategy, stype
79
from torch_frame.config import ModelConfig
8-
from torch_frame.config.text_embedder import TextEmbedderConfig
910
from torch_frame.config.text_tokenizer import TextTokenizerConfig
1011
from torch_frame.data.dataset import Dataset
1112
from torch_frame.data.stats import StatType
@@ -22,7 +23,6 @@
2223
StackEncoder,
2324
TimestampEncoder,
2425
)
25-
from torch_frame.testing.text_embedder import HashTextEmbedder
2626
from torch_frame.testing.text_tokenizer import (
2727
RandomTextModel,
2828
WhiteSpaceHashTokenizer,
@@ -44,10 +44,12 @@ def test_categorical_feature_encoder(encoder_cls_kwargs):
4444
stype=stype.categorical,
4545
**encoder_cls_kwargs[1],
4646
)
47-
feat_cat = tensor_frame.feat_dict[stype.categorical]
47+
feat_cat = tensor_frame.feat_dict[stype.categorical].clone()
4848
col_names = tensor_frame.col_names_dict[stype.categorical]
4949
x = encoder(feat_cat, col_names)
5050
assert x.shape == (feat_cat.size(0), feat_cat.size(1), 8)
51+
# Make sure no in-place modification
52+
assert torch.allclose(feat_cat, tensor_frame.feat_dict[stype.categorical])
5153

5254
# Perturb the first column
5355
num_categories = len(stats_list[0][StatType.COUNT])
@@ -96,10 +98,12 @@ def test_numerical_feature_encoder(encoder_cls_kwargs):
9698
stype=stype.numerical,
9799
**encoder_cls_kwargs[1],
98100
)
99-
feat_num = tensor_frame.feat_dict[stype.numerical]
101+
feat_num = tensor_frame.feat_dict[stype.numerical].clone()
100102
col_names = tensor_frame.col_names_dict[stype.numerical]
101103
x = encoder(feat_num, col_names)
102104
assert x.shape == (feat_num.size(0), feat_num.size(1), 8)
105+
# Make sure no in-place modification
106+
assert torch.allclose(feat_num, tensor_frame.feat_dict[stype.numerical])
103107
if "post_module" in encoder_cls_kwargs[1]:
104108
assert encoder.post_module is not None
105109
else:
@@ -142,9 +146,16 @@ def test_multicategorical_feature_encoder(encoder_cls_kwargs):
142146
stype=stype.multicategorical,
143147
**encoder_cls_kwargs[1],
144148
)
145-
feat_multicat = tensor_frame.feat_dict[stype.multicategorical]
149+
feat_multicat = tensor_frame.feat_dict[stype.multicategorical].clone()
146150
col_names = tensor_frame.col_names_dict[stype.multicategorical]
147151
x = encoder(feat_multicat, col_names)
152+
# Make sure no in-place modification
153+
assert torch.allclose(
154+
feat_multicat.values,
155+
tensor_frame.feat_dict[stype.multicategorical].values)
156+
assert torch.allclose(
157+
feat_multicat.offset,
158+
tensor_frame.feat_dict[stype.multicategorical].offset)
148159
assert x.shape == (feat_multicat.size(0), feat_multicat.size(1), 8)
149160

150161
# Perturb the first column
@@ -178,9 +189,12 @@ def test_timestamp_feature_encoder(encoder_cls_kwargs):
178189
stype=stype.timestamp,
179190
**encoder_cls_kwargs[1],
180191
)
181-
feat_timestamp = tensor_frame.feat_dict[stype.timestamp]
192+
feat_timestamp = tensor_frame.feat_dict[stype.timestamp].clone()
182193
col_names = tensor_frame.col_names_dict[stype.timestamp]
183194
x = encoder(feat_timestamp, col_names)
195+
# Make sure no in-place modification
196+
assert torch.allclose(feat_timestamp,
197+
tensor_frame.feat_dict[stype.timestamp])
184198
assert x.shape == (feat_timestamp.size(0), feat_timestamp.size(1), 8)
185199

186200

@@ -324,40 +338,6 @@ def test_timestamp_feature_encoder_with_nan(encoder_cls_kwargs):
324338
assert (~torch.isnan(x)).all()
325339

326340

327-
def test_text_embedded_encoder():
328-
num_rows = 20
329-
text_emb_channels = 10
330-
out_channels = 5
331-
dataset = FakeDataset(
332-
num_rows=num_rows,
333-
stypes=[
334-
torch_frame.text_embedded,
335-
],
336-
col_to_text_embedder_cfg=TextEmbedderConfig(
337-
text_embedder=HashTextEmbedder(text_emb_channels),
338-
batch_size=None),
339-
)
340-
dataset.materialize()
341-
tensor_frame = dataset.tensor_frame
342-
stats_list = [
343-
dataset.col_stats[col_name]
344-
for col_name in tensor_frame.col_names_dict[stype.embedding]
345-
]
346-
encoder = LinearEmbeddingEncoder(
347-
out_channels=out_channels,
348-
stats_list=stats_list,
349-
stype=stype.embedding,
350-
)
351-
feat_text = tensor_frame.feat_dict[stype.embedding]
352-
col_names = tensor_frame.col_names_dict[stype.embedding]
353-
feat = encoder(feat_text, col_names)
354-
assert feat.shape == (
355-
num_rows,
356-
len(tensor_frame.col_names_dict[stype.embedding]),
357-
out_channels,
358-
)
359-
360-
361341
def test_embedding_encoder():
362342
num_rows = 20
363343
out_channels = 5
@@ -378,9 +358,14 @@ def test_embedding_encoder():
378358
stats_list=stats_list,
379359
stype=stype.embedding,
380360
)
381-
feat_text = tensor_frame.feat_dict[stype.embedding]
361+
feat_emb = tensor_frame.feat_dict[stype.embedding].clone()
382362
col_names = tensor_frame.col_names_dict[stype.embedding]
383-
x = encoder(feat_text, col_names)
363+
x = encoder(feat_emb, col_names)
364+
# Make sure no in-place modification
365+
assert torch.allclose(feat_emb.values,
366+
tensor_frame.feat_dict[stype.embedding].values)
367+
assert torch.allclose(feat_emb.offset,
368+
tensor_frame.feat_dict[stype.embedding].offset)
384369
assert x.shape == (
385370
num_rows,
386371
len(tensor_frame.col_names_dict[stype.embedding]),
@@ -421,11 +406,23 @@ def test_text_tokenized_encoder():
421406
stype=stype.text_tokenized,
422407
col_to_model_cfg=col_to_model_cfg,
423408
)
424-
feat_text = tensor_frame.feat_dict[stype.text_tokenized]
409+
feat_text = copy.deepcopy(tensor_frame.feat_dict[stype.text_tokenized])
425410
col_names = tensor_frame.col_names_dict[stype.text_tokenized]
426411
x = encoder(feat_text, col_names)
427412
assert x.shape == (
428413
num_rows,
429414
len(tensor_frame.col_names_dict[stype.text_tokenized]),
430415
out_channels,
431416
)
417+
# Make sure no in-place modification
418+
assert isinstance(feat_text, dict) and isinstance(
419+
tensor_frame.feat_dict[stype.text_tokenized], dict)
420+
assert feat_text.keys() == tensor_frame.feat_dict[
421+
stype.text_tokenized].keys()
422+
for key in feat_text.keys():
423+
assert torch.allclose(
424+
feat_text[key].values,
425+
tensor_frame.feat_dict[stype.text_tokenized][key].values)
426+
assert torch.allclose(
427+
feat_text[key].offset,
428+
tensor_frame.feat_dict[stype.text_tokenized][key].offset)

test/nn/models/test_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
gamma=0.1,
3535
),
3636
None,
37-
4,
37+
7,
3838
id="TabNet",
3939
),
4040
pytest.param(
@@ -54,7 +54,7 @@
5454
Trompt,
5555
dict(channels=8, num_prompts=2),
5656
None,
57-
11,
57+
16,
5858
id="Trompt",
5959
),
6060
pytest.param(

torch_frame/nn/encoder/stype_encoder.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ def reset_parameters_soft(module: Module):
3737
module.reset_parameters()
3838

3939

40+
def get_na_mask(tensor: Tensor) -> Tensor:
41+
r"""Obtains the Na maks of the input :obj:`Tensor`.
42+
43+
Args:
44+
tensor (Tensor): Input :obj:`Tensor`.
45+
"""
46+
if tensor.is_floating_point():
47+
na_mask = torch.isnan(tensor)
48+
else:
49+
na_mask = tensor == -1
50+
return na_mask
51+
52+
4053
class StypeEncoder(Module, ABC):
4154
r"""Base class for stype encoder. This module transforms tensor of a
4255
specific stype, i.e., `TensorFrame.feat_dict[stype.xxx]` into 3-dimensional
@@ -121,11 +134,6 @@ def forward(
121134
f"The number of columns in feat and the length of "
122135
f"col_names must match (got {num_cols} and "
123136
f"{len(col_names)}, respectively.)")
124-
# Clone the tensor to avoid in-place modification
125-
if not isinstance(feat, dict):
126-
feat = feat.clone()
127-
else:
128-
feat = {key: value.clone() for key, value in feat.items()}
129137
# NaN handling of the input Tensor
130138
feat = self.na_forward(feat)
131139
# Main encoding into column embeddings
@@ -174,20 +182,36 @@ def na_forward(self, feat: TensorData) -> TensorData:
174182
"""
175183
if self.na_strategy is None:
176184
return feat
177-
for col in range(feat.size(1)):
178-
column_data = feat[:, col]
179-
if isinstance(feat, _MultiTensor):
180-
column_data = column_data.values
181-
if column_data.is_floating_point():
182-
nan_mask = torch.isnan(column_data)
185+
186+
# Since we are not changing the number of items in each column, it's
187+
# faster to just clone the values, while reusing the same offset
188+
# object.
189+
if isinstance(feat, Tensor):
190+
if get_na_mask(feat).any():
191+
feat = feat.clone()
192+
else:
193+
return feat
194+
elif isinstance(feat, MultiEmbeddingTensor):
195+
if get_na_mask(feat.values).any():
196+
feat = MultiEmbeddingTensor(num_rows=feat.num_rows,
197+
num_cols=feat.num_cols,
198+
values=feat.values.clone(),
199+
offset=feat.offset)
200+
else:
201+
return feat
202+
elif isinstance(feat, MultiNestedTensor):
203+
if get_na_mask(feat.values).any():
204+
feat = MultiNestedTensor(num_rows=feat.num_rows,
205+
num_cols=feat.num_cols,
206+
values=feat.values.clone(),
207+
offset=feat.offset)
183208
else:
184-
nan_mask = column_data == -1
185-
if nan_mask.ndim == 2:
186-
nan_mask = nan_mask.any(dim=-1)
187-
assert nan_mask.ndim == 1
188-
assert len(nan_mask) == len(column_data)
189-
if not nan_mask.any():
190-
continue
209+
return feat
210+
else:
211+
raise ValueError(f"Unrecognized type {type(feat)} in na_forward.")
212+
213+
# TODO: Remove for-loop over columns
214+
for col in range(feat.size(1)):
191215
if self.na_strategy == NAStrategy.MOST_FREQUENT:
192216
# Categorical index is sorted based on count,
193217
# so 0-th index is always the most frequent.
@@ -210,7 +234,13 @@ def na_forward(self, feat: TensorData) -> TensorData:
210234
if isinstance(feat, _MultiTensor):
211235
feat.fillna_col(col, fill_value)
212236
else:
213-
column_data[nan_mask] = fill_value
237+
column_data = feat[:, col]
238+
na_mask = get_na_mask(column_data)
239+
if na_mask.ndim == 2:
240+
na_mask = na_mask.any(dim=-1)
241+
assert na_mask.ndim == 1
242+
assert len(na_mask) == len(column_data)
243+
column_data[na_mask] = fill_value
214244
# Add better safeguard here to make sure nans are actually
215245
# replaced, expecially when nans are represented as -1's. They are
216246
# very hard to catch as they won't error out.
@@ -339,11 +369,10 @@ def encode_forward(
339369
# Increment the index by one so that NaN index (-1) becomes 0
340370
# (padding_idx)
341371
# feat: [batch_size, num_cols]
342-
feat.values = feat.values + 1
343372
xs = []
344373
for i, emb in enumerate(self.embs):
345374
col_feat = feat[:, i]
346-
xs.append(emb(col_feat.values, col_feat.offset[:-1]))
375+
xs.append(emb(col_feat.values + 1, col_feat.offset[:-1]))
347376
# [batch_size, num_cols, hidden_channels]
348377
x = torch.stack(xs, dim=1)
349378
return x

0 commit comments

Comments
 (0)