Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ build: Changes to the build process or tools.

- Updated optimization tests to reflect line search changes
- Updated unit tests to reflect changes to synthetic data functions
- Updated unit tests to reflect changes to `Parameters.load()` method

### Build

Expand All @@ -50,6 +51,10 @@ build: Changes to the build process or tools.
- _Modified signature and tweaked almost all plotting functions (adjusted notebooks accordingly)_
- _Added plotting function to display histogram of prediction error (and added it to goodness of fit summary plots)_
- Changed name of `utils` to `utilities` and added `_sample.py` and `_finite_difference.py` modules
- Converted `NeuralNet.load(...)` to a classmethod such that `reloaded = NeuralNet.load("save_params.json")`
- _Previous pattern: `reloaded = NeuralNet(layer_sizes=[1, 2, 3]).load("save_params.json")`_
- Converted `Parameters.load(...)` to a classmethod such that `reloaded = Parameters.load("save_params.json")`
- _Previous pattern: `reloaded = Parameters(layer_sizes=[1, 2, 3]).load("save_params.json")`_

## v1.0.8 (2024-06-26)

Expand Down
29 changes: 11 additions & 18 deletions docs/examples/.ipynb_checkpoints/demo_1_sinusoid-checkpoint.ipynb

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions docs/examples/.ipynb_checkpoints/demo_2_rastrigin-checkpoint.ipynb

Large diffs are not rendered by default.

29 changes: 11 additions & 18 deletions docs/examples/demo_1_sinusoid.ipynb

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions docs/examples/demo_2_rastrigin.ipynb

Large diffs are not rendered by default.

17 changes: 12 additions & 5 deletions src/jenn/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# This work is licensed under the MIT License.

from pathlib import Path
from typing import Any
from typing import Any, Self

import numpy as np

Expand Down Expand Up @@ -230,7 +230,14 @@ def save(self, file: str | Path = "parameters.json") -> None:
"""Serialize parameters and save to JSON file."""
self.parameters.save(file)

def load(self, file: str | Path = "parameters.json") -> "NeuralNet":
"""Load previously saved parameters from json file."""
self.parameters.load(file)
return self
@classmethod
def load(cls, file: str | Path = "parameters.json") -> Self:
"""Load serialized parameters into a new NeuralNet instance."""
parameters = Parameters.load(file)
neural_net = cls(
layer_sizes=parameters.layer_sizes,
hidden_activation=parameters.hidden_activation,
output_activation=parameters.output_activation,
)
neural_net.parameters = parameters
return neural_net
54 changes: 37 additions & 17 deletions src/jenn/core/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Self

import jsonpointer
import jsonschema
Expand Down Expand Up @@ -333,31 +334,50 @@ def validate_parameters(self) -> None: # noqa: C901
raise ValueError(f"W[{i}] has the wrong shape (expected {(n, m)})")
m = n

def _deserialize(self, saved_parameters: bytes) -> None:
@classmethod
def _deserialize(cls, saved_parameters: bytes) -> dict:
"""Deserialize and apply saved parameters."""
params = orjson.loads(saved_parameters)
jsonschema.validate(params, SCHEMA)
self.W = [np.array(value) for value in params["W"]]
self.b = [np.array(value) for value in params["b"]]
self.a = params["a"]
self.mu_x = np.array(params["mu_x"])
self.mu_y = np.array(params["mu_y"])
self.sigma_x = np.array(params["sigma_x"])
self.sigma_y = np.array(params["sigma_y"])
self.layer_sizes = [W.shape[0] for W in self.W]
self.output_activation = self.a[-1]
self.hidden_activation = self.a[-2]
self.dW = [np.zeros(array.shape) for array in self.W]
self.db = [np.zeros(array.shape) for array in self.b]
self.validate_parameters()
return dict(
W=[np.array(value) for value in params["W"]],
b=[np.array(value) for value in params["b"]],
a=params["a"],
mu_x=np.array(params["mu_x"]),
mu_y=np.array(params["mu_y"]),
sigma_x=np.array(params["sigma_x"]),
sigma_y=np.array(params["sigma_y"]),
layer_sizes=[np.array(value).shape[0] for value in params["W"]],
output_activation=params["a"][-1],
hidden_activation=params["a"][-2],
dW=[np.zeros(np.array(value).shape) for value in params["W"]],
db=[np.zeros(np.array(value).shape) for value in params["b"]],
)

def save(self, binary_file: str | Path = "parameters.json") -> None:
"""Save parameters to specified json file."""
with Path(binary_file).open("wb") as file:
file.write(self._serialize())

def load(self, binary_file: str | Path = "parameters.json") -> None:
"""Load parameters from specified json file."""
@classmethod
def load(cls, binary_file: str | Path = "parameters.json") -> Self:
"""Load serialized parameters into a new Parameters instance.

:param binary_file: JSON file containing saved parameters
:type binary_file: str | Path
:return: a new instance of Parameters
:rtype: Parameters
"""
with Path(binary_file).open("rb") as file:
byte_stream = file.read()
self._deserialize(byte_stream)
attrs = cls._deserialize(byte_stream)
obj = cls(
layer_sizes=attrs["layer_sizes"],
hidden_activation=attrs["hidden_activation"],
output_activation=attrs["output_activation"],
)
obj.initialize()
for key, val in attrs.items():
setattr(obj, key, val)
obj.validate_parameters()
return obj
15 changes: 8 additions & 7 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TestSerialization:
"""Check that parameters can be saved and reloaded."""

@pytest.fixture
def params(self) -> jenn.core.parameters.Parameters:
def parameters(self) -> jenn.core.parameters.Parameters:
"""Return XOR parameters."""
parameters = jenn.core.parameters.Parameters(
layer_sizes=[2, 2, 1],
Expand All @@ -28,12 +28,13 @@ def params(self) -> jenn.core.parameters.Parameters:
parameters.W[2][:] = np.array([[1, -2]]) # layer 2
return parameters

def test_serialization(self, params: jenn.core.parameters.Parameters) -> None:
def test_serialization(self, parameters: jenn.core.parameters.Parameters) -> None:
"""Test that saved parameters can be reloaded into a new object."""
with tempfile.TemporaryDirectory() as tmpdirname:
tmpfile = pathlib.Path(tmpdirname) / "params.json"
params.save(tmpfile)
parameters = jenn.core.parameters.Parameters(params.layer_sizes)
assert params != parameters
parameters.load(tmpfile)
assert params == parameters
parameters.save(tmpfile)
new_instance = jenn.core.parameters.Parameters(parameters.layer_sizes)
new_instance.initialize()
assert parameters != new_instance
reloaded = jenn.core.parameters.Parameters.load(tmpfile)
assert parameters == reloaded