1
+ import copy
2
+
1
3
import pytest
2
4
import torch
3
5
from torch .nn import ReLU
4
6
5
7
import torch_frame
6
8
from torch_frame import NAStrategy , stype
7
9
from torch_frame .config import ModelConfig
8
- from torch_frame .config .text_embedder import TextEmbedderConfig
9
10
from torch_frame .config .text_tokenizer import TextTokenizerConfig
10
11
from torch_frame .data .dataset import Dataset
11
12
from torch_frame .data .stats import StatType
22
23
StackEncoder ,
23
24
TimestampEncoder ,
24
25
)
25
- from torch_frame .testing .text_embedder import HashTextEmbedder
26
26
from torch_frame .testing .text_tokenizer import (
27
27
RandomTextModel ,
28
28
WhiteSpaceHashTokenizer ,
@@ -44,10 +44,12 @@ def test_categorical_feature_encoder(encoder_cls_kwargs):
44
44
stype = stype .categorical ,
45
45
** encoder_cls_kwargs [1 ],
46
46
)
47
- feat_cat = tensor_frame .feat_dict [stype .categorical ]
47
+ feat_cat = tensor_frame .feat_dict [stype .categorical ]. clone ()
48
48
col_names = tensor_frame .col_names_dict [stype .categorical ]
49
49
x = encoder (feat_cat , col_names )
50
50
assert x .shape == (feat_cat .size (0 ), feat_cat .size (1 ), 8 )
51
+ # Make sure no in-place modification
52
+ assert torch .allclose (feat_cat , tensor_frame .feat_dict [stype .categorical ])
51
53
52
54
# Perturb the first column
53
55
num_categories = len (stats_list [0 ][StatType .COUNT ])
@@ -96,10 +98,12 @@ def test_numerical_feature_encoder(encoder_cls_kwargs):
96
98
stype = stype .numerical ,
97
99
** encoder_cls_kwargs [1 ],
98
100
)
99
- feat_num = tensor_frame .feat_dict [stype .numerical ]
101
+ feat_num = tensor_frame .feat_dict [stype .numerical ]. clone ()
100
102
col_names = tensor_frame .col_names_dict [stype .numerical ]
101
103
x = encoder (feat_num , col_names )
102
104
assert x .shape == (feat_num .size (0 ), feat_num .size (1 ), 8 )
105
+ # Make sure no in-place modification
106
+ assert torch .allclose (feat_num , tensor_frame .feat_dict [stype .numerical ])
103
107
if "post_module" in encoder_cls_kwargs [1 ]:
104
108
assert encoder .post_module is not None
105
109
else :
@@ -142,9 +146,16 @@ def test_multicategorical_feature_encoder(encoder_cls_kwargs):
142
146
stype = stype .multicategorical ,
143
147
** encoder_cls_kwargs [1 ],
144
148
)
145
- feat_multicat = tensor_frame .feat_dict [stype .multicategorical ]
149
+ feat_multicat = tensor_frame .feat_dict [stype .multicategorical ]. clone ()
146
150
col_names = tensor_frame .col_names_dict [stype .multicategorical ]
147
151
x = encoder (feat_multicat , col_names )
152
+ # Make sure no in-place modification
153
+ assert torch .allclose (
154
+ feat_multicat .values ,
155
+ tensor_frame .feat_dict [stype .multicategorical ].values )
156
+ assert torch .allclose (
157
+ feat_multicat .offset ,
158
+ tensor_frame .feat_dict [stype .multicategorical ].offset )
148
159
assert x .shape == (feat_multicat .size (0 ), feat_multicat .size (1 ), 8 )
149
160
150
161
# Perturb the first column
@@ -178,9 +189,12 @@ def test_timestamp_feature_encoder(encoder_cls_kwargs):
178
189
stype = stype .timestamp ,
179
190
** encoder_cls_kwargs [1 ],
180
191
)
181
- feat_timestamp = tensor_frame .feat_dict [stype .timestamp ]
192
+ feat_timestamp = tensor_frame .feat_dict [stype .timestamp ]. clone ()
182
193
col_names = tensor_frame .col_names_dict [stype .timestamp ]
183
194
x = encoder (feat_timestamp , col_names )
195
+ # Make sure no in-place modification
196
+ assert torch .allclose (feat_timestamp ,
197
+ tensor_frame .feat_dict [stype .timestamp ])
184
198
assert x .shape == (feat_timestamp .size (0 ), feat_timestamp .size (1 ), 8 )
185
199
186
200
@@ -324,40 +338,6 @@ def test_timestamp_feature_encoder_with_nan(encoder_cls_kwargs):
324
338
assert (~ torch .isnan (x )).all ()
325
339
326
340
327
- def test_text_embedded_encoder ():
328
- num_rows = 20
329
- text_emb_channels = 10
330
- out_channels = 5
331
- dataset = FakeDataset (
332
- num_rows = num_rows ,
333
- stypes = [
334
- torch_frame .text_embedded ,
335
- ],
336
- col_to_text_embedder_cfg = TextEmbedderConfig (
337
- text_embedder = HashTextEmbedder (text_emb_channels ),
338
- batch_size = None ),
339
- )
340
- dataset .materialize ()
341
- tensor_frame = dataset .tensor_frame
342
- stats_list = [
343
- dataset .col_stats [col_name ]
344
- for col_name in tensor_frame .col_names_dict [stype .embedding ]
345
- ]
346
- encoder = LinearEmbeddingEncoder (
347
- out_channels = out_channels ,
348
- stats_list = stats_list ,
349
- stype = stype .embedding ,
350
- )
351
- feat_text = tensor_frame .feat_dict [stype .embedding ]
352
- col_names = tensor_frame .col_names_dict [stype .embedding ]
353
- feat = encoder (feat_text , col_names )
354
- assert feat .shape == (
355
- num_rows ,
356
- len (tensor_frame .col_names_dict [stype .embedding ]),
357
- out_channels ,
358
- )
359
-
360
-
361
341
def test_embedding_encoder ():
362
342
num_rows = 20
363
343
out_channels = 5
@@ -378,9 +358,14 @@ def test_embedding_encoder():
378
358
stats_list = stats_list ,
379
359
stype = stype .embedding ,
380
360
)
381
- feat_text = tensor_frame .feat_dict [stype .embedding ]
361
+ feat_emb = tensor_frame .feat_dict [stype .embedding ]. clone ()
382
362
col_names = tensor_frame .col_names_dict [stype .embedding ]
383
- x = encoder (feat_text , col_names )
363
+ x = encoder (feat_emb , col_names )
364
+ # Make sure no in-place modification
365
+ assert torch .allclose (feat_emb .values ,
366
+ tensor_frame .feat_dict [stype .embedding ].values )
367
+ assert torch .allclose (feat_emb .offset ,
368
+ tensor_frame .feat_dict [stype .embedding ].offset )
384
369
assert x .shape == (
385
370
num_rows ,
386
371
len (tensor_frame .col_names_dict [stype .embedding ]),
@@ -421,11 +406,23 @@ def test_text_tokenized_encoder():
421
406
stype = stype .text_tokenized ,
422
407
col_to_model_cfg = col_to_model_cfg ,
423
408
)
424
- feat_text = tensor_frame .feat_dict [stype .text_tokenized ]
409
+ feat_text = copy . deepcopy ( tensor_frame .feat_dict [stype .text_tokenized ])
425
410
col_names = tensor_frame .col_names_dict [stype .text_tokenized ]
426
411
x = encoder (feat_text , col_names )
427
412
assert x .shape == (
428
413
num_rows ,
429
414
len (tensor_frame .col_names_dict [stype .text_tokenized ]),
430
415
out_channels ,
431
416
)
417
+ # Make sure no in-place modification
418
+ assert isinstance (feat_text , dict ) and isinstance (
419
+ tensor_frame .feat_dict [stype .text_tokenized ], dict )
420
+ assert feat_text .keys () == tensor_frame .feat_dict [
421
+ stype .text_tokenized ].keys ()
422
+ for key in feat_text .keys ():
423
+ assert torch .allclose (
424
+ feat_text [key ].values ,
425
+ tensor_frame .feat_dict [stype .text_tokenized ][key ].values )
426
+ assert torch .allclose (
427
+ feat_text [key ].offset ,
428
+ tensor_frame .feat_dict [stype .text_tokenized ][key ].offset )
0 commit comments