Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ try net.addReLU();
try net.addLinear(3);
try net.addSoftmax();

try net.train(300, 0.01, features, labels);
try net.train(300, 0.01, 32, features, labels);
```

## Development
Expand Down
36 changes: 27 additions & 9 deletions src/layer/linear.zig
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,31 @@ pub fn Linear(comptime T: type) type {
return self.activations;
}

/// Zero out accumulated gradients. Call this at the start of each batch.
pub fn zeroGradients(self: *Self) void {
self.weights_grad.zeros();
self.biases_grad.zeros();
}

/// Compute input, weight and bias gradients given upstream gradient of
/// the followup layers.
/// the followup layers. Gradients are accumulated (added to existing).
pub fn backward(self: *Self, input: Matrix(T), err_grad: Matrix(T)) Matrix(T) {
// dC/db = err_grad
self.biases_grad.copy(err_grad);
// dC/db = err_grad (accumulate)
for (self.biases_grad.elements, err_grad.elements) |*bg, eg| {
bg.* += eg;
}

// dC/dw = input^T @ err_grad
// dC/dw = input^T @ err_grad (accumulate)
input.transpose(&self.inputs_t);
self.inputs_t.multiply(err_grad, &self.weights_grad);
for (0..self.inputs_t.rows) |i| {
for (0..err_grad.columns) |j| {
var v: T = 0;
for (0..self.inputs_t.columns) |k| {
v += self.inputs_t.get(i, k) * err_grad.get(k, j);
}
self.weights_grad.set(i, j, self.weights_grad.get(i, j) + v);
}
}

// dC/di = err_grad @ weights^T
self.weights.transpose(&self.weights_t);
Expand All @@ -110,13 +126,15 @@ pub fn Linear(comptime T: type) type {
return self.inputs_grad;
}

// Apply weight and bias gradients to layer.
pub fn applyGradients(self: *Self, learning_rate: f32) void {
/// Apply accumulated weight and bias gradients to layer, scaled by
/// learning rate and batch size.
pub fn applyGradients(self: *Self, learning_rate: f32, batch_size: usize) void {
const scale = learning_rate / @as(f32, @floatFromInt(batch_size));
for (self.weights.elements, self.weights_grad.elements) |*w, g| {
w.* += g * learning_rate;
w.* += g * scale;
}
for (self.biases.elements, self.biases_grad.elements) |*b, g| {
b.* += g * learning_rate;
b.* += g * scale;
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub fn main() !void {
try net.addLinear(3);
try net.addSoftmax();

try net.train(300, 0.01, X_train, y_train, stdout);
try net.train(300, 0.01, 32, X_train, y_train, stdout);

const predictions = try net.predictBatch(X_test);
const acc = accuracy(f32, predictions, y_test);
Expand Down
60 changes: 45 additions & 15 deletions src/net.zig
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,16 @@ pub fn Network(comptime T: type) type {
return predictions;
}

/// Propagate gradient through layers and adjust parameters.
pub fn backward(self: Self, input: Matrix(T), grad: Matrix(T), learning_rate: f32) void {
/// Zero gradients in all linear layers. Call at start of each batch.
pub fn zeroGradients(self: Self) void {
for (self.layers.items) |*layer| {
if (layer.* == .linear)
layer.linear.zeroGradients();
}
}

/// Propagate gradient through layers (accumulates gradients).
pub fn backward(self: Self, input: Matrix(T), grad: Matrix(T)) void {
var err_grad = grad;
var i = self.layers.items.len;
while (i > 0) : (i -= 1) {
Expand All @@ -113,17 +121,23 @@ pub fn Network(comptime T: type) type {
layer.backward(self.layers.items[i - 2].activation(), err_grad)
else
layer.backward(input, err_grad);
}
}

if (layer == .linear)
layer.linear.applyGradients(learning_rate);
/// Apply accumulated gradients to all linear layers.
pub fn applyGradients(self: Self, learning_rate: f32, batch_size: usize) void {
for (self.layers.items) |*layer| {
if (layer.* == .linear)
layer.linear.applyGradients(learning_rate, batch_size);
}
}

/// Train the network for fixed number of epochs.
pub fn train(self: Self, epochs: usize, learning_rate: f32, input: Matrix(T), labels: Matrix(T), writer: *std.Io.Writer) !void {
/// Train the network for fixed number of epochs using mini-batch gradient descent.
pub fn train(self: Self, epochs: usize, learning_rate: f32, batch_size: usize, input: Matrix(T), labels: Matrix(T), writer: *std.Io.Writer) !void {
assert(input.rows == labels.rows);
assert(input.columns == self.inputs);
assert(labels.columns == self.outputs);
assert(batch_size > 0);

// TODO Make loss function configurable
var cost_fn = try MeanSquaredError(f32).init(self.allocator, self.outputs);
Expand All @@ -135,14 +149,30 @@ pub fn Network(comptime T: type) type {
for (0..epochs) |e| {
var loss_per_epoch: f32 = 0;

for (0..num_samples) |r| {
const X = input.getRow(r);
const y = labels.getRow(r);
const prediction = self.predict(X);
loss_per_epoch += cost_fn.computeLoss(prediction, y);
// Process data in mini-batches
var batch_start: usize = 0;
while (batch_start < num_samples) {
const batch_end = @min(batch_start + batch_size, num_samples);
const current_batch_size = batch_end - batch_start;

// Zero gradients at start of batch
self.zeroGradients();

// Accumulate gradients over batch
for (batch_start..batch_end) |r| {
const X = input.getRow(r);
const y = labels.getRow(r);
const prediction = self.predict(X);
loss_per_epoch += cost_fn.computeLoss(prediction, y);

const err_grad = cost_fn.computeGradient(prediction, y);
self.backward(X, err_grad);
}

// Apply accumulated gradients once per batch
self.applyGradients(learning_rate, current_batch_size);

const err_grad = cost_fn.computeGradient(prediction, y);
self.backward(X, err_grad, learning_rate);
batch_start = batch_end;
}

const avg_loss_per_epoch = loss_per_epoch / @as(f32, @floatFromInt(num_samples));
Expand All @@ -151,7 +181,7 @@ pub fn Network(comptime T: type) type {
}

const elapsed_seconds: f32 = @as(f32, @floatFromInt(std.time.milliTimestamp() - start)) / 1000;
try writer.print("Training took {d:.2} seconds.\n", .{elapsed_seconds});
try writer.print("Training took {d:.2} seconds (batch_size={d}).\n", .{ elapsed_seconds, batch_size });
try writer.flush();
}
};
Expand Down Expand Up @@ -224,5 +254,5 @@ test "Train network" {
const labels = Matrix(f32).fromSlice(1, 4, &labels_data);

var discarding = std.Io.Writer.Discarding.init(&.{});
try net.train(40, 0.001, input, labels, &discarding.writer);
try net.train(40, 0.001, 1, input, labels, &discarding.writer);
}