Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 45 additions & 121 deletions junctiontree/computation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from junctiontree.sum_product import SumProduct
import numpy as np
import itertools

# setting optimize to true allows einsum to benefit from speed up due to
# contraction order optimization but at the cost of memory usage
Expand Down Expand Up @@ -56,6 +57,7 @@ def get_message(sepset_ix, tree, beliefs, clique_vars):
:return: message: potential with scope defined by sepset (or None if tree includes root)
'''

cluster_ix = tree[0]
messages = [
comp # message or sepset variables in order processed
for sepset_ix, subtree in tree[1:]
Expand All @@ -65,90 +67,43 @@ def get_message(sepset_ix, tree, beliefs, clique_vars):
]
]

neighbor_vars = [
var
for ss_ix, subtree in tree[1:]
for var in clique_vars[ss_ix]
]

neighbor_vars = list(set(neighbor_vars))

# multiply neighbor messages
messages = messages if len(messages) else [1]

msg_prod = dl.einsum(
*messages,
neighbor_vars
)
#messages = messages if len(messages) else [1,[]]

args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]]
args = messages + [beliefs[cluster_ix], clique_vars[cluster_ix], clique_vars[cluster_ix]]

# compute message as marginalization over non-sepset values
# multiplied by product of messages with output being vars in input sepset
message = dl.einsum(*args)

#update clique belief
beliefs[cluster_ix] = dl.einsum(*args)

try:
# attempt to update belief
beliefs[sepset_ix] = message
return message
# attempt to update sepset belief
args = [beliefs[cluster_ix], clique_vars[cluster_ix], clique_vars[sepset_ix]]
beliefs[sepset_ix] = dl.einsum(*args)

# send sepset belief as message
return beliefs[sepset_ix]
except TypeError:
# function called on full tree so no message to send
return None


def remove_message(msg_prod, prod_ixs, msg, msg_ixs, out_ixs):
'''Removes (divides out) sepset message from
product of all neighbor sepset messages for a clique

:param msg_prod: product of all messages for clique
:param prod_ixs: variable indices in clique
:param msg: sepset message to be removed from product
:param msg_ixs: variable indices in sepset
:param out_ixs: variables indices expected in result
:return: the product of messages with sepset msg removed (divided out)
'''

exp_mask = np.in1d(prod_ixs, msg_ixs)

# use mask to specify expanded dimensions in message
exp_ixs = np.full(msg_prod.ndim, None)
exp_ixs[exp_mask] = slice(None)

# use mask to select slice dimensions
slice_mask = np.in1d(prod_ixs, out_ixs)
slice_ixs = np.full(msg_prod.ndim, slice(None))
slice_ixs[~slice_mask] = 0

if all(exp_mask) and msg_ixs != prod_ixs:
# axis must be labeled starting at 0
var_map = {var:i for i, var in enumerate(set(msg_ixs + prod_ixs))}

# axis must be re-ordered if all variables shared but order is different
msg = np.moveaxis(msg, [var_map[var] for var in prod_ixs], [var_map[var] for var in msg_ixs])

# create dummy dimensions for performing division (with exp_ix)
# slice out dimensions of sepset variables from division result (with slice_ixs)
return np.divide(
msg_prod,
msg[ tuple(exp_ixs) ],
out = np.zeros_like(msg_prod),
where = msg[ tuple(exp_ixs) ] != 0
)[ tuple(slice_ixs) ]



def send_message(message, sepset_ix, tree, beliefs, clique_vars):
def send_message(message, sepset_ix, tree, beliefs, pots, clique_vars):
'''Sends message from clique at root of tree

:param message: message sent by neighbor
(use np.array(1) for no message)
:param sepset_ix: index of sepset scope in which message sent
(use slice(0) for no sepset)
:param tree: list representation of tree rooted by cluster receiving message
:param beliefs: list of numpy arrays for cliques in junction tree
:param beliefs: beliefs to update for cliques in junction tree
:param pots: list of original numpy arrays for cliques in junction tree
:param clique_vars: list of variables included in each clique in potentials list
'''

cluster_ix = tree[0]
# computed messages stored in beliefs for neighbor sepsets
messages = [
comp
Expand All @@ -157,92 +112,61 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars):
# adding message sent
] + [message, clique_vars[sepset_ix]]

all_neighbor_vars = [
var
for vars in messages[1::2]
for var in vars
]

neighbor_vars = list(set(all_neighbor_vars))

# multiply neighbor messages
msg_prod = dl.einsum(
*messages,
neighbor_vars
)

# send message to each neighbor
ss_num = 0
for ss_ix, subtree in tree[1:]:

# divide product of messages by current sepset message for this neighbor
output_vars = list(
set(
[
var
for vars in messages[1::2][0:ss_num] + messages[1::2][ss_num+1:]
for var in vars
]
)
)

mask = np.in1d(
neighbor_vars,
output_vars

)

mod_neighbor_vars = np.array(neighbor_vars)[mask].tolist()

mod_msg_prod = remove_message(
msg_prod,
neighbor_vars,
beliefs[ss_ix],
clique_vars[ss_ix],
mod_neighbor_vars
)

args = [mod_msg_prod, mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]]
# collect all messages (excluding those from ss_ix)
# using id() as sepset variables can be same even when sepsets are unique
mod_messages = [
comp
for i in range(1,len(messages), 2)
for comp in messages[i-1:i+1] if id(messages[i]) != id(clique_vars[ss_ix])
]

# calculate message to be sent
message = dl.einsum( *args )
args = mod_messages + [pots[cluster_ix], clique_vars[cluster_ix], clique_vars[ss_ix]]
msg = dl.einsum( *args )

# update sepset belief
beliefs[ss_ix] *= message
args = [beliefs[ss_ix], clique_vars[ss_ix], msg, clique_vars[ss_ix], clique_vars[ss_ix]]
beliefs[ss_ix] = dl.einsum(*args)

# send message to neighbor (excludes message from subtree)
send_message(msg, ss_ix, subtree, beliefs, pots, clique_vars)

send_message(message, ss_ix, subtree, beliefs, clique_vars)
ss_num += 1

# update belief for clique
args = [
beliefs[tree[0]],
clique_vars[tree[0]],
msg_prod,
neighbor_vars,
clique_vars[tree[0]]
beliefs[cluster_ix],
clique_vars[cluster_ix],
message,
clique_vars[sepset_ix],
clique_vars[cluster_ix]
]

beliefs[tree[0]] = dl.einsum(*args)
beliefs[cluster_ix] = dl.einsum(*args)


def __run(tree, beliefs, clique_vars):
def __run(tree, potentials, clique_vars):
'''Collect messages from neighbors recursively. Then, send messages
recursively. Updated beliefs when this

:param tree: list representing the structure of the junction tree
:param briefs: list of numpy arrays for cliques in junction tree
:param potentials: list of numpy arrays for cliques in junction tree
:param clique_vars: list of variables included in each clique in potentials list
:return beliefs: consistent beliefs after Shafer-Shenoy updates applied
'''

beliefs = [np.copy(p) for p in potentials]

# get messages from each neighbor
get_message(slice(0), tree, beliefs, clique_vars)

# send message to each neighbor
send_message(np.array(1), slice(0), tree, beliefs, clique_vars)
send_message(np.array(1), slice(0), tree, beliefs, potentials, clique_vars)

return beliefs

beliefs = [np.copy(p) for p in potentials]
return __run(tree, beliefs, clique_vars)
return beliefs


return __run(tree, potentials, clique_vars)
8 changes: 8 additions & 0 deletions junctiontree/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def einsum(self, *args, **kwargs):
:return: the resulting calculation based on the summation performed
'''

try:
if len(args[0]) == 0:
# represents identity element (empty list/array) for distributive law
return 1
except TypeError:
# __len__ not defined for type (e.g. scalar)
pass

args_list = list(args)

var_indices = args_list[1::2] + [args_list[-1]] if len(args_list) % 2 == 1 else []
Expand Down