From 1e078ffa256a56eb3d92681677f84d49a2b6bf31 Mon Sep 17 00:00:00 2001 From: Craig Kolb Date: Fri, 13 Jun 2025 17:20:26 -0700 Subject: [PATCH] Add LinearLayer.set_weights() and .set_biases() --- .../neuralnetworks/components/LinearLayer.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/experiments/neuralnetwork/neuralnetworks/components/LinearLayer.py b/experiments/neuralnetwork/neuralnetworks/components/LinearLayer.py index 9cdd5e1..b0e4dae 100644 --- a/experiments/neuralnetwork/neuralnetworks/components/LinearLayer.py +++ b/experiments/neuralnetwork/neuralnetworks/components/LinearLayer.py @@ -33,6 +33,22 @@ def __init__(self, num_inputs: AutoSettable[int], num_outputs: int, dtype: AutoS self._dtype = dtype self._use_coopvec = use_coopvec + def set_weights(self, weights_np: np.ndarray[Any, Any]): + self.check_initialized() + if weights_np.shape != (self.num_outputs, self.num_inputs): + raise ValueError(f"LinearLayer weights must have shape ({self.num_outputs}, {self.num_inputs}), rather than {weights_np.shape}") + + if self.use_coopvec: + layout = CoopVecMatrixLayout.training_optimal + desc = self.weights.device.coopvec_create_matrix_desc(self.num_outputs, self.num_inputs, layout, self.dtype.sgl(), 0) + weight_count = desc.size // self.dtype.size() + + params_np = np.zeros((weight_count, ), dtype=self.dtype.numpy()) + self.weights.device.coopvec_convert_matrix_host(weights_np, params_np, dst_layout=layout) + self.weights.storage.copy_from_numpy(params_np) + else: + self.weights.storage.copy_from_numpy(weights_np) + def get_weights(self) -> np.ndarray[Any, Any]: self.check_initialized() weights_np = self.weights.to_numpy() @@ -48,6 +64,13 @@ def get_weights(self) -> np.ndarray[Any, Any]: else: return weights_np + def set_biases(self, biases_np: np.ndarray[Any, Any]): + self.check_initialized() + if biases_np.shape != (self.num_outputs,): + raise ValueError(f"LinearLayer biases must have shape ({self.num_outputs}), rather than {biases_np.shape}") + + self.biases.storage.copy_from_numpy(biases_np) + def get_biases(self) -> np.ndarray[Any, Any]: self.check_initialized() return self.biases.to_numpy()