diff --git a/README.md b/README.md index 1bd6a33..a304251 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ try net.addReLU(); try net.addLinear(3); try net.addSoftmax(); -try net.train(300, 0.01, features, labels); +try net.train(300, 0.01, 32, features, labels); ``` ## Development diff --git a/src/layer/linear.zig b/src/layer/linear.zig index 9443d36..d326604 100644 --- a/src/layer/linear.zig +++ b/src/layer/linear.zig @@ -93,15 +93,31 @@ pub fn Linear(comptime T: type) type { return self.activations; } + /// Zero out accumulated gradients. Call this at the start of each batch. + pub fn zeroGradients(self: *Self) void { + self.weights_grad.zeros(); + self.biases_grad.zeros(); + } + /// Compute input, weight and bias gradients given upstream gradient of - /// the followup layers. + /// the followup layers. Gradients are accumulated (added to existing). pub fn backward(self: *Self, input: Matrix(T), err_grad: Matrix(T)) Matrix(T) { - // dC/db = err_grad - self.biases_grad.copy(err_grad); + // dC/db = err_grad (accumulate) + for (self.biases_grad.elements, err_grad.elements) |*bg, eg| { + bg.* += eg; + } - // dC/dw = input^T @ err_grad + // dC/dw = input^T @ err_grad (accumulate) input.transpose(&self.inputs_t); - self.inputs_t.multiply(err_grad, &self.weights_grad); + for (0..self.inputs_t.rows) |i| { + for (0..err_grad.columns) |j| { + var v: T = 0; + for (0..self.inputs_t.columns) |k| { + v += self.inputs_t.get(i, k) * err_grad.get(k, j); + } + self.weights_grad.set(i, j, self.weights_grad.get(i, j) + v); + } + } // dC/di = err_grad @ weights^T self.weights.transpose(&self.weights_t); @@ -110,13 +126,15 @@ pub fn Linear(comptime T: type) type { return self.inputs_grad; } - // Apply weight and bias gradients to layer. - pub fn applyGradients(self: *Self, learning_rate: f32) void { + /// Apply accumulated weight and bias gradients to layer, scaled by + /// learning rate and batch size. + pub fn applyGradients(self: *Self, learning_rate: f32, batch_size: usize) void { + const scale = learning_rate / @as(f32, @floatFromInt(batch_size)); for (self.weights.elements, self.weights_grad.elements) |*w, g| { - w.* += g * learning_rate; + w.* += g * scale; } for (self.biases.elements, self.biases_grad.elements) |*b, g| { - b.* += g * learning_rate; + b.* += g * scale; } } }; diff --git a/src/main.zig b/src/main.zig index ad80e44..603cc10 100644 --- a/src/main.zig +++ b/src/main.zig @@ -35,7 +35,7 @@ pub fn main() !void { try net.addLinear(3); try net.addSoftmax(); - try net.train(300, 0.01, X_train, y_train, stdout); + try net.train(300, 0.01, 32, X_train, y_train, stdout); const predictions = try net.predictBatch(X_test); const acc = accuracy(f32, predictions, y_test); diff --git a/src/net.zig b/src/net.zig index 6623d38..12d9621 100644 --- a/src/net.zig +++ b/src/net.zig @@ -102,8 +102,16 @@ pub fn Network(comptime T: type) type { return predictions; } - /// Propagate gradient through layers and adjust parameters. - pub fn backward(self: Self, input: Matrix(T), grad: Matrix(T), learning_rate: f32) void { + /// Zero gradients in all linear layers. Call at start of each batch. + pub fn zeroGradients(self: Self) void { + for (self.layers.items) |*layer| { + if (layer.* == .linear) + layer.linear.zeroGradients(); + } + } + + /// Propagate gradient through layers (accumulates gradients). + pub fn backward(self: Self, input: Matrix(T), grad: Matrix(T)) void { var err_grad = grad; var i = self.layers.items.len; while (i > 0) : (i -= 1) { @@ -113,17 +121,23 @@ pub fn Network(comptime T: type) type { layer.backward(self.layers.items[i - 2].activation(), err_grad) else layer.backward(input, err_grad); + } + } - if (layer == .linear) - layer.linear.applyGradients(learning_rate); + /// Apply accumulated gradients to all linear layers. + pub fn applyGradients(self: Self, learning_rate: f32, batch_size: usize) void { + for (self.layers.items) |*layer| { + if (layer.* == .linear) + layer.linear.applyGradients(learning_rate, batch_size); } } - /// Train the network for fixed number of epochs. - pub fn train(self: Self, epochs: usize, learning_rate: f32, input: Matrix(T), labels: Matrix(T), writer: *std.Io.Writer) !void { + /// Train the network for fixed number of epochs using mini-batch gradient descent. + pub fn train(self: Self, epochs: usize, learning_rate: f32, batch_size: usize, input: Matrix(T), labels: Matrix(T), writer: *std.Io.Writer) !void { assert(input.rows == labels.rows); assert(input.columns == self.inputs); assert(labels.columns == self.outputs); + assert(batch_size > 0); // TODO Make loss function configurable var cost_fn = try MeanSquaredError(f32).init(self.allocator, self.outputs); @@ -135,14 +149,30 @@ pub fn Network(comptime T: type) type { for (0..epochs) |e| { var loss_per_epoch: f32 = 0; - for (0..num_samples) |r| { - const X = input.getRow(r); - const y = labels.getRow(r); - const prediction = self.predict(X); - loss_per_epoch += cost_fn.computeLoss(prediction, y); + // Process data in mini-batches + var batch_start: usize = 0; + while (batch_start < num_samples) { + const batch_end = @min(batch_start + batch_size, num_samples); + const current_batch_size = batch_end - batch_start; + + // Zero gradients at start of batch + self.zeroGradients(); + + // Accumulate gradients over batch + for (batch_start..batch_end) |r| { + const X = input.getRow(r); + const y = labels.getRow(r); + const prediction = self.predict(X); + loss_per_epoch += cost_fn.computeLoss(prediction, y); + + const err_grad = cost_fn.computeGradient(prediction, y); + self.backward(X, err_grad); + } + + // Apply accumulated gradients once per batch + self.applyGradients(learning_rate, current_batch_size); - const err_grad = cost_fn.computeGradient(prediction, y); - self.backward(X, err_grad, learning_rate); + batch_start = batch_end; } const avg_loss_per_epoch = loss_per_epoch / @as(f32, @floatFromInt(num_samples)); @@ -151,7 +181,7 @@ pub fn Network(comptime T: type) type { } const elapsed_seconds: f32 = @as(f32, @floatFromInt(std.time.milliTimestamp() - start)) / 1000; - try writer.print("Training took {d:.2} seconds.\n", .{elapsed_seconds}); + try writer.print("Training took {d:.2} seconds (batch_size={d}).\n", .{ elapsed_seconds, batch_size }); try writer.flush(); } }; @@ -224,5 +254,5 @@ test "Train network" { const labels = Matrix(f32).fromSlice(1, 4, &labels_data); var discarding = std.Io.Writer.Discarding.init(&.{}); - try net.train(40, 0.001, input, labels, &discarding.writer); + try net.train(40, 0.001, 1, input, labels, &discarding.writer); }