Skip to content

Commit 6ae8350

Browse files
committed
Make it work with datasets<4.0
1 parent 4fc9ba6 commit 6ae8350

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

delphi/utils.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import Any, TypeVar, cast
22

3-
import datasets
43
import numpy as np
54
import torch
6-
from datasets.table import table_iter
75
from torch import Tensor
86
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
97

@@ -48,15 +46,20 @@ def load_tokenized_data(
4846

4947
tokens = tokens_ds["input_ids"]
5048

51-
if isinstance(tokens, datasets.Column):
52-
tokens = torch.cat(
53-
[
54-
torch.from_numpy(np.stack(table_chunk["input_ids"].to_numpy(), axis=0))
55-
for table_chunk in table_iter(
56-
tokens.source._data, convert_to_tensor_chunk_size
57-
)
58-
]
59-
)
49+
try:
50+
from datasets import Column
51+
if isinstance(tokens, Column):
52+
from datasets.table import table_iter
53+
tokens = torch.cat(
54+
[
55+
torch.from_numpy(np.stack(table_chunk["input_ids"].to_numpy(), axis=0))
56+
for table_chunk in table_iter(
57+
tokens.source._data, convert_to_tensor_chunk_size
58+
)
59+
]
60+
)
61+
except ImportError:
62+
assert len(tokens.shape) == 2
6063

6164
return tokens
6265

0 commit comments

Comments
 (0)