From 259ab72c775d6c974a160ac2de97fb72e596bdd8 Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Tue, 22 Dec 2015 16:18:06 +0100 Subject: [PATCH] Bugfix: brainstorm.tools.extract crashed when a mask was provided by the data iterator. Also reimplemented extract_and_save in terms of extract (fixes the very same bug). --- brainstorm/tests/test_tools.py | 81 ++++++++++++++++++++++++++++++++++ brainstorm/tools.py | 33 ++++---------- 2 files changed, 89 insertions(+), 25 deletions(-) create mode 100644 brainstorm/tests/test_tools.py diff --git a/brainstorm/tests/test_tools.py b/brainstorm/tests/test_tools.py new file mode 100644 index 0000000..101d019 --- /dev/null +++ b/brainstorm/tests/test_tools.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# coding=utf-8 + +from __future__ import division, print_function, unicode_literals + +import numpy as np +import pytest + +from brainstorm import Network +from brainstorm.handlers import NumpyHandler +from brainstorm.tools import extract +from brainstorm.data_iterators import Minibatches +from brainstorm.initializers import Gaussian +from brainstorm.layers import Input,FullyConnected + +def get_simple_net(): + inp = Input(out_shapes={'default': ('T', 'B', 4)}) + out = FullyConnected(2, name='Output', activation='tanh') + simple_net = Network.from_layer(inp >> out) + return simple_net + +def get_simple_net_with_mask(): + inp = Input(out_shapes={'default': ('T', 'B', 4),'mask': ('T','B',1)}) + out = FullyConnected(2, name='Output', activation='tanh') + simple_net = Network.from_layer(inp >> out) + return simple_net + +def test_extract(): + net = get_simple_net() + net.initialize(Gaussian(0.1)) + batch_size = 6 + input_data = np.random.rand(10,6,4).astype(np.float32) + + # compute expected result + layer_W = net.buffer['Output']['parameters']['W'] + layer_bias = net.buffer['Output']['parameters']['bias'] + + expected_result = np.tanh(np.dot(input_data,layer_W.T) + layer_bias) + + # run extract + data_iterator = Minibatches(batch_size,default=input_data) + + extracted_data = extract(net,data_iterator,['Output.outputs.default']) + + assert expected_result.shape == extracted_data['Output.outputs.default'].shape + assert np.allclose(expected_result,extracted_data['Output.outputs.default']) + +def test_extract_with_mask(): + net = get_simple_net_with_mask() +# pytest.set_trace() + net.initialize(Gaussian(0.1)) + batch_size = 6 + input_data = np.random.rand(10,6,4).astype(np.float32) + # set some mask + input_mask = np.zeros((10,6,1),dtype=np.float32) + input_mask[0:5,0,0] = 1 + input_mask[0:3,0,0] = 1 + input_mask[0:8,0,0] = 1 + input_mask[0:7,0,0] = 1 + input_mask[0:2,0,0] = 1 + input_mask[0:3,0,0] = 1 + + # compute expected result WITHOUT mask + layer_W = net.buffer['Output']['parameters']['W'] + layer_bias = net.buffer['Output']['parameters']['bias'] + + expected_result = np.tanh(np.dot(input_data,layer_W.T) + layer_bias) + + # run extract + data_iterator = Minibatches(batch_size,default=input_data,mask=input_mask) + + extracted_data = extract(net,data_iterator,['Output.outputs.default']) + +# pytest.set_trace() + assert expected_result.shape == extracted_data['Output.outputs.default'].shape + # where the mask is 0, we don't care for the result + assert np.allclose(expected_result[input_mask.astype(bool)],extracted_data['Output.outputs.default'][input_mask.astype(bool)]) + + + + diff --git a/brainstorm/tools.py b/brainstorm/tools.py index fbfb1ed..0bcc0dc 100644 --- a/brainstorm/tools.py +++ b/brainstorm/tools.py @@ -103,6 +103,7 @@ def extract(network, iter, buffer_names): if isinstance(buffer_names, six.string_types): buffer_names = [buffer_names] + time_steps = iter.data_shapes.values()[0][0] nr_examples = iter.data_shapes.values()[0][1] return_data = {} nr_items = 0 @@ -114,9 +115,9 @@ def extract(network, iter, buffer_names): if num == 0: nr_items += data.shape[1] if first_pass: - data_shape = (data.shape[0], nr_examples) + data.shape[2:] + data_shape = (time_steps, nr_examples) + data.shape[2:] return_data[buffer_name] = np.zeros(data_shape, data.dtype) - return_data[buffer_name][:, nr_items - data.shape[1]:nr_items] = \ + return_data[buffer_name][0:data.shape[0], nr_items - data.shape[1]:nr_items] = \ data return return_data @@ -149,32 +150,14 @@ def extract_and_save(network, iter, buffer_names, file_name): Name of the hdf5 file (including extension) in which the features should be saved. """ - iterator = iter(handler=network.handler) - if isinstance(buffer_names, six.string_types): - buffer_names = [buffer_names] - nr_items = 0 - ds = [] - + + extracted_data = extract(network,iter,buffer_names) with h5py.File(file_name, 'w') as f: f.attrs.create('info', get_brainstorm_info()) f.attrs.create('format', b'Buffers file v1.0') - - for _ in run_network(network, iterator, all_inputs=False): - network.forward_pass() - first_pass = False if len(ds) > 0 else True - for num, buffer_name in enumerate(buffer_names): - data = network.get(buffer_name) - if num == 0: - nr_items += data.shape[1] - if first_pass: - ds.append(f.create_dataset( - buffer_name, data.shape, data.dtype, chunks=data.shape, - maxshape=(data.shape[0], None) + data.shape[2:])) - ds[-1][:] = data - else: - ds[num].resize(size=nr_items, axis=1) - ds[num][:, nr_items - data.shape[1]:nr_items, ...] = data - + + for buffer_name in buffer_names: + f.create_dataset(buffer_name,data=extracted_data[buffer_name]) def get_in_out_layers(task_type, in_shape, out_shape, data_name='default', targets_name='targets', projection_name=None,