From 11ac47c470b7b708f98c888d2bf5d2572ef0ee06 Mon Sep 17 00:00:00 2001 From: Darryl Reeves Date: Wed, 16 Mar 2022 23:40:58 -0400 Subject: [PATCH 1/6] Fixed send message computations in Schafer-Shenoy --- junctiontree/computation.py | 103 +++++++++--------------------------- 1 file changed, 26 insertions(+), 77 deletions(-) diff --git a/junctiontree/computation.py b/junctiontree/computation.py index 5388ba3..21bd206 100644 --- a/junctiontree/computation.py +++ b/junctiontree/computation.py @@ -71,22 +71,25 @@ def get_message(sepset_ix, tree, beliefs, clique_vars): for var in clique_vars[ss_ix] ] - neighbor_vars = list(set(neighbor_vars)) + neighbor_vars = np.unique(neighbor_vars) # multiply neighbor messages - messages = messages if len(messages) else [1] - msg_prod = dl.einsum( - *messages, - neighbor_vars + msg_prod = 1 if len(messages) == 0 else dl.einsum( + *messages, + neighbor_vars ) + + args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]] # compute message as marginalization over non-sepset values # multiplied by product of messages with output being vars in input sepset + message = dl.einsum(*args) + try: # attempt to update belief beliefs[sepset_ix] = message @@ -96,47 +99,6 @@ def get_message(sepset_ix, tree, beliefs, clique_vars): return None - def remove_message(msg_prod, prod_ixs, msg, msg_ixs, out_ixs): - '''Removes (divides out) sepset message from - product of all neighbor sepset messages for a clique - - :param msg_prod: product of all messages for clique - :param prod_ixs: variable indices in clique - :param msg: sepset message to be removed from product - :param msg_ixs: variable indices in sepset - :param out_ixs: variables indices expected in result - :return: the product of messages with sepset msg removed (divided out) - ''' - - exp_mask = np.in1d(prod_ixs, msg_ixs) - - # use mask to specify expanded dimensions in message - exp_ixs = np.full(msg_prod.ndim, None) - exp_ixs[exp_mask] = slice(None) - - # use mask to select slice dimensions - slice_mask = np.in1d(prod_ixs, out_ixs) - slice_ixs = np.full(msg_prod.ndim, slice(None)) - slice_ixs[~slice_mask] = 0 - - if all(exp_mask) and msg_ixs != prod_ixs: - # axis must be labeled starting at 0 - var_map = {var:i for i, var in enumerate(set(msg_ixs + prod_ixs))} - - # axis must be re-ordered if all variables shared but order is different - msg = np.moveaxis(msg, [var_map[var] for var in prod_ixs], [var_map[var] for var in msg_ixs]) - - # create dummy dimensions for performing division (with exp_ix) - # slice out dimensions of sepset variables from division result (with slice_ixs) - return np.divide( - msg_prod, - msg[ tuple(exp_ixs) ], - out = np.zeros_like(msg_prod), - where = msg[ tuple(exp_ixs) ] != 0 - )[ tuple(slice_ixs) ] - - - def send_message(message, sepset_ix, tree, beliefs, clique_vars): '''Sends message from clique at root of tree @@ -157,13 +119,14 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars): # adding message sent ] + [message, clique_vars[sepset_ix]] + all_neighbor_vars = [ var for vars in messages[1::2] for var in vars ] - neighbor_vars = list(set(all_neighbor_vars)) + neighbor_vars = np.unique(all_neighbor_vars) # multiply neighbor messages msg_prod = dl.einsum( @@ -175,34 +138,20 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars): ss_num = 0 for ss_ix, subtree in tree[1:]: - # divide product of messages by current sepset message for this neighbor - output_vars = list( - set( - [ - var - for vars in messages[1::2][0:ss_num] + messages[1::2][ss_num+1:] - for var in vars - ] - ) - ) - - mask = np.in1d( - neighbor_vars, - output_vars - - ) - - mod_neighbor_vars = np.array(neighbor_vars)[mask].tolist() - - mod_msg_prod = remove_message( - msg_prod, - neighbor_vars, - beliefs[ss_ix], - clique_vars[ss_ix], - mod_neighbor_vars - ) - - args = [mod_msg_prod, mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]] + # remove sepset ix vars from neighbor vars + mod_neighbor_vars = np.setdiff1d(neighbor_vars, clique_vars[ss_ix]) + + + + + # create product of messages that excludes the message from this sepset + mod_messages = [ + comp + for i in range(1,len(messages), 2) + for comp in messages[i-1:i+1] if messages[i] != clique_vars[ss_ix] + ] + args = [dl.einsum(*mod_messages, mod_neighbor_vars), mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]] + # calculate message to be sent message = dl.einsum( *args ) @@ -221,6 +170,7 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars): clique_vars[tree[0]] ] + beliefs[tree[0]] = dl.einsum(*args) @@ -242,7 +192,6 @@ def __run(tree, beliefs, clique_vars): return beliefs + beliefs = [np.copy(p) for p in potentials] return __run(tree, beliefs, clique_vars) - - From efd3757d239457a90f8111ef910c8ccfdc92fe4c Mon Sep 17 00:00:00 2001 From: Darryl Reeves Date: Sun, 27 Mar 2022 17:35:11 -0400 Subject: [PATCH 2/6] adding identity element this is for Sum Product distributive law to avoid special case handling in get_messages function --- computation.py | 191 +++++++++++++++++++++++++++++++++++++++++++++++++ sum_product.py | 43 +++++++++++ 2 files changed, 234 insertions(+) create mode 100644 computation.py create mode 100644 sum_product.py diff --git a/computation.py b/computation.py new file mode 100644 index 0000000..5674dcf --- /dev/null +++ b/computation.py @@ -0,0 +1,191 @@ +from junctiontree.sum_product import SumProduct +import numpy as np + +# setting optimize to true allows einsum to benefit from speed up due to +# contraction order optimization but at the cost of memory usage +# need to evaulate tradeoff within library +#sum_product = SumProduct(np.einsum,optimize=True) + +sum_product = SumProduct(np.einsum) + +def apply_evidence(potentials, variables, evidence): + ''' Shrink potentials based on given evidence + + :param potentials: list of numpy arrays subject to evidence + :param variables: list of variables in corresponding to potentials + :param evidence: dictionary with variables as keys and assigned value as value + :return: a new list of potentials after evidence applied + ''' + + return [ + [ + # index array based on evidence value when evidence provided otherwise use full array + pot[ + tuple( + [ + slice(evidence.get(var, 0), evidence.get(var, pot.shape[i]) + 1) + for i, var in enumerate(vars) + ] + ) + # workaround for scalar factors + ] if not np.isscalar(pot) else pot + ] + for pot, vars in zip(potentials, variables) + ] + + +def compute_beliefs(tree, potentials, clique_vars, dl=sum_product): + '''Computes beliefs for clique potentials in a junction tree + using Shafer-Shenoy updates. + + :param tree: list representing the structure of the junction tree + :param potentials: list of numpy arrays for cliques in junction tree + :param clique_vars: list of variables included in each clique in potentials list + :return: list of numpy arrays defining computed beliefs of each clique + ''' + + def get_message(sepset_ix, tree, beliefs, clique_vars): + '''Computes message from root of tree with scope defined by sepset + + :param sepset_ix: index of sepset scope in which to return message + (use slice(0) for no sepset) + :param tree: list representation of tree rooted by cluster for which message + will be computed + :param beliefs: list of numpy arrays for cliques in junction tree + :param clique_vars: list of variables included in each clique in potentials list + :return: message: potential with scope defined by sepset (or None if tree includes root) + ''' + + messages = [ + comp # message or sepset variables in order processed + for sepset_ix, subtree in tree[1:] + for comp in [ + get_message(sepset_ix, subtree, beliefs, clique_vars), + clique_vars[sepset_ix] + ] + ] + + neighbor_vars = [ + var + for ss_ix, subtree in tree[1:] + for var in clique_vars[ss_ix] + ] + + neighbor_vars = np.unique(neighbor_vars) + + # multiply neighbor messages + msg_prod = dl.einsum(*messages, neighbor_vars) + + args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]] + + # compute message as marginalization over non-sepset values + # multiplied by product of messages with output being vars in input sepset + + message = dl.einsum(*args) + + + try: + # attempt to update belief + beliefs[sepset_ix] = message + return message + except TypeError: + # function called on full tree so no message to send + return None + + + def send_message(message, sepset_ix, tree, beliefs, clique_vars): + '''Sends message from clique at root of tree + + :param message: message sent by neighbor + (use np.array(1) for no message) + :param sepset_ix: index of sepset scope in which message sent + (use slice(0) for no sepset) + :param tree: list representation of tree rooted by cluster receiving message + :param beliefs: list of numpy arrays for cliques in junction tree + :param clique_vars: list of variables included in each clique in potentials list + ''' + + # computed messages stored in beliefs for neighbor sepsets + messages = [ + comp + for ss_ix, _ in tree[1:] + for comp in [beliefs[ss_ix], clique_vars[ss_ix]] + # adding message sent + ] + [message, clique_vars[sepset_ix]] + + + all_neighbor_vars = [ + var + for vars in messages[1::2] + for var in vars + ] + + neighbor_vars = np.unique(all_neighbor_vars) + + # multiply neighbor messages + msg_prod = dl.einsum( + *messages, + neighbor_vars + ) + + # send message to each neighbor + ss_num = 0 + for ss_ix, subtree in tree[1:]: + + # remove sepset ix vars from neighbor vars + mod_neighbor_vars = np.setdiff1d(neighbor_vars, clique_vars[ss_ix]) + + + + + # create product of messages that excludes the message from this sepset + mod_messages = [ + comp + for i in range(1,len(messages), 2) + for comp in messages[i-1:i+1] if messages[i] != clique_vars[ss_ix] + ] + args = [dl.einsum(*mod_messages, mod_neighbor_vars), mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]] + + # calculate message to be sent + message = dl.einsum( *args ) + + # update sepset belief + beliefs[ss_ix] *= message + + send_message(message, ss_ix, subtree, beliefs, clique_vars) + ss_num += 1 + + # update belief for clique + args = [ + beliefs[tree[0]], + clique_vars[tree[0]], + msg_prod, + neighbor_vars, + clique_vars[tree[0]] + ] + + + beliefs[tree[0]] = dl.einsum(*args) + + + def __run(tree, beliefs, clique_vars): + '''Collect messages from neighbors recursively. Then, send messages + recursively. Updated beliefs when this + + :param tree: list representing the structure of the junction tree + :param briefs: list of numpy arrays for cliques in junction tree + :param clique_vars: list of variables included in each clique in potentials list + :return beliefs: consistent beliefs after Shafer-Shenoy updates applied + ''' + + # get messages from each neighbor + get_message(slice(0), tree, beliefs, clique_vars) + + # send message to each neighbor + send_message(np.array(1), slice(0), tree, beliefs, clique_vars) + + return beliefs + + + beliefs = [np.copy(p) for p in potentials] + return __run(tree, beliefs, clique_vars) diff --git a/sum_product.py b/sum_product.py new file mode 100644 index 0000000..bcff9f4 --- /dev/null +++ b/sum_product.py @@ -0,0 +1,43 @@ + +class SumProduct(): + ''' Sum-product distributive law ''' + + + def __init__(self, einsum, *args, **kwargs): + # Perhaps support for different frameworks (TensorFlow, Theano) could + # be provided by giving the necessary functions. + self.func = einsum + self.args = args + self.kwargs = kwargs + return + + def einsum(self, *args, **kwargs): + '''Performs Einstein summation based on input arguments + + :param args: the required positional arguments passed to underlying einsum function + :param kwargs: provides ability to pass key-word args to underlying function + :return: the resulting calculation based on the summation performed + ''' + + try: + if len(args[0]) == 0: + # represents identity element (empty list/array) for distributive law + return 1 + except TypeError: + # __len__ not defined for type (e.g. scalar) + pass + + args_list = list(args) + + var_indices = args_list[1::2] + [args_list[-1]] if len(args_list) % 2 == 1 else [] + + var_map = {var:i for i, var in enumerate(set([var for vars in var_indices for var in vars]))} + + args_list[1::2] = [ + [ var_map[var] for var in vars ] if len(vars) > 0 else [] for vars in args_list[1::2] + ] + + # explicit output indices may be provided requiring one additional mapping + args_list[-1] = [var_map[var] for var in args_list[-1]] + + return self.func(*args_list, *self.args, **kwargs, **self.kwargs) From 09427f0bb28153197fcece2bc026e5a797e2d457 Mon Sep 17 00:00:00 2001 From: Darryl Reeves Date: Mon, 4 Apr 2022 21:40:23 -0400 Subject: [PATCH 3/6] Delete computation.py --- computation.py | 191 ------------------------------------------------- 1 file changed, 191 deletions(-) delete mode 100644 computation.py diff --git a/computation.py b/computation.py deleted file mode 100644 index 5674dcf..0000000 --- a/computation.py +++ /dev/null @@ -1,191 +0,0 @@ -from junctiontree.sum_product import SumProduct -import numpy as np - -# setting optimize to true allows einsum to benefit from speed up due to -# contraction order optimization but at the cost of memory usage -# need to evaulate tradeoff within library -#sum_product = SumProduct(np.einsum,optimize=True) - -sum_product = SumProduct(np.einsum) - -def apply_evidence(potentials, variables, evidence): - ''' Shrink potentials based on given evidence - - :param potentials: list of numpy arrays subject to evidence - :param variables: list of variables in corresponding to potentials - :param evidence: dictionary with variables as keys and assigned value as value - :return: a new list of potentials after evidence applied - ''' - - return [ - [ - # index array based on evidence value when evidence provided otherwise use full array - pot[ - tuple( - [ - slice(evidence.get(var, 0), evidence.get(var, pot.shape[i]) + 1) - for i, var in enumerate(vars) - ] - ) - # workaround for scalar factors - ] if not np.isscalar(pot) else pot - ] - for pot, vars in zip(potentials, variables) - ] - - -def compute_beliefs(tree, potentials, clique_vars, dl=sum_product): - '''Computes beliefs for clique potentials in a junction tree - using Shafer-Shenoy updates. - - :param tree: list representing the structure of the junction tree - :param potentials: list of numpy arrays for cliques in junction tree - :param clique_vars: list of variables included in each clique in potentials list - :return: list of numpy arrays defining computed beliefs of each clique - ''' - - def get_message(sepset_ix, tree, beliefs, clique_vars): - '''Computes message from root of tree with scope defined by sepset - - :param sepset_ix: index of sepset scope in which to return message - (use slice(0) for no sepset) - :param tree: list representation of tree rooted by cluster for which message - will be computed - :param beliefs: list of numpy arrays for cliques in junction tree - :param clique_vars: list of variables included in each clique in potentials list - :return: message: potential with scope defined by sepset (or None if tree includes root) - ''' - - messages = [ - comp # message or sepset variables in order processed - for sepset_ix, subtree in tree[1:] - for comp in [ - get_message(sepset_ix, subtree, beliefs, clique_vars), - clique_vars[sepset_ix] - ] - ] - - neighbor_vars = [ - var - for ss_ix, subtree in tree[1:] - for var in clique_vars[ss_ix] - ] - - neighbor_vars = np.unique(neighbor_vars) - - # multiply neighbor messages - msg_prod = dl.einsum(*messages, neighbor_vars) - - args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]] - - # compute message as marginalization over non-sepset values - # multiplied by product of messages with output being vars in input sepset - - message = dl.einsum(*args) - - - try: - # attempt to update belief - beliefs[sepset_ix] = message - return message - except TypeError: - # function called on full tree so no message to send - return None - - - def send_message(message, sepset_ix, tree, beliefs, clique_vars): - '''Sends message from clique at root of tree - - :param message: message sent by neighbor - (use np.array(1) for no message) - :param sepset_ix: index of sepset scope in which message sent - (use slice(0) for no sepset) - :param tree: list representation of tree rooted by cluster receiving message - :param beliefs: list of numpy arrays for cliques in junction tree - :param clique_vars: list of variables included in each clique in potentials list - ''' - - # computed messages stored in beliefs for neighbor sepsets - messages = [ - comp - for ss_ix, _ in tree[1:] - for comp in [beliefs[ss_ix], clique_vars[ss_ix]] - # adding message sent - ] + [message, clique_vars[sepset_ix]] - - - all_neighbor_vars = [ - var - for vars in messages[1::2] - for var in vars - ] - - neighbor_vars = np.unique(all_neighbor_vars) - - # multiply neighbor messages - msg_prod = dl.einsum( - *messages, - neighbor_vars - ) - - # send message to each neighbor - ss_num = 0 - for ss_ix, subtree in tree[1:]: - - # remove sepset ix vars from neighbor vars - mod_neighbor_vars = np.setdiff1d(neighbor_vars, clique_vars[ss_ix]) - - - - - # create product of messages that excludes the message from this sepset - mod_messages = [ - comp - for i in range(1,len(messages), 2) - for comp in messages[i-1:i+1] if messages[i] != clique_vars[ss_ix] - ] - args = [dl.einsum(*mod_messages, mod_neighbor_vars), mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]] - - # calculate message to be sent - message = dl.einsum( *args ) - - # update sepset belief - beliefs[ss_ix] *= message - - send_message(message, ss_ix, subtree, beliefs, clique_vars) - ss_num += 1 - - # update belief for clique - args = [ - beliefs[tree[0]], - clique_vars[tree[0]], - msg_prod, - neighbor_vars, - clique_vars[tree[0]] - ] - - - beliefs[tree[0]] = dl.einsum(*args) - - - def __run(tree, beliefs, clique_vars): - '''Collect messages from neighbors recursively. Then, send messages - recursively. Updated beliefs when this - - :param tree: list representing the structure of the junction tree - :param briefs: list of numpy arrays for cliques in junction tree - :param clique_vars: list of variables included in each clique in potentials list - :return beliefs: consistent beliefs after Shafer-Shenoy updates applied - ''' - - # get messages from each neighbor - get_message(slice(0), tree, beliefs, clique_vars) - - # send message to each neighbor - send_message(np.array(1), slice(0), tree, beliefs, clique_vars) - - return beliefs - - - beliefs = [np.copy(p) for p in potentials] - return __run(tree, beliefs, clique_vars) From 325dcfb0b6bc08de9da79f53e27dcea077d93b66 Mon Sep 17 00:00:00 2001 From: Darryl Reeves Date: Mon, 4 Apr 2022 21:40:31 -0400 Subject: [PATCH 4/6] Delete sum_product.py --- sum_product.py | 43 ------------------------------------------- 1 file changed, 43 deletions(-) delete mode 100644 sum_product.py diff --git a/sum_product.py b/sum_product.py deleted file mode 100644 index bcff9f4..0000000 --- a/sum_product.py +++ /dev/null @@ -1,43 +0,0 @@ - -class SumProduct(): - ''' Sum-product distributive law ''' - - - def __init__(self, einsum, *args, **kwargs): - # Perhaps support for different frameworks (TensorFlow, Theano) could - # be provided by giving the necessary functions. - self.func = einsum - self.args = args - self.kwargs = kwargs - return - - def einsum(self, *args, **kwargs): - '''Performs Einstein summation based on input arguments - - :param args: the required positional arguments passed to underlying einsum function - :param kwargs: provides ability to pass key-word args to underlying function - :return: the resulting calculation based on the summation performed - ''' - - try: - if len(args[0]) == 0: - # represents identity element (empty list/array) for distributive law - return 1 - except TypeError: - # __len__ not defined for type (e.g. scalar) - pass - - args_list = list(args) - - var_indices = args_list[1::2] + [args_list[-1]] if len(args_list) % 2 == 1 else [] - - var_map = {var:i for i, var in enumerate(set([var for vars in var_indices for var in vars]))} - - args_list[1::2] = [ - [ var_map[var] for var in vars ] if len(vars) > 0 else [] for vars in args_list[1::2] - ] - - # explicit output indices may be provided requiring one additional mapping - args_list[-1] = [var_map[var] for var in args_list[-1]] - - return self.func(*args_list, *self.args, **kwargs, **self.kwargs) From f7d81989e6437bc623001f2468cc0efaa77ad901 Mon Sep 17 00:00:00 2001 From: Darryl Reeves Date: Mon, 4 Apr 2022 21:41:45 -0400 Subject: [PATCH 5/6] Adding computation.py and sum_product.py --- junctiontree/computation.py | 10 ++-------- junctiontree/sum_product.py | 8 ++++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/junctiontree/computation.py b/junctiontree/computation.py index 21bd206..5674dcf 100644 --- a/junctiontree/computation.py +++ b/junctiontree/computation.py @@ -73,14 +73,8 @@ def get_message(sepset_ix, tree, beliefs, clique_vars): neighbor_vars = np.unique(neighbor_vars) - # multiply neighbor messages - - msg_prod = 1 if len(messages) == 0 else dl.einsum( - *messages, - neighbor_vars - ) - - + # multiply neighbor messages + msg_prod = dl.einsum(*messages, neighbor_vars) args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]] diff --git a/junctiontree/sum_product.py b/junctiontree/sum_product.py index 8ed78c9..bcff9f4 100644 --- a/junctiontree/sum_product.py +++ b/junctiontree/sum_product.py @@ -19,6 +19,14 @@ def einsum(self, *args, **kwargs): :return: the resulting calculation based on the summation performed ''' + try: + if len(args[0]) == 0: + # represents identity element (empty list/array) for distributive law + return 1 + except TypeError: + # __len__ not defined for type (e.g. scalar) + pass + args_list = list(args) var_indices = args_list[1::2] + [args_list[-1]] if len(args_list) % 2 == 1 else [] From b92e56693b63262d3a0bc08d6c87b2779a52e0b1 Mon Sep 17 00:00:00 2001 From: Darryl Reeves Date: Sat, 23 Apr 2022 11:58:29 -0400 Subject: [PATCH 6/6] direct implementation of Shafer-Shenoy updates --- junctiontree/computation.py | 99 +++++++++++++++---------------------- 1 file changed, 40 insertions(+), 59 deletions(-) diff --git a/junctiontree/computation.py b/junctiontree/computation.py index 5674dcf..66ccdc4 100644 --- a/junctiontree/computation.py +++ b/junctiontree/computation.py @@ -1,5 +1,6 @@ from junctiontree.sum_product import SumProduct import numpy as np +import itertools # setting optimize to true allows einsum to benefit from speed up due to # contraction order optimization but at the cost of memory usage @@ -56,6 +57,7 @@ def get_message(sepset_ix, tree, beliefs, clique_vars): :return: message: potential with scope defined by sepset (or None if tree includes root) ''' + cluster_ix = tree[0] messages = [ comp # message or sepset variables in order processed for sepset_ix, subtree in tree[1:] @@ -65,35 +67,30 @@ def get_message(sepset_ix, tree, beliefs, clique_vars): ] ] - neighbor_vars = [ - var - for ss_ix, subtree in tree[1:] - for var in clique_vars[ss_ix] - ] - - neighbor_vars = np.unique(neighbor_vars) - - # multiply neighbor messages - msg_prod = dl.einsum(*messages, neighbor_vars) + # multiply neighbor messages + #messages = messages if len(messages) else [1,[]] - args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]] + args = messages + [beliefs[cluster_ix], clique_vars[cluster_ix], clique_vars[cluster_ix]] # compute message as marginalization over non-sepset values # multiplied by product of messages with output being vars in input sepset - message = dl.einsum(*args) - + #update clique belief + beliefs[cluster_ix] = dl.einsum(*args) try: - # attempt to update belief - beliefs[sepset_ix] = message - return message + # attempt to update sepset belief + args = [beliefs[cluster_ix], clique_vars[cluster_ix], clique_vars[sepset_ix]] + beliefs[sepset_ix] = dl.einsum(*args) + + # send sepset belief as message + return beliefs[sepset_ix] except TypeError: # function called on full tree so no message to send return None - def send_message(message, sepset_ix, tree, beliefs, clique_vars): + def send_message(message, sepset_ix, tree, beliefs, pots, clique_vars): '''Sends message from clique at root of tree :param message: message sent by neighbor @@ -101,10 +98,12 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars): :param sepset_ix: index of sepset scope in which message sent (use slice(0) for no sepset) :param tree: list representation of tree rooted by cluster receiving message - :param beliefs: list of numpy arrays for cliques in junction tree + :param beliefs: beliefs to update for cliques in junction tree + :param pots: list of original numpy arrays for cliques in junction tree :param clique_vars: list of variables included in each clique in potentials list ''' + cluster_ix = tree[0] # computed messages stored in beliefs for neighbor sepsets messages = [ comp @@ -113,79 +112,61 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars): # adding message sent ] + [message, clique_vars[sepset_ix]] - - all_neighbor_vars = [ - var - for vars in messages[1::2] - for var in vars - ] - - neighbor_vars = np.unique(all_neighbor_vars) - - # multiply neighbor messages - msg_prod = dl.einsum( - *messages, - neighbor_vars - ) - # send message to each neighbor - ss_num = 0 for ss_ix, subtree in tree[1:]: - # remove sepset ix vars from neighbor vars - mod_neighbor_vars = np.setdiff1d(neighbor_vars, clique_vars[ss_ix]) - - - - - # create product of messages that excludes the message from this sepset + # collect all messages (excluding those from ss_ix) + # using id() as sepset variables can be same even when sepsets are unique mod_messages = [ comp for i in range(1,len(messages), 2) - for comp in messages[i-1:i+1] if messages[i] != clique_vars[ss_ix] + for comp in messages[i-1:i+1] if id(messages[i]) != id(clique_vars[ss_ix]) ] - args = [dl.einsum(*mod_messages, mod_neighbor_vars), mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]] # calculate message to be sent - message = dl.einsum( *args ) + args = mod_messages + [pots[cluster_ix], clique_vars[cluster_ix], clique_vars[ss_ix]] + msg = dl.einsum( *args ) # update sepset belief - beliefs[ss_ix] *= message + args = [beliefs[ss_ix], clique_vars[ss_ix], msg, clique_vars[ss_ix], clique_vars[ss_ix]] + beliefs[ss_ix] = dl.einsum(*args) + + # send message to neighbor (excludes message from subtree) + send_message(msg, ss_ix, subtree, beliefs, pots, clique_vars) - send_message(message, ss_ix, subtree, beliefs, clique_vars) - ss_num += 1 # update belief for clique args = [ - beliefs[tree[0]], - clique_vars[tree[0]], - msg_prod, - neighbor_vars, - clique_vars[tree[0]] + beliefs[cluster_ix], + clique_vars[cluster_ix], + message, + clique_vars[sepset_ix], + clique_vars[cluster_ix] ] - - beliefs[tree[0]] = dl.einsum(*args) + beliefs[cluster_ix] = dl.einsum(*args) - def __run(tree, beliefs, clique_vars): + def __run(tree, potentials, clique_vars): '''Collect messages from neighbors recursively. Then, send messages recursively. Updated beliefs when this :param tree: list representing the structure of the junction tree - :param briefs: list of numpy arrays for cliques in junction tree + :param potentials: list of numpy arrays for cliques in junction tree :param clique_vars: list of variables included in each clique in potentials list :return beliefs: consistent beliefs after Shafer-Shenoy updates applied ''' + beliefs = [np.copy(p) for p in potentials] + # get messages from each neighbor get_message(slice(0), tree, beliefs, clique_vars) # send message to each neighbor - send_message(np.array(1), slice(0), tree, beliefs, clique_vars) + send_message(np.array(1), slice(0), tree, beliefs, potentials, clique_vars) + return beliefs - beliefs = [np.copy(p) for p in potentials] - return __run(tree, beliefs, clique_vars) + return __run(tree, potentials, clique_vars)