-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathdata_utils.py
More file actions
68 lines (49 loc) · 1.89 KB
/
data_utils.py
File metadata and controls
68 lines (49 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gc
import os
import time
import math
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from functools import partial
from collections import Counter
from typing import Dict, List, Optional, Tuple, Callable, Union
class DatasetLoader:
def __init__(self, batch_size: int, dataset: list, columns: list = None, maps: dict = {}):
self.batch_size, self.dataset = batch_size, dataset
self.columns = columns or set([x for r in dataset for x in r.keys()])
self.maps = dict([(key, maps[key]) if key in maps.keys() else (key, lambda x: x) for key in self.columns])
self.n = 0
self.max_n = len(self.dataset)
def __iter__(self):
return self
def __next__(self):
outputs = dict([(key, []) for key in self.columns])
for i in range(self.batch_size):
if self.n >= self.max_n:
self.cleanup()
raise StopIteration
for key in self.columns:
outputs[key].append(self.maps[key](self.dataset[self.n][key]))
self.n += 1
return outputs
def cleanup(self):
if hasattr(self, 'dataset') and self.dataset is not None:
if hasattr(self.dataset, 'clear'):
self.dataset.clear()
self.dataset = None
gc.collect()
def __del__(self):
self.cleanup()
def load_parquet(base_path, dataset_name, batch_size: int, columns: list, maps: dict = {}):
full_path = f"{base_path}{dataset_name}"
table = pq.read_table(full_path, columns=columns)
buffer = table.to_pylist()
del table
try:
pa.default_memory_pool().release_unused()
except:
pass
gc.collect()
return DatasetLoader(batch_size, buffer, columns, maps), len(buffer)