Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions bcolz_array_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class BcolzArrayIterator(object):
y = bcolz.open('file_path/label_file.bc', mode='r')
trn_batches = BcolzArrayIterator(X, y, batch_size=64, shuffle=True)
model.fit_generator(generator=trn_batches, samples_per_epoch=trn_batches.N, nb_epoch=1)
:param X: Input features
:param X_ftrs: Array of input features
:param y: (optional) Input labels
:param w: (optional) Input feature weights
:param batch_size: (optional) Batch size, defaults to 32
Expand All @@ -31,7 +31,15 @@ class BcolzArrayIterator(object):
True
"""

def __init__(self, X, y=None, w=None, batch_size=32, shuffle=False, seed=None):
def __init__(self, X_ftrs, y=None, w=None, batch_size=32, shuffle=False, seed=None):
if isinstance(X_ftrs, bcolz.carray):
self.onefeature = True
X_ftrs = [X_ftrs]
else:
self.onefeature = False
for X in X_ftrs:
if X is None or len(X) != len(X_ftrs[0]):
raise ValueError('X (features) should have the same length')
if y is not None and len(X) != len(y):
raise ValueError('X (features) and y (labels) should have the same length'
'Found: X.shape = %s, y.shape = %s' % (X.shape, y.shape))
Expand All @@ -41,8 +49,9 @@ def __init__(self, X, y=None, w=None, batch_size=32, shuffle=False, seed=None):
if batch_size % X.chunklen != 0:
raise ValueError('batch_size needs to be a multiple of X.chunklen')

self.X_ftrs = X_ftrs
self.chunks_per_batch = batch_size // X.chunklen
self.X = X
self.nchunks = X.nchunks
self.y = y if y is not None else None
self.w = w if w is not None else None
self.N = X.shape[0]
Expand All @@ -62,37 +71,42 @@ def next(self):
if self.batch_index == 0:
if self.seed is not None:
np.random.seed(self.seed + self.total_batches_seen)
self.index_array = (np.random.permutation(self.X.nchunks + 1) if self.shuffle
else np.arange(self.X.nchunks + 1))
self.index_array = (np.random.permutation(self.nchunks + 1) if self.shuffle
else np.arange(self.nchunks + 1))

#batches_x = np.zeros((self.batch_size,)+self.X.shape[1:])
batches_x, batches_y, batches_w = [],[],[]
batches_x, batches_y, batches_w = [], [], []
for i in range(len(self.X_ftrs)):
batches_x.append([])
for i in range(self.chunks_per_batch):
current_index = self.index_array[self.batch_index]
if current_index == self.X.nchunks:
batches_x.append(self.X.leftover_array[:self.X.leftover_elements])
current_batch_size = self.X.leftover_elements
if current_index == self.nchunks:
for idx, X in enumerate(self.X_ftrs):
batches_x[idx].append(X.leftover_array[:X.leftover_elements])
current_batch_size = X.leftover_elements
else:
batches_x.append(self.X.chunks[current_index][:])
current_batch_size = self.X.chunklen
for idx, X in enumerate(self.X_ftrs):
batches_x[idx].append(X.chunks[current_index][:])
current_batch_size = X.chunklen
self.batch_index += 1
self.total_batches_seen += 1

idx = current_index * self.X.chunklen
idx = current_index * X.chunklen
if not self.y is None: batches_y.append(self.y[idx: idx + current_batch_size])
if not self.w is None: batches_w.append(self.w[idx: idx + current_batch_size])
if self.batch_index >= len(self.index_array):
self.batch_index = 0
break

batch_x = np.concatenate(batches_x)
if self.y is None: return batch_x
batches_x = [np.concatenate(b) for b in batches_x]
if self.onefeature:
batches_x = batches_x[0]
if self.y is None: return batches_x

batch_y = np.concatenate(batches_y)
if self.w is None: return batch_x, batch_y
if self.w is None: return batches_x, batch_y

batch_w = np.concatenate(batches_w)
return batch_x, batch_y, batch_w
return batches_x, batch_y, batch_w


def __iter__(self): return self
Expand Down