diff --git a/junctiontree/computation.py b/junctiontree/computation.py index 5388ba3..66ccdc4 100644 --- a/junctiontree/computation.py +++ b/junctiontree/computation.py @@ -1,5 +1,6 @@ from junctiontree.sum_product import SumProduct import numpy as np +import itertools # setting optimize to true allows einsum to benefit from speed up due to # contraction order optimization but at the cost of memory usage @@ -56,6 +57,7 @@ def get_message(sepset_ix, tree, beliefs, clique_vars): :return: message: potential with scope defined by sepset (or None if tree includes root) ''' + cluster_ix = tree[0] messages = [ comp # message or sepset variables in order processed for sepset_ix, subtree in tree[1:] @@ -65,79 +67,30 @@ def get_message(sepset_ix, tree, beliefs, clique_vars): ] ] - neighbor_vars = [ - var - for ss_ix, subtree in tree[1:] - for var in clique_vars[ss_ix] - ] - - neighbor_vars = list(set(neighbor_vars)) - # multiply neighbor messages - messages = messages if len(messages) else [1] - - msg_prod = dl.einsum( - *messages, - neighbor_vars - ) + #messages = messages if len(messages) else [1,[]] - args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]] + args = messages + [beliefs[cluster_ix], clique_vars[cluster_ix], clique_vars[cluster_ix]] # compute message as marginalization over non-sepset values # multiplied by product of messages with output being vars in input sepset - message = dl.einsum(*args) + + #update clique belief + beliefs[cluster_ix] = dl.einsum(*args) try: - # attempt to update belief - beliefs[sepset_ix] = message - return message + # attempt to update sepset belief + args = [beliefs[cluster_ix], clique_vars[cluster_ix], clique_vars[sepset_ix]] + beliefs[sepset_ix] = dl.einsum(*args) + + # send sepset belief as message + return beliefs[sepset_ix] except TypeError: # function called on full tree so no message to send return None - def remove_message(msg_prod, prod_ixs, msg, msg_ixs, out_ixs): - '''Removes (divides out) sepset message from - product of all neighbor sepset messages for a clique - - :param msg_prod: product of all messages for clique - :param prod_ixs: variable indices in clique - :param msg: sepset message to be removed from product - :param msg_ixs: variable indices in sepset - :param out_ixs: variables indices expected in result - :return: the product of messages with sepset msg removed (divided out) - ''' - - exp_mask = np.in1d(prod_ixs, msg_ixs) - - # use mask to specify expanded dimensions in message - exp_ixs = np.full(msg_prod.ndim, None) - exp_ixs[exp_mask] = slice(None) - - # use mask to select slice dimensions - slice_mask = np.in1d(prod_ixs, out_ixs) - slice_ixs = np.full(msg_prod.ndim, slice(None)) - slice_ixs[~slice_mask] = 0 - - if all(exp_mask) and msg_ixs != prod_ixs: - # axis must be labeled starting at 0 - var_map = {var:i for i, var in enumerate(set(msg_ixs + prod_ixs))} - - # axis must be re-ordered if all variables shared but order is different - msg = np.moveaxis(msg, [var_map[var] for var in prod_ixs], [var_map[var] for var in msg_ixs]) - - # create dummy dimensions for performing division (with exp_ix) - # slice out dimensions of sepset variables from division result (with slice_ixs) - return np.divide( - msg_prod, - msg[ tuple(exp_ixs) ], - out = np.zeros_like(msg_prod), - where = msg[ tuple(exp_ixs) ] != 0 - )[ tuple(slice_ixs) ] - - - - def send_message(message, sepset_ix, tree, beliefs, clique_vars): + def send_message(message, sepset_ix, tree, beliefs, pots, clique_vars): '''Sends message from clique at root of tree :param message: message sent by neighbor @@ -145,10 +98,12 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars): :param sepset_ix: index of sepset scope in which message sent (use slice(0) for no sepset) :param tree: list representation of tree rooted by cluster receiving message - :param beliefs: list of numpy arrays for cliques in junction tree + :param beliefs: beliefs to update for cliques in junction tree + :param pots: list of original numpy arrays for cliques in junction tree :param clique_vars: list of variables included in each clique in potentials list ''' + cluster_ix = tree[0] # computed messages stored in beliefs for neighbor sepsets messages = [ comp @@ -157,92 +112,61 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars): # adding message sent ] + [message, clique_vars[sepset_ix]] - all_neighbor_vars = [ - var - for vars in messages[1::2] - for var in vars - ] - - neighbor_vars = list(set(all_neighbor_vars)) - - # multiply neighbor messages - msg_prod = dl.einsum( - *messages, - neighbor_vars - ) - # send message to each neighbor - ss_num = 0 for ss_ix, subtree in tree[1:]: - # divide product of messages by current sepset message for this neighbor - output_vars = list( - set( - [ - var - for vars in messages[1::2][0:ss_num] + messages[1::2][ss_num+1:] - for var in vars - ] - ) - ) - - mask = np.in1d( - neighbor_vars, - output_vars - - ) - - mod_neighbor_vars = np.array(neighbor_vars)[mask].tolist() - - mod_msg_prod = remove_message( - msg_prod, - neighbor_vars, - beliefs[ss_ix], - clique_vars[ss_ix], - mod_neighbor_vars - ) - - args = [mod_msg_prod, mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]] + # collect all messages (excluding those from ss_ix) + # using id() as sepset variables can be same even when sepsets are unique + mod_messages = [ + comp + for i in range(1,len(messages), 2) + for comp in messages[i-1:i+1] if id(messages[i]) != id(clique_vars[ss_ix]) + ] + # calculate message to be sent - message = dl.einsum( *args ) + args = mod_messages + [pots[cluster_ix], clique_vars[cluster_ix], clique_vars[ss_ix]] + msg = dl.einsum( *args ) # update sepset belief - beliefs[ss_ix] *= message + args = [beliefs[ss_ix], clique_vars[ss_ix], msg, clique_vars[ss_ix], clique_vars[ss_ix]] + beliefs[ss_ix] = dl.einsum(*args) + + # send message to neighbor (excludes message from subtree) + send_message(msg, ss_ix, subtree, beliefs, pots, clique_vars) - send_message(message, ss_ix, subtree, beliefs, clique_vars) - ss_num += 1 # update belief for clique args = [ - beliefs[tree[0]], - clique_vars[tree[0]], - msg_prod, - neighbor_vars, - clique_vars[tree[0]] + beliefs[cluster_ix], + clique_vars[cluster_ix], + message, + clique_vars[sepset_ix], + clique_vars[cluster_ix] ] - beliefs[tree[0]] = dl.einsum(*args) + beliefs[cluster_ix] = dl.einsum(*args) - def __run(tree, beliefs, clique_vars): + def __run(tree, potentials, clique_vars): '''Collect messages from neighbors recursively. Then, send messages recursively. Updated beliefs when this :param tree: list representing the structure of the junction tree - :param briefs: list of numpy arrays for cliques in junction tree + :param potentials: list of numpy arrays for cliques in junction tree :param clique_vars: list of variables included in each clique in potentials list :return beliefs: consistent beliefs after Shafer-Shenoy updates applied ''' + beliefs = [np.copy(p) for p in potentials] + # get messages from each neighbor get_message(slice(0), tree, beliefs, clique_vars) # send message to each neighbor - send_message(np.array(1), slice(0), tree, beliefs, clique_vars) + send_message(np.array(1), slice(0), tree, beliefs, potentials, clique_vars) - return beliefs - beliefs = [np.copy(p) for p in potentials] - return __run(tree, beliefs, clique_vars) + return beliefs + return __run(tree, potentials, clique_vars) diff --git a/junctiontree/sum_product.py b/junctiontree/sum_product.py index 8ed78c9..bcff9f4 100644 --- a/junctiontree/sum_product.py +++ b/junctiontree/sum_product.py @@ -19,6 +19,14 @@ def einsum(self, *args, **kwargs): :return: the resulting calculation based on the summation performed ''' + try: + if len(args[0]) == 0: + # represents identity element (empty list/array) for distributive law + return 1 + except TypeError: + # __len__ not defined for type (e.g. scalar) + pass + args_list = list(args) var_indices = args_list[1::2] + [args_list[-1]] if len(args_list) % 2 == 1 else []