diff --git a/graph_extractions/graph_sampler.py b/graph_extractions/graph_sampler.py index 96f8853..90703a9 100755 --- a/graph_extractions/graph_sampler.py +++ b/graph_extractions/graph_sampler.py @@ -29,19 +29,20 @@ FIX2 = False -def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = False, all_negs = False, sample_all_negs = False, all_candidate_negs = False, onek_negs = False, two_hun_negs = False, neg_triplet_as_task = False, subset = None, inductive = False, no_candidates = False): +def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = False, all_negs = False, sample_all_negs = False, all_candidate_negs = False, onek_negs = False, two_hun_negs = False, neg_triplet_as_task = False, subset = None, inductive = False, no_candidates = False, inductive_graph=False): raw_data_paths = os.path.join(root, dataset) if sample_neg or all_negs: rel2candidates = json.load(open(os.path.join(raw_data_paths, 'rel2candidates.json'))) e1rel_e2 = json.load(open(os.path.join(raw_data_paths, 'e1rel_e2.json'))) - + postfix = "" if not inductive else "_inductive" - - path_graph = json.load(open(os.path.join(raw_data_paths, f"path_graph{postfix}.json"))) - - adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(raw_data_paths, inductive = inductive) - + if inductive_graph: + path_graph = json.load(open(os.path.join(raw_data_paths, f"path_graph_inductive.json"))) + else: + path_graph = json.load(open(os.path.join(raw_data_paths, f"path_graph.json"))) + adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(raw_data_paths, inductive = inductive, inductive_graph=inductive_graph) + links = {} print(splits) for split_name in splits: @@ -49,13 +50,13 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa if no_candidates: assert split_name == "test" - + tasks = json.load(open(os.path.join(raw_data_paths, split_name + '_tasks.json'))) if split_name == "pretrain": tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks{postfix}.json'))) if split_name == "train" and inductive: tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks{postfix}.json'))) - + pos = {} if not all_negs: # don't need to extract pos again normally @@ -68,7 +69,7 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa print("nop") pos[rel] = t - + neg = {} if not all_negs: if not sample_neg: @@ -80,7 +81,7 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks_neg{postfix}.json'))) if split_name == "pretrain": tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks_neg{postfix}.json'))) - + for rel, task in tasks.items(): t = [] for e1, rel, e2 in task: @@ -97,7 +98,7 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa for rel, task in tqdm(tasks.items()): t = [] d[rel] = [] - for e1, rel, e2 in tqdm(task): + for e1, rel, e2 in tqdm(task): while True: if rel in rel2candidates and not no_candidates: negative = random.choice(rel2candidates[rel]) @@ -105,14 +106,14 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa else: negative = random.choice(list(entity2id.keys())) negative_condition = [e1, rel, negative] not in path_graph - + if (negative_condition) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[rel]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[rel].append([e1,rel, negative]) break - neg[rel] = t - if split_name == "pretrain": + neg[rel] = t + if split_name == "pretrain": json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_neg{postfix}.json'), "w")) elif split_name == "train" and inductive: json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_neg{postfix}.json'), "w")) @@ -120,8 +121,8 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_neg_nocandidates.json'), "w")) else: json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_neg.json'), "w")) - - + + elif neg_triplet_as_task: ## only for 50 negs print("50negs (neg_triplet_as_task) ") @@ -156,27 +157,27 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa d[e1+rel+e2].append([e1,rel, negative]) else: break - + num_current_negs = len(d[e1+rel+e2]) - + if num_current_negs < 50: - # sample new negs - d_e = [] + # sample new negs + d_e = [] for negative in rel2candidates[rel]: if (negative not in e1rel_e2[e1 + rel]) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d_e and [e1, rel, negative] not in d[e1+rel +e2]: d_e.append([e1,rel, negative]) - + # print(len(t)) indices = np.random.choice(range(len(d_e)), 50 - num_current_negs, replace = False) d[e1+rel+e2] = d[e1+rel+e2] + np.array(d_e)[indices].tolist() for e1,rel, negative in np.array(d_e)[indices].tolist(): all_triplets[e1 + negative] = [[e1,rel, negative]] neg[e1+negative] = [[entity2id[e1], entity2id[negative]]] - + json.dump(d,open(os.path.join(raw_data_paths, split_name + '_tasks_50neg.json'), "w")) json.dump(all_triplets,open(os.path.join(raw_data_paths, split_name + '_tasks_50neg_triplet_as_task.json'), "w")) - + else: if all_candidate_negs: print("all_candidate_negs") @@ -203,7 +204,7 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa for negative in rel2candidates[rel]: if (negative not in e1rel_e2[e1 + rel]) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) # print(len(t)) neg[e1+rel+e2] = t @@ -235,22 +236,22 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa for negative in rel2candidates[rel]: if (negative not in e1rel_e2[e1 + rel]) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) # print(len(t)) indices = np.random.choice(range(len(t)), min(1000, len(t)), replace = False) neg[e1+rel+e2] = np.array(t)[indices].tolist() d[e1+rel+e2] = np.array(d[e1+rel +e2])[indices].tolist() - + else: while len(t) < 1000: negative = random.choice(list(entity2id.keys())) if (negative not in e1rel_e2[e1 + rel]) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) - neg[e1+rel+e2] = t + neg[e1+rel+e2] = t json.dump(d,open(os.path.join(raw_data_paths, split_name + '_tasks_1000neg.json'), "w")) elif two_hun_negs: @@ -274,12 +275,12 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa for e1, rel, e2 in tqdm(task): t = [] d[e1+rel +e2] = [] - if rel in rel2candidates and dataset != "ConceptNet": + if rel in rel2candidates and dataset != "ConceptNet": # sample all negs for dev and test for negative in rel2candidates[rel]: if (negative not in e1rel_e2[e1 + rel]) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) # print(len(t)) indices = np.random.choice(range(len(t)), min(200, len(t)), replace = False) @@ -290,22 +291,22 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa negative = random.choice(list(entity2id.keys())) if ([e1, rel, negative] not in path_graph) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) - neg[e1+rel+e2] = t - json.dump(d,open(os.path.join(raw_data_paths, split_name + '_tasks_200neg.json'), "w")) + neg[e1+rel+e2] = t + json.dump(d,open(os.path.join(raw_data_paths, split_name + '_tasks_200neg.json'), "w")) else: print("50negs") if not sample_all_negs: print("reuse negatives") tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg.json'))) - if split_name == "pretrain": + if split_name == "pretrain": tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg{postfix}.json'))) if split_name == "train" and inductive: tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg{postfix}.json'))) if no_candidates: tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg_nocandidates.json'))) - + if dataset == "Wiki": print(f"subset {subset} triplets") tasks = json.load(open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg_subset{subset}.json'))) @@ -326,12 +327,12 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa for e1, rel, e2 in tqdm(task): t = [] d[e1+rel +e2] = [] - # sample all negs for dev and test - if rel in rel2candidates and dataset not in ["ConceptNet", "FB15K-237"] and not no_candidates: + # sample all negs for dev and test + if rel in rel2candidates and dataset not in ["ConceptNet", "FB15K-237"] and not no_candidates: for negative in rel2candidates[rel]: if (negative not in e1rel_e2[e1 + rel]) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) # print(len(t)) indices = np.random.choice(range(len(t)), 50, replace = False) @@ -342,11 +343,11 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa negative = random.choice(list(entity2id.keys())) if (negative not in e1rel_e2[e1 + rel]) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) - - neg[e1+rel+e2] = t - + + neg[e1+rel+e2] = t + else: print("no e1rel_e2") while len(t) < 50: @@ -355,12 +356,12 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa if (negative_condition) \ and negative != e2 and negative != e1 and [e1, rel, negative] not in d[e1+rel+e2]: - t.append([entity2id[e1], entity2id[negative]]) + t.append([entity2id[e1], entity2id[negative]]) d[e1+rel+e2].append([e1,rel, negative]) - neg[e1+rel+e2] = t - - - if split_name == "pretrain": + neg[e1+rel+e2] = t + + + if split_name == "pretrain": json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg{postfix}.json'), "w")) elif split_name == "train" and inductive: json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg{postfix}.json'), "w")) @@ -368,35 +369,35 @@ def generate_subgraph_datasets(root, dataset, splits, kind, hop, sample_neg = Fa json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg_nocandidates.json'), "w")) else: json.dump(d,open(os.path.join(raw_data_paths, split_name + f'_tasks_50neg.json'), "w")) - + split['pos'] = pos split['neg'] = neg - + links[split_name] = split - + if dataset == "Wiki": postfix += f"_{subset}" if all_negs: if neg_triplet_as_task: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_50negs_triplet_as_task_hop={hop}' + postfix) - + db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_50negs_triplet_as_task_hop={hop}' + postfix) + elif all_candidate_negs: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_allnegs_hop={hop}' + postfix) + db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_allnegs_hop={hop}' + postfix) elif onek_negs: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_1000negs_hop={hop}'+ postfix) + db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_1000negs_hop={hop}'+ postfix) elif two_hun_negs: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_200negs_hop={hop}'+ postfix) - else: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_50negs_hop={hop}'+ postfix) + db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_200negs_hop={hop}'+ postfix) + else: + db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_50negs_hop={hop}'+ postfix) else: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_hop={hop}'+ postfix) - - if FIX2: + db_path = os.path.join(raw_data_paths, f'subgraphs_fix_new_{kind}_hop={hop}'+ postfix) + + if FIX2: if all_negs: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix2_new_{kind}_50negs_hop={hop}'+ postfix) + db_path = os.path.join(raw_data_paths, f'subgraphs_fix2_new_{kind}_50negs_hop={hop}'+ postfix) else: - db_path = os.path.join(raw_data_paths, f'subgraphs_fix2_new_{kind}_hop={hop}'+ postfix) + db_path = os.path.join(raw_data_paths, f'subgraphs_fix2_new_{kind}_hop={hop}'+ postfix) print(db_path) links2subgraphs(adj_list, links, kind, hop, db_path) @@ -405,7 +406,7 @@ def links2subgraphs(A, links, kind, hop, db_path): ''' extract enclosing subgraphs, write map mode + named dbs ''' - + max_n_label = {'value': np.array([0, 0])} subgraph_sizes = [] enc_ratios = [] @@ -415,8 +416,8 @@ def links2subgraphs(A, links, kind, hop, db_path): # BYTES_PER_DATUM = get_average_subgraph_size(100, links['dev']['pos'], A, kind, hop) * 1.5 print(BYTES_PER_DATUM) - links_length = 0 - + links_length = 0 + for split_name, split in links.items(): for rel, task in split['pos'].items(): links_length += len(task) @@ -427,15 +428,15 @@ def links2subgraphs(A, links, kind, hop, db_path): map_size = links_length * BYTES_PER_DATUM * 1000 env = lmdb.open(db_path, map_size=map_size, max_dbs=8) - - A_ = ray.put(A) + + A_ = ray.put(A) def extraction_helper(A, links_all, r_label_all, g_labels_all, split_env, ids_all, hop, prefix_all): thread_n =6000000 for idx in tqdm(range(0, len(links_all), thread_n), leave = True): - + end = idx+thread_n if end > len(links_all): end = len(links_all) @@ -443,11 +444,11 @@ def extraction_helper(A, links_all, r_label_all, g_labels_all, split_env, ids_al links = links_all[idx:end] r_label = r_label_all[idx:end] g_labels = g_labels_all[idx:end] - prefix = prefix_all[idx:end] + prefix = prefix_all[idx:end] with mp.Pool(processes=None) as p: args_ = zip(ids, links, r_label,g_labels, [kind] *len(links), [hop] *len(links), prefix, [A_] * len(links)) - + for (str_id, datum) in tqdm(p.imap_unordered(extract_save_subgraph, list(args_)), total=len(links), leave = True): max_n_label['value'] = np.maximum(np.max(datum['n_labels'], axis=0), max_n_label['value']) subgraph_sizes.append(datum['subgraph_size']) @@ -455,10 +456,10 @@ def extraction_helper(A, links_all, r_label_all, g_labels_all, split_env, ids_al num_pruned_nodes.append(datum['num_pruned_nodes']) with env.begin(write=True, db=split_env) as txn: - txn.put(str_id, serialize(datum)) - - - + txn.put(str_id, serialize(datum)) + + + for split_name, split in links.items(): print(f"Extracting enclosing subgraphs for positive links in {split_name} set") db_name_pos = split_name + '_pos' @@ -467,20 +468,20 @@ def extraction_helper(A, links_all, r_label_all, g_labels_all, split_env, ids_al rs = [] prefix = [] ids = [] - count = 0 + count = 0 with env.begin(write=False, db=split_env) as txn: - for rel, task in split['pos'].items(): - + for rel, task in split['pos'].items(): + # missing = False # for idx in range(len(task)): # str_id = (rel).encode() + '{:08}'.format(idx).encode('ascii') -# if txn.get(str_id) is None: +# if txn.get(str_id) is None: # missing = True # break # if not missing: # print(rel, "already exists") -# continue - +# continue + ls.extend(task) rs.extend([rel] * len(task)) prefix.extend([rel] * len(task)) @@ -488,7 +489,7 @@ def extraction_helper(A, links_all, r_label_all, g_labels_all, split_env, ids_al count += len(task) labels= np.ones(count) extraction_helper(A, ls, rs, labels, split_env, ids, hop, prefix) - + print(f"Extracting enclosing subgraphs for negative links in {split_name} set") db_name_neg = split_name + '_neg' split_env = env.open_db(db_name_neg.encode()) @@ -496,24 +497,24 @@ def extraction_helper(A, links_all, r_label_all, g_labels_all, split_env, ids_al rs = [] prefix = [] ids = [] - count = 0 + count = 0 with env.begin(write=False, db=split_env) as txn: - for rel, task in split['neg'].items(): - + for rel, task in split['neg'].items(): + # more finegrained missing missing_ids = list(range(len(task))) # missing_ids = [] # missing = False # for idx in range(len(task)): # str_id = (rel).encode() + '{:08}'.format(idx).encode('ascii') -# if txn.get(str_id) is None: +# if txn.get(str_id) is None: # missing = True # missing_ids.append(idx) # # break # if not missing: # print(rel, "already exists") -# continue - +# continue + ls.extend(np.array(task)[missing_ids].tolist()) rs.extend([rel] * len(missing_ids)) prefix.extend([rel] * len(missing_ids)) @@ -523,7 +524,7 @@ def extraction_helper(A, links_all, r_label_all, g_labels_all, split_env, ids_al print(count) extraction_helper(A, ls, rs, labels, split_env, ids, hop, prefix) - + max_n_label['value'] = max_n_label['value'] with env.begin(write=True) as txn: @@ -561,7 +562,7 @@ def get_average_subgraph_size(sample_size, pos, A, kind, hop): def intialize_worker(A): global A_ A_ = A - + def extract_save_subgraph(args_): idx, (n1, n2), r_label, g_label, kind, hop, prefix, A_ = args_ A_ = ray.get(A_) @@ -610,24 +611,24 @@ def subgraph_extraction_labeling(ind, A_list, kind, h=1, max_nodes_per_hop=None, subgraph_nodes = list(ind) + list(subgraph_nei_nodes_int) else: subgraph_nodes = list(ind) + list(subgraph_nei_nodes_un) - - + + subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list] labels, enclosing_subgraph_nodes = node_label(incidence_matrix(subgraph), max_distance=h) - if kind == "union_prune" or kind == "union_prune_plus": + if kind == "union_prune" or kind == "union_prune_plus": while len(enclosing_subgraph_nodes) != len(subgraph_nodes): subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes] subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list] labels, enclosing_subgraph_nodes = node_label(incidence_matrix(subgraph), max_distance=h) - + pruned_subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes] - pruned_labels = labels[enclosing_subgraph_nodes] + pruned_labels = labels[enclosing_subgraph_nodes] else: pruned_subgraph_nodes = subgraph_nodes pruned_labels = labels - + if kind == "union_prune_plus": if not FIX2: root1_nei_1 = get_neighbor_nodes(set([ind[0]]), A_incidence, 1, 50) @@ -635,20 +636,20 @@ def subgraph_extraction_labeling(ind, A_list, kind, h=1, max_nodes_per_hop=None, else: root1_nei_1 = get_neighbor_nodes(set([ind[0]]), A_incidence, 2, 50) root2_nei_1 = get_neighbor_nodes(set([ind[1]]), A_incidence, 2, 50) - + root1_nei_1 = root1_nei_1 - set(pruned_subgraph_nodes) root2_nei_1 = root2_nei_1 - set(pruned_subgraph_nodes) - root1_nei_1 - + pruned_subgraph_nodes_after = np.array(list(pruned_subgraph_nodes) + list(root1_nei_1) + list(root2_nei_1)) - pruned_labels_after = np.zeros((len(pruned_subgraph_nodes_after), 2)) + pruned_labels_after = np.zeros((len(pruned_subgraph_nodes_after), 2)) pruned_labels_after[:len(pruned_subgraph_nodes)] = pruned_labels pruned_labels_after[len(pruned_subgraph_nodes): len(pruned_subgraph_nodes)+ len(root1_nei_1)] = [1, h] pruned_labels_after[len(pruned_subgraph_nodes)+ len(root1_nei_1):] = [h, 1] - + pruned_subgraph_nodes = pruned_subgraph_nodes_after pruned_labels = pruned_labels_after - - + + if max_node_label_value is not None: pruned_labels = np.array([np.minimum(label, max_node_label_value).tolist() for label in pruned_labels]) @@ -683,15 +684,21 @@ def node_label(subgraph, max_distance=1): # set sample_neg/sample_all_negs = True to resample negatives # by default, all cores are used in parallel; you can change this on L459 - # after the subgraph extraction is completed, run SubgraphFewshotDataset in load_kg_dataset + # after the subgraph extraction is completed, run SubgraphFewshotDataset in load_kg_dataset # with preprocess/preprocess_50negs = True to pre cache the dataset (generate the preprocessed_* dirs) # e.g. SubgraphFewshotDataset(".", shot = 3, dataset="NELL", mode="pretrain", kind="union_prune_plus", hop=2, preprocess = True, preprocess_50negs = True) # generate_subgraph_datasets(".", dataset="NELL", splits = ['pretrain', 'dev','test'], kind = "union_prune_plus", hop=2, sample_neg = False) # generate_subgraph_datasets(".", dataset="NELL", splits = ['dev','test'], kind = "union_prune_plus", hop=2, all_negs = False, sample_all_negs = False) - - # generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['pretrain', 'dev','test'], kind = "union_prune_plus", hop=1, sample_neg = False) - # generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['dev','test'], kind = "union_prune_plus", hop=1, all_negs = True, sample_all_negs = False) + generate_subgraph_datasets(".", dataset="FB15K-237", splits=['pretrain'], kind="union_prune_plus", + hop=1, sample_neg=False, inductive=True, inductive_graph=True) + generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['dev','test'], kind = "union_prune_plus", hop=1, sample_neg = False, inductive=True, inductive_graph=False) + generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['dev','test'], kind = "union_prune_plus", hop=1, all_negs = True, sample_all_negs = False, inductive=True, inductive_graph=False) + # then run following code to preprocess: + # SubgraphFewshotDataset(".", shot=3, dataset="FB15K-237", mode="pretrain", kind="union_prune_plus", hop=1, + # preprocess=True, preprocess_50neg=False, inductive=True) + # SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="dev", kind="union_prune_plus", hop=1, preprocess = True, preprocess_50neg = True, inductive = True) + # SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="test", kind="union_prune_plus", hop=1, preprocess = True, preprocess_50neg = True, inductive = True) # generate_subgraph_datasets(".", dataset="ConceptNet", splits = ['pretrain', 'dev','test'], kind = "union_prune_plus", hop=1, sample_neg = False) # generate_subgraph_datasets(".", dataset="ConceptNet", splits = ['dev','test'], kind = "union_prune_plus", hop=1, all_negs = True, sample_all_negs = False) diff --git a/load_kg_dataset.py b/load_kg_dataset.py index 03213e7..02f7b0f 100644 --- a/load_kg_dataset.py +++ b/load_kg_dataset.py @@ -19,7 +19,7 @@ from tqdm import tqdm -import lmdb +import lmdb from scipy.sparse import csc_matrix class Collater: @@ -29,21 +29,21 @@ def __init__(self): def __call__(self, batch): support_triples, support_subgraphs, support_negative_triples, support_negative_subgraphs, query_triples, query_subgraphs, negative_triples, negative_subgraphs, curr_rel = list(map(list, zip(*batch))) if support_subgraphs[0] is None: - return ((torch.tensor(support_triples), None, - torch.tensor(support_negative_triples), None, - torch.tensor(query_triples), None, - torch.tensor(negative_triples), None), + return ((torch.tensor(support_triples), None, + torch.tensor(support_negative_triples), None, + torch.tensor(query_triples), None, + torch.tensor(negative_triples), None), curr_rel) - + support_subgraphs = [item for sublist in support_subgraphs for item in sublist] support_negative_subgraphs = [item for sublist in support_negative_subgraphs for item in sublist] query_subgraphs = [item for sublist in query_subgraphs for item in sublist] - negative_subgraphs = [item for sublist in negative_subgraphs for item in sublist] - - return ((support_triples, Batch.from_data_list(support_subgraphs), - support_negative_triples, Batch.from_data_list(support_negative_subgraphs), - query_triples, Batch.from_data_list(query_subgraphs), - negative_triples, Batch.from_data_list(negative_subgraphs)), + negative_subgraphs = [item for sublist in negative_subgraphs for item in sublist] + + return ((support_triples, Batch.from_data_list(support_subgraphs), + support_negative_triples, Batch.from_data_list(support_negative_subgraphs), + query_triples, Batch.from_data_list(query_subgraphs), + negative_triples, Batch.from_data_list(negative_subgraphs)), curr_rel) @@ -63,11 +63,11 @@ def __init__( collate_fn=Collater(), **kwargs, ) - + def next_batch(self): return next(iter(self)) - - + + def serialize(data): data_tuple = tuple(data.values()) @@ -88,25 +88,25 @@ def ssp_multigraph_to_g(graph, cache = None): if cache and os.path.exists(cache): print("Use cache from: ", cache) g = torch.load(cache) - return g, g.edge_attr.max() + 1, g.num_nodes - - + return g, g.edge_attr.max() + 1, g.num_nodes + + edge_list = [[],[]] edge_features = [] for i in range(len(graph)): edge_list[0].append(graph[i].nonzero()[0]) edge_list[1].append(graph[i].nonzero()[1]) edge_features.append(torch.full((len(graph[i].nonzero()[0]),), i)) - + edge_list[0] = np.concatenate(edge_list[0]) edge_list[1] = np.concatenate(edge_list[1]) edge_index = torch.tensor(np.array(edge_list)) - + g = Data(x=None, edge_index=edge_index.long(), edge_attr= torch.cat(edge_features).long(), num_nodes=graph[0].shape[0]) if cache: torch.save(g, cache) - + return g, len(graph), g.num_nodes class SubgraphFewshotDataset(Dataset): @@ -114,25 +114,25 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 self.root = root if orig_test and mode == "test": mode = "orig_test" - self.mode = mode + self.mode = mode self.dataset = dataset self.inductive = inductive self.rev = rev raw_data_paths = os.path.join(root, dataset) - + postfix = "" if not inductive else "_inductive" if mode == "pretrain": self.tasks = json.load(open(os.path.join(raw_data_paths, mode + f'_tasks{postfix}.json'))) - self.tasks_neg = json.load(open(os.path.join(raw_data_paths, mode + f'_tasks_neg{postfix}.json'))) - print(os.path.join(raw_data_paths, mode + f'_tasks{postfix}.json')) + self.tasks_neg = json.load(open(os.path.join(raw_data_paths, mode + f'_tasks_neg{postfix}.json'))) + print(os.path.join(raw_data_paths, mode + f'_tasks{postfix}.json')) else: # dev and test self.tasks = json.load(open(os.path.join(raw_data_paths, mode + '_tasks.json'))) self.tasks_neg = json.load(open(os.path.join(raw_data_paths, mode + '_tasks_neg.json'))) print(os.path.join(raw_data_paths, mode + '_tasks.json')) - - if mode == "test" and inductive: + + if mode == "test" and inductive and not preprocess and not preprocess_50neg: print("subsample tasks!!!!!!!!!!!!!!!!!!!") self.test_tasks_idx = json.load(open(os.path.join(raw_data_paths, 'sample_test_tasks_idx.json'))) for r in list(self.tasks.keys()): @@ -140,39 +140,39 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 self.tasks[r] = [] else: self.tasks[r] = np.array(self.tasks[r])[self.test_tasks_idx[r]].tolist() - + self.e1rel_e2 = json.load(open(os.path.join(raw_data_paths,'e1rel_e2.json'))) self.all_rels = sorted(list(self.tasks.keys())) self.all_rels2id = { self.all_rels[i]:i for i in range(len(self.all_rels))} - - if mode == "test" and inductive: + + if mode == "test" and inductive and not preprocess and not preprocess_50neg: for idx, r in enumerate(list(self.all_rels)): if len(self.tasks[r]) == 0: del self.tasks[r] print("remove empty tasks!!!!!!!!!!!!!!!!!!!") self.all_rels = sorted(list(self.tasks.keys())) - + self.num_rels = len(self.all_rels) - - - + + + self.few = shot self.nq = n_query try: - if mode == "pretrain": - self.tasks_neg_all = json.load(open(os.path.join(raw_data_paths, mode + f'_tasks_{num_rank_negs}neg{postfix}.json'))) + if mode == "pretrain": + self.tasks_neg_all = json.load(open(os.path.join(raw_data_paths, mode + f'_tasks_{num_rank_negs}neg{postfix}.json'))) else: self.tasks_neg_all = json.load(open(os.path.join(raw_data_paths, mode + f'_tasks_{num_rank_negs}neg.json'))) - - + + self.all_negs = sorted(list(self.tasks_neg_all.keys())) self.all_negs2id = { self.all_negs[i]:i for i in range(len(self.all_negs))} self.num_all_negs = len(self.all_negs) except: print(mode + f'_tasks_{num_rank_negs}neg.json', "not exists") - + if mode not in ['train', 'pretrain']: - + self.eval_triples = [] self.eval_triples_ids = [] for rel in self.all_rels: @@ -185,21 +185,24 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 ###### backgroud KG ####### - cache = os.path.join(raw_data_paths, f'graph{postfix}.pt') + if mode=='pretrain': + cache = os.path.join(raw_data_paths, f'graph{postfix}.pt') + else: + cache = os.path.join(raw_data_paths, f'graph.pt') if os.path.exists(cache): print("Use cache from: ", cache) ssp_graph = None - + with open(os.path.join(raw_data_paths, f'relation2id{postfix}.json'), 'r') as f: - relation2id = json.load(f) + relation2id = json.load(f) with open(os.path.join(raw_data_paths, f'entity2id{postfix}.json'), 'r') as f: - entity2id = json.load(f) - + entity2id = json.load(f) + id2relation = {v: k for k, v in relation2id.items()} id2entity = {v: k for k, v in entity2id.items()} else: - ssp_graph, __, entity2id, relation2id, id2entity, id2relation = process_files(raw_data_paths, inductive = inductive) + ssp_graph, __, entity2id, relation2id, id2entity, id2relation = process_files(raw_data_paths, inductive = inductive, inductive_graph=mode == 'pretrain') # self.num_rels_bg = len(ssp_graph) # Add transpose matrices to handle both directions of relations. @@ -211,7 +214,7 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 # self.num_rels_bg = len(ssp_graph) self.graph, _, self.num_nodes_bg = ssp_multigraph_to_g(ssp_graph, cache) - + self.num_rels_bg = len(relation2id.keys()) if rev: self.num_rels_bg = self.num_rels_bg * 2 # add rev edges @@ -220,7 +223,7 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 self.relation2id = relation2id self.id2entity = id2entity self.id2relation = id2relation - + ###### preprocess subgraphs ####### if rev: @@ -244,7 +247,7 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 db_path = os.path.join(raw_data_paths, f"subgraphs_fix_new_{kind}_hop=" + str(hop)+ postfix) print(db_path) self.main_env = lmdb.open(db_path, readonly=True, max_dbs=4, lock=False) - + self.db_pos = self.main_env.open_db((mode + "_pos").encode()) self.db_neg = self.main_env.open_db((mode + "_neg").encode()) @@ -252,14 +255,14 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 self.max_n_label = np.array([3, 3]) self._preprocess() - - if preprocess_50neg: + + if preprocess_50neg: db_path_50negs = os.path.join(raw_data_paths, f"subgraphs_fix_new_{kind}_{num_rank_negs}negs_hop=" + str(hop)+ postfix) if use_fix2: db_path_50negs = os.path.join(raw_data_paths, f"subgraphs_fix2_new_{kind}_{num_rank_negs}negs_hop=" + str(hop)+ postfix) print(db_path_50negs) self.main_env = lmdb.open(db_path_50negs, readonly=True, max_dbs=3, lock=False) - + self.db_50negs = self.main_env.open_db((mode + "_neg").encode()) self.max_n_label = np.array([0, 0]) @@ -268,8 +271,8 @@ def __init__(self, root, add_traspose_rels=False, shot = 1, n_query = 3, hop = 2 self.max_n_label[1] = int.from_bytes(txn.get('max_n_label_obj'.encode()), byteorder='little') self._preprocess_50negs(num_rank_negs) - - + + if (not preprocess) and (not preprocess_50neg) and (not skip): try: self.pos_dict = torch.load(os.path.join(self.dict_save_path, "pos-%s.pt" % self.mode)) @@ -297,7 +300,7 @@ def _save_torch_geometric(self, index): pos_edge_index, pos_x, pos_x_id, pos_edge_attr, pos_n_size, pos_e_size = [], [], [], [], [], [] neg_edge_index, neg_x, neg_x_id, neg_edge_attr, neg_n_size, neg_e_size = [], [], [], [], [], [] - + with self.main_env.begin(db=self.db_pos) as txn: for idx, i in enumerate(curr_tasks_idx): str_id = curr_rel.encode()+'{:08}'.format(i).encode('ascii') @@ -311,9 +314,9 @@ def _save_torch_geometric(self, index): pos_edge_attr.append(d.edge_attr) pos_n_size.append(d.x.shape[0]) pos_e_size.append(d.edge_index.shape[1]) - - with self.main_env.begin(db=self.db_neg) as txn: - for idx, i in enumerate(curr_tasks_neg_idx): + + with self.main_env.begin(db=self.db_neg) as txn: + for idx, i in enumerate(curr_tasks_neg_idx): str_id = curr_rel.encode()+'{:08}'.format(i).encode('ascii') nodes_neg, r_label_neg, g_label_neg, n_labels_neg = deserialize(txn.get(str_id)).values() d = self._prepare_subgraphs(nodes_neg, r_label_neg, n_labels_neg) @@ -327,9 +330,9 @@ def _save_torch_geometric(self, index): neg_e_size.append(d.edge_index.shape[1]) return torch.cat(pos_edge_index, 1), torch.cat(pos_x, 0), torch.cat(pos_x_id, 0), torch.cat(pos_edge_attr, 0), torch.LongTensor(pos_n_size), torch.LongTensor(pos_e_size), torch.cat(neg_edge_index, 1), torch.cat(neg_x, 0), torch.cat(neg_x_id, 0), torch.cat(neg_edge_attr, 0), torch.LongTensor(neg_n_size), torch.LongTensor(neg_e_size) - + def dict_to_torch_geometric(self, index, data_dict): - + if index == 0: task_index = 0 start_e = 0 @@ -338,7 +341,7 @@ def dict_to_torch_geometric(self, index, data_dict): task_index = data_dict["task_offsets"][index-1] start_e = data_dict['e_size'][task_index - 1] start_n = data_dict['n_size'][task_index - 1] - + task_index_end = data_dict["task_offsets"][index] graphs = [] @@ -355,10 +358,10 @@ def dict_to_torch_geometric(self, index, data_dict): return graphs - + def _preprocess_50negs(self, num_rank_negs): print("start preprocessing 50negs for %s" % self.mode) - + all_neg_edge_index, all_neg_x, all_neg_x_id, all_neg_edge_attr, all_neg_n_size, all_neg_e_size = [], [], [], [], [], [] task_offsets_neg = [] for index in tqdm(range(self.num_all_negs)): @@ -368,9 +371,9 @@ def _preprocess_50negs(self, num_rank_negs): neg_edge_index, neg_x, neg_x_id, neg_edge_attr, neg_n_size, neg_e_size = [], [], [], [], [], [] - + with self.main_env.begin(db=self.db_50negs) as txn: - for idx, i in enumerate(curr_tasks_neg_idx): + for idx, i in enumerate(curr_tasks_neg_idx): str_id = curr_rel.encode()+'{:08}'.format(i).encode('ascii') nodes_neg, r_label_neg, g_label_neg, n_labels_neg = deserialize(txn.get(str_id)).values() d = self._prepare_subgraphs(nodes_neg, r_label_neg, n_labels_neg) @@ -381,7 +384,7 @@ def _preprocess_50negs(self, num_rank_negs): neg_n_size.append(d.x.shape[0]) neg_e_size.append(d.edge_index.shape[1]) - + all_neg_edge_index.append(torch.cat(neg_edge_index, 1)) all_neg_x.append(torch.cat(neg_x, 0)) all_neg_x_id.append(torch.cat(neg_x_id, 0)) @@ -393,8 +396,8 @@ def _preprocess_50negs(self, num_rank_negs): print("concat all") all_neg_edge_index = torch.cat(all_neg_edge_index, 1) - all_neg_x = torch.cat(all_neg_x, 0) - all_neg_x_id = torch.cat(all_neg_x_id, 0) + all_neg_x = torch.cat(all_neg_x, 0) + all_neg_x_id = torch.cat(all_neg_x_id, 0) all_neg_edge_attr = torch.cat(all_neg_edge_attr, 0) all_neg_n_size = torch.cat(all_neg_n_size) @@ -403,10 +406,10 @@ def _preprocess_50negs(self, num_rank_negs): all_neg_n_size = torch.cumsum(all_neg_n_size, 0) all_neg_e_size = torch.cumsum(all_neg_e_size, 0) - + task_offsets_neg = torch.tensor(task_offsets_neg) task_offsets_neg = torch.cumsum(task_offsets_neg, 0) - + save_path = self.dict_save_path neg_save_dict = { @@ -422,7 +425,7 @@ def _preprocess_50negs(self, num_rank_negs): print("saving to", os.path.join(save_path, f"neg_{num_rank_negs}negs-%s.pt" % self.mode)) torch.save(neg_save_dict, os.path.join(save_path, f"neg_{num_rank_negs}negs-%s.pt" % self.mode)) self.all_neg_dict = neg_save_dict - + def _preprocess(self): print("start preprocessing %s" % self.mode) all_pos_edge_index, all_pos_x, all_pos_x_id, all_pos_edge_attr, all_pos_n_size, all_pos_e_size = [], [], [], [], [], [] @@ -438,7 +441,7 @@ def _preprocess(self): all_pos_n_size.append(pos_n_size) all_pos_e_size.append(pos_e_size) task_offsets_pos.append(len(pos_n_size)) - + all_neg_edge_index.append(neg_edge_index) all_neg_x.append(neg_x) all_neg_x_id.append(neg_x_id) @@ -449,14 +452,14 @@ def _preprocess(self): print("concat all") all_pos_edge_index = torch.cat(all_pos_edge_index, 1) - all_pos_x = torch.cat(all_pos_x, 0) - all_pos_x_id = torch.cat(all_pos_x_id, 0) + all_pos_x = torch.cat(all_pos_x, 0) + all_pos_x_id = torch.cat(all_pos_x_id, 0) all_pos_edge_attr = torch.cat(all_pos_edge_attr, 0) all_neg_edge_index = torch.cat(all_neg_edge_index, 1) - all_neg_x = torch.cat(all_neg_x, 0) - all_neg_x_id = torch.cat(all_neg_x_id, 0) + all_neg_x = torch.cat(all_neg_x, 0) + all_neg_x_id = torch.cat(all_neg_x_id, 0) all_neg_edge_attr = torch.cat(all_neg_edge_attr, 0) @@ -470,12 +473,12 @@ def _preprocess(self): all_neg_n_size = torch.cumsum(all_neg_n_size, 0) all_neg_e_size = torch.cumsum(all_neg_e_size, 0) - + task_offsets_pos = torch.tensor(task_offsets_pos) task_offsets_pos = torch.cumsum(task_offsets_pos, 0) task_offsets_neg = torch.tensor(task_offsets_neg) task_offsets_neg = torch.cumsum(task_offsets_neg, 0) - + save_path = self.dict_save_path pos_save_dict = { 'edge_index': all_pos_edge_index, @@ -502,8 +505,8 @@ def _preprocess(self): torch.save(neg_save_dict, os.path.join(save_path, "neg-%s.pt" % self.mode)) self.pos_dict = pos_save_dict self.neg_dict = neg_save_dict - - def __getitem__(self, index): + + def __getitem__(self, index): # get current relation and current candidates curr_rel = self.all_rels[index] curr_tasks = self.tasks[curr_rel] @@ -511,15 +514,15 @@ def __getitem__(self, index): if self.nq is not None: curr_tasks_idx = np.random.choice(curr_tasks_idx, self.few+self.nq, replace = False) support_triples = [curr_tasks[i] for i in curr_tasks_idx[:self.few]] - query_triples = [curr_tasks[i] for i in curr_tasks_idx[self.few:]] - + query_triples = [curr_tasks[i] for i in curr_tasks_idx[self.few:]] + all_pos_graphs = self.dict_to_torch_geometric(self.all_rels2id[curr_rel], self.pos_dict) all_neg_graphs = self.dict_to_torch_geometric(self.all_rels2id[curr_rel], self.neg_dict) - - ### extract subgraphs + + ### extract subgraphs support_subgraphs = [] query_subgraphs = [] - for idx, i in enumerate(curr_tasks_idx): + for idx, i in enumerate(curr_tasks_idx): if self.mode == "test" and self.inductive: subgraph_pos = all_pos_graphs[self.test_tasks_idx[curr_rel][i]] @@ -530,104 +533,104 @@ def __getitem__(self, index): else: query_subgraphs.append(subgraph_pos) - + curr_tasks_neg = self.tasks_neg[curr_rel] curr_tasks_neg_idx = curr_tasks_idx - + support_negative_triples = [curr_tasks_neg[i] for i in curr_tasks_neg_idx[:self.few]] - negative_triples = [curr_tasks_neg[i] for i in curr_tasks_neg_idx[self.few:]] + negative_triples = [curr_tasks_neg[i] for i in curr_tasks_neg_idx[self.few:]] # construct support and query negative triples support_negative_subgraphs = [] negative_subgraphs = [] - for idx, i in enumerate(curr_tasks_neg_idx): + for idx, i in enumerate(curr_tasks_neg_idx): if self.mode == "test" and self.inductive: subgraph_neg = all_neg_graphs[self.test_tasks_idx[curr_rel][i]] else: subgraph_neg = all_neg_graphs[i] - + if (self.mode in ["train", "pretrain"] and self.dataset in ['NELL', 'FB15K-237'] and not self.inductive): #choose 1 neg from 50 e1, r, e2 = curr_tasks[i] all_50_neg_graphs = self.dict_to_torch_geometric(self.all_negs2id[e1 + r + e2], self.all_neg_dict) subgraph_neg = random.choice(all_50_neg_graphs) - + if idx < self.few: support_negative_subgraphs.append(subgraph_neg) else: negative_subgraphs.append(subgraph_neg) return support_triples, support_subgraphs, support_negative_triples, support_negative_subgraphs, query_triples, query_subgraphs, negative_triples, negative_subgraphs, curr_rel - - + + def next_one_on_eval(self, index): # get current triple query_triple = self.eval_triples[index] curr_rel = query_triple[1] curr_rel_neg = query_triple[0] + query_triple[1] + query_triple[2] curr_task = self.tasks[curr_rel] - + all_pos_graphs = self.dict_to_torch_geometric(self.all_rels2id[curr_rel], self.pos_dict) all_neg_graphs = self.dict_to_torch_geometric(self.all_rels2id[curr_rel], self.neg_dict) all_50_neg_graphs = self.dict_to_torch_geometric(self.all_negs2id[curr_rel_neg], self.all_neg_dict) - + # get support triples support_triples_idx = np.arange(0, len(curr_task), 1)[:self.few] support_triples = [] support_subgraphs = [] - for idx, i in enumerate(support_triples_idx): - support_triples.append(curr_task[i]) + for idx, i in enumerate(support_triples_idx): + support_triples.append(curr_task[i]) if self.mode == "test" and self.inductive: subgraph_pos = all_pos_graphs[self.test_tasks_idx[curr_rel][i]] else: subgraph_pos = all_pos_graphs[i] - support_subgraphs.append(subgraph_pos) + support_subgraphs.append(subgraph_pos) query_triples = [query_triple] query_subgraphs = [] - + if self.mode == "test" and self.inductive: subgraph_pos = all_pos_graphs[self.test_tasks_idx[curr_rel][self.eval_triples_ids[index]]] else: subgraph_pos = all_pos_graphs[self.eval_triples_ids[index]] - + query_subgraphs.append(subgraph_pos) - - + + # construct support negative - + curr_task_neg = self.tasks_neg[curr_rel] support_negative_triples_idx = support_triples_idx support_negative_triples = [] support_negative_subgraphs = [] - for idx, i in enumerate(support_negative_triples_idx): + for idx, i in enumerate(support_negative_triples_idx): support_negative_triples.append(curr_task_neg[i]) - + if self.mode == "test" and self.inductive: subgraph_neg = all_neg_graphs[self.test_tasks_idx[curr_rel][i]] else: subgraph_neg = all_neg_graphs[i] - + support_negative_subgraphs.append(subgraph_neg) - - + + ### 50 query negs curr_task_50neg = self.tasks_neg_all[curr_rel_neg] negative_triples_idx = np.arange(0, len(curr_task_50neg), 1) negative_triples = [] negative_subgraphs = [] - for idx, i in enumerate(negative_triples_idx): + for idx, i in enumerate(negative_triples_idx): negative_triples.append(curr_task_50neg[i]) negative_subgraphs.append(all_50_neg_graphs[i]) - + return support_triples, support_subgraphs, support_negative_triples, support_negative_subgraphs, query_triples, query_subgraphs, negative_triples, negative_subgraphs, curr_rel - - + + def _prepare_subgraphs(self, nodes, r_label, n_labels): # import pdb;pdb.set_trace() if nodes[0] == nodes[1]: @@ -637,32 +640,32 @@ def _prepare_subgraphs(self, nodes, r_label, n_labels): subgraph = Data(edge_index = torch.zeros([2, 0]), edge_attr = torch.zeros([0]), num_nodes = 2) else: subgraph = get_subgraph(self.graph, torch.tensor(nodes)) - # remove the (0,1) target edge + # remove the (0,1) target edge index = (torch.tensor([0, 1]) == subgraph.edge_index.transpose(0,1)).all(1) index = index & (subgraph.edge_attr == r_label) if index.any(): subgraph.edge_index = subgraph.edge_index.transpose(0,1)[~index].transpose(0,1) subgraph.edge_attr= subgraph.edge_attr[~index] - - - # add reverse edges + + + # add reverse edges if self.rev: subgraph.edge_index = torch.cat([subgraph.edge_index, subgraph.edge_index.flip(0)], 1) subgraph.edge_attr = torch.cat([subgraph.edge_attr, self.num_rels_bg - subgraph.edge_attr], 0) - + # One hot encode the node label feature and concat to n_featsure n_nodes = subgraph.num_nodes n_labels = n_labels.astype(int) label_feats = np.zeros((n_nodes, 6)) label_feats[0] = [1, 0, 0, 0, 1, 0] label_feats[1] = [0, 1, 0, 1, 0, 0] - - - + + + subgraph.x = torch.FloatTensor(label_feats) subgraph.x_id = torch.LongTensor(nodes) - - # sort it + + # sort it edge_index = subgraph.edge_index edge_attr = subgraph.edge_attr row = edge_index[0] @@ -675,45 +678,47 @@ def _prepare_subgraphs(self, nodes, r_label, n_labels): row = row[perm] col = col[perm] edge_attr = edge_attr[perm] - edge_index = torch.stack([row,col], 0) + edge_index = torch.stack([row,col], 0) subgraph.edge_index = edge_index subgraph.edge_attr = edge_attr return subgraph - - -def process_files(data_path, use_cache = True, inductive = False): + + +def process_files(data_path, use_cache = True, inductive = False, inductive_graph=False): entity2id = {} - relation2id = {} - + relation2id = {} + postfix = "" if not inductive else "_inductive" relation2id_path = os.path.join(data_path, f'relation2id{postfix}.json') if use_cache and os.path.exists(relation2id_path): print("Use cache from: ", relation2id_path) with open(relation2id_path, 'r') as f: - relation2id = json.load(f) - - - + relation2id = json.load(f) + + + entity2id_path = os.path.join(data_path, f'entity2id{postfix}.json') if use_cache and os.path.exists(entity2id_path): print("Use cache from: ", entity2id_path) with open(entity2id_path, 'r') as f: - entity2id = json.load(f) - + entity2id = json.load(f) + triplets = {} ent = 0 rel = 0 for mode in ['bg']: # assuming only one kind of background graph for now - - - file_path = os.path.join(data_path,f'path_graph{postfix}.json') + + if inductive_graph: + file_path = os.path.join(data_path,f'path_graph_inductive.json') + else: + file_path = os.path.join(data_path, f'path_graph.json') data = [] with open(file_path) as f: file_data = json.load(f) @@ -751,7 +756,7 @@ def process_files(data_path, use_cache = True, inductive = False): if not os.path.exists(entity2id_path): with open(entity2id_path, 'w') as f: - json.dump(entity2id, f) + json.dump(entity2id, f) return adj_list, triplets, entity2id, relation2id, id2entity, id2relation @@ -789,7 +794,7 @@ def get_subgraph(graph, nodes): node_idx[nodes] = torch.arange(subset.sum().item(), device=device) edge_index = node_idx[edge_index] - + num_nodes = nodes.size(0) data = copy.copy(graph) @@ -810,7 +815,7 @@ def get_subgraph(graph, nodes): class SubgraphFewshotDatasetRankTail(SubgraphFewshotDataset): def __len__(self): return len(self.eval_triples) - + def __getitem__(self, index): return self.next_one_on_eval(index)