From 8eda218638f41f6c97940aaa4e082d1e60a95bb6 Mon Sep 17 00:00:00 2001 From: lockwo Date: Sat, 15 Nov 2025 22:11:53 -0800 Subject: [PATCH 1/2] example graph --- thrml/graph_utils.py | 69 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 thrml/graph_utils.py diff --git a/thrml/graph_utils.py b/thrml/graph_utils.py new file mode 100644 index 0000000..ac33774 --- /dev/null +++ b/thrml/graph_utils.py @@ -0,0 +1,69 @@ +import math +import jax +import jax.numpy as jnp +from .pgm import SpinNode +from .block_management import Block + +def make_graph( + side_len: int, + torus: bool, +) -> tuple: + jumps = [(1,0), (2, 1), (3, 2), (1, 4)] + side_len = math.ceil(side_len / 2) * 2 + size = side_len**2 + + def get_idx(i, j): + if torus: + i = (i + 10 * side_len) % side_len + j = (j + 10 * side_len) % side_len + + cond = (i >= side_len) | (j >= side_len) | (i < 0) | (j < 0) + return jnp.where(cond, -1, i * side_len + j) + + def get_coords(idx): + return idx // side_len, (idx + side_len) % side_len + + @jax.jit + def make_edge_single(idx, di, dj): + i, j = get_coords(idx) + return jnp.array([idx, get_idx(i + di, j + dj)]) + + make_edge_arr = jax.jit( + jax.vmap(make_edge_single, in_axes=(0, None, None), out_axes=0) + ) + + indices = jnp.arange(size) + edge_arrs_list = [] + + for dx, dy in jumps: + edges_pos = make_edge_arr(indices, dx, dy) + edges_neg = make_edge_arr(indices, -dx, -dy) + edge_arrs_list.append(edges_pos) + edge_arrs_list.append(edges_neg) + + edge_array = jnp.concatenate(edge_arrs_list, axis=0) + + nodes_upper = [] + nodes_lower = [] + all_nodes = [] + for i in range(size): + new_node = SpinNode() + all_nodes.append(new_node) + if (i // side_len + i % side_len) % 2 == 0: + nodes_upper.append(new_node) + else: + nodes_lower.append(new_node) + + edges = set() + edge_array = edge_array.tolist() + for i, j in edge_array: + if i == -1 or j == -1: + continue + edges.add((all_nodes[i], all_nodes[j])) + + edges = list(edges) + + upper_block = Block(nodes_upper) + lower_block = Block(nodes_lower) + + return all_nodes, edges, upper_block, lower_block \ No newline at end of file From d72c27581107b0cf3438cdec4f6f6e603df31349 Mon Sep 17 00:00:00 2001 From: lockwo Date: Sat, 15 Nov 2025 22:12:30 -0800 Subject: [PATCH 2/2] init --- thrml/__init__.py | 1 + thrml/graph_utils.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/thrml/__init__.py b/thrml/__init__.py index 8996d0b..5477ff9 100644 --- a/thrml/__init__.py +++ b/thrml/__init__.py @@ -22,6 +22,7 @@ from .factor import AbstractFactor as AbstractFactor from .factor import FactorSamplingProgram as FactorSamplingProgram from .factor import WeightedFactor as WeightedFactor +from .graph_utils import make_graph as make_graph from .interaction import InteractionGroup as InteractionGroup from .observers import AbstractObserver as AbstractObserver from .observers import MomentAccumulatorObserver as MomentAccumulatorObserver diff --git a/thrml/graph_utils.py b/thrml/graph_utils.py index ac33774..863c7d6 100644 --- a/thrml/graph_utils.py +++ b/thrml/graph_utils.py @@ -1,14 +1,17 @@ import math + import jax import jax.numpy as jnp -from .pgm import SpinNode + from .block_management import Block +from .pgm import SpinNode + def make_graph( side_len: int, torus: bool, ) -> tuple: - jumps = [(1,0), (2, 1), (3, 2), (1, 4)] + jumps = [(1, 0), (2, 1), (3, 2), (1, 4)] side_len = math.ceil(side_len / 2) * 2 size = side_len**2 @@ -28,9 +31,7 @@ def make_edge_single(idx, di, dj): i, j = get_coords(idx) return jnp.array([idx, get_idx(i + di, j + dj)]) - make_edge_arr = jax.jit( - jax.vmap(make_edge_single, in_axes=(0, None, None), out_axes=0) - ) + make_edge_arr = jax.jit(jax.vmap(make_edge_single, in_axes=(0, None, None), out_axes=0)) indices = jnp.arange(size) edge_arrs_list = [] @@ -66,4 +67,4 @@ def make_edge_single(idx, di, dj): upper_block = Block(nodes_upper) lower_block = Block(nodes_lower) - return all_nodes, edges, upper_block, lower_block \ No newline at end of file + return all_nodes, edges, upper_block, lower_block