From b6580e89d08ab5212acefdc11d60fcd89fa25b14 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Thu, 11 Nov 2021 15:15:54 +0800 Subject: [PATCH 01/21] =?UTF-8?q?experiment-emb.sh=20=E4=B8=AD=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=B8=BAnohup=E5=91=BD=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/fb15k-237-distmult.sh | 4 ++-- experiment-emb.sh | 12 +++++++----- requirements.txt | 10 ++++++++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/configs/fb15k-237-distmult.sh b/configs/fb15k-237-distmult.sh index 83f7d63..b809883 100644 --- a/configs/fb15k-237-distmult.sh +++ b/configs/fb15k-237-distmult.sh @@ -9,8 +9,8 @@ num_rollouts=1 bucket_interval=10 num_epochs=1000 num_wait_epochs=20 -batch_size=512 -train_batch_size=512 +batch_size=256 +train_batch_size=256 dev_batch_size=128 learning_rate=0.003 grad_norm=5 diff --git a/experiment-emb.sh b/experiment-emb.sh index 7394403..8597c86 100755 --- a/experiment-emb.sh +++ b/experiment-emb.sh @@ -17,7 +17,7 @@ if [[ $group_examples_by_query = *"True"* ]]; then group_examples_by_query_flag="--group_examples_by_query" fi -cmd="python3 -m src.experiments \ +cmd="python3 -u -m src.experiments \ --data_dir $data_dir \ $exp \ --model $model \ @@ -38,9 +38,11 @@ cmd="python3 -m src.experiments \ --beam_size $beam_size \ $group_examples_by_query_flag \ $add_reversed_training_edges_flag \ - --gpu $gpu \ - $ARGS" + --gpu $gpu " -echo "Executing $cmd" +echo $cmd + +LOG_FILE="logs/emb_"$model"_GPU_"$gpu".log" + +nohup $cmd>$LOG_FILE 2>&1 & -$cmd diff --git a/requirements.txt b/requirements.txt index a54c112..57c88ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,8 @@ -tqdm==4.9.0 -matplotlib==2.1.2 +certifi==2021.10.8 +numpy==1.18.5 +pip==21.3.1 +setuptools==58.5.3 +torch==1.9.1+cu111 +tqdm==4.61.1 +typing-extensions==3.10.0.2 +wheel==0.37.0 \ No newline at end of file From d31eddc54268eaa2ef0e8e758fd3580a6e7b7cfd Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 12:10:49 +0800 Subject: [PATCH 02/21] =?UTF-8?q?beam=5Fsearch.py=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=8F=AF=E8=AF=BB=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/beam_search.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rl/graph_search/beam_search.py b/src/rl/graph_search/beam_search.py index d480779..7f58147 100644 --- a/src/rl/graph_search/beam_search.py +++ b/src/rl/graph_search/beam_search.py @@ -140,12 +140,12 @@ def adjust_search_trace(search_trace, action_offset): action = init_action for t in range(num_steps): - last_r, e = action + last_r, current_entity = action assert(q.size() == e_s.size()) assert(q.size() == e_t.size()) - assert(e.size()[0] % batch_size == 0) + assert(current_entity.size()[0] % batch_size == 0) assert(q.size()[0] % batch_size == 0) - k = int(e.size()[0] / batch_size) + k = int(current_entity.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) @@ -153,7 +153,7 @@ def adjust_search_trace(search_trace, action_offset): obs = [e_s, q, e_t, t==(num_steps-1), last_r, seen_nodes] # one step forward in search db_outcomes, _, _ = pn.transit( - e, obs, kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=True) + current_entity, obs, kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=True) action_space, action_dist = db_outcomes[0] # => [batch_size*k, action_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(action_dist) From 529a4bb6f965acc525769f2a611a4ebb79c22846 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 16:17:56 +0800 Subject: [PATCH 03/21] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E6=95=88?= =?UTF-8?q?=E5=8F=98=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 8528df7..88a39f4 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -75,7 +75,7 @@ def transit(self, e, obs, kg, use_action_space_bucketing=True, merge_aspace_batc action_dist: (Batch) distribution over actions. entropy: (Batch) entropy of action distribution. """ - e_s, q, e_t, last_step, last_r, seen_nodes = obs + e_s, q, e_t, last_step, last_r, _ = obs # Representation of the current state (current node and other observations) Q = kg.get_relation_embeddings(q) From 7c997d54420455fb58ef03cde2f8fe29ec089b0f Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 16:28:57 +0800 Subject: [PATCH 04/21] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/beam_search.py | 32 +++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/rl/graph_search/beam_search.py b/src/rl/graph_search/beam_search.py index 7f58147..57ae930 100644 --- a/src/rl/graph_search/beam_search.py +++ b/src/rl/graph_search/beam_search.py @@ -13,21 +13,21 @@ from src.utils.ops import unique_max, var_cuda, zeros_var_cuda, int_var_cuda, int_fill_var_cuda, var_to_numpy -def beam_search(pn, e_s, q, e_t, kg, num_steps, beam_size, return_path_components=False): +def beam_search(pn, source_entity, query_relation, target_entity, kg, num_steps, beam_size, return_path_components=False): """ Beam search from source. :param pn: Policy network. - :param e_s: (Variable:batch) source entity indices. - :param q: (Variable:batch) query relation indices. - :param e_t: (Variable:batch) target entity indices. + :param source_entity: (Variable:batch) source entity indices. + :param query_relation: (Variable:batch) query relation indices. + :param target_entity: (Variable:batch) target entity indices. :param kg: Knowledge graph environment. :param num_steps: Number of search steps. :param beam_size: Beam size used in search. :param return_path_components: If set, return all path components at the end of search. """ assert (num_steps >= 1) - batch_size = len(e_s) + batch_size = len(source_entity) def top_k_action(log_action_dist, action_space): """ @@ -124,13 +124,13 @@ def adjust_search_trace(search_trace, action_offset): search_trace[i] = (new_r, new_e) # Initialization - r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r) - seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1) - init_action = (r_s, e_s) + r_s = int_fill_var_cuda(source_entity.size(), kg.dummy_start_r) + seen_nodes = int_fill_var_cuda(source_entity.size(), kg.dummy_e).unsqueeze(1) + init_action = (r_s, source_entity) # path encoder pn.initialize_path(init_action, kg) if kg.args.save_beam_search_paths: - search_trace = [(r_s, e_s)] + search_trace = [(r_s, source_entity)] # Run beam search for num_steps # [batch_size*k], k=1 @@ -141,16 +141,16 @@ def adjust_search_trace(search_trace, action_offset): action = init_action for t in range(num_steps): last_r, current_entity = action - assert(q.size() == e_s.size()) - assert(q.size() == e_t.size()) + assert(query_relation.size() == source_entity.size()) + assert(query_relation.size() == target_entity.size()) assert(current_entity.size()[0] % batch_size == 0) - assert(q.size()[0] % batch_size == 0) + assert(query_relation.size()[0] % batch_size == 0) k = int(current_entity.size()[0] / batch_size) # => [batch_size*k] - q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) - e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) - e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) - obs = [e_s, q, e_t, t==(num_steps-1), last_r, seen_nodes] + query_relation = ops.tile_along_beam(query_relation.view(batch_size, -1)[:, 0], k) + source_entity = ops.tile_along_beam(source_entity.view(batch_size, -1)[:, 0], k) + target_entity = ops.tile_along_beam(target_entity.view(batch_size, -1)[:, 0], k) + obs = [source_entity, query_relation, target_entity, t == (num_steps - 1), last_r, seen_nodes] # one step forward in search db_outcomes, _, _ = pn.transit( current_entity, obs, kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=True) From 23fe64bf6fb6f3de2683f3fa61277b2b82a83bce Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 16:54:42 +0800 Subject: [PATCH 05/21] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 88a39f4..5a3a5b2 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -46,17 +46,17 @@ def __init__(self, args): self.fn = None self.fn_kg = None - def transit(self, e, obs, kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=False): + def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=False): """ Compute the next action distribution based on (a) the current node (entity) in KG and the query relation (b) action history representation - :param e: agent location (node) at step t. + :param current_entity: agent location (node) at step t. :param obs: agent observation at step t. e_s: source node q: query relation e_t: target node - last_step: If set, the agent is carrying out the last step. + is_last_step: If set, the agent is carrying out the last step. last_r: label of edge traversed in the previous step seen_nodes: notes seen on the paths :param kg: Knowledge graph environment. @@ -75,7 +75,7 @@ def transit(self, e, obs, kg, use_action_space_bucketing=True, merge_aspace_batc action_dist: (Batch) distribution over actions. entropy: (Batch) entropy of action distribution. """ - e_s, q, e_t, last_step, last_r, _ = obs + e_s, q, e_t, is_last_step, last_r, seen_nodes = obs # Representation of the current state (current node and other observations) Q = kg.get_relation_embeddings(q) @@ -84,10 +84,10 @@ def transit(self, e, obs, kg, use_action_space_bucketing=True, merge_aspace_batc X = torch.cat([H, Q], dim=-1) elif self.relation_only_in_path: E_s = kg.get_entity_embeddings(e_s) - E = kg.get_entity_embeddings(e) + E = kg.get_entity_embeddings(current_entity) X = torch.cat([E, H, E_s, Q], dim=-1) else: - E = kg.get_entity_embeddings(e) + E = kg.get_entity_embeddings(current_entity) X = torch.cat([E, H, Q], dim=-1) # MLP @@ -124,7 +124,7 @@ def pad_and_cat_action_space(action_spaces, inv_offset): db_outcomes = [] entropy_list = [] references = [] - db_action_spaces, db_references = self.get_action_space_in_buckets(e, obs, kg) + db_action_spaces, db_references = self.get_action_space_in_buckets(current_entity, obs, kg) for action_space_b, reference_b in zip(db_action_spaces, db_references): X2_b = X2[reference_b, :] action_dist_b, entropy_b = policy_nn_fun(X2_b, action_space_b) @@ -142,7 +142,7 @@ def pad_and_cat_action_space(action_spaces, inv_offset): db_outcomes = [(action_space, action_dist)] inv_offset = None else: - action_space = self.get_action_space(e, obs, kg) + action_space = self.get_action_space(current_entity, obs, kg) action_dist, entropy = policy_nn_fun(X2, action_space) db_outcomes = [(action_space, action_dist)] inv_offset = None From 8de4853cc91f50c08b0f072686dae29bcac04f13 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 17:15:51 +0800 Subject: [PATCH 06/21] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 5a3a5b2..0c975ca 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -223,7 +223,7 @@ def get_action_space_in_buckets(self, e, obs, kg, collapse_entities=False): l_batch_refsi stores the indices of the examples in bucket i in the current batch, which is used later to restore the output results to the original order. """ - e_s, q, e_t, last_step, last_r, seen_nodes = obs + e_s, q, e_t, is_last_step, last_r, seen_nodes = obs assert(len(e) == len(last_r)) assert(len(e) == len(e_s)) assert(len(e) == len(q)) @@ -257,7 +257,7 @@ def get_action_space_in_buckets(self, e, obs, kg, collapse_entities=False): q_b = q[l_batch_refs] e_t_b = e_t[l_batch_refs] seen_nodes_b = seen_nodes[l_batch_refs] - obs_b = [e_s_b, q_b, e_t_b, last_step, last_r_b, seen_nodes_b] + obs_b = [e_s_b, q_b, e_t_b, is_last_step, last_r_b, seen_nodes_b] action_space_b = ((r_space_b, e_space_b), action_mask_b) action_space_b = self.apply_action_masks(action_space_b, e_b, obs_b, kg) db_action_spaces.append(action_space_b) @@ -271,18 +271,18 @@ def get_action_space(self, e, obs, kg): action_space = ((r_space, e_space), action_mask) return self.apply_action_masks(action_space, e, obs, kg) - def apply_action_masks(self, action_space, e, obs, kg): + def apply_action_masks(self, action_space, current_entity, obs, kg): (r_space, e_space), action_mask = action_space - e_s, q, e_t, last_step, last_r, seen_nodes = obs + source_entity, query_relation, target_entity, is_last_step, last_r, seen_nodes = obs # Prevent the agent from selecting the ground truth edge - ground_truth_edge_mask = self.get_ground_truth_edge_mask(e, r_space, e_space, e_s, q, e_t, kg) + ground_truth_edge_mask = self.get_ground_truth_edge_mask(current_entity, r_space, e_space, source_entity, query_relation, target_entity, kg) action_mask -= ground_truth_edge_mask self.validate_action_mask(action_mask) # Mask out false negatives in the final step - if last_step: - false_negative_mask = self.get_false_negative_mask(e_space, e_s, q, e_t, kg) + if is_last_step: + false_negative_mask = self.get_false_negative_mask(e_space, source_entity, query_relation, target_entity, kg) action_mask *= (1 - false_negative_mask) self.validate_action_mask(action_mask) From 9ff7e159f4971a04947dd284ade7ee1424d977f4 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 17:28:30 +0800 Subject: [PATCH 07/21] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 0c975ca..d864403 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -297,13 +297,13 @@ def apply_action_masks(self, action_space, current_entity, obs, kg): # action_mask *= (1 - loop_mask_b) return (r_space, e_space), action_mask - def get_ground_truth_edge_mask(self, e, r_space, e_space, e_s, q, e_t, kg): + def get_ground_truth_edge_mask(self, current_entity, r_space, e_space, source_entity, query_relation, target_entity, kg): ground_truth_edge_mask = \ - ((e == e_s).unsqueeze(1) * (r_space == q.unsqueeze(1)) * (e_space == e_t.unsqueeze(1))) - inv_q = kg.get_inv_relation_id(q) + ((current_entity == source_entity).unsqueeze(1) * (r_space == query_relation.unsqueeze(1)) * (e_space == target_entity.unsqueeze(1))) + inv_q = kg.get_inv_relation_id(query_relation) inv_ground_truth_edge_mask = \ - ((e == e_t).unsqueeze(1) * (r_space == inv_q.unsqueeze(1)) * (e_space == e_s.unsqueeze(1))) - return ((ground_truth_edge_mask + inv_ground_truth_edge_mask) * (e_s.unsqueeze(1) != kg.dummy_e)).float() + ((current_entity == target_entity).unsqueeze(1) * (r_space == inv_q.unsqueeze(1)) * (e_space == source_entity.unsqueeze(1))) + return ((ground_truth_edge_mask + inv_ground_truth_edge_mask) * (source_entity.unsqueeze(1) != kg.dummy_e)).float() def get_answer_mask(self, e_space, e_s, q, kg): if kg.args.mask_test_false_negatives: From dc663831daa8d25d5fe53966f0b5c0eb0ae6e070 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 17:43:43 +0800 Subject: [PATCH 08/21] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index d864403..03245db 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -75,15 +75,15 @@ def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merg action_dist: (Batch) distribution over actions. entropy: (Batch) entropy of action distribution. """ - e_s, q, e_t, is_last_step, last_r, seen_nodes = obs + source_entity, query_relation, target_entity, is_last_step, last_r, seen_nodes = obs # Representation of the current state (current node and other observations) - Q = kg.get_relation_embeddings(q) + Q = kg.get_relation_embeddings(query_relation) H = self.path[-1][0][-1, :, :] if self.relation_only: X = torch.cat([H, Q], dim=-1) elif self.relation_only_in_path: - E_s = kg.get_entity_embeddings(e_s) + E_s = kg.get_entity_embeddings(source_entity) E = kg.get_entity_embeddings(current_entity) X = torch.cat([E, H, E_s, Q], dim=-1) else: From ee1cd89044909ec02e0f5681c2d773b0b894766e Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 17:51:27 +0800 Subject: [PATCH 09/21] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 03245db..ae24838 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -53,9 +53,9 @@ def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merg (b) action history representation :param current_entity: agent location (node) at step t. :param obs: agent observation at step t. - e_s: source node - q: query relation - e_t: target node + source_entity: source node + query_relation: query relation + target_entity: target node is_last_step: If set, the agent is carrying out the last step. last_r: label of edge traversed in the previous step seen_nodes: notes seen on the paths @@ -78,17 +78,17 @@ def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merg source_entity, query_relation, target_entity, is_last_step, last_r, seen_nodes = obs # Representation of the current state (current node and other observations) - Q = kg.get_relation_embeddings(query_relation) + relation_embeddings = kg.get_relation_embeddings(query_relation) H = self.path[-1][0][-1, :, :] if self.relation_only: - X = torch.cat([H, Q], dim=-1) + X = torch.cat([H, relation_embeddings], dim=-1) elif self.relation_only_in_path: - E_s = kg.get_entity_embeddings(source_entity) - E = kg.get_entity_embeddings(current_entity) - X = torch.cat([E, H, E_s, Q], dim=-1) + source_entity_embeddings = kg.get_entity_embeddings(source_entity) + current_entity_embeddings = kg.get_entity_embeddings(current_entity) + X = torch.cat([current_entity_embeddings, H, source_entity_embeddings, relation_embeddings], dim=-1) else: - E = kg.get_entity_embeddings(current_entity) - X = torch.cat([E, H, Q], dim=-1) + current_entity_embeddings = kg.get_entity_embeddings(current_entity) + X = torch.cat([current_entity_embeddings, H, relation_embeddings], dim=-1) # MLP X = self.W1(X) From fa84b3b53694e311810eba7207ce7b04a5fb6037 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 19:50:35 +0800 Subject: [PATCH 10/21] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index ae24838..48a9c0c 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -79,16 +79,16 @@ def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merg # Representation of the current state (current node and other observations) relation_embeddings = kg.get_relation_embeddings(query_relation) - H = self.path[-1][0][-1, :, :] + hide_embeddings = self.path[-1][0][-1, :, :] if self.relation_only: - X = torch.cat([H, relation_embeddings], dim=-1) + X = torch.cat([hide_embeddings, relation_embeddings], dim=-1) elif self.relation_only_in_path: source_entity_embeddings = kg.get_entity_embeddings(source_entity) current_entity_embeddings = kg.get_entity_embeddings(current_entity) - X = torch.cat([current_entity_embeddings, H, source_entity_embeddings, relation_embeddings], dim=-1) + X = torch.cat([current_entity_embeddings, hide_embeddings, source_entity_embeddings, relation_embeddings], dim=-1) else: current_entity_embeddings = kg.get_entity_embeddings(current_entity) - X = torch.cat([current_entity_embeddings, H, relation_embeddings], dim=-1) + X = torch.cat([current_entity_embeddings, hide_embeddings, relation_embeddings], dim=-1) # MLP X = self.W1(X) From 1255bacb9fcf1aafde50195f92468197302451bb Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 20:00:19 +0800 Subject: [PATCH 11/21] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 48a9c0c..71bf928 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -81,21 +81,21 @@ def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merg relation_embeddings = kg.get_relation_embeddings(query_relation) hide_embeddings = self.path[-1][0][-1, :, :] if self.relation_only: - X = torch.cat([hide_embeddings, relation_embeddings], dim=-1) + state_embeddings = torch.cat([hide_embeddings, relation_embeddings], dim=-1) elif self.relation_only_in_path: source_entity_embeddings = kg.get_entity_embeddings(source_entity) current_entity_embeddings = kg.get_entity_embeddings(current_entity) - X = torch.cat([current_entity_embeddings, hide_embeddings, source_entity_embeddings, relation_embeddings], dim=-1) + state_embeddings = torch.cat([current_entity_embeddings, hide_embeddings, source_entity_embeddings, relation_embeddings], dim=-1) else: current_entity_embeddings = kg.get_entity_embeddings(current_entity) - X = torch.cat([current_entity_embeddings, hide_embeddings, relation_embeddings], dim=-1) + state_embeddings = torch.cat([current_entity_embeddings, hide_embeddings, relation_embeddings], dim=-1) # MLP - X = self.W1(X) - X = F.relu(X) - X = self.W1Dropout(X) - X = self.W2(X) - X2 = self.W2Dropout(X) + state_embeddings = self.W1(state_embeddings) + state_embeddings = F.relu(state_embeddings) + state_embeddings = self.W1Dropout(state_embeddings) + state_embeddings = self.W2(state_embeddings) + state_embeddings = self.W2Dropout(state_embeddings) def policy_nn_fun(X2, action_space): (r_space, e_space), action_mask = action_space @@ -126,8 +126,8 @@ def pad_and_cat_action_space(action_spaces, inv_offset): references = [] db_action_spaces, db_references = self.get_action_space_in_buckets(current_entity, obs, kg) for action_space_b, reference_b in zip(db_action_spaces, db_references): - X2_b = X2[reference_b, :] - action_dist_b, entropy_b = policy_nn_fun(X2_b, action_space_b) + temp_state_embeddings = state_embeddings[reference_b, :] + action_dist_b, entropy_b = policy_nn_fun(temp_state_embeddings, action_space_b) references.extend(reference_b) db_outcomes.append((action_space_b, action_dist_b)) entropy_list.append(entropy_b) @@ -143,7 +143,7 @@ def pad_and_cat_action_space(action_spaces, inv_offset): inv_offset = None else: action_space = self.get_action_space(current_entity, obs, kg) - action_dist, entropy = policy_nn_fun(X2, action_space) + action_dist, entropy = policy_nn_fun(state_embeddings, action_space) db_outcomes = [(action_space, action_dist)] inv_offset = None From f6102b0d4d7822f8522cbeaf3058ca9cda2564c5 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 20:14:18 +0800 Subject: [PATCH 12/21] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 71bf928..7212c8e 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -97,11 +97,11 @@ def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merg state_embeddings = self.W2(state_embeddings) state_embeddings = self.W2Dropout(state_embeddings) - def policy_nn_fun(X2, action_space): + def policy_nn_fun(state_embeddings, action_space): (r_space, e_space), action_mask = action_space A = self.get_action_embedding((r_space, e_space), kg) action_dist = F.softmax( - torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) + torch.squeeze(A @ torch.unsqueeze(state_embeddings, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) return action_dist, ops.entropy(action_dist) From 4be7fa0a5ed7a2a70eb3c23043d859b43e91e92c Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Fri, 12 Nov 2021 20:55:24 +0800 Subject: [PATCH 13/21] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rl/graph_search/pn.py b/src/rl/graph_search/pn.py index 7212c8e..d25becd 100644 --- a/src/rl/graph_search/pn.py +++ b/src/rl/graph_search/pn.py @@ -99,9 +99,9 @@ def transit(self, current_entity, obs, kg, use_action_space_bucketing=True, merg def policy_nn_fun(state_embeddings, action_space): (r_space, e_space), action_mask = action_space - A = self.get_action_embedding((r_space, e_space), kg) + action_space_embeddings = self.get_action_embedding((r_space, e_space), kg) action_dist = F.softmax( - torch.squeeze(A @ torch.unsqueeze(state_embeddings, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) + torch.squeeze(action_space_embeddings @ torch.unsqueeze(state_embeddings, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) return action_dist, ops.entropy(action_dist) From 747d92e9b3f2f5602e290edf30f023b6ba923fa7 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Mon, 15 Nov 2021 15:51:37 +0800 Subject: [PATCH 14/21] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=EF=BC=9Aaction=5Fspace=5Fbuckets=5Fdiscrete=20=E8=A1=8C?= =?UTF-8?q?=E5=8A=A8=E7=A9=BA=E9=97=B4=E5=AF=B9=E9=BD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/knowledge_graph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/knowledge_graph.py b/src/knowledge_graph.py index 32792d7..14d44fd 100644 --- a/src/knowledge_graph.py +++ b/src/knowledge_graph.py @@ -178,6 +178,7 @@ def vectorize_unique_r_space(unique_r_space_list, unique_r_space_size, volatile) num_facts_saved_in_action_table - self.num_entities)) for key in action_space_buckets_discrete: print('Vectorizing action spaces bucket {}...'.format(key)) + # key * self.args.bucket_interval 做了一次行动空间的对齐 self.action_space_buckets[key] = vectorize_action_space( action_space_buckets_discrete[key], key * self.args.bucket_interval) else: From 7e01e0ffc9dec2c48b226e63e74c545b9b5f7e99 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Mon, 15 Nov 2021 16:03:47 +0800 Subject: [PATCH 15/21] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E6=95=88?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/pg.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/rl/graph_search/pg.py b/src/rl/graph_search/pg.py index dac9d48..a44fbeb 100644 --- a/src/rl/graph_search/pg.py +++ b/src/rl/graph_search/pg.py @@ -199,14 +199,12 @@ def sample(action_space, action_dist): if inv_offset is not None: next_r_list = [] next_e_list = [] - action_dist_list = [] action_prob_list = [] for action_space, action_dist in db_outcomes: sample_outcome = sample(action_space, action_dist) next_r_list.append(sample_outcome['action_sample'][0]) next_e_list.append(sample_outcome['action_sample'][1]) action_prob_list.append(sample_outcome['action_prob']) - action_dist_list.append(action_dist) next_r = torch.cat(next_r_list, dim=0)[inv_offset] next_e = torch.cat(next_e_list, dim=0)[inv_offset] action_sample = (next_r, next_e) From c586d19803e99cb74e7011156c2d01794791b9e5 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Mon, 15 Nov 2021 16:41:51 +0800 Subject: [PATCH 16/21] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/beam_search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rl/graph_search/beam_search.py b/src/rl/graph_search/beam_search.py index 57ae930..62d5f08 100644 --- a/src/rl/graph_search/beam_search.py +++ b/src/rl/graph_search/beam_search.py @@ -124,13 +124,13 @@ def adjust_search_trace(search_trace, action_offset): search_trace[i] = (new_r, new_e) # Initialization - r_s = int_fill_var_cuda(source_entity.size(), kg.dummy_start_r) + start_relation = int_fill_var_cuda(source_entity.size(), kg.dummy_start_r) seen_nodes = int_fill_var_cuda(source_entity.size(), kg.dummy_e).unsqueeze(1) - init_action = (r_s, source_entity) + init_action = (start_relation, source_entity) # path encoder pn.initialize_path(init_action, kg) if kg.args.save_beam_search_paths: - search_trace = [(r_s, source_entity)] + search_trace = [(start_relation, source_entity)] # Run beam search for num_steps # [batch_size*k], k=1 From ba7b6f507ca3b14859c1bbeb045e458a1a424da4 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Mon, 15 Nov 2021 16:48:01 +0800 Subject: [PATCH 17/21] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/beam_search.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rl/graph_search/beam_search.py b/src/rl/graph_search/beam_search.py index 62d5f08..d7eef73 100644 --- a/src/rl/graph_search/beam_search.py +++ b/src/rl/graph_search/beam_search.py @@ -141,10 +141,12 @@ def adjust_search_trace(search_trace, action_offset): action = init_action for t in range(num_steps): last_r, current_entity = action + assert(query_relation.size() == source_entity.size()) assert(query_relation.size() == target_entity.size()) assert(current_entity.size()[0] % batch_size == 0) assert(query_relation.size()[0] % batch_size == 0) + k = int(current_entity.size()[0] / batch_size) # => [batch_size*k] query_relation = ops.tile_along_beam(query_relation.view(batch_size, -1)[:, 0], k) From 1e457895754969ba728631e06da2787de2ddbbf7 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Tue, 16 Nov 2021 09:25:47 +0800 Subject: [PATCH 18/21] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E6=95=88?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/utils/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/ops.py b/src/utils/ops.py index 0bc0a5a..106fea6 100644 --- a/src/utils/ops.py +++ b/src/utils/ops.py @@ -187,7 +187,7 @@ def pack(l, a): l.pop(0) -def unique_max(unique_x, x, values, marker_2D=None): +def unique_max(unique_x, x, values): unique_interval = 100 unique_values, unique_indices = [], [] # prevent memory explotion during decoding From 760fde7fd1c7688c835fa42bec7fc854d211c988 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Tue, 16 Nov 2021 10:41:41 +0800 Subject: [PATCH 19/21] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=EF=BC=9A=20File=20"/content/MultiHopKG-master/src/rl/graph=5Fs?= =?UTF-8?q?earch/pn.py",=20line=20176,=20in=20new=5Ftuple=20=3D=20tuple([?= =?UTF-8?q?=5Fx[:,=20offset,=20:]=20for=20=5Fx=20in=20x])=20IndexError:=20?= =?UTF-8?q?tensors=20used=20as=20indices=20must=20be=20long,=20byte=20or?= =?UTF-8?q?=20bool=20tensors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl/graph_search/beam_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rl/graph_search/beam_search.py b/src/rl/graph_search/beam_search.py index d7eef73..68d6603 100644 --- a/src/rl/graph_search/beam_search.py +++ b/src/rl/graph_search/beam_search.py @@ -59,7 +59,7 @@ def top_k_action(log_action_dist, action_space): log_action_prob = log_action_prob.view(-1) # *** compute parent offset # [batch_size, k] - action_beam_offset = action_ind / action_space_size + action_beam_offset = action_ind // action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] @@ -102,7 +102,7 @@ def top_k_answer_unique(log_action_dist, action_space): k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk(unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] - top_unique_beam_offset = top_unique_idx / action_space_size + top_unique_beam_offset = top_unique_idx // action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) From e50e96eacc2ae5a4026536337e26092d58ab8d59 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Tue, 16 Nov 2021 11:33:04 +0800 Subject: [PATCH 20/21] =?UTF-8?q?experiment-rs.sh=E4=B8=AD=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E5=91=BD=E4=BB=A4=E5=8F=98=E4=B8=BAnohup=E5=BD=A2?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- experiment-rs.sh | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/experiment-rs.sh b/experiment-rs.sh index 80e9af1..4f5bb5d 100755 --- a/experiment-rs.sh +++ b/experiment-rs.sh @@ -21,7 +21,7 @@ if [[ $use_action_space_bucketing = *"True"* ]]; then use_action_space_bucketing_flag='--use_action_space_bucketing' fi -cmd="python3 -m src.experiments \ +cmd="python3 -u -m src.experiments \ --data_dir $data_dir \ $exp \ --model $model \ @@ -60,6 +60,13 @@ cmd="python3 -m src.experiments \ --gpu $gpu \ $ARGS" -echo "Executing $cmd" +echo $cmd + +arr=(${data_dir//// }) +dataset=${arr[1]} + +LOG_FILE="logs/rs_"$dataset"_"$model"_GPU_"$gpu".log" + +nohup $cmd>$LOG_FILE 2>&1 & + -$cmd From b8ad8efb37e5bba11e11b25e5dc38fa9e87f8403 Mon Sep 17 00:00:00 2001 From: Sun Kai Date: Tue, 16 Nov 2021 11:53:32 +0800 Subject: [PATCH 21/21] =?UTF-8?q?=E8=8E=B7=E5=8F=96KGE=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/experiments.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/experiments.py b/src/experiments.py index 3343a36..e47bf20 100644 --- a/src/experiments.py +++ b/src/experiments.py @@ -189,6 +189,7 @@ def construct_model(args): lf = PolicyGradient(args, kg, pn) elif args.model.startswith('point.rs'): pn = GraphSearchPolicy(args) + #获取KGE模型 fn_model = args.model.split('.')[2] fn_args = copy.deepcopy(args) fn_args.model = fn_model