diff --git a/README.md b/README.md
index 40389212..4d0580f8 100644
--- a/README.md
+++ b/README.md
@@ -10,8 +10,32 @@ Library for analyzing source code with graphs and NLP. What this repository can
### Installation
+You need to use conda, create virtual environment `SourceCodeTools` with python 3.8
+```bash
+conda create -n SourceCodeTools python=3.8
+```
+
+If you plan to use graphviz
+```python
+conda install -c conda-forge pygraphviz
+```
+
+Install CUDA 11.1 if needed
+```python
+conda install -c nvidia cudatoolkit=11.1
+```
+
+To install SourceCodeTools library run
```bash
git clone https://github.com/VitalyRomanov/method-embedding.git
cd method-embedding
pip install -e .
-```
\ No newline at end of file
+# pip install -e .[gpu]
+```
+
+### Installing Sourcetrail
+Download a release from [Github repo](https://github.com/CoatiSoftware/Sourcetrail/releases) (latest tested version is 2020.1.117). Add Sourcetrail location to `PATH`
+```bash
+echo 'export PATH=/path/to/Sourcetrail_2020_1_117:$PATH' >> ~/.bashrc
+```
+Scripts that use Sourcetrail work on Linux, some issues were spotted on Macs.
\ No newline at end of file
diff --git a/SourceCodeTools/code/ast_tools.py b/SourceCodeTools/code/ast_tools.py
index 98b5571d..5ed78eb3 100644
--- a/SourceCodeTools/code/ast_tools.py
+++ b/SourceCodeTools/code/ast_tools.py
@@ -30,9 +30,18 @@ def get_mentions(function, root, mention):
def get_descendants(function, children):
+ """
+ :param function: function string
+ :param children: List of targets.
+ :return: Offsets for attributes or names that are used as target for assignment operation. Subscript, Tuple and List
+ targets are skipped.
+ """
descendants = []
+ # if isinstance(children, ast.Tuple):
+ # descendants.extend(get_descendants(function, children.elts))
+ # else:
for chld in children:
# for node in ast.walk(chld):
node = chld
@@ -42,6 +51,10 @@ def get_descendants(function, children):
[(node.lineno-1, node.end_lineno-1, node.col_offset, node.end_col_offset, "new_var")], as_bytes=True)
# descendants.append((node.id, offset[-1]))
descendants.append((function[offset[-1][0]:offset[-1][1]], offset[-1]))
+ # elif isinstance(node, ast.Tuple):
+ # descendants.extend(get_descendants(function, node.elts))
+ elif isinstance(node, ast.Subscript) or isinstance(node, ast.Tuple) or isinstance(node, ast.List):
+ pass # skip for now
else:
raise Exception("")
diff --git a/SourceCodeTools/code/data/sourcetrail/Dataset.py b/SourceCodeTools/code/data/sourcetrail/Dataset.py
index 56f3de34..a13c6bae 100644
--- a/SourceCodeTools/code/data/sourcetrail/Dataset.py
+++ b/SourceCodeTools/code/data/sourcetrail/Dataset.py
@@ -8,7 +8,7 @@
from os.path import join
-from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker, NodeNameMasker
+from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker, NodeNameMasker, NodeClfMasker
from SourceCodeTools.code.data.sourcetrail.file_utils import *
from SourceCodeTools.code.python_ast import PythonSharedNodes
from SourceCodeTools.nlp.embed.bpe import make_tokenizer, load_bpe_model
@@ -163,6 +163,15 @@ def get_global_edges():
return types
+def get_embeddable_name(name):
+ if "@" in name:
+ return name.split("@")[0]
+ elif "_0x" in name:
+ return name.split("_0x")[0]
+ else:
+ return name
+
+
class SourceGraphDataset:
g = None
nodes = None
@@ -182,7 +191,8 @@ def __init__(self, data_path,
label_from, use_node_types=False,
use_edge_types=False, filter=None, self_loops=False,
train_frac=0.6, random_seed=None, tokenizer_path=None, min_count_for_objectives=1,
- no_global_edges=False, remove_reverse=False, package_names=None):
+ no_global_edges=False, remove_reverse=False, custom_reverse=None, package_names=None,
+ restricted_id_pool=None, use_ns_groups=False):
"""
Prepares the data for training GNN model. The graph is prepared in the following way:
1. Edges are split into the train set and holdout set. Holdout set is used in the future experiments.
@@ -195,10 +205,6 @@ def __init__(self, data_path,
node_types flag to True.
4. Graphs require contiguous indexing of nodes. For this reason additional mapping is created that tracks
the relationship between the new graph id and the original node id from the training data.
- :param nodes_path: path to csv or compressed csv for nodes witch columns
- "id", "type", {"name", "serialized_name"}, {any column with labels}
- :param edges_path: path to csv or compressed csv for edges with columns
- "id", "type", {"source_node_id", "src"}, {"target_node_id", "dst"}
:param label_from: the column where the labels are taken from
:param use_node_types: boolean value, whether to use node types or not
(node-heterogeneous graph}
@@ -216,12 +222,17 @@ def __init__(self, data_path,
self.min_count_for_objectives = min_count_for_objectives
self.no_global_edges = no_global_edges
self.remove_reverse = remove_reverse
+ self.custom_reverse = None if custom_reverse is None else custom_reverse.split(",")
+
+ self.use_ns_groups = use_ns_groups
nodes_path = join(data_path, "nodes.bz2")
edges_path = join(data_path, "edges.bz2")
self.nodes, self.edges = load_data(nodes_path, edges_path)
+ # self.nodes, self.edges, self.holdout = self.holdout(self.nodes, self.edges)
+
# index is later used for sampling and is assumed to be unique
assert len(self.nodes) == len(self.nodes.index.unique())
assert len(self.edges) == len(self.edges.index.unique())
@@ -240,6 +251,9 @@ def __init__(self, data_path,
if self.no_global_edges:
self.remove_global_edges()
+ if self.custom_reverse is not None:
+ self.add_custom_reverse()
+
if use_node_types is False and use_edge_types is False:
new_nodes, new_edges = self.create_nodetype_edges()
self.nodes = self.nodes.append(new_nodes, ignore_index=True)
@@ -250,7 +264,8 @@ def __init__(self, data_path,
self.nodes['type'] = "node_"
self.nodes = self.nodes.astype({'type': 'category'})
- self.add_embeddable_flag()
+ self.add_embedding_names()
+ # self.add_embeddable_flag()
# need to do this to avoid issues insode dgl library
self.edges['type'] = self.edges['type'].apply(lambda x: f"{x}_")
@@ -271,7 +286,7 @@ def __init__(self, data_path,
# self.nodes, self.label_map = self.add_compact_labels()
self.add_typed_ids()
- self.add_splits(train_frac=train_frac, package_names=package_names)
+ self.add_splits(train_frac=train_frac, package_names=package_names, restricted_id_pool=restricted_id_pool)
# self.mark_leaf_nodes()
@@ -310,8 +325,12 @@ def compress_edge_types(self):
self.edges['type'] = self.edges['type'].apply(lambda x: edge_type_map[x])
+ def add_embedding_names(self):
+ self.nodes["embeddable"] = True
+ self.nodes["embeddable_name"] = self.nodes["name"].apply(get_embeddable_name)
+
def add_embeddable_flag(self):
- embeddable_types = PythonSharedNodes.shared_node_types
+ embeddable_types = PythonSharedNodes.shared_node_types | set(list(node_types.values()))
if len(self.nodes.query("type_backup == 'subword'")) > 0:
# some of the types should not be embedded if subwords were generated
@@ -326,6 +345,8 @@ def add_embeddable_flag(self):
inplace=True
)
+ self.nodes["embeddable_name"] = self.nodes["name"].apply(get_embeddable_name)
+
def op_tokens(self):
if self.tokenizer_path is None:
from SourceCodeTools.code.python_tokens_to_bpe_subwords import python_ops_to_bpe
@@ -351,7 +372,7 @@ def op_tokens(self):
# self.nodes.eval("name_alter_tokens = name.map(@op_tokenize)",
# local_dict={"op_tokenize": op_tokenize}, inplace=True)
- def add_splits(self, train_frac, package_names=None):
+ def add_splits(self, train_frac, package_names=None, restricted_id_pool=None):
"""
Generates train, validation, and test masks
Store the masks is pandas table for nodes
@@ -359,6 +380,7 @@ def add_splits(self, train_frac, package_names=None):
:return:
"""
+ self.nodes.reset_index(drop=True, inplace=True)
assert len(self.nodes.index) == self.nodes.index.max() + 1
# generate splits for all nodes, additional filtering will be applied later
# by an objective
@@ -376,6 +398,14 @@ def add_splits(self, train_frac, package_names=None):
create_train_val_test_masks(self.nodes, *splits)
+ if restricted_id_pool is not None:
+ node_ids = set(pd.read_csv(restricted_id_pool)["node_id"].tolist()) | \
+ set(self.nodes.query("type_backup == 'FunctionDef' or type_backup == 'mention'")["id"].tolist())
+ to_keep = self.nodes["id"].apply(lambda id_: id_ in node_ids)
+ self.nodes["train_mask"] = self.nodes["train_mask"] & to_keep
+ self.nodes["test_mask"] = self.nodes["test_mask"] & to_keep
+ self.nodes["val_mask"] = self.nodes["val_mask"] & to_keep
+
def add_typed_ids(self):
nodes = self.nodes.copy()
@@ -464,7 +494,8 @@ def remove_ast_edges(self):
def remove_global_edges(self):
global_edges = get_global_edges()
- global_edges.add("global_mention")
+ # global_edges.add("global_mention")
+ # global_edges |= set(edge + "_rev" for edge in global_edges)
is_ast = lambda type: type not in global_edges
edges = self.edges.query("type.map(@is_ast)", local_dict={"is_ast": is_ast})
self.edges = edges
@@ -472,12 +503,25 @@ def remove_global_edges(self):
def remove_reverse_edges(self):
from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import special_mapping
- global_reverse = {val for _, val in special_mapping.items()}
+ # TODO test this change
+ global_reverse = {key for key, val in special_mapping.items()}
not_reverse = lambda type: not (type.endswith("_rev") or type in global_reverse)
edges = self.edges.query("type.map(@not_reverse)", local_dict={"not_reverse": not_reverse})
self.edges = edges
+ def add_custom_reverse(self):
+ to_reverse = self.edges[
+ self.edges["type"].apply(lambda type_: type_ in self.custom_reverse)
+ ]
+
+ to_reverse["type"] = to_reverse["type"].apply(lambda type_: type_ + "_rev")
+ tmp = to_reverse["src"]
+ to_reverse["src"] = to_reverse["dst"]
+ to_reverse["dst"] = tmp
+
+ self.edges = self.edges.append(to_reverse[["src", "dst", "type"]])
+
def update_global_id(self):
orig_id = []
graph_id = []
@@ -599,24 +643,52 @@ def assess_need_for_self_loops(cls, nodes, edges):
return nodes, edges
- # @classmethod
- # def holdout(cls, nodes, edges, holdout_frac, random_seed):
- # """
- # Create a set of holdout edges, ensure that there are no orphan nodes after these edges are removed.
- # :param nodes:
- # :param edges:
- # :param holdout_frac:
- # :param random_seed:
- # :return:
- # """
- #
- # train, test = split(edges, holdout_frac, random_seed=random_seed)
- #
- # nodes, train_edges = ensure_connectedness(nodes, train)
- #
- # nodes, test_edges = ensure_valid_edges(nodes, test)
- #
- # return nodes, train_edges, test_edges
+ @classmethod
+ def holdout(cls, nodes: pd.DataFrame, edges: pd.DataFrame, holdout_size=10000, random_seed=42):
+ """
+ Create a set of holdout edges, ensure that there are no orphan nodes after these edges are removed.
+ :param nodes:
+ :param edges:
+ :param holdout_frac:
+ :param random_seed:
+ :return:
+ """
+
+ from collections import Counter
+
+ degree_count = Counter(edges["src"].tolist()) | Counter(edges["dst"].tolist())
+
+ heldout = []
+
+ edges = edges.reset_index(drop=True)
+ index = edges.index.to_numpy()
+ numpy.random.seed(random_seed)
+ numpy.random.shuffle(index)
+
+ for i in index:
+ src_id = edges.loc[i].src
+ if degree_count[src_id] > 2:
+ heldout.append(edges.loc[i].id)
+ degree_count[src_id] -= 1
+ if len(heldout) >= holdout_size:
+ break
+
+ heldout = set(heldout)
+
+ def is_held(id_):
+ return id_ in heldout
+
+ train_edges = edges[
+ edges["id"].apply(lambda id_: not is_held(id_))
+ ]
+
+ heldout_edges = edges[
+ edges["id"].apply(is_held)
+ ]
+
+ assert len(edges) == edges["id"].unique().size
+
+ return nodes, train_edges, heldout_edges
# def mark_leaf_nodes(self):
# leaf_types = {'subword', "Op", "Constant", "Name"} # the last is used in graphs without subwords
@@ -638,6 +710,11 @@ def load_node_names(self):
][['id', 'type_backup', 'name']]\
.rename({"name": "serialized_name", "type_backup": "type"}, axis=1)
+ global_node_types = set(node_types.values())
+ for_training = for_training[
+ for_training["type"].apply(lambda x: x not in global_node_types)
+ ]
+
node_names = extract_node_names(for_training, 2)
node_names = filter_dst_by_freq(node_names, freq=self.min_count_for_objectives)
@@ -707,6 +784,76 @@ def load_global_edges_prediction(self):
return edges[["src", "dst"]]
+ def load_edge_prediction(self):
+
+ nodes_path = join(self.data_path, "nodes.bz2")
+ edges_path = join(self.data_path, "edges.bz2")
+
+ _, edges = load_data(nodes_path, edges_path)
+
+ edges.rename(
+ {
+ "source_node_id": "src",
+ "target_node_id": "dst"
+ }, inplace=True, axis=1
+ )
+
+ global_edges = {"global_mention", "subword", "next", "prev"}
+ global_edges = global_edges | {"mention_scope", "defined_in_module", "defined_in_class", "defined_in_function"}
+
+ if self.no_global_edges:
+ global_edges = global_edges | get_global_edges()
+
+ global_edges = global_edges | set(edge + "_rev" for edge in global_edges)
+ is_ast = lambda type: type not in global_edges
+ edges = edges.query("type.map(@is_ast)", local_dict={"is_ast": is_ast})
+ edges = edges[edges["type"].apply(lambda type_: not type_.endswith("_rev"))]
+
+ valid_nodes = set(edges["src"].tolist())
+ valid_nodes = valid_nodes.intersection(set(edges["dst"].tolist()))
+
+ # if self.use_ns_groups:
+ # groups = self.get_negative_sample_groups()
+ # valid_nodes = valid_nodes.intersection(set(groups["id"].tolist()))
+
+ edges = edges[
+ edges["src"].apply(lambda id_: id_ in valid_nodes)
+ ]
+ edges = edges[
+ edges["dst"].apply(lambda id_: id_ in valid_nodes)
+ ]
+
+ return edges[["src", "dst", "type"]]
+
+ def load_type_prediction(self):
+
+ type_ann = unpersist(join(self.data_path, "type_annotations.bz2"))
+
+ filter_rule = lambda name: "0x" not in name
+
+ type_ann = type_ann[
+ type_ann["dst"].apply(filter_rule)
+ ]
+
+ node2id = dict(zip(self.nodes["id"], self.nodes["type_backup"]))
+ type_ann = type_ann[
+ type_ann["src"].apply(lambda id_: id_ in node2id)
+ ]
+
+ type_ann["src_type"] = type_ann["src"].apply(lambda x: node2id[x])
+
+ type_ann = type_ann[
+ type_ann["src_type"].apply(lambda type_: type_ in {"mention"}) # FunctionDef {"arg", "AnnAssign"})
+ ]
+
+ norm = lambda x: x.strip("\"").strip("'").split("[")[0].split(".")[-1]
+
+ type_ann["dst"] = type_ann["dst"].apply(norm)
+ type_ann = filter_dst_by_freq(type_ann, self.min_count_for_objectives)
+ type_ann = type_ann[["src", "dst"]]
+
+ return type_ann
+
def load_docstring(self):
docstrings_path = os.path.join(self.data_path, "common_source_graph_bodies.bz2")
@@ -730,6 +877,18 @@ def normalize(text):
return dosctrings
+ def load_node_classes(self):
+ have_inbound = set(self.edges["dst"].tolist())
+ labels = self.nodes.query("train_mask == True or test_mask == True or val_mask == True")[["id", "type_backup"]].rename({
+ "id": "src",
+ "type_backup": "dst"
+ }, axis=1)
+
+ labels = labels[
+ labels["src"].apply(lambda id_: id_ in have_inbound)
+ ]
+ return labels
+
def buckets_from_pretrained_embeddings(self, pretrained_path, n_buckets):
from SourceCodeTools.nlp.embed.fasttext import load_w2v_map
@@ -783,6 +942,25 @@ def create_node_name_masker(self, tokenizer_path):
"""
return NodeNameMasker(self.nodes, self.edges, self.load_node_names(), tokenizer_path)
+ def create_node_clf_masker(self):
+ """
+ :param tokenizer_path: path to bpe tokenizer
+ :return: SubwordMasker for function nodes. Suitable for node name use prediction objective
+ """
+ return NodeClfMasker(self.nodes, self.edges)
+
+ def get_negative_sample_groups(self):
+ return self.nodes[["id", "mentioned_in"]].dropna(axis=0)
+
+
+ @classmethod
+ def load(cls, path, args):
+ dataset = pickle.load(open(path, "rb"))
+ dataset.data_path = args.data_path
+ if dataset.tokenizer_path is not None:
+ dataset.tokenizer_path = args.tokenizer
+ return dataset
+
def ensure_connectedness(nodes: pandas.DataFrame, edges: pandas.DataFrame):
"""
@@ -842,10 +1020,10 @@ def ensure_valid_edges(nodes, edges, ignore_src=False):
return nodes, edges
-def read_or_create_dataset(args, model_base, labels_from="type"):
- if args.restore_state:
+def read_or_create_dataset(args, model_base, labels_from="type", force_new=False):
+ if args.restore_state and not force_new:
# i'm not happy with this behaviour that differs based on the flag status
- dataset = pickle.load(open(join(model_base, "dataset.pkl"), "rb"))
+ dataset = SourceGraphDataset.load(join(model_base, "dataset.pkl"), args)
else:
dataset = SourceGraphDataset(
# args.node_path, args.edge_path,
@@ -861,7 +1039,10 @@ def read_or_create_dataset(args, model_base, labels_from="type"):
min_count_for_objectives=args.min_count_for_objectives,
no_global_edges=args.no_global_edges,
remove_reverse=args.remove_reverse,
- package_names=open(args.packages_file).readlines() if args.packages_file is not None else None
+ custom_reverse=args.custom_reverse,
+ package_names=open(args.packages_file).readlines() if args.packages_file is not None else None,
+ restricted_id_pool=args.restricted_id_pool,
+ use_ns_groups=args.use_ns_groups
)
# save dataset state for recovery
diff --git a/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py b/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py
index 88217468..a2950b8f 100644
--- a/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py
+++ b/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py
@@ -1,7 +1,11 @@
+import shelve
+import shutil
+import tempfile
from os.path import join
from tqdm import tqdm
from SourceCodeTools.code.data.sourcetrail.common import map_offsets
+from SourceCodeTools.code.data.sourcetrail.sourcetrail_filter_type_edges import filter_type_edges
from SourceCodeTools.code.data.sourcetrail.sourcetrail_map_id_columns import map_columns
from SourceCodeTools.code.data.sourcetrail.sourcetrail_merge_graphs import get_global_node_info, merge_global_with_local
from SourceCodeTools.code.data.sourcetrail.sourcetrail_node_local2global import get_local2global
@@ -23,7 +27,8 @@ def __init__(
self, path, lang,
bpe_tokenizer, create_subword_instances,
connect_subwords, only_with_annotations,
- do_extraction=False, visualize=False, track_offsets=False
+ do_extraction=False, visualize=False, track_offsets=False, remove_type_annotations=False,
+ recompute_l2g=False
):
self.indexed_path = path
self.lang = lang
@@ -34,15 +39,35 @@ def __init__(
self.extract = do_extraction
self.visualize = visualize
self.track_offsets = track_offsets
+ self.remove_type_annotations = remove_type_annotations
+ self.recompute_l2g = recompute_l2g
paths = (os.path.join(path, dir) for dir in os.listdir(path))
self.environments = sorted(list(filter(lambda path: os.path.isdir(path), paths)), key=lambda x: x.lower())
- self.local2global_cache = {}
+ self.init_cache()
from SourceCodeTools.code.data.sourcetrail.common import UNRESOLVED_SYMBOL
self.unsolved_symbol = UNRESOLVED_SYMBOL
+ def init_cache(self):
+ char_ranges = [chr(i) for i in range(ord("a"), ord("a")+26)] + [chr(i) for i in range(ord("A"), ord("A")+26)] + [chr(i) for i in range(ord("0"), ord("0")+10)]
+ from random import sample
+ rnd_name = "".join(sample(char_ranges, k=10))
+
+ self.tmp_dir = os.path.join(tempfile.gettempdir(), rnd_name)
+ if os.path.isdir(self.tmp_dir):
+ shutil.rmtree(self.tmp_dir)
+ os.mkdir(self.tmp_dir)
+
+ self.local2global_cache_filename = os.path.join(self.tmp_dir,"local2global_cache.db")
+ self.local2global_cache = shelve.open(self.local2global_cache_filename)
+
+ def __del__(self):
+ self.local2global_cache.close()
+ shutil.rmtree(self.tmp_dir)
+ # os.remove(self.local2global_cache_filename) # TODO nofile on linux, need to check
+
def merge(self, output_directory):
if self.extract:
@@ -62,8 +87,6 @@ def merge(self, output_directory):
join(no_ast_path, "common_function_variable_pairs.bz2"), "Merging variables")
self.create_global_file("call_seq.bz2", "local2global.bz2", ['src', 'dst'],
join(no_ast_path, "common_call_seq.bz2"), "Merging call seq")
- # self.create_global_file("name_groups.bz2", "local2global.bz2", [],
- # join(no_ast_path, "name_groups.bz2"), "Merging name groups")
global_nodes = self.filter_orphaned_nodes(
unpersist(join(no_ast_path, "common_nodes.bz2")), no_ast_path
@@ -75,6 +98,8 @@ def merge(self, output_directory):
if node_names is not None:
persist(node_names, join(no_ast_path, "node_names.bz2"))
+ self.handle_parallel_edges(join(no_ast_path, "common_edges.bz2"))
+
if self.visualize:
self.visualize_func(
unpersist(join(no_ast_path, "common_nodes.bz2")),
@@ -92,8 +117,24 @@ def merge(self, output_directory):
join(with_ast_path, "common_function_variable_pairs.bz2"), "Merging variables with ast")
self.create_global_file("call_seq.bz2", "local2global_with_ast.bz2", ['src', 'dst'],
join(with_ast_path, "common_call_seq.bz2"), "Merging call seq with ast")
- # self.create_global_file("name_groups.bz2", "local2global_with_ast.bz2", [],
- # join(with_ast_path, "name_groups.bz2"), "Merging name groups")
+ self.create_global_file("offsets.bz2", "local2global_with_ast.bz2", ['node_id'],
+ join(with_ast_path, "common_offsets.bz2"), "Merging offsets with ast",
+ columns_special=[("mentioned_in", map_offsets)])
+ self.create_global_file("filecontent_with_package.bz2", "local2global_with_ast.bz2", [],
+ join(with_ast_path, "common_filecontent.bz2"), "Merging filecontents")
+
+ if self.remove_type_annotations:
+ no_annotations, annotations = filter_type_edges(
+ unpersist(join(with_ast_path, "common_nodes.bz2")),
+ unpersist(join(with_ast_path, "common_edges.bz2"))
+ )
+ persist(no_annotations, join(with_ast_path, "common_edges.bz2"))
+ if annotations is not None:
+ persist(annotations, join(with_ast_path, "type_annotations.bz2"))
+
+ self.handle_parallel_edges(join(with_ast_path, "common_edges.bz2"))
+
+ self.post_pruning(join(with_ast_path, "common_nodes.bz2"), join(with_ast_path, "common_edges.bz2"))
global_nodes = self.filter_orphaned_nodes(
unpersist(join(with_ast_path, "common_nodes.bz2")), with_ast_path
@@ -112,11 +153,76 @@ def merge(self, output_directory):
join(with_ast_path, "visualization.pdf")
)
+ def handle_parallel_edges(self, path):
+ edges = unpersist(path)
+ edges["id"] = list(range(len(edges)))
+
+ edge_priority = {
+ "next": -1, "prev": -1, "global_mention": -1, "global_mention_rev": -1,
+ "calls": 0,
+ "called_by": 0,
+ "defines": 1,
+ "defined_in": 1,
+ "inheritance": 1,
+ "inherited_by": 1,
+ "imports": 1,
+ "imported_by": 1,
+ "uses": 2,
+ "used_by": 2,
+ "uses_type": 2,
+ "type_used_by": 2, "mention_scope": 10, "mention_scope_rev": 10, "defined_in_function": 4, "defined_in_function_rev": 4, "defined_in_class": 5, "defined_in_class_rev": 5, "defined_in_module": 6, "defined_in_module_rev": 6
+ }
+
+ edge_bank = {}
+ for id, type, src, dst in edges[["id", "type", "source_node_id", "target_node_id"]].values:
+ key = (src, dst)
+ if key in edge_bank:
+ edge_bank[key].append((id, type))
+ else:
+ edge_bank[key] = [(id, type)]
+
+ ids_to_remove = set()
+ for key, parallel_edges in edge_bank.items():
+ if len(parallel_edges) > 1:
+ parallel_edges = sorted(parallel_edges, key=lambda x: edge_priority.get(x[1],3))
+ ids_to_remove.update(pe[0] for pe in parallel_edges[1:])
+
+ edges = edges[
+ edges["id"].apply(lambda id_: id_ not in ids_to_remove)
+ ]
+
+ edges["id"] = range(len(edges))
+
+ persist(edges, path)
+
+ def post_pruning(self, npath, epath):
+ nodes = unpersist(npath)
+ edges = unpersist(epath)
+
+ restricted_edges = {"global_mention_rev"}
+ restricted_in_types = {
+ "Op", "Constant", "#attr#", "#keyword#",
+ 'CtlFlow', 'JoinedStr', 'Name', 'ast_Literal',
+ 'subword', 'type_annotation'
+ }
+
+ restricted_nodes = set(nodes[
+ nodes["type"].apply(lambda type_: type_ in restricted_in_types)
+ ]["id"].tolist())
+
+ edges = edges[
+ edges["type"].apply(lambda type_: type_ not in restricted_edges)
+ ]
+
+ edges = edges[
+ edges["target_node_id"].apply(lambda type_: type_ not in restricted_nodes)
+ ]
+
+ persist(edges, epath)
def do_extraction(self):
global_nodes = None
global_nodes_with_ast = None
- name_groups = None
for env_path in self.environments:
logging.info(f"Found {os.path.basename(env_path)}")
@@ -125,43 +231,50 @@ def do_extraction(self):
logging.info("Package not indexed")
continue
- nodes, edges, source_location, occurrence, filecontent, element_component = \
- self.read_sourcetrail_files(env_path)
+ if not self.recompute_l2g:
- if nodes is None:
- logging.info("Index is empty")
- continue
+ nodes, edges, source_location, occurrence, filecontent, element_component = \
+ self.read_sourcetrail_files(env_path)
- edges = filter_ambiguous_edges(edges, element_component)
+ if nodes is None:
+ logging.info("Index is empty")
+ continue
- nodes, edges = self.filter_unsolved_symbols(nodes, edges)
+ edges = filter_ambiguous_edges(edges, element_component)
- bodies = process_bodies(nodes, edges, source_location, occurrence, filecontent, self.lang)
- call_seq = extract_call_seq(nodes, edges, source_location, occurrence)
+ nodes, edges = self.filter_unsolved_symbols(nodes, edges)
- edges = add_reverse_edges(edges)
+ bodies = process_bodies(nodes, edges, source_location, occurrence, filecontent, self.lang)
+ call_seq = extract_call_seq(nodes, edges, source_location, occurrence)
- # if bodies is not None:
- ast_nodes, ast_edges, offsets, name_group_tracker = get_ast_from_modules(
- nodes, edges, source_location, occurrence, filecontent,
- self.bpe_tokenizer, self.create_subword_instances, self.connect_subwords, self.lang,
- track_offsets=self.track_offsets
- )
- nodes_with_ast = nodes.append(ast_nodes)
- edges_with_ast = edges.append(ast_edges)
- if bodies is not None:
- vars = extract_var_names(nodes, bodies, self.lang)
- else:
- vars = None
- if name_groups is None:
- name_groups = name_group_tracker
+ edges = add_reverse_edges(edges)
+
+ # if bodies is not None:
+ ast_nodes, ast_edges, offsets = get_ast_from_modules(
+ nodes, edges, source_location, occurrence, filecontent,
+ self.bpe_tokenizer, self.create_subword_instances, self.connect_subwords, self.lang,
+ track_offsets=self.track_offsets
+ )
+
+ if offsets is not None:
+ offsets["package"] = os.path.basename(env_path)
+
+ # need this check in situations when module has a single file and this file cannot be parsed
+ nodes_with_ast = nodes.append(ast_nodes) if ast_nodes is not None else nodes
+ edges_with_ast = edges.append(ast_edges) if ast_edges is not None else edges
+
+ if bodies is not None:
+ vars = extract_var_names(nodes, bodies, self.lang)
+ else:
+ vars = None
else:
- name_groups = name_groups.append(name_group_tracker)
- # else:
- # nodes_with_ast = nodes
- # edges_with_ast = edges
- # vars = None
- # offsets = None
+ nodes = unpersist_if_present(join(env_path, "nodes.bz2"))
+ nodes_with_ast = unpersist_if_present(join(env_path, "nodes_with_ast.bz2"))
+
+ if nodes is None or nodes_with_ast is None:
+ continue
+
+ edges = bodies = call_seq = vars = edges_with_ast = offsets = None
global_nodes = self.merge_with_global(global_nodes, nodes)
global_nodes_with_ast = self.merge_with_global(global_nodes_with_ast, nodes_with_ast)
@@ -175,7 +288,7 @@ def do_extraction(self):
self.write_local(env_path, nodes, edges, bodies, call_seq, vars,
nodes_with_ast, edges_with_ast, offsets,
- local2global, local2global_with_ast, name_groups)
+ local2global, local2global_with_ast)
def get_local2global(self, path):
if path in self.local2global_cache:
@@ -239,29 +352,37 @@ def read_sourcetrail_files(self, env_path):
def write_local(self, dir, nodes, edges, bodies, call_seq, vars,
nodes_with_ast, edges_with_ast, offsets,
- local2global, local2global_with_ast, name_groups):
- write_nodes(nodes, dir)
- write_edges(edges, dir)
- if bodies is not None:
- write_processed_bodies(bodies, dir)
- if call_seq is not None:
- persist(call_seq, join(dir, filenames['call_seq']))
- if vars is not None:
- persist(vars, join(dir, filenames['function_variable_pairs']))
- persist(nodes_with_ast, join(dir, "nodes_with_ast.bz2"))
- persist(edges_with_ast, join(dir, "edges_with_ast.bz2"))
- if offsets is not None:
- persist(offsets, join(dir, "offsets.bz2"))
- if len(edges_with_ast.query("type == 'annotation_for' or type == 'returned_by'")) > 0:
- with open(join(dir, "has_annotations"), "w") as has_annotations:
- pass
+ local2global, local2global_with_ast):
+ if not self.recompute_l2g:
+ write_nodes(nodes, dir)
+ write_edges(edges, dir)
+ if bodies is not None:
+ write_processed_bodies(bodies, dir)
+ if call_seq is not None:
+ persist(call_seq, join(dir, filenames['call_seq']))
+ if vars is not None:
+ persist(vars, join(dir, filenames['function_variable_pairs']))
+ persist(nodes_with_ast, join(dir, "nodes_with_ast.bz2"))
+ persist(edges_with_ast, join(dir, "edges_with_ast.bz2"))
+ if offsets is not None:
+ persist(offsets, join(dir, "offsets.bz2"))
+ if len(edges_with_ast.query("type == 'annotation_for' or type == 'returned_by'")) > 0:
+ with open(join(dir, "has_annotations"), "w") as has_annotations:
+ pass
+
+ # add package name to filecontent
+ filecontent = read_filecontent(dir)
+ filecontent["package"] = os.path.basename(dir)
+ persist(filecontent, join(dir, "filecontent_with_package.bz2"))
+
persist(local2global, join(dir, "local2global.bz2"))
persist(local2global_with_ast, join(dir, "local2global_with_ast.bz2"))
- if name_groups is not None:
- persist(name_groups, join(dir, "name_groups.bz2"))
-
def get_global_node_info(self, global_nodes):
+ """
+ :param global_nodes: nodes from a global merged graph
+ :return: Set of existing nodes represented with (type, node_name), minimal available free id
+ """
if global_nodes is None:
existing_nodes, next_valid_id = set(), 0
else:
@@ -269,6 +390,12 @@ def get_global_node_info(self, global_nodes):
return existing_nodes, next_valid_id
def merge_with_global(self, global_nodes, local_nodes):
+ """
+ Merge nodes obtained from the source code with the previously existing nodes.
+ :param global_nodes: Nodes from a global inter-package graph
+ :param local_nodes: Nodes from a local file-level graph
+ :return: Updated version of the global inter-package graph
+ """
existing_nodes, next_valid_id = self.get_global_node_info(global_nodes)
new_nodes = merge_global_with_local(existing_nodes, next_valid_id, local_nodes)
@@ -345,14 +472,19 @@ def visualize_func(self, nodes, edges, output_path):
parser.add_argument('--do_extraction', action='store_true', default=False, help="")
parser.add_argument('--visualize', action='store_true', default=False, help="")
parser.add_argument('--track_offsets', action='store_true', default=False, help="")
-
+ parser.add_argument('--recompute_l2g', action='store_true', default=False, help="")
+ parser.add_argument('--remove_type_annotations', action='store_true', default=False, help="")
args = parser.parse_args()
+ if args.recompute_l2g:
+ args.do_extraction = True
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(message)s")
- dataset = DatasetCreator(args.indexed_environments, args.language,
- args.bpe_tokenizer, args.create_subword_instances,
- args.connect_subwords, args.only_with_annotations,
- args.do_extraction, args.visualize, args.track_offsets)
+ dataset = DatasetCreator(
+ args.indexed_environments, args.language, args.bpe_tokenizer, args.create_subword_instances,
+ args.connect_subwords, args.only_with_annotations, args.do_extraction, args.visualize, args.track_offsets,
+ args.remove_type_annotations, args.recompute_l2g
+ )
dataset.merge(args.output_directory)
diff --git a/SourceCodeTools/code/data/sourcetrail/SubwordMasker.py b/SourceCodeTools/code/data/sourcetrail/SubwordMasker.py
index 89607ad3..509ac997 100644
--- a/SourceCodeTools/code/data/sourcetrail/SubwordMasker.py
+++ b/SourceCodeTools/code/data/sourcetrail/SubwordMasker.py
@@ -103,4 +103,25 @@ def instantiate(self, nodes, orig_edges, **kwargs):
self.lookup[key].extend([subword2key[sub] for sub in subwords if sub in subword2key])
+class NodeClfMasker(SubwordMasker):
+ """
+ Masker that tells which node ids are subwords for variables mentioned in a given function.
+ """
+ def __init__(self, nodes: pd.DataFrame, edges: pd.DataFrame, **kwargs):
+ super(NodeClfMasker, self).__init__(nodes, edges, **kwargs)
+
+ def instantiate(self, nodes, orig_edges, **kwargs):
+ pass
+
+ def get_mask(self, ids):
+ """
+ Accepts node ids that represent embeddable tokens as an input
+ :param ids:
+ :return:
+ """
+ if isinstance(ids, dict):
+ for_masking = ids
+ else:
+ for_masking = {"node_": ids}
+ return for_masking
diff --git a/SourceCodeTools/code/data/sourcetrail/common.py b/SourceCodeTools/code/data/sourcetrail/common.py
index 8cc28b4c..f0f5b2d5 100644
--- a/SourceCodeTools/code/data/sourcetrail/common.py
+++ b/SourceCodeTools/code/data/sourcetrail/common.py
@@ -26,6 +26,14 @@ def __del__(self):
def get_occurrence_groups(nodes, edges, source_location, occurrence):
+ """
+ Group nodes based on file id. Return dataset that contains node ids and their offsets in the source code.
+ :param nodes: dataframe with nodes
+ :param edges: dataframe with edges
+ :param source_location: dataframe with sources
+ :param occurrence: dataframe with with offsets
+ :return: Result of group by file id
+ """
edges = edges.rename(columns={'type': 'e_type'})
edges = edges.query("id >= 0") # filter reverse edges
diff --git a/SourceCodeTools/code/data/sourcetrail/deprecated/sourcetrail_extract_type_information.py b/SourceCodeTools/code/data/sourcetrail/deprecated/sourcetrail_extract_type_information.py
deleted file mode 100644
index 590f4be9..00000000
--- a/SourceCodeTools/code/data/sourcetrail/deprecated/sourcetrail_extract_type_information.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import sys
-
-from SourceCodeTools.code.data.sourcetrail.file_utils import *
-
-
-def main():
- edges = unpersist(sys.argv[1])
- out_annotations = sys.argv[2]
- out_no_annotations = sys.argv[3]
-
- annotations = edges.query(f"type == 'annotation_for' or type == 'returned_by'")
- no_annotations = edges.query(f"type != 'annotation_for' and type != 'returned_by'")
-
- persist(annotations, out_annotations)
- persist(no_annotations, out_no_annotations)
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/SourceCodeTools/code/data/sourcetrail/file_utils.py b/SourceCodeTools/code/data/sourcetrail/file_utils.py
index 05287e26..4893c9d5 100644
--- a/SourceCodeTools/code/data/sourcetrail/file_utils.py
+++ b/SourceCodeTools/code/data/sourcetrail/file_utils.py
@@ -139,7 +139,7 @@ def write_processed_bodies(df, base_path):
def persist(df: pd.DataFrame, path: str, **kwargs):
- if path.endswith(".csv"):
+ if path.endswith(".csv") or path.endswith(".tsv"):
write_csv(df, path, **kwargs)
elif path.endswith(".pkl") or path.endswith(".bz2"):
write_pickle(df, path, **kwargs)
diff --git a/SourceCodeTools/code/data/sourcetrail/k_hop_graph.py b/SourceCodeTools/code/data/sourcetrail/k_hop_graph.py
new file mode 100644
index 00000000..45b59259
--- /dev/null
+++ b/SourceCodeTools/code/data/sourcetrail/k_hop_graph.py
@@ -0,0 +1,66 @@
+from os.path import join
+
+import networkx as nx
+from tqdm import tqdm
+
+from SourceCodeTools.code.data.sourcetrail.Dataset import load_data
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("working_directory")
+ parser.add_argument("k_hops", type=int)
+ parser.add_argument("output")
+
+ args = parser.parse_args()
+
+ nodes, edges = load_data(
+ join(args.working_directory, "common_nodes.bz2"), join(args.working_directory, "common_edges.bz2")
+ )
+
+ edge_types = {}
+ edge_lists = {}
+ for s, d, t in edges[["src", "dst", "type"]].values:
+ edge_types[(s,d)] = t
+ if s not in edge_lists:
+ edge_lists[s] = []
+ edge_lists[s].append(d)
+
+ g = nx.from_pandas_edgelist(
+ edges, source="src", target="dst", create_using=nx.DiGraph, edge_attr="type"
+ )
+
+ # def expand_edges(node_id, view, edge_prefix, level=0):
+ # edges = []
+ # if level <= args.k_hops:
+ # if edge_prefix != "":
+ # edge_prefix += "|"
+ # for e in view:
+ # edges.append((node_id, e, edge_prefix + view[e]["type"]))
+ # edges.extend(expand_edges(node_id, g[e], edge_prefix + view[e]["type"], level=level+1))
+ # return edges
+ #
+ # edges = []
+ # for node in tqdm(g.nodes):
+ # edges.extend(expand_edges(node, g[node], "", level=0))
+
+ def expand_edges(node_id, s, dlist, edge_prefix, level=0):
+ edges = []
+ if level <= args.k_hops:
+ if edge_prefix != "":
+ edge_prefix += "|"
+ for d in dlist:
+ etype = edge_prefix + edge_types[(s,d)]
+ edges.append((node_id, d, etype))
+ edges.extend(expand_edges(node_id, d, edge_lists[d], etype, level=level+1))
+ return edges
+
+ edges = []
+ for node in tqdm(edge_lists):
+ edges.extend(expand_edges(node, node, edge_lists[node], "", level=0))
+
+ print()
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/code/data/sourcetrail/nodes_of_interest_from_dataset.py b/SourceCodeTools/code/data/sourcetrail/nodes_of_interest_from_dataset.py
new file mode 100644
index 00000000..d9e69bcc
--- /dev/null
+++ b/SourceCodeTools/code/data/sourcetrail/nodes_of_interest_from_dataset.py
@@ -0,0 +1,29 @@
+import json
+
+
+def get_node_ids_from_dataset(dataset_path):
+ node_ids = []
+ with open(dataset_path, "r") as dataset:
+ for line in dataset:
+ entry = json.loads(line)
+ for _, _, id_ in entry["replacements"]:
+ node_ids.append(int(id_))
+ return node_ids
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("dataset")
+ parser.add_argument("output")
+ args = parser.parse_args()
+
+ node_ids = get_node_ids_from_dataset(args.dataset)
+ with open(args.output, "w") as sink:
+ sink.write("node_id\n")
+ for id_ in node_ids:
+ sink.write(f"{id_}\n")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_analyze_tree_depth.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_analyze_tree_depth.py
new file mode 100644
index 00000000..f96df56f
--- /dev/null
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_analyze_tree_depth.py
@@ -0,0 +1,46 @@
+import ast
+import os
+from typing import Iterable
+
+from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist
+import numpy as np
+
+class DepthEstimator:
+ def __init__(self):
+ self.depth = 0
+
+ def go(self, node, depth=0):
+ depth += 1
+ if depth > self.depth:
+ self.depth = depth
+ if isinstance(node, Iterable) and not isinstance(node, str):
+ for subnode in node:
+ self.go(subnode, depth=depth)
+ if hasattr(node, "_fields"):
+ for field in node._fields:
+ self.go(getattr(node, field), depth=depth)
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("bodies")
+ args = parser.parse_args()
+
+ bodies = unpersist(args.bodies)
+
+ depths = []
+
+ for ind, row in bodies.iterrows():
+ body = row.body
+ body_ast = ast.parse(body.strip())
+ de = DepthEstimator()
+ de.go(body_ast)
+ depths.append(de.depth)
+
+ print(f"Average depth: {sum(depths)/len(depths)}")
+ depths = np.array(depths, dtype=np.int32)
+ np.savetxt(os.path.join(os.path.dirname(args.bodies), "bodies_depths.txt"), depths, "%d")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py
index 14cf586d..7a5e4ad7 100644
--- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py
@@ -423,6 +423,7 @@ def make_reverse_edge(edge):
rev_edge['type'] = edge['type'] + "_rev"
rev_edge['src'] = edge['dst']
rev_edge['dst'] = edge['src']
+ rev_edge['offsets'] = None
return rev_edge
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py
index d852163b..7f796ff8 100644
--- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py
@@ -9,6 +9,7 @@
from SourceCodeTools.code.data.sourcetrail.sourcetrail_add_reverse_edges import add_reverse_edges
from SourceCodeTools.code.data.sourcetrail.sourcetrail_ast_edges import NodeResolver, make_reverse_edge
from SourceCodeTools.code.python_ast import AstGraphGenerator, GNode, PythonSharedNodes
+# from SourceCodeTools.code.python_ast_cf import AstGraphGenerator
from SourceCodeTools.nlp.entity.annotator.annotator_utils import adjust_offsets2
from SourceCodeTools.nlp.entity.annotator.annotator_utils import overlap as range_overlap
from SourceCodeTools.nlp.entity.annotator.annotator_utils import to_offsets, get_cum_lens
@@ -27,6 +28,11 @@ def __init__(self, bpe_tokenizer_path, create_subword_instances, connect_subword
self.connect_subwords = connect_subwords
def replace_mentions_with_subwords(self, edges):
+ """
+ Process edges and tokenize certain node types
+ :param edges: List of edges
+ :return: List of edges, including new edges for subword tokenization
+ """
if self.create_subword_instances:
def produce_subw_edges(subwords, dst):
@@ -201,6 +207,13 @@ def get_node_types(nodes):
return id2type
def match_with_global_nodes(self, nodes, edges):
+ """
+ Match AST nodes that represent functions, classes. and modules with nodes from global graph.
+ This information is not stored explicitly and need to walk the graph to resolve.
+ :param nodes:
+ :param edges:
+ :return: Mapping from AST nodes to corresponding global nodes
+ """
nodes = pd.DataFrame([node for node in nodes if node["type"] in self.allowed_ast_node_types])
edges = pd.DataFrame([edge for edge in edges if edge["type"] in self.allowed_ast_edge_types])
@@ -216,8 +229,8 @@ def match_with_global_nodes(self, nodes, edges):
def find_global_id(graph, def_id, motif):
"""
- Do function or class global id lookup
- :param graph:
+ Perform function or class global id lookup. Need this because
+ :param graph: nx graph
:param def_id:
:param motif: list with edge types, return None if path does not exist
:return:
@@ -309,36 +322,42 @@ def merge_global_references(global_references, module_global_references):
class ReplacementNodeResolver(NodeResolver):
- def __init__(self, nodes):
-
- self.nodeid2name = dict(zip(nodes['id'].tolist(), nodes['serialized_name'].tolist()))
- self.nodeid2type = dict(zip(nodes['id'].tolist(), nodes['type'].tolist()))
+ def __init__(self, nodes=None):
+
+ if nodes is not None:
+ self.nodeid2name = dict(zip(nodes['id'].tolist(), nodes['serialized_name'].tolist()))
+ self.nodeid2type = dict(zip(nodes['id'].tolist(), nodes['type'].tolist()))
+ self.valid_new_node = nodes['id'].max() + 1
+ self.old_nodes = nodes.copy()
+ self.old_nodes['mentioned_in'] = pd.NA
+ self.old_nodes['string'] = pd.NA
+ self.old_nodes = self.old_nodes.astype({'mentioned_in': 'Int32'})
+ self.old_nodes = self.old_nodes.astype({'string': 'string'})
+ else:
+ self.nodeid2name = dict()
+ self.nodeid2type = dict()
+ self.valid_new_node = 0
+ self.old_nodes = pd.DataFrame()
- self.valid_new_node = nodes['id'].max() + 1
self.node_ids = {}
self.new_nodes = []
self.stashed_nodes = []
- self.old_nodes = nodes.copy()
- self.old_nodes['mentioned_in'] = pd.NA
- self.old_nodes = self.old_nodes.astype({'mentioned_in': 'Int32'})
-
def stash_new_nodes(self):
+ """
+ Put new nodes into temporary storage.
+ :return: Nothing
+ """
self.stashed_nodes.extend(self.new_nodes)
self.new_nodes = []
- def resolve_substrings(self, node, replacement2srctrl):
-
- decorated = "@" in node.name
- assert not decorated
-
- name_ = copy(node.name)
-
+ def recover_original_string(self, name_, replacement2srctrl):
replacements = dict()
global_node_id = []
global_name = []
global_type = []
- for name in re.finditer("srctrlrpl_[0-9]+", name_):
+
+ for name in re.finditer("srctrlrpl_[0-9]{19}", name_):
if isinstance(name, re.Match):
name = name.group()
elif isinstance(name, str):
@@ -363,9 +382,25 @@ def resolve_substrings(self, node, replacement2srctrl):
real_name = name_
for r, v in replacements.items():
real_name = real_name.replace(r, v["name"])
+ return real_name, global_name, global_node_id, global_type
+
+ def resolve_substrings(self, node, replacement2srctrl):
+
+ decorated = "@" in node.name
+
+ name_ = copy(node.name)
+ if decorated:
+ name_, decorator = name_.split("@")
+ # assert not decorated
+
+ real_name, global_name, global_node_id, global_type = self.recover_original_string(name_, replacement2srctrl)
+
+ if decorated:
+ real_name = real_name + "@" + decorator
return GNode(
- name=real_name, type=node.type, global_name=global_name, global_id=global_node_id, global_type=global_type
+ name=real_name, type=node.type, global_name=global_name, global_id=global_node_id, global_type=global_type,
+ string=node.string
)
def resolve_regular_replacement(self, node, replacement2srctrl):
@@ -408,28 +443,49 @@ def resolve_regular_replacement(self, node, replacement2srctrl):
type_ = node.type
if hasattr(node, "scope"):
new_node = GNode(name=real_name, type=type_, global_name=global_name, global_id=global_node_id,
- global_type=global_type, scope=node.scope)
+ global_type=global_type, scope=node.scope, string=node.string)
else:
new_node = GNode(name=real_name, type=type_, global_name=global_name, global_id=global_node_id,
- global_type=global_type)
+ global_type=global_type, string=node.string)
else:
new_node = node
return new_node
def resolve(self, node, replacement2srctrl):
+ """
+ :param node: string that represents node
+ :param replacement2srctrl: dictionary of {sourcetrail_node: original name}
+ :return: resolved string with original node name
+ """
+ # this function is defined in this class instead of SourceTrail resolver because it needs access to
+ # global node ids
if node.type == "type_annotation":
new_node = self.resolve_substrings(node, replacement2srctrl)
+ if "srctrlrpl_" in new_node.name:
+ new_node.name = "Any"
else:
new_node = self.resolve_regular_replacement(node, replacement2srctrl)
- if "srctrlrpl_" in new_node.name: # hack to process imports
+ if "srctrlrpl_" in new_node.name: # hack to process imports TODO maybe need to use while to resolve everything?
new_node = self.resolve_substrings(node, replacement2srctrl)
+ if "srctrlrpl_" in new_node.name: # if still failed to resolve
+ new_node.name = "unresolved_name"
+
+ if node.string is not None:
+ string,_,_,_ = self.recover_original_string(node.string, replacement2srctrl)
+ node.string = string
- assert "srctrlrpl_" not in new_node.name
+ # assert "srctrlrpl_" not in new_node.name
return new_node
def resolve_node_id(self, node, **kwargs):
+ """
+ Resolve node id from name and type, create new node is no nodes like that found.
+ :param node: node
+ :param kwargs:
+ :return: updated node (return object with the save reference as input)
+ """
if not hasattr(node, "id"):
node_repr = (node.name.strip(), node.type.strip())
@@ -440,7 +496,7 @@ def resolve_node_id(self, node, **kwargs):
new_id = self.get_new_node_id()
self.node_ids[node_repr] = new_id
- if not PythonSharedNodes.is_shared(node):
+ if not PythonSharedNodes.is_shared(node) and not node.name == "unresolved_name":
assert "0x" in node.name
self.new_nodes.append(
@@ -448,7 +504,8 @@ def resolve_node_id(self, node, **kwargs):
"id": new_id,
"type": node.type,
"serialized_name": node.name,
- "mentioned_in": pd.NA
+ "mentioned_in": pd.NA,
+ "string": node.string
}
)
if hasattr(node, "scope"):
@@ -459,15 +516,19 @@ def resolve_node_id(self, node, **kwargs):
def prepare_for_write(self, from_stashed=False):
nodes = pd.concat([self.old_nodes, self.new_nodes_for_write(from_stashed)])[
- ['id', 'type', 'serialized_name', 'mentioned_in']
+ ['id', 'type', 'serialized_name', 'mentioned_in', 'string']
]
return nodes
def new_nodes_for_write(self, from_stashed=False):
- new_nodes = pd.DataFrame(self.new_nodes if not from_stashed else self.stashed_nodes)[
- ['id', 'type', 'serialized_name', 'mentioned_in']
+ new_nodes = pd.DataFrame(self.new_nodes if not from_stashed else self.stashed_nodes)
+ if len(new_nodes) == 0:
+ return None
+
+ new_nodes = new_nodes[
+ ['id', 'type', 'serialized_name', 'mentioned_in', 'string']
].astype({"mentioned_in": "Int32"})
return new_nodes
@@ -609,6 +670,11 @@ def into_offset(range):
class SourcetrailResolver:
+ """
+ Helper class to work with source code stored in Sourcetrail database. Implements functions of
+ - Iterating over files
+ - Preserving Sourcetrail nodes
+ """
def __init__(self, nodes, edges, source_location, occurrence, file_content, lang):
self.nodes = nodes
self.node2name = dict(zip(nodes['id'], nodes['serialized_name']))
@@ -622,6 +688,9 @@ def __init__(self, nodes, edges, source_location, occurrence, file_content, lang
@property
def occurrence_groups(self):
+ """
+ :return: Iterator for occurrences grouped by file id.
+ """
if self._occurrence_groups is None:
self._occurrence_groups = get_occurrence_groups(self.nodes, self.edges, self.source_location, self.occurrence)
return self._occurrence_groups
@@ -709,11 +778,18 @@ def add_names_from_edges(self, edges):
def to_df(self):
return pd.DataFrame(self.group2names)
+
def global_mention_edges_from_node(node):
+ """
+ Construct a new edge that will link to a global node.
+ :param node:
+ :return:
+ """
global_edges = []
if type(node.global_id) is int:
id_type = [(node.global_id, node.global_type)]
else:
+ # in case there are many global references
id_type = zip(node.global_id, node.global_type)
for gid, gtype in id_type:
@@ -727,7 +803,13 @@ def global_mention_edges_from_node(node):
global_edges.append(make_reverse_edge(global_mention))
return global_edges
+
def add_global_mentions(edges):
+ """
+ :param edges: List of dictionaries that represent edges
+ :return: Original edges with additional edges for global references (e.g. global_mention edge
+ for function calls)
+ """
new_edges = []
for edge in edges:
if edge['src'].type in {"#attr#", "Name"}:
@@ -777,6 +859,13 @@ def produce_nodes_without_name(global_nodes, ast_edges):
def standardize_new_edges(edges, node_resolver, mention_tokenizer):
+ """
+ Tokenize relevant node names, assign id to every node, collapse edge representation to id-based
+ :param edges: list of edges
+ :param node_resolver: helper class that tracks node ids
+ :param mention_tokenizer: helper class that performs tokenization of relevant nodes
+ :return:
+ """
edges = mention_tokenizer.replace_mentions_with_subwords(edges)
@@ -801,13 +890,28 @@ def standardize_new_edges(edges, node_resolver, mention_tokenizer):
return edges
-def process_code(source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, named_group_tracker, track_offsets=False):
+def process_code(source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, track_offsets=False):
+ """
+
+ :param source_file_content: String for a file from SOurcetrail database
+ :param offsets: List of global occurrences in this file
+ :param node_resolver:
+ :param mention_tokenizer:
+ :param node_matcher:
+ :param track_offsets: Flag that tells whether to perform offset tracking.
+ :return: Tuple that stores egdes, list of global and ast offsets for the current file in the format
+ (start, end, node_id, set(offsets for functions where the current offset occurs),
+ mapping from AST nodes to corresponding global nodes (will be used to replace one with the other).
+ """
# replace global occurrences with special tokens to help further parsing with ast package
replacer = OccurrenceReplacer()
replacer.perform_replacements(source_file_content, offsets)
# compute ast edges
- ast_processor = AstProcessor(replacer.source_with_replacements)
+ try:
+ ast_processor = AstProcessor(replacer.source_with_replacements)
+ except:
+ return None, None, None
try: # TODO recursion error does not appear consistently. The issue is probably with library versions...
edges = ast_processor.get_edges(as_dataframe=False)
except RecursionError:
@@ -816,7 +920,7 @@ def process_code(source_file_content, offsets, node_resolver, mention_tokenizer,
if len(edges) == 0:
return None, None, None
- # resolve existing node names (primarily for subwords)
+ # resolve sourcetrail nodes ans replace with original
resolve = lambda node: node_resolver.resolve(node, replacer.replacement_index)
for edge in edges:
@@ -828,23 +932,32 @@ def process_code(source_file_content, offsets, node_resolver, mention_tokenizer,
# insert global mentions using replacements that were created on the first step
edges = add_global_mentions(edges)
- named_group_tracker.add_names_from_edges(edges)
+ # It seems I do not need this feature
+ # named_group_tracker.add_names_from_edges(edges)
+ # tokenize names, replace nodes with their ids
edges = standardize_new_edges(edges, node_resolver, mention_tokenizer)
if track_offsets:
def get_valid_offsets(edges):
+ """
+ :param edges: Dictionary that represents edge. Information is tored in edges but is related to source node
+ :return: Information about location of this edge (offset) in the source file in fromat (start, end, node_id)
+ """
return [(edge["offsets"][0], edge["offsets"][1], edge["src"]) for edge in edges if edge["offsets"] is not None]
+ # recover ast offsets for the current file
ast_offsets = replacer.recover_offsets_with_edits2(get_valid_offsets(edges))
def merge_global_and_ast_offsets(ast_offsets, global_offsets, definitions):
"""
- Merge local and global offsets and add information about the scope
- :param ast_offsets:
- :param global_offsets:
- :param definitions:
- :return:
+ Merge local and global offsets and add information about the scope. Preserve the information that
+ indicates a function where the occurrence takes place.
+ :param ast_offsets: List of offsets in format (start, end, node_id)
+ :param global_offsets: List of offsets in format (start, end, node_id)
+ :param definitions: offsets for function declarations
+ :return: List of offsets in format [start, end, node_id, set(function_offset)]. Each offset can belong to
+ several functions in the case of nested functions.
"""
offsets = [[*offset, set()] for offset in ast_offsets]
offsets = offsets + [[*offset, set()] for offset in global_offsets]
@@ -864,6 +977,7 @@ def merge_global_and_ast_offsets(ast_offsets, global_offsets, definitions):
else:
global_and_ast_offsets = None
+ # Get mapping from AST nodes to global nodes
ast_nodes_to_srctrl_nodes = node_matcher.match_with_global_nodes(node_resolver.new_nodes, edges)
return edges, global_and_ast_offsets, ast_nodes_to_srctrl_nodes
@@ -873,12 +987,27 @@ def get_ast_from_modules(
nodes, edges, source_location, occurrence, file_content,
bpe_tokenizer_path, create_subword_instances, connect_subwords, lang, track_offsets=False
):
-
+ """
+ Create edges from source code and methe them wit hthe global graph. Prepare all offsets in the uniform format.
+ :param nodes: DataFrame with nodes
+ :param edges: DataFrame with edges
+ :param source_location: Dataframe that links nodes to source files
+ :param occurrence: Dataframe with records about occurrences
+ :param file_content: Dataframe with sources
+ :param bpe_tokenizer_path: path to sentencepiece model
+ :param create_subword_instances:
+ :param connect_subwords:
+ :param lang:
+ :param track_offsets:
+ :return: Tuple:
+ - Dataframe with all nodes. Schema: id, type, name, mentioned_in (global and AST)
+ - Dataframe with all edges. Schema: id, type, src, dst, file_id, mentioned_in
+ - Dataframe with all_offsets. Schema: file_id, start, end, node_id, mentioned_in
+ """
srctrl_resolver = SourcetrailResolver(nodes, edges, source_location, occurrence, file_content, lang)
node_resolver = ReplacementNodeResolver(nodes)
node_matcher = GlobalNodeMatcher(nodes, edges) # add_reverse_edges(edges))
mention_tokenizer = MentionTokenizer(bpe_tokenizer_path, create_subword_instances, connect_subwords)
- name_group_tracker = NameGroupTracker()
all_ast_edges = []
all_global_references = {}
all_offsets = []
@@ -898,7 +1027,7 @@ def get_ast_from_modules(
# process code
# try:
edges, global_and_ast_offsets, ast_nodes_to_srctrl_nodes = process_code(
- source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, name_group_tracker, track_offsets=track_offsets
+ source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, track_offsets=track_offsets
)
# except SyntaxError:
# logging.warning(f"Error processing file_id {file_id}")
@@ -918,6 +1047,12 @@ def get_ast_from_modules(
node_matcher.merge_global_references(all_global_references, ast_nodes_to_srctrl_nodes)
def format_offsets(global_and_ast_offsets, target):
+ """
+ Format offset as a record and add to the common storage for offsets
+ :param global_and_ast_offsets:
+ :param target: List where all other offsets are stored.
+ :return: Nothing
+ """
if global_and_ast_offsets is not None:
for offset in global_and_ast_offsets:
target.append({
@@ -939,7 +1074,7 @@ def replace_ast_node_to_global(edges, mapping):
if "scope" in edge:
edge["scope"] = mapping.get(edge["scope"], edge["scope"])
- replace_ast_node_to_global(all_ast_edges, all_global_references)
+ # replace_ast_node_to_global(all_ast_edges, all_global_references) # disabled to keep global nodes separate
def create_subwords_for_global_nodes():
all_ast_edges.extend(
@@ -947,7 +1082,7 @@ def create_subwords_for_global_nodes():
node_resolver, mention_tokenizer))
node_resolver.stash_new_nodes()
- create_subwords_for_global_nodes()
+ # create_subwords_for_global_nodes() # disabled to keep global nodes separate
def prepare_new_nodes(node_resolver):
@@ -960,11 +1095,20 @@ def prepare_new_nodes(node_resolver):
}, from_stashed=True
)
node_resolver.map_mentioned_in_to_global(all_global_references, from_stashed=True)
+ # TODO since some function definitions and modules are dropped, need to rename decorated nodes as well
+ # find old name in node_resolver.stashed nodes, go over all nodes and replace corresponding substrings in names
node_resolver.drop_nodes(set(all_global_references.keys()), from_stashed=True)
- prepare_new_nodes(node_resolver)
+ for offset in all_offsets:
+ offset["node_id"] = all_global_references.get(offset["node_id"], offset["node_id"])
+ offset["mentioned_in"] = [(e[0], e[1], all_global_references.get(e[2], e[2])) for e in offset["mentioned_in"]]
+
+
+ # prepare_new_nodes(node_resolver) # disabled to keep global nodes separate
all_ast_nodes = node_resolver.new_nodes_for_write(from_stashed=True)
+ if all_ast_nodes is None:
+ return None, None, None
def prepare_edges(all_ast_edges):
all_ast_edges = pd.DataFrame(all_ast_edges)
@@ -983,7 +1127,7 @@ def prepare_edges(all_ast_edges):
else:
all_offsets = None
- return all_ast_nodes, all_ast_edges, all_offsets, name_group_tracker.to_df()
+ return all_ast_nodes, all_ast_edges, all_offsets
class OccurrenceReplacer:
@@ -1190,7 +1334,7 @@ def recover_offsets_with_edits2(self, offsets):
edges = read_edges(working_directory)
file_content = read_filecontent(working_directory)
- ast_nodes, ast_edges, offsets, ng = get_ast_from_modules(nodes, edges, source_location, occurrence, file_content,
+ ast_nodes, ast_edges, offsets = get_ast_from_modules(nodes, edges, source_location, occurrence, file_content,
args.bpe_tokenizer, args.create_subword_instances,
args.connect_subwords, args.lang)
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py
index 731b93ef..63d50e43 100644
--- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py
@@ -2,6 +2,7 @@
import json
import logging
import os
+from os.path import join
import pandas as pd
from tqdm import tqdm
@@ -82,19 +83,34 @@ def correct_entities(entities, removed_offsets):
for offset_len, offset in zip(offset_lens, offsets_sorted):
new_entities = []
for entity in for_correction:
- if offset[0] <= entity[0] and offset[1] <= entity[0]:
+ if offset[0] <= entity[0] and offset[1] <= entity[0]: # removed span is to the left of the entitity
if len(entity) == 2:
new_entities.append((entity[0] - offset_len, entity[1] - offset_len))
elif len(entity) == 3:
new_entities.append((entity[0] - offset_len, entity[1] - offset_len, entity[2]))
else:
raise Exception("Invalid entity size")
- elif offset[0] >= entity[1] and offset[1] >= entity[1]:
+ elif offset[0] >= entity[1] and offset[1] >= entity[1]: # removed span is to the right of the entitity
new_entities.append(entity)
- elif offset[0] <= entity[1] <= offset[1] or offset[0] <= entity[0] <= offset[1]:
- pass # likely to be a type annotation being removed
+ elif offset[0] <= entity[0] <= offset[1] and offset[0] <= entity[1] <= offset[1]: # removed span covers the entity
+ pass
+ elif offset[0] <= entity[0] <= offset[1] and entity[0] <= offset[1] <= entity[1]: # removed span overlaps on the left
+ if len(entity) == 3:
+ new_entities.append((entity[0] - offset_len + offset[1] - entity[1], entity[1] - offset_len, entity[2]))
+ elif len(entity) == 2:
+ new_entities.append((entity[0] - offset_len + offset[1] - entity[1], entity[1] - offset_len))
+ else:
+ raise Exception("Invalid entity size")
+ elif entity[0] <= offset[0] <= entity[1] and entity[0] <= entity[1] <= offset[1]: # removed span overlaps on the right
+ if len(entity) == 3:
+ new_entities.append((entity[0], entity[1] - offset_len + offset[1] - entity[1], entity[2]))
+ elif len(entity) == 2:
+ new_entities.append((entity[0], entity[1] - offset_len + offset[1] - entity[1]))
+ else:
+ raise Exception("Invalid entity size")
else:
- raise Exception("Invalid data?")
+ logging.warning(f"Encountered invalid offset: {entity}")
+ # raise Exception("Invalid data?")
for_correction = new_entities
@@ -160,6 +176,9 @@ def unpack_returns(body: str, labels: pd.DataFrame):
:param labels: DataFrame with information about return type annotation
:return: Trimmed body and list of return types (normally one).
"""
+ if labels is None:
+ return [], []
+
returns = []
for ind, row in labels.iterrows():
@@ -193,6 +212,22 @@ def unpack_returns(body: str, labels: pd.DataFrame):
return ret, cuts
+def get_defaults_spans(body):
+ root = ast.parse(body)
+ defaults_offsets = to_offsets(
+ body,
+ [(arg.lineno-1, arg.end_lineno-1, arg.col_offset, arg.end_col_offset, "default") for arg in root.body[0].args.defaults],
+ as_bytes=True
+ )
+
+ extended = []
+ for start, end, label in defaults_offsets:
+ while body[start] != "=":
+ start -= 1
+ extended.append((start, end))
+ return extended
+
+
def unpack_annotations(body, labels):
"""
Use information from ast package to strip type annotation from function body
@@ -200,6 +235,11 @@ def unpack_annotations(body, labels):
:param labels: DataFrame with information about type annotations
:return: Trimmed body and list of annotations.
"""
+ if labels is None:
+ return [], []
+
+ global remove_default
+
variables = []
annotations = []
@@ -215,6 +255,7 @@ def unpack_annotations(body, labels):
# but type annotations usually appear in the end of signature and in the beginnig of a line
variables = to_offsets(body, variables, as_bytes=True)
annotations = to_offsets(body, annotations, as_bytes=True)
+ defaults_spans = get_defaults_spans(body)
cuts = []
vars = []
@@ -236,9 +277,20 @@ def unpack_annotations(body, labels):
assert offset_var[0] != len(head)
vars.append((offset_var[0], beginning, preprocess(body[offset_ann[0]: offset_ann[1]])))
+ if remove_default:
+ cuts.extend(defaults_spans)
+
return vars, cuts
+def body_valid(body):
+ try:
+ ast.parse(body)
+ return True
+ except:
+ return False
+
+
def process_body(nlp, body: str, replacements=None):
"""
Extract annotation information, strip documentation and type annotations.
@@ -267,16 +319,20 @@ def process_body(nlp, body: str, replacements=None):
body_, replacements, docstrings = remove_offsets(body_, replacements, docsting_offsets)
entry['docstrings'].extend(docstrings)
+ was_valid = body_valid(body_)
initial_labels = get_initial_labels(body_)
- if initial_labels is None:
- return None
+ # if initial_labels is None:
+ # return None
returns, return_cuts = unpack_returns(body_, initial_labels)
annotations, annotation_cuts = unpack_annotations(body_, initial_labels)
body_, replacements_annotations, _ = remove_offsets(body_, replacements + annotations,
return_cuts + annotation_cuts)
+ is_valid = body_valid(body_)
+ if was_valid != is_valid:
+ raise Exception()
replacements_annotations = adjust_offsets2(replacements_annotations, len(initial_strip))
body_ = initial_strip + body_
@@ -439,11 +495,86 @@ def load_local2global(working_directory):
return data
-def main():
+def iterate_functions(offsets, nodes, filecontent):
+
+ allowed_entity_types = {"class_method", "function"}
+
+ for package_id in offsets:
+ content = filecontent[package_id]
+
+ # entry is a function or a class
+ for (entity_start, entity_end, entity_node_id), entity_offsets in offsets[package_id].items():
+ if nodes[entity_node_id][1] in allowed_entity_types:
+ body = content[entity_start: entity_end]
+ adjusted_entity_offsets = adjust_offsets2(entity_offsets, -entity_start)
+
+ yield body, adjusted_entity_offsets
+
+
+def get_node_maps(nodes):
+ return dict(zip(nodes["id"], zip(nodes["serialized_name"], nodes["type"])))
+
+
+def get_filecontent_maps(filecontent):
+ return dict(zip(zip(filecontent["package"], filecontent["id"]), filecontent["content"]))
+
+
+def group_offsets(offsets):
+ """
+ :param offsets: Dataframe with offsets
+ :return: offsets grouped first by package name and file id, and then by the entity in which they occur.
+ """
+ # This function will process all function that have graph annotations. If there are no
+ # annotations - the function is not processed.
+ offsets_grouped = {}
+
+ for file_id, start, end, node_id, mentioned_in, package in offsets.values:
+ package_id = (package, file_id)
+ if package_id not in offsets_grouped:
+ offsets_grouped[package_id] = {}
+
+ offset_ent = (start, end, node_id)
+
+ for e in mentioned_in:
+ if e not in offsets_grouped[package_id]:
+ offsets_grouped[package_id][e] = []
+
+ offsets_grouped[package_id][e].append(offset_ent)
+
+ return offsets_grouped
+
+
+def create_from_dataset():
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("dataset_path", type=str, help="")
+ parser.add_argument("output_path", type=str, help="")
+ parser.add_argument("--format", "-f", dest="format", default="jsonl", help="jsonl|csv")
+ parser.add_argument("--remove_default", action="store_true", default=False)
+
+ args = parser.parse_args()
+
+ global remove_default
+ remove_default = args.remove_default
+
+ node_maps = get_node_maps(unpersist(join(args.dataset_path, "common_nodes.bz2")))
+ filecontent = get_filecontent_maps(unpersist(join(args.dataset_path, "common_filecontent.bz2")))
+ offsets = group_offsets(unpersist(join(args.dataset_path, "common_offsets.bz2")))
+
+ data = []
+ nlp = create_tokenizer("spacy")
+
+ for ind, (f_body, f_offsets) in enumerate(iterate_functions(offsets, node_maps, filecontent)):
+ data.append(process_body(nlp, f_body, replacements=f_offsets))
+
+ store(data, args)
+
+
+def create_from_environments():
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("packages", type=str, help="")
- parser.add_argument("output_dataset", type=str, help="")
+ parser.add_argument("output_path", type=str, help="")
parser.add_argument("--format", "-f", dest="format", default="jsonl", help="jsonl|csv")
parser.add_argument("--global_nodes", "-g", dest="global_nodes", default=None)
@@ -460,17 +591,23 @@ def main():
data.extend(process_package(working_directory=pkg_path, global_names=global_names))
+ store(data, args)
+
+
+def store(data, args):
if args.format == "jsonl": # jsonl format is used by spacy
- with open(args.output_dataset, "w") as sink:
+ with open(args.output_path, "w") as sink:
for entry in data:
sink.write(f"{json.dumps(entry)}\n")
elif args.format == "csv":
- if os.path.isfile(args.output_dataset):
+ if os.path.isfile(args.output_path):
header = False
else:
header = True
- pd.DataFrame(data).to_csv(args.output_dataset, index=False, header=header)
+ pd.DataFrame(data).to_csv(args.output_path, index=False, header=header)
if __name__ == "__main__":
- main()
+ # create_from_environments()
+ remove_default = False
+ create_from_dataset()
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py
index bbf490bd..6a20e114 100644
--- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py
@@ -20,6 +20,22 @@ def extract_node_names(nodes, min_count):
data['src'] = data['id']
data['dst'] = data['serialized_name'].apply(get_node_name)
+ corrected_names = []
+ for type_, name_ in data[["type", "dst"]].values:
+ if type_ == "mention":
+ corrected_names.append(name_.split("@")[0])
+ else:
+ corrected_names.append(name_)
+
+ data["dst"] = corrected_names
+
+ def not_contains(name):
+ return "0x" not in name
+
+ data = data[
+ data["dst"].apply(not_contains)
+ ]
+
counts = data['dst'].value_counts()
data['counts'] = data['dst'].apply(lambda x: counts[x])
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_type_edges.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_type_edges.py
new file mode 100644
index 00000000..daae3d7f
--- /dev/null
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_type_edges.py
@@ -0,0 +1,51 @@
+import sys
+from os.path import join
+
+from SourceCodeTools.code.data.sourcetrail.file_utils import *
+
+
+def filter_type_edges(nodes, edges, keep_proportion=0.0):
+ annotations = edges.query(f"type == 'annotation_for' or type == 'returned_by' or type == 'annotation_for_rev' or type == 'returned_by_rev'")
+ no_annotations = edges.query(f"type != 'annotation_for' and type != 'returned_by' and type != 'annotation_for_rev' and type != 'returned_by_rev'")
+
+ to_keep = int(len(annotations) * keep_proportion)
+ if to_keep == 0:
+ annotations_removed = annotations
+ annotations_kept = None
+ elif to_keep == len(annotations):
+ annotations_removed = None
+ annotations_kept = annotations
+ else:
+ annotations = annotations.sample(frac=1.)
+ annotations_kept, annotations_removed = annotations.iloc[:to_keep], annotations.iloc[to_keep:]
+
+ if annotations_kept is not None:
+ no_annotations = no_annotations.append(annotations_kept)
+
+ annotations = annotations_removed
+ if annotations is not None:
+ annotations = annotations_removed
+ node2name = dict(zip(nodes["id"], nodes["serialized_name"]))
+ get_name = lambda id_: node2name[id_]
+ annotations["source_node_id"] = annotations["source_node_id"].apply(get_name)
+ # rename columns to use as a dataset
+ annotations.rename({"source_node_id": "dst", "target_node_id": "src"}, axis=1, inplace=True)
+ annotations = annotations[["src","dst"]]
+
+ return no_annotations, annotations
+
+
+def main():
+ working_directory = sys.argv[1]
+ nodes = unpersist(join(working_directory, "nodes.bz2"))
+ edges = unpersist(join(working_directory, "edges.bz2"))
+ out_annotations = sys.argv[2]
+ out_no_annotations = sys.argv[3]
+
+ no_annotations, annotations = filter_type_edges(nodes, edges)
+
+ persist(annotations, out_annotations)
+ persist(no_annotations, out_no_annotations)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_graph_complexity_analysis.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_graph_complexity_analysis.py
new file mode 100644
index 00000000..461d50aa
--- /dev/null
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_graph_complexity_analysis.py
@@ -0,0 +1,142 @@
+import dgl
+import pandas as pd
+import argparse
+import ast
+import numpy as np
+
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+
+from SourceCodeTools.code.data.sourcetrail.file_utils import read_processed_bodies, unpersist
+from SourceCodeTools.code.data.sourcetrail.sourcetrail_ast_edges2 import AstProcessor, standardize_new_edges, \
+ ReplacementNodeResolver, MentionTokenizer
+from SourceCodeTools.code.data.sourcetrail.sourcetrail_compute_function_diameter import compute_diameter
+from SourceCodeTools.code.data.sourcetrail.sourcetrail_parse_bodies2 import has_valid_syntax
+from SourceCodeTools.code.python_ast import AstGraphGenerator
+from SourceCodeTools.nlp import create_tokenizer
+
+
+def process_code(source_file_content, node_resolver, mention_tokenizer):
+ ast_processor = AstProcessor(source_file_content)
+ try: # TODO recursion error does not appear consistently. The issue is probably with library versions...
+ edges = ast_processor.get_edges(as_dataframe=False)
+ except RecursionError:
+ return None
+
+ if len(edges) == 0:
+ return None
+
+ edges = standardize_new_edges(edges, node_resolver, mention_tokenizer)
+
+ return edges
+
+
+def compute_gnn_passings(body, mention_tokenizer):
+
+ node_resolver = ReplacementNodeResolver()
+
+ source_file_content = body.lstrip()
+
+
+
+ edges = process_code(
+ source_file_content, node_resolver, mention_tokenizer
+ )
+
+ if edges is None:
+ return None
+
+ edges = pd.DataFrame(edges).rename({"src": "source_node_id", "dst": "target_node_id"}, axis=1)
+ diameter = compute_diameter(edges, func_id=0)
+
+ G = dgl.DGLGraph()
+ G.add_edges(edges["source_node_id"], edges["target_node_id"])
+
+ sampler = dgl.dataloading.MultiLayerFullNeighborSampler(diameter)
+
+ loader = dgl.dataloading.NodeDataLoader(
+ G, G.nodes(), sampler, batch_size=1, shuffle=True, num_workers=0)
+ for input_nodes, seeds, blocks in loader:
+ num_edges = 0
+ for block in blocks:
+ num_edges += block.num_edges()
+ if num_edges != 0:
+ break
+
+ return num_edges
+
+def compute_transformer_passings(body, bpe):
+ num_tokens = len(bpe(body))
+ return num_tokens
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("bodies")
+ parser.add_argument("bpe_path")
+ parser.add_argument("--num_layers", default=8, type=int)
+
+ args = parser.parse_args()
+
+ bodies = unpersist(args.bodies)
+ bpe = create_tokenizer(type="bpe", bpe_path=args.bpe_path)
+ mention_tokenizer = MentionTokenizer(args.bpe_path, create_subword_instances=True, connect_subwords=False)
+
+ lengths_tr = {}
+ lengths_gnn = {}
+ ratio = []
+
+ for body in tqdm(bodies["body"]):
+ if not has_valid_syntax(body):
+ continue
+
+ n_tokens = compute_transformer_passings(body, bpe)
+ n_edges = compute_gnn_passings(body, mention_tokenizer)
+
+ if n_tokens not in lengths_tr:
+ lengths_tr[n_tokens] = []
+ if n_tokens not in lengths_gnn:
+ lengths_gnn[n_tokens] = []
+
+ lengths_tr[n_tokens].append(n_tokens ** 2 * args.num_layers)
+ lengths_gnn[n_tokens].append(n_edges)# * args.num_layers)
+ ratio.append((n_tokens, n_edges))
+
+ for key in lengths_tr:
+ data_tr = np.array(lengths_tr[key])
+ data_gnn = np.array(lengths_gnn[key])
+
+ lengths_tr[key] = np.mean(data_tr)#, np.std(data_tr))
+ lengths_gnn[key] = np.mean(data_gnn)#, np.std(data_gnn))
+
+ data_ratios = np.array(ratio)
+
+ plt.plot(data_ratios[:, 0], data_ratios[:, 1], "*")
+ plt.xlabel("Number of Tokens")
+ plt.ylabel("Number of Edges")
+ plt.savefig("tokens_edges.png")
+ plt.close()
+
+ plt.hist(data_ratios[:, 1] / data_ratios[:, 0], bins=20)
+ plt.xlabel("Number of edges / Number of tokens")
+ plt.savefig("ratio.png")
+ plt.close()
+
+ ratio = data_ratios[:, 1] / data_ratios[:, 0]
+ ratio = (np.mean(ratio), np.std(ratio))
+
+ plt.plot(list(lengths_tr.keys()), np.log10(np.array(list(lengths_tr.values()))), "*")
+ plt.plot(list(lengths_gnn.keys()), np.log10(np.array(list(lengths_gnn.values()))), "*")
+ plt.plot(list(lengths_gnn.keys()), np.log10(np.array(list(lengths_gnn.values())) * args.num_layers), "*")
+ plt.legend([f"Transformer {args.num_layers} layers", "GNN L layers", f"GNN L*{args.num_layers} layers"])
+ plt.xlabel("Number of Tokens")
+ plt.ylabel("log10(Number of Message Exchanges)")
+ plt.savefig("avg_passings.png")
+ plt.close()
+
+
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py
index 4d68195d..8457d6e8 100644
--- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py
@@ -43,7 +43,7 @@ def merge_global_with_local(existing_nodes, next_valid_id, local_nodes):
assert len(new_nodes) == len(set(new_nodes['node_repr'].to_list()))
- ids_start = next_valid_id
+ ids_start = int(next_valid_id)
ids_end = ids_start + len(new_nodes)
new_nodes['id'] = list(range(ids_start, ids_end))
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py
index a127ff49..d6143c5b 100644
--- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py
@@ -60,7 +60,10 @@ def get_function_body(file_content, file_id, start, end, s_col, e_col) -> str:
source_lines = file_content[file_id].split("\n")
- body_lines = source_lines[start: end]
+ if start == end: # handle situations when the entire function takes only one line
+ body_lines = source_lines[start]
+ else:
+ body_lines = source_lines[start: end]
initial_strip = body_lines[0][0:len(body_lines[0]) - len(body_lines[0].lstrip())]
body = initial_strip + file_content[file_id][offsets[0][0]: offsets[0][1]]
@@ -100,6 +103,15 @@ def get_range_for_replacement(occurrence, start_col, end_col, line, nodeid2name)
def process_body(body, local_occurrences, nodeid2name, f_id, f_start):
+ """
+ Extract the list
+ :param body:
+ :param local_occurrences:
+ :param nodeid2name:
+ :param f_id:
+ :param f_start:
+ :return:
+ """
body_lines = body.split("\n")
local_occurrences = sort_occurrences(local_occurrences)
@@ -141,6 +153,16 @@ def process_body(body, local_occurrences, nodeid2name, f_id, f_start):
def process_bodies(nodes, edges, source_location, occurrence, file_content, lang):
+ """
+ :param nodes:
+ :param edges:
+ :param source_location:
+ :param occurrence:
+ :param file_content:
+ :param lang:
+ :return: Dataframe with columns id, body, sourcetrail_node_offsets. Node offsets are not resolved from the
+ global graph.
+ """
occurrence_groups = get_occurrence_groups(nodes, edges, source_location, occurrence)
diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py
index bbad1d77..dc442261 100644
--- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py
+++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py
@@ -15,7 +15,8 @@
2: "uses_type", # from user to type
16: "inheritance",
4: "uses", # from user to item
- 512: "imports" # from module to imported object
+ 512: "imports", # from module to imported object
+ 32: "non-idexed-package"
}
diff --git a/SourceCodeTools/code/experiments/Experiments.py b/SourceCodeTools/code/experiments/Experiments.py
index 85bb8295..d1cf45f7 100644
--- a/SourceCodeTools/code/experiments/Experiments.py
+++ b/SourceCodeTools/code/experiments/Experiments.py
@@ -1,4 +1,6 @@
# %%
+from typing import Iterable
+
import pandas
from os.path import join
@@ -6,6 +8,9 @@
import numpy as np
# from graphtools import Embedder
+from SourceCodeTools.code.data.sourcetrail.Dataset import get_train_val_test_indices, create_train_val_test_masks, \
+ filter_dst_by_freq
+from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist
from SourceCodeTools.models.Embedder import Embedder
import pickle
@@ -24,6 +29,13 @@ def get_nodes_with_split_ids(embedder, split_ids):
# ]['id'].to_numpy()
+def shuffle(X, y):
+ np.random.seed(42)
+ ind_shuffle = np.arange(0, X.shape[0])
+ np.random.shuffle(ind_shuffle)
+ return X[ind_shuffle], y[ind_shuffle]
+
+
class Experiments:
"""
Provides convenient interface for creating experiments.
@@ -47,7 +59,8 @@ def __init__(self,
variable_use_path=None,
function_name_path=None,
type_ann=None,
- gnn_layer=-1):
+ gnn_layer=-1,
+ embeddings_path=None):
"""
:param base_path: path tp trained gnn model
@@ -78,7 +91,9 @@ def __init__(self,
# self.splits = torch.load(os.path.join(self.base_path, "state_dict.pt"))["splits"]
if base_path is not None:
- self.embed = pickle.load(open(join(self.base_path, "embeddings.pkl"), "rb"))[gnn_layer]
+ self.embed = pickle.load(open(embeddings_path, "rb"))
+ if isinstance(self.embed, Iterable):
+ self.embed = self.embed[gnn_layer]
# alternative_nodes = pickle.load(open("nodes.pkl", "rb"))
# self.embed.e = alternative_nodes
# self.embed.e = np.random.randn(self.embed.e.shape[0], self.embed.e.shape[1])
@@ -112,12 +127,14 @@ def __getitem__(self, type: str):
:param type: str description of the experiment
:return: Experiment object
"""
- nodes = pandas.read_csv(join(self.base_path, "nodes.csv"))
- edges = pandas.read_csv(join(self.base_path, "held.csv")).astype({"src": "int32", "dst": "int32"})
+
+ nodes = unpersist(join(self.base_path, "nodes.bz2"))
+ # edges = pandas.read_csv(join(self.base_path, "held.csv")).astype({"src": "int32", "dst": "int32"})
+ edges = unpersist(join(self.base_path, "edges.bz2"))
from SourceCodeTools.code.data.sourcetrail.Dataset import SourceGraphDataset
- self.splits = SourceGraphDataset.get_global_graph_id_splits(nodes)
+ # self.splits = SourceGraphDataset.get_global_graph_id_splits(nodes)
# global_ids = nodes['global_graph_id'].values
# self.splits = (
# global_ids[nodes['train_mask'].values],
@@ -200,22 +217,111 @@ def __getitem__(self, type: str):
compact_dst=False)
elif type == "typeann":
- type_ann = pandas.read_csv(self.experiments['typeann']).astype({"src": "int32", "dst": "str"})
+ import os
+ from SourceCodeTools.nlp.entity.utils.data import read_data
- # node_pool = set(self.splits[2]).union(self.splits[1]).union(self.splits[0])
- # node_pool = set(nodes['id'].values.tolist())
- # node_pool = set(get_nodes_with_split_ids(nodes, set(self.splits[2]).union(self.splits[1]).union(self.splits[0])))
- node_pool = set(
- get_nodes_with_split_ids(self.embed, set(self.splits[2]).union(self.splits[1]).union(self.splits[0])))
+ random_seed = 42
+ min_entity_count = 3
+
+ type_ann = unpersist(self.experiments['typeann'])
+
+ filter_rule = lambda name: "0x" not in name
type_ann = type_ann[
- type_ann['src'].apply(lambda nid: nid in node_pool)
+ type_ann["dst"].apply(filter_rule)
]
+ node2id = dict(zip(nodes["id"], nodes["type"]))
+ type_ann["src_type"] = type_ann["src"].apply(lambda x: node2id[x])
+
+ type_ann = type_ann[
+ type_ann["src_type"].apply(lambda type_: type_ in {"mention"})
+ ]
+
+ norm = lambda x: x.strip("\"").strip("'").split("[")[0].split(".")[-1]
+
+ type_ann["dst"] = type_ann["dst"].apply(norm)
+ type_ann = filter_dst_by_freq(type_ann, min_entity_count)
+
+ # this is used for codebert embeddings
+ filter_rule = lambda id_: id_ in self.embed.ind
+
+ type_ann = type_ann[
+ type_ann["src"].apply(filter_rule)
+ ]
+
+ # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray',
+ # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description',
+ # 'Type'}
+ test_only_popular_types = False
+ if test_only_popular_types:
+ allowed = {
+ 'str', 'Optional', 'int', 'Any', 'Union', 'bool', 'Other', 'Callable', 'Dict', 'bytes', 'float',
+ 'Description',
+ 'List', 'Sequence', 'Namespace', 'T', 'Type', 'object', 'HTTPServerRequest', 'Future'
+ }
+ type_ann = type_ann[
+ type_ann["dst"].apply(lambda type_: type_ in allowed)
+ ]
+ else:
+ allowed = None
+
+ from pathlib import Path
+ dataset_dir = Path(self.experiments['typeann']).parent
+ from SourceCodeTools.nlp.entity.type_prediction import filter_labels
+ train_data = filter_labels(
+ pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_train.pkl"), "rb")),
+ allowed=allowed
+ )
+ test_data = filter_labels(
+ pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_test.pkl"), "rb")),
+ allowed=allowed
+ )
+
+ def get_ids(typeann_dataset):
+ ids = []
+ for sent, annotations in typeann_dataset:
+ for _, _, r in annotations["replacements"]:
+ ids.append(r)
+
+ return set(ids)
+
+ train_ids = get_ids(train_data)
+ test_ids = get_ids(test_data)
+
+ type_ann = type_ann[["src", "dst"]]
+
+ # splits = get_train_val_test_indices(nodes.index)
+ # create_train_val_test_masks(nodes, *splits)
+
+ # train_data, test_data = read_data(
+ # open(os.path.join(self.base_path, "function_annotations.jsonl"), "r").readlines(),
+ # normalize=True, allowed=None, include_replacements=True,
+ # include_only="entities",
+ # min_entity_count=min_entity_count, random_seed=random_seed
+ # )
+ #
+ # train_nodes = set()
+ # for _, ann in train_data:
+ # for _, _, nid in ann["replacements"]:
+ # train_nodes.add(nid)
+ #
+ # test_nodes = set()
+ # for _, ann in test_data:
+ # for _, _, nid in ann["replacements"]:
+ # test_nodes.add(nid)
+ #
+ # type_nodes = set(type_ann["src"])
+ #
+ # train_nodes = train_nodes.intersection(type_nodes)
+ # test_nodes = test_nodes.intersection(type_nodes)
+
# return Experiment2(self.embed, nodes, edges, type_ann, split_on="nodes", neg_sampling_strategy="word2vec")
- return Experiment3(self.embed, nodes, edges, type_ann, split_on="edges",
- neg_sampling_strategy="word2vec", compact_dst=True)
+ return Experiment3(
+ self.embed, nodes, edges, type_ann, split_on="edges", neg_sampling_strategy="word2vec",
+ compact_dst=True, train_ids=train_ids, test_ids=test_ids
+ )
elif type == "typeann_name":
type_ann = pandas.read_csv(self.experiments['typeann']).astype({"src": "int32", "dst": "str"})
@@ -298,7 +404,8 @@ def __init__(self,
split_on="nodes",
neg_sampling_strategy="word2vec",
K=1,
- test_frac=0.1, compact_dst=True):
+ test_frac=0.1, compact_dst=True,
+ train_ids=None, test_ids=None):
# store local copies"
self.embed = embeddings
@@ -310,6 +417,8 @@ def __init__(self,
self.neg_smpl_strategy = neg_sampling_strategy
self.K = K
self.TEST_FRAC = test_frac
+ self.train_ids = train_ids
+ self.test_ids = test_ids
# make sure to drop duplicate edges to prevent leakage into the test set
# do it before creating experiment?
@@ -391,11 +500,6 @@ def get_training_data(self):
assert X_train.shape[0] == y_train.shape[0]
assert X_test.shape[0] == y_test.shape[0]
- def shuffle(X, y):
- ind_shuffle = np.arange(0, X.shape[0])
- np.random.shuffle(ind_shuffle)
- return X[ind_shuffle], y[ind_shuffle]
-
self.X_train, self.y_train = shuffle(X_train, y_train)
self.X_test, self.y_test = shuffle(X_test, y_test)
@@ -557,19 +661,20 @@ def __init__(self, embeddings: Embedder,
split_on="nodes",
neg_sampling_strategy="word2vec",
K=1,
- test_frac=0.1, compact_dst=True):
+ test_frac=0.1, compact_dst=True,
+ train_ids=None, test_ids=None):
super(Experiment2, self).__init__(embeddings, nodes, edges, target, split_on=split_on,
neg_sampling_strategy=neg_sampling_strategy, K=K, test_frac=test_frac,
- compact_dst=compact_dst)
+ compact_dst=compact_dst, train_ids=train_ids, test_ids=test_ids)
# TODO
# make sure that compact_property work always identically between runs
- self.name_map = compact_property(target['dst'])
+ self.name_map, self.inv_index = compact_property(target['dst'], return_order=True)
self.dst_orig = target['dst']
target['orig_dst'] = target['dst']
target['dst'] = target['dst'].apply(lambda name: self.name_map[name])
- print(f"Doing experiment with {len(self.name_map)} distinct target targets")
+ print(f"Doing experiment with {len(self.name_map)} distinct targets")
self.unique_src = self.target['src'].unique()
self.unique_dst = self.target['dst'].unique()
@@ -684,33 +789,40 @@ def get_negative_out(self, num):
class Experiment3(Experiment2):
- def __init__(self, embeddings: Embedder,
- nodes: pandas.DataFrame,
- edges: pandas.DataFrame,
- target: pandas.DataFrame,
- split_on="nodes",
- neg_sampling_strategy="word2vec",
- K=1,
- test_frac=0.1, compact_dst=True
- ):
+ def __init__(
+ self, embeddings: Embedder,
+ nodes: pandas.DataFrame,
+ edges: pandas.DataFrame,
+ target: pandas.DataFrame,
+ split_on="nodes",
+ neg_sampling_strategy="word2vec",
+ K=1,
+ test_frac=0.1, compact_dst=True,
+ train_ids=None, test_ids=None
+ ):
super(Experiment3, self).__init__(embeddings, nodes, edges, target, split_on=split_on,
neg_sampling_strategy=neg_sampling_strategy, K=K, test_frac=test_frac,
- compact_dst=compact_dst)
+ compact_dst=compact_dst, train_ids=train_ids, test_ids=test_ids)
def get_training_data(self):
# self.get_train_test_split()
- train_positive = self.target.iloc[self.train_edge_ind].values
- test_positive = self.target.iloc[self.test_edge_ind].values
+ if self.train_ids is not None and self.test_ids is not None:
+ test_positive = self.target[
+ self.target["src"].apply(lambda x: x in self.test_ids)
+ ].values
+ chosen_for_test = set(test_positive[:, 0].tolist())
+ train_positive = self.target[
+ self.target["src"].apply(lambda x: x in self.train_ids and x not in chosen_for_test)
+ ].values
- X_train, y_train = train_positive[:, 0].reshape(-1, 1), train_positive[:, 1].reshape(-1, 1)
- X_test, y_test = test_positive[:, 0].reshape(-1, 1), test_positive[:, 1].reshape(-1, 1)
+ else:
+ train_positive = self.target.iloc[self.train_edge_ind].values
+ test_positive = self.target.iloc[self.test_edge_ind].values
- def shuffle(X, y):
- ind_shuffle = np.arange(0, X.shape[0])
- np.random.shuffle(ind_shuffle)
- return X[ind_shuffle], y[ind_shuffle]
+ X_train, y_train = train_positive[:, 0].reshape(-1, 1).astype(np.int32), train_positive[:, 1].reshape(-1, 1).astype(np.int32)
+ X_test, y_test = test_positive[:, 0].reshape(-1, 1).astype(np.int32), test_positive[:, 1].reshape(-1, 1).astype(np.int32)
self.X_train, self.y_train = shuffle(X_train, y_train)
self.X_test, self.y_test = shuffle(X_test, y_test)
diff --git a/SourceCodeTools/code/experiments/classifiers.py b/SourceCodeTools/code/experiments/classifiers.py
index ff037f9b..3969021f 100644
--- a/SourceCodeTools/code/experiments/classifiers.py
+++ b/SourceCodeTools/code/experiments/classifiers.py
@@ -1,3 +1,5 @@
+import logging
+
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input, Embedding, concatenate
@@ -95,6 +97,9 @@ class NodeClassifier(Model):
def __init__(self, node_emb_size, n_classes, h_size=None):
super(NodeClassifier, self).__init__()
+ self.proj = Dense(node_emb_size, use_bias=False)
+ logging.warning("Using projection matrix")
+
if h_size is None:
h_size = [30, 15]
@@ -115,7 +120,7 @@ def __init__(self, node_emb_size, n_classes, h_size=None):
# self.logits = Dense(n_classes, input_shape=(h_size[1],))
def __call__(self, x, **kwargs):
- h = x
+ h = self.proj(x)
for l in self.layers_:
h = l(h)
return h
diff --git a/SourceCodeTools/code/experiments/run_experiment.py b/SourceCodeTools/code/experiments/run_experiment.py
index 7c287c8d..f84a7e7d 100644
--- a/SourceCodeTools/code/experiments/run_experiment.py
+++ b/SourceCodeTools/code/experiments/run_experiment.py
@@ -1,5 +1,7 @@
import os
+from sklearn import metrics
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
@@ -22,6 +24,166 @@
all_experiments = link_prediction_experiments + name_prediction_experiments + node_classification_experiments + name_subword_predictor
+class Tracker:
+ def __init__(self, inverted_index=None):
+ self.all_true = []
+ self.all_estimated = []
+ self.all_emb = []
+ self.inv_index = inverted_index
+
+ def add(self, embs, pred, true):
+ self.all_true.append(true)
+ self.all_estimated.append(pred)
+ self.all_emb.append(embs)
+
+ def decode_label_names(self, ids):
+ assert self.inv_index is not None, "Need inverted index"
+ return list(map(lambda x: self.inv_index[x], ids))
+
+ @property
+ def true_labels(self):
+ return np.concatenate(self.all_true, axis=0).reshape(-1,).tolist()
+
+ @property
+ def true_label_names(self):
+ return self.decode_label_names(self.true_labels)
+
+ @property
+ def pred_labels(self):
+ return np.argmax(np.concatenate(self.all_estimated, axis=0), axis=-1).reshape(-1, ).tolist()
+
+ @property
+ def pred_label_names(self):
+ return self.decode_label_names(self.pred_labels)
+
+ @property
+ def pred_scores(self):
+ return np.concatenate(self.all_estimated, axis=0)
+
+ @property
+ def embeddings(self):
+ return np.concatenate(self.all_emb, axis=0)
+
+ def clear(self):
+ self.all_true.clear()
+ self.all_estimated.clear()
+ self.all_emb.clear()
+
+ def save_embs_for_tb(self, save_name):
+ assert self.inv_index is not None, "Cannot export for tensorboard without metadata"
+ np.savetxt(f"{save_name}_embeddings.tsv", self.embeddings, delimiter="\t")
+ with open(f"{save_name}_meta.tsv", "w") as meta_sink:
+ for label in list(map(lambda x: self.inv_index[x], self.pred_labels)):
+ meta_sink.write(f"{label}\n")
+
+ def save_umap(self, save_name):
+ type_freq = {
+ "str": 532,
+ "Optional": 232,
+ "int": 206,
+ "Any": 171,
+ "Union": 156,
+ "bool": 143,
+ "Callable": 80,
+ "Dict": 77,
+ "bytes": 58,
+ "float": 48
+ }
+ from umap import UMAP
+ import matplotlib.pyplot as plt
+ plt.rcParams.update({'font.size': 5})
+ reducer = UMAP(50)
+ embedding = reducer.fit_transform(self.embeddings)
+
+ labels = list(map(lambda x: self.inv_index[x], self.pred_labels))
+ unique_labels = list(set(labels))
+
+ plt.figure(figsize=(4,4))
+ for label in unique_labels:
+ if label not in type_freq:
+ continue
+ xs = []
+ ys = []
+ for lbl, (x, y) in zip(labels, embedding):
+ if lbl == label:
+ xs.append(x)
+ ys.append(y)
+ plt.scatter(xs, ys, 1.)
+ plt.axis('off')
+ plt.legend(unique_labels)
+ plt.savefig(f"{save_name}_umap.pdf")
+ plt.close()
+ # plt.show()
+
+
+
+ def get_metrics(self):
+ all_true = self.true_labels
+ all_scores = self.pred_scores
+
+ metric_dict = {}
+
+ for k in [1,3,5]:
+ metric_dict[f"Acc@{k}"] = metrics.top_k_accuracy_score(y_true=all_true, y_score=all_scores, k=k, labels=list(range(all_scores.shape[1])))
+
+ return metric_dict
+
+ def save_confusion_matrix(self, save_path):
+ estimate_confusion(
+ self.pred_label_names,
+ self.true_label_names,
+ save_path=save_path
+ )
+
+
+def estimate_confusion(pred, true, save_path):
+ pred_filtered = pred
+ true_filtered = true
+
+ import matplotlib.pyplot as plt
+ plt.rcParams.update({'font.size': 50})
+
+ labels = sorted(list(set(true_filtered + pred_filtered)))
+ label2ind = dict(zip(labels, range(len(labels))))
+
+ confusion = np.zeros((len(labels), len(labels)))
+
+ for pred, true in zip(pred_filtered, true_filtered):
+ confusion[label2ind[true], label2ind[pred]] += 1
+
+ norm = np.array([x if x != 0 else 1. for x in np.sum(confusion, axis=1)]).reshape(-1,1)
+ confusion /= norm
+
+ import matplotlib.pyplot as plt
+
+ fig, ax = plt.subplots(figsize=(45,45))
+ from matplotlib.pyplot import cm
+ im = ax.imshow(confusion, interpolation='nearest', cmap=cm.Blues)
+
+ # We want to show all ticks...
+ ax.set_xticks(np.arange(len(labels)))
+ ax.set_yticks(np.arange(len(labels)))
+ # ... and label them with the respective list entries
+ ax.set_xticklabels(labels)
+ ax.set_yticklabels(labels)
+
+ # Rotate the tick labels and set their alignment.
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
+ rotation_mode="anchor")
+
+ # Loop over data dimensions and create text annotations.
+ for i in range(len(labels)):
+ for j in range(len(labels)):
+ text = ax.text(j, i, f"{confusion[i, j]: .2f}",
+ ha="center", va="center", color="w")
+
+ ax.set_title("Confusion matrix for Python type prediction")
+ fig.tight_layout()
+ # plt.show()
+ plt.savefig(save_path)
+ plt.close()
+
+
def run_experiment(e, experiment_name, args):
experiment = e[experiment_name]
@@ -70,8 +232,8 @@ def run_experiment(e, experiment_name, args):
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
- @tf.function
- def train_step(batch):
+ # @tf.function
+ def train_step(batch, tracker=None):
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
@@ -80,22 +242,30 @@ def train_step(batch):
gradients = tape.gradient(loss, clf.trainable_variables)
optimizer.apply_gradients(zip(gradients, clf.trainable_variables))
+ if tracker is not None:
+ tracker.add(batch["x"], predictions.numpy(), batch["y"])
+
train_loss(loss)
train_accuracy(batch["y"], predictions)
- @tf.function
- def test_step(batch):
+ # @tf.function
+ def test_step(batch, tracker=None):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
predictions = clf(**batch, training=False)
t_loss = loss_object(batch["y"], predictions)
+ if tracker is not None:
+ tracker.add(batch["x"], predictions.numpy(), batch["y"])
+
test_loss(t_loss)
test_accuracy(batch["y"], predictions)
- args.epochs = 500
-
+ trains = []
tests = []
+ metrics = []
+
+ test_tracker = Tracker(inverted_index=experiment.inv_index if hasattr(experiment, "inv_index") else None)
for epoch in range(args.epochs):
# Reset the metrics at the start of the next epoch
@@ -108,12 +278,17 @@ def test_step(batch):
train_step(batch)
if epoch % 1 == 0:
+
+ test_tracker.clear()
+
for batch in experiment.test_batches():
- test_step(batch)
+ test_step(batch, tracker=test_tracker)
ma_train = train_accuracy.result() * 100 * ma_alpha + ma_train * (1 - ma_alpha)
ma_test = test_accuracy.result() * 100 * ma_alpha + ma_test * (1 - ma_alpha)
+ trains.append(ma_train)
tests.append(ma_test)
+ metrics.append(test_tracker.get_metrics())
# template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.4f}, Test Loss: {:.4f}, Test Accuracy: {:.4f}, Average Test {:.4f}'
# print(template.format(epoch+1,
@@ -123,9 +298,26 @@ def test_step(batch):
# test_accuracy.result()*100,
# ma_test))
+ # plot confusion matrix
+ if hasattr(experiment, "inv_index") and args.confusion_out_path is not None:
+ test_tracker.save_confusion_matrix(save_path=args.confusion_out_path)
+
+ if hasattr(experiment, "inv_index") and args.emb_out is not None:
+ test_tracker.save_embs_for_tb(save_name=args.emb_out)
+ test_tracker.save_umap(save_name=args.emb_out)
+
+ print(metrics[tests.index(max(tests))])
+
# ma_train = train_accuracy.result() * 100 * ma_alpha + ma_train * (1 - ma_alpha)
# ma_test = test_accuracy.result() * 100 * ma_alpha + ma_test * (1 - ma_alpha)
+ import matplotlib.pyplot as plt
+ plt.plot(trains)
+ plt.plot(tests)
+ plt.legend(["Train", "Test"])
+ # plt.savefig(os.path.join(args.base_path, f"{args.experiment}.png"))
+ plt.show()
+
return ma_train, max(tests)
@@ -151,6 +343,10 @@ def test_step(batch):
parser.add_argument("--element_predictor_h_size", default=50, type=int, help="")
parser.add_argument("--link_predictor_h_size", default="[20]", type=str, help="")
parser.add_argument("--node_classifier_h_size", default="[30,15]", type=str, help="")
+ parser.add_argument("--confusion_out_path", default=None, type=str, help="")
+ parser.add_argument("--trials", default=1, type=int, help="")
+ parser.add_argument("--emb_out", default=None, type=str, help="")
+ parser.add_argument("--embeddings", default=None)
parser.add_argument('--random', action='store_true')
parser.add_argument('--test_embedder', action='store_true')
args = parser.parse_args()
@@ -166,15 +362,17 @@ def test_step(batch):
function_name_path=None,
type_ann=args.type_ann,
gnn_layer=-1,
+ embeddings_path=args.embeddings
)
experiments = args.experiment.split(",")
for experiment_name in experiments:
- print(f"\n{experiment_name}:")
- try:
- train_acc, test_acc = run_experiment(e, experiment_name, args)
- print("Train Accuracy: {:.4f}, Test Accuracy: {:.4f}".format(train_acc, test_acc))
- except ValueError as err:
- print(err)
- print("\n")
+ for trial in range(args.trials):
+ print(f"\n{experiment_name}, trial {trial}:")
+ try:
+ train_acc, test_acc = run_experiment(e, experiment_name, args)
+ print(f"Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}")
+ except ValueError as err:
+ print(err)
+ print("\n")
diff --git a/SourceCodeTools/code/python_ast.py b/SourceCodeTools/code/python_ast.py
index 07b8f7aa..fcc510d7 100644
--- a/SourceCodeTools/code/python_ast.py
+++ b/SourceCodeTools/code/python_ast.py
@@ -76,6 +76,7 @@ class GNode:
# id = None
def __init__(self, **kwargs):
+ self.string = None
for k, v in kwargs.items():
setattr(self, k, v)
@@ -104,24 +105,28 @@ def __init__(self, source):
self.condition_status = []
self.scope = []
- def get_source_from_ast_range(self, start_line, end_line, start_col, end_col):
+ def get_source_from_ast_range(self, start_line, end_line, start_col, end_col, strip=True):
source = ""
num_lines = end_line - start_line + 1
if start_line == end_line:
- source += self.source[start_line - 1].encode("utf8")[start_col:end_col].decode(
- "utf8").strip()
+ section = self.source[start_line - 1].encode("utf8")[start_col:end_col].decode(
+ "utf8")
+ source += section.strip() if strip else section + "\n"
else:
for ind, lineno in enumerate(range(start_line - 1, end_line)):
if ind == 0:
- source += self.source[lineno].encode("utf8")[start_col:].decode(
- "utf8").strip()
+ section = self.source[lineno].encode("utf8")[start_col:].decode(
+ "utf8")
+ source += section.strip() if strip else section + "\n"
elif ind == num_lines - 1:
- source += self.source[lineno].encode("utf8")[:end_col].decode(
- "utf8").strip()
+ section = self.source[lineno].encode("utf8")[:end_col].decode(
+ "utf8")
+ source += section.strip() if strip else section + "\n"
else:
- source += self.source[lineno].strip()
+ section = self.source[lineno]
+ source += section.strip() if strip else section + "\n"
- return source
+ return source.rstrip()
def get_name(self, *, node=None, name=None, type=None, add_random_identifier=False):
@@ -132,10 +137,17 @@ def get_name(self, *, node=None, name=None, type=None, add_random_identifier=Fal
if add_random_identifier:
name += f"_{str(hex(int(time_ns())))}"
+ if hasattr(node, "lineno"):
+ node_string = self.get_source_from_ast_range(node.lineno, node.end_lineno, node.col_offset, node.end_col_offset, strip=False)
+ # if "\n" in node_string:
+ # node_string = None
+ else:
+ node_string = None
+
if len(self.scope) > 0:
- return GNode(name=name, type=type, scope=copy(self.scope[-1]))
+ return GNode(name=name, type=type, scope=copy(self.scope[-1]), string=node_string)
else:
- return GNode(name=name, type=type)
+ return GNode(name=name, type=type, string=node_string)
# return (node.__class__.__name__ + "_" + str(hex(int(time_ns()))), node.__class__.__name__)
def get_edges(self, as_dataframe=True):
@@ -170,6 +182,8 @@ def parse_body(self, nodes):
for node in nodes:
s = self.parse(node)
if isinstance(s, tuple):
+ if s[1].type == "Constant": # this happens when processing docstring, as a result a lot of nodes are connected to the node Constant_
+ continue # in general, constant node has no affect as a body expression, can skip
# some parsers return edges and names?
edges.extend(s[0])
@@ -278,9 +292,9 @@ def generic_parse(self, node, operands, with_name=None, ensure_iterables=False):
else:
node_name = with_name
- if len(self.scope) > 0:
- edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": self.scope[-1], "type": "mention_scope"})
- edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": node_name, "type": "mention_scope_rev"})
+ # if len(self.scope) > 0:
+ # edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": self.scope[-1], "type": "mention_scope"})
+ # edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": node_name, "type": "mention_scope_rev"})
for operand in operands:
if operand in ["body", "orelse", "finalbody"]:
@@ -476,9 +490,11 @@ def parse_arg(self, node):
)
annotation = GNode(name=annotation_string,
type="type_annotation")
- edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
- # do not use reverse edges for types, will result in leak from function to function
- # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'})
+ mention_name = GNode(name=node.arg + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1]))
+ edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": mention_name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
+ # edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
+ # # do not use reverse edges for types, will result in leak from function to function
+ # # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'})
return edges, name
# return self.generic_parse(node, ["arg", "annotation"])
@@ -503,9 +519,15 @@ def parse_AnnAssign(self, node):
annotation = GNode(name=annotation_string,
type="type_annotation")
edges, name = self.generic_parse(node, ["target"])
- edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
- # do not use reverse edges for types, will result in leak from function to function
- # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'})
+ try:
+ mention_name = GNode(name=node.target.id + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1]))
+ edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": mention_name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
+ except Exception as e:
+ edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
+ # print(e) # don't know how I should parse this "Attribute(value=Name(id='self', ctx=Load()), attr='srctrlrpl_1631733463030025000@#attr#', ctx=Store())"
+ # edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
+ # # do not use reverse edges for types, will result in leak from function to function
+ # # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'})
return edges, name
# return self.generic_parse(node, ["target", "annotation"])
@@ -854,7 +876,7 @@ def parse_arguments(self, node):
# arguments(args=[arg(arg='self', annotation=None), arg(arg='tqdm_cls', annotation=None), arg(arg='sleep_interval', annotation=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[])
# vararg constains type annotations
- return self.generic_parse(node, ["args", "vararg"])
+ return self.generic_parse(node, ["args", "vararg"]) # kwarg, kwonlyargs, posonlyargs???
def parse_comprehension(self, node):
edges = []
@@ -862,9 +884,9 @@ def parse_comprehension(self, node):
cph_name = self.get_name(name="comprehension", type="comprehension", add_random_identifier=True)
# cph_name = GNode(name="comprehension_" + str(hex(int(time_ns()))), type="comprehension")
- if len(self.scope) > 0:
- edges.append({"scope": copy(self.scope[-1]), "src": cph_name, "dst": self.scope[-1], "type": "mention_scope"})
- edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": cph_name, "type": "mention_scope_rev"})
+ # if len(self.scope) > 0:
+ # edges.append({"scope": copy(self.scope[-1]), "src": cph_name, "dst": self.scope[-1], "type": "mention_scope"})
+ # edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": cph_name, "type": "mention_scope_rev"})
target, ext_edges = self.parse_operand(node.target)
edges.extend(ext_edges)
diff --git a/SourceCodeTools/code/python_ast_cf.py b/SourceCodeTools/code/python_ast_cf.py
new file mode 100644
index 00000000..d5c85192
--- /dev/null
+++ b/SourceCodeTools/code/python_ast_cf.py
@@ -0,0 +1,508 @@
+import ast
+from copy import copy
+from enum import Enum
+from pprint import pprint
+from time import time_ns
+from collections.abc import Iterable
+import pandas as pd
+
+from SourceCodeTools.nlp.entity.annotator.annotator_utils import to_offsets
+
+
+class GNode:
+ # name = None
+ # type = None
+ # id = None
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ def __eq__(self, other):
+ if self.name == other.name and self.type == other.type:
+ return True
+ else:
+ return False
+
+ def __repr__(self):
+ return self.__dict__.__repr__()
+
+ def __hash__(self):
+ return (self.name, self.type).__hash__()
+
+ def setprop(self, key, value):
+ setattr(self, key, value)
+
+
+class AstGraphGenerator(object):
+
+ def __init__(self, source):
+ self.source = source.split("\n") # lines of the source code
+ self.full_source = source
+ self.root = ast.parse(source)
+ self.current_condition = []
+ self.condition_status = []
+ self.scope = []
+
+ def get_source_from_ast_range(self, start_line, end_line, start_col, end_col):
+ source = ""
+ num_lines = end_line - start_line + 1
+ if start_line == end_line:
+ source += self.source[start_line - 1].encode("utf8")[start_col:end_col].decode(
+ "utf8").strip()
+ else:
+ for ind, lineno in enumerate(range(start_line - 1, end_line)):
+ if ind == 0:
+ source += self.source[lineno].encode("utf8")[start_col:].decode(
+ "utf8").strip()
+ elif ind == num_lines - 1:
+ source += self.source[lineno].encode("utf8")[:end_col].decode(
+ "utf8").strip()
+ else:
+ source += self.source[lineno].strip()
+
+ return source
+
+ def get_name(self, *, node=None, name=None, type=None, add_random_identifier=False):
+
+ if node is not None:
+ name = node.__class__.__name__ + "_" + str(hex(int(time_ns())))
+ type = node.__class__.__name__
+ else:
+ if add_random_identifier:
+ name += f"_{str(hex(int(time_ns())))}"
+
+ if len(self.scope) > 0:
+ return GNode(name=name, type=type, scope=copy(self.scope[-1]))
+ else:
+ return GNode(name=name, type=type)
+ # return (node.__class__.__name__ + "_" + str(hex(int(time_ns()))), node.__class__.__name__)
+
+ def get_edges(self, as_dataframe=False):
+ edges = []
+ edges.extend(self.parse(self.root)[0])
+ # for f_def_node in ast.iter_child_nodes(self.root):
+ # if type(f_def_node) == ast.FunctionDef:
+ # edges.extend(self.parse(f_def_node))
+ # break # to avoid going through nested definitions
+
+ if not as_dataframe:
+ return edges
+ df = pd.DataFrame(edges)
+ return df.astype({col: "Int32" for col in df.columns if col not in {"src", "dst", "type"}})
+
+ def parse(self, node):
+ n_type = type(node)
+ method_name = "parse_" + n_type.__name__
+ if hasattr(self, method_name):
+ return self.__getattribute__(method_name)(node)
+ else:
+ # print(type(node))
+ # print(ast.dump(node))
+ # print(node._fields)
+ # pprint(self.source)
+ return self.parse_as_expression(node)
+ # return self.generic_parse(node, node._fields)
+ # raise Exception()
+ # return [type(node)]
+
+ def parse_body(self, nodes):
+ edges = []
+ last_node = None
+ for node in nodes:
+ s = self.parse(node)
+ if isinstance(s, tuple):
+ # some parsers return edges and names?
+ edges.extend(s[0])
+
+ if last_node is not None:
+ edges.append({"dst": s[1], "src": last_node, "type": "next"})
+ edges.append({"dst": last_node, "src": s[1], "type": "prev"})
+
+ last_node = s[1]
+
+ for cond_name, cond_stat in zip(self.current_condition[-1:], self.condition_status[-1:]):
+ edges.append({"scope": copy(self.scope[-1]), "src": last_node, "dst": cond_name, "type": "defined_in_" + cond_stat})
+ edges.append({"scope": copy(self.scope[-1]), "src": cond_name, "dst": last_node, "type": "defined_in_" + cond_stat + "_rev"})
+ # edges.append({"scope": copy(self.scope[-1]), "src": cond_name, "dst": last_node, "type": "execute_when_" + cond_stat})
+ else:
+ edges.extend(s)
+ return edges
+
+ def parse_in_context(self, cond_name, cond_stat, edges, body):
+ if isinstance(cond_name, str):
+ cond_name = [cond_name]
+ cond_stat = [cond_stat]
+ elif isinstance(cond_name, GNode):
+ cond_name = [cond_name]
+ cond_stat = [cond_stat]
+
+ for cn, cs in zip(cond_name, cond_stat):
+ self.current_condition.append(cn)
+ self.condition_status.append(cs)
+
+ edges.extend(self.parse_body(body))
+
+ for i in range(len(cond_name)):
+ self.current_condition.pop(-1)
+ self.condition_status.pop(-1)
+
+ def parse_as_expression(self, node, *args, **kwargs):
+ offset = to_offsets(self.full_source,
+ [(node.lineno - 1, node.end_lineno - 1, node.col_offset, node.end_col_offset, "expression")],
+ as_bytes=True)
+ offset, = offset
+ line = self.full_source[offset[0]: offset[1]].replace("@","##at##")
+ name = GNode(name=line, type="Name")
+ expr = GNode(name="Expression" + "_" + str(hex(int(time_ns()))), type="mention")
+ edges = [
+ {"scope": copy(self.scope[-1]), "src": name, "dst": expr, "type": "local_mention", "line": node.lineno - 1, "end_line": node.end_lineno - 1, "col_offset": node.col_offset, "end_col_offset": node.end_col_offset},
+ ]
+
+ return edges, expr
+
+ def parse_as_mention(self, name):
+ mention_name = GNode(name=name + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1]))
+ name = GNode(name=name, type="Name")
+ # mention_name = (name + "@" + self.scope[-1], "mention")
+
+ # edge from name to mention in a function
+ edges = [
+ {"scope": copy(self.scope[-1]), "src": name, "dst": mention_name, "type": "local_mention"},
+ # {"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": mention_name, "type": "mention_scope"}
+ ]
+ return edges, mention_name
+
+ def parse_operand(self, node):
+ # need to make sure upper level name is correct; handle @keyword; type placeholder for sourcetrail nodes?
+ edges = []
+ if isinstance(node, str):
+ # fall here when parsing attributes, they are given as strings; should attributes be parsed into subwords?
+ if "@" in node:
+ parts = node.split("@")
+ node = GNode(name=parts[0], type=parts[1])
+ else:
+ node = GNode(name=node, type="Name")
+ iter_ = node
+ elif isinstance(node, int) or node is None:
+ iter_ = GNode(name=str(node), type="ast_Literal")
+ # iter_ = str(node)
+ elif isinstance(node, GNode):
+ iter_ = node
+ else:
+ iter_e = self.parse(node)
+ if type(iter_e) == str:
+ iter_ = iter_e
+ elif isinstance(iter_e, GNode):
+ iter_ = iter_e
+ elif type(iter_e) == tuple:
+ ext_edges, name = iter_e
+ assert isinstance(name, GNode)
+ edges.extend(ext_edges)
+ iter_ = name
+ else:
+ # unexpected scenario
+ print(node)
+ print(ast.dump(node))
+ print(iter_e)
+ print(type(iter_e))
+ pprint(self.source)
+ print(self.source[node.lineno - 1].strip())
+ raise Exception()
+
+ return iter_, edges
+
+ def parse_and_add_operand(self, node_name, operand, type, edges):
+
+ operand_name, ext_edges = self.parse_operand(operand)
+ edges.extend(ext_edges)
+
+ if hasattr(operand, "lineno"):
+ edges.append({"scope": copy(self.scope[-1]), "src": operand_name, "dst": node_name, "type": type, "line": operand.lineno-1, "end_line": operand.end_lineno-1, "col_offset": operand.col_offset, "end_col_offset": operand.end_col_offset})
+ else:
+ edges.append({"scope": copy(self.scope[-1]), "src": operand_name, "dst": node_name, "type": type})
+
+ # if len(ext_edges) > 0: # need this to avoid adding reverse edges to operation names and other highly connected nodes
+ edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": operand_name, "type": type + "_rev"})
+
+ def generic_parse(self, node, operands, with_name=None, ensure_iterables=False):
+
+ edges = []
+
+ if with_name is None:
+ node_name = self.get_name(node=node)
+ else:
+ node_name = with_name
+
+ if len(self.scope) > 0:
+ edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": self.scope[-1], "type": "mention_scope"})
+ edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": node_name, "type": "mention_scope_rev"})
+
+ for operand in operands:
+ if operand in ["body", "orelse", "finalbody"]:
+ self.parse_in_context(node_name, "operand", edges, node.__getattribute__(operand))
+ else:
+ operand_ = node.__getattribute__(operand)
+ if operand_ is not None:
+ if isinstance(operand_, Iterable) and not isinstance(operand_, str):
+ # TODO:
+ # appears as leaf node if the iterable is empty. suggest adding an "EMPTY" element
+ for oper_ in operand_:
+ self.parse_and_add_operand(node_name, oper_, operand, edges)
+ else:
+ self.parse_and_add_operand(node_name, operand_, operand, edges)
+
+ # TODO
+ # need to identify the benefit of this node
+ # maybe it is better to use node types in the graph
+ # edges.append({"scope": copy(self.scope[-1]), "src": node.__class__.__name__, "dst": node_name, "type": "node_type"})
+
+ return edges, node_name
+
+ def parse_type_node(self, node):
+ # node.lineno, node.col_offset, node.end_lineno, node.end_col_offset
+ if node.lineno == node.end_lineno:
+ type_str = self.source[node.lineno][node.col_offset - 1: node.end_col_offset]
+ # print(type_str)
+ else:
+ type_str = ""
+ for ln in range(node.lineno - 1, node.end_lineno):
+ if ln == node.lineno - 1:
+ type_str += self.source[ln][node.col_offset - 1:].strip()
+ elif ln == node.end_lineno - 1:
+ type_str += self.source[ln][:node.end_col_offset].strip()
+ else:
+ type_str += self.source[ln].strip()
+ return type_str
+
+ def parse_Module(self, node):
+ edges, module_name = self.generic_parse(node, [])
+ self.scope.append(module_name)
+ self.parse_in_context(module_name, "module", edges, node.body)
+ self.scope.pop(-1)
+ return edges, module_name
+
+ def parse_FunctionDef(self, node):
+ # edges, f_name = self.generic_parse(node, ["name", "args", "returns", "decorator_list"])
+ # edges, f_name = self.generic_parse(node, ["args", "returns", "decorator_list"])
+
+ # need to creare function name before generic_parse so that the scope is set up correctly
+ # scope is used to create local mentions of variable and function names
+ fdef_node_name = self.get_name(node=node)
+ self.scope.append(fdef_node_name)
+
+ to_parse = []
+ if len(node.args.args) > 0 or node.args.vararg is not None:
+ to_parse.append("args")
+ if len("decorator_list") > 0:
+ to_parse.append("decorator_list")
+
+ edges, f_name = self.generic_parse(node, to_parse, with_name=fdef_node_name)
+
+ if node.returns is not None:
+ # returns stores return type annotation
+ # can contain quotes
+ # https://stackoverflow.com/questions/46458470/should-you-put-quotes-around-type-annotations-in-python
+ # https://www.python.org/dev/peps/pep-0484/#forward-references
+ annotation_string = self.get_source_from_ast_range(
+ node.returns.lineno, node.returns.end_lineno,
+ node.returns.col_offset, node.returns.end_col_offset
+ )
+ annotation = GNode(name=annotation_string,
+ type="type_annotation")
+ edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": f_name, "type": "returned_by", "line": node.returns.lineno - 1, "end_line": node.returns.end_lineno - 1, "col_offset": node.returns.col_offset, "end_col_offset": node.returns.end_col_offset})
+ # do not use reverse edges for types, will result in leak from function to function
+ # edges.append({"scope": copy(self.scope[-1]), "src": f_name, "dst": annotation, "type": 'returns'})
+
+ # if node.returns:
+ # print(self.source[node.lineno -1]) # can get definition string here
+
+ self.parse_in_context(f_name, "function", edges, node.body)
+
+ self.scope.pop(-1)
+
+ # func_name = GNode(name=node.name, type="Name")
+ ext_edges, func_name = self.parse_as_mention(name=node.name)
+ edges.extend(ext_edges)
+
+ assert isinstance(node.name, str)
+ edges.append({"scope": copy(self.scope[-1]), "src": f_name, "dst": func_name, "type": "function_name"})
+ edges.append({"scope": copy(self.scope[-1]), "src": func_name, "dst": f_name, "type": "function_name_rev"})
+
+ return edges, f_name
+
+ def parse_AsyncFunctionDef(self, node):
+ return self.parse_FunctionDef(node)
+
+ def parse_ClassDef(self, node):
+
+ edges, class_node_name = self.generic_parse(node, [])
+
+ self.scope.append(class_node_name)
+ self.parse_in_context(class_node_name, "class", edges, node.body)
+ self.scope.pop(-1)
+
+ ext_edges, cls_name = self.parse_as_mention(name=node.name)
+ edges.extend(ext_edges)
+ edges.append({"scope": copy(self.scope[-1]), "src": class_node_name, "dst": cls_name, "type": "class_name"})
+ edges.append({"scope": copy(self.scope[-1]), "src": cls_name, "dst": class_node_name, "type": "class_name_rev"})
+
+ return edges, class_node_name
+
+ def parse_arg(self, node):
+ # node.annotation stores type annotation
+ # if node.annotation:
+ # print(self.source[node.lineno-1]) # can get definition string here
+ # print(node.arg)
+
+ # # included mention
+ name = self.get_name(node=node)
+ edges, mention_name = self.parse_as_mention(node.arg)
+ edges.append({"scope": copy(self.scope[-1]), "src": mention_name, "dst": name, "type": 'arg'})
+ edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": mention_name, "type": 'arg_rev'})
+ # edges, name = self.generic_parse(node, ["arg"])
+ if node.annotation is not None:
+ # can contain quotes
+ # https://stackoverflow.com/questions/46458470/should-you-put-quotes-around-type-annotations-in-python
+ # https://www.python.org/dev/peps/pep-0484/#forward-references
+ annotation_string = self.get_source_from_ast_range(
+ node.annotation.lineno, node.annotation.end_lineno,
+ node.annotation.col_offset, node.annotation.end_col_offset
+ )
+ annotation = GNode(name=annotation_string,
+ type="type_annotation")
+ edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset})
+ # do not use reverse edges for types, will result in leak from function to function
+ # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'})
+ return edges, name
+ # return self.generic_parse(node, ["arg", "annotation"])
+
+ def parse_arguments(self, node):
+ # have missing fields. example:
+ # arguments(args=[arg(arg='self', annotation=None), arg(arg='tqdm_cls', annotation=None), arg(arg='sleep_interval', annotation=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[])
+
+ # vararg constains type annotations
+ return self.generic_parse(node, ["args", "vararg"])
+
+ def parse_With(self, node):
+ edges, with_name = self.generic_parse(node, ["items"])
+
+ self.parse_in_context(with_name, "with", edges, node.body)
+
+ return edges, with_name
+
+ def parse_AsyncWith(self, node):
+ return self.parse_With(node)
+
+ def parse_withitem(self, node):
+ return self.generic_parse(node, ['context_expr', 'optional_vars'])
+
+ def parse_ExceptHandler(self, node):
+ # have missing fields. example:
+ # not parsing "name" field
+ # except handler is unique for every function
+ return self.generic_parse(node, ["type"])
+
+ def parse_name(self, node):
+ edges = []
+ # if type(node) == ast.Attribute:
+ # left, ext_edges = self.parse(node.value)
+ # right = node.attr
+ # return self.parse(node.value) + "___" + node.attr
+ if type(node) == ast.Name:
+ return self.parse_as_mention(str(node.id))
+ elif type(node) == ast.NameConstant:
+ return GNode(name=str(node.value), type="NameConstant")
+
+ def parse_Name(self, node):
+ return self.parse_name(node)
+
+ def parse_op_name(self, node):
+ return GNode(name=node.__class__.__name__, type="Op")
+ # return node.__class__.__name__
+
+ def parse_Str(self, node):
+ return self.generic_parse(node, [])
+ # return node.s
+
+ def parse_If(self, node):
+
+ edges, if_name = self.generic_parse(node, ["test"])
+
+ self.parse_in_context(if_name, "if_true", edges, node.body)
+ self.parse_in_context(if_name, "if_false", edges, node.orelse)
+
+ return edges, if_name
+
+ def parse_For(self, node):
+
+ edges, for_name = self.generic_parse(node, ["target", "iter"])
+
+ self.parse_in_context(for_name, "for", edges, node.body)
+ self.parse_in_context(for_name, "for_orelse", edges, node.orelse)
+
+ return edges, for_name
+
+ def parse_AsyncFor(self, node):
+ return self.parse_For(node)
+
+ def parse_Try(self, node):
+
+ edges, try_name = self.generic_parse(node, [])
+
+ self.parse_in_context(try_name, "try", edges, node.body)
+
+ for h in node.handlers:
+
+ handler_name, ext_edges = self.parse_operand(h)
+ edges.extend(ext_edges)
+ self.parse_in_context([try_name, handler_name], ["try_except", "try_handler"], edges, h.body)
+
+ self.parse_in_context(try_name, "try_final", edges, node.finalbody)
+ self.parse_in_context(try_name, "try_else", edges, node.orelse)
+
+ return edges, try_name
+
+ def parse_While(self, node):
+
+ edges, while_name = self.generic_parse(node, [])
+
+ cond_name, ext_edges = self.parse_operand(node.test)
+ edges.extend(ext_edges)
+
+ self.parse_in_context([while_name, cond_name], ["while", "if_true"], edges, node.body)
+
+ return edges, while_name
+
+ def parse_Expr(self, node):
+ edges = []
+ expr_name, ext_edges = self.parse_operand(node.value)
+ edges.extend(ext_edges)
+
+ # for cond_name, cons_stat in zip(self.current_condition, self.condition_status):
+ # edges.append({"scope": copy(self.scope[-1]), "src": expr_name, "dst": cond_name, "type": "depends_on_" + cons_stat})
+ return edges, expr_name
+
+
+if __name__ == "__main__":
+ import sys
+ f_bodies = pd.read_pickle(sys.argv[1])
+ failed = 0
+ for ind, c in enumerate(f_bodies['body_normalized']):
+ if isinstance(c, str):
+ # c = """def g():
+ # yield 1"""
+ try:
+ g = AstGraphGenerator(c.lstrip())
+ except SyntaxError as e:
+ print(e)
+ continue
+ failed += 1
+ edges = g.get_edges()
+ # edges.to_csv(os.path.join(os.path.dirname(sys.argv[1]), "body_edges.csv"), mode="a", index=False, header=(ind==0))
+ print("\r%d/%d" % (ind, len(f_bodies['body_normalized'])), end="")
+ else:
+ print("Skipped not a string")
+
+ print(" " * 30, end="\r")
+ print(failed, len(f_bodies['body_normalized']))
diff --git a/SourceCodeTools/mltools/torch/__init__.py b/SourceCodeTools/mltools/torch/__init__.py
new file mode 100644
index 00000000..47e7becd
--- /dev/null
+++ b/SourceCodeTools/mltools/torch/__init__.py
@@ -0,0 +1,5 @@
+import torch
+
+
+def compute_accuracy(pred_, true_):
+ return torch.sum(pred_ == true_).item() / len(true_)
\ No newline at end of file
diff --git a/SourceCodeTools/models/graph/ElementEmbedder.py b/SourceCodeTools/models/graph/ElementEmbedder.py
index 56b111a7..4ade9c23 100644
--- a/SourceCodeTools/models/graph/ElementEmbedder.py
+++ b/SourceCodeTools/models/graph/ElementEmbedder.py
@@ -8,17 +8,25 @@
from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase
from SourceCodeTools.models.graph.train.Scorer import Scorer
-from SourceCodeTools.models.nlp.Encoder import LSTMEncoder, Encoder
+from SourceCodeTools.models.nlp.TorchEncoder import LSTMEncoder, Encoder
class GraphLinkSampler(ElementEmbedderBase, Scorer):
- def __init__(self, elements, nodes, compact_dst=True, dst_to_global=True, emb_size=None):
+ def __init__(
+ self, elements, nodes, compact_dst=True, dst_to_global=True, emb_size=None, device="cpu",
+ method="inner_prod", nn_index="brute", ns_groups=None
+ ):
assert emb_size is not None
- ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes, compact_dst=compact_dst, dst_to_global=dst_to_global)
- Scorer.__init__(self, num_embs=len(self.elements["dst"].unique()), emb_size=emb_size, src2dst=self.element_lookup)
+ ElementEmbedderBase.__init__(
+ self, elements=elements, nodes=nodes, compact_dst=compact_dst, dst_to_global=dst_to_global
+ )
+ Scorer.__init__(
+ self, num_embs=len(self.elements["dst"].unique()), emb_size=emb_size, src2dst=self.element_lookup,
+ device=device, method=method, index_backend=nn_index, ns_groups=ns_groups
+ )
def sample_negative(self, size, ids=None, strategy="closest"):
- if strategy == "w2v":
+ if strategy == "w2v" or self.scorer_index is None:
negative = ElementEmbedderBase.sample_negative(self, size)
else:
negative = Scorer.sample_closest_negative(self, ids, k=size // len(ids))
@@ -26,10 +34,12 @@ def sample_negative(self, size, ids=None, strategy="closest"):
return negative
-class ElementEmbedder(ElementEmbedderBase, nn.Module):
+class ElementEmbedder(ElementEmbedderBase, nn.Module, Scorer):
def __init__(self, elements, nodes, emb_size, compact_dst=True):
ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes, compact_dst=compact_dst)
nn.Module.__init__(self)
+ Scorer.__init__(self, num_embs=len(self.elements["dst"].unique()), emb_size=emb_size,
+ src2dst=self.element_lookup)
self.emb_size = emb_size
n_elems = self.elements['emb_id'].unique().size
@@ -39,9 +49,30 @@ def __init__(self, elements, nodes, emb_size, compact_dst=True):
def __getitem__(self, ids):
return torch.LongTensor(ElementEmbedderBase.__getitem__(self, ids=ids))
+ def sample_negative(self, size, ids=None, strategy="closest"):
+ # TODO
+ # Try other distributions
+ if strategy == "w2v":
+ negative = ElementEmbedderBase.sample_negative(self, size)
+ else:
+ ### negative = random.choices(Scorer.sample_closest_negative(self, ids), k=size)
+ negative = Scorer.sample_closest_negative(self, ids, k=size // len(ids))
+ assert len(negative) == size
+
+ return torch.LongTensor(negative)
+
def forward(self, input, **kwargs):
return self.norm(self.embed(input))
+ def set_embed(self):
+ all_keys = self.get_keys_for_scoring()
+ with torch.set_grad_enabled(False):
+ self.scorer_all_emb = self(torch.LongTensor(all_keys).to(self.embed.weight.device)).detach().cpu().numpy()
+
+ def prepare_index(self):
+ self.set_embed()
+ Scorer.prepare_index(self)
+
# def hashstr(s, num_buckets):
# return int(hashlib.md5(s.encode('utf8')).hexdigest(), 16) % num_buckets
diff --git a/SourceCodeTools/models/graph/ElementEmbedderBase.py b/SourceCodeTools/models/graph/ElementEmbedderBase.py
index 50d829b5..2b2727bd 100644
--- a/SourceCodeTools/models/graph/ElementEmbedderBase.py
+++ b/SourceCodeTools/models/graph/ElementEmbedderBase.py
@@ -16,7 +16,8 @@ def __init__(self, elements, nodes, compact_dst=True, dst_to_global=False):
def init(self, compact_dst):
if compact_dst:
elem2id, self.inverse_dst_map = compact_property(self.elements['dst'], return_order=True)
- self.elements['emb_id'] = self.elements['dst'].apply(lambda x: elem2id[x])
+ self.elements['emb_id'] = self.elements['dst'].apply(lambda x: elem2id.get(x, -1))
+ assert -1 not in self.elements['emb_id'].tolist()
else:
self.elements['emb_id'] = self.elements['dst']
@@ -49,7 +50,11 @@ def preprocess_element_data(self, element_data, nodes, compact_dst, dst_to_globa
id2typedid = dict(zip(nodes['id'].tolist(), nodes['typed_id'].tolist()))
id2type = dict(zip(nodes['id'].tolist(), nodes['type'].tolist()))
- def get_node_pools(element_data):
+ self.id2nodeid = id2nodeid
+ self.id2typedid = id2typedid
+ self.id2type = id2type
+
+ def get_node_pools(element_data): # create typed node list for possible use with dgl
node_typed_pools = {}
for orig_node_id in element_data['src']:
global_id = id2nodeid.get(orig_node_id, None)
@@ -128,9 +133,9 @@ def get_src_pool(self, ntypes=None):
# return {ntype: set(self.elements.query(f"src_type == '{ntype}'")['src_typed_id'].tolist()) for ntype in ntypes}
def _create_pools(self, train_idx, val_idx, test_idx, pool) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
- train_idx = np.fromiter(pool.intersection(train_idx.tolist()), dtype=np.int64)
- val_idx = np.fromiter(pool.intersection(val_idx.tolist()), dtype=np.int64)
- test_idx = np.fromiter(pool.intersection(test_idx.tolist()), dtype=np.int64)
+ train_idx = np.fromiter(pool.intersection(train_idx.reshape((-1,)).tolist()), dtype=np.int64)
+ val_idx = np.fromiter(pool.intersection(val_idx.reshape((-1,)).tolist()), dtype=np.int64)
+ test_idx = np.fromiter(pool.intersection(test_idx.reshape((-1,)).tolist()), dtype=np.int64)
return train_idx, val_idx, test_idx
def create_idx_pools(self, train_idx, val_idx, test_idx):
diff --git a/SourceCodeTools/models/graph/LinkPredictor.py b/SourceCodeTools/models/graph/LinkPredictor.py
index 340803c6..5320e772 100644
--- a/SourceCodeTools/models/graph/LinkPredictor.py
+++ b/SourceCodeTools/models/graph/LinkPredictor.py
@@ -49,21 +49,74 @@ class BilinearLinkPedictor(nn.Module):
def __init__(self, embedding_dim_1, embedding_dim_2, target_classes=2):
super(BilinearLinkPedictor, self).__init__()
- self.bilinear = nn.Bilinear(embedding_dim_1, embedding_dim_2, target_classes)
+ self.l1 = nn.Linear(embedding_dim_1, 300)
+ self.l2 = nn.Linear(embedding_dim_2, 300)
+ self.act = nn.Sigmoid()
+ self.bilinear = nn.Bilinear(300, 300, target_classes)
def forward(self, x1, x2):
- return self.bilinear(x1, x2)
+ return self.bilinear(self.act(self.l1(x1)), self.act(self.l2(x2)))
class CosineLinkPredictor(nn.Module):
- def __init__(self):
+ def __init__(self, margin=0.):
+ """
+ Dummy link predictor, using to keep API the same
+ """
super(CosineLinkPredictor, self).__init__()
self.cos = nn.CosineSimilarity()
- self.max_margin = torch.Tensor([0.4])
+ self.max_margin = torch.Tensor([margin])[0]
def forward(self, x1, x2):
if self.max_margin.device != x1.device:
self.max_margin = self.max_margin.to(x1.device)
+ # this will not train
logit = (self.cos(x1, x2) > self.max_margin).float().unsqueeze(1)
- return torch.cat([1 - logit, logit], dim=1)
\ No newline at end of file
+ return torch.cat([1 - logit, logit], dim=1)
+
+
+class L2LinkPredictor(nn.Module):
+ def __init__(self, margin=1.):
+ super(L2LinkPredictor, self).__init__()
+ self.margin = torch.Tensor([margin])[0]
+
+ def forward(self, x1, x2):
+ if self.margin.device != x1.device:
+ self.margin = self.margin.to(x1.device)
+ # this will not train
+ logit = (torch.norm(x1 - x2, dim=-1, keepdim=True) < self.margin).float()
+ return torch.cat([1 - logit, logit], dim=1)
+
+
+class TransRLinkPredictor(nn.Module):
+ def __init__(self, input_dim, rel_dim, num_relations, margin=0.3):
+ super(TransRLinkPredictor, self).__init__()
+ self.rel_dim = rel_dim
+ self.input_dim = input_dim
+ self.margin = margin
+
+ self.rel_emb = nn.Embedding(num_embeddings=num_relations, embedding_dim=rel_dim)
+ self.proj_matr = nn.Embedding(num_embeddings=num_relations, embedding_dim=input_dim * rel_dim)
+ self.triplet_loss = nn.TripletMarginLoss(margin=margin)
+
+ def forward(self, a, p, n, labels):
+ weights = self.proj_matr(labels).reshape((-1, self.rel_dim, self.input_dim))
+ rels = self.rel_emb(labels)
+ m_a = (weights * a.unsqueeze(1)).sum(-1)
+ m_p = (weights * p.unsqueeze(1)).sum(-1)
+ m_n = (weights * n.unsqueeze(1)).sum(-1)
+
+ transl = m_a + rels
+
+ sim = torch.norm(torch.cat([transl - m_p, transl - m_n], dim=0), dim=-1)
+
+ # pos_diff = torch.norm(transl - m_p, dim=-1)
+ # neg_diff = torch.norm(transl - m_n, dim=-1)
+
+ # loss = pos_diff + torch.maximum(torch.tensor([0.]).to(neg_diff.device), self.margin - neg_diff)
+ # return loss.mean(), sim
+ return self.triplet_loss(transl, m_p, m_n), sim < self.margin
+
+
+
diff --git a/SourceCodeTools/models/graph/NodeEmbedder.py b/SourceCodeTools/models/graph/NodeEmbedder.py
index c6149723..28ab6015 100644
--- a/SourceCodeTools/models/graph/NodeEmbedder.py
+++ b/SourceCodeTools/models/graph/NodeEmbedder.py
@@ -7,6 +7,9 @@ class NodeEmbedder(nn.Module):
def __init__(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=None):
super(NodeEmbedder, self).__init__()
+ self.init(nodes, emb_size, dtype, n_buckets, pretrained)
+
+ def init(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=None):
self.emb_size = emb_size
self.dtype = dtype
if dtype is None:
@@ -15,11 +18,13 @@ def __init__(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=Non
self.buckets = None
+ embedding_field = "embeddable_name"
+
nodes_with_embeddings = nodes.query("embeddable == True")[
- ['global_graph_id', 'typed_id', 'type', 'type_backup', 'name']
+ ['global_graph_id', 'typed_id', 'type', 'type_backup', embedding_field]
]
- type_name = list(zip(nodes_with_embeddings['type_backup'], nodes_with_embeddings['name']))
+ type_name = list(zip(nodes_with_embeddings['type_backup'], nodes_with_embeddings[embedding_field]))
self.node_info = dict(zip(
list(zip(nodes_with_embeddings['type'], nodes_with_embeddings['typed_id'])),
@@ -39,7 +44,7 @@ def __init__(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=Non
self._create_buckets_from_pretrained(pretrained)
def _create_buckets(self):
- self.buckets = nn.Embedding(self.n_buckets + 1, self.emb_size, padding_idx=self.n_buckets)
+ self.buckets = nn.Embedding(self.n_buckets + 1, self.emb_size, padding_idx=self.n_buckets, sparse=True)
def _create_buckets_from_pretrained(self, pretrained):
@@ -49,7 +54,7 @@ def _create_buckets_from_pretrained(self, pretrained):
weights_with_pad = torch.tensor(np.vstack([pretrained, np.zeros((1, self.emb_size), dtype=np.float32)]))
- self.buckets = nn.Embedding.from_pretrained(weights_with_pad, freeze=False, padding_idx=self.n_buckets)
+ self.buckets = nn.Embedding.from_pretrained(weights_with_pad, freeze=False, padding_idx=self.n_buckets, sparse=True)
def _get_embedding_from_node_info(self, keys, node_info, masked=None):
idxs = []
@@ -93,6 +98,42 @@ def forward(self, node_type=None, node_ids=None, train_embeddings=True, masked=N
return self.get_embeddings(node_type, node_ids.tolist(), masked=masked)
+class NodeIdEmbedder(NodeEmbedder):
+ def __init__(self, nodes=None, emb_size=None, dtype=None, n_buckets=500000, pretrained=None):
+ super(NodeIdEmbedder, self).__init__(nodes, emb_size, dtype, n_buckets, pretrained)
+
+ def init(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=None):
+ self.emb_size = emb_size
+ self.dtype = dtype
+ if dtype is None:
+ self.dtype = torch.float32
+ self.n_buckets = n_buckets
+
+ self.buckets = None
+
+ embedding_field = "embeddable_name"
+
+ nodes_with_embeddings = nodes.query("embeddable == True")[
+ ['global_graph_id', 'typed_id', 'type', 'type_backup', embedding_field]
+ ]
+
+ self.to_global_map = {}
+ for global_graph_id, typed_id, type_, type_backup, name in nodes_with_embeddings.values:
+ if type_ not in self.to_global_map:
+ self.to_global_map[type_] = {}
+
+ self.to_global_map[type_][typed_id] = global_graph_id
+
+ self._create_buckets()
+
+ def get_embeddings(self, node_type=None, node_ids=None, masked=None):
+ assert node_ids is not None
+ if node_type is not None:
+ node_ids = list(map(lambda local_id: self.to_global_map[node_type][local_id], node_ids))
+
+ return self.buckets(torch.LongTensor(node_ids))
+
+
# class SimpleNodeEmbedder(nn.Module):
# def __init__(self, dataset, emb_size, dtype=None, n_buckets=500000, pretrained=None):
# super(SimpleNodeEmbedder, self).__init__()
diff --git a/SourceCodeTools/models/graph/basis_gatconv.py b/SourceCodeTools/models/graph/basis_gatconv.py
new file mode 100644
index 00000000..641f7ba9
--- /dev/null
+++ b/SourceCodeTools/models/graph/basis_gatconv.py
@@ -0,0 +1,237 @@
+from dgl import DGLError
+from dgl.nn.functional import edge_softmax
+from dgl.nn.pytorch import GATConv
+from dgl.nn.pytorch.utils import Identity
+from dgl.utils import expand_as_pair
+import dgl.function as fn
+from torch import nn, softmax
+import torch as th
+from torch.utils import checkpoint
+
+
+class BasisGATConv(GATConv):
+ """
+ Does not seem to improve memory requirements
+ """
+ def __init__(self,
+ in_feats,
+ out_feats,
+ num_heads,
+ basis,
+ attn_basis,
+ basis_coef,
+ feat_drop=0.,
+ attn_drop=0.,
+ negative_slope=0.2,
+ residual=False,
+ activation=None,
+ allow_zero_in_degree=False,
+ bias=True,
+ use_checkpoint=False):
+ super(GATConv, self).__init__()
+ self._basis = basis
+ self._basis_coef = basis_coef
+ self._attn_basis = attn_basis
+
+ self._num_heads = num_heads
+ self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
+ self._out_feats = out_feats
+ self._allow_zero_in_degree = allow_zero_in_degree
+ # if isinstance(in_feats, tuple):
+ # self.fc_src = nn.Linear(
+ # self._in_src_feats, out_feats * num_heads, bias=False)
+ # self.fc_dst = nn.Linear(
+ # self._in_dst_feats, out_feats * num_heads, bias=False)
+ # else:
+ # self.fc = nn.Linear(
+ # self._in_src_feats, out_feats * num_heads, bias=False)
+ # self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
+ # self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
+ self.feat_drop = nn.Dropout(feat_drop)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.leaky_relu = nn.LeakyReLU(negative_slope)
+ if bias:
+ self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_feats,)))
+ else:
+ self.register_buffer('bias', None)
+ if residual:
+ if self._in_dst_feats != out_feats:
+ self.res_fc = nn.Linear(
+ self._in_dst_feats, num_heads * out_feats, bias=False)
+ else:
+ self.res_fc = Identity()
+ else:
+ self.register_buffer('res_fc', None)
+ self.reset_parameters()
+ self.activation = activation
+
+ self.dummy_tensor = th.ones(1, dtype=th.float32, requires_grad=True)
+ self.use_checkpoint = use_checkpoint
+
+ def reset_parameters(self):
+ """
+
+ Description
+ -----------
+ Reinitialize learnable parameters.
+
+ Note
+ ----
+ The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
+ The attention weights are using xavier initialization method.
+ """
+ gain = nn.init.calculate_gain('relu')
+ # if hasattr(self, 'fc'):
+ # nn.init.xavier_normal_(self.fc.weight, gain=gain)
+ # else:
+ # nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
+ # nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
+ # nn.init.xavier_normal_(self.attn_l, gain=gain)
+ # nn.init.xavier_normal_(self.attn_r, gain=gain)
+ nn.init.constant_(self.bias, 0)
+ if isinstance(self.res_fc, nn.Linear):
+ nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
+
+ def set_allow_zero_in_degree(self, set_value):
+ r"""
+
+ Description
+ -----------
+ Set allow_zero_in_degree flag.
+
+ Parameters
+ ----------
+ set_value : bool
+ The value to be set to the flag.
+ """
+ self._allow_zero_in_degree = set_value
+
+
+ def _forward(self, graph, feat, get_attention=False):
+ r"""
+
+ Description
+ -----------
+ Compute graph attention network layer.
+
+ Parameters
+ ----------
+ graph : DGLGraph
+ The graph.
+ feat : torch.Tensor or pair of torch.Tensor
+ If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
+ :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
+ If a pair of torch.Tensor is given, the pair must contain two tensors of shape
+ :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
+ get_attention : bool, optional
+ Whether to return the attention values. Default to False.
+
+ Returns
+ -------
+ torch.Tensor
+ The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
+ is the number of heads, and :math:`D_{out}` is size of output feature.
+ torch.Tensor, optional
+ The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
+ edges. This is returned only when :attr:`get_attention` is ``True``.
+
+ Raises
+ ------
+ DGLError
+ If there are 0-in-degree nodes in the input graph, it will raise DGLError
+ since no message will be passed to those nodes. This will cause invalid output.
+ The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
+ """
+ with graph.local_scope():
+ if not self._allow_zero_in_degree:
+ if (graph.in_degrees() == 0).any():
+ raise DGLError('There are 0-in-degree nodes in the graph, '
+ 'output for those nodes will be invalid. '
+ 'This is harmful for some applications, '
+ 'causing silent performance regression. '
+ 'Adding self-loop on the input graph by '
+ 'calling `g = dgl.add_self_loop(g)` will resolve '
+ 'the issue. Setting ``allow_zero_in_degree`` '
+ 'to be `True` when constructing this module will '
+ 'suppress the check and let the code run.')
+
+ if isinstance(feat, tuple):
+ h_src = self.feat_drop(feat[0])
+ h_dst = self.feat_drop(feat[1])
+ basis_coef = softmax(self._basis_coef, dim=-1).reshape(-1, 1, 1)
+ # if not hasattr(self, 'fc_src'):
+ params_src = (self._basis[0] * basis_coef).sum(dim=0)
+ params_dst = (self._basis[1] * basis_coef).sum(dim=0)
+ feat_src = (params_src @ h_src.T).view(-1, self._num_heads, self._out_feats)
+ feat_dst = (params_dst @ h_dst.T).view(-1, self._num_heads, self._out_feats)
+ # # feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
+ # # feat_dst = self.fc(h_dst).view(-1, self._num_heads, self._out_feats)
+ # else:
+ # params = self._basis * basis_coef
+ # feat_src = (params @ h_src.T).view(-1, self._num_heads, self._out_feats)
+ # feat_dst = (params @ h_dst.T).view(-1, self._num_heads, self._out_feats)
+ # # feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
+ # # feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
+ else:
+ h_src = h_dst = self.feat_drop(feat)
+ basis_coef = softmax(self._basis_coef, dim=-1).reshape(-1, 1, 1)
+ params = (self._basis * basis_coef).sum(dim=0)
+ feat_src = feat_dst = (params @ h_src.T).view(-1, self._num_heads, self._out_feats)
+ # feat_src = feat_dst = self.fc(h_src).view(
+ # -1, self._num_heads, self._out_feats)
+ if graph.is_block:
+ feat_dst = feat_src[:graph.number_of_dst_nodes()]
+ # NOTE: GAT paper uses "first concatenation then linear projection"
+ # to compute attention scores, while ours is "first projection then
+ # addition", the two approaches are mathematically equivalent:
+ # We decompose the weight vector a mentioned in the paper into
+ # [a_l || a_r], then
+ # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
+ # Our implementation is much efficient because we do not need to
+ # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
+ # addition could be optimized with DGL's built-in function u_add_v,
+ # which further speeds up computation and saves memory footprint.
+ attn_l_param = (self._attn_basis[0] * basis_coef).sum(dim=0)
+ attn_r_param = (self._attn_basis[1] * basis_coef).sum(dim=0)
+ el = (feat_src * attn_l_param).sum(dim=-1).unsqueeze(-1)
+ er = (feat_dst * attn_r_param).sum(dim=-1).unsqueeze(-1)
+ # el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
+ # er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
+ graph.srcdata.update({'ft': feat_src, 'el': el})
+ graph.dstdata.update({'er': er})
+ # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
+ graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
+ e = self.leaky_relu(graph.edata.pop('e'))
+ # compute softmax
+ graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
+ # message passing
+ graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
+ fn.sum('m', 'ft'))
+ rst = graph.dstdata['ft']
+ # residual
+ if self.res_fc is not None:
+ resval = self.res_fc(h_dst).view(h_dst.shape[0], self._num_heads, self._out_feats)
+ rst = rst + resval
+ # bias
+ if self.bias is not None:
+ rst = rst + self.bias.view(1, self._num_heads, self._out_feats)
+ # activation
+ if self.activation:
+ rst = self.activation(rst)
+
+ if get_attention:
+ return rst, graph.edata['a']
+ else:
+ return rst
+
+ def custom(self, graph, get_attention):
+ def custom_forward(*inputs):
+ feat0, feat1, dummy = inputs
+ return self._forward(graph, (feat0, feat1), get_attention=get_attention)
+ return custom_forward
+
+ def forward(self, graph, feat, get_attention=False):
+ if self.use_checkpoint:
+ return checkpoint.checkpoint(self.custom(graph, get_attention), feat[0], feat[1], self.dummy_tensor).squeeze(1)
+ else:
+ return self._forward(graph, feat, get_attention=get_attention).squeeze(1)
\ No newline at end of file
diff --git a/SourceCodeTools/models/graph/rgcn_sampling.py b/SourceCodeTools/models/graph/rgcn_sampling.py
index 42ed4536..ae71a788 100644
--- a/SourceCodeTools/models/graph/rgcn_sampling.py
+++ b/SourceCodeTools/models/graph/rgcn_sampling.py
@@ -7,6 +7,7 @@
import dgl.nn as dglnn
# import tqdm
from torch.utils import checkpoint
+from tqdm import tqdm
from SourceCodeTools.models.Embedder import Embedder
@@ -38,9 +39,9 @@ def custom_forward(*inputs):
def forward(self, graph, feat):
if self.use_checkpoint:
- return checkpoint.checkpoint(self.custom(graph), feat[0], feat[1], self.dummy_tensor)
+ return checkpoint.checkpoint(self.custom(graph), feat[0], feat[1], self.dummy_tensor) #.squeeze(1)
else:
- return super(CkptGATConv, self).forward(graph, feat)
+ return super(CkptGATConv, self).forward(graph, feat) #.squeeze(1)
class RelGraphConvLayer(nn.Module):
@@ -71,6 +72,7 @@ def __init__(self,
in_feat,
out_feat,
rel_names,
+ ntype_names,
num_bases,
*,
weight=True,
@@ -82,18 +84,13 @@ def __init__(self,
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
+ self.ntype_names = ntype_names
self.num_bases = num_bases
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.use_gcn_checkpoint = use_gcn_checkpoint
- # TODO
- # think of possibility switching to GAT
- # rel : dglnn.GATConv(in_feat, out_feat, num_heads=4)
- # rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False, allow_zero_in_degree=True)
- self.create_conv(in_feat, out_feat, rel_names)
-
self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
@@ -104,17 +101,33 @@ def __init__(self,
# nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_normal_(self.weight)
+ # TODO
+ # think of possibility switching to GAT
+ # rel : dglnn.GATConv(in_feat, out_feat, num_heads=4)
+ # rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False, allow_zero_in_degree=True)
+ self.create_conv(in_feat, out_feat, rel_names)
+
# bias
if bias:
- self.h_bias = nn.Parameter(th.Tensor(out_feat))
- nn.init.zeros_(self.h_bias)
+ self.bias_dict = nn.ParameterDict()
+ for ntype_name in self.ntype_names:
+ self.bias_dict[ntype_name] = nn.Parameter(torch.Tensor(1, out_feat))
+ nn.init.normal_(self.bias_dict[ntype_name])
+ # self.h_bias = nn.Parameter(th.Tensor(1, out_feat))
+ # nn.init.normal_(self.h_bias)
# weight for self loop
if self.self_loop:
+ # if self.use_basis:
+ # self.loop_weight_basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.ntype_names))
+ # else:
+ # self.loop_weight = nn.Parameter(th.Tensor(len(self.ntype_names), in_feat, out_feat))
+ # # nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
+ # nn.init.xavier_normal_(self.loop_weight)
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
- # nn.init.xavier_uniform_(self.loop_weight,
- # gain=nn.init.calculate_gain('relu'))
- nn.init.xavier_normal_(self.loop_weight)
+ nn.init.xavier_uniform_(self.loop_weight,
+ gain=nn.init.calculate_gain('tanh'))
+ # # nn.init.xavier_normal_(self.loop_weight)
self.dropout = nn.Dropout(dropout)
@@ -145,6 +158,10 @@ def forward(self, g, inputs):
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
+ # if self.self_loop:
+ # self_loop_weight = self.loop_weight_basis() if self.use_basis else self.loop_weight
+ # self_loop_wdict = {self.ntype_names[i]: w.squeeze(0)
+ # for i, w in enumerate(th.split(self_loop_weight, 1, dim=0))}
else:
wdict = {}
@@ -160,9 +177,11 @@ def forward(self, g, inputs):
def _apply(ntype, h):
if self.self_loop:
+ # h = h + th.matmul(inputs_dst[ntype], self_loop_wdict[ntype])
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias:
- h = h + self.h_bias
+ h = h + self.bias_dict[ntype]
+ # h = h + self.h_bias
if self.activation:
h = self.activation(h)
return self.dropout(h)
@@ -171,47 +190,47 @@ def _apply(ntype, h):
# return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
return {ntype : _apply(ntype, h).mean(1) for ntype, h in hs.items()}
-class RelGraphEmbed(nn.Module):
- r"""Embedding layer for featureless heterograph."""
- def __init__(self,
- g,
- embed_size,
- embed_name='embed',
- activation=None,
- dropout=0.0):
- super(RelGraphEmbed, self).__init__()
- self.g = g
- self.embed_size = embed_size
- self.embed_name = embed_name
- # self.activation = activation
- # self.dropout = nn.Dropout(dropout)
-
- # create weight embeddings for each node for each relation
- self.embeds = nn.ParameterDict()
- for ntype in g.ntypes:
- embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
- # TODO
- # watch for activation in init
- # nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
- nn.init.xavier_normal_(embed)
- self.embeds[ntype] = embed
-
- def forward(self, block=None):
- """Forward computation
-
- Parameters
- ----------
- block : DGLHeteroGraph, optional
- If not specified, directly return the full graph with embeddings stored in
- :attr:`embed_name`. Otherwise, extract and store the embeddings to the block
- graph and return.
-
- Returns
- -------
- DGLHeteroGraph
- The block graph fed with embeddings.
- """
- return self.embeds
+# class RelGraphEmbed(nn.Module):
+# r"""Embedding layer for featureless heterograph."""
+# def __init__(self,
+# g,
+# embed_size,
+# embed_name='embed',
+# activation=None,
+# dropout=0.0):
+# super(RelGraphEmbed, self).__init__()
+# self.g = g
+# self.embed_size = embed_size
+# self.embed_name = embed_name
+# # self.activation = activation
+# # self.dropout = nn.Dropout(dropout)
+#
+# # create weight embeddings for each node for each relation
+# self.embeds = nn.ParameterDict()
+# for ntype in g.ntypes:
+# embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
+# # TODO
+# # watch for activation in init
+# # nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
+# nn.init.xavier_normal_(embed)
+# self.embeds[ntype] = embed
+#
+# def forward(self, block=None):
+# """Forward computation
+#
+# Parameters
+# ----------
+# block : DGLHeteroGraph, optional
+# If not specified, directly return the full graph with embeddings stored in
+# :attr:`embed_name`. Otherwise, extract and store the embeddings to the block
+# graph and return.
+#
+# Returns
+# -------
+# DGLHeteroGraph
+# The block graph fed with embeddings.
+# """
+# return self.embeds
class RGCNSampling(nn.Module):
def __init__(self,
@@ -231,6 +250,8 @@ def __init__(self,
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
+ self.ntype_names = list(set(g.etypes))
+ self.ntype_names.sort()
if num_bases < 0 or num_bases > len(self.rel_names):
self.num_bases = len(self.rel_names)
else:
@@ -244,14 +265,14 @@ def __init__(self,
self.layer_norm = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvLayer(
- self.h_dim, self.h_dim, self.rel_names,
+ self.h_dim, self.h_dim, self.rel_names, self.ntype_names,
self.num_bases, activation=self.activation, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint))
self.layer_norm.append(nn.LayerNorm([self.h_dim]))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer(
- self.h_dim, self.h_dim, self.rel_names,
+ self.h_dim, self.h_dim, self.rel_names, self.ntype_names,
self.num_bases, activation=self.activation, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint)) # changed weight for GATConv
self.layer_norm.append(nn.LayerNorm([self.h_dim]))
@@ -260,7 +281,7 @@ def __init__(self,
# weight=False
# h2o
self.layers.append(RelGraphConvLayer(
- self.h_dim, self.out_dim, self.rel_names,
+ self.h_dim, self.out_dim, self.rel_names, self.ntype_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop, weight=False, use_gcn_checkpoint=use_gcn_checkpoint)) # changed weight for GATConv
self.layer_norm.append(nn.LayerNorm([self.out_dim]))
@@ -314,9 +335,10 @@ def forward(self, h, blocks=None,
# all_layers.append(h) # added this as an experimental feature for intermediate supervision
# else:
# minibatch training
+ h0 = h
for layer, norm, block in zip(self.layers, self.layer_norm, blocks):
# h = checkpoint.checkpoint(self.custom(layer), block, h)
- h = layer(block, h)
+ h = layer(block, h, h0)
h = self.normalize(h, norm)
all_layers.append(h) # added this as an experimental feature for intermediate supervision
@@ -332,6 +354,7 @@ def inference(self, batch_size, device, num_workers, x=None):
For node classification, the model is trained to predict on only one node type's
label. Therefore, only that type's final representation is meaningful.
"""
+ h0 = x
with th.set_grad_enabled(False):
@@ -355,7 +378,7 @@ def inference(self, batch_size, device, num_workers, x=None):
drop_last=False,
num_workers=num_workers)
- for input_nodes, output_nodes, blocks in dataloader:#tqdm.tqdm(dataloader):
+ for input_nodes, output_nodes, blocks in tqdm(dataloader, desc=f"Layer {l}"):
block = blocks[0].to(device)
if not isinstance(input_nodes, dict):
@@ -363,8 +386,9 @@ def inference(self, batch_size, device, num_workers, x=None):
input_nodes = {key: input_nodes}
output_nodes = {key: output_nodes}
+ _h0 = {k: h0[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
- h = layer(block, h)
+ h = layer(block, h, _h0)
h = self.normalize(h, norm)
for k in h.keys():
diff --git a/SourceCodeTools/models/graph/rggan.py b/SourceCodeTools/models/graph/rggan.py
index ec5fc160..dd6a2654 100644
--- a/SourceCodeTools/models/graph/rggan.py
+++ b/SourceCodeTools/models/graph/rggan.py
@@ -1,5 +1,7 @@
+import torch
from torch.utils import checkpoint
+# from SourceCodeTools.models.graph.basis_gatconv import BasisGATConv
from SourceCodeTools.models.graph.rgcn_sampling import RGCNSampling, RelGraphConvLayer, CkptGATConv
import torch as th
@@ -37,7 +39,7 @@ def do_stuff(self, query, key, value, dummy=None):
def forward(self, list_inputs, dsttype): # pylint: disable=unused-argument
if len(list_inputs) == 1:
return list_inputs[0]
- key = value = th.stack(list_inputs).squeeze(dim=1)
+ key = value = th.stack(list_inputs)#.squeeze(dim=1)
query = self.query_emb(th.LongTensor([token_hasher(dsttype, self.num_query_buckets)]).to(self.att.in_proj_bias.device)).unsqueeze(0).repeat(1, key.shape[1], 1)
# query = self.query_emb[token_hasher(dsttype, self.num_query_buckets)].unsqueeze(0).repeat(1, key.shape[1], 1)
if self.use_checkpoint:
@@ -45,14 +47,42 @@ def forward(self, list_inputs, dsttype): # pylint: disable=unused-argument
else:
att_out, att_w = self.do_stuff(query, key, value)
# att_out, att_w = self.att(query, key, value)
+ # return att_out.mean(0)#.unsqueeze(1)
return att_out.mean(0).unsqueeze(1)
+from torch.nn import init
+class NZBiasGraphConv(dglnn.GraphConv):
+ def __init__(self, *args, **kwargs):
+ super(NZBiasGraphConv, self).__init__(*args, **kwargs)
+ self.use_checkpoint = True
+ self.dummy_tensor = th.ones(1, dtype=th.float32, requires_grad=True)
+
+ def reset_parameters(self):
+ if self.weight is not None:
+ init.xavier_uniform_(self.weight)
+ if self.bias is not None:
+ init.normal_(self.bias)
+
+ def custom(self, graph):
+ def custom_forward(*inputs):
+ feat0, feat1, weight, _ = inputs
+ return super(NZBiasGraphConv, self).forward(graph, (feat0, feat1), weight=weight)
+ return custom_forward
+
+ def forward(self, graph, feat, weight=None):
+ if self.use_checkpoint:
+ return checkpoint.checkpoint(self.custom(graph), feat[0], feat[1], weight, self.dummy_tensor) #.squeeze(1)
+ else:
+ return super(NZBiasGraphConv, self).forward(graph, feat, weight=weight) #.squeeze(1)
+
+
class RGANLayer(RelGraphConvLayer):
def __init__(self,
in_feat,
out_feat,
rel_names,
+ ntype_names,
num_bases,
*,
weight=True,
@@ -62,17 +92,40 @@ def __init__(self,
dropout=0.0, use_gcn_checkpoint=False, use_att_checkpoint=False):
self.use_att_checkpoint = use_att_checkpoint
super(RGANLayer, self).__init__(
- in_feat, out_feat, rel_names, num_bases,
+ in_feat, out_feat, rel_names, ntype_names, num_bases,
weight=weight, bias=bias, activation=activation, self_loop=self_loop, dropout=dropout,
use_gcn_checkpoint=use_gcn_checkpoint
)
def create_conv(self, in_feat, out_feat, rel_names):
- self.attentive_aggregator = AttentiveAggregator(out_feat, use_checkpoint=self.use_att_checkpoint)
+ # self.attentive_aggregator = AttentiveAggregator(out_feat, use_checkpoint=self.use_att_checkpoint)
+ #
+ # num_heads = 1
+ # basis_size = 10
+ # basis = torch.nn.Parameter(torch.Tensor(2, basis_size, in_feat, out_feat * num_heads))
+ # attn_basis = torch.nn.Parameter(torch.Tensor(2, basis_size, num_heads, out_feat))
+ # basis_coef = nn.ParameterDict({rel: torch.nn.Parameter(torch.rand(basis_size,)) for rel in rel_names})
+ #
+ # torch.nn.init.xavier_normal_(basis, gain=1.)
+ # torch.nn.init.xavier_normal_(attn_basis, gain=1.)
+
self.conv = dglnn.HeteroGraphConv({
- rel: CkptGATConv((in_feat, in_feat), out_feat, num_heads=1, use_checkpoint=self.use_gcn_checkpoint)
+ rel: NZBiasGraphConv(
+ in_feat, out_feat, norm='right', weight=False, bias=True, allow_zero_in_degree=True,# activation=self.activation
+ )
+ # rel: BasisGATConv(
+ # (in_feat, in_feat), out_feat, num_heads=num_heads,
+ # basis=basis,
+ # attn_basis=attn_basis,
+ # basis_coef=basis_coef[rel],
+ # use_checkpoint=self.use_gcn_checkpoint
+ # )
for rel in rel_names
- }, aggregate=self.attentive_aggregator)
+ }, aggregate="mean") #self.attentive_aggregator)
+ # self.conv = dglnn.HeteroGraphConv({
+ # rel: CkptGATConv((in_feat, in_feat), out_feat, num_heads=num_heads, use_checkpoint=self.use_gcn_checkpoint)
+ # for rel in rel_names
+ # }, aggregate="mean") #self.attentive_aggregator)
class RGAN(RGCNSampling):
@@ -92,6 +145,8 @@ def __init__(self,
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
+ self.ntype_names = list(set(g.ntypes))
+ self.ntype_names.sort()
if num_bases < 0 or num_bases > len(self.rel_names):
self.num_bases = len(self.rel_names)
else:
@@ -103,14 +158,14 @@ def __init__(self,
self.layers = nn.ModuleList()
# i2h
self.layers.append(RGANLayer(
- self.h_dim, self.h_dim, self.rel_names,
+ self.h_dim, self.h_dim, self.rel_names, self.ntype_names,
self.num_bases, activation=self.activation, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint,
use_att_checkpoint=use_att_checkpoint))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RGANLayer(
- self.h_dim, self.h_dim, self.rel_names,
+ self.h_dim, self.h_dim, self.rel_names, self.ntype_names,
self.num_bases, activation=self.activation, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint,
use_att_checkpoint=use_att_checkpoint)) # changed weight for GATConv
@@ -119,7 +174,7 @@ def __init__(self,
# weight=False
# h2o
self.layers.append(RGANLayer(
- self.h_dim, self.out_dim, self.rel_names,
+ self.h_dim, self.out_dim, self.rel_names, self.ntype_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop, weight=False, use_gcn_checkpoint=use_gcn_checkpoint,
use_att_checkpoint=use_att_checkpoint)) # changed weight for GATConv
@@ -147,6 +202,7 @@ def __init__(self, dim, use_checkpoint=False):
self.dummy_tensor = th.ones(1, dtype=th.float32, requires_grad=True)
def do_stuff(self, x, h, dummy_tensor=None):
+ # x = x.unsqueeze(1)
r = self.act_r(self.gru_rx(x) + self.gru_rh(h))
z = self.act_z(self.gru_zx(x) + self.gru_zh(h))
n = self.act_n(self.gru_nx(x) + self.gru_nh(r * h))
@@ -165,6 +221,7 @@ def __init__(self,
in_feat,
out_feat,
rel_names,
+ ntype_names,
num_bases,
*,
weight=True,
@@ -173,14 +230,15 @@ def __init__(self,
self_loop=False,
dropout=0.0, use_gcn_checkpoint=False, use_att_checkpoint=False, use_gru_checkpoint=False):
super(RGGANLayer, self).__init__(
- in_feat, out_feat, rel_names, num_bases, weight=weight, bias=bias, activation=activation,
+ in_feat, out_feat, rel_names, ntype_names, num_bases, weight=weight, bias=bias, activation=activation,
self_loop=self_loop, dropout=dropout, use_gcn_checkpoint=use_gcn_checkpoint,
use_att_checkpoint=use_att_checkpoint
)
+ # self.mix_weights = nn.Parameter(torch.randn(3).reshape((3,1,1)))
- self.gru = OneStepGRU(out_feat, use_checkpoint=use_gru_checkpoint)
+ # self.gru = OneStepGRU(out_feat, use_checkpoint=use_gru_checkpoint)
- def forward(self, g, inputs):
+ def forward(self, g, inputs, h0):
"""Forward computation
Parameters
@@ -200,6 +258,10 @@ def forward(self, g, inputs):
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
+ # if self.self_loop:
+ # self_loop_weight = self.loop_weight_basis() if self.use_basis else self.loop_weight
+ # self_loop_wdict = {self.ntype_names[i]: w.squeeze(0)
+ # for i, w in enumerate(th.split(self_loop_weight, 1, dim=0))}
else:
wdict = {}
@@ -215,9 +277,15 @@ def forward(self, g, inputs):
def _apply(ntype, h):
if self.self_loop:
+ # h = h + th.matmul(inputs_dst[ntype], self_loop_wdict[ntype])
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
+ # mix = nn.functional.softmax(self.mix_weights, dim=0)
+ # h = h - inputs_dst[ntype] + th.matmul(inputs_dst[ntype], self.loop_weight) #+ h0[ntype][:h.size(0), :]
+ # h = h + th.matmul(inputs_dst[ntype], self.loop_weight) + h0[ntype][:h.size(0), :]
+ # h = (torch.stack([h, th.matmul(inputs_dst[ntype], self.loop_weight), h0[ntype][:h.size(0), :]], dim=0) * mix).sum(0)
if self.bias:
- h = h + self.h_bias
+ h = h + self.bias_dict[ntype]
+ # h = h + self.h_bias
if self.activation:
h = self.activation(h)
return self.dropout(h)
@@ -226,10 +294,10 @@ def _apply(ntype, h):
# TODO
# think of possibility switching to GAT
- # return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
- h_gru_input = {ntype : _apply(ntype, h) for ntype, h in hs.items()}
-
- return {dsttype: self.gru(h_dst, inputs_dst[dsttype].unsqueeze(1)).squeeze(dim=1) for dsttype, h_dst in h_gru_input.items()}
+ return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
+ # h_gru_input = {ntype : _apply(ntype, h) for ntype, h in hs.items()}
+ #
+ # return {dsttype: self.gru(h_dst, inputs_dst[dsttype].unsqueeze(1)).squeeze(dim=1) for dsttype, h_dst in h_gru_input.items()}
class RGGAN(RGAN):
"""A gated recurrent unit (GRU) cell
@@ -262,7 +330,9 @@ def __init__(self,
assert h_dim == num_classes, f"Parameter h_dim and num_classes should be equal in {self.__class__.__name__}"
self.rel_names = list(set(g.etypes))
+ self.ntype_names = list(set(g.ntypes))
self.rel_names.sort()
+ self.ntype_names.sort()
if num_bases < 0 or num_bases > len(self.rel_names):
self.num_bases = len(self.rel_names)
else:
@@ -273,9 +343,9 @@ def __init__(self,
# i2h
self.layer = RGGANLayer(
- self.h_dim, self.h_dim, self.rel_names,
+ self.h_dim, self.h_dim, self.rel_names, self.ntype_names,
self.num_bases, activation=self.activation, self_loop=self.use_self_loop,
- dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint,
+ dropout=self.dropout, weight=True, use_gcn_checkpoint=use_gcn_checkpoint, # : )
use_att_checkpoint=use_att_checkpoint, use_gru_checkpoint=use_gru_checkpoint
)
# TODO
diff --git a/SourceCodeTools/models/graph/train/Scorer.py b/SourceCodeTools/models/graph/train/Scorer.py
index 11fb736f..eba46e3a 100644
--- a/SourceCodeTools/models/graph/train/Scorer.py
+++ b/SourceCodeTools/models/graph/train/Scorer.py
@@ -1,47 +1,161 @@
import random
-from collections import Iterable
+import time
+from collections import Iterable, defaultdict
+from typing import Dict, List
import torch
import numpy as np
-from sklearn.metrics import ndcg_score
+from sklearn.metrics import ndcg_score, top_k_accuracy_score
+from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors._ball_tree import BallTree
from sklearn.preprocessing import normalize
class FaissIndex:
- def __init__(self, X, *args, **kwargs):
+ def __init__(self, X, method="inner_prod", *args, **kwargs):
import faiss
- self.index = faiss.IndexFlatL2(X.shape[1])
+ self.method = method
+ if method == "inner_prod":
+ self.index = faiss.IndexFlatIP(X.shape[1])
+ elif method == "l2":
+ self.index = faiss.IndexFlatL2(X.shape[1])
+ else:
+ raise NotImplementedError()
self.index.add(X.astype(np.float32))
- def query(self, X, k):
+ def query_inner_prod(self, X, k):
+ X = normalize(X, axis=1)
return self.index.search(X.astype(np.float32), k=k)
+ def query_l2(self, X, k):
+ return self.index.search(X.astype(np.float32), k=k)
+
+ def query(self, X, k):
+ if self.method == "inner_prod":
+ return self.query_inner_prod(X, k)
+ elif self.method == "l2":
+ return self.query_l2(X, k)
+ else:
+ raise NotImplementedError()
+
+
+class Brute:
+ def __init__(self, X, method="inner_prod", device="cpu", *args, **kwargs):
+ self.vectors = X
+ self.method = method
+ self.device = device
+
+ def query_inner_prod(self, X, k):
+ X = torch.Tensor(normalize(X, axis=1)).to(self.device)
+ vectors = torch.Tensor(self.vectors).to(self.device)
+ score = (vectors @ X.T).reshape(-1,).to("cpu").numpy()
+ ind = np.flip(np.argsort(score))[:k]
+ return score[ind][:k], ind
+
+ def query_l2(self, X, k):
+ vectors = torch.Tensor(self.vectors).to(self.device)
+ X = torch.Tensor(X).to(self.device)
+ score = torch.norm(vectors - X, dim=-1).to("cpu").numpy()
+ # score = np.linalg.norm(vectors - X).reshape(-1,)
+ ind = np.argsort(score)[:k]
+ return score[ind][:k], ind
+
+ def query(self, X, k):
+ assert X.shape[0] == 1
+ if self.method == "inner_prod":
+ dist, ind = self.query_inner_prod(X, k)
+ elif self.method == "l2":
+ dist, ind = self.query_l2(X, k)
+ else:
+ raise NotImplementedError()
+
+ return dist, ind#.reshape((-1,1))
+
class Scorer:
"""
- Implements sampler for hard triplet loss. This sampler is useful when the loss is based on the neighbourhood
+ Implements sampler for triplet loss. This sampler is useful when the loss is based on the neighbourhood
similarity. It becomes less useful when the decision is made by neural network because it does not need to mode
points to learn how to make correct decisions.
"""
- def __init__(self, num_embs, emb_size, src2dst, neighbours_to_sample=5, index_backend="sklearn"):
+ def __init__(
+ self, num_embs, emb_size, src2dst: Dict[int, List[int]], neighbours_to_sample=5, index_backend="brute",
+ method = "inner_prod", device="cpu", ns_groups=None
+ ):
+ """
+ Creates an embedding table, the embeddings in this table are updated once during an epoch. Embeddings from this
+ table are used for nearest neighbour queries during negative sampling. We avoid keeping track of all possible
+ embeddings by knowing that only part of embeddings are eligible as DST.
+ :param num_embs: number of unique DST
+ :param emb_size: embedding dimensionality
+ :param src2dst: Mapping from SRC to all DST, need this to find the hardest negative example for all DST at once
+ :param neighbours_to_sample: default number of neighbours
+ :param index_backend: Choose between sklearn and faiss
+ """
self.scorer_num_emb = num_embs
self.scorer_emb_size = emb_size
- self.scorer_src2dst = src2dst
+ self.scorer_src2dst = src2dst # mapping from src to all possible dst
self.scorer_index_backend = index_backend
+ self.scorer_method = method
+ self.scorer_device = device
- self.scorer_all_emb = np.ones((num_embs, emb_size))
+ self.scorer_all_emb = normalize(np.ones((num_embs, emb_size)), axis=1) # unique dst embedding table
self.scorer_all_keys = self.get_cand_to_score_against(None)
self.scorer_key_order = dict(zip(self.scorer_all_keys, range(len(self.scorer_all_keys))))
self.scorer_index = None
self.neighbours_to_sample = min(neighbours_to_sample, self.scorer_num_emb)
+ self.prepare_ns_groups(ns_groups)
+
+ def prepare_ns_groups(self, ns_groups):
+ if ns_groups is None:
+ return
+
+ self.scorer_node2ns_group = {}
+ self.scorer_ns_group2nodes = defaultdict(list)
+
+ unique_dst = set()
+ for dsts in self.scorer_src2dst.values():
+ for dst in dsts:
+ if isinstance(dst, tuple):
+ unique_dst.add(dst[0])
+ else:
+ unique_dst.add(dst)
+
+ for id, mentioned_in in ns_groups.values:
+ id_ = self.id2nodeid[id]
+ mentioned_in_ = self.id2nodeid[mentioned_in]
+ self.scorer_node2ns_group[id_] = mentioned_in_
+ if id_ in unique_dst:
+ self.scorer_ns_group2nodes[mentioned_in_].append(id_)
+
+ def sample_negative_from_groups(self, key_groups, k):
+ possible_targets = []
+ for key_group in key_groups:
+ any_key = key_group[0]
+ possible_negative = self.scorer_ns_group2nodes[self.scorer_node2ns_group[any_key]]
+ possible_negative_ = list(set(possible_negative) - set(key_group))
+ if len(possible_negative_) < k:
+ # backup strategy
+ possible_negative_.extend(
+ random.choices(list(set(self.scorer_all_keys) - set(key_group)), k=k - len(possible_negative_))
+ )
+ possible_targets.append(possible_negative_)
+ return possible_targets
- def prepare_index(self):
+
+ def prepare_index(self, override_strategy=None):
+ if self.scorer_method == "nn":
+ self.scorer_index = None
+ return
if self.scorer_index_backend == "sklearn":
- self.scorer_index = BallTree(self.scorer_all_emb, leaf_size=1)
+ self.scorer_index = NearestNeighbors()
+ self.scorer_index.fit(self.scorer_all_emb)
+ # self.scorer_index = BallTree(self.scorer_all_emb, leaf_size=1)
# self.scorer_index = BallTree(normalize(self.scorer_all_emb, axis=1), leaf_size=1)
elif self.scorer_index_backend == "faiss":
- self.scorer_index = FaissIndex(self.scorer_all_emb)
+ self.scorer_index = FaissIndex(self.scorer_all_emb, method=self.scorer_method)
+ elif self.scorer_index_backend == "brute":
+ self.scorer_index = Brute(self.scorer_all_emb, method=self.scorer_method, device=self.scorer_device)
else:
raise ValueError(f"Unsupported backend: {self.scorer_index_backend}. Supported backends are: sklearn|faiss")
@@ -51,8 +165,18 @@ def sample_closest_negative(self, ids, k=None):
assert ids is not None
seed_pool = []
- [seed_pool.append(self.scorer_src2dst[id]) for id in ids]
- nested_negative = self.get_closest_to_keys(seed_pool, k=k+1)
+ for id in ids:
+ pool = self.scorer_src2dst[id]
+ if len(pool) > 0 and isinstance(pool[0], tuple):
+ pool = [x[0] for x in pool]
+ seed_pool.append(pool)
+ if id in self.scorer_key_order:
+ seed_pool[-1] = seed_pool[-1] + [id] # make sure that original list is not changed
+ # [seed_pool.append(self.scorer_src2dst[id]) for id in ids]
+ if hasattr(self, "scorer_ns_group2nodes"):
+ nested_negative = self.sample_negative_from_groups(seed_pool, k=k+1)
+ else:
+ nested_negative = self.get_closest_to_keys(seed_pool, k=k+1)
negative = []
for neg in nested_negative:
@@ -68,10 +192,11 @@ def get_closest_to_keys(self, key_groups, k=None):
self.scorer_all_emb[self.scorer_key_order[key]].reshape(1, -1), k=k
)
closest_keys.extend(self.scorer_all_keys[c] for c in closest.ravel())
+ # ensure that negative samples do not come from positive edges
closest_keys_ = list(set(closest_keys) - set(key_group))
if len(closest_keys_) == 0:
# backup strategy
- closest_keys_ = random.choices(list(set(self.scorer_all_keys) - set(key_group)), k=k-1)
+ closest_keys_ = random.choices(list(set(self.scorer_all_keys) - set(key_group)), k=k)
possible_targets.append(closest_keys_)
# possible_targets.extend(self.scorer_all_keys[c] for c in closest.ravel())
return possible_targets
@@ -80,75 +205,263 @@ def get_closest_to_keys(self, key_groups, k=None):
def set_embed(self, ids, embs):
ids = np.array(list(map(self.scorer_key_order.get, ids.tolist())))
- self.scorer_all_emb[ids, :] = embs# normalize(embs, axis=1)
+ self.scorer_all_emb[ids, :] = normalize(embs, axis=1) if self.scorer_method == "inner_prod" else embs
# for ind, id in enumerate(ids):
# self.all_embs[self.key_order[id], :] = embs[ind, :]
def score_candidates_cosine(self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=None):
- score_matr = (to_score_embs @ embs_to_score_against.t()) / \
- to_score_embs.norm(p=2, dim=1, keepdim=True) / \
- embs_to_score_against.norm(p=2, dim=1, keepdim=True).t()
- y_pred = score_matr.tolist()
+ to_score_embs = to_score_embs / to_score_embs.norm(p=2, dim=1, keepdim=True)
+ embs_to_score_against = embs_to_score_against / embs_to_score_against.norm(p=2, dim=1, keepdim=True)
+
+ score_matr = (to_score_embs @ embs_to_score_against.t())
+ score_matr = (score_matr + 1.) / 2.
+ # score_matr = score_matr - self.margin
+ # score_matr[score_matr < 0.] = 0.
+ y_pred = score_matr.cpu().tolist()
return y_pred
- def score_candidates_lp(self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, link_predictor, at=None):
+ def set_margin(self, margin):
+ self.margin = margin
+
+ def score_candidates_l2(self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=None):
y_pred = []
for i in range(len(to_score_ids)):
- input_embs = to_score_embs[i, :].repeat((embs_to_score_against.shape[0], 1))
- # predictor_input = torch.cat([input_embs, all_emb], dim=1)
- y_pred.append(torch.nn.functional.softmax(link_predictor(input_embs, embs_to_score_against), dim=1)[:, 1].tolist()) # 0 - negative, 1 - positive
+ input_embs = to_score_embs[i, :].reshape(1, -1)
+ score_matr = torch.norm(embs_to_score_against - input_embs, dim=-1)
+ score_matr = 1. / (1. + score_matr)
+ # score_matr = score_matr + self.margin
+ # score_matr[score_matr < 0.] = 0
+ y_pred.append(score_matr.cpu().tolist())
+
+ # embs_to_score_against = embs_to_score_against.unsqueeze(0)
+ # to_score_embs = to_score_embs.unsqueeze(1)
+ #
+ # score_matr = -torch.norm(embs_to_score_against - to_score_embs, dim=-1)
+ # score_matr = score_matr + self.margin
+ # score_matr[score_matr < 0.] = 0
+ # y_pred = score_matr.cpu().tolist()
return y_pred
+ def score_candidates_lp(
+ self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, link_predictor, at=None,
+ with_types=None
+ ):
+
+ if with_types is None:
+ y_pred = []
+ for i in range(len(to_score_ids)):
+ input_embs = to_score_embs[i, :].repeat((embs_to_score_against.shape[0], 1))
+ # predictor_input = torch.cat([input_embs, all_emb], dim=1)
+ y_pred.append(
+ torch.nn.functional.softmax(link_predictor(input_embs, embs_to_score_against), dim=1)[:, 1].tolist()
+ ) # 0 - negative, 1 - positive
+
+ return y_pred
+
+ else:
+ y_pred = []
+ for i in range(len(to_score_ids)):
+ y_pred.append(dict())
+ for type in with_types[i]:
+ input_embs = to_score_embs[i, :].unsqueeze(0)
+ # predictor_input = torch.cat([input_embs, all_emb], dim=1)
+
+ labels = torch.LongTensor([type]).to(link_predictor.proj_matr.weight.device)
+ weights = link_predictor.proj_matr(labels).reshape((-1, link_predictor.rel_dim, link_predictor.input_dim))
+ rels = link_predictor.rel_emb(labels)
+ m_a = (weights * input_embs.unsqueeze(1)).sum(-1)
+ m_s = (weights * embs_to_score_against.unsqueeze(1)).sum(-1)
+
+ transl = m_a + rels
+ sim = torch.norm(transl - m_s, dim=-1)
+ sim = 1./ (1. + sim)
+ y_pred[-1][type] = sim.cpu().tolist()
+
+ return y_pred
+
+
def get_gt_candidates(self, ids):
candidates = [set(list(self.scorer_src2dst[id])) for id in ids]
return candidates
def get_cand_to_score_against(self, ids):
+ """
+ Generate sorted list of all possible DST. These will be used as possible targets during NDCG calculation
+ :param ids:
+ :return:
+ """
all_keys = set()
- [all_keys.update(self.scorer_src2dst[key]) for key in self.scorer_src2dst]
+ for key in self.scorer_src2dst:
+ cand = self.scorer_src2dst[key]
+ for c in cand:
+ if isinstance(c, tuple): # happens with graphlinkclassifier objectives
+ all_keys.add(c[0])
+ else:
+ all_keys.add(c)
+
+ # [all_keys.update(self.scorer_src2dst[key]) for key in self.scorer_src2dst]
return sorted(list(all_keys)) # list(self.elem2id[a] for a in all_keys)
def get_embeddings_for_scoring(self, device, **kwargs):
+ """
+ Get all embeddings as a tensor
+ :param device:
+ :param kwargs:
+ :return:
+ """
return torch.Tensor(self.scorer_all_emb).to(device)
def get_keys_for_scoring(self):
return self.scorer_all_keys
+ def hits_at_k(self, y_true, y_pred, k):
+ correct = y_true
+ predicted = y_pred
+ result = []
+ for y_true, y_pred in zip(correct, predicted):
+ ind_true = set([ind for ind, y_t in enumerate(y_true) if y_t == 1.])
+ ind_pred = set(list(np.flip(np.argsort(y_pred))[:k]))
+ result.append(len(ind_pred.intersection(ind_true)) / min(len(ind_true), k))
+
+ return sum(result) / len(result)
+
+ def mean_rank(self, y_true, y_pred):
+ true_ranks = []
+ reciprocal_ranks = []
+
+ correct = y_true
+ predicted = y_pred
+ for y_true, y_pred in zip(correct, predicted):
+ ranks = sorted(zip(y_true, y_pred), key=lambda x: x[1], reverse=True)
+ for ind, (true, pred) in enumerate(ranks):
+ if true > 0.:
+ true_ranks.append(ind + 1)
+ reciprocal_ranks.append(1 / (ind + 1))
+ break # should consider only first result https://en.wikipedia.org/wiki/Mean_reciprocal_rank
+
+ return sum(true_ranks) / len(true_ranks), sum(reciprocal_ranks) / len(reciprocal_ranks)
+
+ def mean_average_precision(self, y_true, y_pred):
+ correct = y_true
+ predicted = y_pred
+ map = 0.
+ for y_true, y_pred in zip(correct, predicted):
+ ranks = sorted(zip(y_true, y_pred), key=lambda x: x[1], reverse=True)
+ found_relevant = 0
+ avep = 0.
+ for ind, (true, pred) in enumerate(ranks):
+ if true > 0.:
+ found_relevant += 1
+ avep += found_relevant / (ind + 1) # precision@k
+ avep /= found_relevant
+
+ map += avep
+
+ map /= len(correct)
+
+ return map
+
+ def get_y_true_from_candidate_list(self, candidates, keys_to_score_against):
+ y_true = [[1. if key in cand else 0. for key in keys_to_score_against] for cand in candidates]
+ return y_true
+
+ def get_y_true_from_candidates(self, candidates, keys_to_score_against):
+
+ has_types = isinstance(list(candidates[0])[0], tuple)
+
+ if not has_types:
+ return self.get_y_true_from_candidate_list(candidates, keys_to_score_against)
+
+ candidate_dicts = []
+ for cand in candidates:
+ candidate_dicts.append(dict())
+ for ent, type in cand:
+ if type not in candidate_dicts[-1]:
+ candidate_dicts[-1][type] = []
+ candidate_dicts[-1][type].append(ent)
+
+ y_true = []
+ for cand in candidate_dicts:
+ y_true.append(dict())
+ for type, ents in cand.items():
+ y_true[-1][type] = [1. if key in ents else 0. for key in keys_to_score_against]
+ assert sum(y_true[-1][type]) > 0.
+
+ return y_true
+
+ def flatten_pred(self, y):
+
+ flattened = []
+ for x in y:
+ for key, scores in x.items():
+ flattened.append(scores)
+
+ return flattened
+
def score_candidates(self, to_score_ids, to_score_embs, link_predictor=None, at=None, type=None, device="cpu"):
if at is None:
at = [1, 3, 5, 10]
+ start = time.time()
+
to_score_ids = to_score_ids.tolist()
- candidates = self.get_gt_candidates(to_score_ids)
+ candidates = self.get_gt_candidates(to_score_ids) # positive candidates
# keys_to_score_against = self.get_cand_to_score_against(to_score_ids)
keys_to_score_against = self.get_keys_for_scoring()
- y_true = [[1. if key in cand else 0. for key in keys_to_score_against] for cand in candidates]
+ y_true = self.get_y_true_from_candidates(candidates, keys_to_score_against)
embs_to_score_against = self.get_embeddings_for_scoring(device=to_score_embs.device)
if type == "nn":
+ has_types = isinstance(y_true[0], dict)
+
y_pred = self.score_candidates_lp(
to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against,
- link_predictor, at=at
+ link_predictor, at=at, with_types=y_true if has_types else None
)
elif type == "inner_prod":
y_pred = self.score_candidates_cosine(
to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=at
)
+ elif type == "l2":
+ y_pred = self.score_candidates_l2(
+ to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=at
+ )
else:
raise ValueError(f"`type` can be either `nn` or `inner_prod` but `{type}` given")
+ has_types = isinstance(y_pred[0], dict)
+
+ if has_types:
+ y_true = self.flatten_pred(y_true)
+ y_pred = self.flatten_pred(y_pred)
+
+ scores = {}
+ # y_true_onehot = np.array(y_true)
+ # labels=list(range(y_true_onehot.shape[1]))
+
if isinstance(at, Iterable):
- scores = {f"ndcg@{k}": ndcg_score(y_true, y_pred, k=k) for k in at}
+ scores.update({f"hits@{k}": self.hits_at_k(y_true, y_pred, k=k) for k in at})
+ scores.update({f"ndcg@{k}": ndcg_score(y_true, y_pred, k=k) for k in at})
+ # scores = {f"ndcg@{k}": ndcg_score(y_true, y_pred, k=k) for k in at}
else:
- scores = {f"ndcg@{at}": ndcg_score(y_true, y_pred, k=at)}
- return scores
\ No newline at end of file
+ scores.update({f"hits@{at}": self.hits_at_k(y_true, y_pred, k=at)})
+ scores.update({f"ndcg@{at}": ndcg_score(y_true, y_pred, k=at)})
+ # scores = {f"ndcg@{at}": ndcg_score(y_true, y_pred, k=at)}
+
+ mr, mrr = self.mean_rank(y_true, y_pred)
+ map = self.mean_average_precision(y_true, y_pred)
+ scores["mr"] = mr
+ scores["mrr"] = mrr
+ scores["map"] = map
+ scores["scoring_time"] = time.time() - start
+ return scores
diff --git a/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py b/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py
index f17197cf..d1f361b7 100644
--- a/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py
+++ b/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py
@@ -11,6 +11,7 @@
from os.path import join
import logging
+from SourceCodeTools.mltools.torch import compute_accuracy
from SourceCodeTools.models.Embedder import Embedder
from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedderWithBpeSubwords
from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase
@@ -19,10 +20,6 @@
from SourceCodeTools.models.graph.NodeEmbedder import NodeEmbedder
-def _compute_accuracy(pred_, true_):
- return torch.sum(pred_ == true_).item() / len(true_)
-
-
class SamplingMultitaskTrainer:
def __init__(self,
@@ -382,7 +379,7 @@ def _evaluate_embedder(self, ee, lp, loader, neg_sampling_factor=1):
logp = nn.functional.log_softmax(logits, 1)
loss = nn.functional.cross_entropy(logp, labels)
- acc = _compute_accuracy(logp.argmax(dim=1), labels)
+ acc = compute_accuracy(logp.argmax(dim=1), labels)
total_loss += loss.item()
total_acc += acc
@@ -405,7 +402,7 @@ def _evaluate_nodes(self, ee, lp, create_api_call_loader, loader,
logp = nn.functional.log_softmax(logits, 1)
loss = nn.functional.cross_entropy(logp, labels)
- acc = _compute_accuracy(logp.argmax(dim=1), labels)
+ acc = compute_accuracy(logp.argmax(dim=1), labels)
total_loss += loss.item()
total_acc += acc
@@ -560,9 +557,9 @@ def train_all(self):
input_nodes_api_call, seeds_api_call, blocks_api_call
)
- train_acc_node_name = _compute_accuracy(logits_node_name.argmax(dim=1), labels_node_name)
- train_acc_var_use = _compute_accuracy(logits_var_use.argmax(dim=1), labels_var_use)
- train_acc_api_call = _compute_accuracy(logits_api_call.argmax(dim=1), labels_api_call)
+ train_acc_node_name = compute_accuracy(logits_node_name.argmax(dim=1), labels_node_name)
+ train_acc_var_use = compute_accuracy(logits_var_use.argmax(dim=1), labels_var_use)
+ train_acc_api_call = compute_accuracy(logits_api_call.argmax(dim=1), labels_api_call)
train_logits = torch.cat([logits_node_name, logits_var_use, logits_api_call], 0)
train_labels = torch.cat([labels_node_name, labels_var_use, labels_api_call], 0)
diff --git a/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py b/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py
index 93b10249..1edaaa12 100644
--- a/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py
+++ b/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py
@@ -1,52 +1,97 @@
import logging
from abc import abstractmethod
-from collections import OrderedDict
-from itertools import chain
+from collections import defaultdict
import dgl
import torch
from torch.nn import CosineEmbeddingLoss
+from tqdm import tqdm
from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker
-from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedderWithBpeSubwords, NameEmbedderWithGroups, \
- GraphLinkSampler
+from SourceCodeTools.mltools.torch import compute_accuracy
+from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedderWithBpeSubwords, GraphLinkSampler
from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase
-from SourceCodeTools.models.graph.LinkPredictor import LinkPredictor, CosineLinkPredictor, BilinearLinkPedictor
+from SourceCodeTools.models.graph.LinkPredictor import CosineLinkPredictor, BilinearLinkPedictor, L2LinkPredictor
import torch.nn as nn
-def _compute_accuracy(pred_: torch.Tensor, true_: torch.Tensor):
- return torch.sum(pred_ == true_).item() / len(true_)
+class ZeroEdges(Exception):
+ def __init__(self, *args):
+ super(ZeroEdges, self).__init__(*args)
+
+
+class EarlyStoppingTracker:
+ def __init__(self, early_stopping_tolerance):
+ self.early_stopping_tolerance = early_stopping_tolerance
+ self.early_stopping_counter = 0
+ self.early_stopping_value = 0.
+ self.early_stopping_trigger = False
+
+ def should_stop(self, metric):
+ if metric <= self.early_stopping_value:
+ self.early_stopping_counter += 1
+ if self.early_stopping_counter >= self.early_stopping_tolerance:
+ return True
+ else:
+ self.early_stopping_counter = 0
+ self.early_stopping_value = metric
+ return False
+
+ def reset(self):
+ self.early_stopping_counter = 0
+ self.early_stopping_value = 0.
+
+
+def sum_scores(s):
+ n = len(s)
+ if n == 0:
+ n += 1
+ return sum(s) / n
class AbstractObjective(nn.Module):
+ # # set in the init
+ # name = None
+ # graph_model = None
+ # sampling_neighbourhood_size = None
+ # batch_size = None
+ # target_emb_size = None
+ # node_embedder = None
+ # device = None
+ # masker = None
+ # link_predictor_type = None
+ # measure_scores = None
+ # dilate_scores = None
+ # early_stopping_tracker = None
+ # early_stopping_trigger = None
+ #
+ # # set elsewhere
+ # target_embedder = None
+ # link_predictor = None
+ # positive_label = None
+ # negative_label = None
+ # label_dtype = None
+ #
+ # train_loader = None
+ # test_loader = None
+ # val_loader = None
+ # num_train_batches = None
+ # num_test_batches = None
+ # num_val_batches = None
+ #
+ # ntypes = None
+
def __init__(
- # self, name, objective_type, graph_model, node_embedder, nodes, data_loading_func, device,
self, name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
- tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker=None,
- measure_ndcg=False, dilate_ndcg=1
+ tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
+ measure_scores=False, dilate_scores=1, early_stopping=False, early_stopping_tolerance=20, nn_index="brute",
+ ns_groups=None
):
- """
- :param name: name for reference
- :param objective_type: one of: graph_link_prediction|graph_link_classification|subword_ranker|node_classification
- :param graph_model:
- :param nodes:
- :param data_loading_func:
- :param device:
- :param sampling_neighbourhood_size:
- :param tokenizer_path:
- :param target_emb_size:
- :param link_predictor_type:
- """
super(AbstractObjective, self).__init__()
- # if objective_type not in {"graph_link_prediction", "graph_link_classification", "subword_ranker", "classification"}:
- # raise NotImplementedError()
-
self.name = name
- # self.type = objective_type
self.graph_model = graph_model
self.sampling_neighbourhood_size = sampling_neighbourhood_size
self.batch_size = batch_size
@@ -55,8 +100,12 @@ def __init__(
self.device = device
self.masker = masker
self.link_predictor_type = link_predictor_type
- self.measure_ndcg = measure_ndcg
- self.dilate_ndcg = dilate_ndcg
+ self.measure_scores = measure_scores
+ self.dilate_scores = dilate_scores
+ self.nn_index = nn_index
+ self.early_stopping_tracker = EarlyStoppingTracker(early_stopping_tolerance) if early_stopping else None
+ self.early_stopping_trigger = False
+ self.ns_groups = ns_groups
self.verify_parameters()
@@ -64,10 +113,12 @@ def __init__(
self.create_link_predictor()
self.create_loaders()
+ self.target_embedding_fn = self.get_targets_from_embedder
+ self.negative_factor = 1
+ self.update_embeddings_for_queries = True
+
@abstractmethod
def verify_parameters(self):
- # if self.link_predictor_type == "inner_prod": # TODO incorrect
- # assert self.target_emb_size == self.graph_model.emb_size, "Graph embedding and target embedder dimensionality should match for `inner_prod` type of link predictor."
pass
def create_base_element_sampler(self, data_loading_func, nodes):
@@ -78,7 +129,8 @@ def create_base_element_sampler(self, data_loading_func, nodes):
def create_graph_link_sampler(self, data_loading_func, nodes):
self.target_embedder = GraphLinkSampler(
elements=data_loading_func(), nodes=nodes, compact_dst=False, dst_to_global=True,
- emb_size=self.target_emb_size
+ emb_size=self.target_emb_size, device=self.device, method=self.link_predictor_type, nn_index=self.nn_index,
+ ns_groups=self.ns_groups
)
def create_subword_embedder(self, data_loading_func, nodes, tokenizer_path):
@@ -89,54 +141,71 @@ def create_subword_embedder(self, data_loading_func, nodes, tokenizer_path):
@abstractmethod
def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
- # # create target embedder
- # if self.type == "graph_link_prediction" or self.type == "graph_link_classification":
- # self.create_base_element_sampler(data_loading_func, nodes)
- # elif self.type == "subword_ranker":
- # self.create_subword_embedder(data_loading_func, nodes, tokenizer_path)
- # elif self.type == "node_classification":
- # self.target_embedder = None
+ # self.create_base_element_sampler(data_loading_func, nodes)
+ # self.create_graph_link_sampler(data_loading_func, nodes)
+ # self.create_subword_embedder(data_loading_func, nodes, tokenizer_path)
raise NotImplementedError()
def create_nn_link_predictor(self):
- # self.link_predictor = LinkPredictor(self.target_emb_size + self.graph_model.emb_size).to(self.device)
self.link_predictor = BilinearLinkPedictor(self.target_emb_size, self.graph_model.emb_size, 2).to(self.device)
self.positive_label = 1
self.negative_label = 0
self.label_dtype = torch.long
def create_inner_prod_link_predictor(self):
- self.link_predictor = CosineLinkPredictor().to(self.device)
- self.cosine_loss = CosineEmbeddingLoss(margin=0.4)
+ self.margin = -0.2
+ self.target_embedder.set_margin(self.margin)
+ self.link_predictor = CosineLinkPredictor(margin=self.margin).to(self.device)
+ self.hinge_loss = nn.HingeEmbeddingLoss(margin=1. - self.margin)
+
+ def cosine_loss(x1, x2, label):
+ sim = nn.CosineSimilarity()
+ dist = 1. - sim(x1, x2)
+ return self.hinge_loss(dist, label)
+
+ # self.cosine_loss = CosineEmbeddingLoss(margin=self.margin)
+ self.cosine_loss = cosine_loss
+ self.positive_label = 1.
+ self.negative_label = -1.
+ self.label_dtype = torch.float32
+
+ def create_l2_link_predictor(self):
+ self.margin = 2.0
+ self.target_embedder.set_margin(self.margin)
+ self.link_predictor = L2LinkPredictor().to(self.device)
+ # self.hinge_loss = nn.HingeEmbeddingLoss(margin=self.margin)
+ self.triplet_loss = nn.TripletMarginLoss(margin=self.margin)
+
+ def l2_loss(x1, x2, label):
+ half = x1.shape[0] // 2
+ pos = x2[:half, :]
+ neg = x2[half:, :]
+
+ return self.triplet_loss(x1[:half, :], pos, neg)
+ # dist = torch.norm(x1 - x2, dim=-1)
+ # return self.hinge_loss(dist, label)
+
+ self.l2_loss = l2_loss
self.positive_label = 1.
self.negative_label = -1.
self.label_dtype = torch.float32
- @abstractmethod
def create_link_predictor(self):
- # # create link predictors
- # if self.type in {"graph_link_prediction", "graph_link_classification", "subword_ranker"}:
- # if self.link_predictor_type == "nn":
- # self.create_nn_link_predictor()
- # elif self.link_predictor_type == "inner_prod":
- # self.create_inner_prod_link_predictor()
- # else:
- # raise NotImplementedError()
- # else:
- # # for node classifier
- # self.link_predictor = LinkPredictor(self.graph_model.emb_size).to(self.device)
- raise NotImplementedError()
+ if self.link_predictor_type == "nn":
+ self.create_nn_link_predictor()
+ elif self.link_predictor_type == "inner_prod":
+ self.create_inner_prod_link_predictor()
+ elif self.link_predictor_type == "l2":
+ self.create_l2_link_predictor()
+ else:
+ raise NotImplementedError()
- # @abstractmethod
def create_loaders(self):
- # create loaders
- # if self.type in {"graph_link_prediction", "graph_link_classification", "subword_ranker"}:
+ print("Number of nodes", self.graph_model.g.number_of_nodes())
train_idx, val_idx, test_idx = self._get_training_targets()
train_idx, val_idx, test_idx = self.target_embedder.create_idx_pools(
train_idx=train_idx, val_idx=val_idx, test_idx=test_idx
)
- # else:
- # raise NotImplementedError()
logging.info(
f"Pool sizes for {self.name}: train {self._idx_len(train_idx)}, "
f"val {self._idx_len(val_idx)}, "
@@ -147,6 +216,13 @@ def create_loaders(self):
batch_size=self.batch_size # batch_size_node_name
)
+ def get_num_nodes(ids):
+ return sum(len(ids[key_]) for key_ in ids) // self.batch_size + 1
+
+ self.num_train_batches = get_num_nodes(train_idx)
+ self.num_test_batches = get_num_nodes(test_idx)
+ self.num_val_batches = get_num_nodes(val_idx)
+
def _idx_len(self, idx):
if isinstance(idx, dict):
length = 0
@@ -179,49 +255,49 @@ def _get_training_targets(self):
# labels = labels[key]
self.use_types = False
- train_idx = {
- ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['train_mask'], as_tuple=False).squeeze()
- for ntype in self.ntypes
- }
- val_idx = {
- ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['val_mask'], as_tuple=False).squeeze()
- for ntype in self.ntypes
- }
- test_idx = {
- ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['test_mask'], as_tuple=False).squeeze()
- for ntype in self.ntypes
- }
+ def get_targets(data_label):
+ return {
+ ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data[data_label], as_tuple=False).squeeze()
+ for ntype in self.ntypes
+ }
+
+ train_idx = get_targets("train_mask")
+ val_idx = get_targets("val_mask")
+ test_idx = get_targets("test_mask")
+
+ # train_idx = {
+ # ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['train_mask'], as_tuple=False).squeeze()
+ # for ntype in self.ntypes
+ # }
+ # val_idx = {
+ # ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['val_mask'], as_tuple=False).squeeze()
+ # for ntype in self.ntypes
+ # }
+ # test_idx = {
+ # ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['test_mask'], as_tuple=False).squeeze()
+ # for ntype in self.ntypes
+ # }
else:
# not sure when this is called
raise NotImplementedError()
- # self.ntypes = None
- # # labels = g.ndata['labels']
- # train_idx = self.graph_model.g.ndata['train_mask']
- # val_idx = self.graph_model.g.ndata['val_mask']
- # test_idx = self.graph_model.g.ndata['test_mask']
- # self.use_types = False
return train_idx, val_idx, test_idx
def _get_loaders(self, train_idx, val_idx, test_idx, batch_size):
- # train sampler
+
layers = self.graph_model.num_layers
- sampler = dgl.dataloading.MultiLayerNeighborSampler([self.sampling_neighbourhood_size] * layers)
- loader = dgl.dataloading.NodeDataLoader(
- self.graph_model.g, train_idx, sampler, batch_size=batch_size, shuffle=False, num_workers=0)
- # validation sampler
- # we do not use full neighbor to save computation resources
- val_sampler = dgl.dataloading.MultiLayerNeighborSampler([self.sampling_neighbourhood_size] * layers)
- val_loader = dgl.dataloading.NodeDataLoader(
- self.graph_model.g, val_idx, val_sampler, batch_size=batch_size, shuffle=False, num_workers=0)
+ def create_loader(ids):
+ sampler = dgl.dataloading.MultiLayerFullNeighborSampler(layers)
+ loader = dgl.dataloading.NodeDataLoader(
+ self.graph_model.g, ids, sampler, batch_size=batch_size, shuffle=False, num_workers=0)
+ return loader
- # we do not use full neighbor to save computation resources
- test_sampler = dgl.dataloading.MultiLayerNeighborSampler([self.sampling_neighbourhood_size] * layers)
- test_loader = dgl.dataloading.NodeDataLoader(
- self.graph_model.g, test_idx, test_sampler, batch_size=batch_size, shuffle=False, num_workers=0)
+ train_loader = create_loader(train_idx)
+ val_loader = create_loader(val_idx)
+ test_loader = create_loader(test_idx)
- return loader, val_loader, test_loader
+ return train_loader, val_loader, test_loader
def reset_iterator(self, data_split):
iter_name = f"{data_split}_loader_iter"
@@ -234,8 +310,7 @@ def loader_next(self, data_split):
return next(getattr(self, iter_name))
def _create_loader(self, indices):
- sampler = dgl.dataloading.MultiLayerNeighborSampler(
- [self.sampling_neighbourhood_size] * self.graph_model.num_layers)
+ sampler = dgl.dataloading.MultiLayerFullNeighborSampler(self.graph_model.num_layers)
return dgl.dataloading.NodeDataLoader(
self.graph_model.g, indices, sampler, batch_size=len(indices), num_workers=0)
@@ -257,15 +332,30 @@ def compute_acc_loss(self, node_embs_, element_embs_, labels):
elif self.link_predictor_type == "inner_prod":
loss = self.cosine_loss(node_embs_, element_embs_, labels)
labels[labels < 0] = 0
+ elif self.link_predictor_type == "l2":
+ loss = self.l2_loss(node_embs_, element_embs_, labels)
+ labels[labels < 0] = 0
+ # num_examples = len(labels) // 2
+ # anchor = node_embs_[:num_examples, :]
+ # positive = element_embs_[:num_examples, :]
+ # negative = element_embs_[num_examples:, :]
+ # # pos_labels_ = labels[:num_examples]
+ # # neg_labels_ = labels[num_examples:]
+ # margin = 1.
+ # triplet = nn.TripletMarginLoss(margin=margin)
+ # self.target_embedder.set_margin(margin)
+ # loss = triplet(anchor, positive, negative)
+ # logits = (torch.norm(node_embs_ - element_embs_, keepdim=True) < 1.).float()
+ # logits = torch.cat([1 - logits, logits], dim=1)
+ # labels[labels < 0] = 0
else:
raise NotImplementedError()
- acc = _compute_accuracy(logits.argmax(dim=1), labels)
+ acc = compute_accuracy(logits.argmax(dim=1), labels)
return acc, loss
-
- def _logits_batch(self, input_nodes, blocks, train_embeddings=True, masked=None):
+ def _graph_embeddings(self, input_nodes, blocks, train_embeddings=True, masked=None):
cumm_logits = []
@@ -312,91 +402,173 @@ def seeds_to_global(self, seeds):
else:
return seeds
- def _logits_embedder(self, node_embeddings, elem_embedder, link_predictor, seeds, negative_factor=1):
- k = negative_factor
- indices = self.seeds_to_global(seeds).tolist()
- batch_size = len(indices)
+ def sample_negative(self, ids, k, neg_sampling_strategy):
+ if neg_sampling_strategy is not None:
+ negative = self.target_embedder.sample_negative(
+ k, ids=ids, strategy=neg_sampling_strategy
+ )
+ else:
+ negative = self.target_embedder.sample_negative(
+ k, ids=ids,
+ )
+ return negative
- node_embeddings_batch = node_embeddings
- element_embeddings = elem_embedder(elem_embedder[indices].to(self.device))
+ def get_targets_from_nodes(
+ self, positive_indices, negative_indices=None, train_embeddings=True
+ ):
+ negative_indices = torch.tensor(negative_indices, dtype=torch.long) if negative_indices is not None else None
- # labels_pos = torch.ones(batch_size, dtype=torch.long)
- labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype)
+ def get_embeddings_for_targets(dst):
+ unique_dst, slice_map = self._handle_non_unique(dst)
+ assert unique_dst[slice_map].tolist() == dst.tolist()
- node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
- # negative_random = elem_embedder(elem_embedder.sample_negative(batch_size * k).to(self.device))
- negative_random = elem_embedder(elem_embedder.sample_negative(batch_size * k, ids=indices).to(self.device)) # closest negative
+ dataloader = self._create_loader(unique_dst)
+ input_nodes, dst_seeds, blocks = next(iter(dataloader))
+ blocks = [blk.to(self.device) for blk in blocks]
+ assert dst_seeds.shape == unique_dst.shape
+ assert dst_seeds.tolist() == unique_dst.tolist()
+ unique_dst_embeddings = self._graph_embeddings(input_nodes, blocks, train_embeddings) # use_types, ntypes)
+ dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
+
+ if self.update_embeddings_for_queries:
+ self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(),
+ unique_dst_embeddings.detach().cpu().numpy())
+
+ return dst_embeddings
- # labels_neg = torch.zeros(batch_size * k, dtype=torch.long)
- labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype)
+ positive_dst = get_embeddings_for_targets(positive_indices)
+ negative_dst = get_embeddings_for_targets(negative_indices) if negative_indices is not None else None
+ return positive_dst, negative_dst
- # positive_batch = torch.cat([node_embeddings_batch, element_embeddings], 1)
- # negative_batch = torch.cat([node_embeddings_neg_batch, negative_random], 1)
- # batch = torch.cat([positive_batch, negative_batch], 0)
- # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
+ def get_targets_from_embedder(
+ self, positive_indices, negative_indices=None, train_embeddings=True
+ ):
+
+ # def get_embeddings_for_targets(dst):
+ # unique_dst, slice_map = self._handle_non_unique(dst)
+ # assert unique_dst[slice_map].tolist() == dst.tolist()
+ # unique_dst_embeddings = self.target_embedder(unique_dst.to(self.device))
+ # dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
#
- # logits = link_predictor(batch)
+ # if self.update_embeddings_for_queries:
+ # self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(),
+ # unique_dst_embeddings.detach().cpu().numpy())
#
- # return logits, labels
- nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0)
- embs = torch.cat([element_embeddings, negative_random], dim=0)
- labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
- return nodes, embs, labels
+ # return dst_embeddings
+
+ positive_dst = self.target_embedder(positive_indices.to(self.device))
+ negative_dst = self.target_embedder(negative_indices.to(self.device)) if negative_indices is not None else None
+ #
+ # positive_dst = get_embeddings_for_targets(positive_indices)
+ # negative_dst = get_embeddings_for_targets(negative_indices) if negative_indices is not None else None
- def _logits_nodes(self, node_embeddings,
- elem_embedder, link_predictor, create_dataloader,
- src_seeds, negative_factor=1, train_embeddings=True):
+ return positive_dst, negative_dst
+
+ def create_positive_labels(self, ids):
+ return torch.full((len(ids),), self.positive_label, dtype=self.label_dtype)
+
+ def create_negative_labels(self, ids, k):
+ return torch.full((len(ids) * k,), self.negative_label, dtype=self.label_dtype)
+
+ def prepare_for_prediction(
+ self, node_embeddings, seeds, target_embedding_fn, negative_factor=1,
+ neg_sampling_strategy=None, train_embeddings=True,
+ ):
k = negative_factor
- indices = self.seeds_to_global(src_seeds).tolist()
+ indices = self.seeds_to_global(seeds).tolist()
batch_size = len(indices)
node_embeddings_batch = node_embeddings
- next_call_indices = elem_embedder[indices] # this assumes indices is torch tensor
-
- # dst targets are not unique
- unique_dst, slice_map = self._handle_non_unique(next_call_indices)
- assert unique_dst[slice_map].tolist() == next_call_indices.tolist()
-
- dataloader = create_dataloader(unique_dst)
- input_nodes, dst_seeds, blocks = next(iter(dataloader))
- blocks = [blk.to(self.device) for blk in blocks]
- assert dst_seeds.shape == unique_dst.shape
- assert dst_seeds.tolist() == unique_dst.tolist()
- unique_dst_embeddings = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes)
- next_call_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
- # labels_pos = torch.ones(batch_size, dtype=torch.long)
- labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype)
-
node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
- # negative_indices = torch.tensor(elem_embedder.sample_negative(
- # batch_size * k), dtype=torch.long) # embeddings are sampled from 3/4 unigram distribution
- negative_indices = torch.tensor(elem_embedder.sample_negative(
- batch_size * k, ids=indices), dtype=torch.long) # closest negative
- unique_negative, slice_map = self._handle_non_unique(negative_indices)
- assert unique_negative[slice_map].tolist() == negative_indices.tolist()
-
- dataloader = create_dataloader(unique_negative)
- input_nodes, dst_seeds, blocks = next(iter(dataloader))
- blocks = [blk.to(self.device) for blk in blocks]
- assert dst_seeds.shape == unique_negative.shape
- assert dst_seeds.tolist() == unique_negative.tolist()
- unique_negative_random = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes)
- negative_random = unique_negative_random[slice_map.to(self.device)]
- # labels_neg = torch.zeros(batch_size * k, dtype=torch.long)
- labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype)
-
- # positive_batch = torch.cat([node_embeddings_batch, next_call_embeddings], 1)
- # negative_batch = torch.cat([node_embeddings_neg_batch, negative_random], 1)
- # batch = torch.cat([positive_batch, negative_batch], 0)
- # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
- #
- # logits = link_predictor(batch)
- #
- # return logits, labels
- nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0)
- embs = torch.cat([next_call_embeddings, negative_random], dim=0)
+
+ positive_indices = self.target_embedder[indices]
+ negative_indices = self.sample_negative(
+ k=batch_size * k, ids=indices, neg_sampling_strategy=neg_sampling_strategy
+ )
+
+ positive_dst, negative_dst = target_embedding_fn(
+ positive_indices, negative_indices, train_embeddings
+ )
+
+ # TODO breaks cache in
+ # SourceCodeTools.models.graph.train.objectives.GraphLinkClassificationObjective.TargetLinkMapper.get_labels
+ labels_pos = self.create_positive_labels(indices)
+ labels_neg = self.create_negative_labels(indices, k)
+
+ src_embs = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0)
+ dst_embs = torch.cat([positive_dst, negative_dst], dim=0)
labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
- return nodes, embs, labels
+ return src_embs, dst_embs, labels
+
+ # def _logits_embedder(
+ # self, node_embeddings, elem_embedder, link_predictor, seeds, negative_factor=1, neg_sampling_strategy=None
+ # ):
+ # k = negative_factor
+ # indices = self.seeds_to_global(seeds).tolist()
+ # batch_size = len(indices)
+ #
+ # node_embeddings_batch = node_embeddings
+ # node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
+ #
+ # element_embeddings = elem_embedder(elem_embedder[indices].to(self.device))
+ # negative_random = elem_embedder(self.sample_negative(
+ # k=batch_size * k, ids=indices, neg_sampling_strategy=neg_sampling_strategy
+ # ).to(self.device))
+ #
+ # labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype)
+ # labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype)
+ #
+ # src_embs = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0)
+ # dst_embs = torch.cat([element_embeddings, negative_random], dim=0)
+ # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
+ # return src_embs, dst_embs, labels
+ #
+ # def _logits_nodes(
+ # self, node_embeddings, elem_embedder, link_predictor, create_dataloader,
+ # src_seeds, negative_factor=1, train_embeddings=True, neg_sampling_strategy=None,
+ # update_embeddings_for_queries=False
+ # ):
+ # k = negative_factor
+ # indices = self.seeds_to_global(src_seeds).tolist()
+ # batch_size = len(indices)
+ #
+ # node_embeddings_batch = node_embeddings
+ # node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
+ #
+ # next_call_indices = elem_embedder[indices] # this assumes indices is torch tensor
+ # negative_indices = torch.tensor(self.sample_negative(
+ # k=batch_size * k, ids=indices, neg_sampling_strategy=neg_sampling_strategy
+ # ), dtype=torch.long)
+ #
+ # # dst targets are not unique
+ # def get_embeddings_for_targets(dst, update_embeddings_for_queries):
+ # unique_dst, slice_map = self._handle_non_unique(dst)
+ # assert unique_dst[slice_map].tolist() == dst.tolist()
+ #
+ # dataloader = create_dataloader(unique_dst)
+ # input_nodes, dst_seeds, blocks = next(iter(dataloader))
+ # blocks = [blk.to(self.device) for blk in blocks]
+ # assert dst_seeds.shape == unique_dst.shape
+ # assert dst_seeds.tolist() == unique_dst.tolist()
+ # unique_dst_embeddings = self._graph_embeddings(input_nodes, blocks, train_embeddings) # use_types, ntypes)
+ # dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
+ #
+ # if update_embeddings_for_queries:
+ # self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(),
+ # unique_dst_embeddings.detach().cpu().numpy())
+ #
+ # return dst_embeddings
+ #
+ # next_call_embeddings = get_embeddings_for_targets(next_call_indices, update_embeddings_for_queries)
+ # negative_random = get_embeddings_for_targets(negative_indices, update_embeddings_for_queries)
+ #
+ # labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype)
+ # labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype)
+ #
+ # src_embs = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0)
+ # dst_embs = torch.cat([next_call_embeddings, negative_random], dim=0)
+ # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
+ # return src_embs, dst_embs, labels
def seeds_to_python(self, seeds):
if isinstance(seeds, dict):
@@ -407,47 +579,32 @@ def seeds_to_python(self, seeds):
python_seeds = seeds.tolist()
return python_seeds
- @abstractmethod
- def forward(self, input_nodes, seeds, blocks, train_embeddings=True):
- # masked = None
- # if self.type in {"subword_ranker"}:
- # masked = self.masker.get_mask(self.seeds_to_python(seeds))
- # graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked)
- # if self.type in {"subword_ranker"}:
- # # logits, labels = self._logits_embedder(graph_emb, self.target_embedder, self.link_predictor, seeds)
- # node_embs_, element_embs_, labels = self._logits_embedder(graph_emb, self.target_embedder, self.link_predictor, seeds)
- # elif self.type in {"graph_link_prediction", "graph_link_classification"}:
- # # logits, labels = self._logits_nodes(graph_emb, self.target_embedder, self.link_predictor,
- # # self._create_loader, seeds, train_embeddings=train_embeddings)
- # node_embs_, element_embs_, labels = self._logits_nodes(graph_emb, self.target_embedder, self.link_predictor,
- # self._create_loader, seeds, train_embeddings=train_embeddings)
- # else:
- # raise NotImplementedError()
- #
- # # logits = self.link_predictor(node_embs_, element_embs_)
- # #
- # # acc = _compute_accuracy(logits.argmax(dim=1), labels)
- # # logp = nn.functional.log_softmax(logits, 1)
- # # loss = nn.functional.nll_loss(logp, labels)
- # acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
- #
- # return loss, acc
- raise NotImplementedError()
+ def forward(self, input_nodes, seeds, blocks, train_embeddings=True, neg_sampling_strategy=None):
+ masked = self.masker.get_mask(self.seeds_to_python(seeds)) if self.masker is not None else None
+ graph_emb = self._graph_embeddings(input_nodes, blocks, train_embeddings, masked=masked)
+ node_embs_, element_embs_, labels = self.prepare_for_prediction(
+ graph_emb, seeds, self.target_embedding_fn, negative_factor=self.negative_factor,
+ neg_sampling_strategy=neg_sampling_strategy,
+ train_embeddings=train_embeddings
+ )
+
+ acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
- def _evaluate_embedder(self, ee, lp, data_split, neg_sampling_factor=1):
+ return loss, acc
- total_loss = 0
- total_acc = 0
- ndcg_at = [1, 3, 5, 10]
- total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at}
- ndcg_count = 0
+ def evaluate_objective(self, data_split, neg_sampling_strategy=None, negative_factor=1):
+ # total_loss = 0
+ # total_acc = 0
+ at = [1, 3, 5, 10]
+ # total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at}
+ # ndcg_count = 0
count = 0
- # if self.measure_ndcg:
- # if self.link_predictor == "inner_prod":
- # self.target_embedder.prepare_index()
+ scores = defaultdict(list)
- for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"):
+ for input_nodes, seeds, blocks in tqdm(
+ getattr(self, f"{data_split}_loader"), total=getattr(self, f"num_{data_split}_batches")
+ ):
blocks = [blk.to(self.device) for blk in blocks]
if self.masker is None:
@@ -455,134 +612,151 @@ def _evaluate_embedder(self, ee, lp, data_split, neg_sampling_factor=1):
else:
masked = self.masker.get_mask(self.seeds_to_python(seeds))
- src_embs = self._logits_batch(input_nodes, blocks, masked=masked)
- # logits, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor)
- node_embs_, element_embs_, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor)
-
- if self.measure_ndcg:
- if count % self.dilate_ndcg == 0:
- ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device)
- for key, val in ndcg.items():
- total_ndcg[key] = total_ndcg[key] + val
- ndcg_count += 1
-
- # logits = self.link_predictor(node_embs_, element_embs_)
- #
- # logp = nn.functional.log_softmax(logits, 1)
- # loss = nn.functional.cross_entropy(logp, labels)
- # acc = _compute_accuracy(logp.argmax(dim=1), labels)
+ src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked)
+ node_embs_, element_embs_, labels = self.prepare_for_prediction(
+ src_embs, seeds, self.target_embedding_fn, negative_factor=negative_factor,
+ neg_sampling_strategy=neg_sampling_strategy,
+ train_embeddings=False
+ )
+
+ if self.measure_scores:
+ if count % self.dilate_scores == 0:
+ scores_ = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs,
+ self.link_predictor, at=at,
+ type=self.link_predictor_type, device=self.device)
+ for key, val in scores_.items():
+ scores[key].append(val)
acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
- total_loss += loss.item()
- total_acc += acc
+ scores["Loss"].append(loss.item())
+ scores["Accuracy"].append(acc)
count += 1
- return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_ndcg else None
- def _evaluate_nodes(self, ee, lp, create_api_call_loader, data_split, neg_sampling_factor=1):
-
- total_loss = 0
- total_acc = 0
- ndcg_at = [1, 3, 5, 10]
- total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at}
- ndcg_count = 0
- count = 0
-
- for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"):
- blocks = [blk.to(self.device) for blk in blocks]
-
- if self.masker is None:
- masked = None
- else:
- masked = self.masker.get_mask(self.seeds_to_python(seeds))
-
- src_embs = self._logits_batch(input_nodes, blocks, masked=masked)
- # logits, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor)
- node_embs_, element_embs_, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor)
-
- if self.measure_ndcg:
- if count % self.dilate_ndcg == 0:
- ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device)
- for key, val in ndcg.items():
- total_ndcg[key] = total_ndcg[key] + val
- ndcg_count += 1
-
- # logits = self.link_predictor(node_embs_, element_embs_)
- #
- # logp = nn.functional.log_softmax(logits, 1)
- # loss = nn.functional.cross_entropy(logp, labels)
- # acc = _compute_accuracy(logp.argmax(dim=1), labels)
-
- acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
-
- total_loss += loss.item()
- total_acc += acc
- count += 1
- return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_ndcg else None
+ scores = {key: sum_scores(val) for key, val in scores.items()}
+ return scores
+ # return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in
+ # total_ndcg.items()} if self.measure_scores else None
+
+ # def _evaluate_embedder(self, ee, lp, data_split, neg_sampling_factor=1):
+ #
+ # total_loss = 0
+ # total_acc = 0
+ # ndcg_at = [1, 3, 5, 10]
+ # total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at}
+ # ndcg_count = 0
+ # count = 0
+ #
+ # for input_nodes, seeds, blocks in tqdm(
+ # getattr(self, f"{data_split}_loader"), total=getattr(self, f"num_{data_split}_batches")
+ # ):
+ # blocks = [blk.to(self.device) for blk in blocks]
+ #
+ # if self.masker is None:
+ # masked = None
+ # else:
+ # masked = self.masker.get_mask(self.seeds_to_python(seeds))
+ #
+ # src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked)
+ # # logits, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor)
+ # node_embs_, element_embs_, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor)
+ #
+ # if self.measure_scores:
+ # if count % self.dilate_scores == 0:
+ # ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device)
+ # for key, val in ndcg.items():
+ # total_ndcg[key] = total_ndcg[key] + val
+ # ndcg_count += 1
+ #
+ # # logits = self.link_predictor(node_embs_, element_embs_)
+ # #
+ # # logp = nn.functional.log_softmax(logits, 1)
+ # # loss = nn.functional.cross_entropy(logp, labels)
+ # # acc = _compute_accuracy(logp.argmax(dim=1), labels)
+ #
+ # acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
+ #
+ # total_loss += loss.item()
+ # total_acc += acc
+ # count += 1
+ # return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_scores else None
+ #
+ # def _evaluate_nodes(self, ee, lp, create_api_call_loader, data_split, neg_sampling_factor=1):
+ #
+ # total_loss = 0
+ # total_acc = 0
+ # ndcg_at = [1, 3, 5, 10]
+ # total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at}
+ # ndcg_count = 0
+ # count = 0
+ #
+ # for input_nodes, seeds, blocks in tqdm(
+ # getattr(self, f"{data_split}_loader"), total=getattr(self, f"num_{data_split}_batches")
+ # ):
+ # blocks = [blk.to(self.device) for blk in blocks]
+ #
+ # if self.masker is None:
+ # masked = None
+ # else:
+ # masked = self.masker.get_mask(self.seeds_to_python(seeds))
+ #
+ # src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked)
+ # # logits, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor)
+ # # node_embs_, element_embs_, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor)
+ #
+ # node_embs_, element_embs_, labels = self.prepare_for_prediction(
+ # src_embs, seeds, self.target_embedding_fn, negative_factor=1,
+ # neg_sampling_strategy=None, train_embeddings=True,
+ # update_embeddings_for_queries=False
+ # )
+ #
+ # if self.measure_scores:
+ # if count % self.dilate_scores == 0:
+ # ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device)
+ # for key, val in ndcg.items():
+ # total_ndcg[key] = total_ndcg[key] + val
+ # ndcg_count += 1
+ #
+ # # logits = self.link_predictor(node_embs_, element_embs_)
+ # #
+ # # logp = nn.functional.log_softmax(logits, 1)
+ # # loss = nn.functional.cross_entropy(logp, labels)
+ # # acc = _compute_accuracy(logp.argmax(dim=1), labels)
+ #
+ # acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
+ #
+ # total_loss += loss.item()
+ # total_acc += acc
+ # count += 1
+ # return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_scores else None
+
+ def check_early_stopping(self, metric):
+ """
+ Checks the metric value and raises Early Stopping when the metric stops increasing.
+ Assumes that the metric grows. Uses accuracy as a metric by default. Check implementation of child classes.
+ :param metric: metric value
+ :return: Nothing
+ """
+ if self.early_stopping_tracker is not None:
+ self.early_stopping_trigger = self.early_stopping_tracker.should_stop(metric)
- @abstractmethod
- def evaluate(self, data_split, neg_sampling_factor=1):
- # if self.type in {"subword_ranker"}:
- # loss, acc, ndcg = self._evaluate_embedder(
- # self.target_embedder, self.link_predictor, data_split=data_split, neg_sampling_factor=neg_sampling_factor
- # )
- # elif self.type in {"graph_link_prediction", "graph_link_classification"}:
- # loss, acc = self._evaluate_nodes(
- # self.target_embedder, self.link_predictor, self._create_loader, data_split=data_split,
- # neg_sampling_factor=neg_sampling_factor
- # )
- # ndcg = None
- # else:
- # raise NotImplementedError()
- #
- # return loss, acc, ndcg
- raise NotImplementedError()
+ def evaluate(self, data_split, *, neg_sampling_strategy=None, early_stopping=False, early_stopping_tolerance=20):
+ # negative factor is 1 for evaluation
+ scores = self.evaluate_objective(data_split, neg_sampling_strategy=None, negative_factor=1)
+ if data_split == "val":
+ self.check_early_stopping(scores["Accuracy"])
+ return scores
@abstractmethod
def parameters(self, recurse: bool = True):
- # if self.type in {"subword_ranker"}:
- # return chain(self.target_embedder.parameters(), self.link_predictor.parameters())
- # elif self.type in {"graph_link_prediction", "graph_link_classification"}:
- # return self.link_predictor.parameters()
- # else:
- # raise NotImplementedError()
raise NotImplementedError()
@abstractmethod
def custom_state_dict(self):
- # state_dict = OrderedDict()
- # if self.type in {"subword_ranker"}:
- # for k, v in self.target_embedder.state_dict().items():
- # state_dict[f"target_embedder.{k}"] = v
- # for k, v in self.link_predictor.state_dict().items():
- # state_dict[f"link_predictor.{k}"] = v
- # # state_dict["target_embedder"] = self.target_embedder.state_dict()
- # # state_dict["link_predictor"] = self.link_predictor.state_dict()
- # elif self.type in {"graph_link_prediction", "graph_link_classification"}:
- # for k, v in self.link_predictor.state_dict().items():
- # state_dict[f"link_predictor.{k}"] = v
- # # state_dict["link_predictor"] = self.link_predictor.state_dict()
- # else:
- # raise NotImplementedError()
- #
- # return state_dict
raise NotImplementedError()
@abstractmethod
def custom_load_state_dict(self, state_dicts):
- # if self.type in {"subword_ranker"}:
- # self.target_embedder.load_state_dict(
- # self.get_prefix("target_embedder", state_dicts)
- # )
- # self.link_predictor.load_state_dict(
- # self.get_prefix("link_predictor", state_dicts)
- # )
- # elif self.type in {"graph_link_prediction", "graph_link_classification"}:
- # self.link_predictor.load_state_dict(
- # self.get_prefix("link_predictor", state_dicts)
- # )
- # else:
- # raise NotImplementedError()
raise NotImplementedError()
def get_prefix(self, prefix, state_dict):
diff --git a/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py b/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py
index 0cd50be7..183ab41d 100644
--- a/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py
+++ b/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py
@@ -1,15 +1,15 @@
-from collections import OrderedDict
import numpy as np
import random as rnd
import torch
from torch.nn import CrossEntropyLoss
from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker
+from SourceCodeTools.mltools.torch import compute_accuracy
+from SourceCodeTools.models.graph.ElementEmbedder import GraphLinkSampler
from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase
-from SourceCodeTools.models.graph.LinkPredictor import LinkClassifier, BilinearLinkPedictor
+from SourceCodeTools.models.graph.LinkPredictor import BilinearLinkPedictor, TransRLinkPredictor
from SourceCodeTools.models.graph.train.Scorer import Scorer
-from SourceCodeTools.models.graph.train.objectives import GraphLinkObjective
-from SourceCodeTools.models.graph.train.sampling_multitask2 import _compute_accuracy
+from SourceCodeTools.models.graph.train.objectives.GraphLinkObjective import GraphLinkObjective
from SourceCodeTools.tabular.common import compact_property
@@ -18,69 +18,32 @@ def __init__(
self, name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1, ns_groups=None
):
super().__init__(
name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, ns_groups=ns_groups
+ )
+ self.measure_scores = True
+ self.update_embeddings_for_queries = True
+
+ def create_graph_link_sampler(self, data_loading_func, nodes):
+ self.target_embedder = TargetLinkMapper(
+ elements=data_loading_func(), nodes=nodes, emb_size=self.target_emb_size, ns_groups=self.ns_groups
)
- self.measure_ndcg = False
def create_link_predictor(self):
- self.link_predictor = BilinearLinkPedictor(self.target_emb_size, self.graph_model.emb_size, self.target_embedder.num_classes).to(self.device)
+ self.link_predictor = BilinearLinkPedictor(
+ self.target_emb_size, self.graph_model.emb_size, self.target_embedder.num_classes
+ ).to(self.device)
# self.positive_label = 1
- self.negative_label = 0
+ self.negative_label = self.target_embedder.null_class
self.label_dtype = torch.long
- def _logits_nodes(self, node_embeddings,
- elem_embedder, link_predictor, create_dataloader,
- src_seeds, negative_factor=1, train_embeddings=True):
- k = negative_factor
- indices = self.seeds_to_global(src_seeds).tolist()
- batch_size = len(indices)
-
- node_embeddings_batch = node_embeddings
- dst_indices, labels_pos = elem_embedder[indices] # this assumes indices is torch tensor
-
- # dst targets are not unique
- unique_dst, slice_map = self._handle_non_unique(dst_indices)
- assert unique_dst[slice_map].tolist() == dst_indices.tolist()
-
- dataloader = create_dataloader(unique_dst)
- input_nodes, dst_seeds, blocks = next(iter(dataloader))
- blocks = [blk.to(self.device) for blk in blocks]
- assert dst_seeds.shape == unique_dst.shape
- assert dst_seeds.tolist() == unique_dst.tolist()
- unique_dst_embeddings = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes)
- dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
-
- self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(), unique_dst_embeddings.detach().cpu().numpy())
-
- node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
- # negative_indices = torch.tensor(elem_embedder.sample_negative(
- # batch_size * k), dtype=torch.long) # embeddings are sampled from 3/4 unigram distribution
- negative_indices = torch.tensor(elem_embedder.sample_negative(
- batch_size * k, ids=indices), dtype=torch.long) # closest negative
- unique_negative, slice_map = self._handle_non_unique(negative_indices)
- assert unique_negative[slice_map].tolist() == negative_indices.tolist()
-
- dataloader = create_dataloader(unique_negative)
- input_nodes, dst_seeds, blocks = next(iter(dataloader))
- blocks = [blk.to(self.device) for blk in blocks]
- assert dst_seeds.shape == unique_negative.shape
- assert dst_seeds.tolist() == unique_negative.tolist()
- unique_negative_random = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes)
- negative_random = unique_negative_random[slice_map.to(self.device)]
- labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype)
-
- self.target_embedder.set_embed(unique_negative.detach().cpu().numpy(), unique_negative_random.detach().cpu().numpy())
-
- nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0)
- embs = torch.cat([dst_embeddings, negative_random], dim=0)
- labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
- return nodes, embs, labels
+ def create_positive_labels(self, ids):
+ return torch.LongTensor(self.target_embedder.get_labels(ids))
def compute_acc_loss(self, node_embs_, element_embs_, labels):
logits = self.link_predictor(node_embs_, element_embs_)
@@ -89,14 +52,52 @@ def compute_acc_loss(self, node_embs_, element_embs_, labels):
loss = loss_fct(logits.reshape(-1, logits.size(-1)),
labels.reshape(-1))
- acc = _compute_accuracy(logits.argmax(dim=1), labels)
+ acc = compute_accuracy(logits.argmax(dim=1), labels)
return acc, loss
-class TargetLinkMapper(ElementEmbedderBase, Scorer):
- def __init__(self, elements, nodes):
- ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes, )
+class TransRObjective(GraphLinkClassificationObjective):
+ def __init__(
+ self, graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
+ measure_scores=False, dilate_scores=1, ns_groups=None
+ ):
+ super().__init__(
+ "TransR", graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, ns_groups=ns_groups
+ )
+
+ def create_link_predictor(self):
+ self.link_predictor = TransRLinkPredictor(
+ input_dim=self.target_emb_size, rel_dim=30,
+ num_relations=self.target_embedder.num_classes
+ ).to(self.device)
+ # self.positive_label = 1
+ self.negative_label = -1
+ self.label_dtype = torch.long
+
+ def compute_acc_loss(self, node_embs_, element_embs_, labels):
+
+ num_examples = len(labels) // 2
+ anchor = node_embs_[:num_examples, :]
+ positive = element_embs_[:num_examples, :]
+ negative = element_embs_[num_examples:, :]
+ labels_ = labels[:num_examples]
+
+ loss, sim = self.link_predictor(anchor, positive, negative, labels_)
+ acc = compute_accuracy(sim, labels >= 0)
+
+ return acc, loss
+
+class TargetLinkMapper(GraphLinkSampler):
+ def __init__(self, elements, nodes, emb_size=1, ns_groups=None):
+ super(TargetLinkMapper, self).__init__(
+ elements, nodes, compact_dst=False, dst_to_global=True, emb_size=emb_size, ns_groups=ns_groups
+ )
def init(self, compact_dst):
if compact_dst:
@@ -106,7 +107,7 @@ def init(self, compact_dst):
self.elements['emb_id'] = self.elements['dst']
self.link_type2id, self.inverse_link_type_map = compact_property(self.elements['type'], return_order=True, index_from_one=True)
- self.elements["link_type"] = self.elements["type"].apply(lambda x: self.link_type2id[x])
+ self.elements["link_type"] = list(map(lambda x: self.link_type2id[x], self.elements["type"].tolist()))
self.element_lookup = {}
for id_, emb_id, link_type in self.elements[["id", "emb_id", "link_type"]].values:
@@ -117,12 +118,22 @@ def init(self, compact_dst):
self.init_neg_sample()
self.num_classes = len(self.inverse_link_type_map)
+ self.null_class = 0
def __getitem__(self, ids):
+ self.cached_ids = ids
node_ids, labels = zip(*(rnd.choice(self.element_lookup[id]) for id in ids))
- return np.array(node_ids, dtype=np.int32), np.array(labels, dtype=np.int32)
+ self.cached_labels = list(labels)
+ return np.array(node_ids, dtype=np.int32)#, torch.LongTensor(np.array(labels, dtype=np.int32))
+
+ def get_labels(self, ids):
+ if self.cached_ids == ids:
+ return self.cached_labels
+ else:
+ node_ids, labels = zip(*(rnd.choice(self.element_lookup[id]) for id in ids))
+ return list(labels)
- def sample_negative(self, size, ids=None, strategy="closest"): # TODO switch to w2v?
+ def sample_negative(self, size, ids=None, strategy="w2v"): # TODO switch to w2v?
if strategy == "w2v":
negative = ElementEmbedderBase.sample_negative(self, size)
else:
diff --git a/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py b/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py
index 3aaaedde..0b14105f 100644
--- a/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py
+++ b/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py
@@ -1,9 +1,6 @@
from collections import OrderedDict
-import torch
-
from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker
-from SourceCodeTools.models.graph.LinkPredictor import LinkClassifier
from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective
@@ -12,14 +9,18 @@ def __init__(
self, name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1, nn_index="brute", ns_groups=None
):
super(GraphLinkObjective, self).__init__(
name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index,
+ ns_groups=ns_groups
)
+ self.target_embedding_fn = self.get_targets_from_nodes
+ self.negative_factor = 1
+ self.update_embeddings_for_queries = True
def verify_parameters(self):
pass
@@ -27,34 +28,6 @@ def verify_parameters(self):
def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
self.create_graph_link_sampler(data_loading_func, nodes)
- def create_link_predictor(self):
- if self.link_predictor_type == "nn":
- self.create_nn_link_predictor()
- elif self.link_predictor_type == "inner_prod":
- self.create_inner_prod_link_predictor()
- else:
- raise NotImplementedError()
-
- def forward(self, input_nodes, seeds, blocks, train_embeddings=True):
- masked = None
- graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked)
- node_embs_, element_embs_, labels = self._logits_nodes(
- graph_emb, self.target_embedder, self.link_predictor,
- self._create_loader, seeds, train_embeddings=train_embeddings
- )
-
- acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
-
- return loss, acc
-
- def evaluate(self, data_split, neg_sampling_factor=1):
- loss, acc, ndcg = self._evaluate_nodes(
- self.target_embedder, self.link_predictor, self._create_loader, data_split=data_split,
- neg_sampling_factor=neg_sampling_factor
- )
- # ndcg = None
- return loss, acc, ndcg
-
def parameters(self, recurse: bool = True):
return self.link_predictor.parameters()
@@ -69,86 +42,37 @@ def custom_load_state_dict(self, state_dicts):
self.get_prefix("link_predictor", state_dicts)
)
- def _logits_nodes(self, node_embeddings,
- elem_embedder, link_predictor, create_dataloader,
- src_seeds, negative_factor=1, train_embeddings=True):
- k = negative_factor
- indices = self.seeds_to_global(src_seeds).tolist()
- batch_size = len(indices)
-
- node_embeddings_batch = node_embeddings
- next_call_indices = elem_embedder[indices] # this assumes indices is torch tensor
-
- # dst targets are not unique
- unique_dst, slice_map = self._handle_non_unique(next_call_indices)
- assert unique_dst[slice_map].tolist() == next_call_indices.tolist()
-
- dataloader = create_dataloader(unique_dst)
- input_nodes, dst_seeds, blocks = next(iter(dataloader))
- blocks = [blk.to(self.device) for blk in blocks]
- assert dst_seeds.shape == unique_dst.shape
- assert dst_seeds.tolist() == unique_dst.tolist()
- unique_dst_embeddings = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes)
- next_call_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
- labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype)
-
- node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
- # negative_indices = torch.tensor(elem_embedder.sample_negative(
- # batch_size * k), dtype=torch.long) # embeddings are sampled from 3/4 unigram distribution
- negative_indices = torch.tensor(elem_embedder.sample_negative(
- batch_size * k, ids=indices), dtype=torch.long) # closest negative
- unique_negative, slice_map = self._handle_non_unique(negative_indices)
- assert unique_negative[slice_map].tolist() == negative_indices.tolist()
-
- self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(), unique_dst_embeddings.detach().cpu().numpy())
-
- dataloader = create_dataloader(unique_negative)
- input_nodes, dst_seeds, blocks = next(iter(dataloader))
- blocks = [blk.to(self.device) for blk in blocks]
- assert dst_seeds.shape == unique_negative.shape
- assert dst_seeds.tolist() == unique_negative.tolist()
- unique_negative_random = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes)
- negative_random = unique_negative_random[slice_map.to(self.device)]
- labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype)
-
- self.target_embedder.set_embed(unique_negative.detach().cpu().numpy(), unique_negative_random.detach().cpu().numpy())
-
- nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0)
- embs = torch.cat([next_call_embeddings, negative_random], dim=0)
- labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)
- return nodes, embs, labels
-
-
-class GraphLinkTypeObjective(GraphLinkObjective):
- def __init__(
- self, name, graph_model, node_embedder, nodes, data_loading_func, device,
- sampling_neighbourhood_size, batch_size,
- tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
- ):
- self.set_num_classes(data_loading_func)
-
- super(GraphLinkObjective, self).__init__(
- name, graph_model, node_embedder, nodes, data_loading_func, device,
- sampling_neighbourhood_size, batch_size,
- tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
- )
-
- def set_num_classes(self, data_loading_func):
- pass
-
- def create_nn_link_type_predictor(self):
- self.link_predictor = LinkClassifier(2 * self.graph_model.emb_size, self.num_classes).to(self.device)
- self.positive_label = 1
- self.negative_label = 0
- self.label_dtype = torch.long
- def create_link_predictor(self):
- if self.link_predictor_type == "nn":
- self.create_nn_link_predictor()
- elif self.link_predictor_type == "inner_prod":
- self.create_inner_prod_link_predictor()
- else:
- raise NotImplementedError()
+# class GraphLinkTypeObjective(GraphLinkObjective):
+# def __init__(
+# self, name, graph_model, node_embedder, nodes, data_loading_func, device,
+# sampling_neighbourhood_size, batch_size,
+# tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
+# measure_scores=False, dilate_scores=1
+# ):
+# self.set_num_classes(data_loading_func)
+#
+# super(GraphLinkObjective, self).__init__(
+# name, graph_model, node_embedder, nodes, data_loading_func, device,
+# sampling_neighbourhood_size, batch_size,
+# tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
+# masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
+# )
+#
+# def set_num_classes(self, data_loading_func):
+# pass
+#
+# def create_nn_link_type_predictor(self):
+# self.link_predictor = LinkClassifier(2 * self.graph_model.emb_size, self.num_classes).to(self.device)
+# self.positive_label = 1
+# self.negative_label = 0
+# self.label_dtype = torch.long
+#
+# def create_link_predictor(self):
+# if self.link_predictor_type == "nn":
+# self.create_nn_link_predictor()
+# elif self.link_predictor_type == "inner_prod":
+# self.create_inner_prod_link_predictor()
+# else:
+# raise NotImplementedError()
diff --git a/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py b/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py
index cd568ec8..799470c1 100644
--- a/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py
+++ b/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py
@@ -1,31 +1,32 @@
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
from itertools import chain
import torch
-from sklearn.metrics import ndcg_score
+from sklearn.metrics import ndcg_score, top_k_accuracy_score
from torch import nn
from torch.nn import CrossEntropyLoss
from SourceCodeTools.code.data.sourcetrail import SubwordMasker
-from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, _compute_accuracy
+from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, compute_accuracy, \
+ sum_scores
from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase
from SourceCodeTools.models.graph.train.Scorer import Scorer
import numpy as np
-class NodeNameClassifier(AbstractObjective):
+class NodeClassifierObjective(AbstractObjective):
def __init__(
- self, graph_model, node_embedder, nodes, data_loading_func, device,
+ self, name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type=None, masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1, early_stopping=False, early_stopping_tolerance=20
):
super().__init__(
- "NodeNameClassifier", graph_model, node_embedder, nodes, data_loading_func, device,
+ name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, early_stopping=early_stopping, early_stopping_tolerance=early_stopping_tolerance
)
def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
@@ -36,63 +37,76 @@ def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
def create_link_predictor(self):
self.classifier = NodeClassifier(self.target_emb_size, self.target_embedder.num_classes).to(self.device)
- def compute_acc_loss(self, graph_emb, labels, return_logits=False):
+ def compute_acc_loss(self, graph_emb, element_emb, labels, return_logits=False):
logits = self.classifier(graph_emb)
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.reshape(-1, logits.size(-1)),
labels.reshape(-1))
- acc = _compute_accuracy(logits.argmax(dim=1), labels)
+ acc = compute_accuracy(logits.argmax(dim=1), labels)
if return_logits:
return acc, loss, logits
return acc, loss
- def forward(self, input_nodes, seeds, blocks, train_embeddings=True):
- masked = self.masker.get_mask(self.seeds_to_python(seeds)) if self.masker is not None else None
- graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked)
+ def prepare_for_prediction(
+ self, node_embeddings, seeds, target_embedding_fn, negative_factor=1,
+ neg_sampling_strategy=None, train_embeddings=True,
+ ):
indices = self.seeds_to_global(seeds).tolist()
labels = torch.LongTensor(self.target_embedder[indices]).to(self.device)
- acc, loss = self.compute_acc_loss(graph_emb, labels)
- return loss, acc
+ return node_embeddings, None, labels
- def evaluate_generation(self, data_split):
- total_loss = 0
- total_acc = 0
- ndcg_at = [1, 3, 5, 10]
- total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at}
- ndcg_count = 0
+ def evaluate_objective(self, data_split, neg_sampling_strategy=None, negative_factor=1):
+ at = [1, 3, 5, 10]
count = 0
+ scores = defaultdict(list)
for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"):
blocks = [blk.to(self.device) for blk in blocks]
- src_embs = self._logits_batch(input_nodes, blocks)
- indices = self.seeds_to_global(seeds).tolist()
- labels = self.target_embedder[indices]
- labels = torch.LongTensor(labels).to(self.device)
- acc, loss, logits = self.compute_acc_loss(src_embs, labels, return_logits=True)
+ if self.masker is None:
+ masked = None
+ else:
+ masked = self.masker.get_mask(self.seeds_to_python(seeds))
+
+ src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked)
+ node_embs_, element_embs_, labels = self.prepare_for_prediction(
+ src_embs, seeds, self.target_embedding_fn, negative_factor=negative_factor,
+ neg_sampling_strategy=neg_sampling_strategy,
+ train_embeddings=False
+ )
+ # indices = self.seeds_to_global(seeds).tolist()
+ # labels = self.target_embedder[indices]
+ # labels = torch.LongTensor(labels).to(self.device)
+ acc, loss, logits = self.compute_acc_loss(node_embs_, element_embs_, labels, return_logits=True)
y_pred = nn.functional.softmax(logits, dim=-1).to("cpu").numpy()
y_true = np.zeros(y_pred.shape)
y_true[np.arange(0, y_true.shape[0]), labels.to("cpu").numpy()] = 1.
- if self.measure_ndcg:
- if count % self.dilate_ndcg == 0:
- for k in ndcg_at:
- total_ndcg[f"ndcg@{k}"] += ndcg_score(y_true, y_pred, k=k)
- ndcg_count += 1
+ if self.measure_scores:
+ if count % self.dilate_scores == 0:
+ y_true_onehot = np.array(y_true)
+ labels = list(range(y_true_onehot.shape[1]))
- total_loss += loss.item()
- total_acc += acc
+ for k in at:
+ scores[f"ndcg@{k}"].append(ndcg_score(y_true, y_pred, k=k))
+ scores[f"acc@{k}"].append(
+ top_k_accuracy_score(y_true_onehot.argmax(-1), y_pred, k=k, labels=labels)
+ )
+
+ scores["Loss"].append(loss.item())
+ scores["Accuracy"].append(acc)
count += 1
- return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_ndcg else None
- def evaluate(self, data_split, neg_sampling_factor=1):
- loss, acc, bleu = self.evaluate_generation(data_split)
- return loss, acc, bleu
+ if count == 0:
+ count += 1
+
+ scores = {key: sum_scores(val) for key, val in scores.items()}
+ return scores
def parameters(self, recurse: bool = True):
return chain(self.classifier.parameters())
@@ -100,10 +114,25 @@ def parameters(self, recurse: bool = True):
def custom_state_dict(self):
state_dict = OrderedDict()
for k, v in self.classifier.state_dict().items():
- state_dict[f"target_embedder.{k}"] = v
+ state_dict[f"classifier.{k}"] = v
return state_dict
+class NodeNameClassifier(NodeClassifierObjective):
+ def __init__(
+ self, graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ tokenizer_path=None, target_emb_size=None, link_predictor_type=None, masker: SubwordMasker = None,
+ measure_scores=False, dilate_scores=1, early_stopping=False, early_stopping_tolerance=20
+ ):
+ super().__init__(
+ "NodeNameClassifier", graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, early_stopping=early_stopping, early_stopping_tolerance=early_stopping_tolerance
+ )
+
+
class ClassifierTargetMapper(ElementEmbedderBase, Scorer):
def __init__(self, elements, nodes):
ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes,)
diff --git a/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py b/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py
index ff83d6e0..86a3b1a2 100644
--- a/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py
+++ b/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py
@@ -10,14 +10,17 @@ def __init__(
self, name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1, nn_index="brute"
):
super(SubwordEmbedderObjective, self).__init__(
name, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index
)
+ self.target_embedding_fn = self.get_targets_from_embedder
+ self.negative_factor = 1
+ self.update_embeddings_for_queries = False
def verify_parameters(self):
if self.link_predictor_type == "inner_prod":
@@ -26,34 +29,6 @@ def verify_parameters(self):
def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
self.create_subword_embedder(data_loading_func, nodes, tokenizer_path)
- def create_link_predictor(self):
- if self.link_predictor_type == "nn":
- self.create_nn_link_predictor()
- elif self.link_predictor_type == "inner_prod":
- self.create_inner_prod_link_predictor()
- else:
- raise NotImplementedError()
-
- def forward(self, input_nodes, seeds, blocks, train_embeddings=True):
- masked = self.masker.get_mask(self.seeds_to_python(seeds)) if self.masker is not None else None
- graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked)
- node_embs_, element_embs_, labels = self._logits_embedder(
- graph_emb, self.target_embedder, self.link_predictor, seeds
- )
- acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels)
-
- return loss, acc
-
- # def train(self):
- # pass
-
- def evaluate(self, data_split, neg_sampling_factor=1):
- loss, acc, ndcg = self._evaluate_embedder(
- self.target_embedder, self.link_predictor, data_split=data_split,
- neg_sampling_factor=neg_sampling_factor
- )
- return loss, acc, ndcg
-
def parameters(self, recurse: bool = True):
return chain(self.target_embedder.parameters(), self.link_predictor.parameters())
diff --git a/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py b/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py
index 56ff5d8a..3a627273 100644
--- a/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py
+++ b/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py
@@ -1,17 +1,17 @@
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
from itertools import chain
import datasets
import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss, NLLLoss
+from torch.nn import CrossEntropyLoss
from SourceCodeTools.code.data.sourcetrail import SubwordMasker
from SourceCodeTools.models.graph.ElementEmbedder import DocstringEmbedder, create_fixed_length, \
ElementEmbedderWithBpeSubwords
from SourceCodeTools.models.graph.train.objectives import SubwordEmbedderObjective
-from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, _compute_accuracy
-from SourceCodeTools.models.nlp.Decoder import LSTMDecoder, Decoder
+from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, compute_accuracy, \
+ sum_scores
+from SourceCodeTools.models.nlp.TorchDecoder import LSTMDecoder, Decoder
from SourceCodeTools.models.nlp.Vocabulary import Vocabulary
from SourceCodeTools.nlp.embed.bpe import load_bpe_model
import numpy as np
@@ -22,13 +22,13 @@ def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1
):
super().__init__(
"GraphTextPrediction", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
)
def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
@@ -43,14 +43,14 @@ def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1, max_len=20
+ measure_scores=False, dilate_scores=1, max_len=20
):
self.max_len = max_len + 2 # add pad and eos
super().__init__(
"GraphTextGeneration", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
)
def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
@@ -85,17 +85,21 @@ def compute_acc_loss(self, graph_emb, labels, lengths, return_logits=False):
max_len = logits.shape[1]
length_mask = torch.arange(max_len).to(self.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
+ logits_unrolled = logits[length_mask, :]
+ labels_unrolled = labels[length_mask]
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(logits_unrolled, labels_unrolled)
# mask_ = length_mask.reshape(-1,)
# loss_fct = NLLLoss(reduction="sum")
- loss = loss_fct(logits.reshape(-1, logits.size(-1)),#[mask_, :],
- labels.reshape(-1)) #[mask_])
+ # loss = loss_fct(logits.reshape(-1, logits.size(-1)),#[mask_, :],
+ # labels.reshape(-1)) #[mask_])
def masked_accuracy(pred, true, mask):
mask = mask.reshape(-1,)
pred = pred.reshape(-1,)[mask]
true = true.reshape(-1,)[mask]
- return _compute_accuracy(pred, true)
+ return compute_accuracy(pred, true)
acc = masked_accuracy(logits.argmax(dim=2), labels, length_mask)
@@ -103,9 +107,8 @@ def masked_accuracy(pred, true, mask):
return acc, loss, logits
return acc, loss
-
- def forward(self, input_nodes, seeds, blocks, train_embeddings=True):
- graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings)
+ def forward(self, input_nodes, seeds, blocks, train_embeddings=True, neg_sampling_strategy=None):
+ graph_emb = self._graph_embeddings(input_nodes, blocks, train_embeddings)
indices = self.seeds_to_global(seeds).tolist()
labels, lengths = self.target_embedder[indices]
labels = labels.to(self.device)
@@ -129,16 +132,13 @@ def get_generated(self, tokens):
return sents
def evaluate_generation(self, data_split):
- total_loss = 0
- total_acc = 0
- total_bleu = {f"bleu": 0.}
- bleu_count = 0
+ scores = defaultdict(list)
count = 0
for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"):
blocks = [blk.to(self.device) for blk in blocks]
- src_embs = self._logits_batch(input_nodes, blocks)
+ src_embs = self._graph_embeddings(input_nodes, blocks)
indices = self.seeds_to_global(seeds).tolist()
labels, lengths = self.target_embedder[indices]
labels = labels.to(self.device)
@@ -154,16 +154,19 @@ def evaluate_generation(self, data_split):
# bleu = sacrebleu.corpus_bleu(pred, true)
bleu = self.bleu_metric.compute(predictions=pred, references=[[t] for t in true])
- bleu_count += 1
- total_bleu["bleu"] += bleu['score']
+ scores["bleu"].append(bleu['score'])
- total_loss += loss.item()
- total_acc += acc
+ scores["Loss"].append(loss.item())
+ scores["Accuracy"].append(acc)
count += 1
- return total_loss / count, total_acc / count, {"bleu": total_bleu["bleu"] / bleu_count}
- def evaluate(self, data_split, neg_sampling_factor=1):
+ scores = {key: sum_scores(val) for key, val in scores.items()}
+ return scores
+
+ def evaluate(self, data_split, *, neg_sampling_strategy=None, early_stopping=False, early_stopping_tolerance=20):
loss, acc, bleu = self.evaluate_generation(data_split)
+ if data_split == "val":
+ self.check_early_stopping(acc)
return loss, acc, bleu
def parameters(self, recurse: bool = True):
diff --git a/SourceCodeTools/models/graph/train/objectives/__init__.py b/SourceCodeTools/models/graph/train/objectives/__init__.py
index 57b55612..b98486bb 100644
--- a/SourceCodeTools/models/graph/train/objectives/__init__.py
+++ b/SourceCodeTools/models/graph/train/objectives/__init__.py
@@ -1,6 +1,8 @@
from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker
-from SourceCodeTools.models.graph.train.objectives.GraphLinkObjective import GraphLinkObjective, GraphLinkTypeObjective
-from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeNameClassifier
+from SourceCodeTools.models.graph.train.objectives.GraphLinkClassificationObjective import \
+ GraphLinkClassificationObjective
+from SourceCodeTools.models.graph.train.objectives.GraphLinkObjective import GraphLinkObjective
+from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeNameClassifier, NodeClassifierObjective
from SourceCodeTools.models.graph.train.objectives.SubwordEmbedderObjective import SubwordEmbedderObjective
from SourceCodeTools.models.graph.train.objectives.TextPredictionObjective import GraphTextPrediction, GraphTextGeneration
@@ -10,13 +12,13 @@ def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1
):
super(TokenNamePrediction, self).__init__(
"TokenNamePrediction", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
)
@@ -25,13 +27,29 @@ def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1, nn_index="brute"
):
super(NodeNamePrediction, self).__init__(
"NodeNamePrediction", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index
+ )
+
+
+class TypeAnnPrediction(NodeClassifierObjective):
+ def __init__(
+ self, graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ tokenizer_path=None, target_emb_size=None, link_predictor_type=None, masker: SubwordMasker = None,
+ measure_scores=False, dilate_scores=1
+ ):
+ super(TypeAnnPrediction, self).__init__(
+ "TypeAnnPrediction", graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ # tokenizer_path=tokenizer_path,
+ target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
+ masker=masker, dilate_scores=dilate_scores
)
@@ -40,13 +58,13 @@ def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1
):
super(VariableNameUsePrediction, self).__init__(
"VariableNamePrediction", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
)
@@ -55,13 +73,13 @@ def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1
):
super(NextCallPrediction, self).__init__(
"NextCallPrediction", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
)
@@ -70,26 +88,68 @@ def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1
):
super(GlobalLinkPrediction, self).__init__(
"GlobalLinkPrediction", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
)
-class LinkTypePrediction(GraphLinkTypeObjective):
+# class LinkTypePrediction(GraphLinkTypeObjective):
+# def __init__(
+# self, graph_model, node_embedder, nodes, data_loading_func, device,
+# sampling_neighbourhood_size, batch_size,
+# tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
+# measure_scores=False, dilate_scores=1
+# ):
+# super(GraphLinkTypeObjective, self).__init__(
+# "LinkTypePrediction", graph_model, node_embedder, nodes, data_loading_func, device,
+# sampling_neighbourhood_size, batch_size,
+# tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
+# masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
+# )
+
+
+class EdgePrediction(GraphLinkObjective):
def __init__(
self, graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
- measure_ndcg=False, dilate_ndcg=1
+ measure_scores=False, dilate_scores=1, nn_index="brute", ns_groups=None
):
- super(GraphLinkTypeObjective, self).__init__(
- "LinkTypePrediction", graph_model, node_embedder, nodes, data_loading_func, device,
+ super(EdgePrediction, self).__init__(
+ "EdgePrediction", graph_model, node_embedder, nodes, data_loading_func, device,
sampling_neighbourhood_size, batch_size,
tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
- masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index,
+ ns_groups=ns_groups
)
+ # super(EdgePrediction, self).__init__(
+ # "LinkTypePrediction", graph_model, node_embedder, nodes, data_loading_func, device,
+ # sampling_neighbourhood_size, batch_size,
+ # tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
+ # masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
+ # )
+
+class EdgePrediction2(SubwordEmbedderObjective):
+ def __init__(
+ self, graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None,
+ measure_scores=False, dilate_scores=1
+ ):
+ super(EdgePrediction2, self).__init__(
+ "EdgePrediction2", graph_model, node_embedder, nodes, data_loading_func, device,
+ sampling_neighbourhood_size, batch_size,
+ tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type,
+ masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores
+ )
+
+ def create_target_embedder(self, data_loading_func, nodes, tokenizer_path):
+ from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedder
+ self.target_embedder = ElementEmbedder(
+ elements=data_loading_func(), nodes=nodes, emb_size=self.target_emb_size,
+ ).to(self.device)
\ No newline at end of file
diff --git a/SourceCodeTools/models/graph/train/sampling_multitask2.py b/SourceCodeTools/models/graph/train/sampling_multitask2.py
index 46c9d873..0b311ab6 100644
--- a/SourceCodeTools/models/graph/train/sampling_multitask2.py
+++ b/SourceCodeTools/models/graph/train/sampling_multitask2.py
@@ -1,7 +1,9 @@
import os
+from collections import defaultdict
from copy import copy
from typing import Tuple
+import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
@@ -10,15 +12,25 @@
from os.path import join
import logging
+from tqdm import tqdm
+
from SourceCodeTools.models.Embedder import Embedder
from SourceCodeTools.models.graph.train.objectives import VariableNameUsePrediction, TokenNamePrediction, \
NextCallPrediction, NodeNamePrediction, GlobalLinkPrediction, GraphTextPrediction, GraphTextGeneration, \
- NodeNameClassifier
+ NodeNameClassifier, EdgePrediction, TypeAnnPrediction, EdgePrediction2, NodeClassifierObjective
from SourceCodeTools.models.graph.NodeEmbedder import NodeEmbedder
+from SourceCodeTools.models.graph.train.objectives.GraphLinkClassificationObjective import TransRObjective
+
+class EarlyStopping(Exception):
+ def __init__(self, *args, **kwargs):
+ super(EarlyStopping, self).__init__(*args, **kwargs)
-def _compute_accuracy(pred_, true_):
- return torch.sum(pred_ == true_).item() / len(true_)
+
+def add_to_summary(summary, partition, objective_name, scores, postfix):
+ summary.update({
+ f"{key}/{partition}/{objective_name}_{postfix}": val for key, val in scores.items()
+ })
class SamplingMultitaskTrainer:
@@ -27,7 +39,7 @@ def __init__(self,
dataset=None, model_name=None, model_params=None,
trainer_params=None, restore=None, device=None,
pretrained_embeddings_path=None,
- tokenizer_path=None
+ tokenizer_path=None, load_external_dataset=None
):
self.graph_model = model_name(dataset.g, **model_params).to(device)
@@ -35,24 +47,32 @@ def __init__(self,
self.trainer_params = trainer_params
self.device = device
self.epoch = 0
+ self.restore_epoch = 0
self.batch = 0
self.dtype = torch.float32
+ if load_external_dataset is not None:
+ logging.info("Loading external dataset")
+ external_args, external_dataset = load_external_dataset()
+ self.graph_model.g = external_dataset.g
+ dataset = external_dataset
self.create_node_embedder(
dataset, tokenizer_path, n_dims=model_params["h_dim"],
pretrained_path=pretrained_embeddings_path, n_buckets=trainer_params["embedding_table_size"]
)
- self.summary_writer = SummaryWriter(self.model_base_path)
-
self.create_objectives(dataset, tokenizer_path)
if restore:
self.restore_from_checkpoint(self.model_base_path)
- self.optimizer = self._create_optimizer()
+ if load_external_dataset is not None:
+ self.trainer_params["model_base_path"] = external_args.external_model_base
+
+ self._create_optimizer()
self.lr_scheduler = ExponentialLR(self.optimizer, gamma=1.0)
# self.lr_scheduler = ReduceLROnPlateau(self.optimizer, patience=10, cooldown=20)
+ self.summary_writer = SummaryWriter(self.model_base_path)
def create_objectives(self, dataset, tokenizer_path):
self.objectives = nn.ModuleList()
@@ -66,12 +86,18 @@ def create_objectives(self, dataset, tokenizer_path):
self.create_api_call_objective(dataset, tokenizer_path)
if "global_link_pred" in self.trainer_params["objectives"]:
self.create_global_link_objective(dataset, tokenizer_path)
+ if "edge_pred" in self.trainer_params["objectives"]:
+ self.create_edge_objective(dataset, tokenizer_path)
+ if "transr" in self.trainer_params["objectives"]:
+ self.create_transr_objective(dataset, tokenizer_path)
if "doc_pred" in self.trainer_params["objectives"]:
self.create_text_prediction_objective(dataset, tokenizer_path)
if "doc_gen" in self.trainer_params["objectives"]:
self.create_text_generation_objective(dataset, tokenizer_path)
if "node_clf" in self.trainer_params["objectives"]:
- self.create_node_name_classifier_objective(dataset, tokenizer_path)
+ self.create_node_classifier_objective(dataset, tokenizer_path)
+ if "type_ann_pred" in self.trainer_params["objectives"]:
+ self.create_type_ann_objective(dataset, tokenizer_path)
def create_token_pred_objective(self, dataset, tokenizer_path):
self.objectives.append(
@@ -80,21 +106,40 @@ def create_token_pred_objective(self, dataset, tokenizer_path):
dataset.load_token_prediction, self.device,
self.sampling_neighbourhood_size, self.batch_size,
tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod",
- masker=dataset.create_subword_masker(), measure_ndcg=self.trainer_params["measure_ndcg"],
- dilate_ndcg=self.trainer_params["dilate_ndcg"]
+ masker=dataset.create_subword_masker(), measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"]
)
)
def create_node_name_objective(self, dataset, tokenizer_path):
self.objectives.append(
+ # GraphTextGeneration(
+ # self.graph_model, self.node_embedder, dataset.nodes,
+ # dataset.load_node_names, self.device,
+ # self.sampling_neighbourhood_size, self.batch_size,
+ # tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size,
+ # )
NodeNamePrediction(
self.graph_model, self.node_embedder, dataset.nodes,
dataset.load_node_names, self.device,
self.sampling_neighbourhood_size, self.batch_size,
tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod",
masker=dataset.create_node_name_masker(tokenizer_path),
- measure_ndcg=self.trainer_params["measure_ndcg"],
- dilate_ndcg=self.trainer_params["dilate_ndcg"]
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"], nn_index=self.trainer_params["nn_index"]
+ )
+ )
+
+ def create_type_ann_objective(self, dataset, tokenizer_path):
+ self.objectives.append(
+ TypeAnnPrediction(
+ self.graph_model, self.node_embedder, dataset.nodes,
+ dataset.load_type_prediction, self.device,
+ self.sampling_neighbourhood_size, self.batch_size,
+ tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod",
+ masker=None,
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"]
)
)
@@ -106,8 +151,8 @@ def create_node_name_classifier_objective(self, dataset, tokenizer_path):
self.sampling_neighbourhood_size, self.batch_size,
tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size,
masker=dataset.create_node_name_masker(tokenizer_path),
- measure_ndcg=self.trainer_params["measure_ndcg"],
- dilate_ndcg=self.trainer_params["dilate_ndcg"]
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"]
)
)
@@ -119,8 +164,8 @@ def create_var_use_objective(self, dataset, tokenizer_path):
self.sampling_neighbourhood_size, self.batch_size,
tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod",
masker=dataset.create_variable_name_masker(tokenizer_path),
- measure_ndcg=self.trainer_params["measure_ndcg"],
- dilate_ndcg=self.trainer_params["dilate_ndcg"]
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"]
)
)
@@ -131,8 +176,8 @@ def create_api_call_objective(self, dataset, tokenizer_path):
dataset.load_api_call, self.device,
self.sampling_neighbourhood_size, self.batch_size,
tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod",
- measure_ndcg=self.trainer_params["measure_ndcg"],
- dilate_ndcg=self.trainer_params["dilate_ndcg"]
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"]
)
)
@@ -145,8 +190,35 @@ def create_global_link_objective(self, dataset, tokenizer_path):
dataset.load_global_edges_prediction, self.device,
self.sampling_neighbourhood_size, self.batch_size,
tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="nn",
- measure_ndcg=self.trainer_params["measure_ndcg"],
- dilate_ndcg=self.trainer_params["dilate_ndcg"]
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"]
+ )
+ )
+
+ def create_edge_objective(self, dataset, tokenizer_path):
+ self.objectives.append(
+ EdgePrediction(
+ self.graph_model, self.node_embedder, dataset.nodes,
+ dataset.load_edge_prediction, self.device,
+ self.sampling_neighbourhood_size, self.batch_size,
+ tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type=self.trainer_params["metric"],
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"], nn_index=self.trainer_params["nn_index"],
+ # ns_groups=dataset.get_negative_sample_groups()
+ )
+ )
+
+ def create_transr_objective(self, dataset, tokenizer_path):
+ self.objectives.append(
+ TransRObjective(
+ self.graph_model, self.node_embedder, dataset.nodes,
+ dataset.load_edge_prediction, self.device,
+ self.sampling_neighbourhood_size, self.batch_size,
+ tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size,
+ link_predictor_type=self.trainer_params["metric"],
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"],
+ # ns_groups=dataset.get_negative_sample_groups()
)
)
@@ -157,8 +229,8 @@ def create_text_prediction_objective(self, dataset, tokenizer_path):
dataset.load_docstring, self.device,
self.sampling_neighbourhood_size, self.batch_size,
tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod",
- measure_ndcg=self.trainer_params["measure_ndcg"],
- dilate_ndcg=self.trainer_params["dilate_ndcg"]
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"]
)
)
@@ -172,6 +244,22 @@ def create_text_generation_objective(self, dataset, tokenizer_path):
)
)
+ def create_node_classifier_objective(self, dataset, tokenizer_path):
+ self.objectives.append(
+ NodeClassifierObjective(
+ "NodeTypeClassifier",
+ self.graph_model, self.node_embedder, dataset.nodes,
+ dataset.load_node_classes, self.device,
+ self.sampling_neighbourhood_size, self.batch_size,
+ tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size,
+ masker=dataset.create_node_clf_masker(),
+ measure_scores=self.trainer_params["measure_scores"],
+ dilate_scores=self.trainer_params["dilate_scores"],
+ early_stopping=self.trainer_params["early_stopping"],
+ early_stopping_tolerance=self.trainer_params["early_stopping_tolerance"]
+ )
+ )
+
def create_node_embedder(self, dataset, tokenizer_path, n_dims=None, pretrained_path=None, n_buckets=500000):
from SourceCodeTools.nlp.embed.fasttext import load_w2v_map
@@ -271,14 +359,31 @@ def write_hyperparams(self, scores, epoch):
def _create_optimizer(self):
parameters = nn.ParameterList(self.graph_model.parameters())
- parameters.extend(self.node_embedder.parameters())
+ nodeembedder_params = list(self.node_embedder.parameters())
+ # parameters.extend(self.node_embedder.parameters())
[parameters.extend(objective.parameters()) for objective in self.objectives]
# AdaHessian TODO could not run
# optimizer = Yogi(parameters, lr=self.lr)
- optimizer = torch.optim.AdamW(
- [{"params": parameters}], lr=self.lr
+ self.optimizer = torch.optim.AdamW(
+ [{"params": parameters}], lr=self.lr, weight_decay=0.5
+ )
+ self.sparse_optimizer = torch.optim.SparseAdam(
+ [{"params": nodeembedder_params}], lr=self.lr
)
- return optimizer
+
+ def compute_embeddings_for_scorer(self, objective):
+ if hasattr(objective.target_embedder, "scorer_all_keys") and objective.update_embeddings_for_queries:
+ def chunks(lst, n):
+ for i in range(0, len(lst), n):
+ yield torch.LongTensor(lst[i:i + n])
+
+ batches = chunks(objective.target_embedder.scorer_all_keys, self.trainer_params["batch_size"])
+ for batch in tqdm(
+ batches,
+ total=len(objective.target_embedder.scorer_all_keys) // self.trainer_params["batch_size"] + 1,
+ desc="Precompute Target Embeddings", leave=True
+ ):
+ _ = objective.target_embedding_fn(batch) # scorer embedding updated inside
def train_all(self):
"""
@@ -287,9 +392,12 @@ def train_all(self):
"""
summary_dict = {}
+ best_val_loss = float("inf")
+ write_best_model = False
for objective in self.objectives:
with torch.set_grad_enabled(False):
+ self.compute_embeddings_for_scorer(objective)
objective.target_embedder.prepare_index() # need this to update sampler for the next epoch
for epoch in range(self.epoch, self.epochs):
@@ -297,10 +405,18 @@ def train_all(self):
start = time()
- keep_training = True
-
summary_dict = {}
- while keep_training:
+ num_batches = min([objective.num_train_batches for objective in self.objectives])
+
+ train_losses = defaultdict(list)
+ train_accs = defaultdict(list)
+
+ # def append_metric(destination, name, metric):
+ # if name not in destination:
+ # destination[name] = []
+ # destination[name].append(metric)
+
+ for step in tqdm(range(num_batches), total=num_batches, desc=f"Epoch {self.epoch}"):
loss_accum = 0
@@ -312,10 +428,24 @@ def train_all(self):
break
self.optimizer.zero_grad()
- for objective, (input_nodes, seeds, blocks) in zip(self.objectives, loaders):
+ self.sparse_optimizer.zero_grad()
+ for ind, (objective, (input_nodes, seeds, blocks)) in enumerate(zip(self.objectives, loaders)):
blocks = [blk.to(self.device) for blk in blocks]
- loss, acc = objective(input_nodes, seeds, blocks, train_embeddings=self.finetune)
+ objective.target_embedder.prepare_index()
+
+ do_break = False
+ for block in blocks:
+ if block.num_edges() == 0:
+ do_break = True
+ if do_break:
+ break
+
+ # try:
+ loss, acc = objective(
+ input_nodes, seeds, blocks, train_embeddings=self.finetune,
+ neg_sampling_strategy="w2v" if self.trainer_params["force_w2v_ns"] else None
+ )
loss = loss / len(self.objectives) # assumes the same batch size for all objectives
loss_accum += loss.item()
@@ -323,21 +453,34 @@ def train_all(self):
# for param in groups["params"]:
# torch.nn.utils.clip_grad_norm_(param, max_norm=1.)
loss.backward() # create_graph = True
+
+ summary = {}
+ add_to_summary(
+ summary=summary, partition="train", objective_name=objective.name,
+ scores={"Loss": loss.item(), "Accuracy": acc}, postfix=""
+ )
+
+ train_losses[f"Loss/train_avg/{objective.name}"].append(loss.item())
+ train_accs[f"Accuracy/train_avg/{objective.name}"].append(acc)
- summary.update({
- f"Loss/train/{objective.name}_vs_batch": loss.item(),
- f"Accuracy/train/{objective.name}_vs_batch": acc,
- })
+ # except ZeroEdges as e:
+ # logging.warning(f"Zero edges in loader in step {step}")
+ # except Exception as e:
+ # raise e
self.optimizer.step()
+ self.sparse_optimizer.step()
+ step += 1
self.write_summary(summary, self.batch)
summary_dict.update(summary)
self.batch += 1
- summary = {
- f"Loss/train": loss_accum,
- }
+ # summary = {
+ # f"Loss/train": loss_accum,
+ # }
+ summary = {key: sum(val) / len(val) for key, val in train_losses.items()}
+ summary.update({key: sum(val) / len(val) for key, val in train_accs.items()})
self.write_summary(summary, self.batch)
summary_dict.update(summary)
@@ -350,21 +493,29 @@ def train_all(self):
with torch.set_grad_enabled(False):
objective.target_embedder.prepare_index() # need this to update sampler for the next epoch
- val_loss, val_acc, val_ndcg = objective.evaluate("val")
- test_loss, test_acc, test_ndcg = objective.evaluate("test")
-
- summary = {
- f"Accuracy/test/{objective.name}_vs_batch": test_acc,
- f"Accuracy/val/{objective.name}_vs_batch": val_acc,
- }
- if test_ndcg is not None:
- summary.update({f"{key}/test/{objective.name}_vs_batch": val for key, val in test_ndcg.items()})
- if val_ndcg is not None:
- summary.update({f"{key}/val/{objective.name}_vs_batch": val for key, val in val_ndcg.items()})
+ val_scores = objective.evaluate("val")
+ test_scores = objective.evaluate("test")
+
+ summary = {}
+ add_to_summary(summary, "val", objective.name, val_scores, postfix="")
+ add_to_summary(summary, "test", objective.name, test_scores, postfix="")
self.write_summary(summary, self.batch)
summary_dict.update(summary)
+ val_losses = [item for key, item in summary_dict.items() if key.startswith("Loss/val")]
+ avg_val_loss = sum(val_losses) / len(val_losses)
+ if avg_val_loss < best_val_loss:
+ best_val_loss = avg_val_loss
+ write_best_model = True
+
+ if self.do_save:
+ self.save_checkpoint(self.model_base_path, write_best_model=write_best_model)
+ write_best_model = False
+
+ if objective.early_stopping_trigger is True:
+ raise EarlyStopping()
+
objective.train()
# self.write_hyperparams({k.replace("vs_batch", "vs_epoch"): v for k, v in summary_dict.items()}, self.epoch)
@@ -374,14 +525,11 @@ def train_all(self):
print(f"Epoch: {self.epoch}, Time: {int(end - start)} s", end="\t")
print(summary_dict)
- if self.do_save:
- self.save_checkpoint(self.model_base_path)
-
self.lr_scheduler.step()
- def save_checkpoint(self, checkpoint_path=None, checkpoint_name=None, **kwargs):
+ def save_checkpoint(self, checkpoint_path=None, checkpoint_name=None, write_best_model=False, **kwargs):
- checkpoint_path = join(checkpoint_path, "saved_state.pt")
+ model_path = join(checkpoint_path, f"saved_state.pt")
param_dict = {
'graph_model': self.graph_model.state_dict(),
@@ -396,15 +544,21 @@ def save_checkpoint(self, checkpoint_path=None, checkpoint_name=None, **kwargs):
if len(kwargs) > 0:
param_dict.update(kwargs)
- torch.save(param_dict, checkpoint_path)
+ torch.save(param_dict, model_path)
+ if self.trainer_params["save_each_epoch"]:
+ torch.save(param_dict, join(checkpoint_path, f"saved_state_{self.epoch}.pt"))
+
+ if write_best_model:
+ torch.save(param_dict, join(checkpoint_path, f"best_model.pt"))
def restore_from_checkpoint(self, checkpoint_path):
- checkpoint = torch.load(join(checkpoint_path, "saved_state.pt"))
+ checkpoint = torch.load(join(checkpoint_path, "saved_state.pt"), map_location=torch.device('cpu'))
self.graph_model.load_state_dict(checkpoint['graph_model'])
self.node_embedder.load_state_dict(checkpoint['node_embedder'])
for objective in self.objectives:
objective.custom_load_state_dict(checkpoint[objective.name])
self.epoch = checkpoint['epoch']
+ self.restore_epoch = checkpoint['epoch']
self.batch = checkpoint['batch']
logging.info(f"Restored from epoch {checkpoint['epoch']}")
# TODO needs test
@@ -417,27 +571,24 @@ def final_evaluation(self):
objective.reset_iterator("train")
objective.reset_iterator("val")
objective.reset_iterator("test")
+ # objective.early_stopping = False
+ self.compute_embeddings_for_scorer(objective)
+ objective.target_embedder.prepare_index()
+ objective.update_embeddings_for_queries = False
with torch.set_grad_enabled(False):
for objective in self.objectives:
- train_loss, train_acc, train_ndcg = objective.evaluate("train")
- val_loss, val_acc, val_ndcg = objective.evaluate("val")
- test_loss, test_acc, test_ndcg = objective.evaluate("test")
-
- summary = {
- f"Accuracy/train/{objective.name}_final": train_acc,
- f"Accuracy/test/{objective.name}_final": test_acc,
- f"Accuracy/val/{objective.name}_final": val_acc,
- }
- if train_ndcg is not None:
- summary.update({f"{key}/train/{objective.name}_final": val for key, val in train_ndcg.items()})
- if val_ndcg is not None:
- summary.update({f"{key}/val/{objective.name}_final": val for key, val in val_ndcg.items()})
- if test_ndcg is not None:
- summary.update({f"{key}/test/{objective.name}_final": val for key, val in test_ndcg.items()})
-
+ # train_scores = objective.evaluate("train")
+ val_scores = objective.evaluate("val")
+ test_scores = objective.evaluate("test")
+
+ summary = {}
+ # add_to_summary(summary, "train", objective.name, train_scores, postfix="final")
+ add_to_summary(summary, "val", objective.name, val_scores, postfix="final")
+ add_to_summary(summary, "test", objective.name, test_scores, postfix="final")
+
summary_dict.update(summary)
# self.write_hyperparams(summary_dict, self.epoch)
@@ -474,7 +625,8 @@ def get_embeddings(self):
for ntype in self.graph_model.g.ntypes
}
- h = self.graph_model.inference(batch_size=256, device='cpu', num_workers=0, x=node_embs)
+ logging.info("Computing all embeddings")
+ h = self.graph_model.inference(batch_size=2048, device='cpu', num_workers=0, x=node_embs)
original_id = []
global_id = []
@@ -499,9 +651,12 @@ def select_device(args):
def training_procedure(
- dataset, model_name, model_params, args, model_base_path
+ dataset, model_name, model_params, args, model_base_path, trainer=None, load_external_dataset=None
) -> Tuple[SamplingMultitaskTrainer, dict]:
+ if trainer is None:
+ trainer = SamplingMultitaskTrainer
+
device = select_device(args)
model_params['num_classes'] = args.node_emb_size
@@ -509,15 +664,16 @@ def training_procedure(
model_params['use_att_checkpoint'] = args.use_att_checkpoint
model_params['use_gru_checkpoint'] = args.use_gru_checkpoint
+ if len(args.objectives.split(",")) > 1 and args.early_stopping is True:
+ print("Early stopping disabled when several objectives are used")
+ args.early_stopping = False
+
trainer_params = {
'lr': model_params.pop('lr'),
'batch_size': args.batch_size,
'sampling_neighbourhood_size': args.num_per_neigh,
'neg_sampling_factor': args.neg_sampling_factor,
'epochs': args.epochs,
- # 'node_name_file': args.fname_file,
- # 'var_use_file': args.varuse_file,
- # 'call_seq_file': args.call_seq_file,
'elem_emb_size': args.elem_emb_size,
'model_base_path': model_base_path,
'pretraining_phase': args.pretraining_phase,
@@ -525,12 +681,18 @@ def training_procedure(
'schedule_layers_every': args.schedule_layers_every,
'embedding_table_size': args.embedding_table_size,
'save_checkpoints': args.save_checkpoints,
- 'measure_ndcg': args.measure_ndcg,
- 'dilate_ndcg': args.dilate_ndcg,
- "objectives": args.objectives.split(",")
+ 'measure_scores': args.measure_scores,
+ 'dilate_scores': args.dilate_scores,
+ "objectives": args.objectives.split(","),
+ "early_stopping": args.early_stopping,
+ "early_stopping_tolerance": args.early_stopping_tolerance,
+ "force_w2v_ns": args.force_w2v_ns,
+ "metric": args.metric,
+ "nn_index": args.nn_index,
+ "save_each_epoch": args.save_each_epoch
}
- trainer = SamplingMultitaskTrainer(
+ trainer = trainer(
dataset=dataset,
model_name=model_name,
model_params=model_params,
@@ -538,15 +700,18 @@ def training_procedure(
restore=args.restore_state,
device=device,
pretrained_embeddings_path=args.pretrained,
- tokenizer_path=args.tokenizer
+ tokenizer_path=args.tokenizer,
+ load_external_dataset=load_external_dataset
)
try:
trainer.train_all()
except KeyboardInterrupt:
- print("Training interrupted")
+ logging.info("Training interrupted")
+ except EarlyStopping:
+ logging.info("Early stopping triggered")
except Exception as e:
- raise e
+ print("There was an exception", e)
trainer.eval()
scores = trainer.final_evaluation()
@@ -557,8 +722,11 @@ def training_procedure(
def evaluation_procedure(
- dataset, model_name, model_params, args, model_base_path
-) -> Tuple[SamplingMultitaskTrainer, dict]:
+ dataset, model_name, model_params, args, model_base_path, trainer=None
+):
+
+ if trainer is None:
+ trainer = SamplingMultitaskTrainer
device = select_device(args)
@@ -575,9 +743,6 @@ def evaluation_procedure(
'sampling_neighbourhood_size': args.num_per_neigh,
'neg_sampling_factor': args.neg_sampling_factor,
'epochs': args.epochs,
- # 'node_name_file': args.fname_file,
- # 'var_use_file': args.varuse_file,
- # 'call_seq_file': args.call_seq_file,
'elem_emb_size': args.elem_emb_size,
'model_base_path': model_base_path,
'pretraining_phase': args.pretraining_phase,
@@ -585,11 +750,11 @@ def evaluation_procedure(
'schedule_layers_every': args.schedule_layers_every,
'embedding_table_size': args.embedding_table_size,
'save_checkpoints': args.save_checkpoints,
- 'measure_ndcg': args.measure_ndcg,
- 'dilate_ndcg': args.dilate_ndcg
+ 'measure_scores': args.measure_scores,
+ 'dilate_scores': args.dilate_scores
}
- trainer = SamplingMultitaskTrainer(
+ trainer = trainer( #SamplingMultitaskTrainer(
dataset=dataset,
model_name=model_name,
model_params=model_params,
diff --git a/SourceCodeTools/models/graph/train/test_rggan.py b/SourceCodeTools/models/graph/train/test_rggan.py
new file mode 100644
index 00000000..3def8e85
--- /dev/null
+++ b/SourceCodeTools/models/graph/train/test_rggan.py
@@ -0,0 +1,541 @@
+import itertools
+import json
+import logging
+from copy import copy
+from datetime import datetime
+from os import mkdir
+from os.path import isdir, join
+from typing import Tuple
+
+import dgl
+import pandas as pd
+
+from dgl.data import WN18Dataset, FB15kDataset, FB15k237Dataset, AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
+import torch
+from sklearn.model_selection import ParameterGrid
+
+from SourceCodeTools.code.data.sourcetrail.Dataset import SourceGraphDataset
+from SourceCodeTools.models.graph import RGGAN
+from SourceCodeTools.models.graph.NodeEmbedder import NodeIdEmbedder
+from SourceCodeTools.models.graph.train.objectives import GraphLinkObjective
+from SourceCodeTools.models.graph.train.sampling_multitask2 import SamplingMultitaskTrainer
+from SourceCodeTools.models.graph.train.utils import get_name, get_model_base
+from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments
+
+rggan_grids = [
+ {
+ 'h_dim': [15],
+ 'num_bases': [-1],
+ 'num_steps': [5],
+ 'dropout': [0.0],
+ 'use_self_loop': [False],
+ 'activation': [torch.tanh], # torch.nn.functional.hardswish], #[torch.nn.functional.hardtanh], #torch.nn.functional.leaky_relu
+ 'lr': [1e-3], # 1e-4]
+ }
+]
+
+rggan_params = list(
+ itertools.chain.from_iterable(
+ [ParameterGrid(p) for p in rggan_grids]
+ )
+)
+
+
+class TestNodeClfGraph(SourceGraphDataset):
+ def __init__(self, data_loader,
+ label_from=None, use_node_types=True,
+ use_edge_types=True, filter=None, self_loops=False,
+ train_frac=0.6, random_seed=None, tokenizer_path=None, min_count_for_objectives=1,
+ no_global_edges=False, remove_reverse=False, package_names=None):
+ self.random_seed = random_seed
+ self.nodes_have_types = use_node_types
+ self.edges_have_types = use_edge_types
+ self.labels_from = label_from
+ # self.data_path = data_path
+ self.tokenizer_path = tokenizer_path
+ self.min_count_for_objectives = min_count_for_objectives
+ self.no_global_edges = no_global_edges
+ self.remove_reverse = remove_reverse
+
+ # dataset = AIFBDataset()
+ # dataset = MUTAGDataset()
+ # dataset = BGSDataset()
+ # dataset = AMDataset()
+ dataset = data_loader()
+
+ self.nodes, self.edges, self.typed_id_map = self.create_nodes_edges_df(dataset)
+
+ # index is later used for sampling and is assumed to be unique
+ assert len(self.nodes) == len(self.nodes.index.unique())
+ assert len(self.edges) == len(self.edges.index.unique())
+
+ if self_loops:
+ self.nodes, self.edges = SourceGraphDataset.assess_need_for_self_loops(self.nodes, self.edges)
+
+ if filter is not None:
+ for e_type in filter.split(","):
+ logging.info(f"Filtering edge type {e_type}")
+ self.edges = self.edges.query(f"type != '{e_type}'")
+
+ # if self.remove_reverse:
+ # self.remove_reverse_edges()
+ #
+ # if self.no_global_edges:
+ # self.remove_global_edges()
+
+ if use_node_types is False and use_edge_types is False:
+ new_nodes, new_edges = self.create_nodetype_edges()
+ self.nodes = self.nodes.append(new_nodes, ignore_index=True)
+ self.edges = self.edges.append(new_edges, ignore_index=True)
+
+ self.nodes['type_backup'] = self.nodes['type']
+ if not self.nodes_have_types:
+ self.nodes['type'] = "node_"
+ self.nodes = self.nodes.astype({'type': 'category'})
+
+ self.add_embeddable_flag()
+
+ # need to do this to avoid issues insode dgl library
+ self.edges['type'] = self.edges['type'].apply(lambda x: f"{x}_")
+ self.edges['type_backup'] = self.edges['type']
+ if not self.edges_have_types:
+ self.edges['type'] = "edge_"
+ self.edges = self.edges.astype({'type': 'category'})
+
+ # compact labels
+ # self.nodes['label'] = self.nodes[label_from]
+ # self.nodes = self.nodes.astype({'label': 'category'})
+ # self.label_map = compact_property(self.nodes['label'])
+ # assert any(pandas.isna(self.nodes['label'])) is False
+
+ logging.info(f"Unique nodes: {len(self.nodes)}, node types: {len(self.nodes['type'].unique())}")
+ logging.info(f"Unique edges: {len(self.edges)}, edge types: {len(self.edges['type'].unique())}")
+
+ # self.nodes, self.label_map = self.add_compact_labels()
+ self.add_typed_ids()
+
+ # self.add_splits(train_frac=train_frac, package_names=package_names)
+
+ # self.mark_leaf_nodes()
+
+ self.create_hetero_graph()
+
+ self.update_global_id()
+
+ self.nodes.sort_values('global_graph_id', inplace=True)
+
+ # self.splits = SourceGraphDataset.get_global_graph_id_splits(self.nodes)
+
+ def create_nodes_edges_df(self, dataset):
+ graph = dataset[0]
+ nodes_df = None
+
+ node_id_map = {}
+ typed_id_map = {}
+
+ for ntype in graph.ntypes:
+ typed_id = graph.nodes(ntype=ntype).tolist()
+ type = [ntype] * len(typed_id)
+ name = list(map(lambda x: ntype+f"_{x}", typed_id))
+ id_ = graph.nodes[ntype].data["_ID"].tolist()
+ node_dict = {"id": id_, "type": type, "name": name, "typed_id": typed_id}
+
+ node_id_map[ntype] = dict(zip(typed_id, id_))
+ typed_id_map[ntype] = dict(zip(id_, typed_id))
+
+ if "labels" in graph.nodes[ntype].data:
+ node_dict["labels"] = graph.nodes[ntype].data["labels"].tolist()
+ else:
+ node_dict["labels"] = [-1] * len(typed_id)
+
+ if "train_mask" in graph.nodes[ntype].data:
+ node_dict["train_mask"] = graph.nodes[ntype].data["train_mask"].bool().tolist()
+ else:
+ node_dict["train_mask"] = [False] * len(typed_id)
+
+ if "test_mask" in graph.nodes[ntype].data:
+ node_dict["test_mask"] = graph.nodes[ntype].data["test_mask"].bool().tolist()
+ else:
+ node_dict["test_mask"] = [False] * len(typed_id)
+
+ if "val_mask" in graph.nodes[ntype].data:
+ node_dict["val_mask"] = graph.nodes[ntype].data["val_mask"].bool().tolist()
+ else:
+ node_dict["val_mask"] = [False] * len(typed_id)
+
+ if nodes_df is None:
+ nodes_df = pd.DataFrame.from_dict(node_dict)
+ else:
+ nodes_df = nodes_df.append(pd.DataFrame.from_dict(node_dict))
+
+ assert len(nodes_df["id"]) == len(nodes_df["id"].unique()) # thus is a must have assert
+
+ nodes_df = nodes_df.reset_index(drop=True)
+ # contiguous_node_index = dict(zip(nodes_df["index"], nodes_df.index))
+
+ edges_df = None
+ for srctype, etype, dsttype in graph.canonical_etypes:
+ src, dst = graph.edges(etype=(srctype, etype, dsttype))
+ src = list(map(lambda x: node_id_map[srctype][x], src.tolist()))
+ dst = list(map(lambda x: node_id_map[dsttype][x], dst.tolist()))
+
+ edge_data = {
+ "id": graph.edata["_ID"][(srctype, etype, dsttype)].tolist(),
+ "type": [etype] * len(src),
+ "src": src,
+ "dst": dst
+ }
+
+ if edges_df is None:
+ edges_df = pd.DataFrame.from_dict(edge_data)
+ else:
+ edges_df = edges_df.append(pd.DataFrame.from_dict(edge_data))
+
+ # assert len(edges_df["id"]) == len(edges_df["id"].unique()) # fails with AMDataset
+
+ edges_df = edges_df.reset_index(drop=True)
+
+ return nodes_df, edges_df, typed_id_map
+
+ def add_embeddable_flag(self):
+ self.nodes['embeddable'] = True
+
+ def add_typed_ids(self):
+ pass
+
+ def load_node_classes(self):
+ labels = self.nodes.query("train_mask == True or test_mask == True or val_mask == True")[["id", "labels"]].rename({
+ "id": "src",
+ "labels": "dst"
+ }, axis=1)
+ return labels
+
+
+class TestLinkPredGraph(TestNodeClfGraph):
+ def __init__(
+ self, data_loader, label_from=None, use_node_types=True,
+ use_edge_types=True, filter=None, self_loops=False,
+ train_frac=0.6, random_seed=None, tokenizer_path=None, min_count_for_objectives=1,
+ no_global_edges=False, remove_reverse=False, package_names=None
+ ):
+ self.random_seed = random_seed
+ self.nodes_have_types = use_node_types
+ self.edges_have_types = use_edge_types
+ self.labels_from = label_from
+ # self.data_path = data_path
+ self.tokenizer_path = tokenizer_path
+ self.min_count_for_objectives = min_count_for_objectives
+ self.no_global_edges = no_global_edges
+ self.remove_reverse = remove_reverse
+
+ dataset = data_loader()
+
+ self.nodes, self.edges, self.typed_id_map = self.create_nodes_edges_df(dataset)
+
+ # index is later used for sampling and is assumed to be unique
+ assert len(self.nodes) == len(self.nodes.index.unique())
+ assert len(self.edges) == len(self.edges.index.unique())
+
+ if self_loops:
+ self.nodes, self.edges = SourceGraphDataset.assess_need_for_self_loops(self.nodes, self.edges)
+ self.nodes, self.val_edges = SourceGraphDataset.assess_need_for_self_loops(self.nodes, self.val_edges)
+ self.nodes, self.test_edges = SourceGraphDataset.assess_need_for_self_loops(self.nodes, self.test_edges)
+
+ if filter is not None:
+ for e_type in filter.split(","):
+ logging.info(f"Filtering edge type {e_type}")
+ self.edges = self.edges.query(f"type != '{e_type}'")
+ self.val_edges = self.val_edges.query(f"type != '{e_type}'")
+ self.test_edges = self.test_edges.query(f"type != '{e_type}'")
+
+ if use_node_types is False and use_edge_types is False:
+ new_nodes, new_edges = self.create_nodetype_edges()
+ self.nodes = self.nodes.append(new_nodes, ignore_index=True)
+ self.edges = self.edges.append(new_edges, ignore_index=True)
+
+ self.nodes['type_backup'] = self.nodes['type']
+ if not self.nodes_have_types:
+ self.nodes['type'] = "node_"
+ self.nodes = self.nodes.astype({'type': 'category'})
+
+ self.add_embeddable_flag()
+
+ # need to do this to avoid issues insode dgl library
+ self.edges['type'] = self.edges['type'].apply(lambda x: f"{x}_")
+ self.edges['type_backup'] = self.edges['type']
+ if not self.edges_have_types:
+ self.edges['type'] = "edge_"
+ self.edges = self.edges.astype({'type': 'category'})
+
+ # compact labels
+ # self.nodes['label'] = self.nodes[label_from]
+ # self.nodes = self.nodes.astype({'label': 'category'})
+ # self.label_map = compact_property(self.nodes['label'])
+ # assert any(pandas.isna(self.nodes['label'])) is False
+
+ logging.info(f"Unique nodes: {len(self.nodes)}, node types: {len(self.nodes['type'].unique())}")
+ logging.info(f"Unique edges: {len(self.edges)}, edge types: {len(self.edges['type'].unique())}")
+
+ # self.nodes, self.label_map = self.add_compact_labels()
+ self.add_typed_ids()
+
+ # self.add_splits(train_frac=train_frac, package_names=package_names)
+
+ # self.mark_leaf_nodes()
+
+ self.create_hetero_graph()
+
+ self.update_global_id()
+
+ self.nodes.sort_values('global_graph_id', inplace=True)
+
+ # self.splits = SourceGraphDataset.get_global_graph_id_splits(self.nodes)
+
+ def create_nodes_edges_df(self, dataset):
+ graph = dataset[0]
+ nodes_df = None
+
+ node_id_map = {}
+ typed_id_map = {}
+ typed_id = []
+
+ id_ = graph.nodes().tolist()
+ type = graph.ndata["ntype"].tolist()
+ src, dst = graph.edges()
+ src = src.tolist()
+ dst = dst.tolist()
+ edge_type = graph.edata["etype"].tolist()
+ train_mask = graph.edata["train_mask"].tolist()
+ val_mask = graph.edata["val_mask"].tolist()
+ test_mask = graph.edata["test_mask"].tolist()
+ # node2train_mask = dict(zip(src, train_mask))
+ # node2val_mask = dict(zip(src, val_mask))
+ # node2test_mask = dict(zip(src, test_mask))
+
+ for nid_, ntype_ in zip(id_, type):
+ if ntype_ not in node_id_map:
+ node_id_map[ntype_] = dict()
+ typed_id_map[ntype_] = dict()
+
+ node_id_map[ntype_][len(node_id_map[ntype_])] = nid_
+ typed_id_map[ntype_][nid_] = len(typed_id_map[ntype_])
+ typed_id.append(typed_id_map[ntype_][nid_])
+
+ name = list(map(lambda x: f"{x[0]}_{x[1]}", zip(type, typed_id)))
+
+ node_dict = {
+ "id": id_, "type": type, "name": name, "typed_id": typed_id,
+ # "train_mask": list(map(lambda x: node2train_mask[x], id_)),
+ # "val_mask": list(map(lambda x: node2val_mask[x], id_)),
+ # "test_mask": list(map(lambda x: node2test_mask[x], id_))
+ }
+
+ nodes_df = pd.DataFrame.from_dict(node_dict)
+
+ assert len(nodes_df["id"]) == len(nodes_df["id"].unique())
+
+ nodes_df = nodes_df.reset_index(drop=True)
+ # contiguous_node_index = dict(zip(nodes_df["index"], nodes_df.index))
+
+ assert len(node_id_map) == 1
+
+ edges_df = []
+ for e_ind, (src, etype, dst) in enumerate(zip(src, edge_type, dst)):
+ edge_data = {
+ "id": e_ind,
+ "type": etype,
+ "src": src,
+ "dst": dst,
+ "train_mask": train_mask[e_ind],
+ "val_mask": val_mask[e_ind],
+ "test_mask": test_mask[e_ind]
+ }
+ edges_df.append(edge_data)
+
+ edges_df = pd.DataFrame(edges_df)
+
+ assert len(edges_df["id"]) == len(edges_df["id"].unique())
+
+ edges_df = edges_df.reset_index(drop=True)
+ self.train_edges = edges_df.query("train_mask == True")
+ self.val_edges = edges_df.query("train_mask == True or val_mask == True")
+ self.test_edges = edges_df
+
+ return nodes_df, self.train_edges, typed_id_map
+
+
+class TestTrainer(SamplingMultitaskTrainer):
+ def __init__(self, *args, **kwargs):
+ super(TestTrainer, self).__init__(*args, **kwargs)
+
+ def create_node_embedder(self, dataset, tokenizer_path, n_dims=None, pretrained_path=None, n_buckets=500000):
+ self.node_embedder = NodeIdEmbedder(
+ nodes=dataset.nodes,
+ emb_size=n_dims,
+ dtype=self.dtype,
+ n_buckets=len(dataset.nodes) + 1 # override this because bucket size should be the same as number of nodes for this embedder
+ )
+
+
+def main_node_clf(models, args, data_loader):
+ for model, param_grid in models.items():
+ for params in param_grid:
+
+ if args.h_dim is None:
+ params["h_dim"] = args.node_emb_size
+ else:
+ params["h_dim"] = args.h_dim
+
+ params["num_steps"] = args.n_layers
+
+ date_time = str(datetime.now())
+ print("\n\n")
+ print(date_time)
+ print(f"Model: {model.__name__}, Params: {params}")
+
+ model_attempt = get_name(model, date_time)
+
+ model_base = get_model_base(args, model_attempt)
+
+ dataset = TestNodeClfGraph(data_loader=data_loader)
+
+ args.objectives = "node_clf"
+
+ from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure
+
+ trainer, scores = training_procedure(dataset, model, copy(params), args, model_base, trainer=TestTrainer)
+
+ return scores
+
+
+def main_link_pred(models, args, data_loader):
+ for model, param_grid in models.items():
+ for params in param_grid:
+
+ if args.h_dim is None:
+ params["h_dim"] = args.node_emb_size
+ else:
+ params["h_dim"] = args.h_dim
+
+ params["num_steps"] = args.n_layers
+
+ date_time = str(datetime.now())
+ print("\n\n")
+ print(date_time)
+ print(f"Model: {model.__name__}, Params: {params}")
+
+ model_attempt = get_name(model, date_time)
+
+ model_base = get_model_base(args, model_attempt)
+
+ dataset = TestLinkPredGraph(data_loader=data_loader)
+
+ args.objectives = "link_pred"
+
+ from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure
+
+ trainer, scores = training_procedure(dataset, model, copy(params), args, model_base, trainer=TestTrainer)
+
+ return scores
+
+def format_data(dataset):
+ graph = dataset[0]
+ category = dataset.predict_category
+ train_mask = graph.nodes[category].data.pop('train_mask')
+ test_mask = graph.nodes[category].data.pop('test_mask')
+ labels = graph.nodes[category].data.pop('labels')
+ train_labels = labels[train_mask]
+ test_labels = labels[test_mask]
+
+def node_clf(args):
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s")
+
+ models_ = {
+ # GCNSampling: gcnsampling_params,
+ # GATSampler: gatsampling_params,
+ # RGCNSampling: rgcnsampling_params,
+ # RGAN: rgcnsampling_params,
+ RGGAN: rggan_params
+
+ }
+
+ if not isdir(args.model_output_dir):
+ mkdir(args.model_output_dir)
+ args.save_checkpoints = False
+
+ # data_loaders = [AIFBDataset, MUTAGDataset, BGSDataset, AMDataset]
+ data_loaders = [AMDataset]
+
+ for dl in data_loaders:
+ print(dl.__name__)
+ scores = main_node_clf(models_, args, data_loader=dl)
+ print("\t", scores)
+
+
+def link_pred(args):
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s")
+
+ models_ = {
+ # GCNSampling: gcnsampling_params,
+ # GATSampler: gatsampling_params,
+ # RGCNSampling: rgcnsampling_params,
+ # RGAN: rgcnsampling_params,
+ RGGAN: rggan_params
+
+ }
+
+ if not isdir(args.model_output_dir):
+ mkdir(args.model_output_dir)
+ args.save_checkpoints = False
+
+ data_loaders = [FB15kDataset, FB15k237Dataset]
+
+ for dl in data_loaders:
+ print(dl.__name__)
+ scores = main_link_pred(models_, args, data_loader=dl)
+ print("\t", scores)
+
+
+ # dataset = AIFBDataset()
+ # graph = dataset[0]
+ # category = dataset.predict_category
+ # num_classes = dataset.num_classes
+ # train_mask = graph.nodes[category].data.pop('train_mask')
+ # test_mask = graph.nodes[category].data.pop('test_mask')
+ # labels = graph.nodes[category].data.pop('labels')
+ #
+ # print()
+
+# def main():
+# dataset = FB15k237Dataset()
+# g = dataset.graph
+# e_type = g.edata['e_type']
+#
+# train_mask = g.edata['train_mask']
+# val_mask = g.edata['val_mask']
+# test_mask = g.edata['test_mask']
+#
+# # graph = dataset[0]
+# # train_mask = graph.edata['train_mask']
+# # test_mask = g.edata['test_mask']
+#
+# train_set = torch.arange(g.number_of_edges())[train_mask]
+# val_set = torch.arange(g.number_of_edges())[val_mask]
+#
+# train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
+# src, dst = graph.edges(train_idx)
+# rel = graph.edata['etype'][train_idx]
+#
+# print()
+
+if __name__=="__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Process some integers.')
+ add_gnn_train_args(parser)
+
+ args = parser.parse_args()
+ verify_arguments(args)
+
+ node_clf(args)
+ # link_pred(args)
\ No newline at end of file
diff --git a/SourceCodeTools/models/graph/train/utils.py b/SourceCodeTools/models/graph/train/utils.py
index 95d1a14c..50fc134d 100644
--- a/SourceCodeTools/models/graph/train/utils.py
+++ b/SourceCodeTools/models/graph/train/utils.py
@@ -18,8 +18,8 @@ def get_name(model, timestamp):
return "{} {}".format(model.__name__, timestamp).replace(":", "-").replace(" ", "-").replace(".", "-")
-def get_model_base(args, model_attempt):
- if args.restore_state:
+def get_model_base(args, model_attempt, force_new=False):
+ if args.restore_state and not force_new:
model_base = args.model_output_dir
else:
model_base = join(args.model_output_dir, model_attempt)
diff --git a/SourceCodeTools/models/graph/utils/dglke_to_embedder.py b/SourceCodeTools/models/graph/utils/dglke_to_embedder.py
new file mode 100644
index 00000000..093b1a00
--- /dev/null
+++ b/SourceCodeTools/models/graph/utils/dglke_to_embedder.py
@@ -0,0 +1,38 @@
+import pickle
+
+import numpy
+import pandas as pd
+
+from SourceCodeTools.models.Embedder import Embedder
+
+
+def read_entities(path):
+ return pd.read_csv(path, sep="\t", header=None)[1]
+
+
+def read_vectors(path):
+ return numpy.load(path)
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("entities")
+ parser.add_argument("vectors")
+ parser.add_argument("output")
+
+ args = parser.parse_args()
+
+ entities = read_entities(args.entities)
+ vectors = read_vectors(args.vectors)
+
+ embedder = Embedder(dict(zip(entities, range(len(entities)))), vectors)
+ with open(args.output, "wb") as sink:
+ pickle.dump(embedder, sink)
+
+
+
+
+
+if __name__=="__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/models/graph/utils/export4visualization.py b/SourceCodeTools/models/graph/utils/export4visualization.py
index 9533a5d1..69d28454 100644
--- a/SourceCodeTools/models/graph/utils/export4visualization.py
+++ b/SourceCodeTools/models/graph/utils/export4visualization.py
@@ -1,89 +1,130 @@
import pickle
-import sys
-import pandas
+from os.path import join
+
import numpy as np
import os
from collections import Counter
-
-model_path = sys.argv[1]
-
-if len(sys.argv) > 2:
- max_embs = int(sys.argv[2])
-else:
- max_embs = 5000
-
-nodes_path = os.path.join(model_path, "nodes.csv")
-edges_path = os.path.join(model_path, "edges.csv")
-embedders_path = os.path.join(model_path, "embeddings.pkl")
-
-embedders = pickle.load(open(embedders_path, "rb"))
-nodes = pandas.read_csv(nodes_path)
-
-edges = pandas.read_csv(edges_path)
-degrees = Counter(edges['src'].tolist()) + Counter(edges['dst'].tolist())
-
-ids = nodes['id'].values
-names = nodes['name']#.apply(lambda x: x.split(".")[-1]).values
-
-id_name_map = list(zip(ids, names))
-id_name_map_d = dict(id_name_map)
-
-ind_mapper = embedders[0].ind
-
-id_name_map = sorted(id_name_map, key=lambda x: ind_mapper[x[0]])
-
-ids, names = zip(*id_name_map)
-
-print(f"Limiting to {max_embs} embeddings")
-
-
-# emb0 = []
-# emb1 = []
-# emb2 = []
-
-for group, gr_ in nodes.groupby("type_backup"):
- c_ = 0
-
- names = []
- embs = [[] for _ in embedders]
-
- nodes_in_group = set(gr_['id'].tolist())
- for ind, (id_, count) in enumerate(degrees.most_common()):
- if id_ not in nodes_in_group: continue
- names.append(id_name_map_d[id_])
- for emb_, embedder in zip(embs, embedders):
- emb_.append(embedder.e[embedder.ind[id_]])
- # emb0.append(embedders[0].e[embedders[0].ind[id_]])
- # emb1.append(embedders[1].e[embedders[1].ind[id_]])
- # emb2.append(embedders[2].e[embedders[2].ind[id_]])
- c_ += 1
- if c_ >= max_embs - 1: break
- # if ind >= max_embs-1: break
-
- # np.savetxt(os.path.join(model_path, "emb4proj_meta.tsv"), np.array(names).reshape(-1, 1))
- for ind, emb_ in enumerate(embs):
- np.savetxt(os.path.join(model_path, f"emb4proj{ind}_{group}.tsv"), np.array(emb_), delimiter="\t")
- # np.savetxt(os.path.join(model_path, "emb4proj0.tsv"), np.array(emb0), delimiter="\t")
- # np.savetxt(os.path.join(model_path, "emb4proj1.tsv"), np.array(emb1), delimiter="\t")
- # np.savetxt(os.path.join(model_path, "emb4proj2.tsv"), np.array(emb2), delimiter="\t")
-
- print("Writing meta...", end="")
- with open(os.path.join(model_path, f"emb4proj_meta_{group}.tsv"), "w") as meta:
- for name in names[:max_embs]:
- meta.write(f"{name}\n")
- print("done")
-
-# with open(os.path.join(model_path, "emb4proj2w2v.txt"), "w") as w2v:
-# for ind, name in enumerate(names):
-# w2v.write("%s " % name)
-# for j, v in enumerate(emb2[ind]):
-# if j < len(emb2[ind]) - 1:
-# w2v.write("%f " % v)
-# else:
-# w2v.write("%f\n" % v)
-
-
-# for ind, e in enumerate(embedders):
-# print(f"Writing embedding layer {ind}...", end="")
-# np.savetxt(os.path.join(model_path, f"emb4proj{ind}.tsv"), e.e[:max_embs], delimiter="\t")
-# print("done")
+import argparse
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("model_path")
+ parser.add_argument("output_path")
+ parser.add_argument("--max_embs", type=int, default=5000)
+ parser.add_argument("--into_groups", action="store_true")
+ args = parser.parse_args()
+ model_path = args.model_path
+
+ dataset = pickle.load(open(join(args.model_path, "dataset.pkl"), "rb"))
+ embedders = pickle.load(open(join(args.model_path, "embeddings.pkl"), "rb"))
+
+ nodes = dataset.nodes
+ edges = dataset.edges
+
+ degrees = Counter(edges['src'].tolist()) + Counter(edges['dst'].tolist())
+
+ ids = nodes['id'].values
+ names = nodes['name']#.apply(lambda x: x.split(".")[-1]).values
+
+ id_name_map = list(zip(ids, names))
+ id_name_map_d = dict(id_name_map)
+
+ ind_mapper = embedders[0].ind
+
+ id_name_map = sorted(id_name_map, key=lambda x: ind_mapper[x[0]])
+
+ ids, names = zip(*id_name_map)
+
+ print(f"Limiting to {args.max_embs} embeddings")
+
+ # emb0 = []
+ # emb1 = []
+ # emb2 = []
+
+ def write_in_groups():
+
+ for group, gr_ in nodes.groupby("type_backup"):
+ c_ = 0
+
+ names = []
+ embs = [[] for _ in embedders]
+
+ nodes_in_group = set(gr_['id'].tolist())
+ for ind, (id_, count) in enumerate(degrees.most_common()):
+ if id_ not in nodes_in_group: continue
+ names.append(id_name_map_d[id_])
+ for emb_, embedder in zip(embs, embedders):
+ emb_.append(embedder.e[embedder.ind[id_]])
+ # emb0.append(embedders[0].e[embedders[0].ind[id_]])
+ # emb1.append(embedders[1].e[embedders[1].ind[id_]])
+ # emb2.append(embedders[2].e[embedders[2].ind[id_]])
+ c_ += 1
+ if c_ >= args.max_embs - 1: break
+ # if ind >= max_embs-1: break
+
+ # np.savetxt(os.path.join(model_path, "emb4proj_meta.tsv"), np.array(names).reshape(-1, 1))
+ for ind, emb_ in enumerate(embs):
+ np.savetxt(os.path.join(args.output_path, f"emb4proj{ind}_{group}.tsv"), np.array(emb_), delimiter="\t")
+ # np.savetxt(os.path.join(model_path, "emb4proj0.tsv"), np.array(emb0), delimiter="\t")
+ # np.savetxt(os.path.join(model_path, "emb4proj1.tsv"), np.array(emb1), delimiter="\t")
+ # np.savetxt(os.path.join(model_path, "emb4proj2.tsv"), np.array(emb2), delimiter="\t")
+
+ print("Writing meta...", end="")
+ with open(os.path.join(args.output_path, f"emb4proj_meta_{group}.tsv"), "w") as meta:
+ for name in names[:args.max_embs]:
+ meta.write(f"{name}\n")
+ print("done")
+
+ def write_all_together():
+ c_ = 0
+
+ names = []
+ embs = [[] for _ in embedders]
+
+ nodes_in_group = set(nodes['id'].tolist())
+ for ind, (id_, count) in enumerate(degrees.most_common()):
+ if id_ not in nodes_in_group: continue
+ names.append(id_name_map_d[id_])
+ for emb_, embedder in zip(embs, embedders):
+ emb_.append(embedder.e[embedder.ind[id_]])
+ c_ += 1
+ if c_ >= args.max_embs - 1: break
+ # if ind >= max_embs-1: break
+
+ # np.savetxt(os.path.join(model_path, "emb4proj_meta.tsv"), np.array(names).reshape(-1, 1))
+ for ind, emb_ in enumerate(embs):
+ np.savetxt(os.path.join(args.output_path, f"emb4proj{ind}.tsv"), np.array(emb_), delimiter="\t")
+ # np.savetxt(os.path.join(model_path, "emb4proj0.tsv"), np.array(emb0), delimiter="\t")
+ # np.savetxt(os.path.join(model_path, "emb4proj1.tsv"), np.array(emb1), delimiter="\t")
+ # np.savetxt(os.path.join(model_path, "emb4proj2.tsv"), np.array(emb2), delimiter="\t")
+
+ print("Writing meta...", end="")
+ with open(os.path.join(args.output_path, f"emb4proj_meta.tsv"), "w") as meta:
+ for name in names[:args.max_embs]:
+ meta.write(f"{name}\n")
+ print("done")
+
+ if args.into_groups:
+ write_in_groups()
+ else:
+ write_all_together()
+
+
+ # with open(os.path.join(model_path, "emb4proj2w2v.txt"), "w") as w2v:
+ # for ind, name in enumerate(names):
+ # w2v.write("%s " % name)
+ # for j, v in enumerate(emb2[ind]):
+ # if j < len(emb2[ind]) - 1:
+ # w2v.write("%f " % v)
+ # else:
+ # w2v.write("%f\n" % v)
+
+
+ # for ind, e in enumerate(embedders):
+ # print(f"Writing embedding layer {ind}...", end="")
+ # np.savetxt(os.path.join(model_path, f"emb4proj{ind}.tsv"), e.e[:max_embs], delimiter="\t")
+ # print("done")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/models/graph/utils/prepare_dglke_format.py b/SourceCodeTools/models/graph/utils/prepare_dglke_format.py
index 3482f452..cc12ea5b 100644
--- a/SourceCodeTools/models/graph/utils/prepare_dglke_format.py
+++ b/SourceCodeTools/models/graph/utils/prepare_dglke_format.py
@@ -1,79 +1,85 @@
+import logging
import os
-import pandas as pd
-from SourceCodeTools.code.data.sourcetrail.Dataset import SourceGraphDataset, load_data, compact_property
+from os.path import isdir, join, isfile
+
+from SourceCodeTools.code.data.sourcetrail.Dataset import load_data, compact_property, SourceGraphDataset
import argparse
-parser = argparse.ArgumentParser(description='Process some integers.')
-parser.add_argument('--nodes_path', dest='nodes_path', default=None,
- help='Path to the file with nodes')
-parser.add_argument('--edges_path', dest='edges_path', default=None,
- help='Path to the file with edges')
-parser.add_argument('--fname_path', dest='fname_path', default=None,
- help='')
-parser.add_argument('--varuse_path', dest='varuse_path', default=None,
- help='')
-parser.add_argument('--apicall_path', dest='apicall_path', default=None,
- help='')
-parser.add_argument('--out_path', dest='out_path', default=None,
- help='')
+from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist, persist
+
+
+def get_paths(dataset_path, use_extra_objectives):
+ extra_objectives = ["node_names.bz2", "common_function_variable_pairs.bz2", "common_call_seq.bz2", "type_annotations.bz2"]
+
+ largest_component = join(dataset_path, "largest_component")
+ if isdir(largest_component):
+ logging.info("Using graph from largest_component directory")
+ nodes_path = join(largest_component, "nodes.bz2")
+ edges_path = join(largest_component, "edges.bz2")
+ else:
+ nodes_path = join(dataset_path, "common_nodes.bz2")
+ edges_path = join(dataset_path, "common_edges.bz2")
+
+
+ extra_paths = list(map(
+ lambda file: file if use_extra_objectives and isfile(file) else None,
+ (join(dataset_path, objective) for objective in extra_objectives)
+ ))
-args = parser.parse_args()
+ return nodes_path, edges_path, extra_paths
-nodes, edges = load_data(args.nodes_path, args.edges_path)
-node2graph_id = compact_property(nodes['id'])
-nodes['global_graph_id'] = nodes['id'].apply(lambda x: node2graph_id[x])
-nodes, edges, held = SourceGraphDataset.holdout(nodes, edges, 0.005)
-edges.to_csv(os.path.join(args.out_path, "edges_train.csv"), index=False)
-held.to_csv(os.path.join(args.out_path, "held.csv"), index=False)
+def filter_relevant(data, node_ids):
+ return data.query("src in @allowed", local_dict={"allowed": node_ids})
-edges = edges.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']]
-node_ids = set(nodes['id'].unique())
+def main():
+ parser = argparse.ArgumentParser(description='Process some integers.')
+ parser.add_argument('dataset_path', default=None, help='Path to the dataset')
+ parser.add_argument('output_path', default=None, help='')
+ parser.add_argument("--extra_objectives", action="store_true", default=False)
+ parser.add_argument("--eval_frac", dest="eval_frac", default=0.05, type=float)
-if args.fname_path is not None:
- fname = pd.read_csv(args.fname_path).astype({"src": 'int32', "dst": "str"})
- fname['type'] = 'fname'
+ args = parser.parse_args()
- fname = fname[
- fname['src'].apply(lambda x: x in node_ids)
- ]
+ nodes_path, edges_path, extra_paths = get_paths(
+ args.dataset_path, use_extra_objectives=args.extra_objectives
+ )
- edges = pd.concat([edges, fname])
+ nodes, edges = load_data(nodes_path, edges_path)
+ nodes, edges, holdout = SourceGraphDataset.holdout(nodes, edges)
+ edges = edges.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']]
+ holdout = holdout.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']]
-if args.varuse_path is not None:
- varuse = pd.read_csv(args.varuse_path).astype({"src": 'int32', "dst": "str"})
- varuse['type'] = 'varuse'
+ node2graph_id = compact_property(nodes['id'])
+ nodes['global_graph_id'] = nodes['id'].apply(lambda x: node2graph_id[x])
- varuse = varuse[
- varuse['src'].apply(lambda x: x in node_ids)
- ]
+ node_ids = set(nodes['id'].unique())
- edges = pd.concat([edges, varuse])
+ if args.extra_objectives:
+ for objective_path in extra_paths:
+ data = unpersist(objective_path)
+ data = filter_relevant(data, node_ids)
+ data["type"] = objective_path.split(".")[0]
+ edges = edges.append(data)
-if args.apicall_path is not None:
- apicall = pd.read_csv(args.apicall_path).astype({"src": 'int32', "dst": "int32"})
- apicall['type'] = 'nextcall'
+ if not os.path.isdir(args.output_path):
+ os.mkdir(args.output_path)
- apicall = apicall[
- apicall['src'].apply(lambda x: x in node_ids)
- ]
+ edges = edges[['src','dst','type']]
+ eval_sample = edges.sample(frac=args.eval_frac)
- edges = pd.concat([edges, apicall])
+ persist(nodes, join(args.output_path, "nodes_dglke.csv"))
+ persist(edges, join(args.output_path, "edges_train_dglke.tsv"), header=False, sep="\t")
+ persist(edges, join(args.output_path, "edges_train_node2vec.tsv"), header=False, sep=" ")
+ persist(eval_sample, join(args.output_path, "edges_eval_dglke.tsv"), header=False, sep="\t")
+ persist(eval_sample, join(args.output_path, "edges_eval_node2vec.tsv"), header=False, sep=" ")
+ persist(holdout, join(args.output_path, "edges_eval_dglke_10000.tsv"), header=False, sep="\t")
+ persist(holdout, join(args.output_path, "edges_eval_node2vec_10000.tsv"), header=False, sep=" ")
-# splits = get_train_test_val_indices(edges.index, train_frac=0.6)
-nodes['label'] = nodes['type']
-if not os.path.isdir(args.out_path):
- os.mkdir(args.out_path)
-nodes.to_csv(os.path.join(args.out_path, "nodes.csv"), index=False)
-edges.to_csv(os.path.join(args.out_path, "edges_train_dglke.tsv"), index=False, header=False, sep="\t")
-edges[['src','dst']].to_csv(os.path.join(args.out_path, "edges_train_node2vec.csv"), index=False, header=False, sep=" ")
-# edges.iloc[splits[0]].to_csv(os.path.join(args.out_path, "edges_train.csv"), index=False, header=False, sep="\t")
-# edges.iloc[splits[1]].to_csv(os.path.join(args.out_path, "edges_val.csv"), index=False, header=False, sep="\t")
-# edges.iloc[splits[2]].to_csv(os.path.join(args.out_path, "edges_test.csv"), index=False, header=False, sep="\t")
-held[['src','dst','type']].to_csv(os.path.join(args.out_path, "held_dglkg.tsv"), index=False, sep='\t', header=False)
-held[['src','dst']].to_csv(os.path.join(args.out_path, "held_node2vec.csv"), index=False, sep=' ', header=False)
\ No newline at end of file
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/models/nlp/TFDecoder.py b/SourceCodeTools/models/nlp/TFDecoder.py
new file mode 100644
index 00000000..0c313c8f
--- /dev/null
+++ b/SourceCodeTools/models/nlp/TFDecoder.py
@@ -0,0 +1,122 @@
+import tensorflow as tf
+from tensorflow.python.keras.layers import Layer, Dense, Dropout
+from tensorflow_addons.layers import MultiHeadAttention
+
+from SourceCodeTools.models.nlp.common import positional_encoding
+
+
+class ConditionalDecoderLayer(tf.keras.layers.Layer):
+ def __init__(self, d_model, num_heads, dff, rate=0.1):
+ super(ConditionalDecoderLayer, self).__init__()
+
+ def point_wise_feed_forward_network(d_model, dff):
+ return tf.keras.Sequential([
+ tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
+ tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
+ ])
+
+ self.mha1 = MultiHeadAttention(d_model, num_heads, return_attn_coef=True)
+ self.mha2 = MultiHeadAttention(d_model, num_heads, return_attn_coef=True)
+
+ self.ffn = point_wise_feed_forward_network(d_model, dff)
+
+ self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
+
+ self.dropout1 = tf.keras.layers.Dropout(rate)
+ self.dropout2 = tf.keras.layers.Dropout(rate)
+ self.dropout3 = tf.keras.layers.Dropout(rate)
+
+ def call(self, inputs, look_ahead_mask=None, mask=None, training=None):
+ encoder_out, x = inputs
+ # enc_output.shape == (batch_size, input_seq_len, d_model)
+
+ attn1, attn_weights_block1 = self.mha1((x, x, x), mask=look_ahead_mask) # (batch_size, target_seq_len, d_model)
+ attn1 = self.dropout1(attn1, training=training)
+ out1 = self.layernorm1(attn1 + x) # skip connection
+
+ attn2, attn_weights_block2 = self.mha2(
+ (out1, encoder_out, encoder_out), mask=tf.tile(tf.expand_dims(encoder_out._keras_mask, axis=1), (1,out1.shape[1],1))) # (batch_size, target_seq_len, d_model)
+ attn2 = self.dropout2(attn2, training=training)
+ out2 = self.layernorm2(attn2 + out1) # skip connection (batch_size, target_seq_len, d_model)
+
+ ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
+ ffn_output = self.dropout3(ffn_output, training=training)
+ out3 = self.layernorm3(ffn_output + out2) # skip connection (batch_size, target_seq_len, d_model)
+
+ return out3, attn_weights_block1, attn_weights_block2
+
+
+class ConditionalAttentionDecoder(tf.keras.layers.Layer):
+ def __init__(self, input_dim, out_dim, num_layers, num_heads, ff_hidden, target_vocab_size,
+ maximum_position_encoding, rate=0.1):
+ super(ConditionalAttentionDecoder, self).__init__()
+
+ self.d_model = out_dim
+ self.num_layers = num_layers
+
+ self.embedding = tf.keras.layers.Embedding(target_vocab_size, input_dim)
+ self.pos_encoding = positional_encoding(maximum_position_encoding, input_dim)
+
+ self.dec_layers = [ConditionalDecoderLayer(input_dim, num_heads, ff_hidden, rate)
+ for _ in range(num_layers)]
+ self.dropout = tf.keras.layers.Dropout(rate)
+
+ self.look_ahead_mask = self.create_look_ahead_mask(1)
+ self.fc_out = Dense(out_dim)
+
+ def create_look_ahead_mask(self, size):
+ mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
+ return mask # (seq_len, seq_len)
+
+ def compute_mask(self, inputs, mask=None):
+ # encoder_out, target = inputs
+ return mask
+
+ def call(self, inputs, training=None, mask=None):
+ if len(inputs) == 2:
+ encoder_out, target = inputs
+ elif len(inputs) == 1:
+ encoder_out = inputs
+ target = None
+ else:
+ raise ValueError("Incorrect number of parameters")
+
+ if target is None:
+ seq_len = tf.shape(encoder_out)[1]
+ x = encoder_out
+ else:
+ seq_len = tf.shape(target)[1]
+ x = self.embedding(target) # (batch_size, target_seq_len, d_model)
+ x += self.pos_encoding[:, :seq_len, :]
+
+ if self.look_ahead_mask.shape[0] != seq_len:
+ self.look_ahead_mask = self.create_look_ahead_mask(seq_len)
+
+ attention_weights = {}
+
+ x = self.dropout(x, training=training)
+
+ for i in range(self.num_layers):
+ x, block1, block2 = self.dec_layers[i](
+ (encoder_out, x), look_ahead_mask=self.look_ahead_mask, mask=mask, training=training
+ )
+
+ attention_weights[f'decoder_layer{i+1}_block1'] = block1
+ attention_weights[f'decoder_layer{i+1}_block2'] = block2
+
+ # x.shape == (batch_size, target_seq_len, d_model)
+ return self.fc_out(x), attention_weights
+
+
+class FlatDecoder(Layer):
+ def __init__(self, out_dims, hidden=100, dropout=0.1):
+ super(FlatDecoder, self).__init__()
+ self.fc1 = Dense(hidden, activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.HeNormal())
+ self.drop = Dropout(rate=dropout)
+ self.fc2 = Dense(out_dims)
+
+ def call(self, inputs, training=None, mask=None):
+ encoder_out, target = inputs
+ return self.fc2(self.drop(self.fc1(encoder_out, training=training), training=training), training=training), None
\ No newline at end of file
diff --git a/SourceCodeTools/models/nlp/TFEncoder.py b/SourceCodeTools/models/nlp/TFEncoder.py
new file mode 100644
index 00000000..c6cf6af0
--- /dev/null
+++ b/SourceCodeTools/models/nlp/TFEncoder.py
@@ -0,0 +1,238 @@
+from copy import copy
+
+from scipy.linalg import toeplitz
+from tensorflow.keras.layers import Layer, Dense, Conv2D, Flatten, Input, Embedding, concatenate, GRU
+from tensorflow.keras import Model
+import tensorflow as tf
+from tensorflow.python.keras.layers import Dropout
+
+
+class DefaultEmbedding(Layer):
+ """
+ Creates an embedder that provides the default value for the index -1. The default value is a zero-vector
+ """
+ def __init__(self, init_vectors=None, shape=None, trainable=True):
+ super(DefaultEmbedding, self).__init__()
+
+ if init_vectors is not None:
+ self.embs = tf.Variable(init_vectors, dtype=tf.float32,
+ trainable=trainable, name="default_embedder_var")
+ shape = init_vectors.shape
+ else:
+ # TODO
+ # the default value is no longer constant. need to replace this with a standard embedder
+ self.embs = tf.Variable(tf.random.uniform(shape=(shape[0], shape[1]), dtype=tf.float32),
+ name="default_embedder_pad")
+ self.pad = tf.zeros(shape=(1, shape[1]), name="default_embedder_pad")
+ # self.pad = tf.random.uniform(shape=(1, init_vectors.shape[1]), name="default_embedder_pad")
+ # self.pad = tf.Variable(tf.random.uniform(shape=(1, shape[1]), dtype=tf.float32),
+ # name="default_embedder_pad")
+
+ # def compute_mask(self, inputs, mask=None):
+ # ids, lengths = inputs
+ # # position with value -1 is a pad
+ # return tf.sequence_mask(lengths, ids.shape[1])
+ # # return inputs != self.embs.shape[0]
+
+ def call(self, ids):
+ emb_matr = tf.concat([self.embs, self.pad], axis=0)
+ return tf.nn.embedding_lookup(params=emb_matr, ids=ids)
+ # return tf.expand_dims(tf.nn.embedding_lookup(params=self.emb_matr, ids=ids), axis=3)
+
+
+class PositionalEncoding(Model):
+ def __init__(self, seq_len, pos_emb_size):
+ """
+ Create positional embedding with a trainable embedding matrix. Currently not using because it results
+ in N^2 computational complexity. Should move this functionality to batch preparation.
+ :param seq_len: maximum sequence length
+ :param pos_emb_size: the dimensionality of positional embeddings
+ """
+ super(PositionalEncoding, self).__init__()
+
+ positions = list(range(seq_len * 2))
+ position_splt = positions[:seq_len]
+ position_splt.reverse()
+ self.position_encoding = tf.constant(toeplitz(position_splt, positions[seq_len:]),
+ dtype=tf.int32,
+ name="position_encoding")
+ # self.position_embedding = tf.random.uniform(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32)
+ self.position_embedding = tf.Variable(tf.random.uniform(shape=(seq_len * 2, pos_emb_size), dtype=tf.float32),
+ name="position_embedding")
+ # self.position_embedding = tf.Variable(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32)
+
+ def call(self):
+ # return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup")
+ return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup")
+
+
+class TextCnnLayer(Model):
+ def __init__(self, out_dim, kernel_shape, activation=None):
+ super(TextCnnLayer, self).__init__()
+
+ self.kernel_shape = kernel_shape
+ self.out_dim = out_dim
+
+ self.textConv = Conv2D(filters=out_dim, kernel_size=kernel_shape,
+ activation=activation, data_format='channels_last')
+
+ padding_size = (self.kernel_shape[0] - 1) // 2
+ assert padding_size * 2 + 1 == self.kernel_shape[0]
+ self.pad_constant = tf.constant([[0, 0], [padding_size, padding_size], [0, 0], [0, 0]])
+
+ self.supports_masking = True
+
+ def call(self, x, training=None, mask=None):
+ padded = tf.pad(x, self.pad_constant)
+ # emb_sent_exp = tf.expand_dims(input, axis=3)
+ convolve = self.textConv(padded)
+ return tf.squeeze(convolve, axis=-2)
+
+
+class TextCnnEncoder(Model):
+ """
+ TextCnnEncoder model for classifying tokens in a sequence. The model uses following pipeline:
+
+ token_embeddings (provided from outside) ->
+ several convolutional layers, get representations for all tokens ->
+ pass representation for all tokens through a dense network ->
+ classify each token
+ """
+ def __init__(self,
+ # input_size,
+ h_sizes, seq_len,
+ pos_emb_size, cnn_win_size, dense_size, out_dim,
+ activation=None, dense_activation=None, drop_rate=0.2):
+ """
+
+ :param input_size: dimensionality of input embeddings
+ :param h_sizes: sizes of hidden CNN layers, internal dimensionality of token embeddings
+ :param seq_len: maximum sequence length
+ :param pos_emb_size: dimensionality of positional embeddings
+ :param cnn_win_size: width of cnn window
+ :param dense_size: number of unius in dense network
+ :param num_classes: number of output units
+ :param activation: activation for cnn
+ :param dense_activation: activation for dense layers
+ :param drop_rate: dropout rate for dense network
+ """
+ super(TextCnnEncoder, self).__init__()
+
+ self.seq_len = seq_len
+ self.h_sizes = h_sizes
+ self.dense_size = dense_size
+ self.out_dim = out_dim
+ self.pos_emb_size = pos_emb_size
+ self.cnn_win_size = cnn_win_size
+ self.activation = activation
+ self.dense_activation = dense_activation
+ self.drop_rate = drop_rate
+
+ self.supports_masking = True
+
+ def build(self, input_shape):
+ assert len(input_shape) == 3
+
+ input_size = input_shape[2]
+
+ def infer_kernel_sizes(h_sizes):
+ """
+ Compute kernel sizes from the desired dimensionality of hidden layers
+ :param h_sizes:
+ :return:
+ """
+ kernel_sizes = copy(h_sizes)
+ kernel_sizes.pop(-1) # pop last because it is the output of the last CNN layer
+ kernel_sizes.insert(0, input_size) # the first kernel size should be (cnn_win_size, input_size)
+ kernel_sizes = [(self.cnn_win_size, ks) for ks in kernel_sizes]
+ return kernel_sizes
+
+ kernel_sizes = infer_kernel_sizes(self.h_sizes)
+
+ self.layers_tok = [TextCnnLayer(out_dim=h_size, kernel_shape=kernel_size, activation=self.activation)
+ for h_size, kernel_size in zip(self.h_sizes, kernel_sizes)]
+
+ # self.layers_pos = [TextCnnLayer(out_dim=h_size, kernel_shape=(cnn_win_size, pos_emb_size), activation=activation)
+ # for h_size, _ in zip(h_sizes, kernel_sizes)]
+
+ # self.positional = PositionalEncoding(seq_len=seq_len, pos_emb_size=pos_emb_size)
+
+ if self.dense_activation is None:
+ dense_activation = self.activation
+
+ # self.attention = tfa.layers.MultiHeadAttention(head_size=200, num_heads=1)
+
+ self.dense_1 = Dense(self.dense_size, activation=self.dense_activation)
+ self.dropout_1 = tf.keras.layers.Dropout(rate=self.drop_rate)
+ self.dense_2 = Dense(self.out_dim, activation=None) # logits
+ self.dropout_2 = tf.keras.layers.Dropout(rate=self.drop_rate)
+
+ def compute_mask(self, inputs, mask=None):
+ return mask
+
+ def call(self, embs, training=True, mask=None):
+
+ temp_cnn_emb = embs # shape (?, seq_len, input_size)
+
+ # pass embeddings through several CNN layers
+ for l in self.layers_tok:
+ temp_cnn_emb = l(tf.expand_dims(temp_cnn_emb, axis=3)) # shape (?, seq_len, h_size)
+
+ # TODO
+ # simplify to one CNN and one attention
+
+ # pos_cnn = self.positional()
+ # for l in self.layers_pos:
+ # pos_cnn = l(tf.expand_dims(pos_cnn, axis=3))
+ #
+ # cnn_pool_feat = []
+ # for i in range(self.seq_len):
+ # # slice tensor for the line that corresponds to the current position in the sentence
+ # position_features = tf.expand_dims(pos_cnn[i, ...], axis=0, name="exp_dim_%d" % i)
+ # # convolution without activation can be combined later, hence: temp_cnn_emb + position_features
+ # cnn_pool_feat.append(
+ # tf.expand_dims(tf.nn.tanh(tf.reduce_max(temp_cnn_emb + position_features, axis=1)), axis=1))
+ # # cnn_pool_feat.append(
+ # # tf.expand_dims(tf.nn.tanh(tf.reduce_max(tf.concat([temp_cnn_emb, position_features], axis=-1), axis=1)), axis=1))
+ #
+ # cnn_pool_features = tf.concat(cnn_pool_feat, axis=1)
+ cnn_pool_features = temp_cnn_emb
+
+ # cnn_pool_features = self.attention([cnn_pool_features, cnn_pool_features])
+
+ # token_features = self.dropout_1(
+ # tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1]))
+ # , training=training)
+
+ # reshape before passing through a dense network
+ # token_features = tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1])) # shape (? * seq_len, h_size[-1])
+
+ # local_h2 = self.dropout_2(
+ # self.dense_1(token_features)
+ # , training=training)
+ local_h2 = self.dense_1(cnn_pool_features) # shape (? * seq_len, dense_size)
+ tag_logits = self.dense_2(local_h2) # shape (? * seq_len, num_classes)
+
+ return tag_logits # tf.reshape(tag_logits, (-1, seq_len, self.out_dim)) # reshape back, shape (?, seq_len, num_classes)
+
+
+class GRUEncoder(Model):
+ def __init__(self, input_dim, out_dim=100, num_layers=1, dropout=0.1):
+ super(GRUEncoder, self).__init__()
+ self.num_layers = num_layers
+
+ self.gru_layers = [
+ tf.keras.layers.Bidirectional(GRU(out_dim, dropout=dropout, return_sequences=True)) for _ in range(num_layers)
+ ]
+
+ self.dropout = Dropout(dropout)
+ self.supports_masking = True
+
+ def call(self, inputs, training=None, mask=None):
+ x = inputs
+
+ for layer in self.gru_layers:
+ x = layer(x, training=training, mask=mask)
+ x = self.dropout(x, training=training)
+
+ return x
\ No newline at end of file
diff --git a/SourceCodeTools/models/nlp/Decoder.py b/SourceCodeTools/models/nlp/TorchDecoder.py
similarity index 90%
rename from SourceCodeTools/models/nlp/Decoder.py
rename to SourceCodeTools/models/nlp/TorchDecoder.py
index 1ba6f128..dae8fc21 100644
--- a/SourceCodeTools/models/nlp/Decoder.py
+++ b/SourceCodeTools/models/nlp/TorchDecoder.py
@@ -1,25 +1,33 @@
+import logging
+
import torch
from torch import nn
from torch.nn import Embedding
import torch.nn.functional as F
from torch.autograd import Variable
+from SourceCodeTools.models.nlp.common import positional_encoding
+
class Decoder(nn.Module):
- def __init__(self, encoder_out_dim, decoder_dim, out_dim, vocab_size, nheads=1, layers=1):
+ def __init__(self, encoder_out_dim, decoder_dim, out_dim, vocab_size, seq_len, nheads=1, layers=1):
super(Decoder, self).__init__()
self.encoder_adapter = nn.Linear(encoder_out_dim, decoder_dim)
self.embed = nn.Embedding(vocab_size, decoder_dim)
self.decoder_layer = nn.TransformerDecoderLayer(decoder_dim, nheads, dim_feedforward=decoder_dim)
self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=layers)
self.mask = self.generate_square_subsequent_mask(1)
+ self.seq_len = seq_len
+
+ self.position_encoding = positional_encoding(self.seq_len, decoder_dim).permute(1, 0, 2)
self.fc = nn.Linear(decoder_dim, out_dim)
def forward(self, encoder_out, target):
encoder_out = self.encoder_adapter(encoder_out).permute(1, 0, 2)
target = self.embed(target).permute(1, 0, 2)
- if self.mask.size(0) != target.size(0):
+ target = target + self.position_encoding[:target.shape[0]]
+ if self.mask.size(0) != target.size(0): # for self-attention
self.mask = self.generate_square_subsequent_mask(target.size(0)).to(encoder_out.device)
out = self.decoder(target, encoder_out, tgt_mask=self.mask)
@@ -64,6 +72,7 @@ class LSTMDecoder(nn.Module):
def __init__(self, num_buckets, padding=0, encoder_embed_dim=100, embed_dim=100,
out_embed_dim=100, num_layers=1, dropout_in=0.1,
dropout_out=0.1, use_cuda=True):
+ embed_dim = out_embed_dim = encoder_embed_dim
super(LSTMDecoder, self).__init__()
self.use_cuda = use_cuda
self.dropout_in = dropout_in
@@ -73,7 +82,7 @@ def __init__(self, num_buckets, padding=0, encoder_embed_dim=100, embed_dim=100,
padding_idx = padding
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
- self.create_layers(embed_dim, embed_dim, num_layers)
+ self.create_layers(encoder_embed_dim, embed_dim, num_layers)
self.attention = AttentionLayer(out_embed_dim, encoder_embed_dim, embed_dim)
if embed_dim != out_embed_dim:
@@ -82,10 +91,14 @@ def __init__(self, num_buckets, padding=0, encoder_embed_dim=100, embed_dim=100,
def create_layers(self, encoder_embed_dim, embed_dim, num_layers):
self.layers = nn.ModuleList([
- LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim)
+ LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else encoder_embed_dim, encoder_embed_dim if layer == 0 else embed_dim)
for layer in range(num_layers)
])
+ self.norms = nn.ModuleList([
+ nn.LayerNorm(encoder_embed_dim + embed_dim if layer == 0 else encoder_embed_dim) for layer in range(num_layers)
+ ])
+
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, inference=False):
if incremental_state is not None: # TODO what is this?
prev_output_tokens = prev_output_tokens[:, -1:]
@@ -112,9 +125,9 @@ def forward(self, prev_output_tokens, encoder_out, incremental_state=None, infer
else:
# _, encoder_hiddens, encoder_cells = encoder_out
num_layers = len(self.layers)
- prev_hiddens = [Variable(x.data.new(bsz, embed_dim).zero_()) for i in range(num_layers)]
- prev_cells = [Variable(x.data.new(bsz, embed_dim).zero_()) for i in range(num_layers)]
- input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
+ prev_hiddens = [Variable(x.data.new(bsz, embed_dim).zero_()) if i != 0 else encoder_outs[0] for i in range(num_layers)]
+ prev_cells = [Variable(x.data.new(bsz, embed_dim if i!=0 else encoder_outs.size(-1)).zero_()) for i in range(num_layers)]
+ input_feed = Variable(x.data.new(bsz, encoder_outs.size(-1)).zero_())
attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
outs = []
@@ -122,8 +135,9 @@ def forward(self, prev_output_tokens, encoder_out, incremental_state=None, infer
# input feeding: concatenate context vector from previous time step
input = torch.cat((x[j, :, :], input_feed), dim=1)
- for i, rnn in enumerate(self.layers):
+ for i, (rnn, lnorm) in enumerate(zip(self.layers, self.norms)):
# recurrent cell
+ input = lnorm(input)
hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
# hidden state becomes the input to the next layer
diff --git a/SourceCodeTools/models/nlp/Encoder.py b/SourceCodeTools/models/nlp/TorchEncoder.py
similarity index 97%
rename from SourceCodeTools/models/nlp/Encoder.py
rename to SourceCodeTools/models/nlp/TorchEncoder.py
index 82adf0c3..938d9efb 100644
--- a/SourceCodeTools/models/nlp/Encoder.py
+++ b/SourceCodeTools/models/nlp/TorchEncoder.py
@@ -9,7 +9,7 @@ class Encoder(nn.Module):
def __init__(self, encoder_dim, out_dim, nheads=1, layers=1):
super(Encoder, self).__init__()
# self.embed = nn.Embedding(vocab_size, encoder_dim)
- self.encoder_lauer = nn.TransformerEncoderLayer(encoder_dim, nheads, dim_feedforward=encoder_dim)
+ self.encoder_layer = nn.TransformerEncoderLayer(encoder_dim, nheads, dim_feedforward=encoder_dim)
self.encoder = nn.TransformerEncoder(self.encoder_lauer, num_layers=layers)
self.out_adapter = nn.Linear(encoder_dim, out_dim)
@@ -29,7 +29,6 @@ def forward(self, input, lengths=None):
return out.permute(1, 0, 2)
-
class LSTMEncoder(nn.Module):
"""LSTM encoder."""
def __init__(self, embed_dim=100, num_layers=1, dropout_in=0.1, dropout_out=0.1):
diff --git a/SourceCodeTools/models/nlp/common.py b/SourceCodeTools/models/nlp/common.py
new file mode 100644
index 00000000..1004b0b1
--- /dev/null
+++ b/SourceCodeTools/models/nlp/common.py
@@ -0,0 +1,20 @@
+import numpy as np
+
+def positional_encoding(position, d_model):
+ def get_angles(pos, i, d_model):
+ angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
+ return pos * angle_rates
+
+ angle_rads = get_angles(np.arange(position)[:, np.newaxis],
+ np.arange(d_model)[np.newaxis, :],
+ d_model)
+
+ # apply sin to even indices in the array; 2i
+ angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
+
+ # apply cos to odd indices in the array; 2i+1
+ angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
+
+ pos_encoding = angle_rads[np.newaxis, ...]
+ return pos_encoding
+ # return tf.cast(pos_encoding, dtype=tf.float32)
\ No newline at end of file
diff --git a/SourceCodeTools/models/training_options.py b/SourceCodeTools/models/training_options.py
new file mode 100644
index 00000000..e6388e09
--- /dev/null
+++ b/SourceCodeTools/models/training_options.py
@@ -0,0 +1,86 @@
+def add_data_arguments(parser):
+ parser.add_argument('--data_path', '-d', dest='data_path', default=None, help='Path to the files')
+ parser.add_argument('--train_frac', dest='train_frac', default=0.9, type=float, help='')
+ parser.add_argument('--filter_edges', dest='filter_edges', default=None, help='Edges filtered before training')
+ parser.add_argument('--min_count_for_objectives', dest='min_count_for_objectives', default=5, type=int, help='')
+ parser.add_argument('--packages_file', dest='packages_file', default=None, type=str, help='')
+ parser.add_argument('--self_loops', action='store_true')
+ parser.add_argument('--use_node_types', action='store_true')
+ parser.add_argument('--use_edge_types', action='store_true')
+ parser.add_argument('--restore_state', action='store_true')
+ parser.add_argument('--no_global_edges', action='store_true')
+ parser.add_argument('--remove_reverse', action='store_true')
+ parser.add_argument('--custom_reverse', dest='custom_reverse', default=None, help='')
+ parser.add_argument('--restricted_id_pool', dest='restricted_id_pool', default=None, help='')
+
+
+def add_pretraining_arguments(parser):
+ parser.add_argument('--pretrained', '-p', dest='pretrained', default=None, help='')
+ parser.add_argument('--tokenizer', '-t', dest='tokenizer', default=None, help='')
+ parser.add_argument('--pretraining_phase', dest='pretraining_phase', default=0, type=int, help='')
+
+
+def add_training_arguments(parser):
+ parser.add_argument('--embedding_table_size', dest='embedding_table_size', default=200000, type=int, help='Batch size')
+ parser.add_argument('--random_seed', dest='random_seed', default=None, type=int, help='')
+
+ parser.add_argument('--node_emb_size', dest='node_emb_size', default=100, type=int, help='')
+ parser.add_argument('--elem_emb_size', dest='elem_emb_size', default=100, type=int, help='')
+ parser.add_argument('--num_per_neigh', dest='num_per_neigh', default=10, type=int, help='')
+ parser.add_argument('--neg_sampling_factor', dest='neg_sampling_factor', default=3, type=int, help='')
+
+ parser.add_argument('--use_layer_scheduling', action='store_true')
+ parser.add_argument('--schedule_layers_every', dest='schedule_layers_every', default=10, type=int, help='')
+
+ parser.add_argument('--epochs', dest='epochs', default=100, type=int, help='Number of epochs')
+ parser.add_argument('--batch_size', dest='batch_size', default=128, type=int, help='Batch size')
+
+ parser.add_argument("--h_dim", dest="h_dim", default=None, type=int)
+ parser.add_argument("--n_layers", dest="n_layers", default=5, type=int)
+ parser.add_argument("--objectives", dest="objectives", default=None, type=str)
+
+ parser.add_argument("--save_each_epoch", action="store_true")
+ parser.add_argument("--early_stopping", action="store_true")
+ parser.add_argument("--early_stopping_tolerance", default=20, type=int)
+ parser.add_argument("--force_w2v_ns", action="store_true")
+ parser.add_argument("--use_ns_groups", action="store_true")
+
+ parser.add_argument("--metric", default="inner_prod", type=str)
+ parser.add_argument("--nn_index", default="brute", type=str)
+
+ parser.add_argument("--external_dataset", default=None, type=str)
+
+def add_scoring_arguments(parser):
+ parser.add_argument('--measure_scores', action='store_true')
+ parser.add_argument('--dilate_scores', dest='dilate_scores', default=200, type=int, help='')
+
+
+def add_performance_arguments(parser):
+ parser.add_argument('--no_checkpoints', dest="save_checkpoints", action='store_false')
+
+ parser.add_argument('--use_gcn_checkpoint', action='store_true')
+ parser.add_argument('--use_att_checkpoint', action='store_true')
+ parser.add_argument('--use_gru_checkpoint', action='store_true')
+
+
+def add_gnn_train_args(parser):
+ parser.add_argument(
+ '--training_mode', '-tr', dest='training_mode', default=None,
+ help='Selects one of training procedures [multitask]'
+ )
+
+ add_data_arguments(parser)
+ add_pretraining_arguments(parser)
+ add_training_arguments(parser)
+ add_scoring_arguments(parser)
+ add_performance_arguments(parser)
+
+ parser.add_argument('--note', dest='note', default="", help='Note, added to metadata')
+ parser.add_argument('model_output_dir', help='Location of the final model')
+
+ # parser.add_argument('--intermediate_supervision', action='store_true')
+ parser.add_argument('--gpu', dest='gpu', default=-1, type=int, help='')
+
+
+def verify_arguments(args):
+ pass
\ No newline at end of file
diff --git a/SourceCodeTools/nlp/batchers/PythonBatcher.py b/SourceCodeTools/nlp/batchers/PythonBatcher.py
index e6ff2049..0b3cbdcd 100644
--- a/SourceCodeTools/nlp/batchers/PythonBatcher.py
+++ b/SourceCodeTools/nlp/batchers/PythonBatcher.py
@@ -1,21 +1,49 @@
import json
import os
import shelve
+import shutil
import tempfile
+from collections import defaultdict
+from time import time
from functools import lru_cache
from math import ceil
from typing import Dict, Optional, List
-from spacy.gold import biluo_tags_from_offsets
+import spacy
+from spacy.gold import biluo_tags_from_offsets as spacy_biluo_tags_from_offsets
from SourceCodeTools.code.ast_tools import get_declarations
from SourceCodeTools.models.ClassWeightNormalizer import ClassWeightNormalizer
from SourceCodeTools.nlp import create_tokenizer, tag_map_from_sentences, TagMap, token_hasher, try_int
from SourceCodeTools.nlp.entity import fix_incorrect_tags
+from SourceCodeTools.nlp.entity.annotator.annotator_utils import adjust_offsets
from SourceCodeTools.nlp.entity.utils import overlap
import numpy as np
+
+def biluo_tags_from_offsets(doc, ents, no_localization):
+ ent_tags = spacy_biluo_tags_from_offsets(doc, ents)
+
+ if no_localization:
+ tags = []
+ for ent in ent_tags:
+ parts = ent.split("-")
+
+ assert len(parts) <= 2
+
+ if len(parts) == 2:
+ if parts[0] == "B" or parts[0] == "U":
+ tags.append(parts[1])
+ else:
+ tags.append("O")
+ else:
+ tags.append("O")
+
+ ent_tags = tags
+
+ return ent_tags
+
def filter_unlabeled(entities, declarations):
"""
Get a list of declarations that were not mentioned in `entities`
@@ -33,37 +61,46 @@ def filter_unlabeled(entities, declarations):
return for_mask
+def print_token_tag(doc, tags):
+ for t, tag in zip(doc, tags):
+ print(t, "\t", tag)
+
+
class PythonBatcher:
def __init__(
self, data, batch_size: int, seq_len: int,
- graphmap: Dict[str, int], wordmap: Dict[str, int], tagmap: Optional[TagMap] = None,
+ wordmap: Dict[str, int], *, graphmap: Optional[Dict[str, int]], tagmap: Optional[TagMap] = None,
mask_unlabeled_declarations=True,
- class_weights=False, element_hash_size=1000
+ class_weights=False, element_hash_size=1000, len_sort=True, tokenizer="spacy", no_localization=False
):
self.create_cache()
- self.data = data
+ self.data = sorted(data, key=lambda x: len(x[0])) if len_sort else data
self.batch_size = batch_size
self.seq_len = seq_len
self.class_weights = None
self.mask_unlabeled_declarations = mask_unlabeled_declarations
+ self.tokenizer = tokenizer
+ if tokenizer == "codebert":
+ self.vocab = spacy.blank("en").vocab
+ self.no_localization = no_localization
- self.nlp = create_tokenizer("spacy")
+ self.nlp = create_tokenizer(tokenizer)
if tagmap is None:
self.tagmap = tag_map_from_sentences(list(zip(*[self.prepare_sent(sent) for sent in data]))[1])
else:
self.tagmap = tagmap
- self.graphpad = len(graphmap)
+ self.graphpad = len(graphmap) if graphmap is not None else None
self.wordpad = len(wordmap)
self.tagpad = self.tagmap["O"]
self.prefpad = element_hash_size
self.suffpad = element_hash_size
- self.graphmap_func = lambda g: graphmap.get(g, len(graphmap))
+ self.graphmap_func = (lambda g: graphmap.get(g, len(graphmap))) if graphmap is not None else None
self.wordmap_func = lambda w: wordmap.get(w, len(wordmap))
- self.tagmap_func = lambda t: self.tagmap.get(t, len(self.tagmap))
+ self.tagmap_func = lambda t: self.tagmap.get(t, self.tagmap["O"])
self.prefmap_func = lambda w: token_hasher(w[:3], element_hash_size)
self.suffmap_func = lambda w: token_hasher(w[-3:], element_hash_size)
@@ -81,10 +118,26 @@ def __init__(
else:
self.classw_func = lambda t: 1.
+ def __del__(self):
+ self.sent_cache.close()
+ self.batch_cache.close()
+
+ from shutil import rmtree
+ rmtree(self.tmp_dir, ignore_errors=True)
+
def create_cache(self):
- self.tmp_dir = tempfile.TemporaryDirectory()
- self.sent_cache = shelve.open(os.path.join(self.tmp_dir.name, "sent_cache.db"))
- self.batch_cache = shelve.open(os.path.join(self.tmp_dir.name, "batch_cache.db"))
+ char_ranges = [chr(i) for i in range(ord("a"), ord("a")+26)] + [chr(i) for i in range(ord("A"), ord("A")+26)] + [chr(i) for i in range(ord("0"), ord("0")+10)]
+ from random import sample
+ rnd_name = "".join(sample(char_ranges, k=10)) + str(int(time() * 1e6))
+ time()
+
+ self.tmp_dir = os.path.join(tempfile.gettempdir(), rnd_name)
+ if os.path.isdir(self.tmp_dir):
+ shutil.rmtree(self.tmp_dir)
+ os.mkdir(self.tmp_dir)
+
+ self.sent_cache = shelve.open(os.path.join(self.tmp_dir, "sent_cache"))
+ self.batch_cache = shelve.open(os.path.join(self.tmp_dir, "batch_cache"))
def num_classes(self):
return len(self.tagmap)
@@ -106,16 +159,48 @@ def prepare_sent(self, sent):
unlabeled_dec = filter_unlabeled(ents, get_declarations(text))
tokens = [t.text for t in doc]
- ents_tags = biluo_tags_from_offsets(doc, ents)
- repl_tags = biluo_tags_from_offsets(doc, repl)
+
+ if self.tokenizer == "codebert":
+ backup_tokens = doc
+ fixed_spaces = [False]
+ fixed_words = [""]
+
+ for ind, t in enumerate(doc):
+ if len(t.text) > 1:
+ fixed_words.append(t.text.strip("Ġ"))
+ else:
+ fixed_words.append(t.text)
+ if ind != 0:
+ fixed_spaces.append(t.text.startswith("Ġ") and len(t.text) > 1)
+ fixed_spaces.append(False)
+ fixed_spaces.append(False)
+ fixed_words.append("")
+
+ assert len(fixed_spaces) == len(fixed_words)
+
+ from spacy.tokens import Doc
+ doc = Doc(self.vocab, fixed_words, fixed_spaces)
+
+ assert len(doc) - 2 == len(backup_tokens)
+ assert len(doc.text) - 7 == len(backup_tokens.text)
+ ents = adjust_offsets(ents, -3)
+ repl = adjust_offsets(repl, -3)
+ if self.mask_unlabeled_declarations:
+ unlabeled_dec = adjust_offsets(unlabeled_dec, -3)
+
+ ents_tags = biluo_tags_from_offsets(doc, ents, self.no_localization)
+ repl_tags = biluo_tags_from_offsets(doc, repl, self.no_localization)
if self.mask_unlabeled_declarations:
- unlabeled_dec = biluo_tags_from_offsets(doc, unlabeled_dec)
+ unlabeled_dec = biluo_tags_from_offsets(doc, unlabeled_dec, self.no_localization)
fix_incorrect_tags(ents_tags)
fix_incorrect_tags(repl_tags)
if self.mask_unlabeled_declarations:
fix_incorrect_tags(unlabeled_dec)
+ if self.tokenizer == "codebert":
+ tokens = [""] + [t.text for t in backup_tokens] + [""]
+
assert len(tokens) == len(ents_tags) == len(repl_tags)
if self.mask_unlabeled_declarations:
assert len(tokens) == len(unlabeled_dec)
@@ -151,7 +236,7 @@ def encode(seq, encode_func, pad):
pref = encode(sent, self.prefmap_func, self.prefpad)
suff = encode(sent, self.suffmap_func, self.suffpad)
s = encode(sent, self.wordmap_func, self.wordpad)
- r = encode(repl, self.graphmap_func, self.graphpad) # TODO test
+ r = encode(repl, self.graphmap_func, self.graphpad) if self.graphmap_func is not None else None
# labels
t = encode(tags, self.tagmap_func, self.tagpad)
@@ -160,38 +245,55 @@ def encode(seq, encode_func, pad):
hidem = encode(
list(range(len(sent))) if unlabeled_decls is None else unlabeled_decls,
self.mask_unlbl_func, self.mask_unlblpad
- )
+ ).astype(np.bool)
# class weights
classw = encode(tags, self.classw_func, self.classwpad)
- assert len(s) == len(r) == len(pref) == len(suff) == len(t) == len(classw) == len(hidem)
+ assert len(s) == len(pref) == len(suff) == len(t) == len(classw) == len(hidem)
+ if r is not None:
+ assert len(r) == len(s)
+
+ no_localization_mask = np.array([tag != self.tagpad for tag in t]).astype(np.bool)
output = {
"tok_ids": s,
- "graph_ids": r,
+ "replacements": repl,
+ # "graph_ids": r,
"prefix": pref,
"suffix": suff,
"tags": t,
"class_weights": classw,
"hide_mask": hidem,
- "lens": len(s) if len(s) < self.seq_len else self.seq_len
+ "no_loc_mask": no_localization_mask,
+ "lens": len(sent) if len(sent) < self.seq_len else self.seq_len
}
+ if r is not None:
+ output["graph_ids"] = r
+
self.batch_cache[input_json] = output
return output
def format_batch(self, batch):
- fbatch = {
- "tok_ids": [], "graph_ids": [], "prefix": [], "suffix": [],
- "tags": [], "class_weights": [], "hide_mask": [], "lens": []
- }
+ # fbatch = {
+ # "tok_ids": [], "graph_ids": [], "prefix": [], "suffix": [],
+ # "tags": [], "class_weights": [], "hide_mask": [], "lens": [], "replacements": []
+ # }
+ fbatch = defaultdict(list)
for sent in batch:
for key, val in sent.items():
fbatch[key].append(val)
- return {key: np.stack(val) if key != "lens" else np.array(val, dtype=np.int32) for key, val in fbatch.items()}
+ if len(fbatch["graph_ids"]) == 0:
+ fbatch.pop("graph_ids")
+
+ max_len = max(fbatch["lens"])
+
+ return {
+ key: np.stack(val)[:,:max_len] if key != "lens" and key != "replacements"
+ else (np.array(val, dtype=np.int32) if key != "replacements" else np.array(val)) for key, val in fbatch.items()}
def generate_batches(self):
batch = []
@@ -200,7 +302,9 @@ def generate_batches(self):
if len(batch) >= self.batch_size:
yield self.format_batch(batch)
batch = []
- yield self.format_batch(batch)
+ if len(batch) > 0:
+ yield self.format_batch(batch)
+ # yield self.format_batch(batch)
def __iter__(self):
return self.generate_batches()
diff --git a/SourceCodeTools/nlp/codebert/codebert.py b/SourceCodeTools/nlp/codebert/codebert.py
new file mode 100644
index 00000000..cf841028
--- /dev/null
+++ b/SourceCodeTools/nlp/codebert/codebert.py
@@ -0,0 +1,114 @@
+import pickle
+
+import torch
+from tqdm import tqdm
+from transformers import RobertaTokenizer, RobertaModel
+
+from SourceCodeTools.models.Embedder import Embedder
+from SourceCodeTools.nlp.entity.type_prediction import get_type_prediction_arguments, ModelTrainer
+from SourceCodeTools.nlp.entity.utils.data import read_data
+
+
+def load_typed_nodes(path):
+ from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist
+ type_ann = unpersist(path)
+
+ filter_rule = lambda name: "0x" not in name
+
+ type_ann = type_ann[
+ type_ann["dst"].apply(filter_rule)
+ ]
+
+ typed_nodes = set(type_ann["src"].tolist())
+ return typed_nodes
+
+
+class CodeBertModelTrainer(ModelTrainer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def set_type_ann_edges(self, path):
+ self.type_ann_edges = path
+
+ def get_batcher(self, *args, **kwargs):
+ kwargs.update({"tokenizer": "codebert"})
+ return self.batcher(*args, **kwargs)
+
+ def train_model(self):
+ # graph_emb = load_pkl_emb(self.graph_emb_path) if self.graph_emb_path is not None else None
+
+ typed_nodes = load_typed_nodes(self.type_ann_edges)
+
+ decoder_mapping = RobertaTokenizer.from_pretrained("microsoft/codebert-base").decoder
+ tok_ids, words = zip(*decoder_mapping.items())
+ vocab_mapping = dict(zip(words, tok_ids))
+ batcher = self.get_batcher(
+ self.train_data + self.test_data, self.batch_size, seq_len=self.seq_len,
+ graphmap=None,
+ wordmap=vocab_mapping, tagmap=None,
+ class_weights=False, element_hash_size=1
+ )
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = RobertaModel.from_pretrained("microsoft/codebert-base")
+ model.to(device)
+
+ node_ids = []
+ embeddings = []
+
+ for ind, batch in enumerate(tqdm(batcher)):
+ # token_ids, graph_ids, labels, class_weights, lengths = b
+ token_ids = torch.LongTensor(batch["tok_ids"])
+ lens = torch.LongTensor(batch["lens"])
+
+ token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""]
+
+ def get_length_mask(target, lens):
+ mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None]
+ return mask
+
+ mask = get_length_mask(token_ids, lens)
+ with torch.no_grad():
+ embs = model(input_ids=token_ids, attention_mask=mask)
+
+ for s_emb, s_repl in zip(embs.last_hidden_state, batch["replacements"]):
+ unique_repls = set(list(s_repl))
+ repls_for_ann = [r for r in unique_repls if r in typed_nodes]
+
+ for r in repls_for_ann:
+ position = s_repl.index(r)
+ if position > 512:
+ continue
+ node_ids.append(r)
+ embeddings.append(s_emb[position])
+
+ all_embs = torch.stack(embeddings, dim=0).numpy()
+ embedder = Embedder(dict(zip(node_ids, range(len(node_ids)))), all_embs)
+ pickle.dump(embedder, open("codebert_embeddings.pkl", "wb"), fix_imports=False)
+ print(node_ids)
+
+
+def main():
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ # tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
+ model = RobertaModel.from_pretrained("microsoft/codebert-base")
+ # model.to(device)
+ args = get_type_prediction_arguments()
+
+ # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray',
+ # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', 'Type'}
+
+ train_data, test_data = read_data(
+ open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True,
+ include_only="entities",
+ min_entity_count=args.min_entity_count, random_seed=args.random_seed
+ )
+
+ trainer = CodeBertModelTrainer(train_data, test_data, params={}, seq_len=512)
+ trainer.set_type_ann_edges(args.type_ann_edges)
+ trainer.train_model()
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/nlp/codebert/codebert_train.py b/SourceCodeTools/nlp/codebert/codebert_train.py
new file mode 100644
index 00000000..e166786f
--- /dev/null
+++ b/SourceCodeTools/nlp/codebert/codebert_train.py
@@ -0,0 +1,421 @@
+import json
+import os
+import pickle
+from datetime import datetime
+from time import time
+
+import torch
+from torch.utils.tensorboard import SummaryWriter
+from torch.version import cuda
+from tqdm import tqdm
+from transformers import RobertaTokenizer, RobertaModel
+
+from SourceCodeTools.models.Embedder import Embedder
+from SourceCodeTools.nlp.codebert.codebert import CodeBertModelTrainer, load_typed_nodes
+from SourceCodeTools.nlp.entity.type_prediction import get_type_prediction_arguments, ModelTrainer, load_pkl_emb, \
+ scorer, filter_labels
+from SourceCodeTools.nlp.entity.utils.data import read_data
+
+import torch.nn as nn
+
+class CodebertHybridModel(nn.Module):
+ def __init__(
+ self, codebert_model, graph_emb, padding_idx, num_classes, dense_hidden=100, dropout=0.1, bert_emb_size=768,
+ no_graph=False
+ ):
+ super(CodebertHybridModel, self).__init__()
+
+ self.codebert_model = codebert_model
+ self.use_graph = not no_graph
+
+ num_emb = padding_idx + 1 # padding id is usually not a real embedding
+ emb_dim = graph_emb.shape[1]
+ self.graph_emb = nn.Embedding(num_embeddings=num_emb, embedding_dim=emb_dim, padding_idx=padding_idx)
+
+ import numpy as np
+ pretrained_embeddings = torch.from_numpy(np.concatenate([graph_emb, np.zeros((1, emb_dim))], axis=0)).float()
+ new_param = torch.nn.Parameter(pretrained_embeddings)
+ assert self.graph_emb.weight.shape == new_param.shape
+ self.graph_emb.weight = new_param
+ self.graph_emb.weight.requires_grad = False
+
+ self.fc1 = nn.Linear(
+ bert_emb_size + (emb_dim if self.use_graph else 0),
+ dense_hidden
+ )
+ self.drop = nn.Dropout(dropout)
+ self.fc2 = nn.Linear(dense_hidden, num_classes)
+
+ self.loss_f = nn.CrossEntropyLoss(reduction="mean")
+
+ def forward(self, token_ids, graph_ids, mask, finetune=False):
+ if finetune:
+ x = self.codebert_model(input_ids=token_ids, attention_mask=mask).last_hidden_state
+ else:
+ with torch.no_grad():
+ x = self.codebert_model(input_ids=token_ids, attention_mask=mask).last_hidden_state
+
+ if self.use_graph:
+ graph_emb = self.graph_emb(graph_ids)
+ x = torch.cat([x, graph_emb], dim=-1)
+
+ x = torch.relu(self.fc1(x))
+ x = self.drop(x)
+ x = self.fc2(x)
+
+ return x
+
+ def loss(self, logits, labels, mask, class_weights=None, extra_mask=None):
+ if extra_mask is not None:
+ mask = torch.logical_and(mask, extra_mask)
+ logits = logits[mask, :]
+ labels = labels[mask]
+ loss = self.loss_f(logits, labels)
+ # if class_weights is None:
+ # loss = tf.reduce_mean(tf.boolean_mask(losses, seq_mask))
+ # else:
+ # loss = tf.reduce_mean(tf.boolean_mask(losses * class_weights, seq_mask))
+
+ return loss
+
+ def score(self, logits, labels, mask, scorer=None, extra_mask=None):
+ if extra_mask is not None:
+ mask = torch.logical_and(mask, extra_mask)
+ true_labels = labels[mask]
+ argmax = logits.argmax(-1)
+ estimated_labels = argmax[mask]
+
+ p, r, f1 = scorer(to_numpy(estimated_labels), to_numpy(true_labels))
+
+ return p, r, f1
+
+
+def get_length_mask(target, lens):
+ mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None]
+ return mask
+
+
+def batch_to_torch(batch, device):
+ key_types = {
+ 'tok_ids': torch.LongTensor,
+ 'tags': torch.LongTensor,
+ 'hide_mask': torch.BoolTensor,
+ 'no_loc_mask': torch.BoolTensor,
+ 'lens': torch.LongTensor,
+ 'graph_ids': torch.LongTensor
+ }
+ for key, tf in key_types.items():
+ batch[key] = tf(batch[key]).to(device)
+
+
+def to_numpy(tensor):
+ return tensor.cpu().detach().numpy()
+
+
+def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids, labels, lengths,
+ extra_mask=None, class_weights=None, scorer=None, finetune=False, vocab_mapping=None):
+ token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""]
+ seq_mask = get_length_mask(token_ids, lengths)
+ logits = model(token_ids, graph_ids, mask=seq_mask, finetune=finetune)
+ loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask)
+ p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ return loss, p, r, f1
+
+
+def test_step(
+ model, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None,
+ vocab_mapping=None
+):
+ with torch.no_grad():
+ token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""]
+ seq_mask = get_length_mask(token_ids, lengths)
+ logits = model(token_ids, graph_ids, mask=seq_mask)
+ loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask)
+ p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask)
+
+ return loss, p, r, f1
+
+
+class CodeBertModelTrainer2(CodeBertModelTrainer):
+ def __init__(self, *args, gpu_id=-1, **kwargs):
+ self.gpu_id = gpu_id
+ super().__init__(*args, **kwargs)
+ self.set_gpu()
+
+ def get_dataloaders(self, word_emb, graph_emb, suffix_prefix_buckets):
+ decoder_mapping = RobertaTokenizer.from_pretrained("microsoft/codebert-base").decoder
+ tok_ids, words = zip(*decoder_mapping.items())
+ self.vocab_mapping = dict(zip(words, tok_ids))
+
+ train_batcher = self.get_batcher(
+ self.train_data, self.batch_size, seq_len=self.seq_len,
+ graphmap=graph_emb.ind if graph_emb is not None else None,
+ wordmap=self.vocab_mapping, tagmap=None,
+ class_weights=False, element_hash_size=suffix_prefix_buckets, no_localization=self.no_localization
+ )
+ test_batcher = self.get_batcher(
+ self.test_data, self.batch_size, seq_len=self.seq_len,
+ graphmap=graph_emb.ind if graph_emb is not None else None,
+ wordmap=self.vocab_mapping,
+ tagmap=train_batcher.tagmap, # use the same mapping
+ class_weights=False, element_hash_size=suffix_prefix_buckets, # class_weights are not used for testing
+ no_localization=self.no_localization
+ )
+ return train_batcher, test_batcher
+
+ def train(
+ self, model, train_batches, test_batches, epochs, report_every=10, scorer=None, learning_rate=0.01,
+ learning_rate_decay=1., finetune=False, summary_writer=None, save_ckpt_fn=None, no_localization=False
+ ):
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=learning_rate_decay)
+
+ train_losses = []
+ test_losses = []
+ train_f1s = []
+ test_f1s = []
+
+ num_train_batches = len(train_batches)
+ num_test_batches = len(test_batches)
+
+ best_f1 = 0.
+
+ try:
+ for e in range(epochs):
+ losses = []
+ ps = []
+ rs = []
+ f1s = []
+
+ start = time()
+ model.train()
+
+ for ind, batch in enumerate(tqdm(train_batches)):
+ batch_to_torch(batch, self.device)
+ # token_ids, graph_ids, labels, class_weights, lengths = b
+ loss, p, r, f1 = train_step_finetune(
+ model=model, optimizer=optimizer, token_ids=batch['tok_ids'],
+ prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'],
+ labels=batch['tags'], lengths=batch['lens'],
+ extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'],
+ # class_weights=batch['class_weights'],
+ scorer=scorer, finetune=finetune and e / epochs > 0.6,
+ vocab_mapping=self.vocab_mapping
+ )
+ losses.append(loss.cpu().item())
+ ps.append(p)
+ rs.append(r)
+ f1s.append(f1)
+
+ self.summary_writer.add_scalar("Loss/Train", loss, global_step=e * num_train_batches + ind)
+ self.summary_writer.add_scalar("Precision/Train", p, global_step=e * num_train_batches + ind)
+ self.summary_writer.add_scalar("Recall/Train", r, global_step=e * num_train_batches + ind)
+ self.summary_writer.add_scalar("F1/Train", f1, global_step=e * num_train_batches + ind)
+
+ test_alosses = []
+ test_aps = []
+ test_ars = []
+ test_af1s = []
+
+ model.eval()
+
+ for ind, batch in enumerate(test_batches):
+ batch_to_torch(batch, self.device)
+ # token_ids, graph_ids, labels, class_weights, lengths = b
+ test_loss, test_p, test_r, test_f1 = test_step(
+ model=model, token_ids=batch['tok_ids'],
+ prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'],
+ labels=batch['tags'], lengths=batch['lens'],
+ extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'],
+ # class_weights=batch['class_weights'],
+ scorer=scorer, vocab_mapping=self.vocab_mapping
+ )
+
+ self.summary_writer.add_scalar("Loss/Test", test_loss, global_step=e * num_test_batches + ind)
+ self.summary_writer.add_scalar("Precision/Test", test_p, global_step=e * num_test_batches + ind)
+ self.summary_writer.add_scalar("Recall/Test", test_r, global_step=e * num_test_batches + ind)
+ self.summary_writer.add_scalar("F1/Test", test_f1, global_step=e * num_test_batches + ind)
+ test_alosses.append(test_loss.cpu().item())
+ test_aps.append(test_p)
+ test_ars.append(test_r)
+ test_af1s.append(test_f1)
+
+ epoch_time = time() - start
+
+ train_losses.append(float(sum(losses) / len(losses)))
+ train_f1s.append(float(sum(f1s) / len(f1s)))
+ test_losses.append(float(sum(test_alosses) / len(test_alosses)))
+ test_f1s.append(float(sum(test_af1s) / len(test_af1s)))
+
+ print(
+ f"Epoch: {e}, {epoch_time: .2f} s, Train Loss: {train_losses[-1]: .4f}, Train P: {sum(ps) / len(ps): .4f}, Train R: {sum(rs) / len(rs): .4f}, Train F1: {sum(f1s) / len(f1s): .4f}, "
+ f"Test loss: {test_losses[-1]: .4f}, Test P: {sum(test_aps) / len(test_aps): .4f}, Test R: {sum(test_ars) / len(test_ars): .4f}, Test F1: {test_f1s[-1]: .4f}")
+
+ if save_ckpt_fn is not None and float(test_f1s[-1]) > best_f1:
+ save_ckpt_fn()
+ best_f1 = float(test_f1s[-1])
+
+ scheduler.step(epoch=e)
+
+ except KeyboardInterrupt:
+ pass
+
+ return train_losses, train_f1s, test_losses, test_f1s
+
+ def create_summary_writer(self, path):
+ self.summary_writer = SummaryWriter(path)
+
+ def set_gpu(self):
+ # torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if self.gpu_id == -1:
+ self.use_cuda = False
+ self.device = "cpu"
+ else:
+ torch.cuda.set_device(self.gpu_id)
+ self.use_cuda = True
+ self.device = f"cuda:{self.gpu_id}"
+
+ def train_model(self):
+
+ print(f"\n\n{self.model_params}")
+ lr = self.model_params.pop("learning_rate")
+ lr_decay = self.model_params.pop("learning_rate_decay")
+ suffix_prefix_buckets = self.model_params.pop("suffix_prefix_buckets")
+
+ graph_emb = load_pkl_emb(self.graph_emb_path) if self.graph_emb_path is not None else None
+
+ train_batcher, test_batcher = self.get_dataloaders(None, graph_emb, suffix_prefix_buckets=suffix_prefix_buckets)
+
+ codebert_model = RobertaModel.from_pretrained("microsoft/codebert-base")
+ model = CodebertHybridModel(
+ codebert_model, graph_emb.e, padding_idx=train_batcher.graphpad, num_classes=train_batcher.num_classes(),
+ no_graph=self.no_graph
+ )
+ if self.use_cuda:
+ model.cuda()
+
+ trial_dir = os.path.join(self.output_dir, "codebert_" + str(datetime.now())).replace(":", "-").replace(" ", "_")
+ os.mkdir(trial_dir)
+ self.create_summary_writer(trial_dir)
+
+ def save_ckpt_fn():
+ checkpoint_path = os.path.join(trial_dir, "checkpoint")
+ torch.save(model, open(checkpoint_path, 'wb'))
+
+ train_losses, train_f1, test_losses, test_f1 = self.train(
+ model=model, train_batches=train_batcher, test_batches=test_batcher,
+ epochs=self.epochs, learning_rate=lr,
+ scorer=lambda pred, true: scorer(pred, true, train_batcher.tagmap, no_localization=self.no_localization),
+ learning_rate_decay=lr_decay, finetune=self.finetune, save_ckpt_fn=save_ckpt_fn,
+ no_localization=self.no_localization
+ )
+
+ # checkpoint_path = os.path.join(trial_dir, "checkpoint")
+ # model.save_weights(checkpoint_path)
+
+ metadata = {
+ "train_losses": train_losses,
+ "train_f1": train_f1,
+ "test_losses": test_losses,
+ "test_f1": test_f1,
+ "learning_rate": lr,
+ "learning_rate_decay": lr_decay,
+ "epochs": self.epochs,
+ "suffix_prefix_buckets": suffix_prefix_buckets,
+ "seq_len": self.seq_len,
+ "batch_size": self.batch_size,
+ "no_localization": self.no_localization
+ }
+
+ print("Maximum f1:", max(test_f1))
+
+ # write_config(trial_dir, params, extra_params={"suffix_prefix_buckets": suffix_prefix_buckets, "seq_len": seq_len})
+
+ metadata.update(self.model_params)
+
+ with open(os.path.join(trial_dir, "params.json"), "w") as metadata_sink:
+ metadata_sink.write(json.dumps(metadata, indent=4))
+
+ pickle.dump(train_batcher.tagmap, open(os.path.join(trial_dir, "tag_types.pkl"), "wb"))
+
+ # for ind, batch in enumerate(tqdm(batcher)):
+ # # token_ids, graph_ids, labels, class_weights, lengths = b
+ # token_ids = torch.LongTensor(batch["tok_ids"])
+ # lens = torch.LongTensor(batch["lens"])
+ #
+ # token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""]
+ #
+ # def get_length_mask(target, lens):
+ # mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None]
+ # return mask
+ #
+ # mask = get_length_mask(token_ids, lens)
+ # with torch.no_grad():
+ # embs = model(input_ids=token_ids, attention_mask=mask)
+ #
+ # for s_emb, s_repl in zip(embs.last_hidden_state, batch["replacements"]):
+ # unique_repls = set(list(s_repl))
+ # repls_for_ann = [r for r in unique_repls if r in typed_nodes]
+ #
+ # for r in repls_for_ann:
+ # position = s_repl.index(r)
+ # if position > 512:
+ # continue
+ # node_ids.append(r)
+ # embeddings.append(s_emb[position])
+ #
+ # all_embs = torch.stack(embeddings, dim=0).numpy()
+ # embedder = Embedder(dict(zip(node_ids, range(len(node_ids)))), all_embs)
+ # pickle.dump(embedder, open("codebert_embeddings.pkl", "wb"), fix_imports=False)
+ # print(node_ids)
+
+
+def main():
+ args = get_type_prediction_arguments()
+
+ # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray',
+ # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', 'Type'}
+ if args.restrict_allowed:
+ allowed = {
+ 'str', 'Optional', 'int', 'Any', 'Union', 'bool', 'Callable', 'Dict', 'bytes', 'float', 'Description',
+ 'List', 'Sequence', 'Namespace', 'T', 'Type', 'object', 'HTTPServerRequest', 'Future', "Matcher"
+ }
+ else:
+ allowed = None
+
+ # train_data, test_data = read_data(
+ # open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True,
+ # include_only="entities",
+ # min_entity_count=args.min_entity_count, random_seed=args.random_seed
+ # )
+
+ from pathlib import Path
+ dataset_dir = Path(args.data_path).parent
+ train_data = filter_labels(
+ pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_train.pkl"), "rb")),
+ allowed=allowed
+ )
+ test_data = filter_labels(
+ pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_test.pkl"), "rb")),
+ allowed=allowed
+ )
+
+ trainer = CodeBertModelTrainer2(
+ train_data, test_data, params={"learning_rate": 1e-4, "learning_rate_decay": 0.99, "suffix_prefix_buckets": 1},
+ graph_emb_path=args.graph_emb_path, word_emb_path=args.word_emb_path,
+ output_dir=args.model_output, epochs=args.epochs, batch_size=args.batch_size, gpu_id=args.gpu,
+ finetune=args.finetune, trials=args.trials, seq_len=args.max_seq_len, no_localization=args.no_localization,
+ no_graph=args.no_graph
+ )
+ trainer.set_type_ann_edges(args.type_ann_edges)
+ trainer.train_model()
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/SourceCodeTools/nlp/entity/annotator/annotator_utils.py b/SourceCodeTools/nlp/entity/annotator/annotator_utils.py
index 3dfba4c6..d3e8a847 100644
--- a/SourceCodeTools/nlp/entity/annotator/annotator_utils.py
+++ b/SourceCodeTools/nlp/entity/annotator/annotator_utils.py
@@ -57,10 +57,22 @@ def to_offsets(body: str, entities: Iterable[Iterable], as_bytes=False, cum_lens
def adjust_offsets(offsets, amount):
+ """
+ Adjust offset by subtracting certain amount from the start and end positions
+ :param offsets: iterable with offsets
+ :param amount: adjustment amount
+ :return: list of adjusted offsets
+ """
return [(offset[0] - amount, offset[1] - amount, offset[2]) for offset in offsets]
def adjust_offsets2(offsets, amount):
+ """
+ Adjust offset by adding certain amount to the start and end positions
+ :param offsets: iterable with offsets
+ :param amount: adjustment amount
+ :return: list of adjusted offsets
+ """
return [(offset[0] + amount, offset[1] + amount, offset[2]) for offset in offsets]
diff --git a/SourceCodeTools/nlp/entity/apply_model.py b/SourceCodeTools/nlp/entity/apply_model.py
index d6d75e68..9ddd64bd 100644
--- a/SourceCodeTools/nlp/entity/apply_model.py
+++ b/SourceCodeTools/nlp/entity/apply_model.py
@@ -4,11 +4,17 @@
import os
import pickle
+import numpy as np
+from tqdm import tqdm
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
+os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
+
import tensorflow as tf
from SourceCodeTools.nlp.entity import parse_biluo
from SourceCodeTools.nlp.entity.tf_models.tf_model import TypePredictor
-from SourceCodeTools.nlp.entity.type_prediction import PythonBatcher, load_pkl_emb
+from SourceCodeTools.nlp.entity.type_prediction import PythonBatcher, load_pkl_emb, span_f1
from SourceCodeTools.nlp.entity.utils.data import read_data
@@ -42,14 +48,14 @@ def predict_one(model, input, tagmap):
return logits_to_annotations(
tf.math.argmax(model(
token_ids=input['tok_ids'], prefix_ids=input['prefix'], suffix_ids=input['suffix'],
- graph_ids=input['graph_ids'], training=False
+ graph_ids=input['graph_ids'], training=False, mask=tf.sequence_mask(input["lens"], input["tok_ids"].shape[1])
), axis=-1),
input['lens'],
tagmap
)
-def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=None, checkpoint_path=None, batch_size=1):
+def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=None, checkpoint_path=None):
graph_emb = load_pkl_emb(graph_emb_path)
word_emb = load_pkl_emb(word_emb_path)
@@ -61,10 +67,11 @@ def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=No
seq_len = params.pop("seq_len")
suffix_prefix_buckets = params.pop("suffix_prefix_buckets")
+ batch_size = params.pop("batch_size")
data_batcher = Batcher(
data, batch_size, seq_len=seq_len, wordmap=word_emb.ind, graphmap=graph_emb.ind, tagmap=tagmap,
- mask_unlabeled_declarations=True, class_weights=False, element_hash_size=suffix_prefix_buckets
+ mask_unlabeled_declarations=True, class_weights=False, element_hash_size=suffix_prefix_buckets, len_sort=True
)
params.pop("train_losses")
@@ -76,26 +83,96 @@ def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=No
params.pop("learning_rate_decay")
model = Model(
- word_emb, graph_emb, train_embeddings=False, num_classes=len(tagmap), seq_len=seq_len, **params
+ word_emb, graph_emb, train_embeddings=False, num_classes=len(tagmap), seq_len=seq_len,
+ suffix_prefix_buckets=suffix_prefix_buckets, **params
)
model.load_weights(os.path.join(checkpoint_path, "checkpoint"))
all_true = []
all_estimated = []
+ true_scoring = []
+ pred_scoring = []
- for batch in data_batcher:
+ for batch_ind, batch in enumerate(data_batcher):
true_annotations = logits_to_annotations(batch['tags'], batch['lens'], tagmap)
est_annotations = predict_one(model, batch, tagmap)
all_true.extend(true_annotations)
all_estimated.extend(est_annotations)
+ for s_ind in range(len(batch["lens"])):
+ for ent in true_annotations[s_ind]:
+ true_scoring.append((batch_ind, s_ind, ent))
+ for ent in est_annotations[s_ind]:
+ pred_scoring.append((batch_ind, s_ind, ent))
+
+ precision, recall, f1 = span_f1(set(pred_scoring), set(true_scoring))
+
+ with open(os.path.join(args.checkpoint_path, "scores.txt"), "w") as scores_sink:
+ scores_str = f"Precision: {precision: .2f}, Recall: {recall: .2f}, f1: {f1: .2f}"
+ print(scores_str)
+ scores_sink.write(f"{scores_str}\n")
+
from SourceCodeTools.nlp.entity.entity_render import render_annotations
- html = render_annotations(zip(data, all_estimated, all_true))
- with open("render.html", "w") as render:
+ html = render_annotations(zip(data_batcher.data, all_estimated, all_true))
+ with open(os.path.join(args.checkpoint_path, "render.html"), "w") as render:
render.write(html)
+ estimate_confusion(all_estimated, all_true)
+
+
+def estimate_confusion(pred, true):
+ pred_filtered = []
+ true_filtered = []
+ for p, t in zip(pred, true):
+ for e in p:
+ e_span = e[:2]
+ for e_t in t:
+ t_span = e_t[:2]
+ if e_span == t_span:
+ pred_filtered.append(e[2])
+ true_filtered.append(e_t[2])
+ break
+
+ labels = sorted(list(set(true_filtered + pred_filtered)))
+ label2ind = dict(zip(labels, range(len(labels))))
+
+ confusion = np.zeros((len(labels), len(labels)))
+
+ for pred, true in zip(pred_filtered, true_filtered):
+ confusion[label2ind[true], label2ind[pred]] += 1
+
+ norm = np.array([x if x != 0 else 1. for x in np.sum(confusion, axis=1)]).reshape(-1,1)
+ confusion /= norm
+
+ import matplotlib.pyplot as plt
+
+ fig, ax = plt.subplots(figsize=(45,45))
+ im = ax.imshow(confusion)
+
+ # We want to show all ticks...
+ ax.set_xticks(np.arange(len(labels)))
+ ax.set_yticks(np.arange(len(labels)))
+ # ... and label them with the respective list entries
+ ax.set_xticklabels(labels)
+ ax.set_yticklabels(labels)
+
+ # Rotate the tick labels and set their alignment.
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
+ rotation_mode="anchor")
+
+ # Loop over data dimensions and create text annotations.
+ for i in range(len(labels)):
+ for j in range(len(labels)):
+ text = ax.text(j, i, f"{confusion[i, j]: .2f}",
+ ha="center", va="center", color="w")
+
+ ax.set_title("Confusion matrix for Python type prediction")
+ fig.tight_layout()
+ # plt.show()
+ plt.savefig(os.path.join(args.checkpoint_path, "confusion.png"))
+
if __name__ == "__main__":
import argparse
@@ -108,16 +185,18 @@ def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=No
parser.add_argument('--word_emb_path', dest='word_emb_path', default=None,
help='Path to the file with edges')
parser.add_argument('checkpoint_path', default=None, help='')
+ parser.add_argument('--random_seed', dest='random_seed', default=None, type=int,
+ help='')
args = parser.parse_args()
train_data, test_data = read_data(
open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True,
- include_only="entities",
- min_entity_count=5
+ include_only="entities", random_seed=args.random_seed,
+ min_entity_count=3
)
apply_to_dataset(
test_data, PythonBatcher, TypePredictor, graph_emb_path=args.graph_emb_path, word_emb_path=args.word_emb_path,
- checkpoint_path=args.checkpoint_path, batch_size=args.batch_size
+ checkpoint_path=args.checkpoint_path
)
diff --git a/SourceCodeTools/nlp/entity/entity_render.py b/SourceCodeTools/nlp/entity/entity_render.py
index 9b708142..8dcf73ed 100644
--- a/SourceCodeTools/nlp/entity/entity_render.py
+++ b/SourceCodeTools/nlp/entity/entity_render.py
@@ -1,5 +1,5 @@
import spacy
-from SourceCodeTools.nlp.entity.util import inject_tokenizer
+from SourceCodeTools.nlp import create_tokenizer
html_template = """
@@ -57,7 +57,7 @@ def annotate(doc, entities):
def render_annotations(annotations):
- nlp = inject_tokenizer(spacy.blank("en"))
+ nlp = create_tokenizer("spacy")
entries = ""
for annotation in annotations:
text, predicted, annotated = annotation
diff --git a/SourceCodeTools/nlp/entity/map_args_to_mentions.py b/SourceCodeTools/nlp/entity/map_args_to_mentions.py
new file mode 100644
index 00000000..2fabba9c
--- /dev/null
+++ b/SourceCodeTools/nlp/entity/map_args_to_mentions.py
@@ -0,0 +1,49 @@
+import argparse
+import json
+import pickle
+from os.path import join
+
+from SourceCodeTools.code.data.sourcetrail.Dataset import load_data
+from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("working_directory")
+ parser.add_argument("output")
+ args = parser.parse_args()
+
+ nodes, edges = load_data(join(args.working_directory, "nodes.bz2"), join(args.working_directory, "edges.bz2"))
+ type_annotated = set(unpersist(join(args.working_directory, "type_annotations.bz2"))["src"].tolist())
+ arguments = set(nodes.query("type == 'arg'")["id"].tolist())
+ mentions = set(nodes.query("type == 'mention'")["id"].tolist())
+
+ edges["in_mentions"] = edges["src"].apply(lambda src: src in mentions)
+
+ edges["in_args"] = edges["dst"].apply(lambda dst: dst in arguments)
+
+ edges = edges.query("in_mentions == True and in_args == True")
+
+ mapping = {}
+ for src, dst in edges[["src", "dst"]].values:
+ if dst in mapping:
+ print()
+ mapping[dst] = src
+
+ with open(args.output, "w") as sink:
+ with open(join(args.working_directory, "function_annotations.jsonl")) as fa:
+ for line in fa:
+ entry = json.loads(line)
+ new_repl = [[s, e, int(mapping.get(r, r))] for s, e, r in entry["replacements"]]
+ entry["replacements"] = new_repl
+
+ sink.write(f"{json.dumps(entry)}\n")
+
+
+ print()
+
+ # pickle.dump(mapping, open(args.output, "wb"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SourceCodeTools/nlp/entity/split_dataset.py b/SourceCodeTools/nlp/entity/split_dataset.py
new file mode 100644
index 00000000..2aa75596
--- /dev/null
+++ b/SourceCodeTools/nlp/entity/split_dataset.py
@@ -0,0 +1,34 @@
+import pickle
+from collections import Counter
+
+from SourceCodeTools.nlp.entity.type_prediction import get_type_prediction_arguments
+from SourceCodeTools.nlp.entity.utils.data import read_data
+
+
+def get_all_annotations(dataset):
+ ann = []
+ for _, annotations in dataset:
+ for _, _, e in annotations["entities"]:
+ ann.append(e)
+ return ann
+
+
+if __name__ == "__main__":
+
+ args = get_type_prediction_arguments()
+
+ train_data, test_data = read_data(
+ open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, include_only="entities",
+ min_entity_count=args.min_entity_count, random_seed=args.random_seed
+ )
+
+ pickle.dump(train_data, open("type_prediction_dataset_no_defaults_train.pkl", "wb"))
+ pickle.dump(test_data, open("type_prediction_dataset_no_defaults_test.pkl", "wb"))
+
+ ent_counts = Counter(get_all_annotations(train_data)) | Counter(get_all_annotations(test_data))
+
+ with open("type_prediction_dataset_argument_annotations_counts.txt", "w") as sink:
+ for ent, count in ent_counts.most_common():
+ sink.write(f"{ent}\t{count}\n")
+
+ print()
\ No newline at end of file
diff --git a/SourceCodeTools/nlp/entity/tf_models/params.py b/SourceCodeTools/nlp/entity/tf_models/params.py
index f8560eeb..d8a1b8ad 100644
--- a/SourceCodeTools/nlp/entity/tf_models/params.py
+++ b/SourceCodeTools/nlp/entity/tf_models/params.py
@@ -115,4 +115,4 @@ def flatten(params):
att_params = list(
map(flatten, ParameterGrid(att_params_grids))
-)
\ No newline at end of file
+)
diff --git a/SourceCodeTools/nlp/entity/tf_models/tf_model.py b/SourceCodeTools/nlp/entity/tf_models/tf_model.py
index 4a1ba2b9..100d52ec 100644
--- a/SourceCodeTools/nlp/entity/tf_models/tf_model.py
+++ b/SourceCodeTools/nlp/entity/tf_models/tf_model.py
@@ -1,17 +1,21 @@
# import tensorflow as tf
# import sys
# from gensim.models import Word2Vec
+import logging
+from time import time
+
import numpy as np
# from collections import Counter
from scipy.linalg import toeplitz
# from gensim.models import KeyedVectors
-from copy import copy
+from copy import copy, deepcopy
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Input, Embedding, concatenate
from tensorflow.keras import Model
+from tensorflow.keras.layers import Layer
# from tensorflow.keras import regularizers
# from spacy.gold import offsets_from_biluo_tags
@@ -21,194 +25,45 @@
# https://github.com/dhiraa/tener/tree/master/src/tener/models
# https://arxiv.org/pdf/1903.07785v1.pdf
# https://github.com/tensorflow/models/tree/master/research/cvt_text/model
-
-
-class DefaultEmbedding(Model):
- """
- Creates an embedder that provides the default value for the index -1. The default value is a zero-vector
- """
- def __init__(self, init_vectors=None, shape=None, trainable=True):
- super(DefaultEmbedding, self).__init__()
-
- if init_vectors is not None:
- self.embs = tf.Variable(init_vectors, dtype=tf.float32,
- trainable=trainable, name="default_embedder_var")
- shape = init_vectors.shape
- else:
- # TODO
- # the default value is no longer constant. need to replace this with a standard embedder
- self.embs = tf.Variable(tf.random.uniform(shape=(shape[0], shape[1]), dtype=tf.float32),
- name="default_embedder_pad")
- # self.pad = tf.zeros(shape=(1, init_vectors.shape[1]), name="default_embedder_pad")
- # self.pad = tf.random.uniform(shape=(1, init_vectors.shape[1]), name="default_embedder_pad")
- self.pad = tf.Variable(tf.random.uniform(shape=(1, shape[1]), dtype=tf.float32),
- name="default_embedder_pad")
-
-
- def __call__(self, ids):
- emb_matr = tf.concat([self.embs, self.pad], axis=0)
- return tf.nn.embedding_lookup(params=emb_matr, ids=ids)
- # return tf.expand_dims(tf.nn.embedding_lookup(params=self.emb_matr, ids=ids), axis=3)
-
-
-class PositionalEncoding(Model):
- def __init__(self, seq_len, pos_emb_size):
- """
- Create positional embedding matrix for tokens. Currently not using because it results in N^2 computational
- complexity. Should move this functionality to batch preparation.
- :param seq_len: maximum sequence length
- :param pos_emb_size: the dimensionality of positional embeddings
- """
- super(PositionalEncoding, self).__init__()
-
- positions = list(range(seq_len * 2))
- position_splt = positions[:seq_len]
- position_splt.reverse()
- self.position_encoding = tf.constant(toeplitz(position_splt, positions[seq_len:]),
- dtype=tf.int32,
- name="position_encoding")
- # self.position_embedding = tf.random.uniform(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32)
- self.position_embedding = tf.Variable(tf.random.uniform(shape=(seq_len * 2, pos_emb_size), dtype=tf.float32),
- name="position_embedding")
- # self.position_embedding = tf.Variable(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32)
-
- def __call__(self):
- # return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup")
- return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup")
-
-
-class TextCnnLayer(Model):
- """
-
- """
- def __init__(self, out_dim, kernel_shape, activation=None):
- super(TextCnnLayer, self).__init__()
-
- self.kernel_shape = kernel_shape
- self.out_dim = out_dim
-
- self.textConv = Conv2D(filters=out_dim, kernel_size=kernel_shape,
- activation=activation, data_format='channels_last')
-
- padding_size = (self.kernel_shape[0] - 1) // 2
- assert padding_size * 2 + 1 == self.kernel_shape[0]
- self.pad_constant = tf.constant([[0, 0], [padding_size, padding_size], [0, 0], [0, 0]])
-
- def __call__(self, x):
- padded = tf.pad(x, self.pad_constant)
- # emb_sent_exp = tf.expand_dims(input, axis=3)
- convolve = self.textConv(padded)
- return tf.squeeze(convolve, axis=-2)
-
-
-class TextCnn(Model):
- """
- TextCnn model for classifying tokens in a sequence. The model uses following pipeline:
-
- token_embeddings (provided from outside) ->
- several convolutional layers, get representations for all tokens ->
- pass representation for all tokens through a dense network ->
- classify each token
- """
- def __init__(self, input_size, h_sizes, seq_len,
- pos_emb_size, cnn_win_size, dense_size, num_classes,
- activation=None, dense_activation=None, drop_rate=0.2):
- """
-
- :param input_size: dimensionality of input embeddings
- :param h_sizes: sizes of hidden CNN layers, internal dimensionality of token embeddings
- :param seq_len: maximum sequence length
- :param pos_emb_size: dimensionality of positional embeddings
- :param cnn_win_size: width of cnn window
- :param dense_size: number of unius in dense network
- :param num_classes: number of output units
- :param activation: activation for cnn
- :param dense_activation: activation for dense layers
- :param drop_rate: dropout rate for dense network
- """
- super(TextCnn, self).__init__()
-
- self.seq_len = seq_len
- self.h_sizes = h_sizes
- self.dense_size = dense_size
- self.num_classes = num_classes
-
- def infer_kernel_sizes(h_sizes):
- """
- Compute kernel sizes from the desired dimensionality of hidden layers
- :param h_sizes:
- :return:
- """
- kernel_sizes = copy(h_sizes)
- kernel_sizes.pop(-1) # pop last because it is the output of the last CNN layer
- kernel_sizes.insert(0, input_size) # the first kernel size should be (cnn_win_size, input_size)
- kernel_sizes = [(cnn_win_size, ks) for ks in kernel_sizes]
- return kernel_sizes
-
- kernel_sizes = infer_kernel_sizes(h_sizes)
-
- self.layers_tok = [ TextCnnLayer(out_dim=h_size, kernel_shape=kernel_size, activation=activation)
- for h_size, kernel_size in zip(h_sizes, kernel_sizes)]
-
- # self.layers_pos = [TextCnnLayer(out_dim=h_size, kernel_shape=(cnn_win_size, pos_emb_size), activation=activation)
- # for h_size, _ in zip(h_sizes, kernel_sizes)]
-
- # self.positional = PositionalEncoding(seq_len=seq_len, pos_emb_size=pos_emb_size)
-
- if dense_activation is None:
- dense_activation = activation
-
- # self.attention = tfa.layers.MultiHeadAttention(head_size=200, num_heads=1)
-
- self.dense_1 = Dense(dense_size, activation=dense_activation)
- self.dropout_1 = tf.keras.layers.Dropout(rate=drop_rate)
- self.dense_2 = Dense(num_classes, activation=None) # logits
- self.dropout_2 = tf.keras.layers.Dropout(rate=drop_rate)
-
- def __call__(self, embs, training=True):
-
- temp_cnn_emb = embs # shape (?, seq_len, input_size)
-
- # pass embeddings through several CNN layers
- for l in self.layers_tok:
- temp_cnn_emb = l(tf.expand_dims(temp_cnn_emb, axis=3)) # shape (?, seq_len, h_size)
-
- # TODO
- # simplify to one CNN and one attention
-
- # pos_cnn = self.positional()
- # for l in self.layers_pos:
- # pos_cnn = l(tf.expand_dims(pos_cnn, axis=3))
- #
- # cnn_pool_feat = []
- # for i in range(self.seq_len):
- # # slice tensor for the line that corresponds to the current position in the sentence
- # position_features = tf.expand_dims(pos_cnn[i, ...], axis=0, name="exp_dim_%d" % i)
- # # convolution without activation can be combined later, hence: temp_cnn_emb + position_features
- # cnn_pool_feat.append(
- # tf.expand_dims(tf.nn.tanh(tf.reduce_max(temp_cnn_emb + position_features, axis=1)), axis=1))
- # # cnn_pool_feat.append(
- # # tf.expand_dims(tf.nn.tanh(tf.reduce_max(tf.concat([temp_cnn_emb, position_features], axis=-1), axis=1)), axis=1))
- #
- # cnn_pool_features = tf.concat(cnn_pool_feat, axis=1)
- cnn_pool_features = temp_cnn_emb
-
- # cnn_pool_features = self.attention([cnn_pool_features, cnn_pool_features])
-
- # token_features = self.dropout_1(
- # tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1]))
- # , training=training)
-
- # reshape before passing through a dense network
- token_features = tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1])) # shape (? * seq_len, h_size[-1])
-
- # local_h2 = self.dropout_2(
- # self.dense_1(token_features)
- # , training=training)
- local_h2 = self.dense_1(token_features) # shape (? * seq_len, dense_size)
- tag_logits = self.dense_2(local_h2) # shape (? * seq_len, num_classes)
-
- return tf.reshape(tag_logits, (-1, self.seq_len, self.num_classes)) # reshape back, shape (?, seq_len, num_classes)
+from tensorflow_addons.layers import MultiHeadAttention
+from tqdm import tqdm
+
+from SourceCodeTools.models.nlp.TFDecoder import ConditionalAttentionDecoder, FlatDecoder
+from SourceCodeTools.models.nlp.TFEncoder import DefaultEmbedding, TextCnnEncoder, GRUEncoder
+
+
+class T5Encoder(Model):
+ def __init__(self):
+ super(T5Encoder, self).__init__()
+ from transformers.models.t5.modeling_tf_t5 import T5Config, TFT5MainLayer
+ self.adapter = Dense(40, activation=tf.nn.relu)
+ self.t5_config = T5Config(d_model=40, d_ff=40, num_layers=1, num_decoder_layers=1, num_heads=1, is_encoder_decoder=False,d_kv=40)
+ self.encoder = TFT5MainLayer(self.t5_config,)
+
+ def call(self, embs, training=True, mask=None):
+ from transformers.modeling_tf_utils import input_processing
+ inputs = input_processing(
+ func=self.call,
+ config=self.t5_config,
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ encoder_outputs=None,
+ past_key_values=None,
+ inputs_embeds=self.adapter(embs),
+ decoder_inputs_embeds=None,
+ use_cache=False,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ training=training,
+ kwargs_call={},
+ )
+ encoded = self.encoder(inputs["inputs"], inputs_embeds=inputs["inputs_embeds"], training=training)
+ return encoded["last_hidden_state"]
class TypePredictor(Model):
@@ -222,9 +77,10 @@ class TypePredictor(Model):
def __init__(self, tok_embedder, graph_embedder, train_embeddings=False,
h_sizes=None, dense_size=100, num_classes=None,
seq_len=100, pos_emb_size=30, cnn_win_size=3,
- crf_transitions=None, suffix_prefix_dims=50, suffix_prefix_buckets=1000):
+ crf_transitions=None, suffix_prefix_dims=50, suffix_prefix_buckets=1000,
+ no_graph=False):
"""
- Initialize TypePredictor. Model initializes embedding layers and then passes embeddings to TextCnn model
+ Initialize TypePredictor. Model initializes embedding layers and then passes embeddings to TextCnnEncoder model
:param tok_embedder: Embedder for tokens
:param graph_embedder: Embedder for graph nodes
:param train_embeddings: whether to finetune embeddings
@@ -255,31 +111,38 @@ def __init__(self, tok_embedder, graph_embedder, train_embeddings=False,
self.prefix_emb = DefaultEmbedding(shape=(suffix_prefix_buckets, suffix_prefix_dims))
self.suffix_emb = DefaultEmbedding(shape=(suffix_prefix_buckets, suffix_prefix_dims))
- # self.tok_emb = Embedding(input_dim=tok_embedder.e.shape[0],
- # output_dim=tok_embedder.e.shape[1],
- # weights=tok_embedder.e, trainable=train_embeddings,
- # mask_zero=True)
- #
- # self.graph_emb = Embedding(input_dim=graph_embedder.e.shape[0],
- # output_dim=graph_embedder.e.shape[1],
- # weights=graph_embedder.e, trainable=train_embeddings,
- # mask_zero=True)
-
# compute final embedding size after concatenation
- input_dim = tok_embedder.e.shape[1] + suffix_prefix_dims * 2 + graph_embedder.e.shape[1]
-
- self.text_cnn = TextCnn(input_size=input_dim, h_sizes=h_sizes,
- seq_len=seq_len, pos_emb_size=pos_emb_size,
- cnn_win_size=cnn_win_size, dense_size=dense_size,
- num_classes=num_classes, activation=tf.nn.relu,
- dense_activation=tf.nn.tanh)
+ input_dim = tok_embedder.e.shape[1] + suffix_prefix_dims * 2 #+ graph_embedder.e.shape[1]
+
+ if cnn_win_size % 2 == 0:
+ cnn_win_size += 1
+ logging.info(f"Window size should be odd. Setting to {cnn_win_size}")
+
+ self.encoder = TextCnnEncoder(
+ # input_size=input_dim,
+ h_sizes=h_sizes,
+ seq_len=seq_len, pos_emb_size=pos_emb_size,
+ cnn_win_size=cnn_win_size, dense_size=dense_size,
+ out_dim=input_dim, activation=tf.nn.relu,
+ dense_activation=tf.nn.tanh)
+ # self.encoder = GRUEncoder(input_dim=input_dim, out_dim=input_dim, num_layers=1, dropout=0.1)
+ # self.encoder = T5Encoder()
+
+ # self.decoder = ConditionalAttentionDecoder(
+ # input_dim, out_dim=num_classes, num_layers=1, num_heads=1,
+ # ff_hidden=100, target_vocab_size=num_classes, maximum_position_encoding=self.seq_len
+ # )
+ self.decoder = FlatDecoder(out_dims=num_classes)
self.crf_transition_params = None
+ self.supports_masking = True
+ self.use_graph = not no_graph
- def __call__(self, token_ids, prefix_ids, suffix_ids, graph_ids, training=True):
+ # @tf.function
+ def __call__(self, token_ids, prefix_ids, suffix_ids, graph_ids, target=None, training=False, mask=None):
"""
- Inference
+ # Inference
:param token_ids: ids for tokens, shape (?, seq_len)
:param prefix_ids: ids for prefixes, shape (?, seq_len)
:param suffix_ids: ids for suffixes, shape (?, seq_len)
@@ -287,22 +150,35 @@ def __call__(self, token_ids, prefix_ids, suffix_ids, graph_ids, training=True):
:param training: whether to finetune embeddings
:return: logits for token classes, shape (?, seq_len, num_classes)
"""
+ assert mask is not None, "Mask is required"
+
tok_emb = self.tok_emb(token_ids)
- graph_emb = self.graph_emb(graph_ids)
prefix_emb = self.prefix_emb(prefix_ids)
suffix_emb = self.suffix_emb(suffix_ids)
embs = tf.concat([tok_emb,
- graph_emb,
+ # graph_emb,
prefix_emb,
suffix_emb], axis=-1)
- logits = self.text_cnn(embs, training=training)
+ encoded = self.encoder(embs, training=training, mask=mask)
+ # if target is None:
+ # logits = self.decoder.seq_decode(encoded, training=training, mask=mask)
+ # else:
+ if self.use_graph:
+ graph_emb = self.graph_emb(graph_ids)
+ encoded = tf.concat([encoded, graph_emb], axis=-1)
+ logits, _ = self.decoder((encoded, target), training=training, mask=mask) # consider sending input instead of target
return logits
+ def compute_mask(self, inputs, mask=None):
+ mask = self.encoder.compute_mask(None, mask=mask)
+ return self.decoder.compute_mask(None, mask=mask)
- def loss(self, logits, labels, lengths, class_weights=None, extra_mask=None):
+
+ # @tf.function
+ def loss(self, logits, labels, mask, class_weights=None, extra_mask=None):
"""
Compute cross-entropy loss for each meaningful tokens. Mask padded tokens.
:param logits: shape (?, seq_len, num_classes)
@@ -313,7 +189,7 @@ def loss(self, logits, labels, lengths, class_weights=None, extra_mask=None):
:return: average cross-entropy loss
"""
losses = tf.nn.softmax_cross_entropy_with_logits(tf.one_hot(labels, depth=logits.shape[-1]), logits, axis=-1)
- seq_mask = tf.sequence_mask(lengths, self.seq_len)
+ seq_mask = mask # logits._keras_mask# tf.sequence_mask(lengths, self.seq_len)
if extra_mask is not None:
seq_mask = tf.math.logical_and(seq_mask, extra_mask)
if class_weights is None:
@@ -329,28 +205,35 @@ def loss(self, logits, labels, lengths, class_weights=None, extra_mask=None):
return loss
- def score(self, logits, labels, lengths, scorer=None, extra_mask=None):
+ def score(self, logits, labels, mask, scorer=None, extra_mask=None):
"""
Compute precision, recall and f1 scores using the provided scorer function
:param logits: shape (?, seq_len, num_classes)
:param labels: ids of token labels, shape (?, seq_len)
:param lengths: tensor of actual sentence lengths, shape (?,)
- :param scorer: scorer function, takes `pred_labels` and `true_labels` as aguments
+ :param scorer: scorer function, takes `pred_labels` and `true_labels` as arguments
:param extra_mask: mask for hiding some of the token labels, not counting them towards the score, shape (?, seq_len)
:return:
"""
- mask = tf.sequence_mask(lengths, self.seq_len)
+ # mask = logits._keras_mask # tf.sequence_mask(lengths, self.seq_len)
if extra_mask is not None:
mask = tf.math.logical_and(mask, extra_mask)
true_labels = tf.boolean_mask(labels, mask)
argmax = tf.math.argmax(logits, axis=-1)
estimated_labels = tf.cast(tf.boolean_mask(argmax, mask), tf.int32)
- p, r, f1 = scorer(estimated_labels.numpy(), true_labels.numpy())
+ p, r, f1 = scorer(to_numpy(estimated_labels), to_numpy(true_labels))
return p, r, f1
+def to_numpy(tensor):
+ if hasattr(tensor, "numpy"):
+ return tensor.numpy()
+ else:
+ return tf.make_ndarray(tf.make_tensor_proto(tensor))
+
+
# def estimate_crf_transitions(batches, n_tags):
# transitions = []
# for _, _, labels, lengths in batches:
@@ -359,7 +242,7 @@ def score(self, logits, labels, lengths, scorer=None, extra_mask=None):
#
# return np.stack(transitions, axis=0).mean(axis=0)
-
+# @tf.function
def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids, labels, lengths,
extra_mask=None, class_weights=None, scorer=None, finetune=False):
"""
@@ -379,9 +262,11 @@ def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids,
:return: values for loss, precision, recall and f1-score
"""
with tf.GradientTape() as tape:
- logits = model(token_ids, prefix, suffix, graph_ids, training=True)
- loss = model.loss(logits, labels, lengths, class_weights=class_weights, extra_mask=extra_mask)
- p, r, f1 = model.score(logits, labels, lengths, scorer=scorer, extra_mask=extra_mask)
+ seq_mask = tf.sequence_mask(lengths, token_ids.shape[1])
+ logits = model(token_ids, prefix, suffix, graph_ids, target=None, training=True, mask=seq_mask)
+ loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask)
+ # token_acc = tf.reduce_sum(tf.cast(tf.argmax(logits, axis=-1) == labels, tf.float32)) / (token_ids.shape[0] * token_ids.shape[1])
+ p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask)
gradients = tape.gradient(loss, model.trainable_variables)
if not finetune:
# do not update embeddings
@@ -393,7 +278,10 @@ def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids,
return loss, p, r, f1
# @tf.function
-def test_step(model, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None):
+def test_step(
+ model, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None,
+ no_localization=False
+):
"""
:param model: TypePrediction model instance
@@ -408,14 +296,44 @@ def test_step(model, token_ids, prefix, suffix, graph_ids, labels, lengths, extr
:param scorer: scorer function, takes `pred_labels` and `true_labels` as aguments
:return: values for loss, precision, recall and f1-score
"""
- logits = model(token_ids, prefix, suffix, graph_ids, training=False)
- loss = model.loss(logits, labels, lengths, class_weights=class_weights, extra_mask=extra_mask)
- p, r, f1 = model.score(logits, labels, lengths, scorer=scorer, extra_mask=extra_mask)
+ seq_mask = tf.sequence_mask(lengths, token_ids.shape[1])
+ logits = model(token_ids, prefix, suffix, graph_ids, target=None, training=False, mask=seq_mask)
+ loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask)
+ p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask)
return loss, p, r, f1
-def train(model, train_batches, test_batches, epochs, report_every=10, scorer=None, learning_rate=0.01, learning_rate_decay=1., finetune=False, summary_writer=None):
+def test(model, test_batches, scorer=None):
+ test_alosses = []
+ test_aps = []
+ test_ars = []
+ test_af1s = []
+
+ for ind, batch in enumerate(test_batches):
+ # token_ids, graph_ids, labels, class_weights, lengths = b
+ test_loss, test_p, test_r, test_f1 = test_step(
+ model=model, token_ids=batch['tok_ids'],
+ prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'],
+ labels=batch['tags'], lengths=batch['lens'], extra_mask=batch['hide_mask'],
+ # class_weights=batch['class_weights'],
+ scorer=scorer
+ )
+
+ test_alosses.append(test_loss)
+ test_aps.append(test_p)
+ test_ars.append(test_r)
+ test_af1s.append(test_f1)
+
+ def avg(arr):
+ return sum(arr) / len(arr)
+
+ return avg(test_alosses), avg(test_aps), avg(test_ars), avg(test_af1s)
+
+def train(
+ model, train_batches, test_batches, epochs, report_every=10, scorer=None, learning_rate=0.01,
+ learning_rate_decay=1., finetune=False, summary_writer=None, save_ckpt_fn=None, no_localization=False
+):
assert summary_writer is not None
@@ -430,6 +348,8 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No
num_train_batches = len(train_batches)
num_test_batches = len(test_batches)
+ best_f1 = 0.
+
try:
with summary_writer.as_default():
@@ -440,12 +360,15 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No
rs = []
f1s = []
- for ind, batch in enumerate(train_batches):
+ start = time()
+
+ for ind, batch in enumerate(tqdm(train_batches)):
# token_ids, graph_ids, labels, class_weights, lengths = b
loss, p, r, f1 = train_step_finetune(
model=model, optimizer=optimizer, token_ids=batch['tok_ids'],
prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'],
- labels=batch['tags'], lengths=batch['lens'], extra_mask=batch['hide_mask'],
+ labels=batch['tags'], lengths=batch['lens'],
+ extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'],
# class_weights=batch['class_weights'],
scorer=scorer, finetune=finetune and e/epochs > 0.6
)
@@ -459,12 +382,18 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No
tf.summary.scalar("Recall/Train", r, step=e * num_train_batches + ind)
tf.summary.scalar("F1/Train", f1, step=e * num_train_batches + ind)
+ test_alosses = []
+ test_aps = []
+ test_ars = []
+ test_af1s = []
+
for ind, batch in enumerate(test_batches):
# token_ids, graph_ids, labels, class_weights, lengths = b
test_loss, test_p, test_r, test_f1 = test_step(
model=model, token_ids=batch['tok_ids'],
prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'],
- labels=batch['tags'], lengths=batch['lens'], extra_mask=batch['hide_mask'],
+ labels=batch['tags'], lengths=batch['lens'],
+ extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'],
# class_weights=batch['class_weights'],
scorer=scorer
)
@@ -473,14 +402,24 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No
tf.summary.scalar("Precision/Test", test_p, step=e * num_test_batches + ind)
tf.summary.scalar("Recall/Test", test_r, step=e * num_test_batches + ind)
tf.summary.scalar("F1/Test", test_f1, step=e * num_test_batches + ind)
+ test_alosses.append(test_loss)
+ test_aps.append(test_p)
+ test_ars.append(test_r)
+ test_af1s.append(test_f1)
- print(f"Epoch: {e}, Train Loss: {sum(losses) / len(losses)}, Train P: {sum(ps) / len(ps)}, Train R: {sum(rs) / len(rs)}, Train F1: {sum(f1s) / len(f1s)}, "
- f"Test loss: {test_loss}, Test P: {test_p}, Test R: {test_r}, Test F1: {test_f1}")
+ epoch_time = time() - start
train_losses.append(float(sum(losses) / len(losses)))
train_f1s.append(float(sum(f1s) / len(f1s)))
- test_losses.append(float(test_loss))
- test_f1s.append(float(test_f1))
+ test_losses.append(float(sum(test_alosses) / len(test_alosses)))
+ test_f1s.append(float(sum(test_af1s) / len(test_af1s)))
+
+ print(f"Epoch: {e}, {epoch_time: .2f} s, Train Loss: {train_losses[-1]: .4f}, Train P: {sum(ps) / len(ps): .4f}, Train R: {sum(rs) / len(rs): .4f}, Train F1: {sum(f1s) / len(f1s): .4f}, "
+ f"Test loss: {test_losses[-1]: .4f}, Test P: {sum(test_aps) / len(test_aps): .4f}, Test R: {sum(test_ars) / len(test_ars): .4f}, Test F1: {test_f1s[-1]: .4f}")
+
+ if save_ckpt_fn is not None and float(test_f1s[-1]) > best_f1:
+ save_ckpt_fn()
+ best_f1 = float(test_f1s[-1])
lr.assign(lr * learning_rate_decay)
diff --git a/SourceCodeTools/nlp/entity/type_prediction.py b/SourceCodeTools/nlp/entity/type_prediction.py
index b15851b6..3394d0ef 100644
--- a/SourceCodeTools/nlp/entity/type_prediction.py
+++ b/SourceCodeTools/nlp/entity/type_prediction.py
@@ -1,25 +1,32 @@
from __future__ import unicode_literals, print_function
import json
+import logging
import os
import pickle
from copy import copy
from datetime import datetime
+# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
+os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
+
+from pathlib import Path
+
import tensorflow
from SourceCodeTools.nlp.batchers import PythonBatcher
from SourceCodeTools.nlp.entity import parse_biluo
from SourceCodeTools.nlp.entity.tf_models.params import cnn_params
-from SourceCodeTools.nlp.entity.utils import get_unique_entities
+from SourceCodeTools.nlp.entity.utils import get_unique_entities, overlap
from SourceCodeTools.nlp.entity.utils.data import read_data
def load_pkl_emb(path):
"""
-
- :param path:
- :return:
+ Load graph embeddings from a pickle file. Embeddigns are stored in class Embedder or in a list of Embedders. The
+ last embedder in the list is returned.
+ :param path: path to graph embeddigs stored as Embedder pickle
+ :return: Embedder object
"""
embedder = pickle.load(open(path, "rb"))
if isinstance(embedder, list):
@@ -27,7 +34,41 @@ def load_pkl_emb(path):
return embedder
-def scorer(pred, labels, tagmap, eps=1e-8):
+def compute_precision_recall_f1(tp, fp, fn, eps=1e-8):
+ precision = tp / (tp + fp + eps)
+ recall = tp / (tp + fn + eps)
+ f1 = 2 * precision * recall / (precision + recall + eps)
+ return precision, recall, f1
+
+
+def localized_f1(pred_spans, true_spans, eps=1e-8):
+
+ tp = 0.
+ fp = 0.
+ fn = 0.
+
+ for pred, true in zip(pred_spans, true_spans):
+ if true != "O":
+ if true == pred:
+ tp += 1
+ else:
+ if pred == "O":
+ fn += 1
+ else:
+ fp += 1
+
+ return compute_precision_recall_f1(tp, fp, fn)
+
+
+def span_f1(pred_spans, true_spans, eps=1e-8):
+ tp = len(pred_spans.intersection(true_spans))
+ fp = len(pred_spans - true_spans)
+ fn = len(true_spans - pred_spans)
+
+ return compute_precision_recall_f1(tp, fp, fn)
+
+
+def scorer(pred, labels, tagmap, no_localization=False, eps=1e-8):
"""
Compute f1 score, precision, and recall from BILUO labels
:param pred: predicted BILUO labels
@@ -42,16 +83,13 @@ def scorer(pred, labels, tagmap, eps=1e-8):
pred_biluo = [tagmap.inverse(p) for p in pred]
labels_biluo = [tagmap.inverse(p) for p in labels]
- pred_spans = set(parse_biluo(pred_biluo))
- true_spans = set(parse_biluo(labels_biluo))
+ if not no_localization:
+ pred_spans = set(parse_biluo(pred_biluo))
+ true_spans = set(parse_biluo(labels_biluo))
- tp = len(pred_spans.intersection(true_spans))
- fp = len(pred_spans - true_spans)
- fn = len(true_spans - pred_spans)
-
- precision = tp / (tp + fp + eps)
- recall = tp / (tp + fn + eps)
- f1 = 2 * precision * recall / (precision + recall + eps)
+ precision, recall, f1 = span_f1(pred_spans, true_spans, eps=eps)
+ else:
+ precision, recall, f1 = localized_f1(pred_biluo, labels_biluo, eps=eps)
return precision, recall, f1
@@ -74,7 +112,8 @@ def write_config(trial_dir, params, extra_params=None):
class ModelTrainer:
def __init__(self, train_data, test_data, params, graph_emb_path=None, word_emb_path=None,
- output_dir=None, epochs=30, batch_size=32, seq_len=100, finetune=False, trials=1):
+ output_dir=None, epochs=30, batch_size=32, seq_len=100, finetune=False, trials=1,
+ no_localization=False, ckpt_path=None, no_graph=False):
self.set_batcher_class()
self.set_model_class()
@@ -89,6 +128,9 @@ def __init__(self, train_data, test_data, params, graph_emb_path=None, word_emb_
self.finetune = finetune
self.trials = trials
self.seq_len = seq_len
+ self.no_localization = no_localization
+ self.ckpt_path = ckpt_path
+ self.no_graph = no_graph
def set_batcher_class(self):
self.batcher = PythonBatcher
@@ -97,16 +139,23 @@ def set_model_class(self):
from SourceCodeTools.nlp.entity.tf_models.tf_model import TypePredictor
self.model = TypePredictor
- def get_batcher(self, *args, **kwards):
- return self.batcher(*args, **kwards)
+ def get_batcher(self, *args, **kwargs):
+ return self.batcher(*args, **kwargs)
def get_model(self, *args, **kwargs):
- return self.model(*args, **kwargs)
+ model = self.model(*args, **kwargs)
+ if self.ckpt_path is not None:
+ model.load_weights(os.path.join(self.ckpt_path, "checkpoint"))
+ return model
def train(self, *args, **kwargs):
from SourceCodeTools.nlp.entity.tf_models.tf_model import train
return train(*args, summary_writer=self.summary_writer, **kwargs)
+ def test(self, *args, **kwargs):
+ from SourceCodeTools.nlp.entity.tf_models.tf_model import test
+ return test(*args, **kwargs)
+
def create_summary_writer(self, path):
self.summary_writer = tensorflow.summary.create_file_writer(path)
@@ -114,48 +163,69 @@ def create_summary_writer(self, path):
# with self.summary_writer.as_default():
# tensorflow.summary.scalar(value_name, value, step=step)
- def train_model(self):
+ def get_dataloaders(self, word_emb, graph_emb, suffix_prefix_buckets):
- graph_emb = load_pkl_emb(self.graph_emb_path)
- word_emb = load_pkl_emb(self.word_emb_path)
-
- suffix_prefix_buckets = params.pop("suffix_prefix_buckets")
+ if self.ckpt_path is not None:
+ tagmap = pickle.load(open(os.path.join(self.ckpt_path, "tag_types.pkl"), "rb"))
+ else:
+ tagmap = None
train_batcher = self.get_batcher(
- train_data, self.batch_size, seq_len=self.seq_len, graphmap=graph_emb.ind, wordmap=word_emb.ind, tagmap=None,
- class_weights=False, element_hash_size=suffix_prefix_buckets
+ self.train_data, self.batch_size, seq_len=self.seq_len,
+ graphmap=graph_emb.ind if graph_emb is not None else None,
+ wordmap=word_emb.ind, tagmap=tagmap,
+ class_weights=False, element_hash_size=suffix_prefix_buckets, no_localization=self.no_localization
)
test_batcher = self.get_batcher(
- test_data, self.batch_size, seq_len=self.seq_len, graphmap=graph_emb.ind, wordmap=word_emb.ind,
+ self.test_data, self.batch_size, seq_len=self.seq_len,
+ graphmap=graph_emb.ind if graph_emb is not None else None,
+ wordmap=word_emb.ind,
tagmap=train_batcher.tagmap, # use the same mapping
- class_weights=False, element_hash_size=suffix_prefix_buckets # class_weights are not used for testing
+ class_weights=False, element_hash_size=suffix_prefix_buckets, # class_weights are not used for testing
+ no_localization=self.no_localization
)
+ return train_batcher, test_batcher
+
+ def train_model(self):
+
+ graph_emb = load_pkl_emb(self.graph_emb_path) if self.graph_emb_path is not None else None
+ word_emb = load_pkl_emb(self.word_emb_path)
+
+ suffix_prefix_buckets = params.pop("suffix_prefix_buckets")
+
+ train_batcher, test_batcher = self.get_dataloaders(word_emb, graph_emb, suffix_prefix_buckets)
print(f"\n\n{params}")
lr = params.pop("learning_rate")
lr_decay = params.pop("learning_rate_decay")
- param_dir = os.path.join(output_dir, str(datetime.now()))
+ timestamp = str(datetime.now()).replace(":","-").replace(" ","_")
+ param_dir = os.path.join(output_dir, timestamp)
os.mkdir(param_dir)
for trial_ind in range(self.trials):
trial_dir = os.path.join(param_dir, repr(trial_ind))
+ logging.info(f"Running trial: {timestamp}")
os.mkdir(trial_dir)
self.create_summary_writer(trial_dir)
model = self.get_model(
word_emb, graph_emb, train_embeddings=self.finetune, suffix_prefix_buckets=suffix_prefix_buckets,
- num_classes=train_batcher.num_classes(), seq_len=self.seq_len, **params
+ num_classes=train_batcher.num_classes(), seq_len=self.seq_len, no_graph=self.no_graph, **params
)
+ def save_ckpt_fn():
+ checkpoint_path = os.path.join(trial_dir, "checkpoint")
+ model.save_weights(checkpoint_path)
+
train_losses, train_f1, test_losses, test_f1 = self.train(
model=model, train_batches=train_batcher, test_batches=test_batcher,
- epochs=self.epochs, learning_rate=lr, scorer=lambda pred, true: scorer(pred, true, train_batcher.tagmap),
- learning_rate_decay=lr_decay, finetune=self.finetune
+ epochs=self.epochs, learning_rate=lr, scorer=lambda pred, true: scorer(pred, true, train_batcher.tagmap, no_localization=self.no_localization),
+ learning_rate_decay=lr_decay, finetune=self.finetune, save_ckpt_fn=save_ckpt_fn, no_localization=self.no_localization
)
- checkpoint_path = os.path.join(trial_dir, "checkpoint")
- model.save_weights(checkpoint_path)
+ # checkpoint_path = os.path.join(trial_dir, "checkpoint")
+ # model.save_weights(checkpoint_path)
metadata = {
"train_losses": train_losses,
@@ -166,9 +236,13 @@ def train_model(self):
"learning_rate_decay": lr_decay,
"epochs": self.epochs,
"suffix_prefix_buckets": suffix_prefix_buckets,
- "seq_len": self.seq_len
+ "seq_len": self.seq_len,
+ "batch_size": self.batch_size,
+ "no_localization": self.no_localization
}
+ print("Maximum f1:", max(test_f1))
+
# write_config(trial_dir, params, extra_params={"suffix_prefix_buckets": suffix_prefix_buckets, "seq_len": seq_len})
metadata.update(params)
@@ -184,11 +258,13 @@ def get_type_prediction_arguments():
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--data_path', dest='data_path', default=None,
- help='Path to the file with nodes')
+ help='Path to the dataset file')
parser.add_argument('--graph_emb_path', dest='graph_emb_path', default=None,
- help='Path to the file with edges')
+ help='Path to the file with graph embeddings')
parser.add_argument('--word_emb_path', dest='word_emb_path', default=None,
- help='Path to the file with edges')
+ help='Path to the file with token embeddings')
+ parser.add_argument('--type_ann_edges', dest='type_ann_edges', default=None,
+ help='Path to type annotation edges')
parser.add_argument('--learning_rate', dest='learning_rate', default=0.01, type=float,
help='')
parser.add_argument('--learning_rate_decay', dest='learning_rate_decay', default=1.0, type=float,
@@ -199,17 +275,34 @@ def get_type_prediction_arguments():
help='')
parser.add_argument('--max_seq_len', dest='max_seq_len', default=100, type=int,
help='')
- # parser.add_argument('--pretrain_phase', dest='pretrain_phase', default=20, type=int,
- # help='')
+ parser.add_argument('--min_entity_count', dest='min_entity_count', default=3, type=int,
+ help='')
+ parser.add_argument('--pretraining_epochs', dest='pretraining_epochs', default=0, type=int,
+ help='')
+ parser.add_argument('--ckpt_path', dest='ckpt_path', default=None, type=str,
+ help='')
parser.add_argument('--epochs', dest='epochs', default=500, type=int,
help='')
parser.add_argument('--trials', dest='trials', default=1, type=int,
help='')
+ parser.add_argument('--gpu', dest='gpu', default=-1, type=int,
+ help='Does not work with Tensorflow backend')
parser.add_argument('--finetune', action='store_true')
+ parser.add_argument('--no_localization', action='store_true')
+ parser.add_argument('--restrict_allowed', action='store_true', default=False)
+ parser.add_argument('--no_graph', action='store_true', default=False)
parser.add_argument('model_output',
help='')
args = parser.parse_args()
+
+ if args.finetune is False and args.pretraining_epochs > 0:
+ logging.info(f"Finetuning is disabled, but the the number of pretraining epochs is {args.pretraining_epochs}. Setting pretraining epochs to 0.")
+ args.pretraining_epochs = 0
+
+ if args.graph_emb_path is not None and not os.path.isfile(args.graph_emb_path):
+ logging.warning(f"File with graph embeddings does not exist: {args.graph_emb_path}")
+ args.graph_emb_path = None
return args
@@ -219,6 +312,22 @@ def save_entities(path, entities):
entitiesfile.write(f"{e}\n")
+def filter_labels(dataset, allowed=None, field=None):
+ if allowed is None:
+ return dataset
+ dataset = copy(dataset)
+ for sent, annotations in dataset:
+ annotations["entities"] = [e for e in annotations["entities"] if e[2] in allowed]
+ return dataset
+
+
+def find_example(dataset, needed_label):
+ for sent, annotations in dataset:
+ for start, end, e in annotations["entities"]:
+ if e == needed_label:
+ print(f"{sent}: {sent[start: end]}")
+
+
if __name__ == "__main__":
args = get_type_prediction_arguments()
@@ -228,10 +337,27 @@ def save_entities(path, entities):
# allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray',
# 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', 'Type'}
-
- train_data, test_data = read_data(
- open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, include_only="entities",
- min_entity_count=1, random_seed=args.random_seed
+ if args.restrict_allowed:
+ allowed = {
+ 'str', 'Optional', 'int', 'Any', 'Union', 'bool', 'Callable', 'Dict', 'bytes', 'float', 'Description',
+ 'List', 'Sequence', 'Namespace', 'T', 'Type', 'object', 'HTTPServerRequest', 'Future', "Matcher"
+ }
+ else:
+ allowed = None
+
+ # train_data, test_data = read_data(
+ # open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, include_only="entities",
+ # min_entity_count=args.min_entity_count, random_seed=args.random_seed
+ # )
+
+ dataset_dir = Path(args.data_path).parent
+ train_data = filter_labels(
+ pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_train.pkl"), "rb")),
+ allowed=allowed
+ )
+ test_data = filter_labels(
+ pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_test.pkl"), "rb")),
+ allowed=allowed
)
unique_entities = get_unique_entities(train_data, field="entities")
@@ -241,6 +367,7 @@ def save_entities(path, entities):
trainer = ModelTrainer(
train_data, test_data, params, graph_emb_path=args.graph_emb_path, word_emb_path=args.word_emb_path,
output_dir=output_dir, epochs=args.epochs, batch_size=args.batch_size,
- finetune=args.finetune, trials=args.trials, seq_len=args.max_seq_len,
+ finetune=args.finetune, trials=args.trials, seq_len=args.max_seq_len, no_localization=args.no_localization,
+ ckpt_path=args.ckpt_path, no_graph=args.no_graph
)
trainer.train_model()
diff --git a/SourceCodeTools/nlp/entity/utils/__init__.py b/SourceCodeTools/nlp/entity/utils/__init__.py
index 52df6d10..8cf742ca 100644
--- a/SourceCodeTools/nlp/entity/utils/__init__.py
+++ b/SourceCodeTools/nlp/entity/utils/__init__.py
@@ -24,7 +24,7 @@
def normalize_entities(entities):
- norm = lambda x: x.split("[")[0].split(".")[-1]
+ norm = lambda x: x.strip("\"").strip("'").split("[")[0].split(".")[-1]
if len(entities) == 0:
return entities
diff --git a/SourceCodeTools/nlp/entity/utils/data.py b/SourceCodeTools/nlp/entity/utils/data.py
index 8ef6652c..8135b607 100644
--- a/SourceCodeTools/nlp/entity/utils/data.py
+++ b/SourceCodeTools/nlp/entity/utils/data.py
@@ -55,7 +55,7 @@ def read_data(
logging.info("Splitting dataset randomly")
else:
random.seed(random_seed)
- logging.warning(f"Using ransom seed {random_seed} for dataset split")
+ logging.warning(f"Using random seed {random_seed} for dataset split")
filter_infrequent(
train_data, entities_in_dataset=Counter(entities_in_dataset),
diff --git a/SourceCodeTools/nlp/entity/utils/visualize_dataset.py b/SourceCodeTools/nlp/entity/utils/visualize_dataset.py
new file mode 100644
index 00000000..7098156e
--- /dev/null
+++ b/SourceCodeTools/nlp/entity/utils/visualize_dataset.py
@@ -0,0 +1,36 @@
+import os.path
+
+import spacy
+import sys
+import json
+from spacy.gold import biluo_tags_from_offsets
+from spacy.tokenizer import Tokenizer
+import re
+
+from SourceCodeTools.nlp import create_tokenizer
+from SourceCodeTools.nlp.entity import parse_biluo
+from SourceCodeTools.nlp.entity.entity_render import render_annotations
+
+annotations_path = sys.argv[1]
+output_path = os.path.dirname(annotations_path)
+
+data = []
+references = []
+results = []
+
+nlp = create_tokenizer("spacy")
+
+with open(annotations_path) as annotations:
+ for line in annotations:
+ entry = json.loads(line.strip())
+
+ doc = nlp(entry['text'])
+ tags_r = parse_biluo(biluo_tags_from_offsets(doc, entry['replacements']))
+ tags_e = parse_biluo(biluo_tags_from_offsets(doc, entry['ents']))
+
+ data.append([entry['text']])
+ references.append(tags_r)
+ results.append(tags_e)
+
+html = render_annotations(zip(data, references, results))
+open(os.path.join(output_path, "annotations.html"), "w").write(html)
\ No newline at end of file
diff --git a/SourceCodeTools/nlp/generation/data.py b/SourceCodeTools/nlp/generation/data.py
new file mode 100644
index 00000000..9f8f0c68
--- /dev/null
+++ b/SourceCodeTools/nlp/generation/data.py
@@ -0,0 +1 @@
+# dataloader for documentation generation
\ No newline at end of file
diff --git a/SourceCodeTools/nlp/tokenizers.py b/SourceCodeTools/nlp/tokenizers.py
index e20fa9e7..6da01efe 100644
--- a/SourceCodeTools/nlp/tokenizers.py
+++ b/SourceCodeTools/nlp/tokenizers.py
@@ -75,6 +75,20 @@ def default_tokenizer(text):
from SourceCodeTools.nlp.embed.bpe import load_bpe_model, make_tokenizer
return make_tokenizer(load_bpe_model(bpe_path))
+ elif type == "codebert":
+ from transformers import RobertaTokenizer
+ import spacy
+ from spacy.tokens import Doc
+
+ tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
+ nlp = spacy.blank("en")
+
+ def tokenize(text):
+ tokens = tokenizer.tokenize(text)
+ doc = Doc(nlp.vocab, tokens, spaces=[False] * len(tokens))
+ return doc
+
+ return tokenize
else:
raise Exception("Supported tokenizer types: spacy, regex, bpe")
diff --git a/SourceCodeTools/tabular/common.py b/SourceCodeTools/tabular/common.py
index 771945fa..ec4a74ac 100644
--- a/SourceCodeTools/tabular/common.py
+++ b/SourceCodeTools/tabular/common.py
@@ -2,6 +2,13 @@
def compact_property(values, return_order=False, index_from_one=False):
+ """
+ Returns a map from the original value to the compact index
+ :param values:
+ :param return_order:
+ :param index_from_one:
+ :return:
+ """
uniq = numpy.unique(values)
if index_from_one:
index = range(1, uniq.size + 1)
@@ -12,5 +19,5 @@ def compact_property(values, return_order=False, index_from_one=False):
inv_index = uniq.tolist()
if index_from_one:
inv_index.insert(0, "NA")
- return prop2pid,
+ return prop2pid, inv_index
return prop2pid
\ No newline at end of file
diff --git a/res/python_testdata/example_code/example/ExampleModule.py b/res/python_testdata/example_code/example/ExampleModule.py
new file mode 100644
index 00000000..44700482
--- /dev/null
+++ b/res/python_testdata/example_code/example/ExampleModule.py
@@ -0,0 +1,24 @@
+class ExampleClass:
+ def __init__(self, argument: int):
+ """
+ Initialize. Инициализация
+ :param argument:
+ """
+ self.field = argument
+
+ def method1(self) -> str:
+ """
+ Call another method. Вызов другого метода.
+ :return:
+ """
+ return self.method2()
+
+ def method2(self) -> str:
+ """
+ Simple operations.
+ Простые операции.
+ :return:
+ """
+ variable1: int = self.field
+ variable2: str = str(variable1)
+ return variable2
diff --git a/res/python_testdata/example_code/example/main.py b/res/python_testdata/example_code/example/main.py
new file mode 100644
index 00000000..30caa43d
--- /dev/null
+++ b/res/python_testdata/example_code/example/main.py
@@ -0,0 +1,9 @@
+from ExampleModule import ExampleClass as EC
+
+instance = EC(5)
+
+def main() -> None:
+ print(instance.method1())
+
+main()
+
diff --git a/res/python_testdata/example_code/example2/Module.py b/res/python_testdata/example_code/example2/Module.py
new file mode 100644
index 00000000..6c949d86
--- /dev/null
+++ b/res/python_testdata/example_code/example2/Module.py
@@ -0,0 +1,24 @@
+class Number:
+ def __init__(self, value: int):
+ """
+ Initialize. Инициализация
+ :param argument:
+ """
+ self.val = value
+
+ def __add__(self, value):
+ """
+ Add two numbers.
+ Сложить 2 числа
+ :param value:
+ :return:
+ """
+ return Number(self.val + value.val)
+
+ def __repr__(self) -> str:
+ """
+ Return representation
+ :return: Получить представление
+ """
+ return f"Number({self.val})"
+
diff --git a/res/python_testdata/example_code/example2/main.py b/res/python_testdata/example_code/example2/main.py
new file mode 100644
index 00000000..77b3999d
--- /dev/null
+++ b/res/python_testdata/example_code/example2/main.py
@@ -0,0 +1,9 @@
+from Module import Number
+
+def main():
+ a = Number(4)
+ b = Number(5)
+ print(a+b)
+
+main()
+
diff --git a/scripts/data_collection/python/process_authors.sh b/scripts/data_collection/python/process_authors.sh
new file mode 100644
index 00000000..b7180997
--- /dev/null
+++ b/scripts/data_collection/python/process_authors.sh
@@ -0,0 +1,51 @@
+# this script processes arbitrary code with sourcetrails, no dependencies are pulled
+
+conda activate python37
+
+while read repo
+do
+
+ if [ ! -f "$repo/$repo.srctrlprj" ]; then
+ echo "Creating Sourcetrail project for $repo"
+ echo "
+
+
+
+ Python Source Group
+ .
+
+ .py
+
+
+ .
+
+ enabled
+ Python Source Group
+
+
+ 8
+" > $repo/$repo.srctrlprj
+ fi
+
+ if [ ! -f "$repo/sourcetrail.log" ]; then
+ run_indexing=true
+ else
+ find_edges=$(cat "$repo/sourcetrail.log" | grep " Edges")
+ if [ -z "$find_edges" ]; then
+ echo "Indexing was interrupted, recovering..."
+ run_indexing=true
+ else
+ run_indexing=false
+ fi
+ fi
+
+ if $run_indexing; then
+ echo "Begin indexing"
+ Sourcetrail.sh index -i $repo/$repo.srctrlprj >> $repo/sourcetrail.log
+ else
+ echo "Already indexed"
+ fi
+
+done < "${1:-/dev/stdin}"
+
+conda deactivate
diff --git a/scripts/data_collection/python/process_folders.sh b/scripts/data_collection/python/process_folders.sh
new file mode 100644
index 00000000..09f42a41
--- /dev/null
+++ b/scripts/data_collection/python/process_folders.sh
@@ -0,0 +1,59 @@
+# this script processes arbitrary code with sourcetrails, no dependencies are pulled
+
+conda activate SourceCodeTools
+
+create_sourcetrail_project_if_not_exist () {
+ if [ ! -f $1 ]; then
+ echo "Creating Sourcetrail project for $repo"
+ echo "
+
+
+
+ Python Source Group
+ .
+
+ .py
+
+
+ .
+
+ enabled
+ Python Source Group
+
+
+ 8
+" > $1
+ fi
+}
+
+
+run_indexer () {
+ repo=$1
+ if [ ! -f "$repo/sourcetrail.log" ]; then
+ run_indexing=true
+ else
+ find_edges=$(cat "$repo/sourcetrail.log" | grep " Edges")
+ if [ -z "$find_edges" ]; then
+ echo "Indexing was interrupted, recovering..."
+ run_indexing=true
+ else
+ run_indexing=false
+ fi
+ fi
+
+ if $run_indexing; then
+ echo "Begin indexing"
+ Sourcetrail.sh index -i $repo/$repo.srctrlprj >> $repo/sourcetrail.log
+ else
+ echo "Already indexed"
+ fi
+}
+
+
+while read repo
+do
+ create_sourcetrail_project_if_not_exist "$repo/$repo.srctrlprj"
+ run_indexer "$repo"
+done < "${1:-/dev/stdin}"
+
+conda deactivate
diff --git a/scripts/data_collection/requirements.txt b/scripts/data_collection/requirements.txt
index 9f0f1d16..0ed01413 100644
--- a/scripts/data_collection/requirements.txt
+++ b/scripts/data_collection/requirements.txt
@@ -10,7 +10,7 @@ mkl-service==2.3.0
msgpack-python==0.5.6
nltk==3.5
numba==0.50.1
-numpy==1.18.1
+numpy==1.21.0
pandas==1.0.3
ply==3.11
plyj==0.1
diff --git a/scripts/data_extraction/process_sourcetrail.sh b/scripts/data_extraction/process_sourcetrail.sh
index 52ff257c..4dfbc28e 100644
--- a/scripts/data_extraction/process_sourcetrail.sh
+++ b/scripts/data_extraction/process_sourcetrail.sh
@@ -18,21 +18,21 @@ for dir in "$ENVS_DIR"/*; do
echo "Found package $package_name"
- if [ -f "$dir/source_graph_bodies.csv" ]; then
- rm "$dir/source_graph_bodies.csv"
- fi
- if [ -f "$dir/nodes_with_ast.csv" ]; then
- rm "$dir/nodes_with_ast.csv"
- fi
- if [ -f "$dir/edges_with_ast.csv" ]; then
- rm "$dir/edges_with_ast.csv"
- fi
- if [ -f "$dir/call_seq.csv" ]; then
- rm "$dir/call_seq.csv"
- fi
- if [ -f "$dir/source_graph_function_variable_pairs.csv" ]; then
- rm "$dir/source_graph_function_variable_pairs.csv"
- fi
+# if [ -f "$dir/source_graph_bodies.csv" ]; then
+# rm "$dir/source_graph_bodies.csv"
+# fi
+# if [ -f "$dir/nodes_with_ast.csv" ]; then
+# rm "$dir/nodes_with_ast.csv"
+# fi
+# if [ -f "$dir/edges_with_ast.csv" ]; then
+# rm "$dir/edges_with_ast.csv"
+# fi
+# if [ -f "$dir/call_seq.csv" ]; then
+# rm "$dir/call_seq.csv"
+# fi
+# if [ -f "$dir/source_graph_function_variable_pairs.csv" ]; then
+# rm "$dir/source_graph_function_variable_pairs.csv"
+# fi
if [ -f "$dir/$package_name.srctrldb" ]; then
@@ -40,25 +40,25 @@ for dir in "$ENVS_DIR"/*; do
sqlite3 "$dir/$package_name.srctrldb" < "$SQL_Q"
cd "$RUN_DIR"
sourcetrail_verify_files.py "$dir"
- sourcetrail_node_name_merge.py "$dir/nodes.csv"
- sourcetrail_decode_edge_types.py "$dir/edges.csv"
- sourcetrail_filter_ambiguous_edges.py $dir
- sourcetrail_parse_bodies.py "$dir"
- sourcetrail_call_seq_extractor.py "$dir"
+# sourcetrail_node_name_merge.py "$dir/nodes.csv"
+# sourcetrail_decode_edge_types.py "$dir/edges.csv"
+# sourcetrail_filter_ambiguous_edges.py $dir
+# sourcetrail_parse_bodies.py "$dir"
+# sourcetrail_call_seq_extractor.py "$dir"
- sourcetrail_add_reverse_edges.py "$dir/edges.bz2"
- if [ -n "$BPE_PATH" ]; then
- BPE_PATH=$(realpath "$BPE_PATH")
- sourcetrail_ast_edges.py "$dir" -bpe $BPE_PATH --create_subword_instances
- else
- sourcetrail_ast_edges.py "$dir"
- fi
- sourcetrail_extract_variable_names.py python "$dir"
+# sourcetrail_add_reverse_edges.py "$dir/edges.bz2"
+# if [ -n "$BPE_PATH" ]; then
+# BPE_PATH=$(realpath "$BPE_PATH")
+# sourcetrail_ast_edges.py "$dir" -bpe $BPE_PATH --create_subword_instances
+# else
+# sourcetrail_ast_edges.py "$dir"
+# fi
+# sourcetrail_extract_variable_names.py python "$dir"
- if [ -f "$dir"/edges_with_ast_temp.csv ]; then
- rm "$dir"/edges_with_ast_temp.csv
- fi
+# if [ -f "$dir"/edges_with_ast_temp.csv ]; then
+# rm "$dir"/edges_with_ast_temp.csv
+# fi
else
echo "Package not indexed"
diff --git a/scripts/data_extraction/process_sourcetrail_v2.sh b/scripts/data_extraction/process_sourcetrail_v2.sh
new file mode 100644
index 00000000..a9679b14
--- /dev/null
+++ b/scripts/data_extraction/process_sourcetrail_v2.sh
@@ -0,0 +1,24 @@
+conda activate SourceCodeTools
+
+ENVS_DIR=$(realpath "$1")
+RUN_DIR=$(realpath "$(dirname "$0")")
+SQL_Q=$(realpath "$RUN_DIR/extract.sql")
+
+for dir in "$ENVS_DIR"/*; do
+ if [ -d "$dir" ]; then
+ package_name="$(basename "$dir")"
+
+ echo "Found package $package_name"
+
+ if [ -f "$dir/$package_name.srctrldb" ]; then
+ cd "$dir"
+ sqlite3 "$dir/$package_name.srctrldb" < "$SQL_Q"
+ cd "$RUN_DIR"
+ sourcetrail_verify_files.py "$dir"
+ else
+ echo "Package not indexed"
+ fi
+ fi
+done
+
+conda deactivate
\ No newline at end of file
diff --git a/scripts/training/dglke_log_parser.py b/scripts/training/dglke_log_parser.py
new file mode 100644
index 00000000..859ac38f
--- /dev/null
+++ b/scripts/training/dglke_log_parser.py
@@ -0,0 +1,56 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("logfile")
+
+ args = parser.parse_args()
+
+ with open(args.logfile, "r") as logfile:
+ steps = [[]]
+ buffer_pos = [[]]
+ buffer_neg = [[]]
+ model_names = []
+ for line in logfile:
+ if "pos_loss" in line:
+ buffer_pos[-1].append(float(line.split("pos_loss: ")[-1]))
+ steps[-1].append(float(line.split("(")[-1].split("/")[0]))
+ elif "neg_loss" in line:
+ buffer_neg[-1].append(float(line.split("neg_loss: ")[-1]))
+ elif "Save model" in line:
+ model_names.append(line.split("/")[-2])
+ buffer_pos.append([])
+ buffer_neg.append([])
+ steps.append([])
+
+ for s, v in zip(steps, buffer_pos):
+ plt.plot(np.log10(s), np.log10(v))
+ plt.xlabel("Step")
+ plt.ylabel("Loss")
+ plt.legend(model_names)
+ plt.savefig("positive_loss.png")
+ plt.close()
+
+ for s, v in zip(steps, buffer_neg):
+ plt.plot(np.log10(s), np.log10(v))
+ plt.xlabel("Step")
+ plt.ylabel("Loss")
+ plt.legend(model_names)
+ plt.savefig("negative_loss.png")
+ plt.close()
+
+ for s, v, n in zip(steps, buffer_pos, buffer_neg):
+ plt.plot(np.log10(s), np.log10(np.array(v) + np.array(n)))
+ plt.xlabel("Step")
+ plt.ylabel("Loss")
+ plt.legend(model_names)
+ plt.savefig("overall_loss.png")
+ plt.close()
+
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/scripts/training/evaluate.py b/scripts/training/evaluate.py
index 69e23432..cf2b1aec 100644
--- a/scripts/training/evaluate.py
+++ b/scripts/training/evaluate.py
@@ -1,6 +1,7 @@
import json
import logging
import os
+import shutil
from datetime import datetime
from os import mkdir
from os.path import isdir, join
@@ -8,13 +9,37 @@
from SourceCodeTools.code.data.sourcetrail.Dataset import read_or_create_dataset
from SourceCodeTools.models.graph import RGCNSampling, RGAN, RGGAN
from SourceCodeTools.models.graph.train.utils import get_name, get_model_base
+from SourceCodeTools.models.training_options import add_gnn_train_args
from params import rgcnsampling_params, rggan_params
+def detect_checkpoint_files(path):
+ checkpoints = []
+ for file in os.listdir(path):
+ filepath = os.path.join(path, file)
+ if not os.path.isfile:
+ continue
+
+ if not file.startswith("saved_state_"):
+ continue
+
+ epoch = int(file.split("_")[2].split(".")[0])
+ checkpoints.append((epoch, filepath))
+
+ checckpoints = sorted(checkpoints, key=lambda x: x[0])
+
+ return checckpoints
+
+
def main(models, args):
for model, param_grid in models.items():
for params in param_grid:
+ if args.h_dim is None:
+ params["h_dim"] = args.node_emb_size
+ else:
+ params["h_dim"] = args.h_dim
+
date_time = str(datetime.now())
print("\n\n")
print(date_time)
@@ -26,27 +51,24 @@ def main(models, args):
dataset = read_or_create_dataset(args=args, model_base=model_base)
- if args.training_mode == "multitask":
+ from SourceCodeTools.models.graph.train.sampling_multitask2 import evaluation_procedure
+
+ checkpoints = detect_checkpoint_files(model_base)
- if args.intermediate_supervision:
- # params['use_self_loop'] = True # ????
- from SourceCodeTools.models.graph.train.sampling_multitask_intermediate_supervision import evaluation_procedure
- else:
- from SourceCodeTools.models.graph.train.sampling_multitask2 import evaluation_procedure
+ for epoch, ckpt_path in checkpoints:
+ shutil.copy(ckpt_path, os.path.join(model_base, "saved_state.pt"))
evaluation_procedure(dataset, model, params, args, model_base)
- else:
- raise ValueError("Issue! ", args.training_mode)
if __name__ == "__main__":
import argparse
- from train import add_train_args
+ # from train import add_train_args
parser = argparse.ArgumentParser(description='Process some integers.')
- add_train_args(parser)
+ add_gnn_train_args(parser)
args = parser.parse_args()
@@ -54,11 +76,17 @@ def main(models, args):
model_out = args.model_output_dir
- saved_args = json.loads(open(os.path.join(model_out, "metadata.json")).read())
-
- models_ = {
- eval(saved_args.pop("name").split("-")[0]): [saved_args.pop("parameters")]
- }
+ try:
+ saved_args = json.loads(open(os.path.join(model_out, "metadata.json")).read())
+ models_ = {
+ eval(saved_args.pop("name").split("-")[0]): [saved_args.pop("parameters")]
+ }
+ except:
+ param_keys = 'activation', 'use_self_loop', 'num_steps', 'dropout', 'num_bases', 'lr'
+ saved_args = json.loads(open(os.path.join(model_out, "params.json")).read())
+ models_ = {
+ RGGAN: [{key: saved_args[key] for key in param_keys}]
+ }
args.__dict__.update(saved_args)
args.restore_state = True
diff --git a/scripts/training/evaluate_only.py b/scripts/training/evaluate_only.py
new file mode 100644
index 00000000..699e88a2
--- /dev/null
+++ b/scripts/training/evaluate_only.py
@@ -0,0 +1,105 @@
+import json
+import logging
+from copy import copy
+from datetime import datetime
+from os import mkdir
+from os.path import isdir, join
+
+from SourceCodeTools.code.data.sourcetrail.Dataset import read_or_create_dataset
+from SourceCodeTools.models.graph import RGCNSampling, RGAN, RGGAN
+from SourceCodeTools.models.graph.train.utils import get_name, get_model_base
+from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments
+from params import rgcnsampling_params, rggan_params
+
+
+def main(models, args):
+ for model, param_grid in models.items():
+ for params in param_grid:
+
+ if args.h_dim is None:
+ params["h_dim"] = args.node_emb_size
+ else:
+ params["h_dim"] = args.h_dim
+
+ params["num_steps"] = args.n_layers
+
+ date_time = str(datetime.now())
+ print("\n\n")
+ print(date_time)
+ print(f"Model: {model.__name__}, Params: {params}")
+
+ model_attempt = get_name(model, date_time)
+
+ model_base = get_model_base(args, model_attempt)
+
+ dataset = read_or_create_dataset(args=args, model_base=model_base)
+
+ def write_params(args, params):
+ args = copy(args.__dict__)
+ args.update(params)
+ args['activation'] = args['activation'].__name__
+ with open(join(model_base, "params.json"), "w") as mdata:
+ mdata.write(json.dumps(args, indent=4))
+
+ if not args.restore_state:
+ write_params(args, params)
+
+ from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure
+
+ trainer, scores = \
+ training_procedure(dataset, model, copy(params), args, model_base)
+
+ trainer.save_checkpoint(model_base)
+
+ print("Saving...")
+
+ params['activation'] = params['activation'].__name__
+
+ metadata = {
+ "base": model_base,
+ "name": model_attempt,
+ "parameters": params,
+ "layers": "embeddings.pkl",
+ "mappings": "nodes.csv",
+ "state": "state_dict.pt",
+ "scores": scores,
+ "time": date_time,
+ }
+
+ metadata.update(args.__dict__)
+
+ # pickle.dump(dataset, open(join(model_base, "dataset.pkl"), "wb"))
+ import pickle
+ pickle.dump(trainer.get_embeddings(), open(join(model_base, metadata['layers']), "wb"))
+
+ with open(join(model_base, "metadata.json"), "w") as mdata:
+ mdata.write(json.dumps(metadata, indent=4))
+
+ print("Done saving")
+
+
+if __name__ == "__main__":
+
+ import argparse
+
+ parser = argparse.ArgumentParser(description='')
+ add_gnn_train_args(parser)
+
+ args = parser.parse_args()
+ verify_arguments(args)
+
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s")
+
+ models_ = {
+ # GCNSampling: gcnsampling_params,
+ # GATSampler: gatsampling_params,
+ # RGCNSampling: rgcnsampling_params,
+ # RGAN: rgcnsampling_params,
+ RGGAN: rggan_params
+
+ }
+
+ if not isdir(args.model_output_dir):
+ mkdir(args.model_output_dir)
+
+ main(models_, args)
diff --git a/scripts/training/params.py b/scripts/training/params.py
index 85444fce..0a5a350c 100644
--- a/scripts/training/params.py
+++ b/scripts/training/params.py
@@ -68,12 +68,12 @@
rggan_grids = [
{
- 'h_dim': [100],
- 'num_bases': [-1],
- 'num_steps': [5],
- 'dropout': [0.0],
- 'use_self_loop': [False],
- 'activation': [torch.nn.functional.leaky_relu], # torch.nn.functional.hardswish], #[torch.nn.functional.hardtanh], #torch.nn.functional.leaky_relu
+ 'h_dim': [100], # set from cli
+ 'num_bases': [10],
+ 'num_steps': [9], # set from cli
+ 'dropout': [0.2],
+ 'use_self_loop': [True],
+ 'activation': [torch.tanh], # torch.nn.functional.hardswish], #[torch.nn.functional.hardtanh], #torch.nn.functional.leaky_relu
'lr': [1e-3], # 1e-4]
}
]
diff --git a/scripts/training/substitute_model.py b/scripts/training/substitute_model.py
new file mode 100644
index 00000000..6002b04b
--- /dev/null
+++ b/scripts/training/substitute_model.py
@@ -0,0 +1,120 @@
+import json
+import logging
+from copy import copy
+from datetime import datetime
+from os import mkdir
+from os.path import isdir, join
+
+from SourceCodeTools.code.data.sourcetrail.Dataset import read_or_create_dataset
+from SourceCodeTools.models.graph import RGCNSampling, RGAN, RGGAN
+from SourceCodeTools.models.graph.train.utils import get_name, get_model_base
+from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments
+from params import rgcnsampling_params, rggan_params
+
+
+def main(models, args):
+ for model, param_grid in models.items():
+ for params in param_grid:
+
+ if args.h_dim is None:
+ params["h_dim"] = args.node_emb_size
+ else:
+ params["h_dim"] = args.h_dim
+
+ params["num_steps"] = args.n_layers
+
+ date_time = str(datetime.now())
+ print("\n\n")
+ print(date_time)
+ print(f"Model: {model.__name__}, Params: {params}")
+
+ model_attempt = get_name(model, date_time)
+
+ model_base = get_model_base(args, model_attempt)
+
+ dataset = read_or_create_dataset(args=args, model_base=model_base)
+
+ if args.external_dataset is not None:
+ external_args = copy(args)
+ external_args.data_path = external_args.external_dataset
+ external_args.external_model_base = get_model_base(external_args, model_attempt, force_new=True)
+ def load_external_dataset():
+ return external_args, read_or_create_dataset(args=external_args, model_base=external_args.external_model_base, force_new=True)
+ else:
+ load_external_dataset = None
+
+ def write_params(args, params):
+ args = copy(args.__dict__)
+ args.update(params)
+ args['activation'] = args['activation'].__name__
+ with open(join(model_base, "params.json"), "w") as mdata:
+ mdata.write(json.dumps(args, indent=4))
+
+ if not args.restore_state:
+ write_params(args, params)
+
+ from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure
+
+ trainer, scores = \
+ training_procedure(dataset, model, copy(params), args, model_base, load_external_dataset=load_external_dataset)
+
+ if load_external_dataset is not None:
+ model_base = external_args.external_model_base
+
+ trainer.save_checkpoint(model_base)
+
+ print("Saving...")
+
+ params['activation'] = params['activation'].__name__
+
+ metadata = {
+ "base": model_base,
+ "name": model_attempt,
+ "parameters": params,
+ "layers": "embeddings.pkl",
+ "mappings": "nodes.csv",
+ "state": "state_dict.pt",
+ "scores": scores,
+ "time": date_time,
+ }
+
+ metadata.update(args.__dict__)
+
+ # pickle.dump(dataset, open(join(model_base, "dataset.pkl"), "wb"))
+ import pickle
+ pickle.dump(trainer.get_embeddings(), open(join(model_base, metadata['layers']), "wb"))
+
+ with open(join(model_base, "metadata.json"), "w") as mdata:
+ mdata.write(json.dumps(metadata, indent=4))
+
+ print("Done saving")
+
+
+
+
+
+if __name__ == "__main__":
+
+ import argparse
+
+ parser = argparse.ArgumentParser(description='')
+ add_gnn_train_args(parser)
+
+ args = parser.parse_args()
+ verify_arguments(args)
+
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s")
+
+ models_ = {
+ # GCNSampling: gcnsampling_params,
+ # GATSampler: gatsampling_params,
+ # RGCNSampling: rgcnsampling_params,
+ # RGAN: rgcnsampling_params,
+ RGGAN: rggan_params
+
+ }
+
+ if not isdir(args.model_output_dir):
+ mkdir(args.model_output_dir)
+
+ main(models_, args)
diff --git a/scripts/training/train.py b/scripts/training/train.py
index 840fa42f..c0ec8f09 100644
--- a/scripts/training/train.py
+++ b/scripts/training/train.py
@@ -8,6 +8,7 @@
from SourceCodeTools.code.data.sourcetrail.Dataset import read_or_create_dataset
from SourceCodeTools.models.graph import RGCNSampling, RGAN, RGGAN
from SourceCodeTools.models.graph.train.utils import get_name, get_model_base
+from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments
from params import rgcnsampling_params, rggan_params
@@ -40,24 +41,17 @@ def write_params(args, params):
with open(join(model_base, "params.json"), "w") as mdata:
mdata.write(json.dumps(args, indent=4))
- write_params(args, params)
+ if not args.restore_state:
+ write_params(args, params)
- if args.training_mode == "multitask":
+ from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure
- # if args.intermediate_supervision:
- # # params['use_self_loop'] = True # ????
- # from SourceCodeTools.models.graph.train.sampling_multitask_intermediate_supervision import training_procedure
- # else:
- from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure
+ trainer, scores = \
+ training_procedure(dataset, model, copy(params), args, model_base)
- trainer, scores = \
- training_procedure(dataset, model, copy(params), args, model_base)
+ trainer.save_checkpoint(model_base)
- trainer.save_checkpoint(model_base)
- else:
- raise ValueError("Unknown training mode:", args.training_mode)
-
- print("Saving...", end="")
+ print("Saving...")
params['activation'] = params['activation'].__name__
@@ -81,91 +75,18 @@ def write_params(args, params):
with open(join(model_base, "metadata.json"), "w") as mdata:
mdata.write(json.dumps(metadata, indent=4))
- print("done")
-
-
-def add_data_arguments(parser):
- parser.add_argument('--data_path', '-d', dest='data_path', default=None, help='Path to the files')
- parser.add_argument('--train_frac', dest='train_frac', default=0.9, type=float, help='')
- parser.add_argument('--filter_edges', dest='filter_edges', default=None, help='Edges filtered before training')
- parser.add_argument('--min_count_for_objectives', dest='min_count_for_objectives', default=5, type=int, help='')
- parser.add_argument('--packages_file', dest='packages_file', default=None, type=str, help='')
- parser.add_argument('--self_loops', action='store_true')
- parser.add_argument('--use_node_types', action='store_true')
- parser.add_argument('--use_edge_types', action='store_true')
- parser.add_argument('--restore_state', action='store_true')
- parser.add_argument('--no_global_edges', action='store_true')
- parser.add_argument('--remove_reverse', action='store_true')
-
-
-def add_pretraining_arguments(parser):
- parser.add_argument('--pretrained', '-p', dest='pretrained', default=None, help='')
- parser.add_argument('--tokenizer', '-t', dest='tokenizer', default=None, help='')
- parser.add_argument('--pretraining_phase', dest='pretraining_phase', default=0, type=int, help='')
-
-
-def add_training_arguments(parser):
- parser.add_argument('--embedding_table_size', dest='embedding_table_size', default=200000, type=int, help='Batch size')
- parser.add_argument('--random_seed', dest='random_seed', default=None, type=int, help='')
-
- parser.add_argument('--node_emb_size', dest='node_emb_size', default=100, type=int, help='')
- parser.add_argument('--elem_emb_size', dest='elem_emb_size', default=100, type=int, help='')
- parser.add_argument('--num_per_neigh', dest='num_per_neigh', default=10, type=int, help='')
- parser.add_argument('--neg_sampling_factor', dest='neg_sampling_factor', default=3, type=int, help='')
-
- parser.add_argument('--use_layer_scheduling', action='store_true')
- parser.add_argument('--schedule_layers_every', dest='schedule_layers_every', default=10, type=int, help='')
-
- parser.add_argument('--epochs', dest='epochs', default=100, type=int, help='Number of epochs')
- parser.add_argument('--batch_size', dest='batch_size', default=128, type=int, help='Batch size')
-
- parser.add_argument("--h_dim", dest="h_dim", default=None, type=int)
- parser.add_argument("--n_layers", dest="n_layers", default=5, type=int)
- parser.add_argument("--objectives", dest="objectives", default=None, type=str)
-
-
-def add_scoring_arguments(parser):
- parser.add_argument('--measure_ndcg', action='store_true')
- parser.add_argument('--dilate_ndcg', dest='dilate_ndcg', default=200, type=int, help='')
-
-
-def add_performance_arguments(parser):
- parser.add_argument('--no_checkpoints', dest="save_checkpoints", action='store_false')
-
- parser.add_argument('--use_gcn_checkpoint', action='store_true')
- parser.add_argument('--use_att_checkpoint', action='store_true')
- parser.add_argument('--use_gru_checkpoint', action='store_true')
-
-
-def add_train_args(parser):
- parser.add_argument(
- '--training_mode', '-tr', dest='training_mode', default=None,
- help='Selects one of training procedures [multitask]'
- )
-
- add_data_arguments(parser)
- add_pretraining_arguments(parser)
- add_training_arguments(parser)
- add_scoring_arguments(parser)
- add_performance_arguments(parser)
-
- parser.add_argument('--note', dest='note', default="", help='Note, added to metadata')
- parser.add_argument('model_output_dir', help='Location of the final model')
+ print("Done saving")
- # parser.add_argument('--intermediate_supervision', action='store_true')
- parser.add_argument('--gpu', dest='gpu', default=-1, type=int, help='')
-def verify_arguments(args):
- pass
if __name__ == "__main__":
import argparse
- parser = argparse.ArgumentParser(description='Process some integers.')
- add_train_args(parser)
+ parser = argparse.ArgumentParser(description='')
+ add_gnn_train_args(parser)
args = parser.parse_args()
verify_arguments(args)
diff --git a/scripts/training/train_all_emb_types.sh b/scripts/training/train_all_emb_types.sh
index 4ddd0545..1e2dd656 100755
--- a/scripts/training/train_all_emb_types.sh
+++ b/scripts/training/train_all_emb_types.sh
@@ -1,7 +1,16 @@
+DATASET=$1
+DGLKE_OUT=$2
+
+#conda activate SourceCodeTools
+#prepare_dglke_format.py $DATASET $DGLKE_OUT
+#conda deactivate
+
conda activate dglke
-dglke_train --model_name TransR --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path transr_u --hidden_dim 100 --num_thread 4 --gpu 0
-dglke_train --model_name RESCAL --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path RESCAL_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 160000
-dglke_train --model_name DistMult --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path DistMult_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 400000
-dglke_train --model_name ComplEx --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path ComplEx_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 400000
-dglke_train --model_name RotatE --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path RotatE_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 200000 -de
+
+dglke_train --model_name TransR --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/transr_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test
+dglke_train --model_name RESCAL --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/RESCAL_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test
+dglke_train --model_name DistMult --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/DistMult_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test
+dglke_train --model_name ComplEx --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/ComplEx_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test
+dglke_train --model_name RotatE --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/RotatE_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -de -adv --test
+
conda deactivate
diff --git a/setup.py b/setup.py
index 7262d3af..bbff5d40 100644
--- a/setup.py
+++ b/setup.py
@@ -1,31 +1,33 @@
from distutils.core import setup
-requitements = [
- 'nltk',
- 'tensorflow>=2.4.0',
- 'torch>=1.7.1',
- 'pandas>=1.1.1',
- 'sklearn',
- 'sentencepiece',
- 'gensim',
- 'numpy>=1.19.2',
- 'scipy',
- 'networkx',
- 'sacrebleu',
- 'datasets',
+requirements = [
+ 'nltk==3.6',
+ 'tensorflow==2.6.0',
+ 'torch==1.9.0',
+ 'pandas==1.1.1',
+ 'scikit-learn==1.0',
+ 'sentencepiece==0.1.96',
+ 'gensim==3.8',
+ 'numpy==1.19.5',
+ 'scipy==1.4.1',
+ 'networkx==2.5',
+ 'sacrebleu==1.5.1',
+ 'datasets==1.5.0',
'spacy==2.3.2',
- 'pytest'
+ 'pytest==6.1.2',
+ 'faiss-cpu==1.7.0'
# 'pygraphviz'
# 'javac_parser'
]
-
+# conda install pytorch cudatoolkit=11.1 dgl-cuda11.1 -c dglteam -c pytorch -c nvidia
setup(name='SourceCodeTools',
- version='0.0.2',
+ version='0.0.3',
py_modules=['SourceCodeTools'],
- install_requires=requitements + ["dgl"],
+ install_requires=requirements + ["dgl==0.7.1"],
extras_require={
- "gpu": requitements + ["dgl-cu110"]
+ "gpu": requirements + ["dgl-cu111==0.7.1"]
},
+ dependency_links=['https://data.dgl.ai/wheels/repo.html'],
scripts=[
'SourceCodeTools/code/data/sourcetrail/sourcetrail_call_seq_extractor.py',
'SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py',
@@ -47,5 +49,6 @@
'SourceCodeTools/code/data/sourcetrail/pandas_format_converter.py',
'SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py',
'SourceCodeTools/nlp/embed/converters/convert_fasttext_format_bin_to_vec.py',
+ 'SourceCodeTools/models/graph/utils/prepare_dglke_format.py',
],
-)
+ )