Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions thrml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .factor import AbstractFactor as AbstractFactor
from .factor import FactorSamplingProgram as FactorSamplingProgram
from .factor import WeightedFactor as WeightedFactor
from .graph_utils import make_graph as make_graph
from .interaction import InteractionGroup as InteractionGroup
from .observers import AbstractObserver as AbstractObserver
from .observers import MomentAccumulatorObserver as MomentAccumulatorObserver
Expand Down
70 changes: 70 additions & 0 deletions thrml/graph_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import math

import jax
import jax.numpy as jnp

from .block_management import Block
from .pgm import SpinNode


def make_graph(
side_len: int,
torus: bool,
) -> tuple:
jumps = [(1, 0), (2, 1), (3, 2), (1, 4)]
side_len = math.ceil(side_len / 2) * 2
size = side_len**2

def get_idx(i, j):
if torus:
i = (i + 10 * side_len) % side_len
j = (j + 10 * side_len) % side_len

cond = (i >= side_len) | (j >= side_len) | (i < 0) | (j < 0)
return jnp.where(cond, -1, i * side_len + j)

def get_coords(idx):
return idx // side_len, (idx + side_len) % side_len

@jax.jit
def make_edge_single(idx, di, dj):
i, j = get_coords(idx)
return jnp.array([idx, get_idx(i + di, j + dj)])

make_edge_arr = jax.jit(jax.vmap(make_edge_single, in_axes=(0, None, None), out_axes=0))

indices = jnp.arange(size)
edge_arrs_list = []

for dx, dy in jumps:
edges_pos = make_edge_arr(indices, dx, dy)
edges_neg = make_edge_arr(indices, -dx, -dy)
edge_arrs_list.append(edges_pos)
edge_arrs_list.append(edges_neg)

edge_array = jnp.concatenate(edge_arrs_list, axis=0)

nodes_upper = []
nodes_lower = []
all_nodes = []
for i in range(size):
new_node = SpinNode()
all_nodes.append(new_node)
if (i // side_len + i % side_len) % 2 == 0:
nodes_upper.append(new_node)
else:
nodes_lower.append(new_node)

edges = set()
edge_array = edge_array.tolist()
for i, j in edge_array:
if i == -1 or j == -1:
continue
edges.add((all_nodes[i], all_nodes[j]))

edges = list(edges)

upper_block = Block(nodes_upper)
lower_block = Block(nodes_lower)

return all_nodes, edges, upper_block, lower_block