diff --git a/graph_playground.py b/graph_playground.py new file mode 100644 index 000000000..5b995feef --- /dev/null +++ b/graph_playground.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from graph import Grammar, mutations, parse, select, to_string + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + { + "s": (["a", "b", "p a", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + } +) + +root = parse(grammar_1, "s(p(s(a), a))") + +selections = list(select(root, how=("climb", range(1, 3)))) +mutants = mutations( + root=root, + grammar=grammar_1, + which=selections, + max_mutation_depth=3, +) +mutants = list(mutants) + +import rich + +rich.print("grammar", grammar_1) +rich.print("root", f"{to_string(root)}") +rich.print("selections", [to_string(s) for s in selections]) +rich.print("mutants", [to_string(m) for m in mutants]) diff --git a/neps/__init__.py b/neps/__init__.py index 756217609..408a33fc3 100644 --- a/neps/__init__.py +++ b/neps/__init__.py @@ -4,7 +4,7 @@ from neps.optimizers.optimizer import SampledConfig from neps.plot.plot import plot from neps.plot.tensorboard_eval import tblogger -from neps.space import Categorical, Constant, Float, Integer, SearchSpace +from neps.space import Categorical, Constant, Float, Grammar, Integer, SearchSpace from neps.state import BudgetInfo, Trial from neps.status.status import status from neps.utils.files import load_and_merge_yamls as load_yamls @@ -15,6 +15,7 @@ "Categorical", "Constant", "Float", + "Grammar", "Integer", "SampledConfig", "SearchSpace", diff --git a/neps/api.py b/neps/api.py index 77c8ebcf2..ae38ff59b 100644 --- a/neps/api.py +++ b/neps/api.py @@ -18,7 +18,7 @@ from ConfigSpace import ConfigurationSpace from neps.optimizers.algorithms import CustomOptimizer - from neps.space import Parameter, SearchSpace + from neps.space import Constant, Grammar, Parameter, SearchSpace from neps.state import EvaluatePipelineReturn logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ def run( # noqa: PLR0913 evaluate_pipeline: Callable[..., EvaluatePipelineReturn] | str, pipeline_space: ( - Mapping[str, dict | str | int | float | Parameter] + Mapping[str, dict | str | int | float | Parameter | Constant | Grammar] | SearchSpace | ConfigurationSpace ), diff --git a/neps/optimizers/algorithms.py b/neps/optimizers/algorithms.py index 6de8f67be..6be6ef2b2 100644 --- a/neps/optimizers/algorithms.py +++ b/neps/optimizers/algorithms.py @@ -82,7 +82,7 @@ def _bo( f" Got: {pipeline_space.fidelities}" ) - parameters = pipeline_space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} match initial_design_size: case "ndim": @@ -126,9 +126,6 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 sampler: Literal["uniform", "prior", "priorband"] | PriorBandSampler | Sampler, bayesian_optimization_kick_in_point: int | float | None, sample_prior_first: bool | Literal["highest_fidelity"], - # NOTE: This is the only argument to get a default, since it - # is not required for hyperband style algorithms, only single bracket - # style ones. early_stopping_rate: int | None, device: torch.device | None, ) -> BracketOptimizer: @@ -183,7 +180,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 """ assert pipeline_space.fidelity is not None fidelity_name, fidelity = pipeline_space.fidelity - parameters = pipeline_space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} if len(pipeline_space.fidelities) != 1: raise ValueError( @@ -324,9 +321,8 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 def determine_optimizer_automatically(space: SearchSpace) -> str: - has_prior = any( - parameter.prior is not None for parameter in space.searchables.values() - ) + parameters = {**space.numerical, **space.categoricals} + has_prior = any(parameter.prior is not None for parameter in parameters.values()) has_fidelity = len(space.fidelities) > 0 match (has_prior, has_fidelity): @@ -360,14 +356,18 @@ def random_search( In this case, the max fidelity is always used. """ if ignore_fidelity: - parameters = pipeline_space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} else: - parameters = {**pipeline_space.searchables, **pipeline_space.fidelities} + parameters = { + **pipeline_space.numerical, + **pipeline_space.categoricals, + **pipeline_space.fidelities, + } return RandomSearch( space=pipeline_space, encoder=ConfigEncoder.from_parameters(parameters), - sampler=( + numerical_sampler=( Prior.from_parameters(parameters) if use_priors else Uniform(ndim=len(parameters)) @@ -384,6 +384,9 @@ def grid_search(pipeline_space: SearchSpace) -> GridSearch: """ from neps.optimizers.utils.grid import make_grid + if pipeline_space.grammar is not None: + raise NotImplementedError("Grammars not supported for `grid_search` yet.") + return GridSearch(configs_list=make_grid(pipeline_space)) @@ -445,7 +448,7 @@ def ifbo( space, fid_bins = _adjust_space_to_match_stepsize(pipeline_space, step_size) assert space.fidelity is not None fidelity_name, fidelity = space.fidelity - parameters = space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} match initial_design_size: case "ndim": diff --git a/neps/optimizers/bayesian_optimization.py b/neps/optimizers/bayesian_optimization.py index ec556803d..ac4dd1b59 100644 --- a/neps/optimizers/bayesian_optimization.py +++ b/neps/optimizers/bayesian_optimization.py @@ -86,7 +86,11 @@ def __call__( n: int | None = None, ) -> SampledConfig | list[SampledConfig]: assert self.space.fidelity is None, "Fidelity not supported yet." - parameters = self.space.searchables + parameters = { + **self.space.numerical, + **self.space.categoricals, + **self.space.grammars, + } n_to_sample = 1 if n is None else n n_sampled = len(trials) diff --git a/neps/optimizers/bracket_optimizer.py b/neps/optimizers/bracket_optimizer.py index d5317c0df..5279e0b68 100644 --- a/neps/optimizers/bracket_optimizer.py +++ b/neps/optimizers/bracket_optimizer.py @@ -249,6 +249,12 @@ class BracketOptimizer: fid_name: str """The name of the fidelity in the space.""" + def __post_init__(self) -> None: + if self.space.grammar is not None: + raise NotImplementedError( + "Grammars not supported for `BracketOptimizer` yet." + ) + def __call__( # noqa: C901, PLR0912 self, trials: Mapping[str, Trial], @@ -257,7 +263,7 @@ def __call__( # noqa: C901, PLR0912 ) -> SampledConfig | list[SampledConfig]: assert n is None, "TODO" space = self.space - parameters = space.searchables + parameters = {**self.space.numerical, **self.space.categoricals} # If we have no trials, we either go with the prior or just a sampled config if len(trials) == 0: diff --git a/neps/optimizers/ifbo.py b/neps/optimizers/ifbo.py index 4e7d90726..416cd66ab 100755 --- a/neps/optimizers/ifbo.py +++ b/neps/optimizers/ifbo.py @@ -129,6 +129,10 @@ class IFBO: Each one will be treated as an individual fidelity level. """ + def __post_init__(self) -> None: + if self.space.grammar is not None: + raise NotImplementedError("Grammars not supported for `IFBO` yet.") + def __call__( self, trials: Mapping[str, Trial], @@ -137,7 +141,7 @@ def __call__( ) -> SampledConfig | list[SampledConfig]: assert self.space.fidelity is not None fidelity_name, fidelity = self.space.fidelity - parameters = self.space.searchables + parameters = {**self.space.numerical, **self.space.categoricals} assert n is None, "TODO" ids = [int(config_id.split("_", maxsplit=1)[0]) for config_id in trials] diff --git a/neps/cli/__init__.py b/neps/optimizers/models/graphs/__init__.py similarity index 100% rename from neps/cli/__init__.py rename to neps/optimizers/models/graphs/__init__.py diff --git a/neps/optimizers/models/graphs/context_managers.py b/neps/optimizers/models/graphs/context_managers.py new file mode 100644 index 000000000..56b643595 --- /dev/null +++ b/neps/optimizers/models/graphs/context_managers.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +from botorch.models import SingleTaskGP + +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, compute_kernel + +if TYPE_CHECKING: + import networkx as nx + from botorch.models.gp_regression_mixed import Kernel + + +@contextmanager +def set_graph_lookup( + kernel_or_gp: Kernel | SingleTaskGP, + new_graphs: list[nx.Graph], + *, + append: bool = True, +) -> Iterator[None]: + """Context manager to temporarily set the graph lookup for a kernel or GP model. + + Args: + kernel_or_gp (Kernel | SingleTaskGP): The kernel or GP model whose graph lookup is + to be set. + new_graphs (list[nx.Graph]): The new graphs to set in the graph lookup. + append (bool, optional): Whether to append the new graphs to the existing graph + lookup. Defaults to True. + """ + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + + # Determine the modules to update based on the input type + if isinstance(kernel_or_gp, SingleTaskGP): + modules = [ + k + for k in kernel_or_gp.covar_module.sub_kernels() + if isinstance(k, BoTorchWLKernel) + ] + elif isinstance(kernel_or_gp, BoTorchWLKernel): + modules = [kernel_or_gp] + else: + assert hasattr(kernel_or_gp, "sub_kernels"), ( + "Kernel module must have sub_kernels method." + ) + modules = [ + k for k in kernel_or_gp.sub_kernels() if isinstance(k, BoTorchWLKernel) + ] + + # Save the current graph lookup and set the new graph lookup + for kern in modules: + compute_kernel.cache_clear() + + kernel_prev_graphs.append((kern, kern.graph_lookup)) + if append: + kern.set_graph_lookup([*kern.graph_lookup, *new_graphs]) + else: + kern.set_graph_lookup(new_graphs) + + yield + + # Restore the original graph lookup after the context manager exits + for kern, prev_graphs in kernel_prev_graphs: + kern.set_graph_lookup(prev_graphs) diff --git a/neps/optimizers/models/graphs/kernels.py b/neps/optimizers/models/graphs/kernels.py new file mode 100644 index 000000000..64f4a7f71 --- /dev/null +++ b/neps/optimizers/models/graphs/kernels.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +import torch +from botorch.models.gp_regression_mixed import Kernel +from torch import Tensor +from torch.nn import Module + +from neps.optimizers.models.graphs.utils import graphs_to_tensors + +if TYPE_CHECKING: + import networkx as nx + + +@lru_cache(maxsize=128) +def compute_kernel( + adjacency_cache: tuple[Tensor, ...], + label_cache: tuple[Tensor, ...], + indices1: tuple[int, ...], + indices2: tuple[int, ...], + n_iter: int, + *, + diag: bool, + normalize: bool, +) -> Tensor: + """Compute the kernel matrix. + + This function is defined outside the class to leverage the `lru_cache` decorator, + which caches the results of expensive function calls and reuses them when the same + inputs occur again. + + Args: + adjacency_cache: Tuple of adjacency matrices for the graphs. + label_cache: Tuple of initial node labels for the graphs. + indices1: Tuple of indices for the first set of graphs. + indices2: Tuple of indices for the second set of graphs. + n_iter: Number of WL iterations. + diag: Whether to return only the diagonal of the kernel matrix. + normalize: Whether to normalize the kernel matrix. + + Returns: + A Tensor representing the kernel matrix. + """ + all_graphs = list(set(indices1).union(indices2)) + adj_matrices = [adjacency_cache[i] for i in all_graphs] + label_tensors = [label_cache[i] for i in all_graphs] + + # Compute full kernel matrix + _kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + K_full = _kernel(adj_matrices, label_tensors) + + # Map indices to their positions in all_graphs + idx1 = [all_graphs.index(i) for i in indices1] + idx2 = [all_graphs.index(i) for i in indices2] + + # Extract the relevant submatrix + K = K_full[idx1][:, idx2] + + # Return the diagonal if requested + if diag: + return torch.diag(K) + + return K + + +class BoTorchWLKernel(Kernel): + """A custom kernel for Gaussian Processes using the Weisfeiler-Lehman (WL) algorithm. + + This kernel computes similarities between graphs based on their structural properties + using the WL algorithm. It is designed to be used with BoTorch and GPyTorch for + Gaussian Process regression. + + Args: + graph_lookup (list[nx.Graph]): List of NetworkX graphs. + n_iter (int, optional): Number of WL iterations to perform. Default is 5. + normalize (bool, optional): Whether to normalize the kernel matrix. + Default is True. + active_dims (tuple[int, ...]): Dimensions of the input to consider. + Not used in this kernel but included for compatibility with the base Kernel class. + **kwargs (Any): Additional arguments for the base Kernel class. + + Attributes: + graph_lookup (list[nx.Graph]): List of graphs used for kernel computation. + n_iter (int): Number of WL iterations. + normalize (bool): Whether to normalize the kernel matrix. + adjacency_cache (list[Tensor]): Cached adjacency matrices of the graphs. + label_cache (list[Tensor]): Cached initial node labels of the graphs. + """ + + has_lengthscale = False + + def __init__( + self, + graph_lookup: list[nx.Graph], + n_iter: int = 5, + *, + normalize: bool = True, + active_dims: tuple[int, ...], + **kwargs: Any, + ) -> None: + super().__init__(active_dims=active_dims, **kwargs) + self.graph_lookup = graph_lookup + self.n_iter = n_iter + self.normalize = normalize + self._precompute_graph_data() + + def _precompute_graph_data(self) -> None: + """Precompute and cache adjacency matrices and initial node labels.""" + self.adjacency_cache, self.label_cache = graphs_to_tensors( + self.graph_lookup, device=self.device + ) + + def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None: + """Update the graph lookup and refresh the cached data.""" + self.graph_lookup = graph_lookup + self._precompute_graph_data() + + def forward( + self, + x1: Tensor, + x2: Tensor, + *, + diag: bool = False, + last_dim_is_batch: bool = False, + **params: Any, + ) -> Tensor: + """Compute kernel matrix containing pairwise similarities between graphs.""" + if last_dim_is_batch: + raise NotImplementedError("Batch dimension handling is not implemented.") + + if x1.ndim == 3: + return self._handle_batched_input(x1=x1, x2=x2, diag=diag) + + indices1, indices2 = self._prepare_indices(x1, x2) + + return compute_kernel( + adjacency_cache=tuple(self.adjacency_cache), + label_cache=tuple(self.label_cache), + indices1=tuple(indices1), + indices2=tuple(indices2), + n_iter=self.n_iter, + diag=diag, + normalize=self.normalize, + ) + + def _handle_batched_input(self, x1: Tensor, x2: Tensor, *, diag: bool) -> Tensor: + """Handle computation for batched input tensors.""" + q_dim_size = x1.shape[0] + assert x2.shape[0] == q_dim_size + + out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device) + for q in range(q_dim_size): + out[q] = self.forward(x1[q], x2[q], diag=diag) + return out + + def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]]: + """Convert tensor indices to integer lists.""" + indices1 = x1.flatten().to(torch.int64).tolist() + indices2 = x2.flatten().to(torch.int64).tolist() + + # Check for missing graph indices (-1) and handle them + # Explanation: The index `-1` is used as a placeholder for "missing" or "invalid" + # graphs. This can occur when a graph feature is missing or undefined, such as + # during the exploration of new candidates where no corresponding graph is + # available in the `graph_lookup`. The kernel expects non-negative indices, so we + # need to convert `-1` to the index of the last graph in the lookup. + + # Use the last graph in the lookup as a placeholder + last_graph_idx = len(self.graph_lookup) - 1 + + if -1 in indices1: + # Replace any `-1` indices with the index of the last graph. + indices1 = [last_graph_idx if i == -1 else i for i in indices1] + + if -1 in indices2: + # Replace any `-1` indices with the index of the last graph. + indices2 = [last_graph_idx if i == -1 else i for i in indices2] + + return indices1, indices2 + + +class TorchWLKernel(Module): + """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. + + The WL Kernel is a graph kernel that measures similarity between graphs based on + their structural properties. It works by iteratively updating node labels based on + their neighborhoods and computing feature vectors from label distributions. + + Args: + n_iter: Number of WL iterations to perform + normalize: bool, optional. Whether to normalize the kernel matrix + + Attributes: + device: torch.device for computation (CPU/GPU) + label_dict: Mapping from node labels to numerical indices + label_counter: Counter for generating new label indices + """ + + def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: + super().__init__() + self.n_iter = n_iter + self.normalize = normalize + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Keep track of labels across iterations + self.label_dict: dict[str, int] = {} + self.label_counter: int = 0 + + def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]: + """Extract neighborhood information from adjacency matrix.""" + if adj.layout == torch.sparse_csr: + adj = adj.to_sparse_coo() + + adj = adj.coalesce() + rows, cols = adj.indices() + num_nodes = adj.size(0) + + neighbors: list[list[int]] = [[] for _ in range(num_nodes)] + for row, col in zip(rows.tolist(), cols.tolist(), strict=False): + neighbors[row].append(col) + + return neighbors + + def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor: + """Perform one WL iteration.""" + if not self.label_dict: + # Start new labels after initial ones + self.label_counter = int(labels.max().item()) + 1 + + num_nodes = labels.size(0) + new_labels: list[int] = [] + neighbors = self._get_node_neighbors(adj) + + for node_idx in range(num_nodes): + # Get current node label + node_label = int(labels[node_idx].item()) + neighbor_labels = sorted([int(labels[n].item()) for n in neighbors[node_idx]]) + + credential = f"{node_label},{neighbor_labels}" + + # Update label dictionary + new_labels.append( + self.label_dict.setdefault(credential, len(self.label_dict)) + ) + + return torch.tensor(new_labels, dtype=torch.int64, device=self.device) + + def _compute_feature_vector(self, all_labels: list[list[Tensor]]) -> Tensor: + """Compute the histogram feature vector for all graphs.""" + batch_size = len(all_labels[0]) + features: list[Tensor] = [] + + for iteration_labels in all_labels: + # Find maximum label value across all graphs in this iteration + max_label = int(max(label.max().item() for label in iteration_labels)) + 1 + + iter_features = torch.zeros((batch_size, max_label), device=self.device) + + # Compute label frequencies + for graph_idx, labels in enumerate(iteration_labels): + counts = torch.bincount(labels, minlength=max_label) + iter_features[graph_idx] = counts + + features.append(iter_features) + + return torch.cat(features, dim=1) + + def forward(self, adj_matrices: list[Tensor], label_tensors: list[Tensor]) -> Tensor: + """Compute WL kernel matrix for a list of graphs. + + Args: + adj_matrices: Precomputed sparse adjacency matrices for graphs. + label_tensors: Precomputed node label tensors for graphs. + + Returns: + Kernel matrix containing pairwise graph similarities. + """ + if len(adj_matrices) != len(label_tensors): + raise ValueError("Mismatch between adjacency matrices and label tensors.") + + # Reset label dictionary for new computation + self.label_dict = {} + # Store all label iterations + all_labels: list[list[Tensor]] = [label_tensors] + + # Perform WL iterations + for _ in range(self.n_iter): + new_labels = [ + self._wl_iteration(adj, labels) + for adj, labels in zip(adj_matrices, all_labels[-1], strict=False) + ] + all_labels.append(new_labels) + + # Compute feature vectors and kernel matrix (similarity matrix) + final_features = self._compute_feature_vector(all_labels) + kernel_matrix = torch.mm(final_features, final_features.t()) + + if self.normalize: + diag = torch.sqrt(torch.diag(kernel_matrix)) + kernel_matrix /= torch.outer(diag, diag) + + return kernel_matrix diff --git a/neps/optimizers/models/graphs/optimization.py b/neps/optimizers/models/graphs/optimization.py new file mode 100644 index 000000000..c036f5816 --- /dev/null +++ b/neps/optimizers/models/graphs/optimization.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from botorch.optim import optimize_acqf_mixed + +from neps.optimizers.models.graphs.context_managers import set_graph_lookup +from neps.optimizers.models.graphs.utils import sample_graphs + +if TYPE_CHECKING: + import networkx as nx + from botorch.acquisition import AcquisitionFunction + + +def optimize_acqf_graph( + acq_function: AcquisitionFunction, + bounds: torch.Tensor, + fixed_features_list: list[dict[int, int]] | None = None, + num_graph_samples: int = 10, + train_graphs: list[nx.Graph] | None = None, + num_restarts: int = 10, + raw_samples: int = 1024, + q: int = 1, +) -> tuple[torch.Tensor, nx.Graph, float]: + """Optimize an acquisition function with graph sampling. + + This function optimizes the acquisition function by sampling graphs from the training + set, temporarily updating the kernel's graph lookup, and evaluating the acquisition + function for each sampled graph. The best candidate, the best graph, and its + corresponding acquisition score are returned. + + Args: + acq_function (AcquisitionFunction): The acquisition function to optimize. + bounds (torch.Tensor): A 2 x d tensor of bounds for numerical and categorical + features, where d is the number of features. + fixed_features_list (list[dict[int, float]] | None): A list of dictionaries + specifying fixed categorical feature configurations. Each dictionary maps + feature indices to their fixed values. Defaults to None. + num_graph_samples (int): The number of graphs to sample from the training set. + Defaults to 10. + train_graphs (list[nx.Graph] | None): The original training graphs. If None, a + ValueError is raised. + num_restarts (int): The number of optimization restarts. Defaults to 10. + raw_samples (int): The number of raw samples to generate for optimization. + Defaults to 1024. + q (int): The number of candidates to generate. Defaults to 1. + + Returns: + tuple[torch.Tensor, nx.Graph, float]: A tuple containing the best candidate + (as a tensor), the best graph, and its corresponding acquisition score. + + Raises: + ValueError: If `train_graphs` is None. + """ + if train_graphs is None: + raise ValueError("train_graphs cannot be None.") + + sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) + + best_candidates, best_graphs, best_scores = [], [], [] + + # Get the index of the graph feature in the bounds + graph_idx = bounds.shape[1] - 1 + + # Todo: Instead of iterating over the graphs, optimize by putting all + # sampled graphs into the kernel and compute the scores in a single batch. + # Update the caching logic accordingly. + for graph in sampled_graphs: + with set_graph_lookup(acq_function.model.covar_module, [graph], append=True): + # Iterate through each fixed feature configuration (if provided) + for fixed_features in fixed_features_list or [{}]: + # Add the graph index to the fixed features, indicating that the last + # graph in the lookup should be used + updated_fixed_features = {**fixed_features, graph_idx: -1.0} + + # Optimize the acquisition function with the updated fixed features + candidates, scores = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=[updated_fixed_features], + num_restarts=num_restarts, + raw_samples=raw_samples, + q=q, + ) + + # Store the candidates, graphs, and their scores + best_candidates.append(candidates) + best_graphs.append(graph) + best_scores.append(scores) + + # Find the index of the best score + best_idx = torch.argmax(torch.tensor(best_scores)) + + # Return the best candidate (without the graph index), the best graph, and its score + return ( + best_candidates[best_idx][:, :-1], + best_graphs[best_idx], + best_scores[best_idx].item(), + ) diff --git a/neps/optimizers/models/graphs/utils.py b/neps/optimizers/models/graphs/utils.py new file mode 100644 index 000000000..056921554 --- /dev/null +++ b/neps/optimizers/models/graphs/utils.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import random + +import networkx as nx +import numpy as np +import torch + + +def seed_all(seed: int = 100) -> None: + """Seed all random generators for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # Ensure reproducibility with CuDNN (may reduce performance) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def min_max_scale(tensor: torch.Tensor) -> torch.Tensor: + """Scale the input tensor to the range [0, 1].""" + min_vals = tensor.min(dim=0, keepdim=True).values + max_vals = tensor.max(dim=0, keepdim=True).values + return (tensor - min_vals) / (max_vals - min_vals) + + +def graphs_to_tensors( + graphs: list[nx.Graph], device: torch.device | None = None +) -> tuple[list[torch.sparse.Tensor], list[torch.Tensor]]: + """Convert a list of NetworkX graphs into sparse adjacency matrices and label tensors. + + Args: + graphs (List[nx.Graph]): A list of NetworkX graphs. + device (torch.device | None): The device to place the tensors on. + Default is CPU. + + Returns: + Tuple[List[torch.sparse.Tensor], List[torch.Tensor]]: + A tuple containing: + - A list of sparse adjacency matrices. + - A list of label tensors. + """ + if device is None: + device = torch.device("cpu") + + adjacency_matrices = [] + label_tensors = [] + + # Create a consistent label mapping across all graphs + label_dict: dict[str, int] = {} + label_counter: int = 0 + + for graph in graphs: + # Create adjacency matrix + edges = list(graph.edges()) + num_nodes = graph.number_of_nodes() + + if not edges: + adj = torch.sparse_coo_tensor( + indices=torch.empty((2, 0), dtype=torch.long), + values=torch.empty(0), + size=(num_nodes, num_nodes), + device=device, + ).to_sparse_csr() + else: + edge_indices = edges + [(v, u) for u, v in edges] + rows, cols = zip(*edge_indices, strict=False) + indices = torch.tensor([rows, cols], dtype=torch.long, device=device) + values = torch.ones(len(edge_indices), dtype=torch.float, device=device) + adj = torch.sparse_coo_tensor( + indices, values, (num_nodes, num_nodes), device=device + ).to_sparse_csr() + + adjacency_matrices.append(adj) + + # Create label tensor + node_labels: list[int] = [] + for node in range(graph.number_of_nodes()): + if "label" in graph.nodes[node]: + label = graph.nodes[node]["label"] + if label not in label_dict: + label_dict[label] = label_counter + label_counter += 1 + node_labels.append(label_dict[label]) + else: + node_labels.append(node) + + label_tensors.append(torch.tensor(node_labels, dtype=torch.long, device=device)) + + return adjacency_matrices, label_tensors + + +def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: + """Sample graphs using random walks or edge modifications. + + Args: + graphs (list[nx.Graph]): Existing training graphs. + num_samples (int): Number of graph samples to generate. + + Returns: + list[nx.Graph]: Sampled graphs. + """ + sampled_graphs = [] + for _ in range(num_samples): + base_graph = random.choice(graphs) + sampled_graph = base_graph.copy() + + # More aggressive modifications + num_modifications = random.randint(2, 5) # Increase minimum modifications + for _ in range(num_modifications): + if random.random() > 0.3: # 70% chance to add edge + nodes = list(sampled_graph.nodes) + if len(nodes) >= 2: + u, v = random.sample(nodes, 2) + if not sampled_graph.has_edge(u, v): + sampled_graph.add_edge(u, v) + elif sampled_graph.edges: # 30% chance to remove edge + u, v = random.choice(list(sampled_graph.edges)) + sampled_graph.remove_edge(u, v) + + # Ensure the graph stays connected + if not nx.is_connected(sampled_graph): + components = list(nx.connected_components(sampled_graph)) + for i in range(len(components) - 1): + u = random.choice(list(components[i])) + v = random.choice(list(components[i + 1])) + sampled_graph.add_edge(u, v) + + sampled_graphs.append(sampled_graph) + + return sampled_graphs diff --git a/neps/optimizers/priorband.py b/neps/optimizers/priorband.py index 9d6d23e4b..eacd46181 100644 --- a/neps/optimizers/priorband.py +++ b/neps/optimizers/priorband.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, assert_never import numpy as np import torch @@ -103,13 +103,18 @@ def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: or spent_one_sh_bracket_worth_of_fidelity is False or any_rung_with_eta_evals is False ): - policy = np.random.choice(["prior", "random"], p=[w_prior, w_random]) + policy: Literal["prior", "random"] = np.random.choice( + ["prior", "random"], + p=[w_prior, w_random], + ) match policy: case "prior": config = prior_dist.sample_config(to=self.encoder) case "random": _sampler = Sampler.uniform(ndim=self.encoder.ndim) config = _sampler.sample_config(to=self.encoder) + case _: + assert_never(policy) return config diff --git a/neps/optimizers/random_search.py b/neps/optimizers/random_search.py index 5b6742a6a..44d11bf90 100644 --- a/neps/optimizers/random_search.py +++ b/neps/optimizers/random_search.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +import numpy as np + from neps.optimizers.optimizer import SampledConfig if TYPE_CHECKING: @@ -18,7 +20,7 @@ class RandomSearch: space: SearchSpace encoder: ConfigEncoder - sampler: Sampler + numerical_sampler: Sampler def __call__( self, @@ -28,12 +30,21 @@ def __call__( ) -> SampledConfig | list[SampledConfig]: n_trials = len(trials) _n = 1 if n is None else n - configs = self.sampler.sample(_n, to=self.encoder.domains) + configs_tensor = self.numerical_sampler.sample(_n, to=self.encoder) - config_dicts = self.encoder.decode(configs) + config_dicts = self.encoder.decode(configs_tensor) for config in config_dicts: config.update(self.space.constants) + # TODO: We should probably have a grammar sampler class, not do it manually here + # This works for now but should be updated. + if self.space.grammar is not None: + rng = np.random.default_rng() # TODO: We should be able to seed this. + grammar_key, grammar = self.space.grammar + for config in config_dicts: + sample = grammar.sample(rng=rng) + config.update({grammar_key: sample.to_string()}) + if n is None: config = config_dicts[0] config_id = str(n_trials + 1) diff --git a/neps/optimizers/utils/initial_design.py b/neps/optimizers/utils/initial_design.py index 615a5a257..6f55fa267 100644 --- a/neps/optimizers/utils/initial_design.py +++ b/neps/optimizers/utils/initial_design.py @@ -5,20 +5,28 @@ import torch +from neps.optimizers.priorband import mutate_config from neps.sampling import Prior, Sampler +from neps.space import Grammar +from neps.space.grammar import RandomSampler, MutationSampler, GrammarSampler if TYPE_CHECKING: - from neps.space import ConfigEncoder - from neps.space.parameters import Parameter + from neps.space import ConfigEncoder, Parameter def make_initial_design( *, - parameters: Mapping[str, Parameter], + parameters: Mapping[str, Parameter | Grammar], encoder: ConfigEncoder, sampler: Literal["sobol", "prior", "uniform"] | Sampler, sample_size: int | Literal["ndim"] | None = "ndim", sample_prior_first: bool = True, + grammar_mutant_selector:( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ) = ("climb", range(1, 4)), + grammar_max_mutation_depth: int = 3, seed: torch.Generator | None = None, ) -> list[dict[str, Any]]: """Generate the initial design of the optimization process. @@ -41,37 +49,64 @@ def make_initial_design( If None, no configurations will be sampled. sample_prior_first: Whether to sample the prior configuration first. + grammar_mutant_selector: Please see [`select()`][neps.space.grammar.select]. + grammar_max_mutation_depth: How deep to enumerate mutants of a prior for the + grammar. seed: The seed to use for the random number generation. """ configs: list[dict[str, Any]] = [] + numerics = {k: p for k, p in parameters.items() if not isinstance(p, Grammar)} + grammars = {k: p for k, p in parameters.items() if isinstance(p, Grammar)} if sample_prior_first: - configs.append( - { - name: p.prior if p.prior is not None else p.center - for name, p in parameters.items() - } - ) - - ndims = len(parameters) + grammar_priors: dict[str, str] = { + k: ( + g.prior + if g.prior is not None + # Ew sorry + else RandomSampler(g).sample(1)[0].to_string() + ) + for k, g in grammars.items() + } + numeric_priors: dict[str, Any] = { + name: p.prior if p.prior is not None else p.center + for name, p in numerics.items() + } + configs.append({**numeric_priors, **grammar_priors}) + + numeric_ndims = len(numerics) + grammar_expansion_count = sum(g._expansion_count for g in grammars.values()) if sample_size == "ndim": - sample_size = ndims + # TODO: Not sure how to handle graphs here properly here to be honest + sample_size = numeric_ndims + grammar_expansion_count elif sample_size is not None and not sample_size > 0: raise ValueError( "The sample size should be a positive integer if passing an int." ) if sample_size is not None: + # Numeric sampling match sampler: case "sobol": - sampler = Sampler.sobol(ndim=ndims) + numeric_sampler = Sampler.sobol(ndim=numeric_ndims) + grammar_sampler = GrammarSampler.random(grammars) case "uniform": - sampler = Sampler.uniform(ndim=ndims) + numeric_sampler = Sampler.uniform(ndim=numeric_ndims) + grammar_sampler = GrammarSampler.random(grammars) case "prior": - sampler = Prior.from_parameters(parameters) + numeric_sampler = Prior.from_parameters(numerics) + grammar_sampler = GrammarSampler.prior( + grammars, + mutant_selector=grammar_mutant_selector, + max_mutation_depth=grammar_max_mutation_depth + ) case _: pass + # TODO: Replace with something more solid + # Grammar sampling + for k, g in grammars.items(): + encoded_configs = sampler.sample(sample_size * 2, to=encoder.domains, seed=seed) uniq_x = torch.unique(encoded_configs, dim=0) sample_configs = encoder.decode(uniq_x[:sample_size]) diff --git a/neps/space/__init__.py b/neps/space/__init__.py index f2bbc55ca..e5d8889e3 100644 --- a/neps/space/__init__.py +++ b/neps/space/__init__.py @@ -1,5 +1,6 @@ from neps.space.domain import Domain from neps.space.encoding import ConfigEncoder +from neps.space.grammar import Grammar from neps.space.parameters import Categorical, Constant, Float, Integer, Parameter from neps.space.search_space import SearchSpace @@ -9,6 +10,7 @@ "Constant", "Domain", "Float", + "Grammar", "Integer", "Parameter", "SearchSpace", diff --git a/neps/space/grammar.py b/neps/space/grammar.py new file mode 100644 index 000000000..df90d8b0f --- /dev/null +++ b/neps/space/grammar.py @@ -0,0 +1,1440 @@ +"""A module containing the [`Grammar`][neps.space.grammar.Grammar] parameter. + +A `Grammar` contains a list of production `rules`, which produce a _string_ from +the grammar, as well as some `start_symbol` which is used by optimizers. + +!!! note + + We make a distinction that **string** is not a python `str`, and represents + an expanded set of rules from the grammar. + +Each rule, either a [`Terminal`][neps.space.grammar.Grammar.Terminal] or +[`NonTerminal`][neps.space.grammar.Grammar.NonTerminal], is a key-value pair, +where the key is a symbol, such as `"S"` and the value is what the symbol represents. +See the example below. + +You can create a `Grammar` conveninetly using +[`Grammar.from_dict({...})`][neps.space.grammar.Grammar.from_dict]. + +!!! example + + ```python + from neps import Grammar + + # Using bare types + grammar = Grammar.from_dict({ + "S": (["OP OP OP", "OP OP"], nn.Sequential), # A seq with either 3 or 2 children + "OP": ["linear golu", "linear relu"], # A choice between linear with a golu/relu + "linear": partial(nn.LazyLinear, out_features=10, bias=False), # A linear layer + "relu": nn.ReLU, # A relu activation + "golu": nn.GoLU, # A golu activation + }) + + # Explicitly + grammar = Grammar({ + "S": NonTerminal(choices=["OP OP OP"], op=nn.Sequential, shared=False), + "OP": NonTerminal(choices=["linear golu", "linear relu"], op=None, shared=False), + "relu": Terminal(nn.ReLU), + "linear": Terminal(partial(nn.LazyLinear, out_features=10, bias=False)), + "golu": Terminal(nn.GoLU), + }) + ``` + +A _string_ from a `Grammar` can be produced in several ways: + +* [`grammar.parse()`][neps.space.grammar.Grammar.parse] - parse a grammar from a `str` + into a _string_, which is represented by a [`Node`][neps.space.grammar.Node] tree. + The inverse of this operation is to call `node.to_string()`. +* [`grammar.sample()`][neps.space.grammar.Grammar.sample] - Sample a random string from + the grammar. +* [`grammar.mutations()`][neps.space.grammar.Grammar.mutations] - This takes in a `Node`, + which represents a _string_ from the grammar, and can mutate selected points of the + string. You can use the function [`node.select()`][neps.space.grammar.select] for + different strategies to select parts of the string to mutate, for example, all + parents of a leaf with `node.select(how=("climb", 1))` or specific symbols using + `node.select(how=("symbol", "OP"))`. +* [`grammar.bfs()`][neps.space.grammar.Grammar.bfs] - This iterates through all possible + strings producable from the grammar, using a max-depth to prevent infinite recursion. + +As mentioned in the above methods, a string from the the `Grammar` is represnted as a tree +of [`Node`][neps.space.grammar.Node], which also contain the associated meaning of the +string parts, i.e. what operation that symbol should do. + +* [`Leaf`][neps.space.grammar.Leaf] - A symbol with no children and an operation. +* [`Container`][neps.space.grammar.Container] - A symbol with children and some containing + operation, for example an `nn.Sequential`. +* [`Passthrough`][neps.space.grammar.Passthrough] - A symbol with children but **no** + operation. It's children will be passed up to its parent until it hits a `Container`. + +Please see the associated docstrings for more information. + +For the most part, you can consider all of these as a [`Node`][neps.space.grammar.Node], +which has the following attached functions: + +* [`to_string()`][neps.space.grammar.to_string] - Convert to it's python `str` + representation. +* [`to_model()`][neps.space.grammar.to_model] - Convert it into some kind of model, + defined by its operations. Normally this represnts some `nn.Module` structure but it + is not necessarily torch specific. +* [`to_nxgraph()`][neps.space.grammar.to_nxgraph] - Convert it into a `nx.Digraph` which + can be useful for optimization or other applications such as plotting. The inverse + operation is called from the grammar, + [`grammar.node_from_nxgraph()`][neps.space.grammar.Grammar.node_from_nxgraph] +* [`select()`][neps.space.grammar.select] - Select certain nodes of the string by + a criterion. +* [`dfs()`][neps.space.grammar.dfs] - DFS iteration over the nodes of the string. +* [`bfs()`][neps.space.grammar.bfs] - BFS iteration over the nodes of the string. +""" + +from __future__ import annotations + +import itertools +from collections import defaultdict +from collections.abc import Callable, Iterable, Iterator, Mapping +from dataclasses import dataclass, field +from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias +from typing_extensions import assert_never + +import more_itertools +import networkx as nx +import numpy as np + +from neps.exceptions import NePSError + + +class ParseError(NePSError): + """An error occured while parsing a grammar string.""" + + +@dataclass +class _BufferedRandInts: + rng: np.random.Generator + buffer_size: int = 50 + _cur_ix: int = 0 + + MAX_INT: ClassVar[int] = np.iinfo(np.int64).max + _nums: list[int] = field(default_factory=list) + + def next(self, n: int) -> int: + if self._cur_ix >= len(self._nums): + self._nums = self.rng.integers( + self.MAX_INT, size=self.buffer_size, dtype=np.int64 + ).tolist() + + self._cur_ix = 0 + + i = self._nums[self._cur_ix] % n + + self._cur_ix += 1 + return i + + +def dfs_node(node: Node) -> Iterator[Node]: + """Perform a depth-first search iteration on the node.""" + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + stack.extend(reversed(children)) + case _: + assert_never(nxt) + + +def bfs_node(node: Node) -> Iterator[Node]: + """Perform a breadth-first search iteration on the node.""" + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + queue.extend(children) + case _: + assert_never(nxt) + + +def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: # noqa: C901 + """Convert a node and it's children into an `nx.DiGraph`. + + Args: + root: The node to start from. + include_passthroughs: Whether to include passthrough symbols into the + produced graph. + """ + nodes: list[tuple[int, dict]] = [] + edges: list[tuple[int, int]] = [] + id_generator: Iterator[int] = itertools.count() + + def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: + node_id = next(id_generator) + match node: + # Atoms are just a node with an edge to its parent + case Leaf(symbol): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + # If we have a passthrough and shouldn't include them, we simply + # forward on the `parent_id` we recieved to the children + case Passthrough(_, children) if include_passthroughs is False: + for child in children: + _recurse_fill_lists(child, parent_id=parent_id) + + # Containers are a node in the graph, with edges to its + # children (direct, or through passthrough) + case Container(symbol, children, _) | Passthrough(symbol, children): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + for child in children: + _recurse_fill_lists(child, parent_id=node_id) + + case _: + assert_never(root.kind) + + graph = nx.DiGraph() + root_id = next(id_generator) + nodes.append((root_id, {"label": root.symbol})) + match root: + case Leaf(): + pass + case Passthrough(_, children) if include_passthroughs is False: + raise ValueError( + f"Can't create a graph starting from a `Passthrough` {root.symbol}, " + " unless `include_passthrough`" + ) + case Container(_, children, _) | Passthrough(_, children): + for child in children: + _recurse_fill_lists(child, parent_id=root_id) + case _: + assert_never(root) + + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def to_model(node: Node) -> Any: + """Convert a parse tree node and its children into some object it represents.""" + + def _build(_n: Node) -> list[Any] | Any: + match _n: + case Leaf(_, op): + return op() + case Container(_, children, op): + # The problem is that each child could be either: + # * A single 'thing', in the case of Leaf or Container + # * Multiple things, in case it's a passthrough + # Hence we flatten them out into a single big children itr + built_children = more_itertools.collapse( + (_build(child) for child in children), + base_type=(op if isinstance(op, type) else None), + ) + return op(*built_children) + case Passthrough(_, children): + return [_build(child) for child in children] + case _: + assert_never(node) + + match node: + case Leaf() | Container(): + obj = _build(node) + assert not isinstance(obj, list) + return obj + case Passthrough(symbol): + raise ValueError(f"Can not call build on a `Passthrough` {symbol}") + case _: + assert_never(node) + + +def select( # noqa: C901, PLR0912, PLR0915 + root: Node, + *, + how: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ), +) -> Iterator[Node]: + """Iterate through the tree and select nodes according to `how=`. + + Args: + root: the root node to start from. + how: which nodes to select. In the case of `"depth"` and `"climb"`, you can either + provide a specific value `int`, or else a `range`, where anything that has + a value in that `range` is included. Note that this follows the same + convention that `4 in range(3, 5)` but `5 not in range(3, 5)`, + i.e. that the stop boundary is non-inclusive. + + * `"symbol"` - Select all nodes which have the given symbol. + * `"depth"`- Select all nodes which are at a given depth, either a particular + depth value or a range of depth values. The `root` is defined to be at + `depth == 0` while its direct children are defined to be at `depth == 1`. + * `"climb"`- Select all nodes which are at a given distance away from a leaf. + Leafs are defined to be at `climb == 0`, while any direct parents + of a leaf are `climb == 1`. + """ + match how: + case ("symbol", symbol): + for node in bfs_node(root): + if node.symbol == symbol: + yield node + case ("depth", depth): + if isinstance(depth, int): + depth = range(depth, depth + 1) + + queue_depth: list[tuple[Node, int]] = [(root, 0)] + while queue_depth: + nxt, d = queue_depth.pop(0) + if d in depth: + yield nxt + + if d >= depth.stop: + continue + + match nxt: + case Leaf(): + pass + case Passthrough(children=children) | Container(children=children): + queue_depth.extend([(child, d + 1) for child in children]) + case _: + assert_never(nxt) + + case ("climb", climb): + if isinstance(climb, int): + climb = range(climb, climb + 1) + + # First, we iterate downwards, populating parent paths back + # up. As the id for a Leaf is shared across all similar leafs + # as well as the fact shared nodes will share the same node id, + # we could have multiple parents per child id. + parents: defaultdict[int, list[Node]] = defaultdict(list) + + # We remove duplicates using a dict and the shared ids, a list would + # end up with duplicates for every leaf. We use this later to begin + # the climb iteration + leafs: dict[int, Node] = {} + + queue_climb: list[Node] = [root] + while queue_climb: + nxt = queue_climb.pop(0) + this_id = id(nxt) + match nxt: + case Leaf(): + leafs[this_id] = nxt + case Passthrough(children=children) | Container(children=children): + for child in children: + parents[id(child)].append(nxt) + queue_climb.extend(children) + case _: + assert_never(nxt) + + # Now we work backwards from the leafs for each of the possible parents + # for the node id, yielding if we're within the climb path. If we've gone + # pass the climb value, we can stop iterating there. + climb_queue: list[tuple[Node, int]] = [] + climb_queue.extend([(leaf, 0) for leaf in leafs.values()]) + seen: set[int] = set() + while climb_queue: + node, climb_value = climb_queue.pop(0) + node_id = id(node) + if node_id in seen: + continue + + if climb_value in climb: + seen.add(node_id) + yield node + + if climb_value < climb.stop: + possible_node_parents = parents[id(node)] + climb_queue.extend( + [ + (p, climb_value + 1) + for p in possible_node_parents + if id(p) not in seen + ] + ) + + case _: + assert_never(how) + + +# TODO: Optimization, we don't need to recompute shared substrings. +# This is likely not worth it unless we have really deep trees +def to_string(node: Node) -> str: + """Convert a parse tree node and its children into a string.""" + match node: + case Leaf(symbol): + return symbol + case Passthrough(symbol, children) | Container(symbol, children): + return f"{symbol}({', '.join(to_string(c) for c in children)})" + case _: + assert_never(node) + return None + + +class Leaf(NamedTuple): + """A node which has no children. + + !!! note + + As we only ever have one kind of leaf per symbol, we tend to re-use the + same instance of a `Leaf` which gets re-used where it needs to. In contrast, + a `Container` and `Passthrough` may have different children per symbol and a new + instance is made each time. + + Args: + symbol: The string symbol associated with this `Leaf`. + op: The associated operations with this `symbol`. + """ + + symbol: str + op: Callable + + def __hash__(self) -> int: + return hash(self.symbol) + + dfs = dfs_node + bfs = bfs_node + to_string = to_string + to_nxgraph = to_nxgraph + to_model = to_model + select = select + + +class Container(NamedTuple): + """A node which contains children and has an associated operation. + + Args: + symobl: The string symbol associated with this `Container`. + children: The direct children of this node. When instantiating this container, + it will be called with it's instantiated children with `op(*children)`. + op: The associated operation with this node, such as an `nn.Sequential`. + """ + + symbol: str + children: list[Node] + op: Callable + + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) + + dfs = dfs_node + bfs = bfs_node + to_string = to_string + to_nxgraph = to_nxgraph + to_model = to_model + select = select + + +class Passthrough(NamedTuple): + """A node which contains children but has no associated operation. + + This is used for things such as `"OP": ["conv2d", "conv3d", "identity"]`, where + `"OP"` does not have some kind of container operation and is used to make a choice + between various symbols. + + Args: + symbol: The associated symbol with this `Passthrough`. + children: The direct children of this node. As this node can not be instantiated, + the children of this `Passthrough` are forward on to this nodes parents. + """ + + symbol: str + children: list[Node] + + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) + + dfs = dfs_node + bfs = bfs_node + to_string = to_string + to_nxgraph = to_nxgraph + to_model = to_model + select = select + + +Node: TypeAlias = Container | Passthrough | Leaf +"""The possible nodes in a constructed instance of a string from the grammar. + +Please see the associated types for their description or the docstring of a +[`Grammar`][neps.space.grammar.Grammar]. +""" + + +@dataclass +class Grammar: + """A grammar defines a search space of symbols which may contain other symbols. + + !!! tip + + You most likely want to create one of these using + [`from_dict()`][neps.space.grammar.Grammar.from_dict]. + + A grammar consists of `rules: dict[str, Grammar.Terminal | Grammar.NonTerminal]` + where the key is a string symbol, and the values are what that string symbol + represents. The initial symbol used by optimizers is specified using `start_symbol`. + + The [`Grammar.Terminal`][neps.space.Grammar.Terminal] represents some kind of leaf + node of a computation graph, such as a function call or some operation which + does not have any children dependancies, for example an `nn.Linear`. This is + modeled as a [`Node`][neps.space.grammar.Node], specifically the + [`Leaf`][neps.space.grammar.Leaf] type. + + The [`Grammar.NonTerminal`][neps.space.Grammar.NonTerminal] represents some kind of + intermediate operation, which contains sub-symbols which are sub-computations of + a computation graph. A common example of this is when `op=nn.Sequential`, which by + itself does not really do any computations but relies on the computation of it's + children which it performs one after another. If there is an associated `op=`, then + we consider this be a [`Container`][neps.space.grammar.Container] kind of + [`Node`][neps.space.grammar.Node]. If there is **no** associated `op=`, then we + consider this to be a [`Passthrough`][neps.space.grammar.Passthrough] kind of + [`Node`][neps.space.grammar.Node]. + + For a `Grammar.NonTerminal`, you may also specify if it is `shared: bool`, which is + by default `False`. When explicitly set as `True`, all choices made for its children + will be shared through the generated/sampled/parsed string. In constrast, if + `shared=False`, then any specific instance of this symbol may have different children. + + Args: + start_symbol: The starting symbol used by optimizers. + rules: The possible grammar rules which define the structure of the grammar. + prior: Some prior string producable from the grammar that can be used as a user + prior. + """ + + start_symbol: str + rules: dict[str, Terminal | NonTerminal] + prior: str | None = None + _expansion_count: int = field(init=False) + _shared: dict[str, NonTerminal] = field(init=False) + _leafs: dict[str, Leaf] = field(init=False) + _prior_node: Node | None = field(init=False) + + class Terminal(NamedTuple): + """A symbol which has no children and an associated operation. + + When a specific instance of a string from this grammar is made, this + rule will create a [`Leaf`][neps.space.grammar.Leaf]. + + Args: + op: The associated operation. + """ + + op: Callable + + class NonTerminal(NamedTuple): + """A symbol which has different possible children. + + Depending on whether `op=` is specified or not, this will either be a + [`Container`][neps.space.grammar.Container] or a + [`Passthrough`][neps.space.grammar.Passthrough]. + + Args: + choices: The list of possible children to place inside this `NonTerminal`. + Different possibilities are specified by the elements of the list. + When a `str` contains multiple symbols that are space seperated, these + will both be children. + + ``` + # The following says that we have a choice between "a", "b" and "c d". + # In the case that "c d" is chosen, both of those will be children of the + # created node. + ["a", "b", "c d"] + ``` + + op: The associated operation with this node, if any. + shared: Whether the choices made for this symbol should be shared throughout + the tree, or whether they should be considred independant. + """ + + choices: list[str] + op: Callable | None = None + shared: bool = False + + def __post_init__(self) -> None: + start_rule = self.rules.get(self.start_symbol, None) + if start_rule is None: + raise ValueError( + f"The start_symbol '{self.start_symbol}' should be one of the symbols" + f" in rules, which are {self.rules.keys()}" + ) + self._shared = { + s: r + for s, r in self.rules.items() + if isinstance(r, Grammar.NonTerminal) and r.shared + } + self._leafs = { + s: Leaf(s, r.op) + for s, r in self.rules.items() + if isinstance(r, Grammar.Terminal) + } + + # In lue of a good proxy for 'size', which we might need in some scenarios, + # such as in initial design where you can specify sample size is of size + # `"ndim"`, we use this proxy. Totally unscientific. + # Things to consider if you want to change it: + # * Recursive elements (e.g. A -> [A, a, b, c]), where recursion by uniform + # sampling is 1/4. This would need to propogate if for example `b` could also + # recurse on itself.... + # * That we can have multiple children, i.e. `A -> [A a, A b, A c, A A]` + # * Leafs do not expand the size + self._expansion_count = sum( + len(rule.choices) + for rule in self.rules.values() + if isinstance(rule, Grammar.NonTerminal) + ) + + if self.prior is not None: + try: + prior_node = self.parse(self.prior) + except ParseError as e: + raise ValueError( + f"The prior '{self.prior}' given for this grammar could" + " not be parsed properly." + ) from e + else: + prior_node = None + + self._prior_node = prior_node + + @classmethod + def from_dict( + cls, + start_symbol: str, + grammar: dict[ + str, + Callable + | list[str] + | tuple[list[str], Callable] + | Grammar.Terminal + | Grammar.NonTerminal, + ], + *, + prior: str | None = None, + ) -> Grammar: + """Create a `Grammar` from a dictionary. + + Please see the module doc for more. + + Args: + start_symbol: The starting symbol from which to produce strings. + grammar: The rules of the grammar. + prior: Some prior string producable from the grammar that can be used as a + user prior. + """ + rules: dict[str, Grammar.Terminal | Grammar.NonTerminal] = {} + for symbol, rule in grammar.items(): + match rule: + case Grammar.Terminal() | Grammar.NonTerminal(): + rules[symbol] = rule + case (choices, op) if isinstance(choices, list) and callable(op): + # > e.g. "S": (["A", "A B", "C"], op) + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, op, shared=False) + + case choices if isinstance(choices, list): + # > e.g. "S": ["A", "A B", "C"] + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, op=None, shared=False) + + case op if callable(op): + # > e.g. "S": op + rules[symbol] = Grammar.Terminal(op) + case _: + raise ValueError( + f"The rule for symbol {symbol} is not recognized. Should be" + " a list of of symbols, a callable or a tuple with both." + f"\n Got {rule}" + ) + + return Grammar(start_symbol=start_symbol, rules=rules, prior=prior) + + def sample( # noqa: C901, PLR0912 + self, + symbol: str | None = None, + *, + rng: np.random.Generator | _BufferedRandInts, + variables: dict[str, Node] | None = None, + ) -> Node: + """Sample a random string from this grammar. + + Args: + symbol: The symbol to start from. If not provided, this will use + the `start_symbol`. + rng: The random generator by which sampling is done. + variables: Any shared variables to use in the case that a sampled + rule has `shared=True`. + + Returns: + The root of the sampled string. + """ + if isinstance(rng, np.random.Generator): + rng = _BufferedRandInts(rng=rng) + + if symbol is None: + symbol = self.start_symbol + + variables = variables or {} + rule = self.rules.get(symbol) + if rule is None: + raise KeyError(f"'{symbol}' not in grammar keys {self.rules.keys()}") + + stack: list[Container | Passthrough] = [] + match rule: + case Grammar.Terminal(): + return self._leafs[symbol] + case Grammar.NonTerminal(choices, op, shared): + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + i = rng.next(len(rule.choices)) + initial_sample = rule.choices[i] + children_symbols = initial_sample.split(" ") + root = ( + Passthrough(symbol, []) if op is None else Container(symbol, [], op) + ) + stack.append(root) + case _: + assert_never(rule) + + while stack: + parent = stack.pop() + i = rng.next(len(choices)) + choice = choices[i] + children_symbols = choice.split(" ") + + for child_symbol in children_symbols: + rule = self.rules[child_symbol] + match rule: + case Grammar.Terminal(): + parent.children.append(self._leafs[child_symbol]) + case Grammar.NonTerminal(choices, op, shared): + shared_node = variables.get(child_symbol) + if shared_node is not None: + parent.children.append(shared_node) + continue + + sub_parent = ( + Passthrough(child_symbol, []) + if op is None + else Container(child_symbol, [], op) + ) + parent.children.append(sub_parent) + stack.append(sub_parent) + + if shared: + variables[child_symbol] = sub_parent + case _: + assert_never(rule) + + return root + + def node_from_graph(self, graph: nx.DiGraph) -> Node: + """Convert an `nx.DiGraph` into a string. + + Args: + graph: The graph, produced by + [`to_nxgraph()`][neps.space.grammar.Grammar.to_nxgraph] + + Returns: + The root of the string produced from the graph. + """ + _root = next((n for n, d in graph.in_degree if d == 0), None) + if _root is None: + raise ValueError( + "Could not find a root in the given graph (a node with indegree 1)." + ) + + variables: dict[str, Node] = {} + + def _recurse(node_id: int) -> Node: + symbol = graph.nodes[node_id].get("label") + if symbol is None: + raise ValueError(f"Node {node_id} does not have a 'label' property.") + + rule = self.rules.get(symbol) + if rule is None: + raise ValueError( + f"Symbol '{symbol}' not found in grammar rules: {self.rules.keys()}" + ) + + # Based on the type of rule, construct the proper node + match rule: + case Grammar.Terminal(op=op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(op=op): + if (shared_node := variables.get(symbol)) is not None: + return shared_node + + children = [ + _recurse(child_id) for child_id in graph.successors(node_id) + ] + node = ( + Passthrough(symbol, children) + if op is None + else Container(symbol, children, op) + ) + if rule.shared: + variables[symbol] = node + case _: + raise ValueError( + f"Unexpected rule type for symbol '{symbol}': {rule}" + ) + + return node + + # Start with the root node + return _recurse(_root) + + def mutations( + self, + root: Node, + *, + which: Iterable[Node], + max_mutation_depth: int, + rng_shuffle: np.random.Generator | None = None, + variables: dict[str, Node] | None = None, + ) -> Iterator[Node]: + """Mutate nodes, returning all the different possibilities for them. + + Args: + root: The root from which to operate. + which: What nodes to mutate, look at `select()`. + max_mutation_depth: The maximum depth allowed for bfs iteration + on the mutant nodes. + rng_shuffle: Whether to shuffle the return order. This takes place at the + place when considering the possibilities for a given node, and does + not follow the order of `NonTerminal.choices`. + variables: Any predefined values you'd like for different symbols. + + Returns: + A new tree per possible mutation + """ + if isinstance(root, Leaf): + raise ValueError(f"Can't mutate `Leaf`: {root}") + + variables = variables or {} + mutation_ids = {id(n) for n in which} + + def _inner(node: Node) -> Iterator[Node]: + match node: + case Leaf(): + # We can't mutate leafs as they don't have possible choices to + # choose from # by definition so we ignore it even if it's + # in the set of `mutation_ids` + yield node + case Passthrough(children=children) | Container(children=children): + rule = self.rules.get(node.symbol) + if not isinstance(rule, Grammar.NonTerminal): + raise ValueError( + "Expected a `NonTerminal` for symbol '{node.symbol}' from the" + f" grammar but got rule {rule}" + ) + + # If we've already determined the value of this shared symbol + if (existing := variables.get(node.symbol)) is not None: + yield existing + return + + # If mutate, we return all possible bfs values from that node. + if id(node) in mutation_ids: + yield from self.bfs( + node.symbol, + rng_shuffle=rng_shuffle, + max_depth=max_mutation_depth, + variables=variables, + ) + else: + children_itrs: list[Iterator[Node]] = [ + _inner(c) for c in children + ] + for new_children in itertools.product(*children_itrs): + new_node = node._replace(children=new_children) + if rule.shared: + variables[new_node.symbol] = new_node + yield new_node + case _: + assert_never(node) + + yield from _inner(root) + + def parse(self, s: str) -> Node: # noqa: C901, PLR0912, PLR0915 + """Parse a `str` into a string of the `Grammar`. + + !!! note + + The initial symbol does not necessarily need to match the + `start_symbol` of the grammar. + + Args: + s: the `str` to convert into a string of the `Grammar`. + + Returns: + The node that represents the string. + """ + # Chunk up the str + string_tokens: list[str] = [] + brace_count = 0 + symbol = "" + for tok in s: + match tok: + case " ": + continue + case "(": + brace_count += 1 + if len(symbol) == 0: + raise ParseError( + f"Opening bracket '(' must be preceeded by symbol" + f" but was not.\n{s}" + ) + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ")": + brace_count -= 1 + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ",": + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case _: + symbol += tok + + if brace_count != 0: + raise ParseError( + f"Imbalanced braces, got {abs(brace_count)} too many" + f" {'(' if brace_count > 0 else ')'}." + ) + + if len(symbol) > 0: + string_tokens.append(symbol) + + # Convert to concrete tokens + tokens: list[Literal[")", "(", ","] | tuple[str, Leaf | Grammar.NonTerminal]] = [] + for symbol in string_tokens: + if symbol in "(),": + tokens.append(symbol) # type: ignore + continue + + rule = self.rules.get(symbol) + match rule: + case Grammar.Terminal(): + tokens.append((symbol, self._leafs[symbol])) + case Grammar.NonTerminal(): + tokens.append((symbol, rule)) + case None: + raise ParseError( + f"Invalid symbol '{symbol}', must be either '(', ')', ',' or" + f" a symbol in {self.rules.keys()}" + ) + case _: + assert_never(rule) + + # If we're being strict that shared elements must be the same, then + # we can do so more cheaply at the beginning by just comparing subtokens + # before we parse. This will also takes care of subnesting of shared nodes + # and allow us to skip on some of the token stream as we encounter shared variable + shared_token_sizes: dict[str, int] = {} + _shared_locs: dict[str, list[int]] = {s: [] for s in self._shared} + + # We figure out the substrings of where each shared symbol begings and ends + if _shared_locs: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, tok in enumerate(tokens): + match tok: + case "," | (_, Leaf()): + continue + case ")": + start = bracket_stack.pop(-1) + bracket_pairs[start] = i + case "(": + bracket_stack.append(i) + case (symbol, Grammar.NonTerminal(shared=shared)): + if i + 1 >= len(tokens): + raise ParseError( + f"Symbol '{tok}' is a `NonTerminal`, implying that it " + " should contain some inner elements. However we found it" + f" at the last index of the {tokens=}" + ) + if tokens[i + 1] != "(": + raise ParseError( + f"Symbol '{tok}' at position {i} is a `NonTerminal`," + " implying that it should contain some inner elements." + " However it was not followed by a '(' at position" + f" {i + 1} in {tokens=}" + ) + if shared is True: + _shared_locs[symbol].append(i) + case _: + assert_never(tok) + + # If we have more than one occurence of a shared symbol, + # we validate their subtokens match + for symbol, symbol_positions in _shared_locs.items(): + first_pos, rest = symbol_positions[0], symbol_positions[1:] + + # Calculate the inner tokens and length + bracket_first_start = first_pos + 1 + bracket_first_end = bracket_pairs[bracket_first_start] + + inner_tokens = tokens[bracket_first_start + 1 : bracket_first_end] + shared_symbol_token_size = len(inner_tokens) + shared_token_sizes[symbol] = shared_symbol_token_size + + for symbol_start in rest: + # +2, skip symbol_start and skip opening bracket '(' + symbol_tokens = tokens[symbol_start + 2 : shared_symbol_token_size] + if symbol_tokens != inner_tokens: + raise ParseError( + f"Found mismatch in shared symbol '{symbol}'" + f" with {symbol=} starting at token `{symbol_start}`" + f" and the same symbol at token `{first_pos}` which has" + f" {inner_tokens=}.\n{tokens=}" + ) + + if len(tokens) == 0: + raise ParseError("Recieved an empty strng") + + match tokens[0]: + case (symbol, Leaf()): + if len(tokens) > 1: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `Terminal`, but was proceeded by more token." + f"\n{tokens=}" + ) + _, root = tokens[0] + case (symbol, Grammar.NonTerminal(op=op)): + if op is None: + raise ParseError( + f"First token was symbol '{symbol}' which is a `NonTerminal` that" + " is `passthrough`, i.e. it has no associated" + " operation and can not be the root." + ) + if len(tokens) < 4: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NonTerminal`, but should have at least 3 more tokens" + " for a '(', 'child' and a closing ')'" + ) + + # NOTE: We don't care about shared here as we validate above that + # a shared variable can not contain itself, and there are no other + # symbols above or on the same level as this one (as it's the root). + # Hence we do not need to interact with `shared` here. + root = Container(symbol=symbol, children=[], op=op) + case "(" | ")" | ",": + raise ParseError("First token can not be a '(', ')' or a ','") + case rule: + assert_never(rule) + + if isinstance(root, Leaf): + return root + + variables: dict[str, Container | Passthrough] = {} + parent_stack: list[Container | Passthrough] = [] + current: Node = root + + token_stream = iter(tokens[1:]) + + for tok in token_stream: + match tok: + case ",": + parent_stack[-1].children.append(current) + case ")": + parent = parent_stack.pop() + parent.children.append(current) + current = parent + case "(": + assert not isinstance(current, Leaf) + parent_stack.append(current) + case (symbol, rule): + if isinstance(rule, Leaf): + current = rule + continue + + if rule.shared and (existing := variables.get(symbol)): + # Re-using a previous one so we can skip ahead in the tokens. + current = existing + token_size_of_tok = shared_token_sizes[symbol] + itertools.islice(token_stream, token_size_of_tok) # Skips + continue + + if rule.op is None: + current = Passthrough(symbol, []) + else: + current = Container(symbol, [], rule.op) + + if rule.shared: + variables[symbol] = current + case _: + assert_never(tok) + + return current + + # TODO: The variables thing can mess up the max depth + def bfs( # noqa: C901 + self, + symbol: str, + *, + max_depth: int, + current_depth: int = 0, + variables: dict[str, Node] | None = None, + rng_shuffle: np.random.Generator | None = None, + ) -> Iterator[Node]: + """Iterate over all possible strings in a breadth first manner. + + Args: + symbol: The symbol to start the string from. + max_depth: The maximum depth of the produced string. This may not + be fully gauranteed given shared `NonTerminal`s. This is required + to prevent infinite recursion. Any non-terminated strings, i.e. those + which still require expansion, but have exceeded the depth, will not be + returned. + current_depth: What depth this call of the function is acting at. This is used + recursively and can mostly be left at `0`. + variables: Any instantiated shared variables used for a `shared=` + `NonTerminal`. + rng_shuffle: Whether to shuffle the order of the children when doing breadth + first search. This may only be required if you are not consuming the full + iterator this returns. For the most part this can be ignored. + + Returns: + An iterator over the valid strings in the grammar. + """ + if current_depth > max_depth: + return + + variables = variables or {} + shared_node = variables.get(symbol) + if shared_node is not None: + yield shared_node + return # TODO: check + + nxt_depth = current_depth + 1 + + rule = self.rules.get(symbol) + match rule: + case Grammar.Terminal(op=op): + node = Leaf(symbol, op) + yield node + case Grammar.NonTerminal(choices=choices, op=op): + for choice in choices: + children = choice.split(" ") + child_expansions: list[Iterator] = [ + self.bfs( + child_symbol, + max_depth=max_depth, + current_depth=nxt_depth, + rng_shuffle=rng_shuffle, + variables=variables, + ) + for child_symbol in children + ] + + if rng_shuffle: + # Works correctly with python lists, but typing for numpy is off + rng_shuffle.shuffle(child_expansions) # type: ignore + + for possible in itertools.product(*child_expansions): + if op is None: + node = Passthrough(symbol, children=list(possible)) + else: + node = Container(symbol, op=op, children=list(possible)) + + if rule.shared: + variables[symbol] = node + + yield node + case None: + raise ValueError(f"No symbol {symbol} in rules {self.rules.keys()}") + case _: + assert_never(rule) + + def is_valid( + self, + node: Node, + *, + already_shared: set[str] | None = None, + ) -> bool: + """Check if a given string is valid. + + Args: + node: The start of the string. + already_shared: Use for recursion, can mostly be kept as `None`. + Used to ensure that `NonTerminal`s that are `shared=True`, do + not contain themselves. + """ + rule = self.rules.get(node.symbol) + if rule is None: + raise ValueError( + f"Node has unknown symbol {node.symbol}, valid symbols are" + f" {self.rules.keys()}" + ) + + # We should never encounter a situtation where we have some nesting of shared + # nodes, for example, consider the following, where L2 is shared. + # L2 -> x -> ... -> L1 -> x -> ... + already_shared = already_shared or set() + if ( + isinstance(rule, Grammar.NonTerminal) + and rule.shared + and node.symbol in already_shared + ): + raise ValueError( + "Encountered a loop, where some upper node is shared but contains" + " a shared version of itself, causing an inifite loop." + ) + + match node: + case Leaf(symbol): + return symbol in self.rules + case Container(symbol, children, _) | Passthrough(symbol, children): + s = " ".join(child.symbol for child in children) + + match rule: + case Grammar.Terminal(_): + return s in self.rules and all( + self.is_valid(child, already_shared=already_shared.copy()) + for child in children + ) + case Grammar.NonTerminal(choices, _): + return s in choices and all( + self.is_valid(child, already_shared=already_shared.copy()) + for child in children + ) + case _: + assert_never(rule) + return None + case _: + assert_never(node) + return None + + def to_model(self, string: str) -> Any: + """Convert a string form this grammar into its model form.""" + node = self.parse(string) + return node.to_model() + + +# TODO: This is just for plotting, not sure where it should go +# https://stackoverflow.com/a/29597210 +def hierarchy_pos( + G: nx.DiGraph, + root: int, + width: float = 2.0, + vert_gap: float = 1.2, + vert_loc: float = 1, + xcenter: float = 1.5, +) -> dict[int, tuple[float, float]]: + """From Joel's answer at https://stackoverflow.com/a/29597210/2966723. + Licensed under Creative Commons Attribution-Share Alike. + + If the graph is a tree this will return the positions to plot this in a + hierarchical layout. + + G: the graph (must be a tree) + + root: the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + + width: horizontal space allocated for this branch - avoids overlap with other branches + + vert_gap: gap between levels of hierarchy + + vert_loc: vertical location of root + + xcenter: horizontal location of root + """ + if not nx.is_tree(G): + raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") + + def _hierarchy_pos( + G, + root, + width=2.0, + vert_gap=1.2, + vert_loc: float = 1, + xcenter=1.5, + pos: dict[int, tuple[float, float]] | None = None, + parent=None, + ) -> dict[int, tuple[float, float]]: + """See hierarchy_pos docstring for most arguments. + + pos: a dict saying where all nodes go if they have been assigned + parent: parent of this branch. - only affects it if non-directed + + """ + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.neighbors(root)) + if not isinstance(G, nx.DiGraph) and parent is not None: + children.remove(parent) + if len(children) != 1: + dx = width / len(children) + nextx = xcenter - width / 3 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos( + G, + child, + width=dx, + vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, + xcenter=nextx, + pos=pos, + parent=root, + ) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) + + +# TODO: Everything below this point should probably be moved. + + +@dataclass +class RandomSampler: + grammar: Grammar + + def sample( + self, + n: int, + *, + rng: np.random.Generator | _BufferedRandInts | None = None, + ) -> list[Node]: + match rng: + case None: + rng = _BufferedRandInts(rng=np.random.default_rng()) + case np.random.Generator(): + rng = _BufferedRandInts(rng=rng) + case _BufferedRandInts(): + pass + case _: + assert_never(rng) + + return [self.grammar.sample(rng=rng) for _ in range(n)] + + +@dataclass +class MutationSampler: + grammar: Grammar + ref_point: Node + max_mutation_depth: int + mutant_selector: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ) + + def sample(self, n: int, *, rng: np.random.Generator | None = None) -> list[Node]: + if rng is None: + rng = np.random.default_rng() + + nodes_to_mutate_from = self.ref_point.select(how=self.mutant_selector) + all_possible_mutants = self.grammar.mutations( + self.ref_point, + which=nodes_to_mutate_from, + max_mutation_depth=self.max_mutation_depth, + ) + all_possible_mutants = list(all_possible_mutants) + return rng.choice(all_possible_mutants, size=n, replace=False) # type: ignore + + +@dataclass +class GrammarSampler: + samplers: Mapping[str, RandomSampler | MutationSampler] + + def sample( + self, n: int, *, rng: np.random.Generator | None = None + ) -> list[dict[str, Node]]: + """Sample n dictionaries of nodes from the underlying grammar samplers. + + Args: + n: the number of samples to generate. + rng: the random number generator to use. + + Returns: + A list of dictionaries mapping each sampler's key to a sampled Node. + """ + if rng is None: + rng = np.random.default_rng() + + samples: dict[str, list[Node]] = { + k: sampler.sample(n, rng=rng) for k, sampler in self.samplers.items() + } + return [{k: samples[k][i] for k in samples} for i in range(n)] + + @classmethod + def random(cls, grammars: Mapping[str, Grammar]) -> GrammarSampler: + return cls(samplers={k: RandomSampler(g) for k, g in grammars.items()}) + + @classmethod + def prior( + cls, + grammars: Mapping[str, Grammar], + *, + mutant_selector: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ) = ("climb", range(1, 3)), + max_mutation_depth: int = 3, + ) -> GrammarSampler: + """Creates samplers for the grammars, using the prior where possible. + + Grammars without a prior use the `RandomSampler` while those with a prior + have mutations done around the prior. + + Args: + grammars: the grammars to build samplers for. + mutant_selector: Please take a look at [`select()`][neps.space.grammar.select] + max_mutation_depth: Dictates how deep mutations of grammars can go to prevent + overly large configurations due to recursive rules. + """ + samplers: dict[str, RandomSampler | MutationSampler] = {} + for k, g in grammars.items(): + if g._prior_node is not None: + samplers[k] = MutationSampler( + g, + ref_point=g._prior_node, + max_mutation_depth=max_mutation_depth, + mutant_selector=mutant_selector, + ) + else: + samplers[k] = RandomSampler(g) + + return cls(samplers) diff --git a/neps/space/search_space.py b/neps/space/search_space.py index 2b0659f6a..166378421 100644 --- a/neps/space/search_space.py +++ b/neps/space/search_space.py @@ -9,7 +9,14 @@ from dataclasses import dataclass, field from typing import Any -from neps.space.parameters import Categorical, Constant, Float, Integer, Parameter +from neps.space.grammar import Grammar +from neps.space.parameters import ( + Categorical, + Constant, + Float, + Integer, + Parameter, +) # NOTE: The use of `Mapping` instead of `dict` is so that type-checkers @@ -19,12 +26,15 @@ class SearchSpace(Mapping[str, Parameter | Constant]): """A container for parameters.""" - elements: Mapping[str, Parameter | Constant] = field(default_factory=dict) + elements: Mapping[str, Parameter | Grammar | Constant] = field(default_factory=dict) """All items in the search space.""" categoricals: Mapping[str, Categorical] = field(init=False) """The categorical hyperparameters in the search space.""" + grammars: Mapping[str, Grammar] = field(init=False) + """The grammar parameters of the search space.""" + numerical: Mapping[str, Integer | Float] = field(init=False) """The numerical hyperparameters in the search space. @@ -43,14 +53,9 @@ class SearchSpace(Mapping[str, Parameter | Constant]): """The constants in the search space.""" @property - def searchables(self) -> Mapping[str, Parameter]: - """The hyperparameters that can be searched over. - - !!! note - - This does not include either constants or fidelities. - """ - return {**self.numerical, **self.categoricals} + def grammar(self) -> tuple[str, Grammar] | None: + """The grammar parameter for the search space if any.""" + return None if len(self.grammars) == 0 else next(iter(self.grammars.items())) @property def fidelity(self) -> tuple[str, Float | Integer] | None: @@ -65,6 +70,7 @@ def __post_init__(self) -> None: numerical: dict[str, Float | Integer] = {} categoricals: dict[str, Categorical] = {} constants: dict[str, Any] = {} + grammars: dict[str, Grammar] = {} # Process the hyperparameters for name, hp in self.elements.items(): @@ -86,7 +92,14 @@ def __post_init__(self) -> None: categoricals[name] = hp case Constant(): constants[name] = hp.value - + case Grammar(): + if len(grammars) >= 1: + raise ValueError( + "neps only supports one grammar parameter in the" + " pipeline space, but multiple were given." + f" Grammars: {grammars}, new: {name}" + ) + grammars[name] = hp case _: raise ValueError(f"Unknown hyperparameter type: {hp}") @@ -94,8 +107,9 @@ def __post_init__(self) -> None: self.numerical = numerical self.constants = constants self.fidelities = fidelities + self.grammars = grammars - def __getitem__(self, key: str) -> Parameter | Constant: + def __getitem__(self, key: str) -> Parameter | Constant | Grammar: return self.elements[key] def __iter__(self) -> Iterator[str]: diff --git a/pyproject.toml b/pyproject.toml index a1ce332e0..3824542a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ dev = [ "mkdocs-literate-nav", "mike", "black", # This allows mkdocstrings to format signatures in the docs + "grakel==0.1.10", ] [tool.setuptools.packages.find] diff --git a/t.py b/t.py deleted file mode 100644 index 5fbd25dbb..000000000 --- a/t.py +++ /dev/null @@ -1,13 +0,0 @@ - -import rich -import neps - -space = neps.SearchSpace( - { - "a": neps.Integer(0, 10), - "b": neps.Categorical(["a", "b", "c"]), - "c": neps.Float(1e-5, 1e0, log=True, prior=1e-3), - } - ) - -rich.print(space) diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 000000000..198d517de --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from functools import partial +from typing import Literal + +import numpy as np +import pytest +import torch +from torch import nn + +from neps.space.grammar import ( + Container, + Grammar, + Leaf, + Node, + ParseError, + Passthrough, + bfs_node, + dfs_node, + select, + to_model, + to_nxgraph, + to_string, +) + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + start_symbol="s", + grammar={ + "s": (["a", "b", "p", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + }, +) + +grammar_2 = Grammar.from_dict( + start_symbol="L1", + grammar={ + "L1": (["L2 L2 L3"], join), + "L2": Grammar.NonTerminal(["L3"], join, shared=True), + "L3": Grammar.NonTerminal(["a", "b"], None, shared=True), + "a": T("a"), + "b": T("a"), + }, +) + +grammar_3 = Grammar.from_dict( + start_symbol="S", + grammar={ + "S": (["mlp", "O"], nn.Sequential), + "mlp": (["L", "O", "S O"], nn.Sequential), + "L": ( + ["linear64 linear128 relu O linear64 relu O", "linear64 elu linear64"], + nn.Sequential, + ), + "O": (["linear64", "linear64 relu", "linear128 elu"], nn.Sequential), + "linear64": partial(nn.LazyLinear, out_features=64), + "linear128": partial(nn.LazyLinear, out_features=64), + "relu": nn.ReLU, + "elu": nn.ELU, + }, +) + + +@pytest.mark.parametrize( + ("grammar", "string", "built", "node"), + [ + (grammar_1, "a", "a", Leaf("a", T("a"))), + (grammar_1, "b", "b", Leaf("b", T("b"))), + ( + grammar_1, + "s(a)", + "[a]", + Container("s", op=join, children=[Leaf("a", T("a"))]), + ), + ( + grammar_1, + "s(p(a, b))", + "[ab]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(a, b), p(s(a)))", + "[ab[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + Passthrough( + "p", + children=[Container("s", children=[Leaf("a", T("a"))], op=join)], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(s(a)))", + "[[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[ + Container( + "s", + children=[Leaf("a", T("a"))], + op=join, + ) + ], + ), + ], + op=join, + ), + ), + ], +) +def test_string_serialization_and_deserialization_correct( + grammar: Grammar, + string: str, + built: str, + node: Node, +) -> None: + # Test parsing + parsed = grammar.parse(string) + assert parsed == node + + # Test serialization + serialized_again = to_string(parsed) + assert serialized_again == string + + # Test building + assert to_model(parsed) == built + + # Test graph and back again + graph = to_nxgraph(parsed, include_passthroughs=True) + + node_again = grammar.node_from_graph(graph) + assert parsed == node_again + + +@pytest.mark.parametrize( + ("grammar", "string"), + [ + (grammar_1, "c"), + (grammar_1, ""), + (grammar_1, "s(a"), + (grammar_1, "p(a, b)"), + (grammar_1, "("), + (grammar_1, "s(a))"), + (grammar_1, "s((a)"), + (grammar_1, "s("), + (grammar_1, "s)"), + (grammar_1, "a, a"), + (grammar_1, "a,"), + (grammar_1, "s, s"), + # Invalid due to shared rule but not sharing values + (grammar_2, "L1(L2(L3(a)), L2(L3(a)), L3(b))"), + ], +) +def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: + with pytest.raises(ParseError): + grammar.parse(string) + + +def test_dfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(dfs_node(node)) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # go down left depth first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + # go down right depth first + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" + + +def test_bfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(bfs_node(node)) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # Second level first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + # Then 3rd level + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" + + +def test_select_symbol() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("symbol", "d"))) + assert selected == [ + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ] + + +def test_select_depth() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("depth", 1))) + assert selected == root.children + + selected = list(select(root, how=("depth", range(1, 3)))) + expected = [ + # Depth 1 + *root.children, + # Depth 2 + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + ] + assert selected == expected + + +def test_select_climb() -> None: + # NOTE: The order is rather arbitrary and not much thought has been given to it. + # However the test still tests a particular order that was done by trial and + # error. Feel free to redo the order if this changes. + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("climb", 0))) + assert selected == [ + Leaf("l3", op=T("l3")), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + Leaf("l1", op=T("l1")), + ] + + selected = list(select(root, how=("climb", range(1, 3)))) + expected = [ + root, + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + ] + for i, (sel, exp) in enumerate(zip(selected, expected, strict=True)): + assert sel == exp, f"Mismatch at pos {i}:\nExpected: {exp}\n\nGot: {sel}" + + +@pytest.mark.parametrize("grammar", [grammar_3]) +def test_sample_grammar_and_build_model(grammar: Grammar): + rng = np.random.default_rng(seed=42) + + x = torch.randn(32, 100) + + t0 = time.perf_counter() + samples = 1_000 + for _ in range(samples): + sample: Node = grammar.sample("S", rng=rng) + model: nn.Module = to_model(sample) + model(x) + assert sum(p.numel() for p in model.parameters()) > 0 + + # feel free to increase the time limit here, based on running this on a M4 Mac + assert time.perf_counter() - t0 < 1 + + +@pytest.mark.parametrize( + ("grammar", "how"), + [ + (grammar_3, ("symbol", "S")), + (grammar_3, ("depth", 2)), + (grammar_3, ("depth", range(1, 3))), + (grammar_3, ("climb", 2)), + (grammar_3, ("climb", range(1, 3))), + ], +) +def test_sample_grammar_and_mutate( + grammar: Grammar, + how: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ), +): + rng = np.random.default_rng(seed=42) + + x = torch.randn(32, 100) + + time.perf_counter() + samples = 1_000 + for _ in range(samples): + sample: Node = grammar.sample("S", rng=rng) + muts = grammar.mutations( + root=sample, + which=select(root=sample, how=how), + max_mutation_depth=3, + ) + + assert len(list(muts)) > 0 + + for _mut in muts: + model: nn.Module = to_model(_mut) + model(x) + assert sum(p.numel() for p in model.parameters()) > 0 diff --git a/tests/test_graphs/__init__.py b/tests/test_graphs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_graphs/test_botorch_wl_kernel.py b/tests/test_graphs/test_botorch_wl_kernel.py new file mode 100644 index 000000000..2ded6237e --- /dev/null +++ b/tests/test_graphs/test_botorch_wl_kernel.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import networkx as nx +import pytest +import torch +from botorch.models.gp_regression_mixed import Kernel + +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel + + +def create_simple_graphs(num_graphs: int) -> list[nx.Graph]: + """Helper function to create a list of graphs.""" + graphs = [] + for _i in range(num_graphs): + G = nx.Graph() + G.add_nodes_from([0, 1, 2]) + G.add_edges_from([(0, 1), (1, 2)]) + graphs.append(G) + return graphs + + +class TestBoTorchWLKernel: + @pytest.fixture + def simple_graphs(self) -> list[nx.Graph]: + return create_simple_graphs(3) + + @pytest.fixture + def wl_kernel(self, simple_graphs: list[nx.Graph]) -> BoTorchWLKernel: + return BoTorchWLKernel( + graph_lookup=simple_graphs, + n_iter=2, + normalize=True, + active_dims=(0,), + ) + + def test_initialization( + self, wl_kernel: BoTorchWLKernel, simple_graphs: list[nx.Graph] + ) -> None: + """Test that the kernel is initialized correctly.""" + assert isinstance(wl_kernel, Kernel) + assert len(wl_kernel.graph_lookup) == len(simple_graphs) + assert wl_kernel.n_iter == 2 + assert wl_kernel.normalize is True + assert torch.equal(wl_kernel.active_dims, torch.tensor([0])) + + def test_precompute_graph_data(self, wl_kernel: BoTorchWLKernel) -> None: + """Test that graph data is precomputed correctly.""" + assert hasattr(wl_kernel, "adjacency_cache") + assert hasattr(wl_kernel, "label_cache") + assert len(wl_kernel.adjacency_cache) == len(wl_kernel.graph_lookup) + assert len(wl_kernel.label_cache) == len(wl_kernel.graph_lookup) + + def test_set_graph_lookup(self, wl_kernel: BoTorchWLKernel) -> None: + """Test that the graph lookup can be updated.""" + new_graphs = create_simple_graphs(2) + wl_kernel.set_graph_lookup(new_graphs) + assert len(wl_kernel.graph_lookup) == 2 + assert len(wl_kernel.adjacency_cache) == 2 + assert len(wl_kernel.label_cache) == 2 + + def test_forward_self_kernel(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for self-similarity.""" + x = torch.tensor([[0], [1], [2]], dtype=torch.float64) + K = wl_kernel.forward(x, x) + assert K.shape == (3, 3) # Kernel matrix should be 3x3 + assert torch.allclose(K, K.T) # Kernel matrix should be symmetric + + def test_forward_cross_kernel(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for cross-similarity.""" + x1 = torch.tensor([[0], [1]], dtype=torch.float64) + x2 = torch.tensor([[1], [2]], dtype=torch.float64) + K = wl_kernel.forward(x1, x2) + assert K.shape == (2, 2) # Kernel matrix should be 2x2 + + def test_forward_diagonal(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for diagonal only.""" + x = torch.tensor([[0], [1], [2]], dtype=torch.float64) + K = wl_kernel.forward(x, x, diag=True) + assert K.shape == (3,) # Diagonal should be a vector of length 3 + + def test_handle_negative_one_index(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the handling of the -1 index.""" + x = torch.tensor([[-1], [0], [1]], dtype=torch.float64) + K = wl_kernel.forward(x, x) + assert K.shape == (3, 3) # Kernel matrix should be 3x3 + # Ensure that -1 refers to the last graph + last_graph_idx = len(wl_kernel.graph_lookup) - 1 + assert torch.allclose(K[0, 0], K[last_graph_idx, last_graph_idx]) + + def test_forward_batched_input(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for batched input.""" + x1 = torch.tensor([[[0], [1]], [[1], [2]]], dtype=torch.float64) + x2 = torch.tensor([[[1], [2]], [[0], [1]]], dtype=torch.float64) + K = wl_kernel.forward(x1, x2) + assert K.shape == (2, 2, 2) # Batched kernel matrix should be 2x2x2 + + def test_forward_invalid_input(self, wl_kernel: BoTorchWLKernel) -> None: + """Test that invalid input raises an error.""" + x1 = torch.tensor([[0], [1], [2]], dtype=torch.float64) + x2 = torch.tensor([[0], [1]], dtype=torch.float64) + with pytest.raises(NotImplementedError): + wl_kernel.forward(x1, x2, last_dim_is_batch=True) diff --git a/tests/test_graphs/test_optimization_over_graphs.py b/tests/test_graphs/test_optimization_over_graphs.py new file mode 100644 index 000000000..958031af0 --- /dev/null +++ b/tests/test_graphs/test_optimization_over_graphs.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from itertools import product + +import networkx as nx +import pytest +import torch +from botorch import fit_gpytorch_mll +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP +from botorch.models.kernels import CategoricalKernel +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.kernels import AdditiveKernel, MaternKernel, ScaleKernel + +from neps.optimizers.models.graphs.context_managers import set_graph_lookup +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel +from neps.optimizers.models.graphs.optimization import optimize_acqf_graph, sample_graphs +from neps.optimizers.models.graphs.utils import min_max_scale + + +class TestGraphOptimizationPipeline: + @pytest.fixture + def setup_data(self) -> dict: + """Fixture to set up common data for tests.""" + TRAIN_CONFIGS = 50 + TEST_CONFIGS = 10 + TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + + N_NUMERICAL = 2 + N_CATEGORICAL = 1 + N_CATEGORICAL_VALUES_PER_CATEGORY = 2 + N_GRAPH = 1 + + # Generate random data + X = torch.cat( + [ + torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), + torch.randint( + 0, + N_CATEGORICAL_VALUES_PER_CATEGORY, + (TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64, + ), + torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1), + ], + dim=1, + ) + + # Generate random graphs + graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] + + # Generate random target values + y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 + + # Split into train and test sets + train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] + train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] + train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) + + # Scale the data + train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) + + return { + "train_x": train_x, + "test_x": test_x, + "train_graphs": train_graphs, + "test_graphs": test_graphs, + "train_y": train_y, + "test_y": test_y, + "N_NUMERICAL": N_NUMERICAL, + "N_CATEGORICAL": N_CATEGORICAL, + "N_CATEGORICAL_VALUES_PER_CATEGORY": N_CATEGORICAL_VALUES_PER_CATEGORY, + "N_GRAPH": N_GRAPH, + } + + def test_gp_fit_and_predict(self, setup_data: dict) -> None: + """Test fitting the GP and making predictions.""" + train_x = setup_data["train_x"] + train_y = setup_data["train_y"] + test_x = setup_data["test_x"] + train_graphs = setup_data["train_graphs"] + setup_data["test_graphs"] + + # Define the kernels + kernels = [ + ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=setup_data["N_NUMERICAL"], + active_dims=range(setup_data["N_NUMERICAL"]), + ) + ), + ScaleKernel( + CategoricalKernel( + ard_num_dims=setup_data["N_CATEGORICAL"], + active_dims=range( + setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + setup_data["N_CATEGORICAL"], + ), + ) + ), + ScaleKernel( + BoTorchWLKernel( + graph_lookup=train_graphs, + n_iter=5, + normalize=True, + active_dims=(train_x.shape[1] - 1,), + ) + ), + ] + + # Create the GP model + gp = SingleTaskGP( + train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels) + ) + + # Fit the GP + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) + fit_gpytorch_mll(mll) + + # Make predictions on the test set + with torch.no_grad(): + posterior = gp.forward(test_x) + predictions = posterior.mean + uncertainties = posterior.variance.sqrt() + + # Ensure predictions are in the correct shape (10, 1) + predictions = predictions.unsqueeze(-1) # Reshape to (10, 1) + + # Basic checks + assert predictions.shape == (setup_data["test_x"].shape[0], 1) + assert uncertainties.shape == (setup_data["test_x"].shape[0],) + + def test_acquisition_function_optimization(self, setup_data: dict) -> None: + """Test optimizing the acquisition function with graph sampling.""" + train_x = setup_data["train_x"] + train_y = setup_data["train_y"] + train_graphs = setup_data["train_graphs"] + + # Define the kernels + kernels = [ + ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=setup_data["N_NUMERICAL"], + active_dims=range(setup_data["N_NUMERICAL"]), + ) + ), + ScaleKernel( + CategoricalKernel( + ard_num_dims=setup_data["N_CATEGORICAL"], + active_dims=range( + setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + setup_data["N_CATEGORICAL"], + ), + ) + ), + ScaleKernel( + BoTorchWLKernel( + graph_lookup=train_graphs, + n_iter=5, + normalize=True, + active_dims=(train_x.shape[1] - 1,), + ) + ), + ] + + # Create the GP model + gp = SingleTaskGP( + train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels) + ) + + # Fit the GP + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) + fit_gpytorch_mll(mll) + + # Define the acquisition function + acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, + ) + + # Define bounds for optimization + bounds = torch.tensor( + [ + [0.0] * setup_data["N_NUMERICAL"] + + [0.0] * setup_data["N_CATEGORICAL"] + + [-1.0] * setup_data["N_GRAPH"], + [1.0] * setup_data["N_NUMERICAL"] + + [float(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"] - 1)] + * setup_data["N_CATEGORICAL"] + + [len(train_x) - 1] * setup_data["N_GRAPH"], + ] + ) + + # Define fixed categorical features + cats_per_column = { + i: list(range(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"])) + for i in range( + setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + setup_data["N_CATEGORICAL"], + ) + } + fixed_cats = [ + dict(zip(cats_per_column.keys(), combo, strict=False)) + for combo in product(*cats_per_column.values()) + ] + + # Optimize the acquisition function + best_candidate, best_graph, best_score = optimize_acqf_graph( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + train_graphs=train_graphs, + num_graph_samples=2, + num_restarts=2, + raw_samples=16, + q=1, + ) + + # Assertions for the acquisition function optimization + assert isinstance(best_candidate, torch.Tensor), ( + "Best candidate should be a tensor" + ) + assert best_candidate.shape == (1, train_x.shape[1] - 1), ( + "Best candidate should have the correct shape (excluding the graph index)" + ) + assert isinstance(best_graph, nx.Graph), "Best graph should be a NetworkX graph" + assert isinstance(best_score, float), "Best score should be a float" + + # Ensure the best candidate does not contain the graph index column + assert best_candidate.shape[1] == train_x.shape[1] - 1, ( + "Best candidate should not include the graph index column" + ) + + def test_graph_sampling(self, setup_data: dict) -> None: + """Test the graph sampling functionality.""" + train_graphs = setup_data["train_graphs"] + num_samples = 5 + + # Sample graphs + sampled_graphs = sample_graphs(train_graphs, num_samples=num_samples) + + # Basic checks + assert len(sampled_graphs) == num_samples, ( + f"Expected {num_samples} sampled graphs, got {len(sampled_graphs)}" + ) + assert all(isinstance(graph, nx.Graph) for graph in sampled_graphs), ( + "All sampled graphs should be NetworkX graphs" + ) + assert all(nx.is_connected(graph) for graph in sampled_graphs), ( + "All sampled graphs should be connected" + ) + + def test_min_max_scaling(self, setup_data: dict) -> None: + """Test the min-max scaling utility.""" + train_x = setup_data["train_x"] + + # Apply min-max scaling + scaled_train_x = min_max_scale(train_x) + + # Assertions for min-max scaling + assert torch.all(scaled_train_x >= 0), "Scaled values should be >= 0" + assert torch.all(scaled_train_x <= 1), "Scaled values should be <= 1" + assert scaled_train_x.shape == train_x.shape, ( + "Scaled data should have the same shape as the input data" + ) + + # Check that the scaling is correct + for i in range(train_x.shape[1]): + col_min = torch.min(train_x[:, i]) + col_max = torch.max(train_x[:, i]) + if col_min != col_max: # Avoid division by zero + expected_scaled_col = (train_x[:, i] - col_min) / (col_max - col_min) + assert torch.allclose(scaled_train_x[:, i], expected_scaled_col), ( + f"Scaling is incorrect for column {i}" + ) + + def test_set_graph_lookup(self, setup_data: dict) -> None: + """Test the set_graph_lookup context manager.""" + train_graphs = setup_data["train_graphs"] + test_graphs = setup_data["test_graphs"] + + # Define the kernel + kernel = BoTorchWLKernel( + graph_lookup=train_graphs, n_iter=5, normalize=True, active_dims=(0,) + ) + + # Use the context manager to temporarily set the graph lookup + with set_graph_lookup(kernel, test_graphs, append=True): + assert len(kernel.graph_lookup) == len(train_graphs) + len(test_graphs) + + # Check that the original graph lookup is restored + assert len(kernel.graph_lookup) == len(train_graphs) diff --git a/tests/test_graphs/test_torch_wl_kernel.py b/tests/test_graphs/test_torch_wl_kernel.py new file mode 100644 index 000000000..3e2f3c1f4 --- /dev/null +++ b/tests/test_graphs/test_torch_wl_kernel.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import networkx as nx +import numpy as np +import pytest +import torch +from grakel import WeisfeilerLehman, graph_from_networkx + +from neps.optimizers.models.graphs.kernels import TorchWLKernel +from neps.optimizers.models.graphs.utils import graphs_to_tensors + + +class TestTorchWLKernel: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @pytest.fixture + def example_graphs_set(self) -> list[nx.Graph]: + # Create example graphs for testing + G1 = nx.Graph() + G1.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 3), (3, 4)]) + for node in G1.nodes(): + G1.nodes[node]["label"] = str(node) + + G2 = nx.Graph() + G2.add_edges_from([(0, 1), (1, 2), (3, 4), (4, 0)]) + for node in G2.nodes(): + G2.nodes[node]["label"] = str(node) + + G3 = nx.Graph() + G3.add_edges_from([(0, 1), (1, 3), (3, 2), (2, 4), (4, 0), (1, 2)]) + for node in G3.nodes(): + G3.nodes[node]["label"] = str(node) + + return [G1, G2, G3] + + @pytest.fixture + def random_graphs_sets(self) -> list[list[nx.Graph]]: + # Set a seed for reproducibility + seed = 100 + np.random.seed(seed) + torch.manual_seed(seed) + random_graph_sets = [] + + # Generate 10 random sets of graphs + for _ in range(10): + # Number of graphs in the set (2 to 10) + num_graphs = np.random.randint(2, 11) + graph_set = [] + + for _ in range(num_graphs): + # Number of nodes in the graph (3 to 50) + num_nodes = np.random.randint(3, 51) + G = nx.Graph() + + # Add nodes with labels + for node in range(num_nodes): + G.add_node(node, label=str(node)) + + # Add random edges + for u in range(num_nodes): + for v in range(u + 1, num_nodes): + if np.random.rand() > 0.5: # 50% chance to add an edge + G.add_edge(u, v) + + graph_set.append(G) + + random_graph_sets.append(graph_set) + + return random_graph_sets + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 5, 10]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_against_grakel( + self, n_iter: int, normalize: bool, random_graphs_sets: list[list[nx.Graph]] + ) -> None: + for graph_set in random_graphs_sets: + adjacency_matrices, label_tensors = graphs_to_tensors( + graph_set, device=self.device + ) + + # Initialize Torch WL Kernel + torch_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = ( + torch_kernel(adjacency_matrices, label_tensors).cpu().numpy() + ) + + # Initialize GraKel WL Kernel + grakel_graphs = list( + graph_from_networkx(graph_set, node_labels_tag="label", as_Graph=True) + ) + grakel_kernel = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_kernel.fit_transform(grakel_graphs) + + # Compare the kernel matrices + np.testing.assert_allclose( + torch_kernel_matrix, + grakel_kernel_matrix, + rtol=1e-5, + atol=1e-8, + err_msg=f"Kernel matrices differ for graph={graph_set}, n_iter={n_iter}", + ) + + def test_empty_graph(self) -> None: + G_empty = nx.Graph() + G_empty.add_node(0) + G_empty.nodes[0]["label"] = "0" + + adjacency_matrices, label_tensors = graphs_to_tensors( + [G_empty], device=self.device + ) + + # Initialize kernel and compute + kernel = TorchWLKernel(n_iter=3, normalize=True) + kernel_matrix = kernel(adjacency_matrices, label_tensors) + + # For a single graph, should get a 1x1 matrix with value 1.0 + expected = torch.ones(1, 1, device=self.device) + torch.testing.assert_close(kernel_matrix, expected) + + def test_invalid_input(self) -> None: + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + with pytest.raises( + ValueError, match="Mismatch between adjacency matrices and label tensors" + ): + wl_kernel([], [torch.tensor([0])]) + + def test_kernel_on_single_node_graph(self) -> None: + G_single = nx.Graph() + G_single.add_node(0) + G_single.nodes[0]["label"] = "0" + + adjacency_matrices, label_tensors = graphs_to_tensors( + [G_single], device=self.device + ) + + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(adjacency_matrices, label_tensors) + + expected = torch.ones(1, 1, device=self.device) + torch.testing.assert_close(K, expected) + + def test_wl_kernel_with_empty_graph_and_reordered_edges( + self, random_graphs_sets: list[list[nx.Graph]] + ) -> None: + """Test the TorchWLKernel with an empty graph and a graph with reordered edges.""" + for graph_set in random_graphs_sets: + # Create an empty graph + G_empty = nx.Graph() + G_empty.add_node(0) + G_empty.nodes[0]["label"] = "0" + + # Select the first graph from the set to reorder its edges + G = graph_set[0] + G_reordered = nx.Graph() + + # Add all nodes from the original graph to G_reordered + for node in G.nodes(): + G_reordered.add_node(node, label=G.nodes[node]["label"]) + + # Reorder edges randomly + edges = list(G.edges()) + np.random.shuffle(edges) # Randomly shuffle the edges + G_reordered.add_edges_from(edges) + + # Combine the empty graph, original graph, and reordered graph + graphs = [G_empty, G, G_reordered] + adjacency_matrices, label_tensors = graphs_to_tensors( + graphs, device=self.device + ) + + # Initialize and compute the kernel + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(adjacency_matrices, label_tensors) + + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert torch.allclose(K[1, 1], K[2, 2]), ( + "Kernel value for original and reordered graphs should be the same" + ) + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_with_different_node_labels( + self, n_iter: int, normalize: bool, example_graphs_set: list[nx.Graph] + ) -> None: + graphs = [] + for i, G in enumerate(example_graphs_set): + G_copy = G.copy() + prefix = ["node_", "vertex_", "n"][i] + for node in G_copy.nodes(): + G_copy.nodes[node]["label"] = f"{prefix}{node}" + graphs.append(G_copy) + + adjacency_matrices, label_tensors = graphs_to_tensors(graphs, device=self.device) + + wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = wl_kernel(adjacency_matrices, label_tensors).cpu().numpy() + + grakel_graphs = graph_from_networkx(graphs, node_labels_tag="label") + grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) + + np.testing.assert_allclose( + torch_kernel_matrix, + grakel_kernel_matrix, + rtol=1e-5, + atol=1e-8, + err_msg=f"Kernel matrices differ for n_iter={n_iter}, normalize={normalize}", + ) + + def test_wl_kernel_with_same_node_labels( + self, example_graphs_set: list[nx.Graph] + ) -> None: + """Test WL kernel behavior with same node labels but different structures. + + Even when all nodes have the same label, the WL kernel should: + 1. Produce a symmetric matrix + 2. Have 1.0 on the diagonal (self-similarity) + 3. Have off-diagonal values less than 1.0 (different structures) + 4. Maintain non-negative values (it's a valid kernel) + """ + graphs = [] + for G in example_graphs_set: + G_copy = G.copy() + for node in G_copy.nodes(): + G_copy.nodes[node]["label"] = "A" + graphs.append(G_copy) + + adjacency_matrices, label_tensors = graphs_to_tensors(graphs, device=self.device) + + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(adjacency_matrices, label_tensors) + + # Check basic properties + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert torch.allclose(K, K.T, atol=1e-4), "Kernel matrix is not symmetric" + + # Check diagonal elements are 1 (normalized self-similarity) + assert torch.allclose(torch.diag(K), torch.ones_like(torch.diag(K)), atol=1e-4), ( + "Diagonal elements should be 1.0" + ) + + # Check off-diagonal elements are less than 1 (different structures) + off_diag_mask = ~torch.eye(K.shape[0], dtype=torch.bool, device=self.device) + assert torch.all(K[off_diag_mask] < 1.0), ( + "Off-diagonal elements should be less than 1.0 for different structures" + ) + + # Check all elements are non-negative (valid kernel) + assert torch.all(K >= 0), "Kernel values should be non-negative" diff --git a/tests/test_search_space.py b/tests/test_search_space.py index 73073a0cc..560b0af02 100644 --- a/tests/test_search_space.py +++ b/tests/test_search_space.py @@ -2,7 +2,7 @@ import pytest -from neps import Categorical, Constant, Float, Integer, SearchSpace +from neps import Categorical, Constant, Float, Grammar, Integer, SearchSpace def test_search_space_orders_parameters_by_name(): @@ -19,6 +19,16 @@ def test_multipe_fidelities_raises_error(): ) +def test_mutliple_grammars_raises_error(): + with pytest.raises(ValueError, match="neps only supports one grammar parameter"): + SearchSpace( + { + "a": Grammar.from_dict("s", {"s": lambda _: None}), + "b": Grammar.from_dict("s", {"s": lambda _: None}), + } + ) + + def test_sorting_of_parameters_into_subsets(): elements = { "a": Float(0, 1), @@ -26,6 +36,7 @@ def test_sorting_of_parameters_into_subsets(): "c": Categorical(["a", "b", "c"]), "d": Float(0, 1, is_fidelity=True), "x": Constant("x"), + "g": Grammar.from_dict("s", {"s": lambda _: None}), } space = SearchSpace(elements) assert space.elements == elements @@ -33,10 +44,13 @@ def test_sorting_of_parameters_into_subsets(): assert space.numerical == {"a": elements["a"], "b": elements["b"]} assert space.fidelities == {"d": elements["d"]} assert space.constants == {"x": "x"} + assert space.grammars == {"g": elements["g"]} - assert space.searchables == { + parameters = {**space.numerical, **space.categoricals} + assert parameters == { "a": elements["a"], "b": elements["b"], "c": elements["c"], } assert space.fidelity == ("d", elements["d"]) + assert space.grammar == ("g", elements["g"]) diff --git a/tests/test_state/test_neps_state.py b/tests/test_state/test_neps_state.py index 57b6db946..7dde50da9 100644 --- a/tests/test_state/test_neps_state.py +++ b/tests/test_state/test_neps_state.py @@ -132,7 +132,7 @@ def optimizer_and_key_and_search_space( if key in JUST_SKIP: pytest.xfail(f"{key} is not instantiable") - if key in REQUIRES_PRIOR and search_space.searchables["a"].prior is None: + if key in REQUIRES_PRIOR and search_space.numerical["a"].prior is None: pytest.xfail(f"{key} requires a prior") if len(search_space.fidelities) > 0 and key in OPTIMIZER_FAILS_WITH_FIDELITY: