File tree Expand file tree Collapse file tree 1 file changed +14
-11
lines changed Expand file tree Collapse file tree 1 file changed +14
-11
lines changed Original file line number Diff line number Diff line change 11from typing import Any , TypeVar , cast
22
3- import datasets
43import numpy as np
54import torch
6- from datasets .table import table_iter
75from torch import Tensor
86from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
97
@@ -48,15 +46,20 @@ def load_tokenized_data(
4846
4947 tokens = tokens_ds ["input_ids" ]
5048
51- if isinstance (tokens , datasets .Column ):
52- tokens = torch .cat (
53- [
54- torch .from_numpy (np .stack (table_chunk ["input_ids" ].to_numpy (), axis = 0 ))
55- for table_chunk in table_iter (
56- tokens .source ._data , convert_to_tensor_chunk_size
57- )
58- ]
59- )
49+ try :
50+ from datasets import Column
51+ if isinstance (tokens , Column ):
52+ from datasets .table import table_iter
53+ tokens = torch .cat (
54+ [
55+ torch .from_numpy (np .stack (table_chunk ["input_ids" ].to_numpy (), axis = 0 ))
56+ for table_chunk in table_iter (
57+ tokens .source ._data , convert_to_tensor_chunk_size
58+ )
59+ ]
60+ )
61+ except ImportError :
62+ assert len (tokens .shape ) == 2
6063
6164 return tokens
6265
You can’t perform that action at this time.
0 commit comments