diff --git a/README.md b/README.md index 40389212..1dea0b64 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,100 @@ Library for analyzing source code with graphs and NLP. What this repository can 5. Train Graph Neural Network for learning representations for source code 6. Predict Python types using NLP and graph embeddings -### Installation +For more details consult our [wiki](https://github.com/VitalyRomanov/method-embedding/wiki). +### Installing Python Libraries + +You need to use conda, create virtual environment `SourceCodeTools` with python 3.8 +```bash +conda create -n SourceCodeTools python=3.8 +``` + +If you plan to use graphviz +```python +conda install -c conda-forge pygraphviz graphviz +``` + +Install CUDA 11.1 if needed +```python +conda install -c nvidia cudatoolkit=11.1 +``` + +To install SourceCodeTools library run ```bash git clone https://github.com/VitalyRomanov/method-embedding.git cd method-embedding pip install -e . -``` \ No newline at end of file +# pip install -e .[gpu] +``` + +### Processing Python Code + +Source code should be structured in the following way +``` +source_code_data +│ +└───package1 +│ │───source_file_1.py +│ │───source_file_2.py +│ └───subfolder_if_needed +│ │───source_file_3.py +│ └───source_file_4.py +│ +└───package2 + │───source_file_1.py + └───source_file_2.py +``` +An example of source code data can be found in this repository `method-embedding\res\python_testdata\example_code`. A package should contain self-sufficient code with its dependencies. Unmet dependencies will be labeled as non-indexed symbol. + +#### Indexing with Docker +To create dataset need to first perform indexing with Sourcetrail. The easiest way to do this is with a docker container +```bash +docker run -it -v "/full/path/to/data/folder":/dataset mortiv16/sourcetrail_indexer +``` + +#### Creating graph +Need to provide a sentencepiece model for subtokenization. Model trained on CodeSearchNet can be downloaded [here](https://www.dropbox.com/s/cw7oxkzicgnkzgb/sentencepiece_bpe.model?dl=1). +```bash +SCT=/path/to/SourceCodeTool_repository +SOURCE_CODE=/path/to/source/code/indexed/with/sourcetrail +DATASET_OUTPUT=/path/to/dataset/output +python $SCT/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py --bpe_tokenizer sentencepiece_bpe.model --track_offsets --do_extraction $SOURCE_CODE $DATASET_OUTPUT +``` + +The graph dataset format is [described in wiki](https://github.com/VitalyRomanov/method-embedding/wiki/04.-Graph-Format-Description) +``` +graph_dataset +│ +└───no_ast +│ │───common_call_seq.bz2 +│ │───common_edges.bz2 +│ │───common_function_variable_pairs.bz2 +│ │───common_nodes.bz2 +│ │───common_source_graph_bodies.bz2 +│ └───node_names.bz2 +│ +└───with_ast + │───common_call_seq.bz2 + │───common_edges.bz2 + │───common_function_variable_pairs.bz2 + │───common_nodes.bz2 + │───common_source_graph_bodies.bz2 + └───node_names.bz2 +``` + +`no_ast` contains graph built from global relationships only. `with_ast` contains graph with AST nodes and edges. Two main files for building the graph are `common_nodes.bz2` and `common_edges.bz2`. The files are stored as pickled pandas table (read with `pandas.read_pickle`) and probably not portable between platforms. One can view the content by converting table into the `csv` format +```bash +python $SCT/SourceCodeTools/code/data/sourcetrail/pandas_format_converter.py common_nodes.bz2 csv +``` + +The graph data can be loaded as pandas tables using `load_data` function + +```python +from SourceCodeTools.code.data.dataset.Dataset import load_data + +nodes, edges = load_data( + node_path="path/to/common_nodes.bz2", + edge_path="path/to/common_edges.bz2" +) +``` diff --git a/SourceCodeTools/cli_arguments/__init__.py b/SourceCodeTools/cli_arguments/__init__.py new file mode 100644 index 00000000..e36ad3df --- /dev/null +++ b/SourceCodeTools/cli_arguments/__init__.py @@ -0,0 +1,52 @@ +import argparse + + +class AstDatasetCreatorArguments: + parser = None + + def __init__(self): + parser = argparse.ArgumentParser(description='Merge indexed environments into a single graph') + parser.add_argument('--language', "-l", dest="language", default="python", help='') + parser.add_argument('--bpe_tokenizer', '-bpe', dest='bpe_tokenizer', type=str, help='') + parser.add_argument('--create_subword_instances', action='store_true', default=False, help="") + parser.add_argument('--connect_subwords', action='store_true', default=False, + help="Takes effect only when `create_subword_instances` is False") + parser.add_argument('--only_with_annotations', action='store_true', default=False, help="") + parser.add_argument('--do_extraction', action='store_true', default=False, help="") + parser.add_argument('--visualize', action='store_true', default=False, help="") + parser.add_argument('--track_offsets', action='store_true', default=False, help="") + parser.add_argument('--recompute_l2g', action='store_true', default=False, help="") + parser.add_argument('--remove_type_annotations', action='store_true', default=False, help="") + + self.parser = parser + self.add_positional_argument() + self.additional_arguments() + + def add_positional_argument(self): + self.parser.add_argument( + 'source_code', help='Path to DataFrame csv.' + ) + self.parser.add_argument('output_directory', help='') + + def additional_arguments(self): + self.parser.add_argument('--chunksize', default=10000, type=int, help='Chunksize for preparing dataset. Larger chunks are faster to process, but they take more memory.') + self.parser.add_argument('--keep_frac', default=1.0, type=float, help="Fraction of the dataset to keep") + + def parse(self): + return self.parser.parse_args() + + +class DatasetCreatorArguments(AstDatasetCreatorArguments): + + def add_positional_argument(self): + self.parser.add_argument('indexed_environments', help='Path to environments indexed by sourcetrail') + self.parser.add_argument('output_directory', help='') + + def additional_arguments(self): + pass + + + + + + diff --git a/SourceCodeTools/code/IdentifierPool.py b/SourceCodeTools/code/IdentifierPool.py new file mode 100644 index 00000000..6b2667be --- /dev/null +++ b/SourceCodeTools/code/IdentifierPool.py @@ -0,0 +1,38 @@ +from os import urandom +from time import time_ns + + +class IdentifierPool: + """ + Creates identifier that is almost guaranteed to be unique. Beginning of identifier is based on + current time, and the tail of identifier is randomly generated. + """ + def __init__(self): + self._used_identifiers = set() + + @staticmethod + def _get_candidate(): + return str(hex(int(time_ns())))[:12] + str(urandom(3).hex()) + # return "0x" + str(urandom(8).hex()) + + def get_new_identifier(self): + candidate = self._get_candidate() + while candidate in self._used_identifiers: + candidate = self._get_candidate() + self._used_identifiers.add(candidate) + return candidate + + +class IntIdentifierPool(IdentifierPool): + def __init__(self): + super().__init__() + + @staticmethod + def _get_candidate(): + candidate = str(int((str(hex(int(time_ns())))[:12] + str(urandom(3).hex())), 16)) + # assert len(candidate) == 19 + return candidate + # candidate = str(int(urandom(10).hex(), 16)) + # while len(candidate) < 19: + # candidate = str(int(urandom(10).hex(), 16)) + # return candidate[:19] \ No newline at end of file diff --git a/SourceCodeTools/nlp/entity/annotator/annotator_utils.py b/SourceCodeTools/code/annotator_utils.py similarity index 67% rename from SourceCodeTools/nlp/entity/annotator/annotator_utils.py rename to SourceCodeTools/code/annotator_utils.py index 3dfba4c6..469d95a7 100644 --- a/SourceCodeTools/nlp/entity/annotator/annotator_utils.py +++ b/SourceCodeTools/code/annotator_utils.py @@ -1,6 +1,34 @@ -from copy import copy +from copy import copy, deepcopy from typing import List, Tuple, Iterable +from SourceCodeTools.nlp import create_tokenizer +from spacy.gold import biluo_tags_from_offsets as spacy_biluo_tags_from_offsets + +from SourceCodeTools.nlp.tokenizers import codebert_to_spacy + + +def biluo_tags_from_offsets(doc, ents, no_localization): + ent_tags = spacy_biluo_tags_from_offsets(doc, ents) + + if no_localization: + tags = [] + for ent in ent_tags: + parts = ent.split("-") + + assert len(parts) <= 2 + + if len(parts) == 2: + if parts[0] == "B" or parts[0] == "U": + tags.append(parts[1]) + else: + tags.append("O") + else: + tags.append("O") + + ent_tags = tags + + return ent_tags + def get_cum_lens(body, as_bytes=False): """ @@ -57,10 +85,22 @@ def to_offsets(body: str, entities: Iterable[Iterable], as_bytes=False, cum_lens def adjust_offsets(offsets, amount): + """ + Adjust offset by subtracting certain amount from the start and end positions + :param offsets: iterable with offsets + :param amount: adjustment amount + :return: list of adjusted offsets + """ return [(offset[0] - amount, offset[1] - amount, offset[2]) for offset in offsets] def adjust_offsets2(offsets, amount): + """ + Adjust offset by adding certain amount to the start and end positions + :param offsets: iterable with offsets + :param amount: adjustment amount + :return: list of adjusted offsets + """ return [(offset[0] + amount, offset[1] + amount, offset[2]) for offset in offsets] @@ -143,4 +183,33 @@ def resolve_self_collisions2(offsets): no_collisions = list(set(no_collisions)) - return no_collisions \ No newline at end of file + return no_collisions + + +def align_tokens_with_graph(doc, spans, tokenzer_name): + spans = deepcopy(spans) + if tokenzer_name == "codebert": + backup_tokens = doc + doc, adjustment = codebert_to_spacy(doc) + spans = adjust_offsets(spans, adjustment) + + node_tags = biluo_tags_from_offsets(doc, spans, no_localization=False) + + if tokenzer_name == "codebert": + doc = [""] + [t.text for t in backup_tokens] + [""] + return doc, node_tags + + +def source_code_graph_alignment(source_codes, node_spans, tokenizer="codebert"): + supported_tokenizers = ["spacy", "codebert"] + assert tokenizer in supported_tokenizers, f"Only these tokenizers supported for alignment: {supported_tokenizers}" + nlp = create_tokenizer(tokenizer) + + for code, spans in zip(source_codes, node_spans): + yield align_tokens_with_graph(nlp(code), resolve_self_collisions2(spans), tokenzer_name=tokenizer) + + +def map_offsets(column, id_map): + def map_entry(entry): + return [(e[0], e[1], id_map[e[2]]) for e in entry] + return [map_entry(entry) for entry in column] \ No newline at end of file diff --git a/SourceCodeTools/code/ast/__init__.py b/SourceCodeTools/code/ast/__init__.py new file mode 100644 index 00000000..2e067ec7 --- /dev/null +++ b/SourceCodeTools/code/ast/__init__.py @@ -0,0 +1,9 @@ +import ast + + +def has_valid_syntax(function_body): + try: + ast.parse(function_body.lstrip()) + return True + except SyntaxError: + return False diff --git a/SourceCodeTools/code/ast_tools.py b/SourceCodeTools/code/ast/ast_tools.py similarity index 85% rename from SourceCodeTools/code/ast_tools.py rename to SourceCodeTools/code/ast/ast_tools.py index 98b5571d..ee3d09d6 100644 --- a/SourceCodeTools/code/ast_tools.py +++ b/SourceCodeTools/code/ast/ast_tools.py @@ -1,6 +1,6 @@ import ast -from SourceCodeTools.nlp.entity.annotator.annotator_utils import to_offsets, resolve_self_collision, adjust_offsets2 +from SourceCodeTools.code.annotator_utils import to_offsets, resolve_self_collision, adjust_offsets2 def get_mentions(function, root, mention): @@ -30,9 +30,18 @@ def get_mentions(function, root, mention): def get_descendants(function, children): + """ + :param function: function string + :param children: List of targets. + :return: Offsets for attributes or names that are used as target for assignment operation. Subscript, Tuple and List + targets are skipped. + """ descendants = [] + # if isinstance(children, ast.Tuple): + # descendants.extend(get_descendants(function, children.elts)) + # else: for chld in children: # for node in ast.walk(chld): node = chld @@ -42,6 +51,10 @@ def get_descendants(function, children): [(node.lineno-1, node.end_lineno-1, node.col_offset, node.end_col_offset, "new_var")], as_bytes=True) # descendants.append((node.id, offset[-1])) descendants.append((function[offset[-1][0]:offset[-1][1]], offset[-1])) + # elif isinstance(node, ast.Tuple): + # descendants.extend(get_descendants(function, node.elts)) + elif isinstance(node, ast.Subscript) or isinstance(node, ast.Tuple) or isinstance(node, ast.List): + pass # skip for now else: raise Exception("") diff --git a/SourceCodeTools/code/python_ast.py b/SourceCodeTools/code/ast/python_ast.py similarity index 87% rename from SourceCodeTools/code/python_ast.py rename to SourceCodeTools/code/ast/python_ast.py index 07b8f7aa..58927e49 100644 --- a/SourceCodeTools/code/python_ast.py +++ b/SourceCodeTools/code/ast/python_ast.py @@ -6,6 +6,7 @@ from collections.abc import Iterable import pandas as pd # import os +from SourceCodeTools.code.IdentifierPool import IdentifierPool class PythonSyntheticNodeTypes(Enum): # TODO NOT USED @@ -76,6 +77,7 @@ class GNode: # id = None def __init__(self, **kwargs): + self.string = None for k, v in kwargs.items(): setattr(self, k, v) @@ -104,38 +106,53 @@ def __init__(self, source): self.condition_status = [] self.scope = [] - def get_source_from_ast_range(self, start_line, end_line, start_col, end_col): + self._identifier_pool = IdentifierPool() + + def get_source_from_ast_range(self, start_line, end_line, start_col, end_col, strip=True): source = "" num_lines = end_line - start_line + 1 if start_line == end_line: - source += self.source[start_line - 1].encode("utf8")[start_col:end_col].decode( - "utf8").strip() + section = self.source[start_line - 1].encode("utf8")[start_col:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" else: for ind, lineno in enumerate(range(start_line - 1, end_line)): if ind == 0: - source += self.source[lineno].encode("utf8")[start_col:].decode( - "utf8").strip() + section = self.source[lineno].encode("utf8")[start_col:].decode( + "utf8") + source += section.strip() if strip else section + "\n" elif ind == num_lines - 1: - source += self.source[lineno].encode("utf8")[:end_col].decode( - "utf8").strip() + section = self.source[lineno].encode("utf8")[:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" else: - source += self.source[lineno].strip() + section = self.source[lineno] + source += section.strip() if strip else section + "\n" - return source + return source.rstrip() def get_name(self, *, node=None, name=None, type=None, add_random_identifier=False): + random_identifier = self._identifier_pool.get_new_identifier() + if node is not None: - name = node.__class__.__name__ + "_" + str(hex(int(time_ns()))) + name = f"{node.__class__.__name__}_{random_identifier}" type = node.__class__.__name__ else: if add_random_identifier: - name += f"_{str(hex(int(time_ns())))}" + name = f"{name}_{random_identifier}" + + if hasattr(node, "lineno"): + node_string = self.get_source_from_ast_range(node.lineno, node.end_lineno, node.col_offset, node.end_col_offset, strip=False) + # if "\n" in node_string: + # node_string = None + else: + node_string = None if len(self.scope) > 0: - return GNode(name=name, type=type, scope=copy(self.scope[-1])) + return GNode(name=name, type=type, scope=copy(self.scope[-1]), string=node_string) else: - return GNode(name=name, type=type) + return GNode(name=name, type=type, string=node_string) # return (node.__class__.__name__ + "_" + str(hex(int(time_ns()))), node.__class__.__name__) def get_edges(self, as_dataframe=True): @@ -170,12 +187,14 @@ def parse_body(self, nodes): for node in nodes: s = self.parse(node) if isinstance(s, tuple): + if s[1].type == "Constant": # this happens when processing docstring, as a result a lot of nodes are connected to the node Constant_ + continue # in general, constant node has no affect as a body expression, can skip # some parsers return edges and names? edges.extend(s[0]) if last_node is not None: - edges.append({"dst": s[1], "src": last_node, "type": "next"}) - edges.append({"dst": last_node, "src": s[1], "type": "prev"}) + edges.append({"dst": s[1], "src": last_node, "type": "next", "scope": copy(self.scope[-1])}) + edges.append({"dst": last_node, "src": s[1], "type": "prev", "scope": copy(self.scope[-1])}) last_node = s[1] @@ -278,9 +297,9 @@ def generic_parse(self, node, operands, with_name=None, ensure_iterables=False): else: node_name = with_name - if len(self.scope) > 0: - edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": self.scope[-1], "type": "mention_scope"}) - edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": node_name, "type": "mention_scope_rev"}) + # if len(self.scope) > 0: + # edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": self.scope[-1], "type": "mention_scope"}) + # edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": node_name, "type": "mention_scope_rev"}) for operand in operands: if operand in ["body", "orelse", "finalbody"]: @@ -414,6 +433,8 @@ def parse_ImportFrom(self, node): # # edges.extend(edges_from) # # edges.append({"scope": copy(self.scope[-1]), "src": name_from, "dst": name, "type": "module"}) # return edges, name + if node.module is not None: + node.module = ast.Name(node.module) return self.generic_parse(node, ["module", "names"]) def parse_Import(self, node): @@ -452,6 +473,10 @@ def parse_alias(self, node): # # # if node.asname: # # # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": node.asname, "type": "alias"}) # return edges, name + if node.name is not None: + node.name = ast.Name(node.name) + if node.asname is not None: + node.asname = ast.Name(node.asname) return self.generic_parse(node, ["name", "asname"]) def parse_arg(self, node): @@ -476,9 +501,11 @@ def parse_arg(self, node): ) annotation = GNode(name=annotation_string, type="type_annotation") - edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) - # do not use reverse edges for types, will result in leak from function to function - # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'}) + mention_name = GNode(name=node.arg + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1])) + edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": mention_name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) + # edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) + # # do not use reverse edges for types, will result in leak from function to function + # # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'}) return edges, name # return self.generic_parse(node, ["arg", "annotation"]) @@ -503,9 +530,15 @@ def parse_AnnAssign(self, node): annotation = GNode(name=annotation_string, type="type_annotation") edges, name = self.generic_parse(node, ["target"]) - edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) - # do not use reverse edges for types, will result in leak from function to function - # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'}) + try: + mention_name = GNode(name=node.target.id + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1])) + edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": mention_name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) + except Exception as e: + edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) + # print(e) # don't know how I should parse this "Attribute(value=Name(id='self', ctx=Load()), attr='srctrlrpl_1631733463030025000@#attr#', ctx=Store())" + # edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) + # # do not use reverse edges for types, will result in leak from function to function + # # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'}) return edges, name # return self.generic_parse(node, ["target", "annotation"]) @@ -854,7 +887,7 @@ def parse_arguments(self, node): # arguments(args=[arg(arg='self', annotation=None), arg(arg='tqdm_cls', annotation=None), arg(arg='sleep_interval', annotation=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]) # vararg constains type annotations - return self.generic_parse(node, ["args", "vararg"]) + return self.generic_parse(node, ["args", "vararg"]) # kwarg, kwonlyargs, posonlyargs??? def parse_comprehension(self, node): edges = [] @@ -862,9 +895,9 @@ def parse_comprehension(self, node): cph_name = self.get_name(name="comprehension", type="comprehension", add_random_identifier=True) # cph_name = GNode(name="comprehension_" + str(hex(int(time_ns()))), type="comprehension") - if len(self.scope) > 0: - edges.append({"scope": copy(self.scope[-1]), "src": cph_name, "dst": self.scope[-1], "type": "mention_scope"}) - edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": cph_name, "type": "mention_scope_rev"}) + # if len(self.scope) > 0: + # edges.append({"scope": copy(self.scope[-1]), "src": cph_name, "dst": self.scope[-1], "type": "mention_scope"}) + # edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": cph_name, "type": "mention_scope_rev"}) target, ext_edges = self.parse_operand(node.target) edges.extend(ext_edges) diff --git a/SourceCodeTools/code/ast/python_ast2.py b/SourceCodeTools/code/ast/python_ast2.py new file mode 100644 index 00000000..0c271c0a --- /dev/null +++ b/SourceCodeTools/code/ast/python_ast2.py @@ -0,0 +1,1005 @@ +import ast +import logging +from copy import copy +from enum import Enum +from itertools import chain +from pprint import pprint +from time import time_ns +from collections.abc import Iterable +import pandas as pd +# import os +from SourceCodeTools.code.IdentifierPool import IdentifierPool + + +class PythonNodeEdgeDefinitions: + ast_node_type_edges = { + "Assign": ["value", "targets"], + "AugAssign": ["target", "op", "value"], + "Import": ["names"], + "alias": ["name", "asname"], + "ImportFrom": ["module", "names"], + "Delete": ["targets"], + "Global": ["names"], + "Nonlocal": ["names"], + "withitem": ["context_expr", "optional_vars"], + "Subscript": ["value", "slice"], + "Slice": ["lower", "upper", "step"], + "ExtSlice": ["dims"], + "Index": ["value"], + "Starred": ["value"], + "Yield": ["value"], + "ExceptHandler": ["type"], + "Call": ["func", "args", "keywords"], + "Compare": ["left", "ops", "comparators"], + "BoolOp": ["values", "op"], + "Assert": ["test", "msg"], + "List": ["elts"], + "Tuple": ["elts"], + "Set": ["elts"], + "UnaryOp": ["operand", "op"], + "BinOp": ["left", "right", "op"], + "Await": ["value"], + "GeneratorExp": ["elt", "generators"], + "ListComp": ["elt", "generators"], + "SetComp": ["elt", "generators"], + "DictComp": ["key", "value", "generators"], + "Return": ["value"], + "Raise": ["exc", "cause"], + "YieldFrom": ["value"], + } + + overriden_node_type_edges = { + "Module": [], # overridden + "FunctionDef": ["function_name", "args", "decorator_list", "returned_by"], # overridden, `function_name` replaces `name`, `returned_by` replaces `returns` + "AsyncFunctionDef": ["function_name", "args", "decorator_list", "returned_by"], # overridden, `function_name` replaces `name`, `returned_by` replaces `returns` + "ClassDef": ["class_name"], # overridden, `class_name` replaces `name` + "AnnAssign": ["target", "value", "annotation_for"], # overridden, `annotation_for` replaces `annotation` + "With": ["items"], # overridden + "AsyncWith": ["items"], # overridden + "arg": ["arg", "annotation_for"], # overridden, `annotation_for` is custom + "Lambda": [], # overridden + "IfExp": ["test", "if_true", "if_false"], # overridden, `if_true` renamed from `body`, `if_false` renamed from `orelse` + "keyword": ["arg", "value"], # overridden + "Attribute": ["value", "attr"], # overridden + "Num": [], # overridden + "Str": [], # overridden + "Bytes": [], # overridden + "If": ["test"], # overridden + "For": ["target", "iter"], # overridden + "AsyncFor": ["target", "iter"], # overridden + "Try": [], # overridden + "While": [], # overridden + "Expr": ["value"], # overridden + "Dict": ["keys", "values"], # overridden + "JoinedStr": [], # overridden + "FormattedValue": ["value"], # overridden + "arguments": ["args", "vararg", "kwarg", "kwonlyargs", "posonlyargs"], # overridden + "comprehension": ["target", "iter", "ifs"], # overridden, `target_for` is custom, `iter_for` is customm `ifs_rev` is custom + } + + context_edge_names = { + "Module": ["defined_in_module"], + "FunctionDef": ["defined_in_function"], + "ClassDef": ["defined_in_class"], + "With": ["executed_inside_with"], + "AsyncWith": ["executed_inside_with"], + "If": ["executed_if_true", "executed_if_false"], + "For": ["executed_in_for", "executed_in_for_orelse"], + "AsyncFor": ["executed_in_for", "executed_in_for_orelse"], + "While": ["executed_in_while", "executed_while_true"], + "Try": ["executed_in_try", "executed_in_try_final", "executed_in_try_else", "executed_in_try_except", "executed_with_try_handler"], + } + + extra_edge_types = { + "control_flow", "next", "local_mention", + } + + # exceptions needed when we do not want to filter some edge types using a simple rule `_rev` + reverse_edge_exceptions = { + # "target": "target_for", + # "iter": "iter_for", # mainly used in comprehension + # "ifs": "ifs_for", # mainly used in comprehension + "next": "prev", + "local_mention": None, # from name to variable mention + "returned_by": None, # for type annotations + "annotation_for": None, # for type annotations + "control_flow": None, # for control flow + "op": None, # for operations + "attr": None, # for attributes + # "arg": None # for keywords ??? + } + + iterable_nodes = { # parse_iterable + "List", "Tuple", "Set" + } + + named_nodes = { + "Name", "NameConstant" # parse_name + } + + constant_nodes = { + "Constant" # parse_Constant + } + + operand_nodes = { # parse_op_name + "And", "Or", "Not", "Is", "Gt", "Lt", "GtE", "LtE", "Eq", "NotEq", "Ellipsis", "Add", "Mod", + "Sub", "UAdd", "USub", "Div", "Mult", "MatMult", "Pow", "FloorDiv", "RShift", "LShift", "BitAnd", + "BitOr", "BitXor", "IsNot", "NotIn", "In", "Invert" + } + + control_flow_nodes = { # parse_control_flow + "Continue", "Break", "Pass" + } + + # extra node types exist for keywords and attributes to prevent them from + # getting mixed with local variable mentions + extra_node_types = { + "#keyword#", + "#attr#" + } + + @classmethod + def regular_node_types(cls): + return set(cls.ast_node_type_edges.keys()) + + @classmethod + def overridden_node_types(cls): + return set(cls.overriden_node_type_edges.keys()) + + @classmethod + def node_types(cls): + return list( + cls.regular_node_types() | + cls.overridden_node_types() | + cls.iterable_nodes | cls.named_nodes | cls.constant_nodes | + cls.operand_nodes | cls.control_flow_nodes | cls.extra_node_types + ) + + @classmethod + def scope_edges(cls): + return set(map(lambda x: x, chain(*cls.context_edge_names.values()))) # "defined_in_" + + + @classmethod + def auxiliary_edges(cls): + direct_edges = cls.scope_edges() | cls.extra_edge_types + reverse_edges = cls.compute_reverse_edges(direct_edges) + return direct_edges | reverse_edges + + @classmethod + def compute_reverse_edges(cls, direct_edges): + reverse_edges = set() + for edge in direct_edges: + if edge in cls.reverse_edge_exceptions: + reverse = cls.reverse_edge_exceptions[edge] + if reverse is not None: + reverse_edges.add(reverse) + else: + reverse_edges.add(edge + "_rev") + return reverse_edges + + @classmethod + def edge_types(cls): + direct_edges = list( + set(chain(*cls.ast_node_type_edges.values())) | + set(chain(*cls.overriden_node_type_edges.values())) | + cls.scope_edges() | + cls.extra_edge_types | cls.named_nodes | cls.constant_nodes | + cls.operand_nodes | cls.control_flow_nodes | cls.extra_node_types + ) + + reverse_edges = list(cls.compute_reverse_edges(direct_edges)) + return direct_edges + reverse_edges + + +PythonAstNodeTypes = Enum( + "PythonAstNodeTypes", + " ".join( + PythonNodeEdgeDefinitions.node_types() + ) +) + + +PythonAstEdgeTypes = Enum( + "PythonAstEdgeTypes", + " ".join( + PythonNodeEdgeDefinitions.edge_types() + ) +) + + +class PythonCodeExamplesForNodes: + examples = { + "FunctionDef": + "def f(a):\n" + " return a\n", + "ClassDef": + "class C:\n" + " def m():\n" + " pass\n", + "AnnAssign": "a: int = 5\n", + "With": + "with open(a) as b:\n" + " do_stuff(b)\n", + "arg": + "def f(a: int = 5):\n" + " return a\n", + "Lambda": "lambda x: x + 3\n", + "IfExp": "a = 5 if True else 0\n", + "keyword": "fn(a=5, b=4)\n", + "Attribute": "a.b.c\n", + "If": + "if d is True:\n" + " a = b\n" + "else:\n" + " a = c\n", + "For": + "for i in list:\n" + " k = fn(i)\n" + " if k == 4:\n" + " fn2(k)\n" + " break\n" + "else:\n" + " fn2(0)\n", + "Try": + "try:\n" + " a = b\n" + "except Exception as e:\n" + " a = c\n" + "else:\n" + " a = d\n" + "finally:\n" + " print(a)\n", + "While": + "while b = c:\n" + " do_iter(b)\n", + "Dict": "{a:b, c:d}\n", + "comprehension": "[i for i in list if i != 5]\n", + "BinOp": "c = a + b\n", + "ImportFrom": "from module import Class\n", + "alias": "import module as m\n", + "List": "a = [1, 2, 3, 4]\n" + } + + +def generate_available_edges(): + node_types = PythonNodeEdgeDefinitions.node_types() + for nt in sorted(node_types): + if hasattr(ast, nt): + fl = sorted(getattr(ast, nt)._fields) + if len(fl) == 0: + print(nt, ) + else: + for f in fl: + print(nt, f, sep=" ") + + +def generate_utilized_edges(): + d = dict() + d.update(PythonNodeEdgeDefinitions.ast_node_type_edges) + d.update(PythonNodeEdgeDefinitions.overriden_node_type_edges) + for nt in sorted(d.keys()): + if hasattr(ast, nt): + fl = sorted(d[nt]) + if len(fl) == 0: + print(nt, ) + else: + for f in fl: + print(nt, f, sep=" ") + + +class PythonSharedNodes: + annotation_types = {"type_annotation", "returned_by"} + tokenizable_types = {"Name", "#attr#", "#keyword#"} + python_token_types = {"Op", "Constant", "JoinedStr", "CtlFlow", "ast_Literal"} + subword_types = {'subword'} + + subword_leaf_types = annotation_types | subword_types | python_token_types + named_leaf_types = annotation_types | tokenizable_types | python_token_types + tokenizable_types_and_annotations = annotation_types | tokenizable_types + + shared_node_types = annotation_types | subword_types | tokenizable_types | python_token_types + + # leaf_types = {'subword', "Op", "Constant", "JoinedStr", "CtlFlow", "ast_Literal", "Name", "type_annotation", "returned_by"} + # shared_node_types = {'subword', "Op", "Constant", "JoinedStr", "CtlFlow", "ast_Literal", "Name", "type_annotation", "returned_by", "#attr#", "#keyword#"} + + @classmethod + def is_shared(cls, node): + # nodes that are of stared type + # nodes that are subwords of keyword arguments + return cls.is_shared_name_type(node.name, node.type) + + @classmethod + def is_shared_name_type(cls, name, type): + if type in cls.shared_node_types or \ + (type == "subword_instance" and "0x" not in name): + return True + return False + + +class GNode: + def __init__(self, **kwargs): + self.string = None + for k, v in kwargs.items(): + setattr(self, k, v) + + def __eq__(self, other): + return self.name == other.name and self.type == other.type + + def __repr__(self): + return self.__dict__.__repr__() + + def __hash__(self): + return (self.name, self.type).__hash__() + + def setprop(self, key, value): + setattr(self, key, value) + + +class GEdge: + def __init__(self, src, dst, type, scope=None, line=None, end_line=None, col_offset=None, end_col_offset=None): + self.src = src + self.dst = dst + self.type = type + self.line = line + self.scope = scope + self.end_line = end_line + self.col_offset = col_offset + self.end_col_offset = end_col_offset + + def __getitem__(self, item): + return self.__dict__[item] + + +class AstGraphGenerator(object): + + def __init__(self, source, add_reverse_edges=True): + self.source = source.split("\n") # lines of the source code + self.root = ast.parse(source) + self.current_condition = [] + self.condition_status = [] + self.scope = [] + self._add_reverse_edges = add_reverse_edges + + self._identifier_pool = IdentifierPool() + + def get_source_from_ast_range(self, node, strip=True): + start_line = node.lineno + end_line = node.end_lineno + start_col = node.col_offset + end_col = node.end_col_offset + + source = "" + num_lines = end_line - start_line + 1 + if start_line == end_line: + section = self.source[start_line - 1].encode("utf8")[start_col:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" + else: + for ind, lineno in enumerate(range(start_line - 1, end_line)): + if ind == 0: + section = self.source[lineno].encode("utf8")[start_col:].decode( + "utf8") + source += section.strip() if strip else section + "\n" + elif ind == num_lines - 1: + section = self.source[lineno].encode("utf8")[:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" + else: + section = self.source[lineno] + source += section.strip() if strip else section + "\n" + + return source.rstrip() + + def get_name(self, *, node=None, name=None, type=None, add_random_identifier=False): + + random_identifier = self._identifier_pool.get_new_identifier() + + if node is not None: + name = f"{node.__class__.__name__}_{random_identifier}" + type = node.__class__.__name__ + else: + if add_random_identifier: + name = f"{name}_{random_identifier}" + + if hasattr(node, "lineno"): + node_string = self.get_source_from_ast_range(node, strip=False) + # if "\n" in node_string: + # node_string = None + else: + node_string = None + + if len(self.scope) > 0: + return GNode(name=name, type=type, scope=copy(self.scope[-1]), string=node_string) + else: + return GNode(name=name, type=type, string=node_string) + # return (node.__class__.__name__ + "_" + str(hex(int(time_ns()))), node.__class__.__name__) + + def get_edges(self, as_dataframe=True): + edges = [] + for f_def_node in ast.iter_child_nodes(self.root): + if type(f_def_node) == ast.FunctionDef: + edges.extend(self.parse(f_def_node)) + break # to avoid going through nested definitions + + if not as_dataframe: + return edges + df = pd.DataFrame(edges) + return df.astype({col: "Int32" for col in df.columns if col not in {"src", "dst", "type"}}) + + def parse(self, node): + n_type = type(node).__name__ + if n_type in PythonNodeEdgeDefinitions.ast_node_type_edges: + return self.generic_parse( + node, + PythonNodeEdgeDefinitions.ast_node_type_edges[n_type] + ) + elif n_type in PythonNodeEdgeDefinitions.overriden_node_type_edges: + method_name = "parse_" + n_type + return self.__getattribute__(method_name)(node) + elif n_type in PythonNodeEdgeDefinitions.iterable_nodes: + return self.parse_iterable(node) + elif n_type in PythonNodeEdgeDefinitions.named_nodes: + return self.parse_name(node) + elif n_type in PythonNodeEdgeDefinitions.constant_nodes: + return self.parse_Constant(node) + elif n_type in PythonNodeEdgeDefinitions.operand_nodes: + return self.parse_op_name(node) + elif n_type in PythonNodeEdgeDefinitions.control_flow_nodes: + return self.parse_control_flow(node) + else: + print(type(node)) + print(ast.dump(node)) + print(node._fields) + pprint(self.source) + return self.generic_parse(node, node._fields) + # raise Exception() + # return [type(node)] + + def add_edge( + self, edges, src, dst, type, scope=None, + position_node=None, var_position_node=None + ): + edges.append({ + "src": src, "dst": dst, "type": type, "scope": scope, + }) + + def get_positions(node): + if node is not None and hasattr(node, "lineno"): + line = node.lineno-1 + end_line = node.end_lineno - 1 + col_offset = node.col_offset + end_col_offset = node.end_col_offset + else: + line = end_line = col_offset = end_col_offset = None + return line, end_line, col_offset, end_col_offset + + line, end_line, col_offset, end_col_offset = get_positions(position_node) + + if line is not None: + edges[-1].update({ + "line": line, "end_line": end_line, "col_offset": col_offset, "end_col_offset": end_col_offset + }) + + var_line, var_end_line, var_col_offset, var_end_col_offset = get_positions(var_position_node) + + if var_line is not None: + edges[-1].update({ + "var_line": var_line, "var_end_line": var_end_line, "var_col_offset": var_col_offset, "var_end_col_offset": var_end_col_offset + }) + + reverse_type = PythonNodeEdgeDefinitions.reverse_edge_exceptions.get(type, type + "_rev") + if self._add_reverse_edges is True and reverse_type is not None: + edges.append({ + "src": dst, "dst": src, "type": reverse_type, "scope": scope + }) + + def parse_body(self, nodes): + edges = [] + last_node = None + for node in nodes: + s = self.parse(node) + if isinstance(s, tuple): + if s[1].type == "Constant": # this happens when processing docstring, as a result a lot of nodes are connected to the node Constant_ + continue # in general, constant node has no affect as a body expression, can skip + # some parsers return edges and names? + edges.extend(s[0]) + + if last_node is not None: + self.add_edge(edges, src=last_node, dst=s[1], type="next", scope=self.scope[-1]) + + last_node = s[1] + + for cond_name, cond_stat in zip(self.current_condition[-1:], self.condition_status[-1:]): + self.add_edge(edges, src=last_node, dst=cond_name, type=cond_stat, scope=self.scope[-1]) # "defined_in_" + + else: + edges.extend(s) + return edges + + def parse_in_context(self, cond_name, cond_stat, edges, body): + if isinstance(cond_name, str): + cond_name = [cond_name] + cond_stat = [cond_stat] + elif isinstance(cond_name, GNode): + cond_name = [cond_name] + cond_stat = [cond_stat] + + for cn, cs in zip(cond_name, cond_stat): + self.current_condition.append(cn) + self.condition_status.append(cs) + + edges.extend(self.parse_body(body)) + + for i in range(len(cond_name)): + self.current_condition.pop(-1) + self.condition_status.pop(-1) + + def parse_as_mention(self, name): + mention_name = GNode(name=name + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1])) + name = GNode(name=name, type="Name") + # mention_name = (name + "@" + self.scope[-1], "mention") + + # edge from name to mention in a function + edges = [] + self.add_edge(edges, src=name, dst=mention_name, type="local_mention", scope=self.scope[-1]) + return edges, mention_name + + def parse_operand(self, node): + # need to make sure upper level name is correct; handle @keyword; type placeholder for sourcetrail nodes? + edges = [] + if isinstance(node, str): + # fall here when parsing attributes, they are given as strings; should attributes be parsed into subwords? + if "@" in node: + parts = node.split("@") + node = GNode(name=parts[0], type=parts[1]) + else: + # node = GNode(name=node, type="Name") + node = ast.Name(node) + edges_, node = self.parse(node) + edges.extend(edges_) + iter_ = node + elif isinstance(node, int) or node is None: + iter_ = GNode(name=str(node), type="ast_Literal") + # iter_ = str(node) + elif isinstance(node, GNode): + iter_ = node + else: + iter_e = self.parse(node) + if type(iter_e) == str: + iter_ = iter_e + elif isinstance(iter_e, GNode): + iter_ = iter_e + elif type(iter_e) == tuple: + ext_edges, name = iter_e + assert isinstance(name, GNode) + edges.extend(ext_edges) + iter_ = name + else: + # unexpected scenario + print(node) + print(ast.dump(node)) + print(iter_e) + print(type(iter_e)) + pprint(self.source) + print(self.source[node.lineno - 1].strip()) + raise Exception() + + return iter_, edges + + def parse_and_add_operand(self, node_name, operand, type, edges): + + operand_name, ext_edges = self.parse_operand(operand) + edges.extend(ext_edges) + + self.add_edge( + edges, src=operand_name, dst=node_name, type=type, scope=self.scope[-1], + position_node=operand + ) + + def generic_parse(self, node, operands, with_name=None, ensure_iterables=False): + + edges = [] + + if with_name is None: + node_name = self.get_name(node=node) + else: + node_name = with_name + + for operand in operands: + if operand in ["body", "orelse", "finalbody"]: + logging.warning(f"Not clear which node is processed here {ast.dump(node)}") + self.parse_in_context(node_name, operand, edges, node.__getattribute__(operand)) + else: + operand_ = node.__getattribute__(operand) + if operand_ is not None: + if isinstance(operand_, Iterable) and not isinstance(operand_, str): + # TODO: + # appears as leaf node if the iterable is empty. suggest adding an "EMPTY" element + for oper_ in operand_: + self.parse_and_add_operand(node_name, oper_, operand, edges) + else: + self.parse_and_add_operand(node_name, operand_, operand, edges) + + # TODO + # need to identify the benefit of this node + # maybe it is better to use node types in the graph + # edges.append({"scope": copy(self.scope[-1]), "src": node.__class__.__name__, "dst": node_name, "type": "node_type"}) + + return edges, node_name + + def parse_type_node(self, node): + if node.lineno == node.end_lineno: + type_str = self.source[node.lineno][node.col_offset - 1: node.end_col_offset] + else: + type_str = "" + for ln in range(node.lineno - 1, node.end_lineno): + if ln == node.lineno - 1: + type_str += self.source[ln][node.col_offset - 1:].strip() + elif ln == node.end_lineno - 1: + type_str += self.source[ln][:node.end_col_offset].strip() + else: + type_str += self.source[ln].strip() + return type_str + + def parse_Module(self, node): + edges, module_name = self.generic_parse(node, []) + self.scope.append(module_name) + self.parse_in_context(module_name, "defined_in_module", edges, node.body) + self.scope.pop(-1) + return edges, module_name + + def parse_FunctionDef(self, node): + # need to create function name before generic_parse so that the scope is set up correctly + # scope is used to create local mentions of variable and function names + fdef_node_name = self.get_name(node=node) + self.scope.append(fdef_node_name) + + to_parse = [] + if len(node.args.args) > 0 or node.args.vararg is not None: + to_parse.append("args") + if len("decorator_list") > 0: + to_parse.append("decorator_list") + + edges, f_name = self.generic_parse(node, to_parse, with_name=fdef_node_name) + + if node.returns is not None: + # returns stores return type annotation + # can contain quotes + # https://stackoverflow.com/questions/46458470/should-you-put-quotes-around-type-annotations-in-python + # https://www.python.org/dev/peps/pep-0484/#forward-references + annotation_string = self.get_source_from_ast_range(node.returns) + annotation = GNode(name=annotation_string, + type="type_annotation") + self.add_edge( + edges, src=annotation, dst=f_name, type="returned_by", scope=self.scope[-1], + position_node=node.returns + ) + + self.parse_in_context(f_name, "defined_in_function", edges, node.body) + + self.scope.pop(-1) + + ext_edges, func_name = self.parse_as_mention(name=node.name) + edges.extend(ext_edges) + + assert isinstance(node.name, str) + self.add_edge( + edges, src=f_name, dst=func_name, type="function_name", scope=self.scope[-1], + ) + + return edges, f_name + + def parse_AsyncFunctionDef(self, node): + return self.parse_FunctionDef(node) + + def parse_ClassDef(self, node): + + edges, class_node_name = self.generic_parse(node, []) + + self.scope.append(class_node_name) + self.parse_in_context(class_node_name, "defined_in_class", edges, node.body) + self.scope.pop(-1) + + ext_edges, cls_name = self.parse_as_mention(name=node.name) + edges.extend(ext_edges) + self.add_edge( + edges, src=class_node_name, dst=cls_name, type="class_name", scope=self.scope[-1], + ) + + return edges, class_node_name + + def parse_With(self, node): + edges, with_name = self.generic_parse(node, ["items"]) + + self.parse_in_context(with_name, "executed_inside_with", edges, node.body) + + return edges, with_name + + def parse_AsyncWith(self, node): + return self.parse_With(node) + + def parse_arg(self, node): + # node.annotation stores type annotation + # if node.annotation: + # print(self.source[node.lineno-1]) # can get definition string here + # print(node.arg) + + # # included mention + name = self.get_name(node=node) + edges, mention_name = self.parse_as_mention(node.arg) + self.add_edge( + edges, src=mention_name, dst=name, type="arg", scope=self.scope[-1], + ) + + if node.annotation is not None: + # can contain quotes + # https://stackoverflow.com/questions/46458470/should-you-put-quotes-around-type-annotations-in-python + # https://www.python.org/dev/peps/pep-0484/#forward-references + annotation_string = self.get_source_from_ast_range(node.annotation) + annotation = GNode(name=annotation_string, + type="type_annotation") + mention_name = GNode(name=node.arg + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1])) + self.add_edge( + edges, src=annotation, dst=mention_name, type="annotation_for", scope=self.scope[-1], + position_node=node.annotation, var_position_node=node + ) + return edges, name + + def parse_AnnAssign(self, node): + # stores annotation information for variables + # + # paths: List[Path] = [] + # AnnAssign(target=Name(id='paths', ctx=Store()), annotation=Subscript(value=Name(id='List', ctx=Load()), + # slice=Index(value=Name(id='Path', ctx=Load())), + # ctx=Load()), value=List(elts=[], ctx=Load()), simple=1) + + # TODO + # parse value?? + + # can contain quotes + # https://stackoverflow.com/questions/46458470/should-you-put-quotes-around-type-annotations-in-python + # https://www.python.org/dev/peps/pep-0484/#forward-references + annotation_string = self.get_source_from_ast_range(node.annotation) + annotation = GNode(name=annotation_string, + type="type_annotation") + edges, name = self.generic_parse(node, ["target", "value"]) + try: + mention_name = GNode(name=node.target.id + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1])) + except Exception as e: + mention_name = name + + self.add_edge( + edges, src=annotation, dst=mention_name, type="annotation_for", scope=self.scope[-1], + position_node=node.annotation, var_position_node=node + ) + return edges, name + + def parse_Lambda(self, node): + # this is too ambiguous + edges, lmb_name = self.generic_parse(node, []) + self.parse_and_add_operand(lmb_name, node.body, "lambda", edges) + + return edges, lmb_name + + def parse_IfExp(self, node): + edges, ifexp_name = self.generic_parse(node, ["test"]) + self.parse_and_add_operand(ifexp_name, node.body, "if_true", edges) + self.parse_and_add_operand(ifexp_name, node.orelse, "if_false", edges) + return edges, ifexp_name + + def parse_ExceptHandler(self, node): + # have missing fields. example: + # not parsing "name" field + # except handler is unique for every function + return self.generic_parse(node, ["type"]) + + def parse_keyword(self, node): + if isinstance(node.arg, str): + # change arg name so that it does not mix with variable names + node.arg += "@#keyword#" + return self.generic_parse(node, ["arg", "value"]) + else: + return self.generic_parse(node, ["value"]) + + def parse_name(self, node): + edges = [] + # if type(node) == ast.Attribute: + # left, ext_edges = self.parse(node.value) + # right = node.attr + # return self.parse(node.value) + "___" + node.attr + if type(node) == ast.Name: + return self.parse_as_mention(str(node.id)) + elif type(node) == ast.NameConstant: + return GNode(name=str(node.value), type="NameConstant") + + def parse_Attribute(self, node): + if node.attr is not None: + # change attr name so that it does not mix with variable names + node.attr += "@#attr#" + return self.generic_parse(node, ["value", "attr"]) + + def parse_Constant(self, node): + # TODO + # decide whether this name should be unique or not + name = GNode(name="Constant_", type="Constant") + # name = "Constant_" + # if node.kind is not None: + # name += "" + return name + + def parse_op_name(self, node): + return GNode(name=node.__class__.__name__, type="Op") + # return node.__class__.__name__ + + def parse_Num(self, node): + return str(node.n) + + def parse_Str(self, node): + return self.generic_parse(node, []) + # return node.s + + def parse_Bytes(self, node): + return repr(node.s) + + def parse_If(self, node): + + edges, if_name = self.generic_parse(node, ["test"]) + + self.parse_in_context(if_name, "executed_if_true", edges, node.body) + self.parse_in_context(if_name, "executed_if_false", edges, node.orelse) + + return edges, if_name + + def parse_For(self, node): + + edges, for_name = self.generic_parse(node, ["target", "iter"]) + + self.parse_in_context(for_name, "executed_in_for", edges, node.body) + self.parse_in_context(for_name, "executed_in_for_orelse", edges, node.orelse) + + return edges, for_name + + def parse_AsyncFor(self, node): + return self.parse_For(node) + + def parse_Try(self, node): + + edges, try_name = self.generic_parse(node, []) + + self.parse_in_context(try_name, "executed_in_try", edges, node.body) + + for h in node.handlers: + + handler_name, ext_edges = self.parse_operand(h) + edges.extend(ext_edges) + self.parse_in_context([try_name, handler_name], ["executed_in_try_except", "executed_with_try_handler"], edges, h.body) + + self.parse_in_context(try_name, "executed_in_try_final", edges, node.finalbody) + self.parse_in_context(try_name, "executed_in_try_else", edges, node.orelse) + + return edges, try_name + + def parse_While(self, node): + + edges, while_name = self.generic_parse(node, []) + + cond_name, ext_edges = self.parse_operand(node.test) + edges.extend(ext_edges) + + self.parse_in_context([while_name, cond_name], ["executed_in_while", "executed_while_true"], edges, node.body) + + return edges, while_name + + # def parse_Compare(self, node): + # return self.generic_parse(node, ["left", "ops", "comparators"]) + # + # def parse_BoolOp(self, node): + # return self.generic_parse(node, ["values", "op"]) + + def parse_Expr(self, node): + edges = [] + expr_name, ext_edges = self.parse_operand(node.value) + edges.extend(ext_edges) + + return edges, expr_name + + def parse_control_flow(self, node): + edges = [] + ctrlflow_name = self.get_name(name="ctrl_flow", type="CtlFlowInstance", add_random_identifier=True) + self.add_edge( + edges, + src=GNode(name=node.__class__.__name__, type="CtlFlow"), dst=ctrlflow_name, + type="control_flow", scope=self.scope[-1] + ) + + return edges, ctrlflow_name + + def parse_iterable(self, node): + return self.generic_parse(node, ["elts"], ensure_iterables=True) + + def parse_Dict(self, node): + return self.generic_parse(node, ["keys", "values"], ensure_iterables=True) + + def parse_JoinedStr(self, node): + joinedstr_name = GNode(name="JoinedStr_", type="JoinedStr") + return [], joinedstr_name + # return self.generic_parse(node, []) + # return self.generic_parse(node, ["values"]) + + def parse_FormattedValue(self, node): + # have missing fields. example: + # FormattedValue(value=Subscript(value=Name(id='args', ctx=Load()), slice=Index(value=Num(n=0)), ctx=Load()),conversion=-1, format_spec=None) + # 'conversion', 'format_spec' not parsed + return self.generic_parse(node, ["value"]) + + def parse_arguments(self, node): + # have missing fields. example: + # arguments(args=[arg(arg='self', annotation=None), arg(arg='tqdm_cls', annotation=None), arg(arg='sleep_interval', annotation=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]) + + # vararg constains type annotations + return self.generic_parse(node, ["args", "vararg", "kwarg", "kwonlyargs", "posonlyargs"]) + + def parse_comprehension(self, node): + edges = [] + + cph_name = self.get_name(name="comprehension", type="comprehension", add_random_identifier=True) + + target, ext_edges = self.parse_operand(node.target) + edges.extend(ext_edges) + + self.add_edge( + edges, src=target, dst=cph_name, type="target", scope=self.scope[-1], + position_node=node.target + ) + + iter_, ext_edges = self.parse_operand(node.iter) + edges.extend(ext_edges) + + self.add_edge( + edges, src=iter_, dst=cph_name, type="iter", scope=self.scope[-1], + position_node=node.iter + ) + + for if_ in node.ifs: + if_n, ext_edges = self.parse_operand(if_) + edges.extend(ext_edges) + self.add_edge( + edges, src=if_n, dst=cph_name, type="ifs", scope=self.scope[-1], + ) + + return edges, cph_name + + # def parse_alias(self, node): + # if isinstance(node.name, str): + # node.name = ast.Name(node.name) + # if isinstance(node.asname, str): + # node.asname = ast.Name(node.asname) + # return self.generic_parse(node, ["name", "asname"]) + # + # def parse_ImportFrom(self, node): + # if node.module is not None: + # node.module = ast.Name(node.module) + # return self.generic_parse(node, ["module", "names"]) + +if __name__ == "__main__": + c = "def f(a=5): f(a=4)" + g = AstGraphGenerator(c.lstrip()) + g.parse(g.root) + # import sys + # f_bodies = pd.read_csv(sys.argv[1]) + # failed = 0 + # for ind, c in enumerate(f_bodies['body_normalized']): + # if isinstance(c, str): + # try: + # g = AstGraphGenerator(c.lstrip()) + # except SyntaxError as e: + # print(e) + # continue + # failed += 1 + # edges = g.get_edges() + # # edges.to_csv(os.path.join(os.path.dirname(sys.argv[1]), "body_edges.csv"), mode="a", index=False, header=(ind==0)) + # print("\r%d/%d" % (ind, len(f_bodies['normalized_body'])), end="") + # else: + # print("Skipped not a string") + # + # print(" " * 30, end="\r") + # print(failed, len(f_bodies['normalized_body'])) diff --git a/SourceCodeTools/code/ast/python_ast_cf.py b/SourceCodeTools/code/ast/python_ast_cf.py new file mode 100644 index 00000000..64e63ba4 --- /dev/null +++ b/SourceCodeTools/code/ast/python_ast_cf.py @@ -0,0 +1,509 @@ +import ast +from copy import copy +from pprint import pprint +from time import time_ns +from collections.abc import Iterable +import pandas as pd + +from SourceCodeTools.code.annotator_utils import to_offsets + + +class GNode: + # name = None + # type = None + # id = None + + def __init__(self, **kwargs): + self.string = None + for k, v in kwargs.items(): + setattr(self, k, v) + + def __eq__(self, other): + if self.name == other.name and self.type == other.type: + return True + else: + return False + + def __repr__(self): + return self.__dict__.__repr__() + + def __hash__(self): + return (self.name, self.type).__hash__() + + def setprop(self, key, value): + setattr(self, key, value) + + +class AstGraphGenerator(object): + + def __init__(self, source): + self.source = source.split("\n") # lines of the source code + self.full_source = source + self.root = ast.parse(source) + self.current_condition = [] + self.condition_status = [] + self.scope = [] + + def get_source_from_ast_range(self, start_line, end_line, start_col, end_col): + source = "" + num_lines = end_line - start_line + 1 + if start_line == end_line: + source += self.source[start_line - 1].encode("utf8")[start_col:end_col].decode( + "utf8").strip() + else: + for ind, lineno in enumerate(range(start_line - 1, end_line)): + if ind == 0: + source += self.source[lineno].encode("utf8")[start_col:].decode( + "utf8").strip() + elif ind == num_lines - 1: + source += self.source[lineno].encode("utf8")[:end_col].decode( + "utf8").strip() + else: + source += self.source[lineno].strip() + + return source + + def get_name(self, *, node=None, name=None, type=None, add_random_identifier=False): + + if node is not None: + name = node.__class__.__name__ + "_" + str(hex(int(time_ns()))) + type = node.__class__.__name__ + else: + if add_random_identifier: + name += f"_{str(hex(int(time_ns())))}" + + if len(self.scope) > 0: + return GNode(name=name, type=type, scope=copy(self.scope[-1])) + else: + return GNode(name=name, type=type) + # return (node.__class__.__name__ + "_" + str(hex(int(time_ns()))), node.__class__.__name__) + + def get_edges(self, as_dataframe=False): + edges = [] + edges.extend(self.parse(self.root)[0]) + # for f_def_node in ast.iter_child_nodes(self.root): + # if type(f_def_node) == ast.FunctionDef: + # edges.extend(self.parse(f_def_node)) + # break # to avoid going through nested definitions + + if not as_dataframe: + return edges + df = pd.DataFrame(edges) + return df.astype({col: "Int32" for col in df.columns if col not in {"src", "dst", "type"}}) + + def parse(self, node): + n_type = type(node) + method_name = "parse_" + n_type.__name__ + if hasattr(self, method_name): + return self.__getattribute__(method_name)(node) + else: + # print(type(node)) + # print(ast.dump(node)) + # print(node._fields) + # pprint(self.source) + return self.parse_as_expression(node) + # return self.generic_parse(node, node._fields) + # raise Exception() + # return [type(node)] + + def parse_body(self, nodes): + edges = [] + last_node = None + for node in nodes: + s = self.parse(node) + if isinstance(s, tuple): + # some parsers return edges and names? + edges.extend(s[0]) + + if last_node is not None: + edges.append({"dst": s[1], "src": last_node, "type": "next"}) + edges.append({"dst": last_node, "src": s[1], "type": "prev"}) + + last_node = s[1] + + for cond_name, cond_stat in zip(self.current_condition[-1:], self.condition_status[-1:]): + edges.append({"scope": copy(self.scope[-1]), "src": last_node, "dst": cond_name, "type": "defined_in_" + cond_stat}) + edges.append({"scope": copy(self.scope[-1]), "src": cond_name, "dst": last_node, "type": "defined_in_" + cond_stat + "_rev"}) + # edges.append({"scope": copy(self.scope[-1]), "src": cond_name, "dst": last_node, "type": "execute_when_" + cond_stat}) + else: + edges.extend(s) + return edges + + def parse_in_context(self, cond_name, cond_stat, edges, body): + if isinstance(cond_name, str): + cond_name = [cond_name] + cond_stat = [cond_stat] + elif isinstance(cond_name, GNode): + cond_name = [cond_name] + cond_stat = [cond_stat] + + for cn, cs in zip(cond_name, cond_stat): + self.current_condition.append(cn) + self.condition_status.append(cs) + + edges.extend(self.parse_body(body)) + + for i in range(len(cond_name)): + self.current_condition.pop(-1) + self.condition_status.pop(-1) + + def parse_as_expression(self, node, *args, **kwargs): + offset = to_offsets(self.full_source, + [(node.lineno - 1, node.end_lineno - 1, node.col_offset, node.end_col_offset, "expression")], + as_bytes=True) + offset, = offset + line = self.full_source[offset[0]: offset[1]].replace("@","##at##") + name = GNode(name=line, type="Name") + # expr = GNode(name="Expression" + "_" + str(hex(int(time_ns()))), type="mention") + expr = GNode(name=f"{line}@{self.scope[-1].name}", type="mention") + edges = [ + {"scope": copy(self.scope[-1]), "src": name, "dst": expr, "type": "local_mention", "line": node.lineno - 1, "end_line": node.end_lineno - 1, "col_offset": node.col_offset, "end_col_offset": node.end_col_offset}, + ] + + return edges, expr + + def parse_as_mention(self, name): + mention_name = GNode(name=name + "@" + self.scope[-1].name, type="mention", scope=copy(self.scope[-1])) + name = GNode(name=name, type="Name") + # mention_name = (name + "@" + self.scope[-1], "mention") + + # edge from name to mention in a function + edges = [ + {"scope": copy(self.scope[-1]), "src": name, "dst": mention_name, "type": "local_mention"}, + # {"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": mention_name, "type": "mention_scope"} + ] + return edges, mention_name + + def parse_operand(self, node): + # need to make sure upper level name is correct; handle @keyword; type placeholder for sourcetrail nodes? + edges = [] + if isinstance(node, str): + # fall here when parsing attributes, they are given as strings; should attributes be parsed into subwords? + if "@" in node: + parts = node.split("@") + node = GNode(name=parts[0], type=parts[1]) + else: + node = GNode(name=node, type="Name") + iter_ = node + elif isinstance(node, int) or node is None: + iter_ = GNode(name=str(node), type="ast_Literal") + # iter_ = str(node) + elif isinstance(node, GNode): + iter_ = node + else: + iter_e = self.parse(node) + if type(iter_e) == str: + iter_ = iter_e + elif isinstance(iter_e, GNode): + iter_ = iter_e + elif type(iter_e) == tuple: + ext_edges, name = iter_e + assert isinstance(name, GNode) + edges.extend(ext_edges) + iter_ = name + else: + # unexpected scenario + print(node) + print(ast.dump(node)) + print(iter_e) + print(type(iter_e)) + pprint(self.source) + print(self.source[node.lineno - 1].strip()) + raise Exception() + + return iter_, edges + + def parse_and_add_operand(self, node_name, operand, type, edges): + + operand_name, ext_edges = self.parse_operand(operand) + edges.extend(ext_edges) + + if hasattr(operand, "lineno"): + edges.append({"scope": copy(self.scope[-1]), "src": operand_name, "dst": node_name, "type": type, "line": operand.lineno-1, "end_line": operand.end_lineno-1, "col_offset": operand.col_offset, "end_col_offset": operand.end_col_offset}) + else: + edges.append({"scope": copy(self.scope[-1]), "src": operand_name, "dst": node_name, "type": type}) + + # if len(ext_edges) > 0: # need this to avoid adding reverse edges to operation names and other highly connected nodes + edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": operand_name, "type": type + "_rev"}) + + def generic_parse(self, node, operands, with_name=None, ensure_iterables=False): + + edges = [] + + if with_name is None: + node_name = self.get_name(node=node) + else: + node_name = with_name + + if len(self.scope) > 0: + edges.append({"scope": copy(self.scope[-1]), "src": node_name, "dst": self.scope[-1], "type": "mention_scope"}) + edges.append({"scope": copy(self.scope[-1]), "src": self.scope[-1], "dst": node_name, "type": "mention_scope_rev"}) + + for operand in operands: + if operand in ["body", "orelse", "finalbody"]: + self.parse_in_context(node_name, "operand", edges, node.__getattribute__(operand)) + else: + operand_ = node.__getattribute__(operand) + if operand_ is not None: + if isinstance(operand_, Iterable) and not isinstance(operand_, str): + # TODO: + # appears as leaf node if the iterable is empty. suggest adding an "EMPTY" element + for oper_ in operand_: + self.parse_and_add_operand(node_name, oper_, operand, edges) + else: + self.parse_and_add_operand(node_name, operand_, operand, edges) + + # TODO + # need to identify the benefit of this node + # maybe it is better to use node types in the graph + # edges.append({"scope": copy(self.scope[-1]), "src": node.__class__.__name__, "dst": node_name, "type": "node_type"}) + + return edges, node_name + + def parse_type_node(self, node): + # node.lineno, node.col_offset, node.end_lineno, node.end_col_offset + if node.lineno == node.end_lineno: + type_str = self.source[node.lineno][node.col_offset - 1: node.end_col_offset] + # print(type_str) + else: + type_str = "" + for ln in range(node.lineno - 1, node.end_lineno): + if ln == node.lineno - 1: + type_str += self.source[ln][node.col_offset - 1:].strip() + elif ln == node.end_lineno - 1: + type_str += self.source[ln][:node.end_col_offset].strip() + else: + type_str += self.source[ln].strip() + return type_str + + def parse_Module(self, node): + edges, module_name = self.generic_parse(node, []) + self.scope.append(module_name) + self.parse_in_context(module_name, "module", edges, node.body) + self.scope.pop(-1) + return edges, module_name + + def parse_FunctionDef(self, node): + # edges, f_name = self.generic_parse(node, ["name", "args", "returns", "decorator_list"]) + # edges, f_name = self.generic_parse(node, ["args", "returns", "decorator_list"]) + + # need to creare function name before generic_parse so that the scope is set up correctly + # scope is used to create local mentions of variable and function names + fdef_node_name = self.get_name(node=node) + self.scope.append(fdef_node_name) + + to_parse = [] + if len(node.args.args) > 0 or node.args.vararg is not None: + to_parse.append("args") + if len("decorator_list") > 0: + to_parse.append("decorator_list") + + edges, f_name = self.generic_parse(node, to_parse, with_name=fdef_node_name) + + if node.returns is not None: + # returns stores return type annotation + # can contain quotes + # https://stackoverflow.com/questions/46458470/should-you-put-quotes-around-type-annotations-in-python + # https://www.python.org/dev/peps/pep-0484/#forward-references + annotation_string = self.get_source_from_ast_range( + node.returns.lineno, node.returns.end_lineno, + node.returns.col_offset, node.returns.end_col_offset + ) + annotation = GNode(name=annotation_string, + type="type_annotation") + edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": f_name, "type": "returned_by", "line": node.returns.lineno - 1, "end_line": node.returns.end_lineno - 1, "col_offset": node.returns.col_offset, "end_col_offset": node.returns.end_col_offset}) + # do not use reverse edges for types, will result in leak from function to function + # edges.append({"scope": copy(self.scope[-1]), "src": f_name, "dst": annotation, "type": 'returns'}) + + # if node.returns: + # print(self.source[node.lineno -1]) # can get definition string here + + self.parse_in_context(f_name, "function", edges, node.body) + + self.scope.pop(-1) + + # func_name = GNode(name=node.name, type="Name") + ext_edges, func_name = self.parse_as_mention(name=node.name) + edges.extend(ext_edges) + + assert isinstance(node.name, str) + edges.append({"scope": copy(self.scope[-1]), "src": f_name, "dst": func_name, "type": "function_name"}) + edges.append({"scope": copy(self.scope[-1]), "src": func_name, "dst": f_name, "type": "function_name_rev"}) + + return edges, f_name + + def parse_AsyncFunctionDef(self, node): + return self.parse_FunctionDef(node) + + def parse_ClassDef(self, node): + + edges, class_node_name = self.generic_parse(node, []) + + self.scope.append(class_node_name) + self.parse_in_context(class_node_name, "class", edges, node.body) + self.scope.pop(-1) + + ext_edges, cls_name = self.parse_as_mention(name=node.name) + edges.extend(ext_edges) + edges.append({"scope": copy(self.scope[-1]), "src": class_node_name, "dst": cls_name, "type": "class_name"}) + edges.append({"scope": copy(self.scope[-1]), "src": cls_name, "dst": class_node_name, "type": "class_name_rev"}) + + return edges, class_node_name + + def parse_arg(self, node): + # node.annotation stores type annotation + # if node.annotation: + # print(self.source[node.lineno-1]) # can get definition string here + # print(node.arg) + + # # included mention + name = self.get_name(node=node) + edges, mention_name = self.parse_as_mention(node.arg) + edges.append({"scope": copy(self.scope[-1]), "src": mention_name, "dst": name, "type": 'arg'}) + edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": mention_name, "type": 'arg_rev'}) + # edges, name = self.generic_parse(node, ["arg"]) + if node.annotation is not None: + # can contain quotes + # https://stackoverflow.com/questions/46458470/should-you-put-quotes-around-type-annotations-in-python + # https://www.python.org/dev/peps/pep-0484/#forward-references + annotation_string = self.get_source_from_ast_range( + node.annotation.lineno, node.annotation.end_lineno, + node.annotation.col_offset, node.annotation.end_col_offset + ) + annotation = GNode(name=annotation_string, + type="type_annotation") + edges.append({"scope": copy(self.scope[-1]), "src": annotation, "dst": name, "type": 'annotation_for', "line": node.annotation.lineno-1, "end_line": node.annotation.end_lineno-1, "col_offset": node.annotation.col_offset, "end_col_offset": node.annotation.end_col_offset, "var_line": node.lineno-1, "var_end_line": node.end_lineno-1, "var_col_offset": node.col_offset, "var_end_col_offset": node.end_col_offset}) + # do not use reverse edges for types, will result in leak from function to function + # edges.append({"scope": copy(self.scope[-1]), "src": name, "dst": annotation, "type": 'annotation'}) + return edges, name + # return self.generic_parse(node, ["arg", "annotation"]) + + def parse_arguments(self, node): + # have missing fields. example: + # arguments(args=[arg(arg='self', annotation=None), arg(arg='tqdm_cls', annotation=None), arg(arg='sleep_interval', annotation=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]) + + # vararg constains type annotations + return self.generic_parse(node, ["args", "vararg"]) + + def parse_With(self, node): + edges, with_name = self.generic_parse(node, ["items"]) + + self.parse_in_context(with_name, "with", edges, node.body) + + return edges, with_name + + def parse_AsyncWith(self, node): + return self.parse_With(node) + + def parse_withitem(self, node): + return self.generic_parse(node, ['context_expr', 'optional_vars']) + + def parse_ExceptHandler(self, node): + # have missing fields. example: + # not parsing "name" field + # except handler is unique for every function + return self.generic_parse(node, ["type"]) + + def parse_name(self, node): + edges = [] + # if type(node) == ast.Attribute: + # left, ext_edges = self.parse(node.value) + # right = node.attr + # return self.parse(node.value) + "___" + node.attr + if type(node) == ast.Name: + return self.parse_as_mention(str(node.id)) + elif type(node) == ast.NameConstant: + return GNode(name=str(node.value), type="NameConstant") + + def parse_Name(self, node): + return self.parse_name(node) + + def parse_op_name(self, node): + return GNode(name=node.__class__.__name__, type="Op") + # return node.__class__.__name__ + + def parse_Str(self, node): + return self.generic_parse(node, []) + # return node.s + + def parse_If(self, node): + + edges, if_name = self.generic_parse(node, ["test"]) + + self.parse_in_context(if_name, "if_true", edges, node.body) + self.parse_in_context(if_name, "if_false", edges, node.orelse) + + return edges, if_name + + def parse_For(self, node): + + edges, for_name = self.generic_parse(node, ["target", "iter"]) + + self.parse_in_context(for_name, "for", edges, node.body) + self.parse_in_context(for_name, "for_orelse", edges, node.orelse) + + return edges, for_name + + def parse_AsyncFor(self, node): + return self.parse_For(node) + + def parse_Try(self, node): + + edges, try_name = self.generic_parse(node, []) + + self.parse_in_context(try_name, "try", edges, node.body) + + for h in node.handlers: + + handler_name, ext_edges = self.parse_operand(h) + edges.extend(ext_edges) + self.parse_in_context([try_name, handler_name], ["try_except", "try_handler"], edges, h.body) + + self.parse_in_context(try_name, "try_final", edges, node.finalbody) + self.parse_in_context(try_name, "try_else", edges, node.orelse) + + return edges, try_name + + def parse_While(self, node): + + edges, while_name = self.generic_parse(node, []) + + cond_name, ext_edges = self.parse_operand(node.test) + edges.extend(ext_edges) + + self.parse_in_context([while_name, cond_name], ["while", "if_true"], edges, node.body) + + return edges, while_name + + def parse_Expr(self, node): + edges = [] + expr_name, ext_edges = self.parse_operand(node.value) + edges.extend(ext_edges) + + # for cond_name, cons_stat in zip(self.current_condition, self.condition_status): + # edges.append({"scope": copy(self.scope[-1]), "src": expr_name, "dst": cond_name, "type": "depends_on_" + cons_stat}) + return edges, expr_name + + +if __name__ == "__main__": + import sys + f_bodies = pd.read_pickle(sys.argv[1]) + failed = 0 + for ind, c in enumerate(f_bodies['body_normalized']): + if isinstance(c, str): + # c = """def g(): + # yield 1""" + try: + g = AstGraphGenerator(c.lstrip()) + except SyntaxError as e: + print(e) + continue + failed += 1 + edges = g.get_edges() + # edges.to_csv(os.path.join(os.path.dirname(sys.argv[1]), "body_edges.csv"), mode="a", index=False, header=(ind==0)) + print("\r%d/%d" % (ind, len(f_bodies['body_normalized'])), end="") + else: + print("Skipped not a string") + + print(" " * 30, end="\r") + print(failed, len(f_bodies['body_normalized'])) diff --git a/SourceCodeTools/code/python_tokens_to_bpe_subwords.py b/SourceCodeTools/code/ast/python_tokens_to_bpe_subwords.py similarity index 90% rename from SourceCodeTools/code/python_tokens_to_bpe_subwords.py rename to SourceCodeTools/code/ast/python_tokens_to_bpe_subwords.py index 9d6ad7c5..648839a8 100644 --- a/SourceCodeTools/code/python_tokens_to_bpe_subwords.py +++ b/SourceCodeTools/code/ast/python_tokens_to_bpe_subwords.py @@ -59,9 +59,9 @@ 'Pow': "**", 'FloorDiv': "//", 'GtE': ">=", - 'USub': "-", # this appearsto be the operator to chenge number sign -5 + 'USub': "-", # this appears to be the operator to change number sign -5 'Invert': "~", - 'UAdd': "+", # this appearsto be the operator to chenge number sign +5, method __pos__ + 'UAdd': "+", # this appears to be the operator to change number sign +5, method __pos__ 'MatMult': "@", 'BitXor': "^", 'LShift': "<<", @@ -71,9 +71,11 @@ "Continue": "continue", } + def op_tokenizer(op: str): return python_ops_to_bpe.get(op, op) + @lru_cache def op_tokenize_or_none(op, tokenize_func): return tokenize_func(python_ops_to_literal[op]) if op in python_ops_to_literal else None \ No newline at end of file diff --git a/SourceCodeTools/code/common.py b/SourceCodeTools/code/common.py new file mode 100644 index 00000000..299a8aac --- /dev/null +++ b/SourceCodeTools/code/common.py @@ -0,0 +1,141 @@ +import hashlib +import os +import sqlite3 + +import pandas as pd +from tqdm import tqdm + +from SourceCodeTools.code.data.file_utils import unpersist + + +class SQLTable: + def __init__(self, df, filename, table_name): + self.conn = sqlite3.connect(filename) + self.path = filename + self.table_name = table_name + + df.to_sql(self.table_name, con=self.conn, if_exists='replace', index=False, index_label=df.columns) + + def query(self, query_string): + return pd.read_sql(query_string, self.conn) + + def __del__(self): + self.conn.close() + if os.path.isfile(self.path): + os.remove(self.path) + + +def create_node_repr(nodes): + return list(zip(nodes['serialized_name'], nodes['type'])) + + +def compute_long_id(obj): + return hashlib.md5(repr(obj).encode('utf-8')).hexdigest() + + +def map_id_columns(df, column_names, mapper): + df = df.copy() + for col in column_names: + if col in df.columns: + df[col] = df[col].apply(lambda x: mapper.get(x, pd.NA)) + return df + + +def merge_with_file_if_exists(df, merge_with_file): + if os.path.isfile(merge_with_file): + original_data = unpersist(merge_with_file) + data = pd.concat([original_data, df], axis=0) + else: + data = df + return data + + +def custom_tqdm(iterable, total, message): + return tqdm(iterable, total=total, desc=message, leave=False, dynamic_ncols=True) + + +def map_columns(input_table, id_map, columns, columns_special=None): + + input_table = map_id_columns(input_table, columns, id_map) + + if columns_special is not None: + assert isinstance(columns_special, list), "`columns_special` should be iterable" + for column, map_func in columns_special: + input_table[column] = map_func(input_table[column], id_map) + + if len(input_table) == 0: + return None + else: + return input_table + + +def grow_with_chunks(chunks, additional_dtypes): + dtypes = {} + + table = None + + for chunk in chunks: + for col, type_ in additional_dtypes.items(): + if col in chunk.columns: + dtypes[col] = type_ + + chunk = chunk.astype(dtypes) + + if table is None: + table = chunk + else: + table = pd.concat([table, chunk], copy=False) + return table + + +def return_chunks(chunks, additional_dtypes): + dtypes = {} + + for chunk in chunks: + for col, type_ in additional_dtypes.items(): + if col in chunk.columns: + dtypes[col] = type_ + + chunk = chunk.astype(dtypes, copy=False) + yield chunk + + +def read_nodes(node_path, as_chunks=False): + dtypes = { + "id": "int32", + "serialized_name": "string", + } + + nodes_chunks = unpersist(node_path, dtype=dtypes, chunksize=100000) + + additional_dtypes = { + 'type': 'category', + "mentioned_in": "Int32", + "string": "string" + } + + if as_chunks is False: + return grow_with_chunks(nodes_chunks, additional_dtypes) + else: + return return_chunks(nodes_chunks, additional_dtypes) + + +def read_edges(edge_path, as_chunks=False): + dtypes = { + "id": "int32", + "source_node_id": "int32", + "target_node_id": "int32", + } + + edge_chunks = unpersist(edge_path, dtype=dtypes, chunksize=100000) + + additional_types = { + "type": 'category', + "mentioned_in": "Int32", + "file_id": "Int32" + } + + if as_chunks: + return return_chunks(edge_chunks, additional_types) + else: + return grow_with_chunks(edge_chunks, additional_types) diff --git a/SourceCodeTools/code/data/AbstractDatasetCreator.py b/SourceCodeTools/code/data/AbstractDatasetCreator.py new file mode 100644 index 00000000..f9e2dee5 --- /dev/null +++ b/SourceCodeTools/code/data/AbstractDatasetCreator.py @@ -0,0 +1,458 @@ +import logging +import os +import shelve +import shutil +import tempfile +from abc import abstractmethod +from collections import defaultdict +from copy import copy +from functools import partial +from os.path import join + +from tqdm import tqdm + +from SourceCodeTools.code.annotator_utils import map_offsets +from SourceCodeTools.code.common import map_columns, read_edges, read_nodes +from SourceCodeTools.code.data.file_utils import get_random_name, unpersist, persist, unpersist_if_present +from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import special_mapping + + +class AbstractDatasetCreator: + """ + Merges several environments indexed with Sourcetrail into a single graph. + """ + + merging_specification = { + "nodes.bz2": {"columns": ['id'], "output_path": "common_nodes.jsonl", "ensure_unique_with": ['type', 'serialized_name']}, + "edges.bz2": {"columns": ['target_node_id', 'source_node_id'], "output_path": "common_edges.jsonl"}, + "source_graph_bodies.bz2": {"columns": ['id'], "output_path": "common_source_graph_bodies.jsonl", "columns_special": [("replacement_list", map_offsets)]}, + "function_variable_pairs.bz2": {"columns": ['src'], "output_path": "common_function_variable_pairs.jsonl"}, + "call_seq.bz2": {"columns": ['src', 'dst'], "output_path": "common_call_seq.jsonl"}, + + "nodes_with_ast.bz2": {"columns": ['id', 'mentioned_in'], "output_path": "common_nodes.jsonl", "ensure_unique_with": ['type', 'serialized_name']}, + "edges_with_ast.bz2": {"columns": ['target_node_id', 'source_node_id', 'mentioned_in'], "output_path": "common_edges.jsonl"}, + "offsets.bz2": {"columns": ['node_id'], "output_path": "common_offsets.jsonl", "columns_special": [("mentioned_in", map_offsets)]}, + "filecontent_with_package.bz2": {"columns": [], "output_path": "common_filecontent.jsonl"}, + "name_mappings.bz2": {"columns": [], "output_path": "common_name_mappings.jsonl"}, + } + + files_for_merging = [ + "nodes.bz2", "edges.bz2", "source_graph_bodies.bz2", "function_variable_pairs.bz2", "call_seq.bz2" + ] + files_for_merging_with_ast = [ + "nodes_with_ast.bz2", "edges_with_ast.bz2", "source_graph_bodies.bz2", "function_variable_pairs.bz2", + "call_seq.bz2", "offsets.bz2", "filecontent_with_package.bz2", "name_mappings.bz2" + ] + + restricted_edges = {} + restricted_in_types = {} + + type_annotation_edge_types = [] + + environments = None + edge_priority = dict() + + def __init__( + self, path, lang, bpe_tokenizer, create_subword_instances, connect_subwords, only_with_annotations, + do_extraction=False, visualize=False, track_offsets=False, remove_type_annotations=False, + recompute_l2g=False + ): + """ + :param path: path to source code dataset + :param lang: language to use for AST parser (only Python for now) + :param bpe_tokenizer: path to bpe tokenizer model + :param create_subword_instances: whether to create nodes that represent subword instances (doubles the + number of nodes) + :param connect_subwords: whether to connect subword instances so that the order of subwords is stored + in the graph. Has effect only when create_subword_instances=True + :param only_with_annotations: include only packages that have type annotations into the final graph + :param do_extraction: when True, process source code and extract AT edges. Otherwise, existing files are + used. + :param visualize: visualize graph using pygraphviz and store as PDF (infeasible for large graphs) + :param track_offsets: store offset information and map node occurrences to global graph ids + :param remove_type_annotations: when True, removes all type annotations from the graph and stores then + in a file called `type_annotations.bz2` + :param recompute_l2g: when True, run merging operation again, without extrcting AST nodes and edges second time + """ + self.indexed_path = path + self.lang = lang + self.bpe_tokenizer = bpe_tokenizer + self.create_subword_instances = create_subword_instances + self.connect_subwords = connect_subwords + self.only_with_annotations = only_with_annotations + self.extract = do_extraction + self.visualize = visualize + self.track_offsets = track_offsets + self.remove_type_annotations = remove_type_annotations + self.recompute_l2g = recompute_l2g + + self.path = path + self._prepare_environments() + + self._init_cache() + + def _init_cache(self): + # TODO this is wrong, use standard utilities + rnd_name = get_random_name() + + self.tmp_dir = os.path.join(tempfile.gettempdir(), rnd_name) + if os.path.isdir(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + os.mkdir(self.tmp_dir) + + self.local2global_cache_filename = os.path.join(self.tmp_dir, "local2global_cache.db") + self.local2global_cache = shelve.open(self.local2global_cache_filename) + + def __del__(self): + self.local2global_cache.close() + shutil.rmtree(self.tmp_dir) + # os.remove(self.local2global_cache_filename) # TODO nofile on linux, need to check + + def handle_parallel_edges(self, edges_path): + logging.info("Handle parallel edges") + last_id = 0 + + global_edge_types = set(special_mapping.keys()) | set(special_mapping.values()) + + existing_global_edges = set() + + temp_edges = join(os.path.dirname(edges_path), "temp_" + os.path.basename(edges_path)) + + for ind, edges in enumerate(read_edges(edges_path, as_chunks=True)): + edges["id"] = range(last_id, len(edges) + last_id) + + edge_bank = defaultdict(list) + ids_to_remove = set() + for id_, type_, src, dst in edges[["id", "type", "source_node_id", "target_node_id"]].values: + if type_ in global_edge_types: + global_edge = (type_, src, dst) + if global_edge not in existing_global_edges: + existing_global_edges.add(global_edge) + else: + ids_to_remove.add(id_) + else: + edge_bank[(src, dst)].append((id_, type_)) + + for key, parallel_edges in edge_bank.items(): + if len(parallel_edges) > 1: + parallel_edges = sorted(parallel_edges, key=lambda x: self.edge_priority.get(x[1], 3)) + ids_to_remove.update(pe[0] for pe in parallel_edges[1:]) + + edges = edges[ + edges["id"].apply(lambda id_: id_ not in ids_to_remove) + ] + + edges["id"] = range(last_id, len(edges) + last_id) + last_id = len(edges) + last_id + + kwargs = self.get_writing_mode(temp_edges.endswith("csv"), first_written=ind != 0) + persist(edges, temp_edges, **kwargs) + + os.remove(edges_path) + os.rename(temp_edges, edges_path) + + def post_pruning(self, nodes_path, edges_path): + logging.info("Post pruning") + + restricted_nodes = set() + + for nodes in read_nodes(nodes_path, as_chunks=True): + restricted_nodes.update( + nodes[ + nodes["type"].apply(lambda type_: type_ in self.restricted_in_types) + ]["id"] + ) + + temp_edges = join(os.path.dirname(edges_path), "temp_" + os.path.basename(edges_path)) + + for ind, edges in enumerate(read_edges(edges_path, as_chunks=True)): + edges = edges[ + edges["type"].apply(lambda type_: type_ not in self.restricted_edges) + ] + + edges = edges[ + edges["target_node_id"].apply(lambda type_: type_ not in restricted_nodes) + ] + + kwargs = self.get_writing_mode(temp_edges.endswith("csv"), first_written=ind != 0) + persist(edges, temp_edges, **kwargs) + + os.remove(edges_path) + os.rename(temp_edges, edges_path) + + def compact_mapping_for_l2g(self, global_nodes, filename): + if len(global_nodes) > 0: + self.update_l2g_file( + mapping=self.create_compact_mapping(global_nodes), filename=filename + ) + + @staticmethod + def create_compact_mapping(node_ids): + return dict(zip(node_ids, range(len(node_ids)))) + + def update_l2g_file(self, mapping, filename): + for env_path in tqdm(self.environments, desc=f"Fixing {filename}"): + filepath = os.path.join(env_path, filename) + if not os.path.isfile(filepath): + continue + l2g = unpersist(filepath) + l2g["global_id"] = l2g["global_id"].apply(lambda id_: mapping.get(id_, None)) + persist(l2g, filepath) + + def get_local2global(self, path): + if path in self.local2global_cache: + return self.local2global_cache[path] + else: + local2global_df = unpersist_if_present(path) + if local2global_df is None: + return None + else: + local2global = dict(zip(local2global_df['id'], local2global_df['global_id'])) + self.local2global_cache[path] = local2global + return local2global + + @staticmethod + def persist_if_not_none(table, dir, name): + if table is not None: + path = os.path.join(dir, name) + persist(table, path) + + def write_type_annotation_flag(self, edges, output_dir): + if len(self.type_annotation_edge_types) > 0: + query_str = " or ".join(f"type == '{edge_type}'" for edge_type in self.type_annotation_edge_types) + if len(edges.query(query_str)) > 0: + with open(os.path.join(output_dir, "has_annotations"), "w") as has_annotations: + pass + + def write_local(self, dir, local2global=None, local2global_with_ast=None, **kwargs): + + if not self.recompute_l2g: + for var_name, var_ in kwargs.items(): + self.persist_if_not_none(var_, dir, var_name + ".bz2") + + self.persist_if_not_none(local2global, dir, "local2global.bz2") + self.persist_if_not_none(local2global_with_ast, dir, "local2global_with_ast.bz2") + + def merge_files(self, env_path, filename, map_filename, columns_to_map, original, columns_special=None): + input_table_path = join(env_path, filename) + local2global = self.get_local2global(join(env_path, map_filename)) + if os.path.isfile(input_table_path) and local2global is not None: + input_table = unpersist(input_table_path) + if self.only_with_annotations: + if not os.path.isfile(join(env_path, "has_annotations")): + return original + new_table = map_columns(input_table, local2global, columns_to_map, columns_special=columns_special) + if original is None: + return new_table + else: + return original.append(new_table) + else: + return original + + def read_mapped_local(self, env_path, filename, map_filename, columns_to_map, columns_special=None): + input_table_path = join(env_path, filename) + local2global = self.get_local2global(join(env_path, map_filename)) + if os.path.isfile(input_table_path) and local2global is not None: + if self.only_with_annotations: + if not os.path.isfile(join(env_path, "has_annotations")): + return None + input_table = unpersist(input_table_path) + new_table = map_columns(input_table, local2global, columns_to_map, columns_special=columns_special) + return new_table + else: + return None + + def get_writing_mode(self, is_csv, first_written): + kwargs = {} + if first_written is True: + kwargs["mode"] = "a" + if is_csv: + kwargs["header"] = False + return kwargs + + def create_global_file( + self, local_file, local2global_file, columns, output_path, message, ensure_unique_with=None, + columns_special=None + ): + assert output_path.endswith("json") or output_path.endswith("csv") + + if ensure_unique_with is not None: + unique_values = set() + else: + unique_values = None + + first_written = False + + for ind, env_path in tqdm( + enumerate(self.environments), desc=message, leave=True, + dynamic_ncols=True, total=len(self.environments) + ): + mapped_local = self.read_mapped_local( + env_path, local_file, local2global_file, columns, columns_special=columns_special + ) + + if mapped_local is not None: + if unique_values is not None: + unique_verify = list(zip(*(mapped_local[col_name] for col_name in ensure_unique_with))) + + mapped_local = mapped_local.loc[ + map(lambda x: x not in unique_values, unique_verify) + ] + unique_values.update(unique_verify) + + kwargs = self.get_writing_mode(output_path.endswith("csv"), first_written) + + persist(mapped_local, output_path, **kwargs) + first_written = True + + + # def create_global_file( + # self, local_file, local2global_file, columns, output_path, message, ensure_unique_with=None, + # columns_special=None + # ): + # global_table = None + # for ind, env_path in tqdm( + # enumerate(self.environments), desc=message, leave=True, + # dynamic_ncols=True, total=len(self.environments) + # ): + # global_table = self.merge_files( + # env_path, local_file, local2global_file, columns, global_table, columns_special=columns_special + # ) + # + # if ensure_unique_with is not None: + # global_table = global_table.drop_duplicates(subset=ensure_unique_with) + # + # if global_table is not None: + # global_table.reset_index(drop=True, inplace=True) + # assert len(global_table) == len(global_table.index.unique()) + # + # persist(global_table, output_path) + + def filter_orphaned_nodes(self, nodes_path, edges_path): + logging.info("Filter orphaned nodes") + active_nodes = set() + + for edges in read_edges(edges_path, as_chunks=True): + active_nodes.update(edges['source_node_id']) + active_nodes.update(edges['target_node_id']) + + temp_nodes = join(os.path.dirname(nodes_path), "temp_" + os.path.basename(nodes_path)) + + for ind, nodes in enumerate(read_nodes(nodes_path, as_chunks=True)): + nodes = nodes[ + nodes['id'].apply(lambda id_: id_ in active_nodes) + ] + + kwargs = self.get_writing_mode(temp_nodes.endswith("csv"), first_written=ind != 0) + persist(nodes, temp_nodes, **kwargs) + + os.remove(nodes_path) + os.rename(temp_nodes, nodes_path) + + def join_files(self, files, local2global_filename, output_dir): + for file in files: + params = copy(self.merging_specification[file]) + params["output_path"] = join(output_dir, params.pop("output_path")) + self.create_global_file(file, local2global_filename, message=f"Merging {file}", **params) + + def merge_graph_without_ast(self, output_path): + self.join_files(self.files_for_merging, "local2global.bz2", output_path) + + get_path = partial(join, output_path) + + nodes_path = get_path("common_nodes.json") + edges_path = get_path("common_edges.json") + + self.filter_orphaned_nodes( + nodes_path, + edges_path, + ) + node_names = self.extract_node_names( + nodes_path, min_count=2 + ) + if node_names is not None: + persist(node_names, get_path("node_names.json")) + + self.handle_parallel_edges(edges_path) + + if self.visualize: + self.visualize_func( + read_nodes(nodes_path), + read_edges(edges_path), + get_path("visualization.pdf") + ) + + def merge_graph_with_ast(self, output_path): + + self.join_files(self.files_for_merging_with_ast, "local2global_with_ast.bz2", output_path) + + get_path = partial(join, output_path) + + nodes_path = get_path("common_nodes.json") + edges_path = get_path("common_edges.json") + + if self.remove_type_annotations: + self.filter_type_edges(nodes_path, edges_path) + + self.handle_parallel_edges(edges_path) + + self.post_pruning(nodes_path, edges_path) + + self.filter_orphaned_nodes( + nodes_path, + edges_path, + ) + # persist(global_nodes, get_path("common_nodes.json")) + node_names = self.extract_node_names( + nodes_path, min_count=2 + ) + if node_names is not None: + persist(node_names, get_path("node_names.json")) + + if self.visualize: + self.visualize_func( + read_nodes(nodes_path), + read_edges(edges_path), + get_path("visualization.pdf") + ) + + @abstractmethod + def create_output_dirs(self, output_path): + pass + + @abstractmethod + def _prepare_environments(self): + pass + + @staticmethod + @abstractmethod + def filter_type_edges(nodes, edges): + pass + + @staticmethod + @abstractmethod + def extract_node_names(nodes, min_count): + pass + + @abstractmethod + def do_extraction(self): + pass + + @abstractmethod + def merge(self, output_directory): + pass + + if self.extract: + logging.info("Extracting...") + self.do_extraction() + + no_ast_path, with_ast_path = self.create_output_dirs(output_directory) + + if not self.only_with_annotations: + self.merge_graph_without_ast(no_ast_path) + + self.merge_graph_with_ast(with_ast_path) + + @abstractmethod + def visualize_func(self, nodes, edges, output_path): + pass \ No newline at end of file diff --git a/SourceCodeTools/code/data/ast_graph/build_ast_graph.py b/SourceCodeTools/code/data/ast_graph/build_ast_graph.py new file mode 100644 index 00000000..5772284e --- /dev/null +++ b/SourceCodeTools/code/data/ast_graph/build_ast_graph.py @@ -0,0 +1,715 @@ +import hashlib +import logging +import os.path +import shutil +import sys +from os.path import join + +import pandas as pd + +from tqdm import tqdm + +from SourceCodeTools.cli_arguments import AstDatasetCreatorArguments +from SourceCodeTools.code.ast import has_valid_syntax +from SourceCodeTools.code.common import read_nodes, read_edges +from SourceCodeTools.code.data.AbstractDatasetCreator import AbstractDatasetCreator +from SourceCodeTools.code.data.ast_graph.extract_node_names import extract_node_names +from SourceCodeTools.code.data.ast_graph.filter_type_edges import filter_type_edges, filter_type_edges_with_chunks +from SourceCodeTools.code.data.file_utils import persist, unpersist, unpersist_if_present +from SourceCodeTools.code.data.ast_graph.draw_graph import visualize +from SourceCodeTools.code.ast.python_ast2 import AstGraphGenerator, GNode, PythonSharedNodes +from SourceCodeTools.code.annotator_utils import adjust_offsets2, map_offsets, to_offsets, get_cum_lens +from SourceCodeTools.code.data.ast_graph.local2global import get_local2global +from SourceCodeTools.nlp.string_tools import get_byte_to_char_map + + +class MentionTokenizer: + def __init__(self, bpe_tokenizer_path, create_subword_instances, connect_subwords): + from SourceCodeTools.nlp.embed.bpe import make_tokenizer + from SourceCodeTools.nlp.embed.bpe import load_bpe_model + + self.bpe = make_tokenizer(load_bpe_model(bpe_tokenizer_path)) \ + if bpe_tokenizer_path else None + self.create_subword_instances = create_subword_instances + self.connect_subwords = connect_subwords + + def replace_mentions_with_subwords(self, edges): + """ + Process edges and tokenize certain node types + :param edges: List of edges + :return: List of edges, including new edges for subword tokenization + """ + + if self.create_subword_instances: + def produce_subw_edges(subwords, dst): + return self.produce_subword_edges_with_instances(subwords, dst) + else: + def produce_subw_edges(subwords, dst): + return self.produce_subword_edges(subwords, dst, self.connect_subwords) + + new_edges = [] + for edge in edges: + + if edge['type'] == "local_mention": + + dst = edge['dst'] + + if self.bpe is not None: + if hasattr(dst, "name_scope") and dst.name_scope == "local": + # TODO + # this rule seems to be irrelevant now + subwords = self.bpe(dst.name.split("@")[0]) + else: + subwords = self.bpe(edge['src'].name) + + new_edges.extend(produce_subw_edges(subwords, dst)) + else: + new_edges.append(edge) + + elif self.bpe is not None and edge["type"] == "__global_name": + # should not have global edges + # subwords = self.bpe(edge['src'].name) + # new_edges.extend(produce_subw_edges(subwords, edge['dst'])) + pass + elif self.bpe is not None and edge['src'].type in PythonSharedNodes.tokenizable_types_and_annotations: + new_edges.append(edge) + if edge['type'] != "global_mention_rev": + # should not have global edges here + pass + + dst = edge['src'] + subwords = self.bpe(dst.name) + new_edges.extend(produce_subw_edges(subwords, dst)) + else: + new_edges.append(edge) + + return new_edges + + @staticmethod + def connect_prev_next_subwords(edges, current, prev_subw, next_subw): + if next_subw is not None: + edges.append({ + 'src': current, + 'dst': next_subw, + 'type': 'next_subword', + 'offsets': None + }) + if prev_subw is not None: + edges.append({ + 'src': current, + 'dst': prev_subw, + 'type': 'prev_subword', + 'offsets': None + }) + + def produce_subword_edges(self, subwords, dst, connect_subwords=False): + new_edges = [] + + subwords = list(map(lambda x: GNode(name=x, type="subword"), subwords)) + for ind, subword in enumerate(subwords): + new_edges.append({ + 'src': subword, + 'dst': dst, + 'type': 'subword', + 'offsets': None + }) + if connect_subwords: + self.connect_prev_next_subwords(new_edges, subword, subwords[ind - 1] if ind > 0 else None, + subwords[ind + 1] if ind < len(subwords) - 1 else None) + return new_edges + + def produce_subword_edges_with_instances(self, subwords, dst, connect_subwords=True): + new_edges = [] + + subwords = list(map(lambda x: GNode(name=x, type="subword"), subwords)) + instances = list(map(lambda x: GNode(name=x.name + "@" + dst.name, type="subword_instance"), subwords)) + for ind, subword in enumerate(subwords): + subword_instance = instances[ind] + new_edges.append({ + 'src': subword, + 'dst': subword_instance, + 'type': 'subword_instance', + 'offsets': None + }) + new_edges.append({ + 'src': subword_instance, + 'dst': dst, + 'type': 'subword', + 'offsets': None + }) + if connect_subwords: + self.connect_prev_next_subwords(new_edges, subword_instance, instances[ind - 1] if ind > 0 else None, + instances[ind + 1] if ind < len(instances) - 1 else None) + return new_edges + + +class NodeIdResolver: + def __init__(self): + self.node_ids = {} + self.new_nodes = [] + self.stashed_nodes = [] + + self._resolver_cache = dict() + + def stash_new_nodes(self): + """ + Put new nodes into temporary storage. + :return: Nothing + """ + self.stashed_nodes.extend(self.new_nodes) + self.new_nodes = [] + + def get_node_id(self, type_name): + return hashlib.md5(type_name.encode('utf-8')).hexdigest() + + def resolve_node_id(self, node, **kwargs): + """ + Resolve node id from name and type, create new node is no nodes like that found. + :param node: node + :param kwargs: + :return: updated node (return object with the save reference as input) + """ + if not hasattr(node, "id"): + node_repr = f"{node.name.strip()}_{node.type.strip()}" + + if node_repr in self.node_ids: + node_id = self.node_ids[node_repr] + node.setprop("id", node_id) + else: + new_id = self.get_node_id(node_repr) + self.node_ids[node_repr] = new_id + + if not PythonSharedNodes.is_shared(node) and not node.name == "unresolved_name": + assert "0x" in node.name + + self.new_nodes.append( + { + "id": new_id, + "type": node.type, + "serialized_name": node.name, + "mentioned_in": pd.NA, + "string": node.string + } + ) + if hasattr(node, "scope"): + self.resolve_node_id(node.scope) + self.new_nodes[-1]["mentioned_in"] = node.scope.id + node.setprop("id", new_id) + return node + + def prepare_for_write(self, from_stashed=False): + nodes = self.new_nodes_for_write(from_stashed)[ + ['id', 'type', 'serialized_name', 'mentioned_in', 'string'] + ] + + return nodes + + def new_nodes_for_write(self, from_stashed=False): + + new_nodes = pd.DataFrame(self.new_nodes if not from_stashed else self.stashed_nodes) + if len(new_nodes) == 0: + return None + + new_nodes = new_nodes[ + ['id', 'type', 'serialized_name', 'mentioned_in', 'string'] + ].astype({"mentioned_in": "string", "id": "string"}) + + return new_nodes + + def adjust_ast_node_types(self, mapping, from_stashed=False): + nodes = self.new_nodes if not from_stashed else self.stashed_nodes + + for node in nodes: + node["type"] = mapping.get(node["type"], node["type"]) + + def drop_nodes(self, node_ids_to_drop, from_stashed=False): + nodes = self.new_nodes if not from_stashed else self.stashed_nodes + + position = 0 + while position < len(nodes): + if nodes[position]["id"] in node_ids_to_drop: + nodes.pop(position) + else: + position += 1 + + def map_mentioned_in_to_global(self, mapping, from_stashed=False): + nodes = self.new_nodes if not from_stashed else self.stashed_nodes + + for node in nodes: + node["mentioned_in"] = mapping.get(node["mentioned_in"], node["mentioned_in"]) + + +class AstProcessor(AstGraphGenerator): + def get_edges(self, as_dataframe=True): + edges = [] + all_edges, top_node_name = self.parse(self.root) + edges.extend(all_edges) + + if as_dataframe: + df = pd.DataFrame(edges) + df = df.astype({col: "Int32" for col in df.columns if col not in {"src", "dst", "type"}}) + + body = "\n".join(self.source) + cum_lens = get_cum_lens(body, as_bytes=True) + byte2char = get_byte_to_char_map(body) + + def format_offsets(edges: pd.DataFrame): + edges["start_line__end_line__start_column__end_column"] = list( + zip(edges["line"], edges["end_line"], edges["col_offset"], edges["end_col_offset"]) + ) + + def into_offset(range): + try: + return to_offsets(body, [(*range, None)], cum_lens=cum_lens, b2c=byte2char, as_bytes=True)[-1][:2] + except: + return None + + edges["offsets"] = edges["start_line__end_line__start_column__end_column"].apply(into_offset) + edges.drop( + axis=1, + labels=[ + "start_line__end_line__start_column__end_column", + "line", + "end_line", + "col_offset", + "end_col_offset" + ], inplace=True + ) + + format_offsets(df) + return df + else: + body = "\n".join(self.source) + cum_lens = get_cum_lens(body, as_bytes=True) + byte2char = get_byte_to_char_map(body) + + def format_offsets(edge): + def into_offset(range): + try: + return to_offsets(body, [(*range, None)], cum_lens=cum_lens, b2c=byte2char, as_bytes=True)[-1][:2] + except: + return None + + if "line" in edge: + edge["offsets"] = into_offset( + (edge["line"], edge["end_line"], edge["col_offset"], edge["end_col_offset"]) + ) + edge.pop("line") + edge.pop("end_line") + edge.pop("col_offset") + edge.pop("end_col_offset") + else: + edge["offsets"] = None + if "var_line" in edge: + edge["var_offsets"] = into_offset( + (edge["var_line"], edge["var_end_line"], edge["var_col_offset"], edge["var_end_col_offset"]) + ) + edge.pop("var_line") + edge.pop("var_end_line") + edge.pop("var_col_offset") + edge.pop("var_end_col_offset") + + for edge in edges: + format_offsets(edge) + + return edges + + +def standardize_new_edges(edges, node_resolver, mention_tokenizer): + """ + Tokenize relevant node names, assign id to every node, collapse edge representation to id-based + :param edges: list of edges + :param node_resolver: helper class that tracks node ids + :param mention_tokenizer: helper class that performs tokenization of relevant nodes + :return: + """ + + edges = mention_tokenizer.replace_mentions_with_subwords(edges) + + resolve_node_id = lambda node: node_resolver.resolve_node_id(node) + extract_id = lambda node: node.id + + for edge in edges: + edge["src"] = resolve_node_id(edge["src"]) + edge["dst"] = resolve_node_id(edge["dst"]) + if "scope" in edge: + edge["scope"] = resolve_node_id(edge["scope"]) + + for edge in edges: + edge["src"] = extract_id(edge["src"]) + edge["dst"] = extract_id(edge["dst"]) + if "scope" in edge: + edge["scope"] = extract_id(edge["scope"]) + else: + edge["scope"] = pd.NA + edge["file_id"] = pd.NA + + return edges + + +def process_code_without_index(source_code, node_resolver, mention_tokenizer, track_offsets=False): + try: + ast_processor = AstProcessor(source_code) + except: + return None, None, None + try: # TODO recursion error does not appear consistently. The issue is probably with library versions... + edges = ast_processor.get_edges(as_dataframe=False) + except RecursionError: + return None, None, None + + if len(edges) == 0: + return None, None, None + + # tokenize names, replace nodes with their ids + edges = standardize_new_edges(edges, node_resolver, mention_tokenizer) + + if track_offsets: + def get_valid_offsets(edges): + """ + :param edges: Dictionary that represents edge. Information is tored in edges but is related to source node + :return: Information about location of this edge (offset) in the source file in fromat (start, end, node_id) + """ + return [(edge["offsets"][0], edge["offsets"][1], edge["src"], edge["scope"]) for edge in edges if edge["offsets"] is not None] + + # recover ast offsets for the current file + ast_offsets = get_valid_offsets(edges) + else: + ast_offsets = None + + return edges, ast_offsets + + +def build_ast_only_graph( + source_codes, bpe_tokenizer_path, create_subword_instances, connect_subwords, lang, track_offsets=False +): + node_resolver = NodeIdResolver() + mention_tokenizer = MentionTokenizer(bpe_tokenizer_path, create_subword_instances, connect_subwords) + all_ast_edges = [] + all_offsets = [] + + for package, source_code_id, source_code in tqdm(source_codes, desc="Processing modules"): + source_code_ = source_code.lstrip() + initial_strip = source_code[:len(source_code) - len(source_code_)] + + if not has_valid_syntax(source_code): + continue + + edges, ast_offsets = process_code_without_index( + source_code, node_resolver, mention_tokenizer, track_offsets=track_offsets + ) + + if ast_offsets is not None: + adjust_offsets2(ast_offsets, len(initial_strip)) + + if edges is None: + continue + + # afterprocessing + + for edge in edges: + edge["file_id"] = source_code_id + + # finish afterprocessing + + all_ast_edges.extend(edges) + + def format_offsets(ast_offsets, target): + """ + Format offset as a record and add to the common storage for offsets + :param ast_offsets: + :param target: List where all other offsets are stored. + :return: Nothing + """ + if ast_offsets is not None: + for offset in ast_offsets: + target.append({ + "file_id": source_code_id, + "start": offset[0], + "end": offset[1], + "node_id": offset[2], + "mentioned_in": offset[3], + "string": source_code[offset[0]: offset[1]], + "package": package + }) + + format_offsets(ast_offsets, target=all_offsets) + + node_resolver.stash_new_nodes() + + all_ast_nodes = node_resolver.new_nodes_for_write(from_stashed=True) + + if all_ast_nodes is None: + return None, None, None + + def prepare_edges(all_ast_edges): + all_ast_edges = pd.DataFrame(all_ast_edges) + all_ast_edges.drop_duplicates(["type", "src", "dst"], inplace=True) + all_ast_edges = all_ast_edges.query("src != dst") + all_ast_edges["id"] = 0 + + all_ast_edges = all_ast_edges[["id", "type", "src", "dst", "file_id", "scope"]] \ + .rename({'src': 'source_node_id', 'dst': 'target_node_id', 'scope': 'mentioned_in'}, axis=1) \ + .astype({'file_id': 'Int32', "mentioned_in": 'string'}) + + all_ast_edges["id"] = range(len(all_ast_edges)) + return all_ast_edges + + all_ast_edges = prepare_edges(all_ast_edges) + + if len(all_offsets) > 0: + all_offsets = pd.DataFrame(all_offsets) + else: + all_offsets = None + + node2id = dict(zip(all_ast_nodes["id"], range(len(all_ast_nodes)))) + + def map_columns_to_int(table, dense_columns, sparse_columns): + types = {column: "int64" for column in dense_columns} + types.update({column: "Int64" for column in sparse_columns}) + + for column, dtype in types.items(): + table[column] = table[column].apply(node2id.get).astype(dtype) + + map_columns_to_int(all_ast_nodes, dense_columns=["id"], sparse_columns=["mentioned_in"]) + map_columns_to_int( + all_ast_edges, + dense_columns=["source_node_id", "target_node_id"], + sparse_columns=["mentioned_in"] + ) + map_columns_to_int(all_offsets, dense_columns=["node_id"], sparse_columns=["mentioned_in"]) + + return all_ast_nodes, all_ast_edges, all_offsets + + +pd.options.mode.chained_assignment = None # default='warn' + + +def create_test_data(output_dir): + # [(id, source), (id, source)] + test_code = pd.DataFrame.from_records([ + {"id": 1, "filecontent": "import numpy\nnumpy.array([1,2,3])", "package": "any_name_1"}, + {"id": 2, "filecontent": "from numpy.submodule import fn1 as f1, fn2 as f2\n", "package": "can use the same name here any_name_1"}, + {"id": 3, "filecontent": """try: + a = b +except Exception as e: + a = c +else: + a = d +finally: + print(a)""", "package": "a"} + ]) + persist(test_code, os.path.join(output_dir, "source_code.bz2")) + + +def build_ast_graph_from_modules(): + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("source_code", type=str, help="Path to DataFrame pickle (written with pandas.to_pickle, use `bz2` format).") + parser.add_argument("output_path") + parser.add_argument("--bpe_tokenizer", type=str, help="Path to sentencepiece model. When provided, names will be subtokenized.") + parser.add_argument("--visualize", action="store_true", help="Visualize graph. Do not use on large graphs.") + parser.add_argument("--create_test_data", action="store_true", help="Visualize graph. Do not use on large graphs.") + args = parser.parse_args() + + if args.create_test_data: + print(f"Creating test data in {args.output_path}") + create_test_data(args.output_path) + sys.exit() + + source_code = unpersist(args.source_code) + + output_dir = args.output_path + + nodes, edges, offsets = build_ast_only_graph( + zip(source_code["package"], source_code["id"], source_code["filecontent"]), args.bpe_tokenizer, + create_subword_instances=False, connect_subwords=False, lang="py", track_offsets=True + ) + + print(f"Writing output to {output_dir}") + persist(source_code, os.path.join(output_dir, "common_filecontent.bz2")) + persist(nodes, os.path.join(output_dir, "common_nodes.bz2")) + persist(edges, os.path.join(output_dir, "common_edges.bz2")) + persist(offsets, os.path.join(output_dir, "common_offsets.bz2")) + + if args.visualize: + visualize(nodes, edges, os.path.join(output_dir, "visualization.pdf")) + + +class AstDatasetCreator(AbstractDatasetCreator): + + merging_specification = { + "source_graph_bodies.bz2": {"columns": ['id'], "output_path": "common_source_graph_bodies.json", "columns_special": [("replacement_list", map_offsets)]}, + "function_variable_pairs.bz2": {"columns": ['src'], "output_path": "common_function_variable_pairs.json"}, + "call_seq.bz2": {"columns": ['src', 'dst'], "output_path": "common_call_seq.json"}, + + "nodes_with_ast.bz2": {"columns": ['id', 'mentioned_in'], "output_path": "common_nodes.json", "ensure_unique_with": ['type', 'serialized_name']}, + "edges_with_ast.bz2": {"columns": ['target_node_id', 'source_node_id', 'mentioned_in'], "output_path": "common_edges.json"}, + "offsets.bz2": {"columns": ['node_id', 'mentioned_in'], "output_path": "common_offsets.json"}, + "filecontent_with_package.bz2": {"columns": [], "output_path": "common_filecontent.json"}, + "name_mappings.bz2": {"columns": [], "output_path": "common_name_mappings.json"}, + } + + files_for_merging_with_ast = [ + "nodes_with_ast.bz2", "edges_with_ast.bz2", "source_graph_bodies.bz2", "function_variable_pairs.bz2", + "call_seq.bz2", "offsets.bz2", "filecontent_with_package.bz2" + ] + + edge_priority = { + "next": -1, "prev": -1, "global_mention": -1, "global_mention_rev": -1, + "calls": 0, + "called_by": 0, + "defines": 1, + "defined_in": 1, + "inheritance": 1, + "inherited_by": 1, + "imports": 1, + "imported_by": 1, + "uses": 2, + "used_by": 2, + "uses_type": 2, + "type_used_by": 2, + "mention_scope": 10, + "mention_scope_rev": 10, + "defined_in_function": 4, + "defined_in_function_rev": 4, + "defined_in_class": 5, + "defined_in_class_rev": 5, + "defined_in_module": 6, + "defined_in_module_rev": 6 + } + + restricted_edges = {"global_mention_rev"} + restricted_in_types = { + "Op", "Constant", "#attr#", "#keyword#", + 'CtlFlow', 'JoinedStr', 'Name', 'ast_Literal', + 'subword', 'type_annotation' + } + + type_annotation_edge_types = ['annotation_for', 'returned_by'] + + def __init__( + self, path, lang, bpe_tokenizer, create_subword_instances, connect_subwords, only_with_annotations, + do_extraction=False, visualize=False, track_offsets=False, remove_type_annotations=False, + recompute_l2g=False, chunksize=10000, keep_frac=1.0 + ): + self.chunksize = chunksize + self.keep_frac = keep_frac + super().__init__( + path, lang, bpe_tokenizer, create_subword_instances, connect_subwords, only_with_annotations, + do_extraction, visualize, track_offsets, remove_type_annotations, recompute_l2g + ) + + def __del__(self): + # TODO use /tmp and add flag for overriding temp folder location + if hasattr(self, "temp_path") and os.path.isdir(self.temp_path): + shutil.rmtree(self.temp_path) + # pass + + def _prepare_environments(self): + dataset_location = os.path.dirname(self.path) + temp_path = os.path.join(dataset_location, "temp_graph_builder") + self.temp_path = temp_path + if os.path.isdir(temp_path): + raise FileExistsError(f"Directory exists: {temp_path}") + os.mkdir(temp_path) + + for ind, chunk in enumerate(pd.read_csv(self.path, chunksize=self.chunksize)): + chunk = chunk.sample(frac=self.keep_frac) + chunk_path = os.path.join(temp_path, f"chunk_{ind}") + os.mkdir(chunk_path) + persist(chunk, os.path.join(chunk_path, "source_code.bz2")) + + paths = (os.path.join(temp_path, dir) for dir in os.listdir(temp_path)) + self.environments = sorted(list(filter(lambda path: os.path.isdir(path), paths)), key=lambda x: x.lower()) + + @staticmethod + def extract_node_names(nodes_path, min_count): + logging.info("Extract node names") + return extract_node_names(read_nodes(nodes_path), min_count=min_count) + + def filter_type_edges(self, nodes_path, edges_path): + logging.info("Filter type edges") + filter_type_edges_with_chunks(nodes_path, edges_path, kwarg_fn=self.get_writing_mode) + + def do_extraction(self): + global_nodes_with_ast = set() + + for env_path in self.environments: + logging.info(f"Found {os.path.basename(env_path)}") + + if not self.recompute_l2g: + + source_code = unpersist(join(env_path, "source_code.bz2")) + + nodes_with_ast, edges_with_ast, offsets = build_ast_only_graph( + zip(source_code["package"], source_code["id"], source_code["filecontent"]), self.bpe_tokenizer, + create_subword_instances=self.create_subword_instances, connect_subwords=self.connect_subwords, + lang=self.lang, track_offsets=self.track_offsets + ) + + else: + nodes_with_ast = unpersist_if_present(join(env_path, "nodes_with_ast.bz2")) + + if nodes_with_ast is None: + continue + + edges_with_ast = offsets = source_code = None + + local2global_with_ast = get_local2global( + global_nodes=global_nodes_with_ast, local_nodes=nodes_with_ast + ) + + global_nodes_with_ast.update(local2global_with_ast["global_id"]) + + self.write_type_annotation_flag(edges_with_ast, env_path) + + self.write_local( + env_path, + local2global_with_ast=local2global_with_ast, + nodes_with_ast=nodes_with_ast, edges_with_ast=edges_with_ast, offsets=offsets, + filecontent_with_package=source_code, + ) + + self.compact_mapping_for_l2g(global_nodes_with_ast, "local2global_with_ast.bz2") + + def create_output_dirs(self, output_path): + if not os.path.isdir(output_path): + os.mkdir(output_path) + + with_ast_path = join(output_path, "with_ast") + + if not os.path.isdir(with_ast_path): + os.mkdir(with_ast_path) + + return with_ast_path + + def merge(self, output_directory): + + if self.extract: + logging.info("Extracting...") + self.do_extraction() + + with_ast_path = self.create_output_dirs(output_directory) + + self.merge_graph_with_ast(with_ast_path) + + def visualize_func(self, nodes, edges, output_path): + visualize(nodes, edges, output_path) + + +if __name__ == "__main__": + + args = AstDatasetCreatorArguments().parse() + + if args.recompute_l2g: + args.do_extraction = True + + logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(message)s") + + dataset = AstDatasetCreator( + args.source_code, args.language, args.bpe_tokenizer, args.create_subword_instances, + args.connect_subwords, args.only_with_annotations, args.do_extraction, args.visualize, args.track_offsets, + args.remove_type_annotations, args.recompute_l2g, args.chunksize, args.keep_frac + ) + dataset.merge(args.output_directory) \ No newline at end of file diff --git a/SourceCodeTools/code/data/ast_graph/draw_graph.py b/SourceCodeTools/code/data/ast_graph/draw_graph.py new file mode 100644 index 00000000..c72369f7 --- /dev/null +++ b/SourceCodeTools/code/data/ast_graph/draw_graph.py @@ -0,0 +1,35 @@ +import pandas as pd + + + +def visualize(nodes, edges, output_path): + import pygraphviz as pgv + + edges = edges[edges["type"].apply(lambda x: not x.endswith("_rev"))] + + id2name = dict(zip(nodes['id'], nodes['serialized_name'])) + + g = pgv.AGraph(strict=False, directed=True) + + from SourceCodeTools.code.ast.python_ast2 import PythonNodeEdgeDefinitions + auxiliaty_edge_types = PythonNodeEdgeDefinitions.auxiliary_edges() + + for ind, edge in edges.iterrows(): + src = edge['source_node_id'] + dst = edge['target_node_id'] + src_name = id2name[src] + dst_name = id2name[dst] + g.add_node(src_name, color="black") + g.add_node(dst_name, color="black") + g.add_edge(src_name, dst_name, color="blue" if edge['type'] in auxiliaty_edge_types else "black") + g_edge = g.get_edge(src_name, dst_name) + g_edge.attr['label'] = edge['type'] + + g.layout("dot") + g.draw(output_path) + + +if __name__ == "__main__": + nodes = pd.read_pickle("common_nodes.bz2") + edges = pd.read_pickle("common_edges.bz2") + visualize(nodes, edges, "test.pdf") \ No newline at end of file diff --git a/SourceCodeTools/code/data/ast_graph/extract_node_names.py b/SourceCodeTools/code/data/ast_graph/extract_node_names.py new file mode 100644 index 00000000..9c37e3dd --- /dev/null +++ b/SourceCodeTools/code/data/ast_graph/extract_node_names.py @@ -0,0 +1,29 @@ +def extract_node_names(nodes, min_count): + + data = nodes.copy().rename({"id": "src", "serialized_name": "dst"}, axis=1) + + corrected_names = [] + for type_, name_ in data[["type", "dst"]].values: + if type_ == "mention": + corrected_names.append(name_.split("@")[0]) + else: + corrected_names.append(name_) + + data["dst"] = corrected_names + + def not_contains(name): + return "0x" not in name + + data = data[ + data["dst"].apply(not_contains) + ] + + counts = data['dst'].value_counts() + + data['counts'] = data['dst'].apply(lambda x: counts[x]) + data = data.query(f"counts >= {min_count}") + + if len(data) > 0: + return data[['src', 'dst']] + else: + return None \ No newline at end of file diff --git a/SourceCodeTools/code/data/ast_graph/filter_type_edges.py b/SourceCodeTools/code/data/ast_graph/filter_type_edges.py new file mode 100644 index 00000000..9f36b6d0 --- /dev/null +++ b/SourceCodeTools/code/data/ast_graph/filter_type_edges.py @@ -0,0 +1,76 @@ +import os +from os.path import join + +from SourceCodeTools.code.common import read_nodes, read_edges +from SourceCodeTools.code.data.file_utils import persist + + +def filter_type_edges(nodes, edges, keep_proportion=0.0): + annotations = edges.query(f"type == 'annotation_for' or type == 'returned_by' or type == 'annotation_for_rev' or type == 'returned_by_rev'") + no_annotations = edges.query(f"type != 'annotation_for' and type != 'returned_by' and type != 'annotation_for_rev' and type != 'returned_by_rev'") + + annotations = annotations.query("type == 'annotation_for' or type == 'returned_by'") + + to_keep = int(len(annotations) * keep_proportion) + if to_keep == 0: + annotations_removed = annotations + annotations_kept = None + elif to_keep == len(annotations): + annotations_removed = None + annotations_kept = annotations + else: + annotations = annotations.sample(frac=1.) + annotations_kept, annotations_removed = annotations.iloc[:to_keep], annotations.iloc[to_keep:] + + if annotations_kept is not None: + no_annotations = no_annotations.append(annotations_kept) + + annotations = annotations_removed + if annotations is not None: + annotations = annotations_removed + node2name = dict(zip(nodes["id"], nodes["serialized_name"])) + get_name = lambda id_: node2name[id_] + annotations["source_node_id"] = annotations["source_node_id"].apply(get_name) + # rename columns to use as a dataset + # annotations.rename({"source_node_id": "dst", "target_node_id": "src"}, axis=1, inplace=True) + annotations.rename({"target_node_id": "src", "type_string": "dst"}, axis=1, inplace=True) + annotations = annotations[["src","dst"]] + + return no_annotations, annotations + + +def filter_type_edges_with_chunks(nodes_path, edges_path, kwarg_fn): + + node2name = {} + for nodes in read_nodes(nodes_path, as_chunks=True): + node2name.update(dict(zip(nodes["id"], nodes["serialized_name"]))) + + temp_edges = join(os.path.dirname(edges_path), "temp_" + os.path.basename(edges_path)) + annotations_path = join(os.path.dirname(edges_path), "type_annotations.json") + + annotations_written = False + + for ind, edges in enumerate(read_edges(edges_path, as_chunks=True)): + annotations = edges.query( + f"type == 'annotation_for' or type == 'returned_by' or type == 'annotation_for_rev' or type == 'returned_by_rev'") + no_annotations = edges.query( + f"type != 'annotation_for' and type != 'returned_by' and type != 'annotation_for_rev' and type != 'returned_by_rev'") + + annotations = annotations.query("type == 'annotation_for' or type == 'returned_by'") + + if annotations is not None and len(annotations) > 0: + annotations["type_string"] = annotations["source_node_id"].apply(node2name.get) + # rename columns to use as a dataset + annotations.rename({"target_node_id": "src", "type_string": "dst"}, axis=1, inplace=True) + annotations = annotations[["src", "dst"]] + + kwargs = kwarg_fn(annotations_path.endswith("csv"), first_written=annotations_written) + persist(annotations, annotations_path, **kwargs) + + annotations_written = True + + kwargs = kwarg_fn(temp_edges.endswith("csv"), first_written=ind != 0) + persist(no_annotations, temp_edges, **kwargs) + + os.remove(edges_path) + os.rename(temp_edges, edges_path) \ No newline at end of file diff --git a/SourceCodeTools/code/data/ast_graph/local2global.py b/SourceCodeTools/code/data/ast_graph/local2global.py new file mode 100644 index 00000000..41c66989 --- /dev/null +++ b/SourceCodeTools/code/data/ast_graph/local2global.py @@ -0,0 +1,45 @@ +import sys + +from SourceCodeTools.code.common import compute_long_id, create_node_repr +from SourceCodeTools.code.data.file_utils import * + + +def create_local_to_global_id_map(local_nodes, global_nodes): + # local_nodes = local_nodes.copy() + # global_nodes = global_nodes.copy() + # + # global_nodes['node_repr'] = create_node_repr(global_nodes) + # local_nodes['node_repr'] = create_node_repr(local_nodes) + # + # rev_id_map = dict(zip( + # global_nodes['node_repr'].tolist(), global_nodes['id'].tolist() + # )) + # id_map = dict(zip( + # local_nodes["id"].tolist(), map( + # lambda x: rev_id_map[x], local_nodes["node_repr"].tolist() + # ) + # )) + id_map = dict(zip( + local_nodes["id"], map(compute_long_id, create_node_repr(local_nodes)) + )) + + return id_map + + +def get_local2global(global_nodes, local_nodes) -> pd.DataFrame: + local_nodes = local_nodes.copy() + id_map = create_local_to_global_id_map(local_nodes=local_nodes, global_nodes=global_nodes) + + local_nodes['global_id'] = local_nodes['id'].apply(lambda x: id_map.get(x, None)) + + return local_nodes[['id', 'global_id']] + + +if __name__ == "__main__": + global_nodes = unpersist_or_exit(sys.argv[1], "Global nodes do not exist!") + local_nodes = unpersist_or_exit(sys.argv[2], "No processed nodes, skipping") + local_map_path = sys.argv[3] + + local2global = get_local2global(global_nodes, local_nodes) + + persist(local2global, local_map_path) diff --git a/SourceCodeTools/code/data/cubert_python_benchmarks/SQLTable.py b/SourceCodeTools/code/data/cubert_python_benchmarks/SQLTable.py new file mode 100644 index 00000000..586c63ab --- /dev/null +++ b/SourceCodeTools/code/data/cubert_python_benchmarks/SQLTable.py @@ -0,0 +1,50 @@ +import sqlite3 + +import pandas as pd + + +class SQLTable: + def __init__(self, filename): + self.conn = sqlite3.connect(filename) + self.path = filename + + def replace_records(self, table, table_name, **kwargs): + table.to_sql(table_name, con=self.conn, if_exists='replace', index=False, method="multi", chunksize=1000, **kwargs) + self.create_index_for_table(table, table_name) + + def add_records(self, table, table_name, **kwargs): + table.to_sql(table_name, con=self.conn, if_exists='append', index=False, method="multi", chunksize=1000, **kwargs) + self.create_index_for_table(table, table_name) + + def create_index_for_table(self, table, table_name): + self.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{table_name} + ON {table_name}({','.join(repr(col) for col in table.columns)}) + """ + ) + + def create_index_for_columns(self, columns, table_name): + self.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{table_name} + ON {table_name}({','.join(repr(col) for col in columns)}) + """ + ) + + def query(self, query_string, **kwargs): + return pd.read_sql(query_string, self.conn, **kwargs) + + def execute(self, query_string): + self.conn.execute(query_string) + self.conn.commit() + + def drop_table(self, table_name): + self.conn.execute(f"DROP TABLE IF EXISTS {table_name}") + self.conn.execute(f"DROP INDEX IF EXISTS idx_{table_name}") + self.conn.commit() + + def __del__(self): + self.conn.close() + # if os.path.isfile(self.path): + # os.remove(self.path) \ No newline at end of file diff --git a/SourceCodeTools/code/data/cubert_python_benchmarks/convert_for_ast_graph_builder.py b/SourceCodeTools/code/data/cubert_python_benchmarks/convert_for_ast_graph_builder.py new file mode 100644 index 00000000..7c55e4cd --- /dev/null +++ b/SourceCodeTools/code/data/cubert_python_benchmarks/convert_for_ast_graph_builder.py @@ -0,0 +1,81 @@ +import bz2 +import json +from pathlib import Path + +import pandas as pd +from tqdm import tqdm + + +def write_chunk(temp, column_order, first_part, output_path): + data = pd.DataFrame.from_records(temp, columns=column_order) + data.rename({"function": "filecontent"}, axis=1, inplace=True) + if first_part is True: + data.to_csv(output_path, index=False) + else: + data.to_csv(output_path, index=False, mode="a", header=False) + + +def convert_bzip(dataset_path, output_path): + assert dataset_path.name.endswith("bz2") + + dataset = bz2.open(dataset_path, mode="rt") + + convert(dataset, output_path) + + +def convert_jsonl(dataset_path, output_path): + assert dataset_path.name.endswith("jsonl") + + dataset = open(dataset_path, mode="r") + + convert(dataset, output_path) + + +def convert(dataset, output_path): + + temp = [] + id_ = 0 + + column_order = None + first_part = True + + for line in tqdm(dataset): + record = json.loads(line) + + if record.pop("parsing_error"): + continue + + if column_order is None: + column_order = list(record.keys()) + column_order.insert(0, "id") + + record["id"] = id_ + id_ += 1 + + temp.append(record) + + if len(temp) > 10000: + write_chunk(temp, column_order, first_part, output_path) + first_part = False + temp.clear() + + if len(temp) > 0: + write_chunk(temp, column_order, first_part, output_path) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser("Convert CuBERT dataset file from jsonl format to pandas table comparible with AST graph builder") + + parser.add_argument("dataset_path", help="Path to jsonl file (compressed bz2)") + parser.add_argument("output_path") + + args = parser.parse_args() + + dataset_path = Path(args.dataset_path) + output_path = Path(args.output_path) + + if dataset_path.name.endswith("bz2"): + convert_bzip(dataset_path, output_path) + else: + convert_jsonl(dataset_path, output_path) \ No newline at end of file diff --git a/SourceCodeTools/code/data/cubert_python_benchmarks/download.sh b/SourceCodeTools/code/data/cubert_python_benchmarks/download.sh new file mode 100644 index 00000000..65e3f144 --- /dev/null +++ b/SourceCodeTools/code/data/cubert_python_benchmarks/download.sh @@ -0,0 +1,106 @@ +mkdir function_docstring_datasets +gsutil -m cp \ + "gs://cubert/20200621_Python/function_docstring_datasets/dev.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/dev.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/dev.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/dev.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/eval.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/eval.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/eval.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/eval.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/train.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/train.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/train.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/function_docstring_datasets/train.jsontxt-00003-of-00004" \ + function_docstring_datasets + +mkdir exception_datasets +gsutil -m cp \ + "gs://cubert/20200621_Python/exception_datasets/dev.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/dev.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/dev.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/dev.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/eval.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/eval.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/eval.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/eval.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/train.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/train.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/train.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/exception_datasets/train.jsontxt-00003-of-00004" \ + exception_datasets + +mkdir variable_misuse_datasets +gsutil -m cp \ + "gs://cubert/20200621_Python/variable_misuse_datasets/dev.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/dev.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/dev.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/dev.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/eval.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/eval.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/eval.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/eval.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/train.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/train.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/train.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_datasets/train.jsontxt-00003-of-00004" \ + variable_misuse_datasets + +mkdir swapped_operands_datasets +gsutil -m cp \ + "gs://cubert/20200621_Python/swapped_operands_datasets/dev.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/dev.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/dev.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/dev.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/eval.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/eval.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/eval.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/eval.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/train.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/train.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/train.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/swapped_operands_datasets/train.jsontxt-00003-of-00004" \ + swapped_operands_datasets + +mkdir wrong_binary_operator_datasets +gsutil -m cp \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/dev.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/dev.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/dev.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/dev.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/eval.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/eval.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/eval.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/eval.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/train.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/train.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/train.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/wrong_binary_operator_datasets/train.jsontxt-00003-of-00004" \ + wrong_binary_operator_datasets + +mkdir variable_misuse_repair_datasets +gsutil -m cp \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/dev.jsontxt-00000-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/dev.jsontxt-00001-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/dev.jsontxt-00002-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/dev.jsontxt-00003-of-00004" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/eval.jsontxt-00000-of-00006" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/eval.jsontxt-00001-of-00006" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/eval.jsontxt-00002-of-00006" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/eval.jsontxt-00003-of-00006" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/eval.jsontxt-00004-of-00006" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/eval.jsontxt-00005-of-00006" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/githubcommits.jsontxt-00000-of-00001" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/githubcommits.raw.jsontxt-00000-of-00001" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00000-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00001-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00002-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00003-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00004-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00005-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00006-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00007-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00008-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00009-of-00011" \ + "gs://cubert/20200621_Python/variable_misuse_repair_datasets/train.jsontxt-00010-of-00011" \ + variable_misuse_repair_datasets \ No newline at end of file diff --git a/SourceCodeTools/code/data/cubert_python_benchmarks/extract_partition.py b/SourceCodeTools/code/data/cubert_python_benchmarks/extract_partition.py new file mode 100644 index 00000000..e4ae5aa0 --- /dev/null +++ b/SourceCodeTools/code/data/cubert_python_benchmarks/extract_partition.py @@ -0,0 +1,28 @@ +import os.path + +from SourceCodeTools.code.data.file_utils import unpersist, persist + + +def extract_partitions(path): + filecontent = unpersist(path) + + items = filecontent[["id"]] + masks = filecontent["partition"].values + + # create masks + items["train_mask"] = masks == "train" + items["val_mask"] = masks == "dev" + items["test_mask"] = masks == "eval" + + dirname = os.path.dirname(path) + persist(items, os.path.join(dirname, "partition.json")) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("filecontent") + + args = parser.parse_args() + + extract_partitions(args.filecontent) \ No newline at end of file diff --git a/SourceCodeTools/code/data/cubert_python_benchmarks/partitioning.py b/SourceCodeTools/code/data/cubert_python_benchmarks/partitioning.py new file mode 100644 index 00000000..39fd1eab --- /dev/null +++ b/SourceCodeTools/code/data/cubert_python_benchmarks/partitioning.py @@ -0,0 +1,87 @@ +import json +from functools import partial +from os.path import join +from random import random + +from SourceCodeTools.code.common import read_edges, read_nodes +from SourceCodeTools.code.data.file_utils import persist + + +def add_splits(items, train_frac, restricted_id_pool=None): + items = items.copy() + + def random_partition(): + r = random() + if r < train_frac: + return "train" + elif r < train_frac + (1 - train_frac) / 2: + return "val" + else: + return "test" + + import numpy as np + # define partitioning + masks = np.array([random_partition() for _ in range(len(items))]) + + # create masks + items["train_mask"] = masks == "train" + items["val_mask"] = masks == "val" + items["test_mask"] = masks == "test" + + if restricted_id_pool is not None: + # if `restricted_id_pool` is provided, mask all nodes not in `restricted_id_pool` negatively + to_keep = items.eval("id in @restricted_ids", local_dict={"restricted_ids": restricted_id_pool}) + items["train_mask"] = items["train_mask"] & to_keep + items["test_mask"] = items["test_mask"] & to_keep + items["val_mask"] = items["val_mask"] & to_keep + + return items + + +def subgraph_partitioning(path_to_dataset, partition_column, train_frac=0.8): + + get_path = partial(join, path_to_dataset) + + # nodes = read_nodes(get_path("common_nodes.json.bz2")) + edges = read_edges(get_path("common_edges.json.bz2")) + + subgraph_ids = add_splits( + edges[[partition_column]].dropna(axis=0).drop_duplicates(), train_frac=train_frac + ).rename({partition_column: "id"}, axis=1) + + valid_subgraph_ids = set(subgraph_ids["id"]) + + persist(subgraph_ids, get_path("subgraph_partition.json")) + + edges = edges[["source_node_id", "target_node_id", partition_column]].sort_values(partition_column) + + last_subgraph_id = -1 + pool = set() + + with open(get_path("subgraph_mapping.json"), "w") as subgraph_mapping: + for src, dst, subgraph_id in edges[["source_node_id", "target_node_id", partition_column]].values: + if subgraph_id in valid_subgraph_ids: + if subgraph_id != last_subgraph_id and last_subgraph_id != -1: + record = json.dumps({ + "subgraph_id": last_subgraph_id, + "node_ids": list(pool) + }) + subgraph_mapping.write(f"{record}\n") + + pool.clear() + + pool.add(src) + pool.add(dst) + last_subgraph_id = subgraph_id + + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("dataset_path") + parser.add_argument("subgraph_column") + + args = parser.parse_args() + + subgraph_partitioning(args.dataset_path, args.subgraph_column) \ No newline at end of file diff --git a/SourceCodeTools/code/data/cubert_python_benchmarks/prepare_for_ast_parser.py b/SourceCodeTools/code/data/cubert_python_benchmarks/prepare_for_ast_parser.py new file mode 100644 index 00000000..812baa94 --- /dev/null +++ b/SourceCodeTools/code/data/cubert_python_benchmarks/prepare_for_ast_parser.py @@ -0,0 +1,394 @@ +import hashlib + +# import pandas as pd +import ast +import json +from pathlib import Path +from nltk import RegexpTokenizer +from itertools import chain + +from tqdm import tqdm + +from SourceCodeTools.code.data.cubert_python_benchmarks.SQLTable import SQLTable + + +class CodeTokenizer: + tok = RegexpTokenizer("\w+|\s+|\W+") + + @classmethod + def tokenize(cls, code): + return cls.tok.tokenize(code) + + @classmethod + def detokenize(cls, code_tokens): + return "".join(code_tokens) + + +def test_CodeTokenizer(): + code = """ +data = pd.DataFrame2.from_records(json.loads(line) for line in open(filename).readlines()) + +success = 0 +errors = 0 +""" + assert CodeTokenizer.detokenize(CodeTokenizer.tokenize(code)) == code + + +test_CodeTokenizer() + + +class Info: + """ + Example + dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async`->`self` + """ + def __init__(self, info): + self.info = info + parts = info.split(" ") + self.parse_filepath(parts) + self.parse_fn_path(parts) + self.parse_state(parts) + self.parse_package(parts) + self.set_id() + + def parse_filepath(self, parts): + self.filepath = " ".join(self.info.split(".py")[0].split(" ")[1:]) + ".py" + # self.filepath = parts[1].split(".py")[0] + ".py" + + def parse_fn_path(self, parts): + self.fn_path = self.filepath + "/" + self.info.split(".py")[1].lstrip("/").lstrip(" ").split("/")[0] + + def parse_state(self, parts): + self.state = "/".join(self.info.split(".py")[1].lstrip("/").lstrip(" ").split("/")[1:]) + + def parse_package(self, parts): + self.package = self.filepath.split("/")[0] + + def set_id(self): + identifier = "\t".join([self.fn_path, self.state]) + self.id = hashlib.md5(identifier.encode('utf-8')).hexdigest() + + oidentifier = "\t".join([self.fn_path, "original"]) + + if self.state.endswith("riginal"): + assert identifier == oidentifier + self.original_id = hashlib.md5(oidentifier.encode('utf-8')).hexdigest() + + def __repr__(self): + return self.info + + +class DatasetAdapter: + replacements = { + "async": "async_", + "await": "await_" + } + + fields = { + "variable_misuse": ["function", "label", "info"], + "variable_misuse_repair": ["function", "target_mask", "error_location_mask", "candidate_mask", "provenance"], + "exception": ["function", "label", "info"], + "function_docstring": ["function", "docstring", "label", "info"], + "swapped_operands": ["function", "label", "info"], + "wrong_binary_operator": ["function", "label", "info"] + } + supported_partitions = ["train", "dev", "eval"] # + + benchmark_names = { + "exception": "exception_datasets", + "function_docstring": "function_docstring_datasets", + "swapped_operands": "swapped_operands_datasets", + "variable_misuse": "variable_misuse_dataset", + "variable_misuse_repair": "variable_misuse_repair_datasets", + "wrong_binary_operator": "wrong_binary_operator_datasets" + } + + preferred_column_order = ["id", "package", "function", "info", "label", "partition"] + + import_order = [ + # "variable_misuse", + "variable_misuse_repair", + # "exception", + # "function_docstring", + # "swapped_operands", + # "wrong_binary_operator" + ] + + # expected_directory_structure = "dev.jsontxt-00000-of-00004\n" + # "dev.jsontxt-00001-of-00004\n" + # "dev.jsontxt-00002-of-00004\n" + # "dev.jsontxt-00003-of-00004\n" + # "eval.jsontxt-00000-of-00004\n" + # "eval.jsontxt-00001-of-00004\n" + # "eval.jsontxt-00002-of-00004\n" + # "eval.jsontxt-00003-of-00004\n" + # "train.jsontxt-00000-of-00004\n" + # "train.jsontxt-00001-of-00004\n" + # "train.jsontxt-00002-of-00004\n" + # "train.jsontxt-00003-of-00004\n" + + def __init__(self, dataset_location): + self.dataset_location = Path(dataset_location) + + self.replacement_fns = { + "function": self.fix_code_if_needed, + "info": self.fix_info_if_needed + } + + self.preprocess = { + # "variable_misuse_repair": { + # "function": self.cubert_detokenize + # } + } + + self.extra_fields = { + "info": [("package", self.get_package)], + "provenance": [("info", self.fix_info_if_needed)] + } + + self.db = SQLTable(self.dataset_location.joinpath("cubert_benchmarks.db")) + + def load_original_functions(self): + functions = self.db.query("SELECT DISTINCT original_id, function FROM functions where comment = 'original' AND dataset = 'variable_misuse'") + self.original_functions = dict(zip(functions["original_id"], functions["function"])) + + def prepare_misuse_repair_record(self, record): + print(record) + + @staticmethod + def get_source_from_ast_range(node, function, strip=True): + lines = function.split("\n") + start_line = node.lineno + end_line = node.end_lineno + start_col = node.col_offset + end_col = node.end_col_offset + + source = "" + num_lines = end_line - start_line + 1 + if start_line == end_line: + section = lines[start_line - 1].encode("utf8")[start_col:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" + else: + for ind, lineno in enumerate(range(start_line - 1, end_line)): + if ind == 0: + section = lines[lineno].encode("utf8")[start_col:].decode( + "utf8") + source += section.strip() if strip else section + "\n" + elif ind == num_lines - 1: + section = lines[lineno].encode("utf8")[:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" + else: + section = lines[lineno] + source += section.strip() if strip else section + "\n" + + return source.rstrip() + + def get_dispatch(self, function): + root = ast.parse(function) + return self.get_source_from_ast_range(root.body[0].decorator_list[0], function) + + @staticmethod + def remove_indent(code): + lines = code.strip("\n").split("\n") + first_line_indent = lines[0][:len(lines[0]) - len(lines[0].lstrip())] + start_char = len(first_line_indent) + if start_char != 0: + clean = "\n".join(line[start_char:] if line.startswith(first_line_indent) else line for line in lines) + else: + if lines[0].lstrip().startswith("@"): + for ind, line in enumerate(lines): + stripped = line.lstrip() + if stripped.startswith("def "): + lines[ind] = stripped + break + clean = "\n".join(lines) + return clean + + @classmethod + def fix_code_if_needed(cls, code): + # f = code.lstrip() + f = cls.remove_indent(code) + try: + ast.parse(f) + except Exception as e: + tokens = CodeTokenizer.tokenize(f) + recovered_tokens = [cls.replacements[token] if token in cls.replacements else token for token in tokens] + f = CodeTokenizer.detokenize(recovered_tokens) + ast.parse(f) + return f + + @classmethod + def fix_info_if_needed(cls, info): + if not info.endswith("original"): + parts = info.split(" ") + variable_replacement = parts[-1] + + variable_replacement = CodeTokenizer.detokenize( + cls.replacements[token] if token in cls.replacements else token + for token in CodeTokenizer.tokenize(variable_replacement) + ) + + parts[-1] = variable_replacement + + info = " ".join(parts) + return info + + @staticmethod + def get_package(info): + return info.split(" ")[1].split("/")[0] + + @classmethod + def sort_columns(cls, columns): + return sorted(columns, key=cls.preferred_column_order.index) + + def process_record(self, record, preprocess_fns): + new_record = {} + for field, data in record.items(): + if preprocess_fns is not None and field in preprocess_fns: + record[field] = preprocess_fns[field](data) + + new_record[field] = data if field not in self.replacement_fns else self.replacement_fns[field](data) + + for new_field, new_field_fn in self.extra_fields.get(field, []): + new_record[new_field] = new_field_fn(data) + + return new_record + + @classmethod + def stream_original_partition(cls, directory, partition, *args, **kwargs): + directory = Path(directory) + assert partition in cls.supported_partitions, f"Only the partitions should be one of: {cls.supported_partitions}, but {partition} given" + for file in directory.iterdir(): + if file.name.startswith(partition): + with open(file) as p: + for line in p: + yield json.loads(line) + + def stream_processed_partition(self, directory, partition, add_partition=False, preprocess_fns=None): + for record in self.stream_original_partition(directory, partition): + try: + r = self.process_record(record, preprocess_fns=preprocess_fns) + r["parsing_error"] = None + except Exception as e: + r = record + r["parsing_error"] = e.msg if hasattr(e, "msg") else e.__class__.__name__ + r["package"] = self.get_package(r["info"]) + # except MemoryError: # there are two functions that cause this error + # continue + # except SyntaxError as e: + # continue + if add_partition: + r["partition"] = partition + yield r + + # @classmethod + # def process_dataset(cls, original_data_location, output_location): + # partitions = ["train", "dev", "eval"] + # + # last_id = 0 + # column_order = None + # + # data = pd.DataFrame.from_records( + # chain( + # *(cls.stream_processed_partition(original_data_location, partition, add_partition=True) for partition in + # partitions)), + # # columns = column_order + # ) + # data["id"] = range(len(data)) + # # data.to_pickle(output_location, index=False, columns=cls.sort_columns(data.columns)) + # data.to_pickle(output_location) + + def iterate_dataset(self, dataset_name): + for partition in self.supported_partitions: + for record in self.stream_processed_partition( + self.dataset_location.joinpath(self.benchmark_names[dataset_name]), partition, + preprocess_fns=self.preprocess.get(dataset_name, None) + ): + record["partition"] = partition + yield record + + def import_data(self): + + functions = [] + + # added_original = set() + + for dataset_name in self.import_order: + + parsed_successfully = 0 + parsed_with_errors = 0 + + dataset_file = open(self.dataset_location.joinpath(f"{dataset_name}.jsonl"), "w") + + for record in tqdm(self.iterate_dataset(dataset_name), desc=f"Processing {dataset_name}"): + info = Info(record["info"]) + + # if info.id != "12101e677d30de9462596e7894d2bbd1": + # continue + + # if dataset_name == "variable_misuse_repair": + # record = self.prepare_misuse_repair_record(record) + # if record["function"].startswith("@dispatch"): + # dispatch = self.get_dispatch(record["function"]) + # info.fn_path += f"@{dispatch}" + # info.set_id() + + record_for_writing = { + # "id": info.id, + # "original_id": info.original_id, + # "filepath": info.filepath, + "fn_path": info.fn_path, + "function": record["function"], + "comment": info.state, + "package": info.package, + "label": record["label"], + # "dataset": dataset_name, + "partition": record["partition"], + "parsing_error": record["parsing_error"] + } + + if record["parsing_error"] is None: + parsed_successfully += 1 + dataset_file.write(f"{json.dumps(record_for_writing)}\n") + else: + parsed_with_errors += 1 + + dataset_file.close() + print(f"{dataset_name}: success {parsed_successfully} error {parsed_with_errors}") + # functions.append(record_for_writing) + # + # if len(functions) > 100000: + # self.db.add_records(pd.DataFrame.from_records(functions), "functions") + # functions.clear() + # + # if len(functions) > 0: + # self.db.add_records(pd.DataFrame.from_records(functions), "functions") + # functions.clear() + + + +def test_fix_info_if_needed(): + example_info = "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async`->`self`" + assert DatasetAdapter.fix_info_if_needed( + example_info) == "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async_`->`self`" + + example_info = "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async_call`->`self`" + assert DatasetAdapter.fix_info_if_needed( + example_info) == "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async_call`->`self`" + + +test_fix_info_if_needed() + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser("Convert CuBERT's variable misuse detection dataset for further processing") + parser.add_argument("dataset_path", help="Path to dataset folder") + # parser.add_argument("output_path", help="Path to output file") + args = parser.parse_args() + + dataset = DatasetAdapter(args.dataset_path) + dataset.import_data() + # DatasetAdapter.process_dataset(args.dataset_path, args.output_path) \ No newline at end of file diff --git a/SourceCodeTools/code/data/cubert_python_benchmarks/variable_misuse_node_level_labels.py b/SourceCodeTools/code/data/cubert_python_benchmarks/variable_misuse_node_level_labels.py new file mode 100644 index 00000000..e968585f --- /dev/null +++ b/SourceCodeTools/code/data/cubert_python_benchmarks/variable_misuse_node_level_labels.py @@ -0,0 +1,71 @@ +from os.path import join + +import pandas as pd +from tqdm import tqdm + +from SourceCodeTools.code.data.file_utils import unpersist, persist + + +def get_node_labels(dataset_path): + filecontent = unpersist(join(dataset_path, "common_filecontent.json.bz2")) + nodes = unpersist(join(dataset_path, "common_nodes.json.bz2")) + edges = unpersist(join(dataset_path, "common_edges.json.bz2")) + + id2comment = dict(zip(filecontent["id"], filecontent["comment"])) + nodeid2name = dict(zip(nodes["id"], nodes["serialized_name"])) + nodeid2type = dict(zip(nodes["id"], nodes["type"])) + + edges.sort_values("file_id", inplace=True) + + del nodes, filecontent + + misuse_labels = [] + + for file_id, file_edges in tqdm(edges.groupby("file_id")): + comment = id2comment[file_id] + + if comment is None: + continue + + if comment.startswith("original"): + continue + + misused_vars = set(comment.split(" ")[-1].strip("`").split("`->`")) + + if len(misused_vars) != 2: + print(f"Error in {file_id}") + continue + + all_nodes = set(file_edges["source_node_id"].append(file_edges["target_node_id"])) + + for node_id in all_nodes: + if nodeid2type[node_id] == "mention": + node_name = nodeid2name[node_id].split("@")[0] + if node_name in misused_vars: + misuse_labels.append({ + "src": node_id, + "dst": "misused" + }) + else: + misuse_labels.append({ + "src": node_id, + "dst": "correct" + }) + + labels = pd.DataFrame.from_records(misuse_labels) + + persist(labels, join(dataset_path, "misuse_labels.json")) + + + + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("dataset_path") + + args = parser.parse_args() + + get_node_labels(args.dataset_path) \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/Dataset.py b/SourceCodeTools/code/data/dataset/Dataset.py similarity index 50% rename from SourceCodeTools/code/data/sourcetrail/Dataset.py rename to SourceCodeTools/code/data/dataset/Dataset.py index 56f3de34..f47b94db 100644 --- a/SourceCodeTools/code/data/sourcetrail/Dataset.py +++ b/SourceCodeTools/code/data/dataset/Dataset.py @@ -1,6 +1,7 @@ # from collections import Counter # from itertools import chain from collections import Counter +from typing import List, Optional import pandas import numpy @@ -8,45 +9,16 @@ from os.path import join -from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker, NodeNameMasker -from SourceCodeTools.code.data.sourcetrail.file_utils import * -from SourceCodeTools.code.python_ast import PythonSharedNodes +from SourceCodeTools.code.data.dataset.SubwordMasker import SubwordMasker, NodeNameMasker, NodeClfMasker +from SourceCodeTools.code.data.dataset.reader import load_data +from SourceCodeTools.code.data.file_utils import * +from SourceCodeTools.code.ast.python_ast import PythonSharedNodes from SourceCodeTools.nlp.embed.bpe import make_tokenizer, load_bpe_model from SourceCodeTools.tabular.common import compact_property from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types from SourceCodeTools.code.data.sourcetrail.sourcetrail_extract_node_names import extract_node_names -def load_data(node_path, edge_path): - nodes = unpersist(node_path) - edges = unpersist(edge_path) - - nodes_ = nodes.rename(mapper={ - 'serialized_name': 'name' - }, axis=1).astype({ - 'type': 'category' - }) - - edges_ = edges.rename(mapper={ - 'source_node_id': 'src', - 'target_node_id': 'dst' - }, axis=1).astype({ - 'type': 'category' - }) - - return nodes_, edges_ - - -def get_name_group(name): - parts = name.split("@") - if len(parts) == 1: - return pd.NA - elif len(parts) == 2: - local_name, group = parts - return group - return pd.NA - - def filter_dst_by_freq(elements, freq=1): counter = Counter(elements["dst"]) allowed = {item for item, count in counter.items() if count >= freq} @@ -54,115 +26,6 @@ def filter_dst_by_freq(elements, freq=1): return target -def create_train_val_test_masks(nodes, train_idx, val_idx, test_idx): - nodes['train_mask'] = True - # nodes.loc[train_idx, 'train_mask'] = True - nodes['val_mask'] = False - nodes.loc[val_idx, 'val_mask'] = True - nodes['test_mask'] = False - nodes.loc[test_idx, 'test_mask'] = True - nodes['train_mask'] = nodes['train_mask'] ^ (nodes['val_mask'] | nodes['test_mask']) - starts_with = lambda x: x.startswith("##node_type") - nodes.loc[nodes.eval("name.map(@starts_with)", local_dict={"starts_with": starts_with}), ['train_mask', 'val_mask', 'test_mask']] = False - - -def get_train_val_test_indices(indices, train_frac=0.6, random_seed=None): - if random_seed is not None: - numpy.random.seed(random_seed) - logging.warning("Random state for splitting dataset is fixed") - else: - logging.info("Random state is not set") - - indices = indices.to_numpy() - - numpy.random.shuffle(indices) - - train = int(indices.size * train_frac) - test = int(indices.size * (train_frac + (1 - train_frac) / 2)) - - logging.info( - f"Splitting into train {train}, validation {test - train}, and test {indices.size - test} sets" - ) - - return indices[:train], indices[train: test], indices[test:] - - -def get_train_val_test_indices_on_packages(nodes, package_names, train_frac=0.6, random_seed=None): - if random_seed is not None: - numpy.random.seed(random_seed) - logging.warning("Random state for splitting dataset is fixed") - else: - logging.info("Random state is not set") - - nodes = nodes.copy() - - package_names = [name.replace("\n", "").replace("-","_").replace(".","_") for name in package_names] - - package_names = numpy.array(package_names) - numpy.random.shuffle(package_names) - - train = int(package_names.size * train_frac) - test = int(package_names.size * (train_frac + (1 - train_frac) / 2)) - - logging.info( - f"Splitting into train {train}, validation {test - train}, and test {package_names.size - test} packages" - ) - - train, valid, test = package_names[:train], package_names[train: test], package_names[test:] - - train = set(train.tolist()) - valid = set(valid.tolist()) - test = set(test.tolist()) - - nodes["node_package_names"] = nodes["name"].map(lambda name: name.split(".")[0]) - - def get_split_indices(split): - global_types = {val for _, val in node_types.items()} - - def is_global_type(type): - return type in global_types - - def in_split(name): - return name in split - - global_nodes = nodes.query( - "node_package_names.map(@in_split) and type_backup.map(@is_global_type)", - local_dict={"in_split": in_split, "is_global_type": is_global_type} - )["id"] - global_nodes = set(global_nodes.tolist()) - - def in_global(node_id): - return node_id in global_nodes - - ast_nodes = nodes.query("mentioned_in.map(@in_global)", local_dict={"in_global": in_global})["id"] - - split_nodes = global_nodes | set(ast_nodes.tolist()) - - def nodes_in_split(node_id): - return node_id in split_nodes - - return nodes.query("id.map(@nodes_in_split)", local_dict={"nodes_in_split": nodes_in_split}).index - - return get_split_indices(train), get_split_indices(valid), get_split_indices(test) - - -def get_global_edges(): - """ - :return: Set of global edges and their reverses - """ - from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import special_mapping, node_types - types = set() - - for key, value in special_mapping.items(): - types.add(key) - types.add(value) - - for _, value in node_types.items(): - types.add(value + "_name") - - return types - - class SourceGraphDataset: g = None nodes = None @@ -175,14 +38,19 @@ class SourceGraphDataset: labels_from = None use_node_types = None use_edge_types = None - filter = None + filter_edges = None self_loops = None - def __init__(self, data_path, - label_from, use_node_types=False, - use_edge_types=False, filter=None, self_loops=False, - train_frac=0.6, random_seed=None, tokenizer_path=None, min_count_for_objectives=1, - no_global_edges=False, remove_reverse=False, package_names=None): + def __init__( + self, data_path: Union[str, Path], use_node_types: bool = False, use_edge_types: bool = False, + filter_edges: Optional[List[str]] = None, self_loops: bool = False, + train_frac: float = 0.6, random_seed: Optional[int] = None, tokenizer_path: Union[str, Path] = None, + min_count_for_objectives: int = 1, + no_global_edges: bool = False, remove_reverse: bool = False, custom_reverse: Optional[List[str]] = None, + # package_names: Optional[List[str]] = None, + restricted_id_pool: Optional[List[int]] = None, use_ns_groups: bool = False, + subgraph_id_column=None, subgraph_partition=None + ): """ Prepares the data for training GNN model. The graph is prepared in the following way: 1. Edges are split into the train set and holdout set. Holdout set is used in the future experiments. @@ -195,53 +63,69 @@ def __init__(self, data_path, node_types flag to True. 4. Graphs require contiguous indexing of nodes. For this reason additional mapping is created that tracks the relationship between the new graph id and the original node id from the training data. - :param nodes_path: path to csv or compressed csv for nodes witch columns - "id", "type", {"name", "serialized_name"}, {any column with labels} - :param edges_path: path to csv or compressed csv for edges with columns - "id", "type", {"source_node_id", "src"}, {"target_node_id", "dst"} - :param label_from: the column where the labels are taken from - :param use_node_types: boolean value, whether to use node types or not - (node-heterogeneous graph} - :param use_edge_types: boolean value, whether to use edge types or not - (edge-heterogeneous graph} - :param filter: list[str], the types of edges to filter from graph - """ + :param data_path: path to the directory with dataset files stored in `bz2` format + :param use_node_types: whether to use node types in the graph + :param use_edge_types: whether to use edge types in the graph + :param filter_edges: edge types to be removed from the graph + :param self_loops: whether to include self-loops + :param train_frac: fraction of the nodes that will be used for training + :param random_seed: seed for generating random splits + :param tokenizer_path: path to bpe tokenizer, needed to process op names correctly + :param min_count_for_objectives: minimum degree of nodes, after which they are excluded from training data + :param no_global_edges: whether to remove global edges from the dataset. + :param remove_reverse: whether to remove reverse edges from the dataset + :param custom_reverse: list of edges for which reverse types should be added. Used together with `remove_reverse` + :param package_names: list of packages that should be used for partitioning into train and test sets. Used to + draw a solid distinction between code in train and test sets + :param restricted_id_pool: path to csv file with column `node_id` that stores nodes that should be involved into + training and testing + :param use_ns_groups: currently not used + """ self.random_seed = random_seed self.nodes_have_types = use_node_types self.edges_have_types = use_edge_types - self.labels_from = label_from self.data_path = data_path self.tokenizer_path = tokenizer_path self.min_count_for_objectives = min_count_for_objectives self.no_global_edges = no_global_edges self.remove_reverse = remove_reverse + self.custom_reverse = custom_reverse + self.subgraph_id_column = subgraph_id_column + self.subgraph_partition = subgraph_partition + + self.use_ns_groups = use_ns_groups - nodes_path = join(data_path, "nodes.bz2") - edges_path = join(data_path, "edges.bz2") + nodes_path = join(data_path, "common_nodes.json.bz2") + edges_path = join(data_path, "common_edges.json.bz2") self.nodes, self.edges = load_data(nodes_path, edges_path) + # self.nodes, self.edges, self.holdout = self.holdout(self.nodes, self.edges) + # index is later used for sampling and is assumed to be unique assert len(self.nodes) == len(self.nodes.index.unique()) assert len(self.edges) == len(self.edges.index.unique()) if self_loops: - self.nodes, self.edges = SourceGraphDataset.assess_need_for_self_loops(self.nodes, self.edges) + self.nodes, self.edges = SourceGraphDataset._assess_need_for_self_loops(self.nodes, self.edges) - if filter is not None: - for e_type in filter.split(","): + if filter_edges is not None: + for e_type in filter_edges: logging.info(f"Filtering edge type {e_type}") self.edges = self.edges.query(f"type != '{e_type}'") if self.remove_reverse: - self.remove_reverse_edges() + self._remove_reverse_edges() if self.no_global_edges: - self.remove_global_edges() + self._remove_global_edges() + + if self.custom_reverse is not None: + self._add_custom_reverse() if use_node_types is False and use_edge_types is False: - new_nodes, new_edges = self.create_nodetype_edges() + new_nodes, new_edges = self._create_nodetype_edges(self.nodes, self.edges) self.nodes = self.nodes.append(new_nodes, ignore_index=True) self.edges = self.edges.append(new_edges, ignore_index=True) @@ -250,7 +134,8 @@ def __init__(self, data_path, self.nodes['type'] = "node_" self.nodes = self.nodes.astype({'type': 'category'}) - self.add_embeddable_flag() + self._add_embedding_names() + # self._add_embeddable_flag() # need to do this to avoid issues insode dgl library self.edges['type'] = self.edges['type'].apply(lambda x: f"{x}_") @@ -269,49 +154,26 @@ def __init__(self, data_path, logging.info(f"Unique edges: {len(self.edges)}, edge types: {len(self.edges['type'].unique())}") # self.nodes, self.label_map = self.add_compact_labels() - self.add_typed_ids() + self._add_typed_ids() - self.add_splits(train_frac=train_frac, package_names=package_names) + self._add_splits(train_frac=train_frac, + package_names=None, #package_names, + restricted_id_pool=restricted_id_pool) # self.mark_leaf_nodes() - self.create_hetero_graph() + self._create_hetero_graph() - self.update_global_id() + self._update_global_id() self.nodes.sort_values('global_graph_id', inplace=True) - # self.splits = SourceGraphDataset.get_global_graph_id_splits(self.nodes) - - @classmethod - def get_global_graph_id_splits(cls, nodes): - - splits = ( - nodes.query("train_mask == True")['global_graph_id'].values, - nodes.query("val_mask == True")['global_graph_id'].values, - nodes.query("test_mask == True")['global_graph_id'].values, - ) - - return splits + def _add_embedding_names(self): + self.nodes["embeddable"] = True + self.nodes["embeddable_name"] = self.nodes["name"].apply(self.get_embeddable_name) - def compress_node_types(self): - node_type_map = compact_property(self.nodes['type']) - self.node_types = pd.DataFrame( - {"str_type": k, "int_type": v} for k, v in compact_property(self.nodes['type']).items() - ) - - self.nodes['type'] = self.nodes['type'].apply(lambda x: node_type_map[x]) - - def compress_edge_types(self): - edge_type_map = compact_property(self.edges['type']) - self.edge_types = pd.DataFrame( - {"str_type": k, "int_type": v} for k, v in compact_property(self.edges['type']).items() - ) - - self.edges['type'] = self.edges['type'].apply(lambda x: edge_type_map[x]) - - def add_embeddable_flag(self): - embeddable_types = PythonSharedNodes.shared_node_types + def _add_embeddable_flag(self): + embeddable_types = PythonSharedNodes.shared_node_types | set(list(node_types.values())) if len(self.nodes.query("type_backup == 'subword'")) > 0: # some of the types should not be embedded if subwords were generated @@ -326,9 +188,11 @@ def add_embeddable_flag(self): inplace=True ) - def op_tokens(self): + self.nodes["embeddable_name"] = self.nodes["name"].apply(self.get_embeddable_name) + + def _op_tokens(self): if self.tokenizer_path is None: - from SourceCodeTools.code.python_tokens_to_bpe_subwords import python_ops_to_bpe + from SourceCodeTools.code.ast.python_tokens_to_bpe_subwords import python_ops_to_bpe logging.info("Using heuristic tokenization for ops") # def op_tokenize(op_name): @@ -342,7 +206,7 @@ def op_tokens(self): # def op_tokenize(op_name): # return op_tokenize_or_none(op_name, tokenizer) - from SourceCodeTools.code.python_tokens_to_bpe_subwords import python_ops_to_literal + from SourceCodeTools.code.ast.python_tokens_to_bpe_subwords import python_ops_to_literal return { op_name: tokenizer(op_literal) for op_name, op_literal in python_ops_to_literal.items() @@ -351,7 +215,7 @@ def op_tokens(self): # self.nodes.eval("name_alter_tokens = name.map(@op_tokenize)", # local_dict={"op_tokenize": op_tokenize}, inplace=True) - def add_splits(self, train_frac, package_names=None): + def _add_splits(self, train_frac, package_names=None, restricted_id_pool=None): """ Generates train, validation, and test masks Store the masks is pandas table for nodes @@ -359,41 +223,44 @@ def add_splits(self, train_frac, package_names=None): :return: """ + self.nodes.reset_index(drop=True, inplace=True) assert len(self.nodes.index) == self.nodes.index.max() + 1 # generate splits for all nodes, additional filtering will be applied later # by an objective if package_names is None: - splits = get_train_val_test_indices( + splits = self.get_train_val_test_indices( self.nodes.index, train_frac=train_frac, random_seed=self.random_seed ) else: - splits = get_train_val_test_indices_on_packages( + splits = self.get_train_val_test_indices_on_packages( self.nodes, package_names, train_frac=train_frac, random_seed=self.random_seed ) - create_train_val_test_masks(self.nodes, *splits) + self.create_train_val_test_masks(self.nodes, *splits) + + if restricted_id_pool is not None: + node_ids = set(pd.read_csv(restricted_id_pool)["node_id"].tolist()) | \ + set(self.nodes.query("type_backup == 'FunctionDef' or type_backup == 'mention'")["id"].tolist()) + to_keep = self.nodes["id"].apply(lambda id_: id_ in node_ids) + self.nodes["train_mask"] = self.nodes["train_mask"] & to_keep + self.nodes["test_mask"] = self.nodes["test_mask"] & to_keep + self.nodes["val_mask"] = self.nodes["val_mask"] & to_keep - def add_typed_ids(self): + def _add_typed_ids(self): nodes = self.nodes.copy() typed_id_map = {} # node_types = dict(zip(self.node_types['int_type'], self.node_types['str_type'])) - for type in nodes['type'].unique(): - # need to use indexes because will need to reference - # the original table - type_ind = nodes[nodes['type'] == type].index - - id_map = compact_property(nodes.loc[type_ind, 'id']) - - nodes.loc[type_ind, 'typed_id'] = nodes.loc[type_ind, 'id'].apply(lambda old_id: id_map[old_id]) - - # typed_id_map[node_types[type]] = id_map - typed_id_map[type] = id_map + for type_ in nodes['type'].unique(): + type_mask = nodes['type'] == type_ + id_map = compact_property(nodes.loc[type_mask, 'id']) + nodes.loc[type_mask, 'typed_id'] = nodes.loc[type_mask, 'id'].apply(lambda old_id: id_map[old_id]) + typed_id_map[type_] = id_map assert any(pandas.isna(nodes['typed_id'])) is False @@ -408,7 +275,8 @@ def add_typed_ids(self): # nodes['compact_label'] = nodes['label'].apply(lambda old_id: label_map[old_id]) # return nodes, label_map - def add_node_types_to_edges(self, nodes, edges): + @staticmethod + def _add_node_types_to_edges(nodes, edges): # nodes = self.nodes # edges = self.edges.copy() @@ -421,15 +289,16 @@ def add_node_types_to_edges(self, nodes, edges): return edges - def create_nodetype_edges(self): - node_new_id = self.nodes["id"].max() + 1 - edge_new_id = self.edges["id"].max() + 1 + @staticmethod + def _create_nodetype_edges(nodes, edges): + node_new_id = nodes["id"].max() + 1 + edge_new_id = edges["id"].max() + 1 new_nodes = [] new_edges = [] added_type_nodes = {} - node_slice = self.nodes[["id", "type"]].values + node_slice = nodes[["id", "type"]].values for id, type in node_slice: if type not in added_type_nodes: @@ -455,30 +324,44 @@ def create_nodetype_edges(self): return pd.DataFrame(new_nodes), pd.DataFrame(new_edges) - def remove_ast_edges(self): - global_edges = get_global_edges() + def _remove_ast_edges(self): + global_edges = self.get_global_edges() global_edges.add("subword") is_global = lambda type: type in global_edges edges = self.edges.query("type_backup.map(@is_global)", local_dict={"is_global": is_global}) - self.nodes, self.edges = ensure_connectedness(self.nodes, edges) + self.nodes, self.edges = self.ensure_connectedness(self.nodes, edges) - def remove_global_edges(self): - global_edges = get_global_edges() - global_edges.add("global_mention") + def _remove_global_edges(self): + global_edges = self.get_global_edges() + # global_edges.add("global_mention") + # global_edges |= set(edge + "_rev" for edge in global_edges) is_ast = lambda type: type not in global_edges edges = self.edges.query("type.map(@is_ast)", local_dict={"is_ast": is_ast}) self.edges = edges # self.nodes, self.edges = ensure_connectedness(self.nodes, edges) - def remove_reverse_edges(self): + def _remove_reverse_edges(self): from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import special_mapping - global_reverse = {val for _, val in special_mapping.items()} + # TODO test this change + global_reverse = {key for key, val in special_mapping.items()} not_reverse = lambda type: not (type.endswith("_rev") or type in global_reverse) edges = self.edges.query("type.map(@not_reverse)", local_dict={"not_reverse": not_reverse}) self.edges = edges - def update_global_id(self): + def _add_custom_reverse(self): + to_reverse = self.edges[ + self.edges["type"].apply(lambda type_: type_ in self.custom_reverse) + ] + + to_reverse["type"] = to_reverse["type"].apply(lambda type_: type_ + "_rev") + tmp = to_reverse["src"] + to_reverse["src"] = to_reverse["dst"] + to_reverse["dst"] = tmp + + self.edges = self.edges.append(to_reverse[["src", "dst", "type"]]) + + def _update_global_id(self): orig_id = [] graph_id = [] prev_offset = 0 @@ -518,11 +401,11 @@ def typed_node_counts(self): return typed_node_counts - def create_hetero_graph(self): + def _create_hetero_graph(self): nodes = self.nodes.copy() edges = self.edges.copy() - edges = self.add_node_types_to_edges(nodes, edges) + edges = self._add_node_types_to_edges(nodes, edges) typed_node_id = dict(zip(nodes['id'], nodes['typed_id'])) @@ -585,8 +468,8 @@ def create_hetero_graph(self): self.g.nodes[ntype].data['typed_id'] = torch.tensor(node_data['typed_id'].values, dtype=torch.int64) self.g.nodes[ntype].data['original_id'] = torch.tensor(node_data['id'].values, dtype=torch.int64) - @classmethod - def assess_need_for_self_loops(cls, nodes, edges): + @staticmethod + def _assess_need_for_self_loops(nodes, edges): # this is a hack when where are only outgoing connections from this node type need_self_loop = set(edges['src'].values.tolist()) - set(edges['dst'].values.tolist()) for nid in need_self_loop: @@ -599,35 +482,249 @@ def assess_need_for_self_loops(cls, nodes, edges): return nodes, edges - # @classmethod - # def holdout(cls, nodes, edges, holdout_frac, random_seed): - # """ - # Create a set of holdout edges, ensure that there are no orphan nodes after these edges are removed. - # :param nodes: - # :param edges: - # :param holdout_frac: - # :param random_seed: - # :return: - # """ - # - # train, test = split(edges, holdout_frac, random_seed=random_seed) - # - # nodes, train_edges = ensure_connectedness(nodes, train) - # - # nodes, test_edges = ensure_valid_edges(nodes, test) - # - # return nodes, train_edges, test_edges + @staticmethod + def holdout(nodes: pd.DataFrame, edges: pd.DataFrame, holdout_size=10000, random_seed=42): + """ + Create a set of holdout edges, ensure that there are no orphan nodes after these edges are removed. + :param nodes: + :param edges: + :param holdout_frac: + :param random_seed: + :return: + """ + + from collections import Counter + + degree_count = Counter(edges["src"].tolist()) | Counter(edges["dst"].tolist()) + + heldout = [] + + edges = edges.reset_index(drop=True) + index = edges.index.to_numpy() + numpy.random.seed(random_seed) + numpy.random.shuffle(index) + + for i in index: + src_id = edges.loc[i].src + if degree_count[src_id] > 2: + heldout.append(edges.loc[i].id) + degree_count[src_id] -= 1 + if len(heldout) >= holdout_size: + break + + heldout = set(heldout) + + def is_held(id_): + return id_ in heldout + + train_edges = edges[ + edges["id"].apply(lambda id_: not is_held(id_)) + ] + + heldout_edges = edges[ + edges["id"].apply(is_held) + ] + + assert len(edges) == edges["id"].unique().size + + return nodes, train_edges, heldout_edges + + @staticmethod + def get_name_group(name): + parts = name.split("@") + if len(parts) == 1: + return pd.NA + elif len(parts) == 2: + local_name, group = parts + return group + return pd.NA + + @staticmethod + def create_train_val_test_masks(nodes, train_idx, val_idx, test_idx): + nodes['train_mask'] = True + # nodes.loc[train_idx, 'train_mask'] = True + nodes['val_mask'] = False + nodes.loc[val_idx, 'val_mask'] = True + nodes['test_mask'] = False + nodes.loc[test_idx, 'test_mask'] = True + nodes['train_mask'] = nodes['train_mask'] ^ (nodes['val_mask'] | nodes['test_mask']) + starts_with = lambda x: x.startswith("##node_type") + nodes.loc[ + nodes.eval("name.map(@starts_with)", local_dict={"starts_with": starts_with}), ['train_mask', 'val_mask', + 'test_mask']] = False + @staticmethod + def get_train_val_test_indices(indices, train_frac=0.6, random_seed=None): + if random_seed is not None: + numpy.random.seed(random_seed) + logging.warning("Random state for splitting dataset is fixed") + else: + logging.info("Random state is not set") + + indices = indices.to_numpy() + + numpy.random.shuffle(indices) + + train = int(indices.size * train_frac) + test = int(indices.size * (train_frac + (1 - train_frac) / 2)) + + logging.info( + f"Splitting into train {train}, validation {test - train}, and test {indices.size - test} sets" + ) + + return indices[:train], indices[train: test], indices[test:] + + @staticmethod + def get_train_val_test_indices_on_packages(nodes, package_names, train_frac=0.6, random_seed=None): + if random_seed is not None: + numpy.random.seed(random_seed) + logging.warning("Random state for splitting dataset is fixed") + else: + logging.info("Random state is not set") + + nodes = nodes.copy() + + package_names = [name.replace("\n", "").replace("-", "_").replace(".", "_") for name in package_names] + + package_names = numpy.array(package_names) + numpy.random.shuffle(package_names) + + train = int(package_names.size * train_frac) + test = int(package_names.size * (train_frac + (1 - train_frac) / 2)) + + logging.info( + f"Splitting into train {train}, validation {test - train}, and test {package_names.size - test} packages" + ) + + train, valid, test = package_names[:train], package_names[train: test], package_names[test:] + + train = set(train.tolist()) + valid = set(valid.tolist()) + test = set(test.tolist()) + + nodes["node_package_names"] = nodes["name"].map(lambda name: name.split(".")[0]) + + def get_split_indices(split): + global_types = {val for _, val in node_types.items()} + + def is_global_type(type): + return type in global_types + + def in_split(name): + return name in split + + global_nodes = nodes.query( + "node_package_names.map(@in_split) and type_backup.map(@is_global_type)", + local_dict={"in_split": in_split, "is_global_type": is_global_type} + )["id"] + global_nodes = set(global_nodes.tolist()) + + def in_global(node_id): + return node_id in global_nodes + + ast_nodes = nodes.query("mentioned_in.map(@in_global)", local_dict={"in_global": in_global})["id"] + + split_nodes = global_nodes | set(ast_nodes.tolist()) + + def nodes_in_split(node_id): + return node_id in split_nodes + + return nodes.query("id.map(@nodes_in_split)", local_dict={"nodes_in_split": nodes_in_split}).index + + return get_split_indices(train), get_split_indices(valid), get_split_indices(test) + + @staticmethod + def get_global_edges(): + """ + :return: Set of global edges and their reverses + """ + from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import special_mapping, node_types + types = set() + + for key, value in special_mapping.items(): + types.add(key) + types.add(value) + + for _, value in node_types.items(): + types.add(value + "_name") + + return types + + @staticmethod + def get_embeddable_name(name): + if "@" in name: + return name.split("@")[0] + elif "_0x" in name: + return name.split("_0x")[0] + else: + return name + + @staticmethod + def ensure_connectedness(nodes: pandas.DataFrame, edges: pandas.DataFrame): + """ + Filtering isolated nodes + :param nodes: DataFrame + :param edges: DataFrame + :return: + """ + + logging.info( + f"Filtering isolated nodes. " + f"Starting from {nodes.shape[0]} nodes and {edges.shape[0]} edges...", + ) + unique_nodes = set(edges['src'].append(edges['dst'])) + + nodes = nodes[ + nodes['id'].apply(lambda nid: nid in unique_nodes) + ] + + logging.info( + f"Ending up with {nodes.shape[0]} nodes and {edges.shape[0]} edges" + ) + + return nodes, edges + + @staticmethod + def ensure_valid_edges(nodes, edges, ignore_src=False): + """ + Filter edges that link to nodes that do not exist + :param nodes: + :param edges: + :param ignore_src: + :return: + """ + print( + f"Filtering edges to invalid nodes. " + f"Starting from {nodes.shape[0]} nodes and {edges.shape[0]} edges...", + end="" + ) + + unique_nodes = set(nodes['id'].values.tolist()) + + if not ignore_src: + edges = edges[ + edges['src'].apply(lambda nid: nid in unique_nodes) + ] + + edges = edges[ + edges['dst'].apply(lambda nid: nid in unique_nodes) + ] + + print( + f"ending up with {nodes.shape[0]} nodes and {edges.shape[0]} edges" + ) + + return nodes, edges # def mark_leaf_nodes(self): # leaf_types = {'subword', "Op", "Constant", "Name"} # the last is used in graphs without subwords # # self.nodes['is_leaf'] = self.nodes['type_backup'].apply(lambda type_: type_ in leaf_types) - def get_typed_node_id(self, node_id, node_type): - return self.typed_id_map[node_type][node_id] - - def get_global_node_id(self, node_id, node_type=None): - return self.node_id_to_global_id[node_id] + # def get_typed_node_id(self, node_id, node_type): + # return self.typed_id_map[node_type][node_id] + # + # def get_global_node_id(self, node_id, node_type=None): + # return self.node_id_to_global_id[node_id] def load_node_names(self): """ @@ -638,6 +735,11 @@ def load_node_names(self): ][['id', 'type_backup', 'name']]\ .rename({"name": "serialized_name", "type_backup": "type"}, axis=1) + global_node_types = set(node_types.values()) + for_training = for_training[ + for_training["type"].apply(lambda x: x not in global_node_types) + ] + node_names = extract_node_names(for_training, 2) node_names = filter_dst_by_freq(node_names, freq=self.min_count_for_objectives) @@ -645,17 +747,33 @@ def load_node_names(self): # path = join(self.data_path, "node_names.bz2") # return unpersist(path) + def load_subgraph_function_names(self): + names_path = os.path.join(self.data_path, "common_name_mappings.json.bz2") + names = unpersist(names_path) + + fname2gname = dict(zip(names["ast_name"], names["proper_names"])) + + functions = self.nodes.query( + "id in @functions", local_dict={"functions": set(self.nodes["mentioned_in"])} + ).query("type_backup == 'FunctionDef'") + + functions["gname"] = functions["name"].apply(lambda x: fname2gname.get(x, pd.NA)) + functions = functions.dropna(axis=0) + functions["gname"] = functions["gname"].apply(lambda x: x.split(".")[-1]) + + return functions.rename({"id": "src", "gname": "dst"}, axis=1)[["src", "dst"]] + def load_var_use(self): """ :return: DataFrame that contains mapping from function ids to variable names that appear in those functions """ - path = join(self.data_path, "common_function_variable_pairs.bz2") + path = join(self.data_path, "common_function_variable_pairs.json.bz2") var_use = unpersist(path) var_use = filter_dst_by_freq(var_use, freq=self.min_count_for_objectives) return var_use def load_api_call(self): - path = join(self.data_path, "common_call_seq.bz2") + path = join(self.data_path, "common_call_seq.json.bz2") api_call = unpersist(path) api_call = filter_dst_by_freq(api_call, freq=self.min_count_for_objectives) return api_call @@ -686,12 +804,12 @@ def load_token_prediction(self): def load_global_edges_prediction(self): - nodes_path = join(self.data_path, "nodes.bz2") - edges_path = join(self.data_path, "edges.bz2") + nodes_path = join(self.data_path, "common_nodes.json.bz2") + edges_path = join(self.data_path, "common_edges.json.bz2") _, edges = load_data(nodes_path, edges_path) - global_edges = get_global_edges() + global_edges = self.get_global_edges() global_edges = global_edges - {"defines", "defined_in"} # these edges are already in AST? global_edges.add("global_mention") @@ -707,9 +825,84 @@ def load_global_edges_prediction(self): return edges[["src", "dst"]] + def load_edge_prediction(self): + + nodes_path = join(self.data_path, "common_nodes.json.bz2") + edges_path = join(self.data_path, "common_edges.json.bz2") + + _, edges = load_data(nodes_path, edges_path) + + edges.rename( + { + "source_node_id": "src", + "target_node_id": "dst" + }, inplace=True, axis=1 + ) + + global_edges = {"global_mention", "subword", "next", "prev"} + global_edges = global_edges | {"mention_scope", "defined_in_module", "defined_in_class", "defined_in_function"} + + if self.no_global_edges: + global_edges = global_edges | self.get_global_edges() + + global_edges = global_edges | set(edge + "_rev" for edge in global_edges) + is_ast = lambda type: type not in global_edges + edges = edges.query("type.map(@is_ast)", local_dict={"is_ast": is_ast}) + edges = edges[edges["type"].apply(lambda type_: not type_.endswith("_rev"))] + + valid_nodes = set(edges["src"].tolist()) + valid_nodes = valid_nodes.intersection(set(edges["dst"].tolist())) + + # if self.use_ns_groups: + # groups = self.get_negative_sample_groups() + # valid_nodes = valid_nodes.intersection(set(groups["id"].tolist())) + + edges = edges[ + edges["src"].apply(lambda id_: id_ in valid_nodes) + ] + edges = edges[ + edges["dst"].apply(lambda id_: id_ in valid_nodes) + ] + + return edges[["src", "dst", "type"]] + + def load_type_prediction(self): + + type_ann = unpersist(join(self.data_path, "type_annotations.json.bz2")) + + filter_rule = lambda name: "0x" not in name + + type_ann = type_ann[ + type_ann["dst"].apply(filter_rule) + ] + + node2id = dict(zip(self.nodes["id"], self.nodes["type_backup"])) + type_ann = type_ann[ + type_ann["src"].apply(lambda id_: id_ in node2id) + ] + + type_ann["src_type"] = type_ann["src"].apply(lambda x: node2id[x]) + + type_ann = type_ann[ + type_ann["src_type"].apply(lambda type_: type_ in {"mention"}) # FunctionDef {"arg", "AnnAssign"}) + ] + + norm = lambda x: x.strip("\"").strip("'").split("[")[0].split(".")[-1] + + type_ann["dst"] = type_ann["dst"].apply(norm) + type_ann = filter_dst_by_freq(type_ann, self.min_count_for_objectives) + type_ann = type_ann[["src", "dst"]] + + return type_ann + + def load_cubert_subgraph_labels(self): + + filecontent = unpersist(join(self.data_path, "common_filecontent.json.bz2")) + return filecontent[["id", "label"]].rename({"id": "src", "label": "dst"}, axis=1) + def load_docstring(self): - docstrings_path = os.path.join(self.data_path, "common_source_graph_bodies.bz2") + docstrings_path = os.path.join(self.data_path, "common_source_graph_bodies.json.bz2") dosctrings = unpersist(docstrings_path)[["id", "docstring"]] @@ -730,6 +923,18 @@ def normalize(text): return dosctrings + def load_node_classes(self): + have_inbound = set(self.edges["dst"].tolist()) + labels = self.nodes.query("train_mask == True or test_mask == True or val_mask == True")[["id", "type_backup"]].rename({ + "id": "src", + "type_backup": "dst" + }, axis=1) + + labels = labels[ + labels["src"].apply(lambda id_: id_ in have_inbound) + ] + return labels + def buckets_from_pretrained_embeddings(self, pretrained_path, n_buckets): from SourceCodeTools.nlp.embed.fasttext import load_w2v_map @@ -754,7 +959,7 @@ def op_embedding(op_tokens): embedding = embedding + token_emb return embedding - python_ops_to_bpe = self.op_tokens() + python_ops_to_bpe = self._op_tokens() for op, op_tokens in python_ops_to_bpe.items(): op_emb = op_embedding(op_tokens) if op_emb is not None: @@ -783,86 +988,61 @@ def create_node_name_masker(self, tokenizer_path): """ return NodeNameMasker(self.nodes, self.edges, self.load_node_names(), tokenizer_path) + def create_node_clf_masker(self): + """ + :param tokenizer_path: path to bpe tokenizer + :return: SubwordMasker for function nodes. Suitable for node name use prediction objective + """ + return NodeClfMasker(self.nodes, self.edges) -def ensure_connectedness(nodes: pandas.DataFrame, edges: pandas.DataFrame): - """ - Filtering isolated nodes - :param nodes: DataFrame - :param edges: DataFrame - :return: - """ + def get_negative_sample_groups(self): + return self.nodes[["id", "mentioned_in"]].dropna(axis=0) - logging.info( - f"Filtering isolated nodes. " - f"Starting from {nodes.shape[0]} nodes and {edges.shape[0]} edges...", - ) - unique_nodes = set(edges['src'].values.tolist() + - edges['dst'].values.tolist()) + @property + def subgraph_mapping(self): + assert self.subgraph_id_column is not None, "`subgraph_id_column` was not provided" - nodes = nodes[ - nodes['id'].apply(lambda nid: nid in unique_nodes) - ] + id2type = dict(zip(self.nodes["id"], self.nodes["type"])) - logging.info( - f"Ending up with {nodes.shape[0]} nodes and {edges.shape[0]} edges" - ) + subgraph_mapping = dict() - return nodes, edges - - -def ensure_valid_edges(nodes, edges, ignore_src=False): - """ - Filter edges that link to nodes that do not exist - :param nodes: - :param edges: - :param ignore_src: - :return: - """ - print( - f"Filtering edges to invalid nodes. " - f"Starting from {nodes.shape[0]} nodes and {edges.shape[0]} edges...", - end="" - ) + def add_item(subgraph_dict, node_id): + type_ = id2type[node_id] - unique_nodes = set(nodes['id'].values.tolist()) + if type_ not in subgraph_dict: + subgraph_dict[type_] = set() - if not ignore_src: - edges = edges[ - edges['src'].apply(lambda nid: nid in unique_nodes) - ] + subgraph_dict[type_].add(self.typed_id_map[type_][node_id]) - edges = edges[ - edges['dst'].apply(lambda nid: nid in unique_nodes) - ] + for src, dst, subgraph_id in self.edges[["src", "dst", self.subgraph_id_column]].values: + if subgraph_id not in subgraph_mapping: + subgraph_mapping[subgraph_id] = dict() - print( - f"ending up with {nodes.shape[0]} nodes and {edges.shape[0]} edges" - ) + subgraph_dict = subgraph_mapping[subgraph_id] + add_item(subgraph_dict, src) + add_item(subgraph_dict, dst) - return nodes, edges + for subgraph_id, subgraph_dict in subgraph_mapping.items(): + for type_ in subgraph_dict: + subgraph_dict[type_] = list(subgraph_dict[type_]) + return subgraph_mapping -def read_or_create_dataset(args, model_base, labels_from="type"): - if args.restore_state: + @classmethod + def load(cls, path, args): + dataset = pickle.load(open(path, "rb")) + dataset.data_path = args["data_path"] + if dataset.tokenizer_path is not None: + dataset.tokenizer_path = args["tokenizer"] + return dataset + + +def read_or_create_gnn_dataset(args, model_base, force_new=False, restore_state=False): + if restore_state and not force_new: # i'm not happy with this behaviour that differs based on the flag status - dataset = pickle.load(open(join(model_base, "dataset.pkl"), "rb")) + dataset = SourceGraphDataset.load(join(model_base, "dataset.pkl"), args) else: - dataset = SourceGraphDataset( - # args.node_path, args.edge_path, - args.data_path, - label_from=labels_from, - use_node_types=args.use_node_types, - use_edge_types=args.use_edge_types, - filter=args.filter_edges, - self_loops=args.self_loops, - train_frac=args.train_frac, - tokenizer_path=args.tokenizer, - random_seed=args.random_seed, - min_count_for_objectives=args.min_count_for_objectives, - no_global_edges=args.no_global_edges, - remove_reverse=args.remove_reverse, - package_names=open(args.packages_file).readlines() if args.packages_file is not None else None - ) + dataset = SourceGraphDataset(**args) # save dataset state for recovery pickle.dump(dataset, open(join(model_base, "dataset.pkl"), "wb")) @@ -880,7 +1060,6 @@ def test_dataset(): dataset = SourceGraphDataset( data_path, # nodes_path, edges_path, - label_from='type', use_node_types=False, use_edge_types=True, ) diff --git a/SourceCodeTools/code/data/sourcetrail/SubwordMasker.py b/SourceCodeTools/code/data/dataset/SubwordMasker.py similarity index 87% rename from SourceCodeTools/code/data/sourcetrail/SubwordMasker.py rename to SourceCodeTools/code/data/dataset/SubwordMasker.py index 89607ad3..509ac997 100644 --- a/SourceCodeTools/code/data/sourcetrail/SubwordMasker.py +++ b/SourceCodeTools/code/data/dataset/SubwordMasker.py @@ -103,4 +103,25 @@ def instantiate(self, nodes, orig_edges, **kwargs): self.lookup[key].extend([subword2key[sub] for sub in subwords if sub in subword2key]) +class NodeClfMasker(SubwordMasker): + """ + Masker that tells which node ids are subwords for variables mentioned in a given function. + """ + def __init__(self, nodes: pd.DataFrame, edges: pd.DataFrame, **kwargs): + super(NodeClfMasker, self).__init__(nodes, edges, **kwargs) + + def instantiate(self, nodes, orig_edges, **kwargs): + pass + + def get_mask(self, ids): + """ + Accepts node ids that represent embeddable tokens as an input + :param ids: + :return: + """ + if isinstance(ids, dict): + for_masking = ids + else: + for_masking = {"node_": ids} + return for_masking diff --git a/SourceCodeTools/code/data/dataset/__init__.py b/SourceCodeTools/code/data/dataset/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SourceCodeTools/code/data/dataset/reader.py b/SourceCodeTools/code/data/dataset/reader.py new file mode 100644 index 00000000..05e652ca --- /dev/null +++ b/SourceCodeTools/code/data/dataset/reader.py @@ -0,0 +1,65 @@ +from pathlib import Path + +from SourceCodeTools.code.common import read_nodes, read_edges +from SourceCodeTools.code.data.file_utils import unpersist +from SourceCodeTools.code.annotator_utils import source_code_graph_alignment + + +def load_data(node_path, edge_path, rename_columns=True): + nodes = read_nodes(node_path) + edges = read_edges(edge_path) + + if rename_columns: + nodes = nodes.rename(mapper={ + 'serialized_name': 'name' + }, axis=1) + edges = edges.rename(mapper={ + 'source_node_id': 'src', + 'target_node_id': 'dst' + }, axis=1) + + return nodes, edges + + +def load_graph(dataset_directory, rename_columns=True): + dataset_path = Path(dataset_directory) + nodes = dataset_path.joinpath("common_nodes.bz2") + edges = dataset_path.joinpath("common_edges.bz2") + + return load_data(nodes, edges, rename_columns=rename_columns) + + +def load_aligned_source_code(dataset_directory, tokenizer="codebert"): + dataset_path = Path(dataset_directory) + + files = unpersist(dataset_path.joinpath("common_filecontent.bz2")).rename({"id": "file_id"}, axis=1) + + content = dict(zip(zip(files["package"], files["file_id"]), files["filecontent"])) + pd_offsets = unpersist(dataset_path.joinpath("common_offsets.bz2")) + + seen = set() + + source_codes = [] + offsets = [] + + for group, data in pd_offsets.groupby(by=["package", "file_id"]): + source_codes.append(content[group]) + offsets.append(list(zip(data["start"], data["end"], data["node_id"]))) + seen.add(group) + + for key, val in content.items(): + if key not in seen: + source_codes.append(val) + offsets.append([]) + + return source_code_graph_alignment(source_codes, offsets, tokenizer=tokenizer) + + +if __name__ == "__main__": + import sys + data_path = sys.argv[1] + for tokens, node_tags in load_aligned_source_code(data_path): + for t, tt in zip(tokens, node_tags): + print(t, tt, sep="\t") + print() + diff --git a/SourceCodeTools/code/data/sourcetrail/file_utils.py b/SourceCodeTools/code/data/file_utils.py similarity index 53% rename from SourceCodeTools/code/data/sourcetrail/file_utils.py rename to SourceCodeTools/code/data/file_utils.py index 05287e26..6f25cc4f 100644 --- a/SourceCodeTools/code/data/sourcetrail/file_utils.py +++ b/SourceCodeTools/code/data/file_utils.py @@ -1,5 +1,11 @@ +import bz2 import logging +import tempfile from csv import QUOTE_NONNUMERIC +from pathlib import Path +from typing import Union + +import numpy as np import pandas as pd import os @@ -44,6 +50,70 @@ def read_parquet(path, **kwargs): return pd.read_parquet(path, **kwargs) +def write_json(df, path, **kwargs): + if "mode" in kwargs: + mode = kwargs.pop("mode") + else: + mode = None + + if mode == "a": + with open(path, "a") as sink: + sink.write("\n") + sink.write(df.to_json(path_or_buf=None, orient="records", lines=True, **kwargs)) + else: + df.to_json(path_or_buf=path, orient="records", lines=True, **kwargs) + + +def read_json_with_generator(path, chunksize, **kwargs): + if str(path).endswith(".bz2"): + source = bz2.open(path, mode="rt") + else: + source = open(path, "r") + # need to read manually, probably a bug in pandas + buffer = [] + last_index = 0 + + def prepare_chunk(buffer, last_index): + chunk = pd.read_json("".join(buffer), orient="records", lines=True, **kwargs) + chunk.index = np.array(list(range(last_index, last_index + len(chunk)))) + last_index = last_index + len(chunk) + return chunk, last_index + + for ind, line in enumerate(source): + buffer.append(line) + if len(buffer) >= chunksize: + chunk, last_index = prepare_chunk(buffer, last_index) + # chunk = pd.read_json("".join(buffer), orient="records", lines=True, **kwargs) + # chunk.index = np.array(list(range(last_index, last_index + len(chunk)))) + # last_index = last_index + len(chunk) + yield chunk + buffer.clear() + if len(buffer) != 0: + chunk, last_index = prepare_chunk(buffer, last_index) + # chunk = pd.read_json("".join(buffer), orient="records", lines=True, **kwargs) + # chunk.index = np.array(list(range(last_index, last_index + len(chunk)))) + # last_index = last_index + len(chunk) + yield chunk + + +def _grow_with_chunks(chunks): + table = None + for chunk in chunks: + if table is None: + table = chunk + else: + table = pd.concat([table, chunk], copy=False) + return table + + +def read_json(path, **kwargs): + if "chunksize" in kwargs: + return read_json_with_generator(path, kwargs.pop("chunksize"), **kwargs) + else: + return _grow_with_chunks(read_json_with_generator(path, 100000, **kwargs)) + # return pd.read_json(path, orient="records", lines=True, **kwargs) + + def read_source_location(base_path): source_location_path = os.path.join(base_path, filenames["source_location"]) @@ -138,26 +208,58 @@ def write_processed_bodies(df, base_path): persist(df, bodies_path) -def persist(df: pd.DataFrame, path: str, **kwargs): - if path.endswith(".csv"): +def likely_format(path): + path = Path(path) + name_parts = path.name.split(".") + if len(name_parts) == 1: + raise ValueError("Extension is not found for the file:", str(path)) + + extensions = "." + ".".join(name_parts[1:]) + + if ".csv" in extensions or ".tsv" in extensions: + ext = "csv" + elif ".json" in extensions: + ext = "json" + elif ".pkl" in extensions or extensions.endswith(".bz2"): + ext = "pkl" + elif ".parquet" in extensions: + ext = "parquet" + else: + raise NotImplementedError("supported extensions: csv, bz2, pkl, parquet, json", extensions) + + return ext + + +def persist(df: pd.DataFrame, path: Union[str, Path, bytes], **kwargs): + if isinstance(path, Path): + path = str(path.absolute()) + + format = likely_format(path) + if format == "csv": write_csv(df, path, **kwargs) - elif path.endswith(".pkl") or path.endswith(".bz2"): + elif format == "pkl": write_pickle(df, path, **kwargs) - elif path.endswith(".parquet"): + elif format == "parquet": write_parquet(df, path, **kwargs) - else: - raise NotImplementedError("supported extensions: csv, bz2, pkl, parquet") + elif format == "json": + write_json(df, path, **kwargs) + +def unpersist(path: Union[str, Path, bytes], **kwargs) -> pd.DataFrame: + if isinstance(path, Path): + path = str(path.absolute()) -def unpersist(path: str, **kwargs) -> pd.DataFrame: - if path.endswith(".csv"): + format = likely_format(path) + if format == "csv": data = read_csv(path, **kwargs) - elif path.endswith(".pkl") or path.endswith(".bz2"): + elif format == "pkl": data = read_pickle(path, **kwargs) - elif path.endswith(".parquet"): + elif format == "parquet": data = read_parquet(path, **kwargs) + elif format == "json": + data = read_json(path, **kwargs) else: - raise NotImplementedError("supported extensions: csv, bz2, pkl, parquet") + data = None return data @@ -178,3 +280,19 @@ def unpersist_or_exit(path, exit_message=None, **kwargs): sys.exit() else: return data + + +def get_random_name(length=10): + char_ranges = [chr(i) for i in range(ord("a"), ord("a")+26)] + \ + [chr(i) for i in range(ord("A"), ord("A")+26)] + \ + [chr(i) for i in range(ord("0"), ord("0")+10)] + from random import sample + return "".join(sample(char_ranges, k=length)) + +def get_temporary_filename(): + tmp_dir = tempfile.gettempdir() + name_generator = tempfile._get_candidate_names() + path = os.path.join(tmp_dir, next(name_generator)) + while os.path.isdir(path): + path = os.path.join(tmp_dir, next(name_generator)) + return path \ No newline at end of file diff --git a/SourceCodeTools/code/data/py_150/create_download_links.py b/SourceCodeTools/code/data/py_150/create_download_links.py new file mode 100644 index 00000000..63344569 --- /dev/null +++ b/SourceCodeTools/code/data/py_150/create_download_links.py @@ -0,0 +1,22 @@ +import sys + +def parse_line(line): + hash, repo_link = line.split("\t") + parts = repo_link.split("/") + user = parts[3] + repo = parts[4] + return user, repo, hash + +def main(): + # download_link = "https://codeload.github.com/{user}/{repo}/zip/{hash}" + download_link = "https://codeload.github.com/{}/{}/zip/{}" + + for line in sys.stdin: + line = line.strip() + if line == "": + continue + + print(download_link.format(*parse_line(line))) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/py_150/download.sh b/SourceCodeTools/code/data/py_150/download.sh new file mode 100644 index 00000000..bbd77dc8 --- /dev/null +++ b/SourceCodeTools/code/data/py_150/download.sh @@ -0,0 +1,7 @@ +while read p; do + reponame=$(echo $p | awk -F"/" '{print $5}') + echo $reponame, $p + wget -O "$reponame.zip" $p + echo "Waiting 2s..." + sleep 2s +done \ No newline at end of file diff --git a/SourceCodeTools/code/data/scaa/partitioning.py b/SourceCodeTools/code/data/scaa/partitioning.py new file mode 100644 index 00000000..4a9367e9 --- /dev/null +++ b/SourceCodeTools/code/data/scaa/partitioning.py @@ -0,0 +1,80 @@ +from functools import partial +from os.path import join +from random import random + +import numpy as np + +from SourceCodeTools.code.common import read_edges, read_nodes +from SourceCodeTools.code.data.file_utils import persist, unpersist + + +def add_splits(items, train_frac, restricted_id_pool=None): + items = items.copy() + + def random_partition(): + r = random() + if r < train_frac: + return "train" + elif r < train_frac + (1 - train_frac) / 2: + return "val" + else: + return "test" + + import numpy as np + # define partitioning + masks = np.array([random_partition() for _ in range(len(items))]) + + # create masks + items["train_mask"] = masks == "train" + items["val_mask"] = masks == "val" + items["test_mask"] = masks == "test" + + if restricted_id_pool is not None: + # if `restricted_id_pool` is provided, mask all nodes not in `restricted_id_pool` negatively + to_keep = items.eval("id in @restricted_ids", local_dict={"restricted_ids": restricted_id_pool}) + items["train_mask"] = items["train_mask"] & to_keep + items["test_mask"] = items["test_mask"] & to_keep + items["val_mask"] = items["val_mask"] & to_keep + + return items + + +def subgraph_partitioning(path_to_dataset, partition_column, train_frac=0.7): + + get_path = partial(join, path_to_dataset) + + # nodes = read_nodes(get_path("common_nodes.json.bz2")) + edges = read_edges(get_path("common_edges.json.bz2")) + filecontent = unpersist(get_path("common_filecontent.json.bz2")) + + def random_partition(): + r = random() + if r < train_frac: + return "train" + elif r < train_frac + (1 - train_frac) / 2: + return "val" + else: + return "test" + + task2split = dict(zip(filecontent[partition_column], [random_partition() for task in filecontent[partition_column]])) + file_id2split = np.array([task2split[task] for task in filecontent[partition_column]]) + + subgraph_ids = filecontent[["id"]] + + subgraph_ids["train_mask"] = file_id2split == "train" + subgraph_ids["val_mask"] = file_id2split == "val" + subgraph_ids["test_mask"] = file_id2split == "test" + + persist(subgraph_ids, get_path("subgraph_partition.json")) + + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("dataset_path") + parser.add_argument("subgraph_column") + + args = parser.parse_args() + + subgraph_partitioning(args.dataset_path, args.subgraph_column) \ No newline at end of file diff --git a/SourceCodeTools/code/data/scaa/prepare_for_ast_parser.py b/SourceCodeTools/code/data/scaa/prepare_for_ast_parser.py new file mode 100644 index 00000000..fffde9f0 --- /dev/null +++ b/SourceCodeTools/code/data/scaa/prepare_for_ast_parser.py @@ -0,0 +1,342 @@ +import hashlib + +# import pandas as pd +import ast +import json +from pathlib import Path + +import pandas as pd +from nltk import RegexpTokenizer +from itertools import chain + +from tqdm import tqdm + +from SourceCodeTools.code.data.cubert_python_benchmarks.SQLTable import SQLTable + + +class CodeTokenizer: + tok = RegexpTokenizer("\w+|\s+|\W+") + + @classmethod + def tokenize(cls, code): + return cls.tok.tokenize(code) + + @classmethod + def detokenize(cls, code_tokens): + return "".join(code_tokens) + + +def test_CodeTokenizer(): + code = """ +data = pd.DataFrame2.from_records(json.loads(line) for line in open(filename).readlines()) + +success = 0 +errors = 0 +""" + assert CodeTokenizer.detokenize(CodeTokenizer.tokenize(code)) == code + + +test_CodeTokenizer() + + +class Info: + """ + Example + dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async`->`self` + """ + def __init__(self, info): + self.info = info + parts = info.split(" ") + self.parse_filepath(parts) + self.parse_fn_path(parts) + self.parse_state(parts) + self.parse_package(parts) + self.set_id() + + def parse_filepath(self, parts): + self.filepath = " ".join(self.info.split(".py")[0].split(" ")[1:]) + ".py" + # self.filepath = parts[1].split(".py")[0] + ".py" + + def parse_fn_path(self, parts): + self.fn_path = self.filepath + "/" + self.info.split(".py")[1].lstrip("/").lstrip(" ").split("/")[0] + + def parse_state(self, parts): + self.state = "/".join(self.info.split(".py")[1].lstrip("/").lstrip(" ").split("/")[1:]) + + def parse_package(self, parts): + self.package = self.filepath.split("/")[0] + + def set_id(self): + identifier = "\t".join([self.fn_path, self.state]) + self.id = hashlib.md5(identifier.encode('utf-8')).hexdigest() + + oidentifier = "\t".join([self.fn_path, "original"]) + + if self.state.endswith("riginal"): + assert identifier == oidentifier + self.original_id = hashlib.md5(oidentifier.encode('utf-8')).hexdigest() + + def __repr__(self): + return self.info + + +class DatasetAdapter: + replacements = { + "async": "async_", + "await": "await_" + } + + fields = { + "python_tasks": ["flines", "task" , "user", "year"], + } + # supported_partitions = ["train", "dev", "eval"] # + + benchmark_names = { + "python_tasks": "scaa_python.csv", + } + + preferred_column_order = ["id", "package", "function", "info", "label", "partition"] + + import_order = [ + "python_tasks", + ] + + def __init__(self, dataset_location): + self.dataset_location = Path(dataset_location) + + self.replacement_fns = { + "flines": self.fix_code_if_needed, + } + + self.preprocess = { + # "variable_misuse_repair": { + # "function": self.cubert_detokenize + # } + } + + self.extra_fields = { + # "info": [("package", self.get_package)], + # "provenance": [("info", self.fix_info_if_needed)] + } + + # self.db = SQLTable(self.dataset_location.joinpath("cubert_benchmarks.db")) + + # def load_original_functions(self): + # functions = self.db.query("SELECT DISTINCT original_id, function FROM functions where comment = 'original' AND dataset = 'variable_misuse'") + # self.original_functions = dict(zip(functions["original_id"], functions["function"])) + + def prepare_misuse_repair_record(self, record): + print(record) + + @staticmethod + def get_source_from_ast_range(node, function, strip=True): + lines = function.split("\n") + start_line = node.lineno + end_line = node.end_lineno + start_col = node.col_offset + end_col = node.end_col_offset + + source = "" + num_lines = end_line - start_line + 1 + if start_line == end_line: + section = lines[start_line - 1].encode("utf8")[start_col:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" + else: + for ind, lineno in enumerate(range(start_line - 1, end_line)): + if ind == 0: + section = lines[lineno].encode("utf8")[start_col:].decode( + "utf8") + source += section.strip() if strip else section + "\n" + elif ind == num_lines - 1: + section = lines[lineno].encode("utf8")[:end_col].decode( + "utf8") + source += section.strip() if strip else section + "\n" + else: + section = lines[lineno] + source += section.strip() if strip else section + "\n" + + return source.rstrip() + + def get_dispatch(self, function): + root = ast.parse(function) + return self.get_source_from_ast_range(root.body[0].decorator_list[0], function) + + @staticmethod + def remove_indent(code): + lines = code.strip("\n").split("\n") + first_line_indent = lines[0][:len(lines[0]) - len(lines[0].lstrip())] + start_char = len(first_line_indent) + if start_char != 0: + clean = "\n".join(line[start_char:] if line.startswith(first_line_indent) else line for line in lines) + else: + if lines[0].lstrip().startswith("@"): + for ind, line in enumerate(lines): + stripped = line.lstrip() + if stripped.startswith("def "): + lines[ind] = stripped + break + clean = "\n".join(lines) + return clean + + @classmethod + def fix_code_if_needed(cls, code): + # f = code.lstrip() + f = cls.remove_indent(code) + try: + ast.parse(f) + except Exception as e: + tokens = CodeTokenizer.tokenize(f) + recovered_tokens = [cls.replacements[token] if token in cls.replacements else token for token in tokens] + f = CodeTokenizer.detokenize(recovered_tokens) + ast.parse(f) + return f + + @classmethod + def fix_info_if_needed(cls, info): + if not info.endswith("original"): + parts = info.split(" ") + variable_replacement = parts[-1] + + variable_replacement = CodeTokenizer.detokenize( + cls.replacements[token] if token in cls.replacements else token + for token in CodeTokenizer.tokenize(variable_replacement) + ) + + parts[-1] = variable_replacement + + info = " ".join(parts) + return info + + @staticmethod + def get_package(info): + return info.split(" ")[1].split("/")[0] + + @classmethod + def sort_columns(cls, columns): + return sorted(columns, key=cls.preferred_column_order.index) + + def process_record(self, record, preprocess_fns): + new_record = {} + for field, data in record.items(): + if preprocess_fns is not None and field in preprocess_fns: + record[field] = preprocess_fns[field](data) + + new_record[field] = data if field not in self.replacement_fns else self.replacement_fns[field](" " + data) + + for new_field, new_field_fn in self.extra_fields.get(field, []): + new_record[new_field] = new_field_fn(data) + + return new_record + + @classmethod + def stream_original_partition(cls, file, *args, **kwargs): + data = pd.read_csv(file) + for record in data.to_dict(orient="records"): + yield record + + def stream_processed_partition(self, file, preprocess_fns=None): + for record in self.stream_original_partition(file): + try: + r = self.process_record(record, preprocess_fns=preprocess_fns) + r["parsing_error"] = None + except Exception as e: + r = record + r["parsing_error"] = e.msg if hasattr(e, "msg") else e.__class__.__name__ + # except MemoryError: # there are two functions that cause this error + # continue + # except SyntaxError as e: + # continue + yield r + + # @classmethod + # def process_dataset(cls, original_data_location, output_location): + # partitions = ["train", "dev", "eval"] + # + # last_id = 0 + # column_order = None + # + # data = pd.DataFrame.from_records( + # chain( + # *(cls.stream_processed_partition(original_data_location, partition, add_partition=True) for partition in + # partitions)), + # # columns = column_order + # ) + # data["id"] = range(len(data)) + # # data.to_pickle(output_location, index=False, columns=cls.sort_columns(data.columns)) + # data.to_pickle(output_location) + + def iterate_dataset(self, dataset_name): + for record in self.stream_processed_partition( + self.dataset_location.joinpath(self.benchmark_names[dataset_name]), + preprocess_fns=self.preprocess.get(dataset_name, None) + ): + yield record + + def import_data(self): + + functions = [] + + # added_original = set() + + for dataset_name in self.import_order: + + parsed_successfully = 0 + parsed_with_errors = 0 + + dataset_file = open(self.dataset_location.joinpath(f"{dataset_name}.jsonl"), "w") + + for record in tqdm(self.iterate_dataset(dataset_name), desc=f"Processing {dataset_name}"): + record_for_writing = { + "id": record["id"], + "function": record["flines"], + "user": record["user"], + "task": record["task"], + "year": record["year"], + "package": record["user"], + "parsing_error": record["parsing_error"] + } + + if record["parsing_error"] is None: + parsed_successfully += 1 + dataset_file.write(f"{json.dumps(record_for_writing)}\n") + else: + parsed_with_errors += 1 + + dataset_file.close() + print(f"{dataset_name}: success {parsed_successfully} error {parsed_with_errors}") + # functions.append(record_for_writing) + # + # if len(functions) > 100000: + # self.db.add_records(pd.DataFrame.from_records(functions), "functions") + # functions.clear() + # + # if len(functions) > 0: + # self.db.add_records(pd.DataFrame.from_records(functions), "functions") + # functions.clear() + + + +def test_fix_info_if_needed(): + example_info = "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async`->`self`" + assert DatasetAdapter.fix_info_if_needed( + example_info) == "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async_`->`self`" + + example_info = "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async_call`->`self`" + assert DatasetAdapter.fix_info_if_needed( + example_info) == "dataset/ETHPy150Open EricssonResearch/calvin-base/calvin/requests/request_handler.py RequestHandler.get_index/VarMisuse@32/36 `async_call`->`self`" + + +test_fix_info_if_needed() + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser("Convert CuBERT's variable misuse detection dataset for further processing") + parser.add_argument("dataset_path", help="Path to dataset folder") + # parser.add_argument("output_path", help="Path to output file") + args = parser.parse_args() + + dataset = DatasetAdapter(args.dataset_path) + dataset.import_data() + # DatasetAdapter.process_dataset(args.dataset_path, args.output_path) \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py b/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py index 88217468..dd69c271 100644 --- a/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py +++ b/SourceCodeTools/code/data/sourcetrail/DatasetCreator2.py @@ -1,8 +1,14 @@ +import logging +import os from os.path import join -from tqdm import tqdm -from SourceCodeTools.code.data.sourcetrail.common import map_offsets -from SourceCodeTools.code.data.sourcetrail.sourcetrail_map_id_columns import map_columns +from SourceCodeTools.cli_arguments import DatasetCreatorArguments +from SourceCodeTools.code.annotator_utils import map_offsets +from SourceCodeTools.code.common import read_nodes +from SourceCodeTools.code.data.AbstractDatasetCreator import AbstractDatasetCreator +from SourceCodeTools.code.data.ast_graph.filter_type_edges import filter_type_edges_with_chunks +from SourceCodeTools.code.data.file_utils import filenames, unpersist_if_present, read_element_component +from SourceCodeTools.code.data.sourcetrail.sourcetrail_filter_type_edges import filter_type_edges from SourceCodeTools.code.data.sourcetrail.sourcetrail_merge_graphs import get_global_node_info, merge_global_with_local from SourceCodeTools.code.data.sourcetrail.sourcetrail_node_local2global import get_local2global from SourceCodeTools.code.data.sourcetrail.sourcetrail_node_name_merge import merge_names @@ -15,179 +21,85 @@ from SourceCodeTools.code.data.sourcetrail.sourcetrail_extract_variable_names import extract_var_names from SourceCodeTools.code.data.sourcetrail.sourcetrail_extract_node_names import extract_node_names -from SourceCodeTools.code.data.sourcetrail.file_utils import * +class DatasetCreator(AbstractDatasetCreator): + """ + Merges several environments indexed with Sourcetrail into a single graph. + """ + + merging_specification = { + "nodes.bz2": {"columns": ['id'], "output_path": "common_nodes.json", "ensure_unique_with": ['type', 'serialized_name']}, + "edges.bz2": {"columns": ['target_node_id', 'source_node_id'], "output_path": "common_edges.json"}, + "bodies.bz2": {"columns": ['id'], "output_path": "common_bodies.json", "columns_special": [("replacement_list", map_offsets)]}, + "function_variable_pairs.bz2": {"columns": ['src'], "output_path": "common_function_variable_pairs.json"}, + "call_seq.bz2": {"columns": ['src', 'dst'], "output_path": "common_call_seq.json"}, + + "nodes_with_ast.bz2": {"columns": ['id', 'mentioned_in'], "output_path": "common_nodes.json", "ensure_unique_with": ['type', 'serialized_name']}, + "edges_with_ast.bz2": {"columns": ['target_node_id', 'source_node_id', 'mentioned_in'], "output_path": "common_edges.json"}, + "offsets.bz2": {"columns": ['node_id'], "output_path": "common_offsets.json", "columns_special": [("mentioned_in", map_offsets)]}, + "filecontent_with_package.bz2": {"columns": [], "output_path": "common_filecontent.json"}, + "name_mappings.bz2": {"columns": [], "output_path": "common_name_mappings.json"}, + } + + files_for_merging = [ + "nodes.bz2", "edges.bz2", "bodies.bz2", "function_variable_pairs.bz2", "call_seq.bz2" + ] + files_for_merging_with_ast = [ + "nodes_with_ast.bz2", "edges_with_ast.bz2", "bodies.bz2", "function_variable_pairs.bz2", + "call_seq.bz2", "offsets.bz2", "filecontent_with_package.bz2", "name_mappings.bz2" + ] + + edge_priority = { + "next": -1, "prev": -1, "global_mention": -1, "global_mention_rev": -1, + "calls": 0, + "called_by": 0, + "defines": 1, + "defined_in": 1, + "inheritance": 1, + "inherited_by": 1, + "imports": 1, + "imported_by": 1, + "uses": 2, + "used_by": 2, + "uses_type": 2, + "type_used_by": 2, + "mention_scope": 10, + "mention_scope_rev": 10, + "defined_in_function": 4, + "defined_in_function_rev": 4, + "defined_in_class": 5, + "defined_in_class_rev": 5, + "defined_in_module": 6, + "defined_in_module_rev": 6 + } + + restricted_edges = {"global_mention_rev"} + restricted_in_types = { + "Op", "Constant", "#attr#", "#keyword#", + 'CtlFlow', 'JoinedStr', 'Name', 'ast_Literal', + 'subword', 'type_annotation' + } + + type_annotation_edge_types = ['annotation_for', 'returned_by'] -class DatasetCreator: def __init__( self, path, lang, bpe_tokenizer, create_subword_instances, connect_subwords, only_with_annotations, - do_extraction=False, visualize=False, track_offsets=False + do_extraction=False, visualize=False, track_offsets=False, remove_type_annotations=False, + recompute_l2g=False ): - self.indexed_path = path - self.lang = lang - self.bpe_tokenizer = bpe_tokenizer - self.create_subword_instances = create_subword_instances - self.connect_subwords = connect_subwords - self.only_with_annotations = only_with_annotations - self.extract = do_extraction - self.visualize = visualize - self.track_offsets = track_offsets - - paths = (os.path.join(path, dir) for dir in os.listdir(path)) - self.environments = sorted(list(filter(lambda path: os.path.isdir(path), paths)), key=lambda x: x.lower()) - - self.local2global_cache = {} + super().__init__( + path, lang, bpe_tokenizer, create_subword_instances, connect_subwords, only_with_annotations, + do_extraction, visualize, track_offsets, remove_type_annotations, recompute_l2g + ) from SourceCodeTools.code.data.sourcetrail.common import UNRESOLVED_SYMBOL self.unsolved_symbol = UNRESOLVED_SYMBOL - def merge(self, output_directory): - - if self.extract: - logging.info("Extracting...") - self.do_extraction() - - no_ast_path, with_ast_path = self.create_output_dirs(output_directory) - - if not self.only_with_annotations: - self.create_global_file("nodes.bz2", "local2global.bz2", ['id'], - join(no_ast_path, "common_nodes.bz2"), message="Merging nodes", ensure_unique_with=['type', 'serialized_name']) - self.create_global_file("edges.bz2", "local2global.bz2", ['target_node_id', 'source_node_id'], - join(no_ast_path, "common_edges.bz2"), message="Merging edges") - self.create_global_file("source_graph_bodies.bz2", "local2global.bz2", ['id'], - join(no_ast_path, "common_source_graph_bodies.bz2"), "Merging bodies", columns_special=[("replacement_list", map_offsets)]) - self.create_global_file("function_variable_pairs.bz2", "local2global.bz2", ['src'], - join(no_ast_path, "common_function_variable_pairs.bz2"), "Merging variables") - self.create_global_file("call_seq.bz2", "local2global.bz2", ['src', 'dst'], - join(no_ast_path, "common_call_seq.bz2"), "Merging call seq") - # self.create_global_file("name_groups.bz2", "local2global.bz2", [], - # join(no_ast_path, "name_groups.bz2"), "Merging name groups") - - global_nodes = self.filter_orphaned_nodes( - unpersist(join(no_ast_path, "common_nodes.bz2")), no_ast_path - ) - persist(global_nodes, join(no_ast_path, "common_nodes.bz2")) - node_names = extract_node_names( - global_nodes, min_count=2 - ) - if node_names is not None: - persist(node_names, join(no_ast_path, "node_names.bz2")) - - if self.visualize: - self.visualize_func( - unpersist(join(no_ast_path, "common_nodes.bz2")), - unpersist(join(no_ast_path, "common_edges.bz2")), - join(no_ast_path, "visualization.pdf") - ) - - self.create_global_file("nodes_with_ast.bz2", "local2global_with_ast.bz2", ['id', 'mentioned_in'], - join(with_ast_path, "common_nodes.bz2"), message="Merging nodes with ast", ensure_unique_with=['type', 'serialized_name']) - self.create_global_file("edges_with_ast.bz2", "local2global_with_ast.bz2", ['target_node_id', 'source_node_id', 'mentioned_in'], - join(with_ast_path, "common_edges.bz2"), "Merging edges with ast") - self.create_global_file("source_graph_bodies.bz2", "local2global_with_ast.bz2", ['id'], - join(with_ast_path, "common_source_graph_bodies.bz2"), "Merging bodies with ast", columns_special=[("replacement_list", map_offsets)]) - self.create_global_file("function_variable_pairs.bz2", "local2global_with_ast.bz2", ['src'], - join(with_ast_path, "common_function_variable_pairs.bz2"), "Merging variables with ast") - self.create_global_file("call_seq.bz2", "local2global_with_ast.bz2", ['src', 'dst'], - join(with_ast_path, "common_call_seq.bz2"), "Merging call seq with ast") - # self.create_global_file("name_groups.bz2", "local2global_with_ast.bz2", [], - # join(with_ast_path, "name_groups.bz2"), "Merging name groups") - - global_nodes = self.filter_orphaned_nodes( - unpersist(join(with_ast_path, "common_nodes.bz2")), with_ast_path - ) - persist(global_nodes, join(with_ast_path, "common_nodes.bz2")) - node_names = extract_node_names( - global_nodes, min_count=2 - ) - if node_names is not None: - persist(node_names, join(with_ast_path, "node_names.bz2")) - - if self.visualize: - self.visualize_func( - unpersist(join(with_ast_path, "common_nodes.bz2")), - unpersist(join(with_ast_path, "common_edges.bz2")), - join(with_ast_path, "visualization.pdf") - ) - - - def do_extraction(self): - global_nodes = None - global_nodes_with_ast = None - name_groups = None - - for env_path in self.environments: - logging.info(f"Found {os.path.basename(env_path)}") - - if not self.is_indexed(env_path): - logging.info("Package not indexed") - continue - - nodes, edges, source_location, occurrence, filecontent, element_component = \ - self.read_sourcetrail_files(env_path) - - if nodes is None: - logging.info("Index is empty") - continue - - edges = filter_ambiguous_edges(edges, element_component) - - nodes, edges = self.filter_unsolved_symbols(nodes, edges) - - bodies = process_bodies(nodes, edges, source_location, occurrence, filecontent, self.lang) - call_seq = extract_call_seq(nodes, edges, source_location, occurrence) - - edges = add_reverse_edges(edges) - - # if bodies is not None: - ast_nodes, ast_edges, offsets, name_group_tracker = get_ast_from_modules( - nodes, edges, source_location, occurrence, filecontent, - self.bpe_tokenizer, self.create_subword_instances, self.connect_subwords, self.lang, - track_offsets=self.track_offsets - ) - nodes_with_ast = nodes.append(ast_nodes) - edges_with_ast = edges.append(ast_edges) - if bodies is not None: - vars = extract_var_names(nodes, bodies, self.lang) - else: - vars = None - if name_groups is None: - name_groups = name_group_tracker - else: - name_groups = name_groups.append(name_group_tracker) - # else: - # nodes_with_ast = nodes - # edges_with_ast = edges - # vars = None - # offsets = None - - global_nodes = self.merge_with_global(global_nodes, nodes) - global_nodes_with_ast = self.merge_with_global(global_nodes_with_ast, nodes_with_ast) - - local2global = get_local2global( - global_nodes=global_nodes, local_nodes=nodes - ) - local2global_with_ast = get_local2global( - global_nodes=global_nodes_with_ast, local_nodes=nodes_with_ast - ) - - self.write_local(env_path, nodes, edges, bodies, call_seq, vars, - nodes_with_ast, edges_with_ast, offsets, - local2global, local2global_with_ast, name_groups) - - def get_local2global(self, path): - if path in self.local2global_cache: - return self.local2global_cache[path] - else: - local2global_df = unpersist_if_present(path) - if local2global_df is None: - return None - else: - local2global = dict(zip(local2global_df['id'], local2global_df['global_id'])) - self.local2global_cache[path] = local2global - return local2global + def _prepare_environments(self): + paths = (os.path.join(self.path, dir) for dir in os.listdir(self.path)) + self.environments = sorted(list(filter(lambda path: os.path.isdir(path), paths)), key=lambda x: x.lower()) def create_output_dirs(self, output_path): if not os.path.isdir(output_path): @@ -204,14 +116,16 @@ def create_output_dirs(self, output_path): return no_ast_path, with_ast_path - def is_indexed(self, path): + @staticmethod + def is_indexed(path): basename = os.path.basename(path) if os.path.isfile(os.path.join(path, f"{basename}.srctrldb")): return True else: return False - def get_csv_name(self, name, path): + @staticmethod + def get_csv_name(name, path): return os.path.join(path, filenames[name]) def filter_unsolved_symbols(self, nodes, edges): @@ -237,31 +151,11 @@ def read_sourcetrail_files(self, env_path): else: return nodes, edges, source_location, occurrence, filecontent, element_component - def write_local(self, dir, nodes, edges, bodies, call_seq, vars, - nodes_with_ast, edges_with_ast, offsets, - local2global, local2global_with_ast, name_groups): - write_nodes(nodes, dir) - write_edges(edges, dir) - if bodies is not None: - write_processed_bodies(bodies, dir) - if call_seq is not None: - persist(call_seq, join(dir, filenames['call_seq'])) - if vars is not None: - persist(vars, join(dir, filenames['function_variable_pairs'])) - persist(nodes_with_ast, join(dir, "nodes_with_ast.bz2")) - persist(edges_with_ast, join(dir, "edges_with_ast.bz2")) - if offsets is not None: - persist(offsets, join(dir, "offsets.bz2")) - if len(edges_with_ast.query("type == 'annotation_for' or type == 'returned_by'")) > 0: - with open(join(dir, "has_annotations"), "w") as has_annotations: - pass - persist(local2global, join(dir, "local2global.bz2")) - persist(local2global_with_ast, join(dir, "local2global_with_ast.bz2")) - - if name_groups is not None: - persist(name_groups, join(dir, "name_groups.bz2")) - def get_global_node_info(self, global_nodes): + """ + :param global_nodes: nodes from a global merged graph + :return: Set of existing nodes represented with (type, node_name), minimal available free id + """ if global_nodes is None: existing_nodes, next_valid_id = set(), 0 else: @@ -269,6 +163,12 @@ def get_global_node_info(self, global_nodes): return existing_nodes, next_valid_id def merge_with_global(self, global_nodes, local_nodes): + """ + Merge nodes obtained from the source code with the previously existing nodes. + :param global_nodes: Nodes from a global inter-package graph + :param local_nodes: Nodes from a local file-level graph + :return: Updated version of the global inter-package graph + """ existing_nodes, next_valid_id = self.get_global_node_info(global_nodes) new_nodes = merge_global_with_local(existing_nodes, next_valid_id, local_nodes) @@ -279,80 +179,127 @@ def merge_with_global(self, global_nodes, local_nodes): return global_nodes - def merge_files(self, env_path, filename, map_filename, columns_to_map, original, columns_special=None): - input_table_path = join(env_path, filename) - local2global = self.get_local2global(join(env_path, map_filename)) - if os.path.isfile(input_table_path) and local2global is not None: - input_table = unpersist(input_table_path) - if self.only_with_annotations: - if not os.path.isfile(join(env_path, "has_annotations")): - return original - new_table = map_columns(input_table, local2global, columns_to_map, columns_special=columns_special) - if original is None: - return new_table + def do_extraction(self): + global_nodes = set() + global_nodes_with_ast = set() + + for env_path in self.environments: + logging.info(f"Found {os.path.basename(env_path)}") + + if not self.is_indexed(env_path): + logging.info("Package not indexed") + continue + + if not self.recompute_l2g: + + nodes, edges, source_location, occurrence, filecontent, element_component = \ + self.read_sourcetrail_files(env_path) + + if nodes is None: + logging.info("Index is empty") + continue + + edges = filter_ambiguous_edges(edges, element_component) + + nodes, edges = self.filter_unsolved_symbols(nodes, edges) + + bodies = process_bodies(nodes, edges, source_location, occurrence, filecontent, self.lang) + call_seq = extract_call_seq(nodes, edges, source_location, occurrence) + + edges = add_reverse_edges(edges) + + # if bodies is not None: + ast_nodes, ast_edges, offsets, name_mappings = get_ast_from_modules( + nodes, edges, source_location, occurrence, filecontent, + self.bpe_tokenizer, self.create_subword_instances, self.connect_subwords, self.lang, + track_offsets=self.track_offsets + ) + + if offsets is not None: + offsets["package"] = os.path.basename(env_path) + filecontent["package"] = os.path.basename(env_path) + + # need this check in situations when module has a single file and this file cannot be parsed + nodes_with_ast = nodes.append(ast_nodes) if ast_nodes is not None else nodes + edges_with_ast = edges.append(ast_edges) if ast_edges is not None else edges + + if bodies is not None: + vars = extract_var_names(nodes, bodies, self.lang) + else: + vars = None else: - return original.append(new_table) - else: - return original - - def create_global_file(self, local_file, local2global_file, columns, output_path, message, ensure_unique_with=None, columns_special=None): - global_table = None - for ind, env_path in tqdm( - enumerate(self.environments), desc=message, leave=True, - dynamic_ncols=True, total=len(self.environments) - ): - global_table = self.merge_files( - env_path, local_file, local2global_file, columns, global_table, columns_special=columns_special + nodes = unpersist_if_present(join(env_path, "nodes.bz2")) + nodes_with_ast = unpersist_if_present(join(env_path, "nodes_with_ast.bz2")) + + if nodes is None or nodes_with_ast is None: + continue + + edges = bodies = call_seq = vars = edges_with_ast = offsets = name_mappings = filecontent = None + + # global_nodes = self.merge_with_global(global_nodes, nodes) + # global_nodes_with_ast = self.merge_with_global(global_nodes_with_ast, nodes_with_ast) + + local2global = get_local2global( + global_nodes=global_nodes, local_nodes=nodes + ) + local2global_with_ast = get_local2global( + global_nodes=global_nodes_with_ast, local_nodes=nodes_with_ast ) - if ensure_unique_with is not None: - global_table = global_table.drop_duplicates(subset=ensure_unique_with) + global_nodes.update(local2global["global_id"]) + global_nodes_with_ast.update(local2global_with_ast["global_id"]) - if global_table is not None: - global_table.reset_index(drop=True, inplace=True) - assert len(global_table) == len(global_table.index.unique()) + self.write_type_annotation_flag(edges_with_ast, env_path) - persist(global_table, output_path) + self.write_local( + env_path, nodes=nodes, edges=edges, bodies=bodies, call_seq=call_seq, function_variable_pairs=vars, + nodes_with_ast=nodes_with_ast, edges_with_ast=edges_with_ast, offsets=offsets, + local2global=local2global, local2global_with_ast=local2global_with_ast, + name_mappings=name_mappings, filecontent_with_package=filecontent + ) - def filter_orphaned_nodes(self, global_nodes, output_dir): - edges = unpersist(join(output_dir, "common_edges.bz2")) - active_nodes = set(edges['source_node_id'].tolist() + edges['target_node_id'].tolist()) - global_nodes = global_nodes[ - global_nodes['id'].apply(lambda id_: id_ in active_nodes) - ] - return global_nodes + self.compact_mapping_for_l2g(global_nodes, "local2global.bz2") + self.compact_mapping_for_l2g(global_nodes_with_ast, "local2global_with_ast.bz2") + + @staticmethod + def extract_node_names(nodes_path, min_count): + logging.info("Extract node names") + return extract_node_names(read_nodes(nodes_path), min_count=min_count) + + def filter_type_edges(self, nodes_path, edges_path): + logging.info("Filter type edges") + filter_type_edges_with_chunks(nodes_path, edges_path, kwarg_fn=self.get_writing_mode) + + def merge(self, output_directory): + + if self.extract: + logging.info("Extracting...") + self.do_extraction() + + no_ast_path, with_ast_path = self.create_output_dirs(output_directory) + + if not self.only_with_annotations: + self.merge_graph_without_ast(no_ast_path) + + self.merge_graph_with_ast(with_ast_path) def visualize_func(self, nodes, edges, output_path): from SourceCodeTools.code.data.sourcetrail.sourcetrail_draw_graph import visualize visualize(nodes, edges, output_path) + if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description='Merge indexed environments into a single graph') - parser.add_argument('indexed_environments', - help='Path to environments indexed by sourcetrail') - parser.add_argument('output_directory', - help='') - parser.add_argument('--language', "-l", dest="language", default="python", - help='Path to environments indexed by sourcetrail') - parser.add_argument('--bpe_tokenizer', '-bpe', dest='bpe_tokenizer', type=str, - help='') - parser.add_argument('--create_subword_instances', action='store_true', default=False, help="") - parser.add_argument('--connect_subwords', action='store_true', default=False, - help="Takes effect only when `create_subword_instances` is False") - parser.add_argument('--only_with_annotations', action='store_true', default=False, help="") - parser.add_argument('--do_extraction', action='store_true', default=False, help="") - parser.add_argument('--visualize', action='store_true', default=False, help="") - parser.add_argument('--track_offsets', action='store_true', default=False, help="") - - - args = parser.parse_args() + + args = DatasetCreatorArguments().parse() + + if args.recompute_l2g: + args.do_extraction = True logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(message)s") - dataset = DatasetCreator(args.indexed_environments, args.language, - args.bpe_tokenizer, args.create_subword_instances, - args.connect_subwords, args.only_with_annotations, - args.do_extraction, args.visualize, args.track_offsets) + dataset = DatasetCreator( + args.indexed_environments, args.language, args.bpe_tokenizer, args.create_subword_instances, + args.connect_subwords, args.only_with_annotations, args.do_extraction, args.visualize, args.track_offsets, + args.remove_type_annotations, args.recompute_l2g + ) dataset.merge(args.output_directory) diff --git a/SourceCodeTools/code/data/sourcetrail/annotations_random_split.py b/SourceCodeTools/code/data/sourcetrail/annotations_random_split.py index a9fae023..60701586 100644 --- a/SourceCodeTools/code/data/sourcetrail/annotations_random_split.py +++ b/SourceCodeTools/code/data/sourcetrail/annotations_random_split.py @@ -1,4 +1,4 @@ -from SourceCodeTools.code.data.sourcetrail.Dataset import split, ensure_connectedness, ensure_valid_edges +from SourceCodeTools.code.data.dataset.Dataset import SourceGraphDataset # split, ensure_connectedness, ensure_valid_edges import pandas as pd import sys from os.path import join, dirname @@ -16,9 +16,9 @@ nodes['id'] = pd.concat([type_ann['src'], type_ann['dst']], axis=0).unique() -nodes, train = ensure_connectedness(nodes, train) +nodes, train = SourceGraphDataset.ensure_connectedness(nodes, train) -nodes, test = ensure_valid_edges(nodes, test, ignore_src=True) +nodes, test = SourceGraphDataset.ensure_valid_edges(nodes, test, ignore_src=True) train.rename({ "dst":"source_node_id", diff --git a/SourceCodeTools/code/data/sourcetrail/assess_type_pred.neighbours.py b/SourceCodeTools/code/data/sourcetrail/assess_type_pred.neighbours.py new file mode 100644 index 00000000..8a529f1e --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/assess_type_pred.neighbours.py @@ -0,0 +1,57 @@ +import pickle + +import sys + +import numpy as np + +from SourceCodeTools.code.data.file_utils import unpersist + +type_annotations_path = sys.argv[1] +nodes = sys.argv[2] +embeddings_path = sys.argv[3] + +nodes = unpersist(nodes) +type_annotations = unpersist(type_annotations_path) +embs = pickle.load(open(embeddings_path, "rb"))# [0] + +def normalize(typeann): + return typeann.strip("\"").strip("'").split("[")[0] + +node2nodetype = dict(zip(nodes["id"], nodes["type"])) + +type_annotations["dst"] = type_annotations["dst"].apply(normalize) + +type_annotations = type_annotations[ + type_annotations["src"].apply(lambda x: node2nodetype[x] == "mention") +] + +seed_nodes = type_annotations["src"] +node2type = dict(zip(type_annotations["src"], type_annotations["dst"])) + +seed_embs = [] +missing_emb = [] +for nid in seed_nodes: + if nid in embs: + seed_embs.append(embs[nid]) + else: + missing_emb.append(nid) + +type_embs = np.vstack(seed_embs) + +added = set() +same_or_not = dict() + +for nid, emb in zip(seed_nodes, type_embs): + diff = emb.reshape(1, -1) - type_embs + diff2 = np.square(diff) + dist = np.sqrt(np.sum(diff2, axis=1)) + order = np.argsort(dist) + min_pos = order[1] + closest = seed_nodes.iloc[min_pos] + if (nid, closest) not in added and (closest, nid) not in added: + added.add((nid, closest)) + added.add((closest, nid)) + same_or_not[(nid, closest)] = node2type[nid] == node2type[closest] + +print(sum(list(same_or_not.values())) / len(same_or_not)) + diff --git a/SourceCodeTools/code/data/sourcetrail/common.py b/SourceCodeTools/code/data/sourcetrail/common.py index 8cc28b4c..e8635e2d 100644 --- a/SourceCodeTools/code/data/sourcetrail/common.py +++ b/SourceCodeTools/code/data/sourcetrail/common.py @@ -1,31 +1,21 @@ -from SourceCodeTools.code.data.sourcetrail.file_utils import * +import hashlib + +from SourceCodeTools.code.data.file_utils import * from tqdm import tqdm DEFINITION_TYPE = 1 UNRESOLVED_SYMBOL = "unsolved_symbol" -import sqlite3 - - -class SQLTable: - def __init__(self, df, filename, table_name): - self.conn = sqlite3.connect(filename) - self.path = filename - self.table_name = table_name - - df.to_sql(self.table_name, con=self.conn, if_exists='replace', index=False, index_label=df.columns) - - def query(self, query_string): - return pd.read_sql(query_string, self.conn) - - def __del__(self): - self.conn.close() - if os.path.isfile(self.path): - os.remove(self.path) - - def get_occurrence_groups(nodes, edges, source_location, occurrence): + """ + Group nodes based on file id. Return dataset that contains node ids and their offsets in the source code. + :param nodes: dataframe with nodes + :param edges: dataframe with edges + :param source_location: dataframe with sources + :param occurrence: dataframe with with offsets + :return: Result of group by file id + """ edges = edges.rename(columns={'type': 'e_type'}) edges = edges.query("id >= 0") # filter reverse edges @@ -70,53 +60,3 @@ def sql_get_occurrences_from_range(occurrences, start, end) -> pd.DataFrame: f"select * from {occurrences.table_name} where start_line >= {start} and end_line <= {end} and occ_type != {DEFINITION_TYPE} and start_line = end_line") df = df.astype({"source_node_id": "Int32", "target_node_id": "Int32"}) return df - - -def create_node_repr(nodes): - return list(zip(nodes['serialized_name'], nodes['type'])) - - -def map_id_columns(df, column_names, mapper): - df = df.copy() - for col in column_names: - if col in df.columns: - df[col] = df[col].apply(lambda x: mapper.get(x, pd.NA)) - return df - - -def map_offsets(column, id_map): - def map_entry(entry): - return [(e[0], e[1], id_map[e[2]]) for e in entry] - return [map_entry(entry) for entry in column] - - -def merge_with_file_if_exists(df, merge_with_file): - if os.path.isfile(merge_with_file): - original_data = unpersist(merge_with_file) - data = pd.concat([original_data, df], axis=0) - else: - data = df - return data - - -def create_local_to_global_id_map(local_nodes, global_nodes): - local_nodes = local_nodes.copy() - global_nodes = global_nodes.copy() - - global_nodes['node_repr'] = create_node_repr(global_nodes) - local_nodes['node_repr'] = create_node_repr(local_nodes) - - rev_id_map = dict(zip( - global_nodes['node_repr'].tolist(), global_nodes['id'].tolist() - )) - id_map = dict(zip( - local_nodes["id"].tolist(), map( - lambda x: rev_id_map[x], local_nodes["node_repr"].tolist() - ) - )) - - return id_map - - -def custom_tqdm(iterable, total, message): - return tqdm(iterable, total=total, desc=message, leave=False, dynamic_ncols=True) \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/deprecated/DatasetCreator.py b/SourceCodeTools/code/data/sourcetrail/deprecated/DatasetCreator.py index 82a3929c..073953e8 100644 --- a/SourceCodeTools/code/data/sourcetrail/deprecated/DatasetCreator.py +++ b/SourceCodeTools/code/data/sourcetrail/deprecated/DatasetCreator.py @@ -14,7 +14,7 @@ from SourceCodeTools.code.data.sourcetrail.sourcetrail_extract_variable_names import extract_var_names from SourceCodeTools.code.data.sourcetrail.sourcetrail_extract_node_names import extract_node_names -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * class DatasetCreator: diff --git a/SourceCodeTools/code/data/sourcetrail/deprecated/sourcetrail_extract_type_information.py b/SourceCodeTools/code/data/sourcetrail/deprecated/sourcetrail_extract_type_information.py deleted file mode 100644 index 590f4be9..00000000 --- a/SourceCodeTools/code/data/sourcetrail/deprecated/sourcetrail_extract_type_information.py +++ /dev/null @@ -1,18 +0,0 @@ -import sys - -from SourceCodeTools.code.data.sourcetrail.file_utils import * - - -def main(): - edges = unpersist(sys.argv[1]) - out_annotations = sys.argv[2] - out_no_annotations = sys.argv[3] - - annotations = edges.query(f"type == 'annotation_for' or type == 'returned_by'") - no_annotations = edges.query(f"type != 'annotation_for' and type != 'returned_by'") - - persist(annotations, out_annotations) - persist(no_annotations, out_no_annotations) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/deprecated/verify_edges.py b/SourceCodeTools/code/data/sourcetrail/deprecated/verify_edges.py index 9523a025..c8e87b17 100644 --- a/SourceCodeTools/code/data/sourcetrail/deprecated/verify_edges.py +++ b/SourceCodeTools/code/data/sourcetrail/deprecated/verify_edges.py @@ -1,4 +1,4 @@ -from SourceCodeTools.code.data.sourcetrail.Dataset import load_data +from SourceCodeTools.code.data.dataset.Dataset import load_data import sys node_path = sys.argv[1] diff --git a/SourceCodeTools/code/data/sourcetrail/extra_reverse_finder.py b/SourceCodeTools/code/data/sourcetrail/extra_reverse_finder.py index a60fa359..297994dd 100644 --- a/SourceCodeTools/code/data/sourcetrail/extra_reverse_finder.py +++ b/SourceCodeTools/code/data/sourcetrail/extra_reverse_finder.py @@ -1,7 +1,7 @@ import os import sys -from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist +from SourceCodeTools.code.data.file_utils import unpersist def main(): diff --git a/SourceCodeTools/code/data/sourcetrail/get_embeddings_with_string.py b/SourceCodeTools/code/data/sourcetrail/get_embeddings_with_string.py new file mode 100644 index 00000000..9eb532eb --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/get_embeddings_with_string.py @@ -0,0 +1,51 @@ +import pickle +import sys + +import pandas as pd + +from SourceCodeTools.code.data.file_utils import unpersist + +# get tb embedding where labels are human-readable strings + + +def main(): + nodes_with_strings = sys.argv[1] + embeddings = sys.argv[2] + out_name_for_tb = sys.argv[3] + + all_embs = pickle.load(open(embeddings, "rb"))[0] + nodes = unpersist(nodes_with_strings) + + embs = [] + strings = [] + + tb_meta_path = out_name_for_tb + "meta.tsv" + tb_embs_path = out_name_for_tb + "embs.tsv" + + nodes.dropna(inplace=True) + nodes = nodes.sample(n=5000) + + with open(tb_meta_path, "w") as tb_meta: + with open(tb_embs_path, "w") as tb_embs: + tb_meta.write("string\ttype\n") + for id, string, type in zip(nodes["id"], nodes["string"], nodes["type"]): + if string is not None and not pd.isna(string) and "srctrl" not in string and "\n" not in string: + # embs.append( + # all_embs[id] + # ) + # strings.append(string) + sep = "\t" + + try: + string = f"{string}\t{type}\n" #f"{string.encode('utf-8')}\n" + emb_string = f"{sep.join(str(e) for e in all_embs[int(id)])}\n" + tb_meta.write(string) + tb_embs.write(emb_string) + except: + pass + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/k_hop_graph.py b/SourceCodeTools/code/data/sourcetrail/k_hop_graph.py new file mode 100644 index 00000000..b821aec4 --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/k_hop_graph.py @@ -0,0 +1,66 @@ +from os.path import join + +import networkx as nx +from tqdm import tqdm + +from SourceCodeTools.code.data.dataset.Dataset import load_data + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("working_directory") + parser.add_argument("k_hops", type=int) + parser.add_argument("output") + + args = parser.parse_args() + + nodes, edges = load_data( + join(args.working_directory, "common_nodes.bz2"), join(args.working_directory, "common_edges.bz2") + ) + + edge_types = {} + edge_lists = {} + for s, d, t in edges[["src", "dst", "type"]].values: + edge_types[(s,d)] = t + if s not in edge_lists: + edge_lists[s] = [] + edge_lists[s].append(d) + + g = nx.from_pandas_edgelist( + edges, source="src", target="dst", create_using=nx.DiGraph, edge_attr="type" + ) + + # def expand_edges(node_id, view, edge_prefix, level=0): + # edges = [] + # if level <= args.k_hops: + # if edge_prefix != "": + # edge_prefix += "|" + # for e in view: + # edges.append((node_id, e, edge_prefix + view[e]["type"])) + # edges.extend(expand_edges(node_id, g[e], edge_prefix + view[e]["type"], level=level+1)) + # return edges + # + # edges = [] + # for node in tqdm(g.nodes): + # edges.extend(expand_edges(node, g[node], "", level=0)) + + def expand_edges(node_id, s, dlist, edge_prefix, level=0): + edges = [] + if level <= args.k_hops: + if edge_prefix != "": + edge_prefix += "|" + for d in dlist: + etype = edge_prefix + edge_types[(s,d)] + edges.append((node_id, d, etype)) + edges.extend(expand_edges(node_id, d, edge_lists[d], etype, level=level+1)) + return edges + + edges = [] + for node in tqdm(edge_lists): + edges.extend(expand_edges(node, node, edge_lists[node], "", level=0)) + + print() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/nodes_of_interest_from_dataset.py b/SourceCodeTools/code/data/sourcetrail/nodes_of_interest_from_dataset.py new file mode 100644 index 00000000..d9e69bcc --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/nodes_of_interest_from_dataset.py @@ -0,0 +1,29 @@ +import json + + +def get_node_ids_from_dataset(dataset_path): + node_ids = [] + with open(dataset_path, "r") as dataset: + for line in dataset: + entry = json.loads(line) + for _, _, id_ in entry["replacements"]: + node_ids.append(int(id_)) + return node_ids + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("dataset") + parser.add_argument("output") + args = parser.parse_args() + + node_ids = get_node_ids_from_dataset(args.dataset) + with open(args.output, "w") as sink: + sink.write("node_id\n") + for id_ in node_ids: + sink.write(f"{id_}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/pandas_format_converter.py b/SourceCodeTools/code/data/sourcetrail/pandas_format_converter.py index ecf6fb01..450e1f7c 100644 --- a/SourceCodeTools/code/data/sourcetrail/pandas_format_converter.py +++ b/SourceCodeTools/code/data/sourcetrail/pandas_format_converter.py @@ -1,5 +1,5 @@ import sys -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * input_path = sys.argv[1] target_format = sys.argv[2] diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_add_reverse_edges.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_add_reverse_edges.py index a27850c7..42d9cc5b 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_add_reverse_edges.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_add_reverse_edges.py @@ -1,4 +1,4 @@ -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * import pandas as p import sys diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_analyze_tree_depth.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_analyze_tree_depth.py new file mode 100644 index 00000000..cb1f96c6 --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_analyze_tree_depth.py @@ -0,0 +1,46 @@ +import ast +import os +from typing import Iterable + +from SourceCodeTools.code.data.file_utils import unpersist +import numpy as np + +class DepthEstimator: + def __init__(self): + self.depth = 0 + + def go(self, node, depth=0): + depth += 1 + if depth > self.depth: + self.depth = depth + if isinstance(node, Iterable) and not isinstance(node, str): + for subnode in node: + self.go(subnode, depth=depth) + if hasattr(node, "_fields"): + for field in node._fields: + self.go(getattr(node, field), depth=depth) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("bodies") + args = parser.parse_args() + + bodies = unpersist(args.bodies) + + depths = [] + + for ind, row in bodies.iterrows(): + body = row.body + body_ast = ast.parse(body.strip()) + de = DepthEstimator() + de.go(body_ast) + depths.append(de.depth) + + print(f"Average depth: {sum(depths)/len(depths)}") + depths = np.array(depths, dtype=np.int32) + np.savetxt(os.path.join(os.path.dirname(args.bodies), "bodies_depths.txt"), depths, "%d") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py index 14cf586d..e98dac59 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges.py @@ -4,13 +4,11 @@ from nltk import RegexpTokenizer -from SourceCodeTools.code.python_ast import AstGraphGenerator -from SourceCodeTools.code.python_ast import GNode -from SourceCodeTools.nlp.entity.annotator.annotator_utils import to_offsets, overlap, resolve_self_collision -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.ast.python_ast import AstGraphGenerator, GNode, PythonSharedNodes +from SourceCodeTools.code.annotator_utils import to_offsets, overlap, resolve_self_collision +from SourceCodeTools.code.data.file_utils import * from SourceCodeTools.nlp.embed.bpe import load_bpe_model, make_tokenizer -from SourceCodeTools.code.data.sourcetrail.common import custom_tqdm -from SourceCodeTools.code.python_ast import PythonSharedNodes +from SourceCodeTools.code.common import custom_tqdm pd.options.mode.chained_assignment = None @@ -219,7 +217,7 @@ def get_ast_nodes(edges): return nodes -from SourceCodeTools.nlp.entity.annotator.annotator_utils import adjust_offsets +from SourceCodeTools.code.annotator_utils import adjust_offsets def format_replacement_offsets(offsets): @@ -423,6 +421,9 @@ def make_reverse_edge(edge): rev_edge['type'] = edge['type'] + "_rev" rev_edge['src'] = edge['dst'] rev_edge['dst'] = edge['src'] + rev_edge['offsets'] = None + if "scope" in edge: + rev_edge["scope"] = edge["scope"] return rev_edge diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py index d852163b..8ee8d578 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_ast_edges2.py @@ -1,19 +1,20 @@ import argparse import re from copy import copy -from time import time_ns import networkx as nx +from SourceCodeTools.code.IdentifierPool import IntIdentifierPool +from SourceCodeTools.code.ast import has_valid_syntax +from SourceCodeTools.code.common import custom_tqdm from SourceCodeTools.code.data.sourcetrail.common import * -from SourceCodeTools.code.data.sourcetrail.sourcetrail_add_reverse_edges import add_reverse_edges from SourceCodeTools.code.data.sourcetrail.sourcetrail_ast_edges import NodeResolver, make_reverse_edge -from SourceCodeTools.code.python_ast import AstGraphGenerator, GNode, PythonSharedNodes -from SourceCodeTools.nlp.entity.annotator.annotator_utils import adjust_offsets2 -from SourceCodeTools.nlp.entity.annotator.annotator_utils import overlap as range_overlap -from SourceCodeTools.nlp.entity.annotator.annotator_utils import to_offsets, get_cum_lens +from SourceCodeTools.code.ast.python_ast2 import AstGraphGenerator, GNode, PythonSharedNodes +# from SourceCodeTools.code.python_ast_cf import AstGraphGenerator +from SourceCodeTools.code.annotator_utils import adjust_offsets2 +from SourceCodeTools.code.annotator_utils import overlap as range_overlap +from SourceCodeTools.code.annotator_utils import to_offsets, get_cum_lens from SourceCodeTools.nlp.string_tools import get_byte_to_char_map -from SourceCodeTools.code.data.sourcetrail.sourcetrail_parse_bodies2 import has_valid_syntax class MentionTokenizer: @@ -27,13 +28,18 @@ def __init__(self, bpe_tokenizer_path, create_subword_instances, connect_subword self.connect_subwords = connect_subwords def replace_mentions_with_subwords(self, edges): + """ + Process edges and tokenize certain node types + :param edges: List of edges + :return: List of edges, including new edges for subword tokenization + """ if self.create_subword_instances: - def produce_subw_edges(subwords, dst): - return self.produce_subword_edges_with_instances(subwords, dst) + def produce_subw_edges(subwords, dst, scope=None): + return self.produce_subword_edges_with_instances(subwords, dst, self.connect_subwords, scope=scope) else: - def produce_subw_edges(subwords, dst): - return self.produce_subword_edges(subwords, dst, self.connect_subwords) + def produce_subw_edges(subwords, dst, scope=None): + return self.produce_subword_edges(subwords, dst, self.connect_subwords, scope=scope) new_edges = [] for edge in edges: @@ -56,13 +62,13 @@ def produce_subw_edges(subwords, dst): else: subwords = self.bpe(edge['src'].name) - new_edges.extend(produce_subw_edges(subwords, dst)) + new_edges.extend(produce_subw_edges(subwords, dst, edge["scope"] if "scope" in edge else None)) else: new_edges.append(edge) elif self.bpe is not None and edge["type"] == "__global_name": subwords = self.bpe(edge['src'].name) - new_edges.extend(produce_subw_edges(subwords, edge['dst'])) + new_edges.extend(produce_subw_edges(subwords, edge['dst'], edge["scope"] if "scope" in edge else None)) elif self.bpe is not None and edge['src'].type in PythonSharedNodes.tokenizable_types_and_annotations: new_edges.append(edge) if edge['type'] != "global_mention_rev": @@ -70,7 +76,7 @@ def produce_subw_edges(subwords, dst): dst = edge['src'] subwords = self.bpe(dst.name) - new_edges.extend(produce_subw_edges(subwords, dst)) + new_edges.extend(produce_subw_edges(subwords, dst, edge["scope"] if "scope" in edge else None)) # elif self.bpe is not None and edge['dst'].type in {"Global"} and edge['src'].type != "Constant": # # this brach is disabled because it does not seem to make sense # # Globals can be referred by Name nodes, but they are already processed in the branch above @@ -79,7 +85,7 @@ def produce_subw_edges(subwords, dst): # # dst = edge['src'] # subwords = self.bpe(dst.name) - # new_edges.extend(produce_subw_edges(subwords, dst)) + # new_edges.extend(produce_subw_edges(subwords, dst, edge["scope"] if "scope" in edge else None)) else: new_edges.append(edge) @@ -105,7 +111,7 @@ def produce_subw_edges(subwords, dst): # return global_edges @staticmethod - def connect_prev_next_subwords(edges, current, prev_subw, next_subw): + def connect_prev_next_subwords(edges, current, prev_subw, next_subw, scope=None): if next_subw is not None: edges.append({ 'src': current, @@ -113,6 +119,8 @@ def connect_prev_next_subwords(edges, current, prev_subw, next_subw): 'type': 'next_subword', 'offsets': None }) + if scope is not None: + edges[-1]["scope"] = scope if prev_subw is not None: edges.append({ 'src': current, @@ -120,8 +128,10 @@ def connect_prev_next_subwords(edges, current, prev_subw, next_subw): 'type': 'prev_subword', 'offsets': None }) + if scope is not None: + edges[-1]["scope"] = scope - def produce_subword_edges(self, subwords, dst, connect_subwords=False): + def produce_subword_edges(self, subwords, dst, connect_subwords=False, scope=None): new_edges = [] subwords = list(map(lambda x: GNode(name=x, type="subword"), subwords)) @@ -132,12 +142,14 @@ def produce_subword_edges(self, subwords, dst, connect_subwords=False): 'type': 'subword', 'offsets': None }) + if scope is not None: + new_edges[-1]["scope"] = scope if connect_subwords: self.connect_prev_next_subwords(new_edges, subword, subwords[ind - 1] if ind > 0 else None, - subwords[ind + 1] if ind < len(subwords) - 1 else None) + subwords[ind + 1] if ind < len(subwords) - 1 else None, scope=scope) return new_edges - def produce_subword_edges_with_instances(self, subwords, dst, connect_subwords=True): + def produce_subword_edges_with_instances(self, subwords, dst, connect_subwords=True, scope=None): new_edges = [] subwords = list(map(lambda x: GNode(name=x, type="subword"), subwords)) @@ -150,15 +162,19 @@ def produce_subword_edges_with_instances(self, subwords, dst, connect_subwords=T 'type': 'subword_instance', 'offsets': None }) + if scope is not None: + new_edges[-1]["scope"] = scope new_edges.append({ 'src': subword_instance, 'dst': dst, 'type': 'subword', 'offsets': None }) + if scope is not None: + new_edges[-1]["scope"] = scope if connect_subwords: - self.connect_prev_next_subwords(new_edges, subword, subwords[ind - 1] if ind > 0 else None, - subwords[ind + 1] if ind < len(subwords) - 1 else None) + self.connect_prev_next_subwords(new_edges, subword_instance, instances[ind - 1] if ind > 0 else None, + instances[ind + 1] if ind < len(instances) - 1 else None, scope=scope) return new_edges @@ -166,7 +182,7 @@ class GlobalNodeMatcher: def __init__(self, nodes, edges): from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types # filter function, classes, methods, and modules - self.allowed_node_types = set([node_types[tt] for tt in [8, 128, 4096, 8192]]) + self.allowed_node_types = set(node_types.values()) # set([node_types[tt] for tt in [8, 128, 2048, 4096, 8192]]) self.allowed_edge_types = {"defined_in"} #{"defines_rev"} self.allowed_ast_node_types = {"FunctionDef", "ClassDef", "Module", "mention"} @@ -181,6 +197,8 @@ def __init__(self, nodes, edges): self.global_graph = nx.DiGraph() self.global_graph.add_edges_from(self.get_edges_for_graph(self.edges)) + self.name_mapping = dict() + @staticmethod def get_edges_for_graph(edges): try: @@ -201,12 +219,21 @@ def get_node_types(nodes): return id2type def match_with_global_nodes(self, nodes, edges): + """ + Match AST nodes that represent functions, classes. and modules with nodes from global graph. + This information is not stored explicitly and need to walk the graph to resolve. + :param nodes: + :param edges: + :return: Mapping from AST nodes to corresponding global nodes + """ nodes = pd.DataFrame([node for node in nodes if node["type"] in self.allowed_ast_node_types]) edges = pd.DataFrame([edge for edge in edges if edge["type"] in self.allowed_ast_edge_types]) if len(edges) == 0: return {} + local_names = dict(zip(nodes["id"], nodes["serialized_name"])) + func_nodes = nodes.query("type == 'FunctionDef'")["id"].values class_nodes = nodes.query("type == 'ClassDef'")["id"].values module_nodes = nodes.query("type == 'Module'")["id"].values @@ -216,8 +243,8 @@ def match_with_global_nodes(self, nodes, edges): def find_global_id(graph, def_id, motif): """ - Do function or class global id lookup - :param graph: + Perform function or class global id lookup. Need this because + :param graph: nx graph :param def_id: :param motif: list with edge types, return None if path does not exist :return: @@ -225,6 +252,8 @@ def find_global_id(graph, def_id, motif): name_node = def_id motif = copy(motif) while len(motif) > 0: + if name_node not in graph: + return None link_type = motif.pop(0) for node, eprop in graph[name_node].items(): if eprop["type"] == link_type: @@ -297,7 +326,9 @@ def get_global_module_id2(func_global, class_global): module_global = get_global_module_id(module_nodes, module_candidates, func_global, class_global) new_node_ids.update(module_global) - return new_node_ids + name_mapping = {local_names[key]: self.global_names[val] for key, val in new_node_ids.items()} + + return new_node_ids, name_mapping @staticmethod def merge_global_references(global_references, module_global_references): @@ -309,36 +340,47 @@ def merge_global_references(global_references, module_global_references): class ReplacementNodeResolver(NodeResolver): - def __init__(self, nodes): - - self.nodeid2name = dict(zip(nodes['id'].tolist(), nodes['serialized_name'].tolist())) - self.nodeid2type = dict(zip(nodes['id'].tolist(), nodes['type'].tolist())) + def __init__(self, nodes=None): + + if nodes is not None: + self.nodeid2name = dict(zip(nodes['id'].tolist(), nodes['serialized_name'].tolist())) + self.nodeid2type = dict(zip(nodes['id'].tolist(), nodes['type'].tolist())) + self.valid_new_node = nodes['id'].max() + 1 + self.old_nodes = nodes.copy() + self.old_nodes['mentioned_in'] = pd.NA + self.old_nodes['string'] = pd.NA + self.old_nodes = self.old_nodes.astype({'mentioned_in': 'Int32'}) + self.old_nodes = self.old_nodes.astype({'string': 'string'}) + else: + self.nodeid2name = dict() + self.nodeid2type = dict() + self.valid_new_node = 0 + self.old_nodes = pd.DataFrame() - self.valid_new_node = nodes['id'].max() + 1 self.node_ids = {} self.new_nodes = [] self.stashed_nodes = [] - self.old_nodes = nodes.copy() - self.old_nodes['mentioned_in'] = pd.NA - self.old_nodes = self.old_nodes.astype({'mentioned_in': 'Int32'}) + self._resolver_cache = dict() def stash_new_nodes(self): + """ + Put new nodes into temporary storage. + :return: Nothing + """ self.stashed_nodes.extend(self.new_nodes) self.new_nodes = [] - def resolve_substrings(self, node, replacement2srctrl): - - decorated = "@" in node.name - assert not decorated - - name_ = copy(node.name) + def recover_original_string(self, name_, replacement2srctrl): + if name_ in self._resolver_cache: + return self._resolver_cache[name_] replacements = dict() global_node_id = [] global_name = [] global_type = [] - for name in re.finditer("srctrlrpl_[0-9]+", name_): + + for name in re.finditer("srctrlrpl_[0-9]{19}", name_): if isinstance(name, re.Match): name = name.group() elif isinstance(name, str): @@ -364,8 +406,26 @@ def resolve_substrings(self, node, replacement2srctrl): for r, v in replacements.items(): real_name = real_name.replace(r, v["name"]) + self._resolver_cache[name_] = (real_name, global_name, global_node_id, global_type) + return real_name, global_name, global_node_id, global_type + + def resolve_substrings(self, node, replacement2srctrl): + + decorated = "@" in node.name + + name_ = copy(node.name) + if decorated: + name_, decorator = name_.split("@") + # assert not decorated + + real_name, global_name, global_node_id, global_type = self.recover_original_string(name_, replacement2srctrl) + + if decorated: + real_name = real_name + "@" + decorator + return GNode( - name=real_name, type=node.type, global_name=global_name, global_id=global_node_id, global_type=global_type + name=real_name, type=node.type, global_name=global_name, global_id=global_node_id, global_type=global_type, + string=node.string ) def resolve_regular_replacement(self, node, replacement2srctrl): @@ -408,28 +468,52 @@ def resolve_regular_replacement(self, node, replacement2srctrl): type_ = node.type if hasattr(node, "scope"): new_node = GNode(name=real_name, type=type_, global_name=global_name, global_id=global_node_id, - global_type=global_type, scope=node.scope) + global_type=global_type, scope=node.scope, string=node.string) else: new_node = GNode(name=real_name, type=type_, global_name=global_name, global_id=global_node_id, - global_type=global_type) + global_type=global_type, string=node.string) else: new_node = node return new_node def resolve(self, node, replacement2srctrl): + """ + :param node: string that represents node + :param replacement2srctrl: dictionary of {sourcetrail_node: original name} + :return: resolved string with original node name + """ + # this function is defined in this class instead of SourceTrail resolver because it needs access to + # global node ids if node.type == "type_annotation": new_node = self.resolve_substrings(node, replacement2srctrl) + if "srctrlrpl_" in new_node.name: + new_node.name = "Any" else: new_node = self.resolve_regular_replacement(node, replacement2srctrl) - if "srctrlrpl_" in new_node.name: # hack to process imports + if "srctrlrpl_" in new_node.name: # hack to process imports TODO maybe need to use while to resolve everything? new_node = self.resolve_substrings(node, replacement2srctrl) + if "srctrlrpl_" in new_node.name: # if still failed to resolve + new_node.name = "unresolved_name" + + if node.string is not None: + string,_,_,_ = self.recover_original_string(node.string, replacement2srctrl) + node.string = string - assert "srctrlrpl_" not in new_node.name + if hasattr(node, "scope"): + node.scope = self.resolve(node.scope, replacement2srctrl) + + # assert "srctrlrpl_" not in new_node.name return new_node def resolve_node_id(self, node, **kwargs): + """ + Resolve node id from name and type, create new node is no nodes like that found. + :param node: node + :param kwargs: + :return: updated node (return object with the save reference as input) + """ if not hasattr(node, "id"): node_repr = (node.name.strip(), node.type.strip()) @@ -440,34 +524,42 @@ def resolve_node_id(self, node, **kwargs): new_id = self.get_new_node_id() self.node_ids[node_repr] = new_id - if not PythonSharedNodes.is_shared(node): + if not PythonSharedNodes.is_shared(node) and not node.name == "unresolved_name": assert "0x" in node.name - self.new_nodes.append( - { - "id": new_id, - "type": node.type, - "serialized_name": node.name, - "mentioned_in": pd.NA - } - ) + if isinstance(node.string, str) and "srctrl" in node.string: + print(node.string) + + new_node = { + "id": new_id, + "type": node.type, + "serialized_name": node.name, + "mentioned_in": pd.NA, + "string": node.string + } + + self.new_nodes.append(new_node) if hasattr(node, "scope"): self.resolve_node_id(node.scope) - self.new_nodes[-1]["mentioned_in"] = node.scope.id + new_node["mentioned_in"] = node.scope.id node.setprop("id", new_id) return node def prepare_for_write(self, from_stashed=False): nodes = pd.concat([self.old_nodes, self.new_nodes_for_write(from_stashed)])[ - ['id', 'type', 'serialized_name', 'mentioned_in'] + ['id', 'type', 'serialized_name', 'mentioned_in', 'string'] ] return nodes def new_nodes_for_write(self, from_stashed=False): - new_nodes = pd.DataFrame(self.new_nodes if not from_stashed else self.stashed_nodes)[ - ['id', 'type', 'serialized_name', 'mentioned_in'] + new_nodes = pd.DataFrame(self.new_nodes if not from_stashed else self.stashed_nodes) + if len(new_nodes) == 0: + return None + + new_nodes = new_nodes[ + ['id', 'type', 'serialized_name', 'mentioned_in', 'string'] ].astype({"mentioned_in": "Int32"}) return new_nodes @@ -609,6 +701,11 @@ def into_offset(range): class SourcetrailResolver: + """ + Helper class to work with source code stored in Sourcetrail database. Implements functions of + - Iterating over files + - Preserving Sourcetrail nodes + """ def __init__(self, nodes, edges, source_location, occurrence, file_content, lang): self.nodes = nodes self.node2name = dict(zip(nodes['id'], nodes['serialized_name'])) @@ -622,6 +719,9 @@ def __init__(self, nodes, edges, source_location, occurrence, file_content, lang @property def occurrence_groups(self): + """ + :return: Iterator for occurrences grouped by file id. + """ if self._occurrence_groups is None: self._occurrence_groups = get_occurrence_groups(self.nodes, self.edges, self.source_location, self.occurrence) return self._occurrence_groups @@ -709,11 +809,18 @@ def add_names_from_edges(self, edges): def to_df(self): return pd.DataFrame(self.group2names) -def global_mention_edges_from_node(node): + +def global_mention_edges_from_node(node, scope=None): + """ + Construct a new edge that will link to a global node. + :param node: + :return: + """ global_edges = [] if type(node.global_id) is int: id_type = [(node.global_id, node.global_type)] else: + # in case there are many global references id_type = zip(node.global_id, node.global_type) for gid, gtype in id_type: @@ -723,19 +830,27 @@ def global_mention_edges_from_node(node): "type": "global_mention", "offsets": None } + if scope is not None: + global_mention["scope"] = scope global_edges.append(global_mention) global_edges.append(make_reverse_edge(global_mention)) return global_edges + def add_global_mentions(edges): + """ + :param edges: List of dictionaries that represent edges + :return: Original edges with additional edges for global references (e.g. global_mention edge + for function calls) + """ new_edges = [] for edge in edges: if edge['src'].type in {"#attr#", "Name"}: if hasattr(edge['src'], "global_id"): - new_edges.extend(global_mention_edges_from_node(edge['src'])) + new_edges.extend(global_mention_edges_from_node(edge['src'], scope=edge["scope"] if "scope" in edge else None)) elif edge['dst'].type == "mention": if hasattr(edge['dst'], "global_id"): - new_edges.extend(global_mention_edges_from_node(edge['dst'])) + new_edges.extend(global_mention_edges_from_node(edge['dst'], scope=edge["scope"] if "scope" in edge else None)) new_edges.append(edge) return new_edges @@ -777,6 +892,13 @@ def produce_nodes_without_name(global_nodes, ast_edges): def standardize_new_edges(edges, node_resolver, mention_tokenizer): + """ + Tokenize relevant node names, assign id to every node, collapse edge representation to id-based + :param edges: list of edges + :param node_resolver: helper class that tracks node ids + :param mention_tokenizer: helper class that performs tokenization of relevant nodes + :return: + """ edges = mention_tokenizer.replace_mentions_with_subwords(edges) @@ -801,22 +923,37 @@ def standardize_new_edges(edges, node_resolver, mention_tokenizer): return edges -def process_code(source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, named_group_tracker, track_offsets=False): +def process_code(source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, track_offsets=False): + """ + + :param source_file_content: String for a file from SOurcetrail database + :param offsets: List of global occurrences in this file + :param node_resolver: + :param mention_tokenizer: + :param node_matcher: + :param track_offsets: Flag that tells whether to perform offset tracking. + :return: Tuple that stores egdes, list of global and ast offsets for the current file in the format + (start, end, node_id, set(offsets for functions where the current offset occurs), + mapping from AST nodes to corresponding global nodes (will be used to replace one with the other). + """ # replace global occurrences with special tokens to help further parsing with ast package replacer = OccurrenceReplacer() replacer.perform_replacements(source_file_content, offsets) # compute ast edges - ast_processor = AstProcessor(replacer.source_with_replacements) + try: + ast_processor = AstProcessor(replacer.source_with_replacements) + except: + return None, None, None, None try: # TODO recursion error does not appear consistently. The issue is probably with library versions... edges = ast_processor.get_edges(as_dataframe=False) except RecursionError: - return None, None, None + return None, None, None, None if len(edges) == 0: - return None, None, None + return None, None, None, None - # resolve existing node names (primarily for subwords) + # resolve sourcetrail nodes ans replace with original resolve = lambda node: node_resolver.resolve(node, replacer.replacement_index) for edge in edges: @@ -828,23 +965,32 @@ def process_code(source_file_content, offsets, node_resolver, mention_tokenizer, # insert global mentions using replacements that were created on the first step edges = add_global_mentions(edges) - named_group_tracker.add_names_from_edges(edges) + # It seems I do not need this feature + # named_group_tracker.add_names_from_edges(edges) + # tokenize names, replace nodes with their ids edges = standardize_new_edges(edges, node_resolver, mention_tokenizer) if track_offsets: def get_valid_offsets(edges): + """ + :param edges: Dictionary that represents edge. Information is tored in edges but is related to source node + :return: Information about location of this edge (offset) in the source file in fromat (start, end, node_id) + """ return [(edge["offsets"][0], edge["offsets"][1], edge["src"]) for edge in edges if edge["offsets"] is not None] + # recover ast offsets for the current file ast_offsets = replacer.recover_offsets_with_edits2(get_valid_offsets(edges)) def merge_global_and_ast_offsets(ast_offsets, global_offsets, definitions): """ - Merge local and global offsets and add information about the scope - :param ast_offsets: - :param global_offsets: - :param definitions: - :return: + Merge local and global offsets and add information about the scope. Preserve the information that + indicates a function where the occurrence takes place. + :param ast_offsets: List of offsets in format (start, end, node_id) + :param global_offsets: List of offsets in format (start, end, node_id) + :param definitions: offsets for function declarations + :return: List of offsets in format [start, end, node_id, set(function_offset)]. Each offset can belong to + several functions in the case of nested functions. """ offsets = [[*offset, set()] for offset in ast_offsets] offsets = offsets + [[*offset, set()] for offset in global_offsets] @@ -864,23 +1010,40 @@ def merge_global_and_ast_offsets(ast_offsets, global_offsets, definitions): else: global_and_ast_offsets = None - ast_nodes_to_srctrl_nodes = node_matcher.match_with_global_nodes(node_resolver.new_nodes, edges) + # Get mapping from AST nodes to global nodes + ast_nodes_to_srctrl_nodes, ast_node_names_to_global_node_names = node_matcher.match_with_global_nodes(node_resolver.new_nodes, edges) - return edges, global_and_ast_offsets, ast_nodes_to_srctrl_nodes + return edges, global_and_ast_offsets, ast_nodes_to_srctrl_nodes, ast_node_names_to_global_node_names def get_ast_from_modules( nodes, edges, source_location, occurrence, file_content, bpe_tokenizer_path, create_subword_instances, connect_subwords, lang, track_offsets=False ): - + """ + Create edges from source code and methe them wit hthe global graph. Prepare all offsets in the uniform format. + :param nodes: DataFrame with nodes + :param edges: DataFrame with edges + :param source_location: Dataframe that links nodes to source files + :param occurrence: Dataframe with records about occurrences + :param file_content: Dataframe with sources + :param bpe_tokenizer_path: path to sentencepiece model + :param create_subword_instances: + :param connect_subwords: + :param lang: + :param track_offsets: + :return: Tuple: + - Dataframe with all nodes. Schema: id, type, name, mentioned_in (global and AST) + - Dataframe with all edges. Schema: id, type, src, dst, file_id, mentioned_in + - Dataframe with all_offsets. Schema: file_id, start, end, node_id, mentioned_in + """ srctrl_resolver = SourcetrailResolver(nodes, edges, source_location, occurrence, file_content, lang) node_resolver = ReplacementNodeResolver(nodes) node_matcher = GlobalNodeMatcher(nodes, edges) # add_reverse_edges(edges)) mention_tokenizer = MentionTokenizer(bpe_tokenizer_path, create_subword_instances, connect_subwords) - name_group_tracker = NameGroupTracker() all_ast_edges = [] all_global_references = {} + all_name_mappings = {} all_offsets = [] for group_ind, (file_id, occurrences) in custom_tqdm( @@ -897,8 +1060,8 @@ def get_ast_from_modules( # process code # try: - edges, global_and_ast_offsets, ast_nodes_to_srctrl_nodes = process_code( - source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, name_group_tracker, track_offsets=track_offsets + edges, global_and_ast_offsets, ast_nodes_to_srctrl_nodes, ast_node_names_to_global_node_names = process_code( + source_file_content, offsets, node_resolver, mention_tokenizer, node_matcher, track_offsets=track_offsets ) # except SyntaxError: # logging.warning(f"Error processing file_id {file_id}") @@ -916,8 +1079,15 @@ def get_ast_from_modules( all_ast_edges.extend(edges) node_matcher.merge_global_references(all_global_references, ast_nodes_to_srctrl_nodes) + all_name_mappings.update(ast_node_names_to_global_node_names) def format_offsets(global_and_ast_offsets, target): + """ + Format offset as a record and add to the common storage for offsets + :param global_and_ast_offsets: + :param target: List where all other offsets are stored. + :return: Nothing + """ if global_and_ast_offsets is not None: for offset in global_and_ast_offsets: target.append({ @@ -939,7 +1109,7 @@ def replace_ast_node_to_global(edges, mapping): if "scope" in edge: edge["scope"] = mapping.get(edge["scope"], edge["scope"]) - replace_ast_node_to_global(all_ast_edges, all_global_references) + # replace_ast_node_to_global(all_ast_edges, all_global_references) # disabled to keep global nodes separate def create_subwords_for_global_nodes(): all_ast_edges.extend( @@ -947,7 +1117,7 @@ def create_subwords_for_global_nodes(): node_resolver, mention_tokenizer)) node_resolver.stash_new_nodes() - create_subwords_for_global_nodes() + # create_subwords_for_global_nodes() # disabled to keep global nodes separate def prepare_new_nodes(node_resolver): @@ -960,11 +1130,32 @@ def prepare_new_nodes(node_resolver): }, from_stashed=True ) node_resolver.map_mentioned_in_to_global(all_global_references, from_stashed=True) + # TODO since some function definitions and modules are dropped, need to rename decorated nodes as well + # find old name in node_resolver.stashed nodes, go over all nodes and replace corresponding substrings in names node_resolver.drop_nodes(set(all_global_references.keys()), from_stashed=True) - prepare_new_nodes(node_resolver) + for offset in all_offsets: + offset["node_id"] = all_global_references.get(offset["node_id"], offset["node_id"]) + offset["mentioned_in"] = [(e[0], e[1], all_global_references.get(e[2], e[2])) for e in offset["mentioned_in"]] + + + # prepare_new_nodes(node_resolver) # disabled to keep global nodes separate all_ast_nodes = node_resolver.new_nodes_for_write(from_stashed=True) + if all_ast_nodes is None: + return None, None, None, None + + def decipher_node_name(name): + if name in all_name_mappings: + return all_name_mappings[name] + elif "@" in name: + for key, value in all_name_mappings.items(): + if key in name: + return name.replace(key, value) + return name + return name + + # all_ast_nodes["serialized_name"] = all_ast_nodes["serialized_name"].apply(decipher_node_name) def prepare_edges(all_ast_edges): all_ast_edges = pd.DataFrame(all_ast_edges) @@ -983,7 +1174,15 @@ def prepare_edges(all_ast_edges): else: all_offsets = None - return all_ast_nodes, all_ast_edges, all_offsets, name_group_tracker.to_df() + if len(all_name_mappings) > 0: + all_name_mappings = pd.DataFrame({ + "ast_name": list(all_name_mappings.keys()), + "proper_names": list(all_name_mappings.values()) + }) + else: + all_name_mappings = None + + return all_ast_nodes, all_ast_edges, all_offsets, all_name_mappings class OccurrenceReplacer: @@ -995,6 +1194,8 @@ def __init__(self): self.evicted = None self.edits = None + self._identifier_pool = IntIdentifierPool() + @staticmethod def format_offsets_for_replacements(offsets): offsets = offsets.sort_values(by=["start", "end"], ascending=[True, False]) @@ -1057,8 +1258,10 @@ def perform_replacements(self, source_file_content, offsets): offset = self.group_overlapping_offsets(offset, pending) # new_name = f"srctrlnd_{offset[2][0]}" - replacement_id = int(time_ns()) - new_name = "srctrlrpl_" + str(replacement_id) + # replacement_id = str(int(time_ns())) + replacement_id = self._identifier_pool.get_new_identifier() + assert len(replacement_id) == 19 + new_name = "srctrlrpl_" + replacement_id replacement_index[new_name] = { "srctrl_id": offset[2][0] if type(offset[2]) is not list else [o[0] for o in offset[2]], "original_string": src_str @@ -1190,7 +1393,7 @@ def recover_offsets_with_edits2(self, offsets): edges = read_edges(working_directory) file_content = read_filecontent(working_directory) - ast_nodes, ast_edges, offsets, ng = get_ast_from_modules(nodes, edges, source_location, occurrence, file_content, + ast_nodes, ast_edges, offsets = get_ast_from_modules(nodes, edges, source_location, occurrence, file_content, args.bpe_tokenizer, args.create_subword_instances, args.connect_subwords, args.lang) diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_call_seq_extractor.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_call_seq_extractor.py index 73f213a4..387272c2 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_call_seq_extractor.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_call_seq_extractor.py @@ -1,3 +1,4 @@ +from SourceCodeTools.code.common import custom_tqdm, SQLTable from SourceCodeTools.code.data.sourcetrail.common import * import sys @@ -31,7 +32,7 @@ def extract_call_seq(nodes, edges, source_location, occurrence): enumerate(occurrence_groups), message="Extracting call sequences", total=len(occurrence_groups) ): - sql_occurrences = SQLTable(occurrences, "/tmp/sourcetrail_occurrences.db", "occurrences") + sql_occurrences = SQLTable(occurrences, ":memory:", "occurrences") function_definitions = sql_get_function_definitions(sql_occurrences) diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_compute_function_diameter.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_compute_function_diameter.py index 7c97bf49..3fbc9d32 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_compute_function_diameter.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_compute_function_diameter.py @@ -1,5 +1,5 @@ -from SourceCodeTools.code.data.sourcetrail.file_utils import * -from SourceCodeTools.code.data.sourcetrail.common import custom_tqdm +from SourceCodeTools.code.data.file_utils import * +from SourceCodeTools.code.common import custom_tqdm import networkx as nx from collections import Counter diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_connected_component.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_connected_component.py index 84828dc8..ad0876e4 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_connected_component.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_connected_component.py @@ -1,7 +1,7 @@ import networkx as nx import sys -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * def main(): diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_decode_edge_types.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_decode_edge_types.py index 34f2ba14..51369150 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_decode_edge_types.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_decode_edge_types.py @@ -1,5 +1,5 @@ from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import edge_types -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * import sys import os diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_draw_graph.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_draw_graph.py index 527309fe..f7793e2d 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_draw_graph.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_draw_graph.py @@ -1,16 +1,14 @@ import pandas as pd -import pygraphviz as pgv def visualize(nodes, edges, output_path): from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types + import pygraphviz as pgv global_types = list(node_types.values()) edges = edges[edges["type"].apply(lambda x: not x.endswith("_rev"))] - from SourceCodeTools.code.data.sourcetrail.Dataset import get_global_edges, ensure_connectedness - # def remove_ast_edges(nodes, edges): # global_edges = get_global_edges() # global_edges.add("subword") @@ -32,6 +30,9 @@ def visualize(nodes, edges, output_path): g = pgv.AGraph(strict=False, directed=True) + from SourceCodeTools.code.ast.python_ast2 import PythonNodeEdgeDefinitions + auxiliaty_edge_types = PythonNodeEdgeDefinitions.auxiliary_edges() + for ind, edge in edges.iterrows(): src = edge['source_node_id'] dst = edge['target_node_id'] @@ -39,7 +40,7 @@ def visualize(nodes, edges, output_path): dst_name = id2name[dst] g.add_node(src_name, color="blue" if id2type[src] in global_types else "black") g.add_node(dst_name, color="blue" if id2type[dst] in global_types else "black") - g.add_edge(src_name, dst_name) + g.add_edge(src_name, dst_name, color="blue" if edge['type'] in auxiliaty_edge_types else "black") g_edge = g.get_edge(src_name, dst_name) g_edge.attr['label'] = edge['type'] diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_edges_name_resolve.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_edges_name_resolve.py index 8e69b141..3e9345d0 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_edges_name_resolve.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_edges_name_resolve.py @@ -1,6 +1,5 @@ -import pandas as p -import sys, os -from SourceCodeTools.code.data.sourcetrail.file_utils import * +import sys +from SourceCodeTools.code.data.file_utils import * nodes = unpersist(sys.argv[1]) edges = unpersist(sys.argv[2]) diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py index bbf490bd..da50d1a2 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py @@ -1,6 +1,7 @@ import sys -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.ast_graph.extract_node_names import extract_node_names +from SourceCodeTools.code.data.file_utils import * def get_node_name(full_name): @@ -10,25 +11,13 @@ def get_node_name(full_name): return full_name.split(".")[-1].split("___")[0] -def extract_node_names(nodes, min_count): - - +def extract_node_names_(nodes, min_count): # some cells are empty, probably because of empty strings in AST # data = nodes.dropna(axis=0) # data = data[data['type'] != 262144] - data = nodes - data['src'] = data['id'] - data['dst'] = data['serialized_name'].apply(get_node_name) - - counts = data['dst'].value_counts() - - data['counts'] = data['dst'].apply(lambda x: counts[x]) - data = data.query(f"counts >= {min_count}") - - if len(data) > 0: - return data[['src', 'dst']] - else: - return None + nodes = nodes.copy() + nodes['serialized_name'] = nodes['serialized_name'].apply(get_node_name) + return extract_node_names(nodes, min_count) if __name__ == "__main__": diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_variable_names.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_variable_names.py index 8378163e..553e763e 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_variable_names.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_variable_names.py @@ -3,8 +3,8 @@ import sys from collections import Counter -from SourceCodeTools.code.data.sourcetrail.common import custom_tqdm -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.common import custom_tqdm +from SourceCodeTools.code.data.file_utils import * pandas.options.mode.chained_assignment = None diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_ambiguous_edges.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_ambiguous_edges.py index 403c6fa2..16d55342 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_ambiguous_edges.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_ambiguous_edges.py @@ -1,7 +1,7 @@ #%% import sys -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * def filter_ambiguous_edges(edges, ambiguous_edges): diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_type_edges.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_type_edges.py new file mode 100644 index 00000000..6cf74270 --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_filter_type_edges.py @@ -0,0 +1,21 @@ +import sys +from os.path import join + +from SourceCodeTools.code.data.ast_graph.filter_type_edges import filter_type_edges +from SourceCodeTools.code.data.file_utils import * + + +def main(): + working_directory = sys.argv[1] + nodes = unpersist(join(working_directory, "nodes.bz2")) + edges = unpersist(join(working_directory, "edges.bz2")) + out_annotations = sys.argv[2] + out_no_annotations = sys.argv[3] + + no_annotations, annotations = filter_type_edges(nodes, edges) + + persist(annotations, out_annotations) + persist(no_annotations, out_no_annotations) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_fn_visualizer.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_fn_visualizer.py new file mode 100644 index 00000000..f44ca3c5 --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_fn_visualizer.py @@ -0,0 +1,35 @@ +from os.path import join + +from SourceCodeTools.code.data.dataset.Dataset import load_data +from SourceCodeTools.code.data.sourcetrail.sourcetrail_draw_graph import visualize + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("working_directory") + parser.add_argument("function_id", type=int) + parser.add_argument("output_path") + + args = parser.parse_args() + + nodes_path = join(args.working_directory, "common_nodes.bz2") + edges_path = join(args.working_directory, "common_edges.bz2") + + nodes, edges = load_data(nodes_path, edges_path, rename_columns=False) + + fn_edges = edges.query(f"mentioned_in == {args.function_id}") + fn_node_ids = set(fn_edges["source_node_id"] + fn_edges["target_node_id"]) + registered_edges = set(fn_edges["id"]) + extra_edges = edges[ + (edges["source_node_id"].apply(lambda id_: id_ in fn_node_ids) | + edges["target_node_id"].apply(lambda id_: id_ in fn_node_ids) ) #& + # edges["id"].apply(lambda id_: id_ not in registered_edges) + ] + fn_nodes = nodes[ + nodes["id"].apply(lambda id_: id_ in fn_node_ids) + ] + visualize(nodes, fn_edges, args.output_path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_get_module_sizes.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_get_module_sizes.py index 6eac0345..c9da6c51 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_get_module_sizes.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_get_module_sizes.py @@ -2,9 +2,7 @@ from collections import Counter from pprint import pprint -from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist -from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types -import pandas as pd +from SourceCodeTools.code.data.file_utils import unpersist # def estimate_module_sizes(nodes_path): diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_graph_complexity_analysis.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_graph_complexity_analysis.py new file mode 100644 index 00000000..0813b96d --- /dev/null +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_graph_complexity_analysis.py @@ -0,0 +1,230 @@ +import logging +from collections import defaultdict +from os.path import join, dirname + +import dgl +import pandas as pd +import argparse +import numpy as np + +import matplotlib.pyplot as plt +from tqdm import tqdm + +from SourceCodeTools.code.ast import has_valid_syntax +from SourceCodeTools.code.data.file_utils import unpersist +from SourceCodeTools.code.data.sourcetrail.sourcetrail_ast_edges2 import AstProcessor, standardize_new_edges, \ + ReplacementNodeResolver, MentionTokenizer +from SourceCodeTools.code.data.sourcetrail.sourcetrail_compute_function_diameter import compute_diameter +# from SourceCodeTools.code.data.type_annotation_dataset.create_type_annotation_dataset import get_docstring, \ +# remove_offsets +from SourceCodeTools.nlp import create_tokenizer + + +import tokenize +from io import StringIO +def remove_comments_and_docstrings(source): + """ + Returns 'source' minus comments and docstrings. + https://stackoverflow.com/questions/1769332/script-to-remove-python-comments-docstrings + """ + io_obj = StringIO(source) + out = "" + prev_toktype = tokenize.INDENT + last_lineno = -1 + last_col = 0 + for tok in tokenize.generate_tokens(io_obj.readline): + token_type = tok[0] + token_string = tok[1] + start_line, start_col = tok[2] + end_line, end_col = tok[3] + ltext = tok[4] + # The following two conditionals preserve indentation. + # This is necessary because we're not using tokenize.untokenize() + # (because it spits out code with copious amounts of oddly-placed + # whitespace). + if start_line > last_lineno: + last_col = 0 + if start_col > last_col: + out += (" " * (start_col - last_col)) + # Remove comments: + if token_type == tokenize.COMMENT: + pass + # This series of conditionals removes docstrings: + elif token_type == tokenize.STRING: + if prev_toktype != tokenize.INDENT: + # This is likely a docstring; double-check we're not inside an operator: + if prev_toktype != tokenize.NEWLINE: + if start_col > 0: + # Unlabelled indentation means we're inside an operator + out += token_string + else: + out += token_string + prev_toktype = token_type + last_col = end_col + last_lineno = end_line + return out + + +def process_code(source_file_content, node_resolver, mention_tokenizer): + try: + ast_processor = AstProcessor(source_file_content) + except: + logging.warning("Unknown exception") + return None + try: # TODO recursion error does not appear consistently. The issue is probably with library versions... + edges = ast_processor.get_edges(as_dataframe=False) + except RecursionError: + return None + + if len(edges) == 0: + return None + + edges = standardize_new_edges(edges, node_resolver, mention_tokenizer) + + return edges + + +def compute_gnn_passings(body, mention_tokenizer, num_layers=None): + + node_resolver = ReplacementNodeResolver() + + source_file_content = body.lstrip() + + edges = process_code( + source_file_content, node_resolver, mention_tokenizer + ) + + if edges is None: + return None + + edges = pd.DataFrame(edges).rename({"src": "source_node_id", "dst": "target_node_id"}, axis=1) + diameter = compute_diameter(edges, func_id=0) + + G = dgl.DGLGraph() + G.add_edges(edges["source_node_id"], edges["target_node_id"]) + + def compute_for_n_layers(n_layers): + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(n_layers) + + for node in G.nodes(): + non_zero_blocks = 0 + num_edges = 0 + loader = dgl.dataloading.NodeDataLoader( + G, [node], sampler, batch_size=1, shuffle=True, num_workers=0) + for input_nodes, seeds, blocks in loader: + num_edges = 0 + for block in blocks: + if block.num_edges() > 0: + non_zero_blocks += 1 + num_edges += block.num_edges() + if num_edges != 0 and non_zero_blocks >= n_layers: + break + return num_edges + + passings_num_layers = compute_for_n_layers(num_layers) + passings_diameter = compute_for_n_layers(diameter) + + return len(edges), passings_num_layers, passings_diameter + + +def compute_transformer_passings(body, bpe): + num_tokens = len(bpe(body)) + return num_tokens + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("bodies") + parser.add_argument("bpe_path") + parser.add_argument("--num_layers", default=8, type=int) + + args = parser.parse_args() + + bodies = unpersist(args.bodies) + bpe = create_tokenizer(type="bpe", bpe_path=args.bpe_path) + mention_tokenizer = MentionTokenizer(args.bpe_path, create_subword_instances=False, connect_subwords=False) + + lengths_tr = defaultdict(list) + lengths_gnn_layers = defaultdict(list) + lengths_gnn_diameter = defaultdict(list) + ratio_layers = [] + ratio_diameter = [] + + for body in tqdm(bodies["body"]): + if not has_valid_syntax(body): + continue + + body_ = body + body = body_.lstrip() + initial_strip = body[:len(body_) - len(body)] + + # docsting_offsets = get_docstring(body) + # body, replacements, docstrings = remove_offsets(body, [], docsting_offsets) + body = remove_comments_and_docstrings(body) + + n_tokens = compute_transformer_passings(body, bpe) + result = compute_gnn_passings(body, mention_tokenizer, args.num_layers) + if result is None: + continue + n_edges, n_passings, n_passings_diam = result + + lengths_tr[n_tokens].append(n_tokens ** 2 * args.num_layers) + lengths_gnn_layers[n_tokens].append(n_passings)# * args.num_layers) + lengths_gnn_diameter[n_tokens].append(n_passings_diam) # * args.num_layers) + ratio_layers.append((n_tokens, n_passings)) + ratio_diameter.append((n_tokens, n_passings_diam)) + + for key in lengths_tr: + data_tr = np.array(lengths_tr[key]) + data_gnn_layers = np.array(lengths_gnn_layers[key]) + data_gnn_diameter = np.array(lengths_gnn_diameter[key]) + + lengths_tr[key] = np.mean(data_tr)#, np.std(data_tr)) + lengths_gnn_layers[key] = np.mean(data_gnn_layers)#, np.std(data_gnn)) + lengths_gnn_diameter[key] = np.mean(data_gnn_diameter) + + data_ratios_layers = np.array(ratio_layers) + data_ratios_diameter = np.array(ratio_diameter) + + plt.plot(data_ratios_layers[:, 0], data_ratios_layers[:, 1], "*") + plt.plot(data_ratios_diameter[:, 0], data_ratios_diameter[:, 1], "*") + plt.xlabel("Number of Tokens") + plt.ylabel("Number of Message Exchanges") + plt.legend([f"Layers = {args.num_layers}", "Layers = Graph Diameter"]) + plt.savefig(join(dirname(args.bodies), "tokens_edges.png")) + plt.close() + + plt.hist(data_ratios_layers[:, 1] / data_ratios_layers[:, 0], bins=20) + plt.hist(data_ratios_diameter[:, 1] / data_ratios_diameter[:, 0], bins=20) + plt.xlabel("Number of edges / Number of tokens") + plt.legend([f"Layers = {args.num_layers}", "Layers = Graph Diameter"]) + plt.savefig(join(dirname(args.bodies), "ratio.png")) + plt.close() + + ratio_layers = data_ratios_layers[:, 1] / data_ratios_layers[:, 0] + ratio_layers = (np.mean(ratio_layers), np.std(ratio_layers)) + + ratio_diameter = data_ratios_diameter[:, 1] / data_ratios_diameter[:, 0] + ratio_diameter = (np.mean(ratio_diameter), np.std(ratio_diameter)) + + plt.plot(list(lengths_tr.keys()), np.array(list(lengths_tr.values())), "*") + plt.plot(list(lengths_gnn_layers.keys()), np.array(list(lengths_gnn_layers.values())), "*") + plt.plot(list(lengths_gnn_diameter.keys()), np.array(list(lengths_gnn_diameter.values())), "*") + plt.gca().set_yscale('log') + plt.grid() + plt.legend([ + f"Transformer Layers = {args.num_layers}", + f"GNN Layers = {args.num_layers}", + f"GNN Layers = Graph Diameter"] + ) + plt.xlabel("Number of Tokens") + plt.ylabel("Number of Message Exchanges") + plt.savefig(join(dirname(args.bodies), "avg_passings.png")) + plt.close() + + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns.py index 6820849a..bbbaa38e 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns.py @@ -1,24 +1,8 @@ import sys -from typing import Iterable -from SourceCodeTools.code.data.sourcetrail.common import \ - map_id_columns, merge_with_file_if_exists, create_local_to_global_id_map -from SourceCodeTools.code.data.sourcetrail.file_utils import * - - -def map_columns(input_table, id_map, columns, columns_special=None): - - input_table = map_id_columns(input_table, columns, id_map) - - if columns_special is not None: - assert isinstance(columns_special, list), "`columns_special` should be iterable" - for column, map_func in columns_special: - input_table[column] = map_func(input_table[column], id_map) - - if len(input_table) == 0: - return None - else: - return input_table +from SourceCodeTools.code.common import map_columns, merge_with_file_if_exists +from SourceCodeTools.code.data.ast_graph.local2global import create_local_to_global_id_map +from SourceCodeTools.code.data.file_utils import * if __name__ == "__main__": diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns_only_annotations.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns_only_annotations.py index d022b1da..047ccbe6 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns_only_annotations.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_map_id_columns_only_annotations.py @@ -1,7 +1,8 @@ import sys -from SourceCodeTools.code.data.sourcetrail.common import map_id_columns, merge_with_file_if_exists, create_local_to_global_id_map -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.common import map_id_columns, merge_with_file_if_exists +from SourceCodeTools.code.data.ast_graph.local2global import create_local_to_global_id_map +from SourceCodeTools.code.data.file_utils import * def map_columns_with_annotations(input_table, id_map, columns): diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py index 4d68195d..471b2a56 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_merge_graphs.py @@ -1,7 +1,7 @@ import sys -from SourceCodeTools.code.data.sourcetrail.file_utils import * -from SourceCodeTools.code.data.sourcetrail.common import merge_with_file_if_exists +from SourceCodeTools.code.common import merge_with_file_if_exists +from SourceCodeTools.code.data.file_utils import * pd.options.mode.chained_assignment = None @@ -43,7 +43,7 @@ def merge_global_with_local(existing_nodes, next_valid_id, local_nodes): assert len(new_nodes) == len(set(new_nodes['node_repr'].to_list())) - ids_start = next_valid_id + ids_start = int(next_valid_id) ids_end = ids_start + len(new_nodes) new_nodes['id'] = list(range(ids_start, ids_end)) diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_local2global.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_local2global.py index 27280010..93a28b91 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_local2global.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_local2global.py @@ -1,14 +1,14 @@ import sys -from SourceCodeTools.code.data.sourcetrail.file_utils import * -from SourceCodeTools.code.data.sourcetrail.common import create_local_to_global_id_map +from SourceCodeTools.code.data.ast_graph.local2global import create_local_to_global_id_map +from SourceCodeTools.code.data.file_utils import * def get_local2global(global_nodes, local_nodes) -> pd.DataFrame: local_nodes = local_nodes.copy() id_map = create_local_to_global_id_map(local_nodes=local_nodes, global_nodes=global_nodes) - local_nodes['global_id'] = local_nodes['id'].apply(lambda x: id_map.get(x, -1)) + local_nodes['global_id'] = local_nodes['id'].apply(lambda x: id_map.get(x, None)) return local_nodes[['id', 'global_id']] diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_name_merge.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_name_merge.py index 22715a08..6f4d5fc5 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_name_merge.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_node_name_merge.py @@ -1,6 +1,6 @@ import sys from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * # needs testing def normalize(line): diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py index a127ff49..f80953ed 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_parse_bodies2.py @@ -2,8 +2,10 @@ import sys from typing import Tuple, List, Optional +from SourceCodeTools.code.ast import has_valid_syntax +from SourceCodeTools.code.common import custom_tqdm, SQLTable from SourceCodeTools.code.data.sourcetrail.common import * -from SourceCodeTools.nlp.entity.annotator.annotator_utils import to_offsets +from SourceCodeTools.code.annotator_utils import to_offsets pd.options.mode.chained_assignment = None # default='warn' @@ -60,7 +62,10 @@ def get_function_body(file_content, file_id, start, end, s_col, e_col) -> str: source_lines = file_content[file_id].split("\n") - body_lines = source_lines[start: end] + if start == end: # handle situations when the entire function takes only one line + body_lines = source_lines[start] + else: + body_lines = source_lines[start: end] initial_strip = body_lines[0][0:len(body_lines[0]) - len(body_lines[0].lstrip())] body = initial_strip + file_content[file_id][offsets[0][0]: offsets[0][1]] @@ -68,14 +73,6 @@ def get_function_body(file_content, file_id, start, end, s_col, e_col) -> str: return body -def has_valid_syntax(function_body): - try: - ast.parse(function_body.lstrip()) - return True - except SyntaxError: - return False - - def get_range_for_replacement(occurrence, start_col, end_col, line, nodeid2name): extended_range = extend_range(start_col, end_col, line) @@ -100,6 +97,15 @@ def get_range_for_replacement(occurrence, start_col, end_col, line, nodeid2name) def process_body(body, local_occurrences, nodeid2name, f_id, f_start): + """ + Extract the list + :param body: + :param local_occurrences: + :param nodeid2name: + :param f_id: + :param f_start: + :return: + """ body_lines = body.split("\n") local_occurrences = sort_occurrences(local_occurrences) @@ -141,6 +147,16 @@ def process_body(body, local_occurrences, nodeid2name, f_id, f_start): def process_bodies(nodes, edges, source_location, occurrence, file_content, lang): + """ + :param nodes: + :param edges: + :param source_location: + :param occurrence: + :param file_content: + :param lang: + :return: Dataframe with columns id, body, sourcetrail_node_offsets. Node offsets are not resolved from the + global graph. + """ occurrence_groups = get_occurrence_groups(nodes, edges, source_location, occurrence) @@ -153,7 +169,7 @@ def process_bodies(nodes, edges, source_location, occurrence, file_content, lang for group_ind, (file_id, occurrences) in custom_tqdm( enumerate(occurrence_groups), message="Processing function bodies", total=len(occurrence_groups) ): - sql_occurrences = SQLTable(occurrences, "/tmp/sourcetrail_occurrences.db", "occurrences") + sql_occurrences = SQLTable(occurrences, ":memory:", "occurrences") function_definitions = sql_get_function_definitions(sql_occurrences) diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py b/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py index bbad1d77..dc442261 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py +++ b/SourceCodeTools/code/data/sourcetrail/sourcetrail_types.py @@ -15,7 +15,8 @@ 2: "uses_type", # from user to type 16: "inheritance", 4: "uses", # from user to item - 512: "imports" # from module to imported object + 512: "imports", # from module to imported object + 32: "non-idexed-package" } diff --git a/SourceCodeTools/code/data/type_annotation_dataset/create_type_annotation_aware_graph_partition.py b/SourceCodeTools/code/data/type_annotation_dataset/create_type_annotation_aware_graph_partition.py new file mode 100644 index 00000000..a7a5e4b0 --- /dev/null +++ b/SourceCodeTools/code/data/type_annotation_dataset/create_type_annotation_aware_graph_partition.py @@ -0,0 +1,91 @@ +import argparse +import json +from os.path import join +from random import random + +import pandas as pd + +from SourceCodeTools.code.common import read_nodes +from SourceCodeTools.code.data.file_utils import persist + + +def add_splits(items, train_frac, restricted_id_pool=None, force_test=None): + items = items.copy() + + if force_test is None: + force_test = set() + + def random_partition(node_id): + r = random() + if node_id not in force_test: + if r < train_frac: + return "train" + elif r < train_frac + (1 - train_frac) / 2: + return "val" + else: + return "test" + else: + if r < .5: + return "val" + else: + return "test" + + import numpy as np + # define partitioning + masks = np.array([random_partition(node_id) for node_id in items["id"]]) + + # create masks + items["train_mask"] = masks == "train" + items["val_mask"] = masks == "val" + items["test_mask"] = masks == "test" + + if restricted_id_pool is not None: + # if `restricted_id_pool` is provided, mask all nodes not in `restricted_id_pool` negatively + to_keep = items.eval("id in @restricted_ids", local_dict={"restricted_ids": restricted_id_pool}) + items["train_mask"] = items["train_mask"] & to_keep + items["test_mask"] = items["test_mask"] & to_keep + items["val_mask"] = items["val_mask"] & to_keep + + return items + + +def read_test_set_nodes(path): + node_ids = set() + with open(path, "r") as dataset: + for line in dataset: + if line.strip() == "": + continue + + text, entry = json.loads(line) + + for _, _, node_id in entry["replacements"]: + node_ids.add(node_id) + + return node_ids + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("working_directory") + parser.add_argument("type_annotation_test_set") + parser.add_argument("output_path") + + args = parser.parse_args() + + all_nodes = [] + for nodes in read_nodes(join(args.working_directory, "common_nodes.json.bz2"), as_chunks=True): + all_nodes.append(nodes[["id"]]) + + partition = add_splits( + items=pd.concat(all_nodes), + train_frac=0.8, + force_test=read_test_set_nodes(args.type_annotation_test_set) + ) + + persist(partition, args.output_path) + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py b/SourceCodeTools/code/data/type_annotation_dataset/create_type_annotation_dataset.py similarity index 59% rename from SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py rename to SourceCodeTools/code/data/type_annotation_dataset/create_type_annotation_dataset.py index 731b93ef..702e7c4e 100644 --- a/SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py +++ b/SourceCodeTools/code/data/type_annotation_dataset/create_type_annotation_dataset.py @@ -2,13 +2,14 @@ import json import logging import os +from os.path import join import pandas as pd from tqdm import tqdm -from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist, unpersist_if_present +from SourceCodeTools.code.data.file_utils import unpersist, unpersist_if_present from SourceCodeTools.nlp import create_tokenizer -from SourceCodeTools.nlp.entity.annotator.annotator_utils import to_offsets, adjust_offsets2, \ +from SourceCodeTools.code.annotator_utils import to_offsets, adjust_offsets2, \ resolve_self_collisions2 from SourceCodeTools.nlp.spacy_tools import isvalid @@ -82,19 +83,34 @@ def correct_entities(entities, removed_offsets): for offset_len, offset in zip(offset_lens, offsets_sorted): new_entities = [] for entity in for_correction: - if offset[0] <= entity[0] and offset[1] <= entity[0]: + if offset[0] <= entity[0] and offset[1] <= entity[0]: # removed span is to the left of the entitity if len(entity) == 2: new_entities.append((entity[0] - offset_len, entity[1] - offset_len)) elif len(entity) == 3: new_entities.append((entity[0] - offset_len, entity[1] - offset_len, entity[2])) else: raise Exception("Invalid entity size") - elif offset[0] >= entity[1] and offset[1] >= entity[1]: + elif offset[0] >= entity[1] and offset[1] >= entity[1]: # removed span is to the right of the entitity new_entities.append(entity) - elif offset[0] <= entity[1] <= offset[1] or offset[0] <= entity[0] <= offset[1]: - pass # likely to be a type annotation being removed + elif offset[0] <= entity[0] <= offset[1] and offset[0] <= entity[1] <= offset[1]: # removed span covers the entity + pass + elif offset[0] <= entity[0] <= offset[1] and entity[0] <= offset[1] <= entity[1]: # removed span overlaps on the left + if len(entity) == 3: + new_entities.append((entity[0] - offset_len + offset[1] - entity[1], entity[1] - offset_len, entity[2])) + elif len(entity) == 2: + new_entities.append((entity[0] - offset_len + offset[1] - entity[1], entity[1] - offset_len)) + else: + raise Exception("Invalid entity size") + elif entity[0] <= offset[0] <= entity[1] and entity[0] <= entity[1] <= offset[1]: # removed span overlaps on the right + if len(entity) == 3: + new_entities.append((entity[0], entity[1] - offset_len + offset[1] - entity[1], entity[2])) + elif len(entity) == 2: + new_entities.append((entity[0], entity[1] - offset_len + offset[1] - entity[1])) + else: + raise Exception("Invalid entity size") else: - raise Exception("Invalid data?") + logging.warning(f"Encountered invalid offset: {entity}") + # raise Exception("Invalid data?") for_correction = new_entities @@ -160,6 +176,9 @@ def unpack_returns(body: str, labels: pd.DataFrame): :param labels: DataFrame with information about return type annotation :return: Trimmed body and list of return types (normally one). """ + if labels is None: + return [], [] + returns = [] for ind, row in labels.iterrows(): @@ -193,6 +212,22 @@ def unpack_returns(body: str, labels: pd.DataFrame): return ret, cuts +def get_defaults_spans(body): + root = ast.parse(body) + defaults_offsets = to_offsets( + body, + [(arg.lineno-1, arg.end_lineno-1, arg.col_offset, arg.end_col_offset, "default") for arg in root.body[0].args.defaults], + as_bytes=True + ) + + extended = [] + for start, end, label in defaults_offsets: + while body[start] != "=": + start -= 1 + extended.append((start, end)) + return extended + + def unpack_annotations(body, labels): """ Use information from ast package to strip type annotation from function body @@ -200,6 +235,11 @@ def unpack_annotations(body, labels): :param labels: DataFrame with information about type annotations :return: Trimmed body and list of annotations. """ + if labels is None: + return [], [] + + global remove_default + variables = [] annotations = [] @@ -215,6 +255,7 @@ def unpack_annotations(body, labels): # but type annotations usually appear in the end of signature and in the beginnig of a line variables = to_offsets(body, variables, as_bytes=True) annotations = to_offsets(body, annotations, as_bytes=True) + defaults_spans = get_defaults_spans(body) cuts = [] vars = [] @@ -225,21 +266,35 @@ def unpack_annotations(body, labels): head = body[:offset_ann[0]] orig_len = len(head) - head = head.rstrip() + head = head.rstrip(" \\\n\t") # include character \ since it can be used to indicate line break stripped_len = len(head) annsymbol = ":" assert head.endswith(annsymbol) beginning = beginning - (orig_len - stripped_len) - len(annsymbol) + # Workaround if there is a space before annotation. Example: "groupby : List[str]" + while beginning > 0 and body[beginning - 1] == " ": + beginning -= 1 cuts.append((beginning, end)) assert offset_var[0] != len(head) vars.append((offset_var[0], beginning, preprocess(body[offset_ann[0]: offset_ann[1]]))) + if remove_default: + cuts.extend(defaults_spans) + return vars, cuts -def process_body(nlp, body: str, replacements=None): +def body_valid(body): + try: + ast.parse(body) + return True + except: + return False + + +def process_body(nlp, body: str, replacements=None, require_labels=False): """ Extract annotation information, strip documentation and type annotations. :param nlp: Spacy tokenizer @@ -267,9 +322,10 @@ def process_body(nlp, body: str, replacements=None): body_, replacements, docstrings = remove_offsets(body_, replacements, docsting_offsets) entry['docstrings'].extend(docstrings) + was_valid = body_valid(body_) initial_labels = get_initial_labels(body_) - if initial_labels is None: + if require_labels and initial_labels is None: return None returns, return_cuts = unpack_returns(body_, initial_labels) @@ -277,6 +333,11 @@ def process_body(nlp, body: str, replacements=None): body_, replacements_annotations, _ = remove_offsets(body_, replacements + annotations, return_cuts + annotation_cuts) + is_valid = body_valid(body_) + if was_valid != is_valid: + print("Failed processing") + return None + # raise Exception() replacements_annotations = adjust_offsets2(replacements_annotations, len(initial_strip)) body_ = initial_strip + body_ @@ -288,7 +349,7 @@ def process_body(nlp, body: str, replacements=None): entry['replacements'] = resolve_self_collisions2(entry['replacements']) - assert isvalid(nlp, body_, entry['replacements']) + # assert isvalid(nlp, body_, entry['replacements']) assert isvalid(nlp, body_, entry['ents']) return entry @@ -386,25 +447,28 @@ def load_names(nodes_path): return names -def process_package(working_directory, global_names=None): +def process_package(working_directory, global_names=None, require_labels=False): """ Find functions with annotations, extract annotation information, strip documentation and type annotations. :param working_directory: location of package related files :param global_names: optional, mapping from global node ids to names :return: list of entries in spacy compatible format """ - bodies = unpersist_if_present(os.path.join(working_directory, "source_graph_bodies.bz2")) - if bodies is None: - return [] + # bodies = unpersist_if_present(os.path.join(working_directory, "source_graph_bodies.bz2")) + # if bodies is None: + # return [] - offsets_path = os.path.join(working_directory, "offsets.bz2") + # offsets_path = os.path.join(working_directory, "offsets.bz2") - # offsets store information about spans for nodes referenced in the source code - if os.path.isfile(offsets_path): - offsets = unpersist(offsets_path) - else: - logging.warning(f"No file with offsets: {offsets_path}") - offsets = None + # # offsets store information about spans for nodes referenced in the source code + # if os.path.isfile(offsets_path): + # offsets = unpersist(offsets_path) + # else: + # logging.warning(f"No file with offsets: {offsets_path}") + # offsets = None + + if not os.path.isfile(join(working_directory, "has_annotations")): + return [] def load_local2global(working_directory): local2global = unpersist(os.path.join(working_directory, "local2global_with_ast.bz2")) @@ -415,62 +479,184 @@ def load_local2global(working_directory): local_names = load_names(os.path.join(working_directory, "nodes_with_ast.bz2")) - nlp = create_tokenizer("spacy") + node_maps = get_node_maps(unpersist(join(working_directory, "nodes_with_ast.bz2"))) + filecontent = get_filecontent_maps(unpersist(join(working_directory, "filecontent_with_package.bz2"))) + offsets = group_offsets(unpersist(join(working_directory, "offsets.bz2"))) data = [] + nlp = create_tokenizer("spacy") - for ind, (_, row) in tqdm( - enumerate(bodies.iterrows()), total=len(bodies), - leave=True, desc=os.path.basename(working_directory) - ): - body = row['body'] - - if offsets is not None: - graph_node_spans = offsets_for_func(offsets, body, row["id"]) - else: - graph_node_spans = [] - - entry = process_body(nlp, body, replacements=graph_node_spans) + for ind, (f_body, f_offsets) in enumerate(iterate_functions(offsets, node_maps, filecontent)): + try: + entry = process_body(nlp, f_body, replacements=f_offsets, require_labels=require_labels) + except Exception as e: + logging.warning("Error during processing") + print(working_directory) + print(e) + continue if entry is not None: entry = to_global_ids(entry, id_maps, global_names, local_names) data.append(entry) + # nlp = create_tokenizer("spacy") + # + # data = [] + # + # for ind, (_, row) in tqdm( + # enumerate(bodies.iterrows()), total=len(bodies), + # leave=True, desc=os.path.basename(working_directory) + # ): + # body = row['body'] + # + # if offsets is not None: + # graph_node_spans = offsets_for_func(offsets, body, row["id"]) + # else: + # graph_node_spans = [] + # + # entry = process_body(nlp, body, replacements=graph_node_spans) + # + # if entry is not None: + # entry = to_global_ids(entry, id_maps, global_names, local_names) + # data.append(entry) + return data -def main(): +def iterate_functions(offsets, nodes, filecontent): + + allowed_entity_types = {"class_method", "function"} + + for package_id in offsets: + content = filecontent[package_id] + + # entry is a function or a class + for (entity_start, entity_end, entity_node_id), entity_offsets in offsets[package_id].items(): + if nodes[entity_node_id][1] in allowed_entity_types: + body = content[entity_start: entity_end] + adjusted_entity_offsets = adjust_offsets2(entity_offsets, -entity_start) + + yield body, adjusted_entity_offsets + + +def get_node_maps(nodes): + return dict(zip(nodes["id"], zip(nodes["serialized_name"], nodes["type"]))) + + +def get_filecontent_maps(filecontent): + return dict(zip(zip(filecontent["package"], filecontent["id"]), filecontent["content"])) + + +def group_offsets(offsets): + """ + :param offsets: Dataframe with offsets + :return: offsets grouped first by package name and file id, and then by the entity in which they occur. + """ + # This function will process all function that have graph annotations. If there are no + # annotations - the function is not processed. + offsets_grouped = {} + + for file_id, start, end, node_id, mentioned_in, package in offsets.values: + package_id = (package, file_id) + if package_id not in offsets_grouped: + offsets_grouped[package_id] = {} + + offset_ent = (start, end, node_id) + + for e in mentioned_in: + e = tuple(e) + if e not in offsets_grouped[package_id]: + offsets_grouped[package_id][e] = [] + + offsets_grouped[package_id][e].append(offset_ent) + + return offsets_grouped + + +def create_from_dataset(args): from argparse import ArgumentParser - parser = ArgumentParser() - parser.add_argument("packages", type=str, help="") - parser.add_argument("output_dataset", type=str, help="") - parser.add_argument("--format", "-f", dest="format", default="jsonl", help="jsonl|csv") - parser.add_argument("--global_nodes", "-g", dest="global_nodes", default=None) + # parser = ArgumentParser() + # parser.add_argument("dataset_path", type=str, help="") + # parser.add_argument("output_path", type=str, help="") + # parser.add_argument("--format", "-f", dest="format", default="jsonl", help="jsonl|csv") + # parser.add_argument("--remove_default", action="store_true", default=False) + # + # args = parser.parse_args() + + global remove_default + remove_default = args.remove_default + + node_maps = get_node_maps(unpersist(join(args.dataset_path, "common_nodes.json.bz2"))) + filecontent = get_filecontent_maps(unpersist(join(args.dataset_path, "common_filecontent.json.bz2"))) + offsets = group_offsets(unpersist(join(args.dataset_path, "common_offsets.json.bz2"))) + + data = [] + nlp = create_tokenizer("spacy") + + for ind, (f_body, f_offsets) in enumerate(iterate_functions(offsets, node_maps, filecontent)): + data.append(process_body(nlp, f_body, replacements=f_offsets, require_labels=args.require_labels)) + + store(data, args) - args = parser.parse_args() + +def create_from_environments(args): + # from argparse import ArgumentParser + # parser = ArgumentParser() + # parser.add_argument("packages", type=str, help="") + # parser.add_argument("output_path", type=str, help="") + # parser.add_argument("--format", "-f", dest="format", default="jsonl", help="jsonl|csv") + # parser.add_argument("--global_nodes", "-g", dest="global_nodes", default=None) + # parser.add_argument("--remove_default", default=False, action="store_true") + + # args = parser.parse_args() global_names = load_names(args.global_nodes) + global remove_default + remove_default = args.remove_default + data = [] - for package in os.listdir(args.packages): - pkg_path = os.path.join(args.packages, package) + for package in os.listdir(args.dataset_path): + pkg_path = os.path.join(args.dataset_path, package) if not os.path.isdir(pkg_path): continue - data.extend(process_package(working_directory=pkg_path, global_names=global_names)) + data.extend(process_package(working_directory=pkg_path, global_names=global_names, require_labels=args.require_labels)) + + store(data, args) + +def store(data, args): if args.format == "jsonl": # jsonl format is used by spacy - with open(args.output_dataset, "w") as sink: + with open(args.output_path, "w") as sink: for entry in data: sink.write(f"{json.dumps(entry)}\n") elif args.format == "csv": - if os.path.isfile(args.output_dataset): + if os.path.isfile(args.output_path): header = False else: header = True - pd.DataFrame(data).to_csv(args.output_dataset, index=False, header=header) + pd.DataFrame(data).to_csv(args.output_path, index=False, header=header) if __name__ == "__main__": - main() + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("dataset_format", type=str) + parser.add_argument("dataset_path", type=str, help="") + parser.add_argument("output_path", type=str, help="") + parser.add_argument("--format", "-f", dest="format", default="jsonl", help="jsonl|csv") + parser.add_argument("--remove_default", action="store_true", default=False) + parser.add_argument("--global_nodes", "-g", dest="global_nodes", default=None) + parser.add_argument("--require_labels", action="store_true", default=False) + args = parser.parse_args() + + if args.dataset_format == "envs": + create_from_environments(args) + elif args.dataset_format == "dataset": + create_from_dataset(args) + else: + raise ValueError("supperted dataset formats are: envs|dataset") + # remove_default = False diff --git a/SourceCodeTools/code/data/type_annotation_dataset/map_args_to_mentions.py b/SourceCodeTools/code/data/type_annotation_dataset/map_args_to_mentions.py new file mode 100644 index 00000000..75582ef7 --- /dev/null +++ b/SourceCodeTools/code/data/type_annotation_dataset/map_args_to_mentions.py @@ -0,0 +1,61 @@ +import argparse +import json +from os.path import join + +import pandas as pd + +from SourceCodeTools.code.common import read_nodes +from SourceCodeTools.code.data.dataset.Dataset import load_data +from SourceCodeTools.code.data.file_utils import unpersist + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("working_directory") + parser.add_argument("output") + parser.add_argument("--dataset_file", default=None) + args = parser.parse_args() + + if args.dataset_file is None: + args.dataset_file = join(args.working_directory, "function_annotations.jsonl") + + # nodes, edges = load_data(join(args.working_directory, "common_nodes.json.bz2"), join(args.working_directory, "common_edges.json.bz2")) + + arguments = set() + mentions = set() + for nodes in read_nodes(join(args.working_directory, "common_nodes.json.bz2"), as_chunks=True): + arguments.update(nodes.query("type == 'arg'")["id"]) + mentions.update(nodes.query("type == 'mention'")["id"]) + + # type_annotated = set(unpersist(join(args.working_directory, "type_annotations.json.bz2"))["src"].tolist()) + + scrutinize_edges = [] + + for edges in read_nodes(join(args.working_directory, "common_edges.json.bz2"), as_chunks=True): + edges = edges.query("(source_node_id in @mentions) and (target_node_id in @arguments)", local_dict={"arguments": arguments, "mentions": mentions}) + scrutinize_edges.append(edges) + + edges = pd.concat(scrutinize_edges) + mapping = {} + for src, dst in edges[["source_node_id", "target_node_id"]].values: + if dst in mapping: + print() + mapping[dst] = src + + with open(args.output, "w") as sink: + with open(args.dataset_file) as fa: + for line in fa: + entry = json.loads(line) + new_repl = [[s, e, int(mapping.get(r, r))] for s, e, r in entry["replacements"]] + entry["replacements"] = new_repl + + sink.write(f"{json.dumps(entry)}\n") + + + print() + + # pickle.dump(mapping, open(args.output, "wb")) + + +if __name__ == "__main__": + main() diff --git a/SourceCodeTools/code/data/type_annotation_dataset/split_dataset.py b/SourceCodeTools/code/data/type_annotation_dataset/split_dataset.py new file mode 100644 index 00000000..911cb8b2 --- /dev/null +++ b/SourceCodeTools/code/data/type_annotation_dataset/split_dataset.py @@ -0,0 +1,52 @@ +import argparse +import json +import os +from collections import Counter + +from SourceCodeTools.nlp.entity.utils.data import read_data + + +def get_all_annotations(dataset): + ann = [] + for _, annotations in dataset: + for _, _, e in annotations["entities"]: + ann.append(e) + return ann + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', dest='data_path', default=None, + help='Path to the dataset file') + parser.add_argument('--min_entity_count', dest='min_entity_count', default=3, type=int, + help='') + parser.add_argument('--random_seed', dest='random_seed', default=None, type=int, + help='') + parser.add_argument('--name_suffix', default="", type=str, + help='') + + args = parser.parse_args() + + train_data, test_data = read_data( + open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, include_only="entities", + min_entity_count=args.min_entity_count, random_seed=args.random_seed + ) + + directory = os.path.dirname(args.data_path) + + def write_to_file(data, directory, suffix, partition): + with open(os.path.join(directory, f"type_prediction_dataset_{suffix}_{partition}.json"), "w") as sink: + for entry in data: + sink.write(f"{json.dumps(entry)}\n") + + write_to_file(train_data, directory, args.name_suffix, "train") + write_to_file(test_data, directory, args.name_suffix, "test") + + ent_counts = Counter(get_all_annotations(train_data)) | Counter(get_all_annotations(test_data)) + + with open(os.path.join(directory, f"type_prediction_dataset_{args.name_suffix}_annotations_counts.txt"), "w") as sink: + for ent, count in ent_counts.most_common(): + sink.write(f"{ent}\t{count}\n") + + print() \ No newline at end of file diff --git a/SourceCodeTools/code/detect_hierarchy.py b/SourceCodeTools/code/detect_hierarchy.py index 1b78c95a..937fcd2a 100644 --- a/SourceCodeTools/code/detect_hierarchy.py +++ b/SourceCodeTools/code/detect_hierarchy.py @@ -1,8 +1,8 @@ import argparse -from SourceCodeTools.code.data.sourcetrail.Dataset import load_data +from SourceCodeTools.code.data.dataset.Dataset import load_data from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import edge_types -from SourceCodeTools.code.data.sourcetrail.file_utils import * +from SourceCodeTools.code.data.file_utils import * import networkx as nx from functools import lru_cache diff --git a/SourceCodeTools/code/experiments/Experiments.py b/SourceCodeTools/code/experiments/Experiments.py index 85bb8295..3e75374e 100644 --- a/SourceCodeTools/code/experiments/Experiments.py +++ b/SourceCodeTools/code/experiments/Experiments.py @@ -1,4 +1,7 @@ # %% +import json +from typing import Iterable + import pandas from os.path import join @@ -6,6 +9,8 @@ import numpy as np # from graphtools import Embedder +from SourceCodeTools.code.data.dataset.Dataset import filter_dst_by_freq +from SourceCodeTools.code.data.file_utils import unpersist from SourceCodeTools.models.Embedder import Embedder import pickle @@ -24,6 +29,13 @@ def get_nodes_with_split_ids(embedder, split_ids): # ]['id'].to_numpy() +def shuffle(X, y): + np.random.seed(42) + ind_shuffle = np.arange(0, X.shape[0]) + np.random.shuffle(ind_shuffle) + return X[ind_shuffle], y[ind_shuffle] + + class Experiments: """ Provides convenient interface for creating experiments. @@ -47,7 +59,8 @@ def __init__(self, variable_use_path=None, function_name_path=None, type_ann=None, - gnn_layer=-1): + gnn_layer=-1, + embeddings_path=None): """ :param base_path: path tp trained gnn model @@ -78,7 +91,9 @@ def __init__(self, # self.splits = torch.load(os.path.join(self.base_path, "state_dict.pt"))["splits"] if base_path is not None: - self.embed = pickle.load(open(join(self.base_path, "embeddings.pkl"), "rb"))[gnn_layer] + self.embed = pickle.load(open(embeddings_path, "rb")) + if isinstance(self.embed, Iterable): + self.embed = self.embed[gnn_layer] # alternative_nodes = pickle.load(open("nodes.pkl", "rb")) # self.embed.e = alternative_nodes # self.embed.e = np.random.randn(self.embed.e.shape[0], self.embed.e.shape[1]) @@ -112,12 +127,12 @@ def __getitem__(self, type: str): :param type: str description of the experiment :return: Experiment object """ - nodes = pandas.read_csv(join(self.base_path, "nodes.csv")) - edges = pandas.read_csv(join(self.base_path, "held.csv")).astype({"src": "int32", "dst": "int32"}) - from SourceCodeTools.code.data.sourcetrail.Dataset import SourceGraphDataset + nodes = unpersist(join(self.base_path, "common_nodes.json.bz2")) + # edges = pandas.read_csv(join(self.base_path, "held.csv")).astype({"src": "int32", "dst": "int32"}) + edges = None # unpersist(join(self.base_path, "common_edges.json.bz2")) - self.splits = SourceGraphDataset.get_global_graph_id_splits(nodes) + # self.splits = SourceGraphDataset.get_global_graph_id_splits(nodes) # global_ids = nodes['global_graph_id'].values # self.splits = ( # global_ids[nodes['train_mask'].values], @@ -200,22 +215,136 @@ def __getitem__(self, type: str): compact_dst=False) elif type == "typeann": - type_ann = pandas.read_csv(self.experiments['typeann']).astype({"src": "int32", "dst": "str"}) - # node_pool = set(self.splits[2]).union(self.splits[1]).union(self.splits[0]) - # node_pool = set(nodes['id'].values.tolist()) - # node_pool = set(get_nodes_with_split_ids(nodes, set(self.splits[2]).union(self.splits[1]).union(self.splits[0]))) - node_pool = set( - get_nodes_with_split_ids(self.embed, set(self.splits[2]).union(self.splits[1]).union(self.splits[0]))) + random_seed = 42 + min_entity_count = 3 + + type_ann = unpersist(self.experiments['typeann']) + + # node_names = dict(zip(nodes["id"], nodes["serialized_name"])) + + filter_rule = lambda name: "0x" not in name type_ann = type_ann[ - type_ann['src'].apply(lambda nid: nid in node_pool) + type_ann["dst"].apply(filter_rule) ] + node2id = dict(zip(nodes["id"], nodes["type"])) + type_ann["src_type"] = type_ann["src"].apply(lambda x: node2id[x]) + + type_ann = type_ann[ + type_ann["src_type"].apply(lambda type_: type_ in {"mention"}) + ] + + norm = lambda x: x.strip("\"").strip("'").split("[")[0].split(".")[-1] + + type_ann["dst"] = type_ann["dst"].apply(norm) + type_ann = filter_dst_by_freq(type_ann, min_entity_count) + + # this is used for codebert embeddings + filter_rule = lambda id_: id_ in self.embed.ind + + type_ann = type_ann[ + type_ann["src"].apply(filter_rule) + ] + + # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray', + # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', + # 'Type'} + test_only_popular_types = False + if test_only_popular_types: + allowed = { + # 'str', 'Optional', 'int', 'Any', 'Union', 'bool', 'Other', 'Callable', 'Dict', 'bytes', 'float', + # 'Description', + # 'List', 'Sequence', 'Namespace', 'T', 'Type', 'object', 'HTTPServerRequest', 'Future' + "str", + "Optional", + "int", + "Any", + "Union", + "bool", + "Other", + "Callable", + "Dict", + "bytes", + "float", + "Description", + "List", + "Sequence", + "Namespace", + "T", + "Type", + "object", + "HTTPServerRequest", + "Future", + "Matcher", + } + type_ann = type_ann[ + type_ann["dst"].apply(lambda type_: type_ in allowed) + ] + else: + allowed = None + + from pathlib import Path + dataset_dir = Path(self.experiments['typeann']).parent + from SourceCodeTools.nlp.entity.type_prediction import filter_labels + def read_dataset_file(path): + with open(path, "r") as source: + return [json.loads(line) for line in source] + + train_data = filter_labels( + read_dataset_file(dataset_dir.joinpath("type_prediction_dataset_no_default_args_mapped_train.json")), + allowed=allowed + ) + test_data = filter_labels( + read_dataset_file(dataset_dir.joinpath("type_prediction_dataset_no_default_args_mapped_test.json")), + allowed=allowed + ) + + def get_ids(typeann_dataset): + ids = [] + for sent, annotations in typeann_dataset: + for _, _, r in annotations["replacements"]: + ids.append(r) + + return set(ids) + + train_ids = get_ids(train_data) + test_ids = get_ids(test_data) + + type_ann = type_ann[["src", "dst"]] + + # splits = get_train_val_test_indices(nodes.index) + # create_train_val_test_masks(nodes, *splits) + + # train_data, test_data = read_data( + # open(os.path.join(self.base_path, "function_annotations.jsonl"), "r").readlines(), + # normalize=True, allowed=None, include_replacements=True, + # include_only="entities", + # min_entity_count=min_entity_count, random_seed=random_seed + # ) + # + # train_nodes = set() + # for _, ann in train_data: + # for _, _, nid in ann["replacements"]: + # train_nodes.add(nid) + # + # test_nodes = set() + # for _, ann in test_data: + # for _, _, nid in ann["replacements"]: + # test_nodes.add(nid) + # + # type_nodes = set(type_ann["src"]) + # + # train_nodes = train_nodes.intersection(type_nodes) + # test_nodes = test_nodes.intersection(type_nodes) + # return Experiment2(self.embed, nodes, edges, type_ann, split_on="nodes", neg_sampling_strategy="word2vec") - return Experiment3(self.embed, nodes, edges, type_ann, split_on="edges", - neg_sampling_strategy="word2vec", compact_dst=True) + return Experiment3( + self.embed, nodes, edges, type_ann, split_on="edges", neg_sampling_strategy="word2vec", + compact_dst=True, train_ids=train_ids, test_ids=test_ids + ) elif type == "typeann_name": type_ann = pandas.read_csv(self.experiments['typeann']).astype({"src": "int32", "dst": "str"}) @@ -298,7 +427,8 @@ def __init__(self, split_on="nodes", neg_sampling_strategy="word2vec", K=1, - test_frac=0.1, compact_dst=True): + test_frac=0.1, compact_dst=True, + train_ids=None, test_ids=None): # store local copies" self.embed = embeddings @@ -310,6 +440,8 @@ def __init__(self, self.neg_smpl_strategy = neg_sampling_strategy self.K = K self.TEST_FRAC = test_frac + self.train_ids = train_ids + self.test_ids = test_ids # make sure to drop duplicate edges to prevent leakage into the test set # do it before creating experiment? @@ -391,11 +523,6 @@ def get_training_data(self): assert X_train.shape[0] == y_train.shape[0] assert X_test.shape[0] == y_test.shape[0] - def shuffle(X, y): - ind_shuffle = np.arange(0, X.shape[0]) - np.random.shuffle(ind_shuffle) - return X[ind_shuffle], y[ind_shuffle] - self.X_train, self.y_train = shuffle(X_train, y_train) self.X_test, self.y_test = shuffle(X_test, y_test) @@ -557,19 +684,20 @@ def __init__(self, embeddings: Embedder, split_on="nodes", neg_sampling_strategy="word2vec", K=1, - test_frac=0.1, compact_dst=True): + test_frac=0.1, compact_dst=True, + train_ids=None, test_ids=None): super(Experiment2, self).__init__(embeddings, nodes, edges, target, split_on=split_on, neg_sampling_strategy=neg_sampling_strategy, K=K, test_frac=test_frac, - compact_dst=compact_dst) + compact_dst=compact_dst, train_ids=train_ids, test_ids=test_ids) # TODO # make sure that compact_property work always identically between runs - self.name_map = compact_property(target['dst']) + self.name_map, self.inv_index = compact_property(target['dst'], return_order=True) self.dst_orig = target['dst'] target['orig_dst'] = target['dst'] target['dst'] = target['dst'].apply(lambda name: self.name_map[name]) - print(f"Doing experiment with {len(self.name_map)} distinct target targets") + print(f"Doing experiment with {len(self.name_map)} distinct targets") self.unique_src = self.target['src'].unique() self.unique_dst = self.target['dst'].unique() @@ -663,7 +791,15 @@ def __init__(self, embeddings: Embedder, neg_sampling_strategy=neg_sampling_strategy, K=K, test_frac=test_frac, compact_dst=compact_dst) - from SourceCodeTools.models.graph.ElementEmbedder import window, hashstr, create_fixed_length + from SourceCodeTools.models.graph.ElementEmbedder import create_fixed_length + import hashlib + def hashstr(s, num_buckets): + return int(hashlib.md5(s.encode('utf8')).hexdigest(), 16) % num_buckets + + def window(x, gram_size): + x = "<" + x + ">" + length = len(x) + return (x[i:i + gram_size] for i in range(0, length) if i+gram_size<=length) reprs = target['orig_dst'].map(lambda x: window(x, gram_size)) \ .map(lambda grams: (hashstr(g, num_buckets) for g in grams)) \ @@ -684,33 +820,40 @@ def get_negative_out(self, num): class Experiment3(Experiment2): - def __init__(self, embeddings: Embedder, - nodes: pandas.DataFrame, - edges: pandas.DataFrame, - target: pandas.DataFrame, - split_on="nodes", - neg_sampling_strategy="word2vec", - K=1, - test_frac=0.1, compact_dst=True - ): + def __init__( + self, embeddings: Embedder, + nodes: pandas.DataFrame, + edges: pandas.DataFrame, + target: pandas.DataFrame, + split_on="nodes", + neg_sampling_strategy="word2vec", + K=1, + test_frac=0.1, compact_dst=True, + train_ids=None, test_ids=None + ): super(Experiment3, self).__init__(embeddings, nodes, edges, target, split_on=split_on, neg_sampling_strategy=neg_sampling_strategy, K=K, test_frac=test_frac, - compact_dst=compact_dst) + compact_dst=compact_dst, train_ids=train_ids, test_ids=test_ids) def get_training_data(self): # self.get_train_test_split() - train_positive = self.target.iloc[self.train_edge_ind].values - test_positive = self.target.iloc[self.test_edge_ind].values + if self.train_ids is not None and self.test_ids is not None: + test_positive = self.target[ + self.target["src"].apply(lambda x: x in self.test_ids) + ].values + chosen_for_test = set(test_positive[:, 0].tolist()) + train_positive = self.target[ + self.target["src"].apply(lambda x: x in self.train_ids and x not in chosen_for_test) + ].values - X_train, y_train = train_positive[:, 0].reshape(-1, 1), train_positive[:, 1].reshape(-1, 1) - X_test, y_test = test_positive[:, 0].reshape(-1, 1), test_positive[:, 1].reshape(-1, 1) + else: + train_positive = self.target.iloc[self.train_edge_ind].values + test_positive = self.target.iloc[self.test_edge_ind].values - def shuffle(X, y): - ind_shuffle = np.arange(0, X.shape[0]) - np.random.shuffle(ind_shuffle) - return X[ind_shuffle], y[ind_shuffle] + X_train, y_train = train_positive[:, 0].reshape(-1, 1).astype(np.int32), train_positive[:, 1].reshape(-1, 1).astype(np.int32) + X_test, y_test = test_positive[:, 0].reshape(-1, 1).astype(np.int32), test_positive[:, 1].reshape(-1, 1).astype(np.int32) self.X_train, self.y_train = shuffle(X_train, y_train) self.X_test, self.y_test = shuffle(X_test, y_test) diff --git a/SourceCodeTools/code/experiments/classifiers.py b/SourceCodeTools/code/experiments/classifiers.py index ff037f9b..3969021f 100644 --- a/SourceCodeTools/code/experiments/classifiers.py +++ b/SourceCodeTools/code/experiments/classifiers.py @@ -1,3 +1,5 @@ +import logging + import tensorflow as tf from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input, Embedding, concatenate @@ -95,6 +97,9 @@ class NodeClassifier(Model): def __init__(self, node_emb_size, n_classes, h_size=None): super(NodeClassifier, self).__init__() + self.proj = Dense(node_emb_size, use_bias=False) + logging.warning("Using projection matrix") + if h_size is None: h_size = [30, 15] @@ -115,7 +120,7 @@ def __init__(self, node_emb_size, n_classes, h_size=None): # self.logits = Dense(n_classes, input_shape=(h_size[1],)) def __call__(self, x, **kwargs): - h = x + h = self.proj(x) for l in self.layers_: h = l(h) return h diff --git a/SourceCodeTools/code/experiments/run_experiment.py b/SourceCodeTools/code/experiments/run_experiment.py index 7c287c8d..7bf1ff3d 100644 --- a/SourceCodeTools/code/experiments/run_experiment.py +++ b/SourceCodeTools/code/experiments/run_experiment.py @@ -1,5 +1,7 @@ import os +from sklearn import metrics + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' @@ -22,6 +24,168 @@ all_experiments = link_prediction_experiments + name_prediction_experiments + node_classification_experiments + name_subword_predictor +class Tracker: + def __init__(self, inverted_index=None): + self.all_true = [] + self.all_estimated = [] + self.all_emb = [] + self.inv_index = inverted_index + + def add(self, embs, pred, true): + self.all_true.append(true) + self.all_estimated.append(pred) + self.all_emb.append(embs) + + def decode_label_names(self, ids): + assert self.inv_index is not None, "Need inverted index" + return list(map(lambda x: self.inv_index[x], ids)) + + @property + def true_labels(self): + return np.concatenate(self.all_true, axis=0).reshape(-1,).tolist() + + @property + def true_label_names(self): + return self.decode_label_names(self.true_labels) + + @property + def pred_labels(self): + return np.argmax(np.concatenate(self.all_estimated, axis=0), axis=-1).reshape(-1, ).tolist() + + @property + def pred_label_names(self): + return self.decode_label_names(self.pred_labels) + + @property + def pred_scores(self): + return np.concatenate(self.all_estimated, axis=0) + + @property + def embeddings(self): + return np.concatenate(self.all_emb, axis=0) + + def clear(self): + self.all_true.clear() + self.all_estimated.clear() + self.all_emb.clear() + + def save_embs_for_tb(self, save_name): + assert self.inv_index is not None, "Cannot export for tensorboard without metadata" + np.savetxt(f"{save_name}_embeddings.tsv", self.embeddings, delimiter="\t") + with open(f"{save_name}_meta.tsv", "w") as meta_sink: + for label in list(map(lambda x: self.inv_index[x], self.true_labels)): + meta_sink.write(f"{label}\n") + + def save_umap(self, save_name): + type_freq = { + "str": 532, + "Optional": 232, + "int": 206, + "Any": 171, + "Union": 156, + "bool": 143, + "Callable": 80, + "Dict": 77, + "bytes": 58, + "float": 48 + } + from umap import UMAP + import matplotlib.pyplot as plt + plt.rcParams.update({'font.size': 5}) + reducer = UMAP(50) + embedding = reducer.fit_transform(self.embeddings) + + labels = list(map(lambda x: self.inv_index[x], self.true_labels)) + unique_labels = sorted(list(set(labels))) + + plt.figure(figsize=(4,4)) + legend = [] + for label in unique_labels: + if label not in type_freq: + continue + xs = [] + ys = [] + for lbl, (x, y) in zip(labels, embedding): + if lbl == label: + xs.append(x) + ys.append(y) + plt.scatter(xs, ys, 1.) + legend.append(label) + plt.axis('off') + plt.legend(legend) + plt.savefig(f"{save_name}_umap.pdf") + plt.close() + # plt.show() + + + + def get_metrics(self): + all_true = self.true_labels + all_scores = self.pred_scores + + metric_dict = {} + + for k in [1,3,5]: + metric_dict[f"Acc@{k}"] = metrics.top_k_accuracy_score(y_true=all_true, y_score=all_scores, k=k, labels=list(range(all_scores.shape[1]))) + + return metric_dict + + def save_confusion_matrix(self, save_path): + estimate_confusion( + self.pred_label_names, + self.true_label_names, + save_path=save_path + ) + + +def estimate_confusion(pred, true, save_path): + pred_filtered = pred + true_filtered = true + + import matplotlib.pyplot as plt + plt.rcParams.update({'font.size': 50}) + + labels = sorted(list(set(true_filtered + pred_filtered))) + label2ind = dict(zip(labels, range(len(labels)))) + + confusion = np.zeros((len(labels), len(labels))) + + for pred, true in zip(pred_filtered, true_filtered): + confusion[label2ind[true], label2ind[pred]] += 1 + + norm = np.array([x if x != 0 else 1. for x in np.sum(confusion, axis=1)]).reshape(-1,1) + confusion /= norm + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(45,45)) + from matplotlib.pyplot import cm + im = ax.imshow(confusion, interpolation='nearest', cmap=cm.Blues) + + # We want to show all ticks... + ax.set_xticks(np.arange(len(labels))) + ax.set_yticks(np.arange(len(labels))) + # ... and label them with the respective list entries + ax.set_xticklabels(labels) + ax.set_yticklabels(labels) + + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", + rotation_mode="anchor") + + # Loop over data dimensions and create text annotations. + for i in range(len(labels)): + for j in range(len(labels)): + text = ax.text(j, i, f"{confusion[i, j]: .2f}", + ha="center", va="center", color="w") + + ax.set_title("Confusion matrix for Python type prediction") + fig.tight_layout() + # plt.show() + plt.savefig(save_path) + plt.close() + + def run_experiment(e, experiment_name, args): experiment = e[experiment_name] @@ -70,8 +234,8 @@ def run_experiment(e, experiment_name, args): test_loss = tf.keras.metrics.Mean(name='test_loss') test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') - @tf.function - def train_step(batch): + # @tf.function + def train_step(batch, tracker=None): with tf.GradientTape() as tape: # training=True is only needed if there are layers with different # behavior during training versus inference (e.g. Dropout). @@ -80,22 +244,30 @@ def train_step(batch): gradients = tape.gradient(loss, clf.trainable_variables) optimizer.apply_gradients(zip(gradients, clf.trainable_variables)) + if tracker is not None: + tracker.add(batch["x"], predictions.numpy(), batch["y"]) + train_loss(loss) train_accuracy(batch["y"], predictions) - @tf.function - def test_step(batch): + # @tf.function + def test_step(batch, tracker=None): # training=False is only needed if there are layers with different # behavior during training versus inference (e.g. Dropout). predictions = clf(**batch, training=False) t_loss = loss_object(batch["y"], predictions) + if tracker is not None: + tracker.add(batch["x"], predictions.numpy(), batch["y"]) + test_loss(t_loss) test_accuracy(batch["y"], predictions) - args.epochs = 500 - + trains = [] tests = [] + metrics = [] + + test_tracker = Tracker(inverted_index=experiment.inv_index if hasattr(experiment, "inv_index") else None) for epoch in range(args.epochs): # Reset the metrics at the start of the next epoch @@ -108,12 +280,17 @@ def test_step(batch): train_step(batch) if epoch % 1 == 0: + + test_tracker.clear() + for batch in experiment.test_batches(): - test_step(batch) + test_step(batch, tracker=test_tracker) ma_train = train_accuracy.result() * 100 * ma_alpha + ma_train * (1 - ma_alpha) ma_test = test_accuracy.result() * 100 * ma_alpha + ma_test * (1 - ma_alpha) + trains.append(ma_train) tests.append(ma_test) + metrics.append(test_tracker.get_metrics()) # template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.4f}, Test Loss: {:.4f}, Test Accuracy: {:.4f}, Average Test {:.4f}' # print(template.format(epoch+1, @@ -123,9 +300,26 @@ def test_step(batch): # test_accuracy.result()*100, # ma_test)) + # plot confusion matrix + if hasattr(experiment, "inv_index") and args.out_path is not None: + test_tracker.save_confusion_matrix(save_path=args.out_path) + + if hasattr(experiment, "inv_index") and args.emb_out: + test_tracker.save_embs_for_tb(save_name=os.path.join(args.out_path, "tb")) + test_tracker.save_umap(save_name=os.path.join(args.out_path, "tb")) + + print(metrics[tests.index(max(tests))]) + # ma_train = train_accuracy.result() * 100 * ma_alpha + ma_train * (1 - ma_alpha) # ma_test = test_accuracy.result() * 100 * ma_alpha + ma_test * (1 - ma_alpha) + import matplotlib.pyplot as plt + plt.plot(trains) + plt.plot(tests) + plt.legend(["Train", "Test"]) + # plt.savefig(os.path.join(args.base_path, f"{args.experiment}.png")) + plt.show() + return ma_train, max(tests) @@ -146,15 +340,24 @@ def test_step(batch): parser.add_argument("--type_link", default=None, help="") parser.add_argument("--type_link_train", default=None, help="") parser.add_argument("--type_link_test", default=None, help="") - parser.add_argument("--epochs", default=500, type=int, help="") + parser.add_argument("--epochs", default=1, type=int, help="") parser.add_argument("--name_emb_dim", default=100, type=int, help="") parser.add_argument("--element_predictor_h_size", default=50, type=int, help="") parser.add_argument("--link_predictor_h_size", default="[20]", type=str, help="") parser.add_argument("--node_classifier_h_size", default="[30,15]", type=str, help="") + parser.add_argument("--out_path", default=None, type=str, help="") + # parser.add_argument("--confusion_out_path", default=None, type=str, help="") + parser.add_argument("--trials", default=1, type=int, help="") + parser.add_argument("--emb_out", default=False, action="store_true", help="") + # parser.add_argument("--emb_out", default=None, type=str, help="") + parser.add_argument("--embeddings", default=None) parser.add_argument('--random', action='store_true') parser.add_argument('--test_embedder', action='store_true') args = parser.parse_args() + if args.out_path is None: + args.out_path = os.path.dirname(args.embeddings) + print(args.__dict__) e = Experiments(base_path=args.base_path, @@ -166,15 +369,17 @@ def test_step(batch): function_name_path=None, type_ann=args.type_ann, gnn_layer=-1, + embeddings_path=args.embeddings ) experiments = args.experiment.split(",") for experiment_name in experiments: - print(f"\n{experiment_name}:") - try: - train_acc, test_acc = run_experiment(e, experiment_name, args) - print("Train Accuracy: {:.4f}, Test Accuracy: {:.4f}".format(train_acc, test_acc)) - except ValueError as err: - print(err) - print("\n") + for trial in range(args.trials): + print(f"\n{experiment_name}, trial {trial}:") + try: + train_acc, test_acc = run_experiment(e, experiment_name, args) + print(f"Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}") + except ValueError as err: + print(err) + print("\n") diff --git a/SourceCodeTools/mltools/torch/__init__.py b/SourceCodeTools/mltools/torch/__init__.py new file mode 100644 index 00000000..47e7becd --- /dev/null +++ b/SourceCodeTools/mltools/torch/__init__.py @@ -0,0 +1,5 @@ +import torch + + +def compute_accuracy(pred_, true_): + return torch.sum(pred_ == true_).item() / len(true_) \ No newline at end of file diff --git a/SourceCodeTools/models/graph/ElementEmbedder.py b/SourceCodeTools/models/graph/ElementEmbedder.py index 56b111a7..4ade9c23 100644 --- a/SourceCodeTools/models/graph/ElementEmbedder.py +++ b/SourceCodeTools/models/graph/ElementEmbedder.py @@ -8,17 +8,25 @@ from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase from SourceCodeTools.models.graph.train.Scorer import Scorer -from SourceCodeTools.models.nlp.Encoder import LSTMEncoder, Encoder +from SourceCodeTools.models.nlp.TorchEncoder import LSTMEncoder, Encoder class GraphLinkSampler(ElementEmbedderBase, Scorer): - def __init__(self, elements, nodes, compact_dst=True, dst_to_global=True, emb_size=None): + def __init__( + self, elements, nodes, compact_dst=True, dst_to_global=True, emb_size=None, device="cpu", + method="inner_prod", nn_index="brute", ns_groups=None + ): assert emb_size is not None - ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes, compact_dst=compact_dst, dst_to_global=dst_to_global) - Scorer.__init__(self, num_embs=len(self.elements["dst"].unique()), emb_size=emb_size, src2dst=self.element_lookup) + ElementEmbedderBase.__init__( + self, elements=elements, nodes=nodes, compact_dst=compact_dst, dst_to_global=dst_to_global + ) + Scorer.__init__( + self, num_embs=len(self.elements["dst"].unique()), emb_size=emb_size, src2dst=self.element_lookup, + device=device, method=method, index_backend=nn_index, ns_groups=ns_groups + ) def sample_negative(self, size, ids=None, strategy="closest"): - if strategy == "w2v": + if strategy == "w2v" or self.scorer_index is None: negative = ElementEmbedderBase.sample_negative(self, size) else: negative = Scorer.sample_closest_negative(self, ids, k=size // len(ids)) @@ -26,10 +34,12 @@ def sample_negative(self, size, ids=None, strategy="closest"): return negative -class ElementEmbedder(ElementEmbedderBase, nn.Module): +class ElementEmbedder(ElementEmbedderBase, nn.Module, Scorer): def __init__(self, elements, nodes, emb_size, compact_dst=True): ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes, compact_dst=compact_dst) nn.Module.__init__(self) + Scorer.__init__(self, num_embs=len(self.elements["dst"].unique()), emb_size=emb_size, + src2dst=self.element_lookup) self.emb_size = emb_size n_elems = self.elements['emb_id'].unique().size @@ -39,9 +49,30 @@ def __init__(self, elements, nodes, emb_size, compact_dst=True): def __getitem__(self, ids): return torch.LongTensor(ElementEmbedderBase.__getitem__(self, ids=ids)) + def sample_negative(self, size, ids=None, strategy="closest"): + # TODO + # Try other distributions + if strategy == "w2v": + negative = ElementEmbedderBase.sample_negative(self, size) + else: + ### negative = random.choices(Scorer.sample_closest_negative(self, ids), k=size) + negative = Scorer.sample_closest_negative(self, ids, k=size // len(ids)) + assert len(negative) == size + + return torch.LongTensor(negative) + def forward(self, input, **kwargs): return self.norm(self.embed(input)) + def set_embed(self): + all_keys = self.get_keys_for_scoring() + with torch.set_grad_enabled(False): + self.scorer_all_emb = self(torch.LongTensor(all_keys).to(self.embed.weight.device)).detach().cpu().numpy() + + def prepare_index(self): + self.set_embed() + Scorer.prepare_index(self) + # def hashstr(s, num_buckets): # return int(hashlib.md5(s.encode('utf8')).hexdigest(), 16) % num_buckets diff --git a/SourceCodeTools/models/graph/ElementEmbedderBase.py b/SourceCodeTools/models/graph/ElementEmbedderBase.py index 50d829b5..2b2727bd 100644 --- a/SourceCodeTools/models/graph/ElementEmbedderBase.py +++ b/SourceCodeTools/models/graph/ElementEmbedderBase.py @@ -16,7 +16,8 @@ def __init__(self, elements, nodes, compact_dst=True, dst_to_global=False): def init(self, compact_dst): if compact_dst: elem2id, self.inverse_dst_map = compact_property(self.elements['dst'], return_order=True) - self.elements['emb_id'] = self.elements['dst'].apply(lambda x: elem2id[x]) + self.elements['emb_id'] = self.elements['dst'].apply(lambda x: elem2id.get(x, -1)) + assert -1 not in self.elements['emb_id'].tolist() else: self.elements['emb_id'] = self.elements['dst'] @@ -49,7 +50,11 @@ def preprocess_element_data(self, element_data, nodes, compact_dst, dst_to_globa id2typedid = dict(zip(nodes['id'].tolist(), nodes['typed_id'].tolist())) id2type = dict(zip(nodes['id'].tolist(), nodes['type'].tolist())) - def get_node_pools(element_data): + self.id2nodeid = id2nodeid + self.id2typedid = id2typedid + self.id2type = id2type + + def get_node_pools(element_data): # create typed node list for possible use with dgl node_typed_pools = {} for orig_node_id in element_data['src']: global_id = id2nodeid.get(orig_node_id, None) @@ -128,9 +133,9 @@ def get_src_pool(self, ntypes=None): # return {ntype: set(self.elements.query(f"src_type == '{ntype}'")['src_typed_id'].tolist()) for ntype in ntypes} def _create_pools(self, train_idx, val_idx, test_idx, pool) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - train_idx = np.fromiter(pool.intersection(train_idx.tolist()), dtype=np.int64) - val_idx = np.fromiter(pool.intersection(val_idx.tolist()), dtype=np.int64) - test_idx = np.fromiter(pool.intersection(test_idx.tolist()), dtype=np.int64) + train_idx = np.fromiter(pool.intersection(train_idx.reshape((-1,)).tolist()), dtype=np.int64) + val_idx = np.fromiter(pool.intersection(val_idx.reshape((-1,)).tolist()), dtype=np.int64) + test_idx = np.fromiter(pool.intersection(test_idx.reshape((-1,)).tolist()), dtype=np.int64) return train_idx, val_idx, test_idx def create_idx_pools(self, train_idx, val_idx, test_idx): diff --git a/SourceCodeTools/models/graph/LinkPredictor.py b/SourceCodeTools/models/graph/LinkPredictor.py index 340803c6..5320e772 100644 --- a/SourceCodeTools/models/graph/LinkPredictor.py +++ b/SourceCodeTools/models/graph/LinkPredictor.py @@ -49,21 +49,74 @@ class BilinearLinkPedictor(nn.Module): def __init__(self, embedding_dim_1, embedding_dim_2, target_classes=2): super(BilinearLinkPedictor, self).__init__() - self.bilinear = nn.Bilinear(embedding_dim_1, embedding_dim_2, target_classes) + self.l1 = nn.Linear(embedding_dim_1, 300) + self.l2 = nn.Linear(embedding_dim_2, 300) + self.act = nn.Sigmoid() + self.bilinear = nn.Bilinear(300, 300, target_classes) def forward(self, x1, x2): - return self.bilinear(x1, x2) + return self.bilinear(self.act(self.l1(x1)), self.act(self.l2(x2))) class CosineLinkPredictor(nn.Module): - def __init__(self): + def __init__(self, margin=0.): + """ + Dummy link predictor, using to keep API the same + """ super(CosineLinkPredictor, self).__init__() self.cos = nn.CosineSimilarity() - self.max_margin = torch.Tensor([0.4]) + self.max_margin = torch.Tensor([margin])[0] def forward(self, x1, x2): if self.max_margin.device != x1.device: self.max_margin = self.max_margin.to(x1.device) + # this will not train logit = (self.cos(x1, x2) > self.max_margin).float().unsqueeze(1) - return torch.cat([1 - logit, logit], dim=1) \ No newline at end of file + return torch.cat([1 - logit, logit], dim=1) + + +class L2LinkPredictor(nn.Module): + def __init__(self, margin=1.): + super(L2LinkPredictor, self).__init__() + self.margin = torch.Tensor([margin])[0] + + def forward(self, x1, x2): + if self.margin.device != x1.device: + self.margin = self.margin.to(x1.device) + # this will not train + logit = (torch.norm(x1 - x2, dim=-1, keepdim=True) < self.margin).float() + return torch.cat([1 - logit, logit], dim=1) + + +class TransRLinkPredictor(nn.Module): + def __init__(self, input_dim, rel_dim, num_relations, margin=0.3): + super(TransRLinkPredictor, self).__init__() + self.rel_dim = rel_dim + self.input_dim = input_dim + self.margin = margin + + self.rel_emb = nn.Embedding(num_embeddings=num_relations, embedding_dim=rel_dim) + self.proj_matr = nn.Embedding(num_embeddings=num_relations, embedding_dim=input_dim * rel_dim) + self.triplet_loss = nn.TripletMarginLoss(margin=margin) + + def forward(self, a, p, n, labels): + weights = self.proj_matr(labels).reshape((-1, self.rel_dim, self.input_dim)) + rels = self.rel_emb(labels) + m_a = (weights * a.unsqueeze(1)).sum(-1) + m_p = (weights * p.unsqueeze(1)).sum(-1) + m_n = (weights * n.unsqueeze(1)).sum(-1) + + transl = m_a + rels + + sim = torch.norm(torch.cat([transl - m_p, transl - m_n], dim=0), dim=-1) + + # pos_diff = torch.norm(transl - m_p, dim=-1) + # neg_diff = torch.norm(transl - m_n, dim=-1) + + # loss = pos_diff + torch.maximum(torch.tensor([0.]).to(neg_diff.device), self.margin - neg_diff) + # return loss.mean(), sim + return self.triplet_loss(transl, m_p, m_n), sim < self.margin + + + diff --git a/SourceCodeTools/models/graph/NodeEmbedder.py b/SourceCodeTools/models/graph/NodeEmbedder.py index c6149723..28ab6015 100644 --- a/SourceCodeTools/models/graph/NodeEmbedder.py +++ b/SourceCodeTools/models/graph/NodeEmbedder.py @@ -7,6 +7,9 @@ class NodeEmbedder(nn.Module): def __init__(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=None): super(NodeEmbedder, self).__init__() + self.init(nodes, emb_size, dtype, n_buckets, pretrained) + + def init(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=None): self.emb_size = emb_size self.dtype = dtype if dtype is None: @@ -15,11 +18,13 @@ def __init__(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=Non self.buckets = None + embedding_field = "embeddable_name" + nodes_with_embeddings = nodes.query("embeddable == True")[ - ['global_graph_id', 'typed_id', 'type', 'type_backup', 'name'] + ['global_graph_id', 'typed_id', 'type', 'type_backup', embedding_field] ] - type_name = list(zip(nodes_with_embeddings['type_backup'], nodes_with_embeddings['name'])) + type_name = list(zip(nodes_with_embeddings['type_backup'], nodes_with_embeddings[embedding_field])) self.node_info = dict(zip( list(zip(nodes_with_embeddings['type'], nodes_with_embeddings['typed_id'])), @@ -39,7 +44,7 @@ def __init__(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=Non self._create_buckets_from_pretrained(pretrained) def _create_buckets(self): - self.buckets = nn.Embedding(self.n_buckets + 1, self.emb_size, padding_idx=self.n_buckets) + self.buckets = nn.Embedding(self.n_buckets + 1, self.emb_size, padding_idx=self.n_buckets, sparse=True) def _create_buckets_from_pretrained(self, pretrained): @@ -49,7 +54,7 @@ def _create_buckets_from_pretrained(self, pretrained): weights_with_pad = torch.tensor(np.vstack([pretrained, np.zeros((1, self.emb_size), dtype=np.float32)])) - self.buckets = nn.Embedding.from_pretrained(weights_with_pad, freeze=False, padding_idx=self.n_buckets) + self.buckets = nn.Embedding.from_pretrained(weights_with_pad, freeze=False, padding_idx=self.n_buckets, sparse=True) def _get_embedding_from_node_info(self, keys, node_info, masked=None): idxs = [] @@ -93,6 +98,42 @@ def forward(self, node_type=None, node_ids=None, train_embeddings=True, masked=N return self.get_embeddings(node_type, node_ids.tolist(), masked=masked) +class NodeIdEmbedder(NodeEmbedder): + def __init__(self, nodes=None, emb_size=None, dtype=None, n_buckets=500000, pretrained=None): + super(NodeIdEmbedder, self).__init__(nodes, emb_size, dtype, n_buckets, pretrained) + + def init(self, nodes, emb_size, dtype=None, n_buckets=500000, pretrained=None): + self.emb_size = emb_size + self.dtype = dtype + if dtype is None: + self.dtype = torch.float32 + self.n_buckets = n_buckets + + self.buckets = None + + embedding_field = "embeddable_name" + + nodes_with_embeddings = nodes.query("embeddable == True")[ + ['global_graph_id', 'typed_id', 'type', 'type_backup', embedding_field] + ] + + self.to_global_map = {} + for global_graph_id, typed_id, type_, type_backup, name in nodes_with_embeddings.values: + if type_ not in self.to_global_map: + self.to_global_map[type_] = {} + + self.to_global_map[type_][typed_id] = global_graph_id + + self._create_buckets() + + def get_embeddings(self, node_type=None, node_ids=None, masked=None): + assert node_ids is not None + if node_type is not None: + node_ids = list(map(lambda local_id: self.to_global_map[node_type][local_id], node_ids)) + + return self.buckets(torch.LongTensor(node_ids)) + + # class SimpleNodeEmbedder(nn.Module): # def __init__(self, dataset, emb_size, dtype=None, n_buckets=500000, pretrained=None): # super(SimpleNodeEmbedder, self).__init__() diff --git a/SourceCodeTools/models/graph/basis_gatconv.py b/SourceCodeTools/models/graph/basis_gatconv.py new file mode 100644 index 00000000..641f7ba9 --- /dev/null +++ b/SourceCodeTools/models/graph/basis_gatconv.py @@ -0,0 +1,237 @@ +from dgl import DGLError +from dgl.nn.functional import edge_softmax +from dgl.nn.pytorch import GATConv +from dgl.nn.pytorch.utils import Identity +from dgl.utils import expand_as_pair +import dgl.function as fn +from torch import nn, softmax +import torch as th +from torch.utils import checkpoint + + +class BasisGATConv(GATConv): + """ + Does not seem to improve memory requirements + """ + def __init__(self, + in_feats, + out_feats, + num_heads, + basis, + attn_basis, + basis_coef, + feat_drop=0., + attn_drop=0., + negative_slope=0.2, + residual=False, + activation=None, + allow_zero_in_degree=False, + bias=True, + use_checkpoint=False): + super(GATConv, self).__init__() + self._basis = basis + self._basis_coef = basis_coef + self._attn_basis = attn_basis + + self._num_heads = num_heads + self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) + self._out_feats = out_feats + self._allow_zero_in_degree = allow_zero_in_degree + # if isinstance(in_feats, tuple): + # self.fc_src = nn.Linear( + # self._in_src_feats, out_feats * num_heads, bias=False) + # self.fc_dst = nn.Linear( + # self._in_dst_feats, out_feats * num_heads, bias=False) + # else: + # self.fc = nn.Linear( + # self._in_src_feats, out_feats * num_heads, bias=False) + # self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) + # self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) + self.feat_drop = nn.Dropout(feat_drop) + self.attn_drop = nn.Dropout(attn_drop) + self.leaky_relu = nn.LeakyReLU(negative_slope) + if bias: + self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_feats,))) + else: + self.register_buffer('bias', None) + if residual: + if self._in_dst_feats != out_feats: + self.res_fc = nn.Linear( + self._in_dst_feats, num_heads * out_feats, bias=False) + else: + self.res_fc = Identity() + else: + self.register_buffer('res_fc', None) + self.reset_parameters() + self.activation = activation + + self.dummy_tensor = th.ones(1, dtype=th.float32, requires_grad=True) + self.use_checkpoint = use_checkpoint + + def reset_parameters(self): + """ + + Description + ----------- + Reinitialize learnable parameters. + + Note + ---- + The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. + The attention weights are using xavier initialization method. + """ + gain = nn.init.calculate_gain('relu') + # if hasattr(self, 'fc'): + # nn.init.xavier_normal_(self.fc.weight, gain=gain) + # else: + # nn.init.xavier_normal_(self.fc_src.weight, gain=gain) + # nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) + # nn.init.xavier_normal_(self.attn_l, gain=gain) + # nn.init.xavier_normal_(self.attn_r, gain=gain) + nn.init.constant_(self.bias, 0) + if isinstance(self.res_fc, nn.Linear): + nn.init.xavier_normal_(self.res_fc.weight, gain=gain) + + def set_allow_zero_in_degree(self, set_value): + r""" + + Description + ----------- + Set allow_zero_in_degree flag. + + Parameters + ---------- + set_value : bool + The value to be set to the flag. + """ + self._allow_zero_in_degree = set_value + + + def _forward(self, graph, feat, get_attention=False): + r""" + + Description + ----------- + Compute graph attention network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : torch.Tensor or pair of torch.Tensor + If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where + :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. + If a pair of torch.Tensor is given, the pair must contain two tensors of shape + :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. + get_attention : bool, optional + Whether to return the attention values. Default to False. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, H, D_{out})` where :math:`H` + is the number of heads, and :math:`D_{out}` is size of output feature. + torch.Tensor, optional + The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of + edges. This is returned only when :attr:`get_attention` is ``True``. + + Raises + ------ + DGLError + If there are 0-in-degree nodes in the input graph, it will raise DGLError + since no message will be passed to those nodes. This will cause invalid output. + The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``. + """ + with graph.local_scope(): + if not self._allow_zero_in_degree: + if (graph.in_degrees() == 0).any(): + raise DGLError('There are 0-in-degree nodes in the graph, ' + 'output for those nodes will be invalid. ' + 'This is harmful for some applications, ' + 'causing silent performance regression. ' + 'Adding self-loop on the input graph by ' + 'calling `g = dgl.add_self_loop(g)` will resolve ' + 'the issue. Setting ``allow_zero_in_degree`` ' + 'to be `True` when constructing this module will ' + 'suppress the check and let the code run.') + + if isinstance(feat, tuple): + h_src = self.feat_drop(feat[0]) + h_dst = self.feat_drop(feat[1]) + basis_coef = softmax(self._basis_coef, dim=-1).reshape(-1, 1, 1) + # if not hasattr(self, 'fc_src'): + params_src = (self._basis[0] * basis_coef).sum(dim=0) + params_dst = (self._basis[1] * basis_coef).sum(dim=0) + feat_src = (params_src @ h_src.T).view(-1, self._num_heads, self._out_feats) + feat_dst = (params_dst @ h_dst.T).view(-1, self._num_heads, self._out_feats) + # # feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats) + # # feat_dst = self.fc(h_dst).view(-1, self._num_heads, self._out_feats) + # else: + # params = self._basis * basis_coef + # feat_src = (params @ h_src.T).view(-1, self._num_heads, self._out_feats) + # feat_dst = (params @ h_dst.T).view(-1, self._num_heads, self._out_feats) + # # feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) + # # feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) + else: + h_src = h_dst = self.feat_drop(feat) + basis_coef = softmax(self._basis_coef, dim=-1).reshape(-1, 1, 1) + params = (self._basis * basis_coef).sum(dim=0) + feat_src = feat_dst = (params @ h_src.T).view(-1, self._num_heads, self._out_feats) + # feat_src = feat_dst = self.fc(h_src).view( + # -1, self._num_heads, self._out_feats) + if graph.is_block: + feat_dst = feat_src[:graph.number_of_dst_nodes()] + # NOTE: GAT paper uses "first concatenation then linear projection" + # to compute attention scores, while ours is "first projection then + # addition", the two approaches are mathematically equivalent: + # We decompose the weight vector a mentioned in the paper into + # [a_l || a_r], then + # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j + # Our implementation is much efficient because we do not need to + # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, + # addition could be optimized with DGL's built-in function u_add_v, + # which further speeds up computation and saves memory footprint. + attn_l_param = (self._attn_basis[0] * basis_coef).sum(dim=0) + attn_r_param = (self._attn_basis[1] * basis_coef).sum(dim=0) + el = (feat_src * attn_l_param).sum(dim=-1).unsqueeze(-1) + er = (feat_dst * attn_r_param).sum(dim=-1).unsqueeze(-1) + # el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) + # er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) + graph.srcdata.update({'ft': feat_src, 'el': el}) + graph.dstdata.update({'er': er}) + # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. + graph.apply_edges(fn.u_add_v('el', 'er', 'e')) + e = self.leaky_relu(graph.edata.pop('e')) + # compute softmax + graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) + # message passing + graph.update_all(fn.u_mul_e('ft', 'a', 'm'), + fn.sum('m', 'ft')) + rst = graph.dstdata['ft'] + # residual + if self.res_fc is not None: + resval = self.res_fc(h_dst).view(h_dst.shape[0], self._num_heads, self._out_feats) + rst = rst + resval + # bias + if self.bias is not None: + rst = rst + self.bias.view(1, self._num_heads, self._out_feats) + # activation + if self.activation: + rst = self.activation(rst) + + if get_attention: + return rst, graph.edata['a'] + else: + return rst + + def custom(self, graph, get_attention): + def custom_forward(*inputs): + feat0, feat1, dummy = inputs + return self._forward(graph, (feat0, feat1), get_attention=get_attention) + return custom_forward + + def forward(self, graph, feat, get_attention=False): + if self.use_checkpoint: + return checkpoint.checkpoint(self.custom(graph, get_attention), feat[0], feat[1], self.dummy_tensor).squeeze(1) + else: + return self._forward(graph, feat, get_attention=get_attention).squeeze(1) \ No newline at end of file diff --git a/SourceCodeTools/models/graph/ggnn.py b/SourceCodeTools/models/graph/ggnn.py index a0716fbf..f082f6ed 100644 --- a/SourceCodeTools/models/graph/ggnn.py +++ b/SourceCodeTools/models/graph/ggnn.py @@ -4,7 +4,7 @@ import dgl.function as fn from dgl.nn.pytorch.conv import GatedGraphConv -from SourceCodeTools.code.data.sourcetrail.Dataset import compact_property +from SourceCodeTools.code.data.dataset.Dataset import compact_property from SourceCodeTools.models.Embedder import Embedder diff --git a/SourceCodeTools/models/graph/rgcn_sampling.py b/SourceCodeTools/models/graph/rgcn_sampling.py index 42ed4536..ae71a788 100644 --- a/SourceCodeTools/models/graph/rgcn_sampling.py +++ b/SourceCodeTools/models/graph/rgcn_sampling.py @@ -7,6 +7,7 @@ import dgl.nn as dglnn # import tqdm from torch.utils import checkpoint +from tqdm import tqdm from SourceCodeTools.models.Embedder import Embedder @@ -38,9 +39,9 @@ def custom_forward(*inputs): def forward(self, graph, feat): if self.use_checkpoint: - return checkpoint.checkpoint(self.custom(graph), feat[0], feat[1], self.dummy_tensor) + return checkpoint.checkpoint(self.custom(graph), feat[0], feat[1], self.dummy_tensor) #.squeeze(1) else: - return super(CkptGATConv, self).forward(graph, feat) + return super(CkptGATConv, self).forward(graph, feat) #.squeeze(1) class RelGraphConvLayer(nn.Module): @@ -71,6 +72,7 @@ def __init__(self, in_feat, out_feat, rel_names, + ntype_names, num_bases, *, weight=True, @@ -82,18 +84,13 @@ def __init__(self, self.in_feat = in_feat self.out_feat = out_feat self.rel_names = rel_names + self.ntype_names = ntype_names self.num_bases = num_bases self.bias = bias self.activation = activation self.self_loop = self_loop self.use_gcn_checkpoint = use_gcn_checkpoint - # TODO - # think of possibility switching to GAT - # rel : dglnn.GATConv(in_feat, out_feat, num_heads=4) - # rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False, allow_zero_in_degree=True) - self.create_conv(in_feat, out_feat, rel_names) - self.use_weight = weight self.use_basis = num_bases < len(self.rel_names) and weight if self.use_weight: @@ -104,17 +101,33 @@ def __init__(self, # nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_normal_(self.weight) + # TODO + # think of possibility switching to GAT + # rel : dglnn.GATConv(in_feat, out_feat, num_heads=4) + # rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False, allow_zero_in_degree=True) + self.create_conv(in_feat, out_feat, rel_names) + # bias if bias: - self.h_bias = nn.Parameter(th.Tensor(out_feat)) - nn.init.zeros_(self.h_bias) + self.bias_dict = nn.ParameterDict() + for ntype_name in self.ntype_names: + self.bias_dict[ntype_name] = nn.Parameter(torch.Tensor(1, out_feat)) + nn.init.normal_(self.bias_dict[ntype_name]) + # self.h_bias = nn.Parameter(th.Tensor(1, out_feat)) + # nn.init.normal_(self.h_bias) # weight for self loop if self.self_loop: + # if self.use_basis: + # self.loop_weight_basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.ntype_names)) + # else: + # self.loop_weight = nn.Parameter(th.Tensor(len(self.ntype_names), in_feat, out_feat)) + # # nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) + # nn.init.xavier_normal_(self.loop_weight) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) - # nn.init.xavier_uniform_(self.loop_weight, - # gain=nn.init.calculate_gain('relu')) - nn.init.xavier_normal_(self.loop_weight) + nn.init.xavier_uniform_(self.loop_weight, + gain=nn.init.calculate_gain('tanh')) + # # nn.init.xavier_normal_(self.loop_weight) self.dropout = nn.Dropout(dropout) @@ -145,6 +158,10 @@ def forward(self, g, inputs): weight = self.basis() if self.use_basis else self.weight wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)} for i, w in enumerate(th.split(weight, 1, dim=0))} + # if self.self_loop: + # self_loop_weight = self.loop_weight_basis() if self.use_basis else self.loop_weight + # self_loop_wdict = {self.ntype_names[i]: w.squeeze(0) + # for i, w in enumerate(th.split(self_loop_weight, 1, dim=0))} else: wdict = {} @@ -160,9 +177,11 @@ def forward(self, g, inputs): def _apply(ntype, h): if self.self_loop: + # h = h + th.matmul(inputs_dst[ntype], self_loop_wdict[ntype]) h = h + th.matmul(inputs_dst[ntype], self.loop_weight) if self.bias: - h = h + self.h_bias + h = h + self.bias_dict[ntype] + # h = h + self.h_bias if self.activation: h = self.activation(h) return self.dropout(h) @@ -171,47 +190,47 @@ def _apply(ntype, h): # return {ntype: _apply(ntype, h) for ntype, h in hs.items()} return {ntype : _apply(ntype, h).mean(1) for ntype, h in hs.items()} -class RelGraphEmbed(nn.Module): - r"""Embedding layer for featureless heterograph.""" - def __init__(self, - g, - embed_size, - embed_name='embed', - activation=None, - dropout=0.0): - super(RelGraphEmbed, self).__init__() - self.g = g - self.embed_size = embed_size - self.embed_name = embed_name - # self.activation = activation - # self.dropout = nn.Dropout(dropout) - - # create weight embeddings for each node for each relation - self.embeds = nn.ParameterDict() - for ntype in g.ntypes: - embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size)) - # TODO - # watch for activation in init - # nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu')) - nn.init.xavier_normal_(embed) - self.embeds[ntype] = embed - - def forward(self, block=None): - """Forward computation - - Parameters - ---------- - block : DGLHeteroGraph, optional - If not specified, directly return the full graph with embeddings stored in - :attr:`embed_name`. Otherwise, extract and store the embeddings to the block - graph and return. - - Returns - ------- - DGLHeteroGraph - The block graph fed with embeddings. - """ - return self.embeds +# class RelGraphEmbed(nn.Module): +# r"""Embedding layer for featureless heterograph.""" +# def __init__(self, +# g, +# embed_size, +# embed_name='embed', +# activation=None, +# dropout=0.0): +# super(RelGraphEmbed, self).__init__() +# self.g = g +# self.embed_size = embed_size +# self.embed_name = embed_name +# # self.activation = activation +# # self.dropout = nn.Dropout(dropout) +# +# # create weight embeddings for each node for each relation +# self.embeds = nn.ParameterDict() +# for ntype in g.ntypes: +# embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size)) +# # TODO +# # watch for activation in init +# # nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu')) +# nn.init.xavier_normal_(embed) +# self.embeds[ntype] = embed +# +# def forward(self, block=None): +# """Forward computation +# +# Parameters +# ---------- +# block : DGLHeteroGraph, optional +# If not specified, directly return the full graph with embeddings stored in +# :attr:`embed_name`. Otherwise, extract and store the embeddings to the block +# graph and return. +# +# Returns +# ------- +# DGLHeteroGraph +# The block graph fed with embeddings. +# """ +# return self.embeds class RGCNSampling(nn.Module): def __init__(self, @@ -231,6 +250,8 @@ def __init__(self, self.rel_names = list(set(g.etypes)) self.rel_names.sort() + self.ntype_names = list(set(g.etypes)) + self.ntype_names.sort() if num_bases < 0 or num_bases > len(self.rel_names): self.num_bases = len(self.rel_names) else: @@ -244,14 +265,14 @@ def __init__(self, self.layer_norm = nn.ModuleList() # i2h self.layers.append(RelGraphConvLayer( - self.h_dim, self.h_dim, self.rel_names, + self.h_dim, self.h_dim, self.rel_names, self.ntype_names, self.num_bases, activation=self.activation, self_loop=self.use_self_loop, dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint)) self.layer_norm.append(nn.LayerNorm([self.h_dim])) # h2h for i in range(self.num_hidden_layers): self.layers.append(RelGraphConvLayer( - self.h_dim, self.h_dim, self.rel_names, + self.h_dim, self.h_dim, self.rel_names, self.ntype_names, self.num_bases, activation=self.activation, self_loop=self.use_self_loop, dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint)) # changed weight for GATConv self.layer_norm.append(nn.LayerNorm([self.h_dim])) @@ -260,7 +281,7 @@ def __init__(self, # weight=False # h2o self.layers.append(RelGraphConvLayer( - self.h_dim, self.out_dim, self.rel_names, + self.h_dim, self.out_dim, self.rel_names, self.ntype_names, self.num_bases, activation=None, self_loop=self.use_self_loop, weight=False, use_gcn_checkpoint=use_gcn_checkpoint)) # changed weight for GATConv self.layer_norm.append(nn.LayerNorm([self.out_dim])) @@ -314,9 +335,10 @@ def forward(self, h, blocks=None, # all_layers.append(h) # added this as an experimental feature for intermediate supervision # else: # minibatch training + h0 = h for layer, norm, block in zip(self.layers, self.layer_norm, blocks): # h = checkpoint.checkpoint(self.custom(layer), block, h) - h = layer(block, h) + h = layer(block, h, h0) h = self.normalize(h, norm) all_layers.append(h) # added this as an experimental feature for intermediate supervision @@ -332,6 +354,7 @@ def inference(self, batch_size, device, num_workers, x=None): For node classification, the model is trained to predict on only one node type's label. Therefore, only that type's final representation is meaningful. """ + h0 = x with th.set_grad_enabled(False): @@ -355,7 +378,7 @@ def inference(self, batch_size, device, num_workers, x=None): drop_last=False, num_workers=num_workers) - for input_nodes, output_nodes, blocks in dataloader:#tqdm.tqdm(dataloader): + for input_nodes, output_nodes, blocks in tqdm(dataloader, desc=f"Layer {l}"): block = blocks[0].to(device) if not isinstance(input_nodes, dict): @@ -363,8 +386,9 @@ def inference(self, batch_size, device, num_workers, x=None): input_nodes = {key: input_nodes} output_nodes = {key: output_nodes} + _h0 = {k: h0[k][input_nodes[k]].to(device) for k in input_nodes.keys()} h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()} - h = layer(block, h) + h = layer(block, h, _h0) h = self.normalize(h, norm) for k in h.keys(): diff --git a/SourceCodeTools/models/graph/rggan.py b/SourceCodeTools/models/graph/rggan.py index ec5fc160..54fb2ce8 100644 --- a/SourceCodeTools/models/graph/rggan.py +++ b/SourceCodeTools/models/graph/rggan.py @@ -1,5 +1,7 @@ +import torch from torch.utils import checkpoint +# from SourceCodeTools.models.graph.basis_gatconv import BasisGATConv from SourceCodeTools.models.graph.rgcn_sampling import RGCNSampling, RelGraphConvLayer, CkptGATConv import torch as th @@ -37,7 +39,7 @@ def do_stuff(self, query, key, value, dummy=None): def forward(self, list_inputs, dsttype): # pylint: disable=unused-argument if len(list_inputs) == 1: return list_inputs[0] - key = value = th.stack(list_inputs).squeeze(dim=1) + key = value = th.stack(list_inputs)#.squeeze(dim=1) query = self.query_emb(th.LongTensor([token_hasher(dsttype, self.num_query_buckets)]).to(self.att.in_proj_bias.device)).unsqueeze(0).repeat(1, key.shape[1], 1) # query = self.query_emb[token_hasher(dsttype, self.num_query_buckets)].unsqueeze(0).repeat(1, key.shape[1], 1) if self.use_checkpoint: @@ -45,14 +47,42 @@ def forward(self, list_inputs, dsttype): # pylint: disable=unused-argument else: att_out, att_w = self.do_stuff(query, key, value) # att_out, att_w = self.att(query, key, value) + # return att_out.mean(0)#.unsqueeze(1) return att_out.mean(0).unsqueeze(1) +from torch.nn import init +class NZBiasGraphConv(dglnn.GraphConv): + def __init__(self, *args, **kwargs): + super(NZBiasGraphConv, self).__init__(*args, **kwargs) + self.use_checkpoint = True + self.dummy_tensor = th.ones(1, dtype=th.float32, requires_grad=True) + + def reset_parameters(self): + if self.weight is not None: + init.xavier_uniform_(self.weight) + if self.bias is not None: + init.normal_(self.bias) + + def custom(self, graph): + def custom_forward(*inputs): + feat0, feat1, weight, _ = inputs + return super(NZBiasGraphConv, self).forward(graph, (feat0, feat1), weight=weight) + return custom_forward + + def forward(self, graph, feat, weight=None): + if self.use_checkpoint: + return checkpoint.checkpoint(self.custom(graph), feat[0], feat[1], weight, self.dummy_tensor) #.squeeze(1) + else: + return super(NZBiasGraphConv, self).forward(graph, feat, weight=weight) #.squeeze(1) + + class RGANLayer(RelGraphConvLayer): def __init__(self, in_feat, out_feat, rel_names, + ntype_names, num_bases, *, weight=True, @@ -62,17 +92,40 @@ def __init__(self, dropout=0.0, use_gcn_checkpoint=False, use_att_checkpoint=False): self.use_att_checkpoint = use_att_checkpoint super(RGANLayer, self).__init__( - in_feat, out_feat, rel_names, num_bases, + in_feat, out_feat, rel_names, ntype_names, num_bases, weight=weight, bias=bias, activation=activation, self_loop=self_loop, dropout=dropout, use_gcn_checkpoint=use_gcn_checkpoint ) def create_conv(self, in_feat, out_feat, rel_names): - self.attentive_aggregator = AttentiveAggregator(out_feat, use_checkpoint=self.use_att_checkpoint) + # self.attentive_aggregator = AttentiveAggregator(out_feat, use_checkpoint=self.use_att_checkpoint) + # + # num_heads = 1 + # basis_size = 10 + # basis = torch.nn.Parameter(torch.Tensor(2, basis_size, in_feat, out_feat * num_heads)) + # attn_basis = torch.nn.Parameter(torch.Tensor(2, basis_size, num_heads, out_feat)) + # basis_coef = nn.ParameterDict({rel: torch.nn.Parameter(torch.rand(basis_size,)) for rel in rel_names}) + # + # torch.nn.init.xavier_normal_(basis, gain=1.) + # torch.nn.init.xavier_normal_(attn_basis, gain=1.) + self.conv = dglnn.HeteroGraphConv({ - rel: CkptGATConv((in_feat, in_feat), out_feat, num_heads=1, use_checkpoint=self.use_gcn_checkpoint) + rel: NZBiasGraphConv( + in_feat, out_feat, norm='right', weight=False, bias=True, allow_zero_in_degree=True,# activation=self.activation + ) + # rel: BasisGATConv( + # (in_feat, in_feat), out_feat, num_heads=num_heads, + # basis=basis, + # attn_basis=attn_basis, + # basis_coef=basis_coef[rel], + # use_checkpoint=self.use_gcn_checkpoint + # ) for rel in rel_names - }, aggregate=self.attentive_aggregator) + }, aggregate="mean") #self.attentive_aggregator) + # self.conv = dglnn.HeteroGraphConv({ + # rel: CkptGATConv((in_feat, in_feat), out_feat, num_heads=num_heads, use_checkpoint=self.use_gcn_checkpoint) + # for rel in rel_names + # }, aggregate="mean") #self.attentive_aggregator) class RGAN(RGCNSampling): @@ -92,6 +145,8 @@ def __init__(self, self.rel_names = list(set(g.etypes)) self.rel_names.sort() + self.ntype_names = list(set(g.ntypes)) + self.ntype_names.sort() if num_bases < 0 or num_bases > len(self.rel_names): self.num_bases = len(self.rel_names) else: @@ -103,14 +158,14 @@ def __init__(self, self.layers = nn.ModuleList() # i2h self.layers.append(RGANLayer( - self.h_dim, self.h_dim, self.rel_names, + self.h_dim, self.h_dim, self.rel_names, self.ntype_names, self.num_bases, activation=self.activation, self_loop=self.use_self_loop, dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint, use_att_checkpoint=use_att_checkpoint)) # h2h for i in range(self.num_hidden_layers): self.layers.append(RGANLayer( - self.h_dim, self.h_dim, self.rel_names, + self.h_dim, self.h_dim, self.rel_names, self.ntype_names, self.num_bases, activation=self.activation, self_loop=self.use_self_loop, dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint, use_att_checkpoint=use_att_checkpoint)) # changed weight for GATConv @@ -119,7 +174,7 @@ def __init__(self, # weight=False # h2o self.layers.append(RGANLayer( - self.h_dim, self.out_dim, self.rel_names, + self.h_dim, self.out_dim, self.rel_names, self.ntype_names, self.num_bases, activation=None, self_loop=self.use_self_loop, weight=False, use_gcn_checkpoint=use_gcn_checkpoint, use_att_checkpoint=use_att_checkpoint)) # changed weight for GATConv @@ -147,6 +202,7 @@ def __init__(self, dim, use_checkpoint=False): self.dummy_tensor = th.ones(1, dtype=th.float32, requires_grad=True) def do_stuff(self, x, h, dummy_tensor=None): + # x = x.unsqueeze(1) r = self.act_r(self.gru_rx(x) + self.gru_rh(h)) z = self.act_z(self.gru_zx(x) + self.gru_zh(h)) n = self.act_n(self.gru_nx(x) + self.gru_nh(r * h)) @@ -165,6 +221,7 @@ def __init__(self, in_feat, out_feat, rel_names, + ntype_names, num_bases, *, weight=True, @@ -173,14 +230,15 @@ def __init__(self, self_loop=False, dropout=0.0, use_gcn_checkpoint=False, use_att_checkpoint=False, use_gru_checkpoint=False): super(RGGANLayer, self).__init__( - in_feat, out_feat, rel_names, num_bases, weight=weight, bias=bias, activation=activation, + in_feat, out_feat, rel_names, ntype_names, num_bases, weight=weight, bias=bias, activation=activation, self_loop=self_loop, dropout=dropout, use_gcn_checkpoint=use_gcn_checkpoint, use_att_checkpoint=use_att_checkpoint ) + # self.mix_weights = nn.Parameter(torch.randn(3).reshape((3,1,1))) - self.gru = OneStepGRU(out_feat, use_checkpoint=use_gru_checkpoint) + # self.gru = OneStepGRU(out_feat, use_checkpoint=use_gru_checkpoint) - def forward(self, g, inputs): + def forward(self, g, inputs, h0): """Forward computation Parameters @@ -200,6 +258,10 @@ def forward(self, g, inputs): weight = self.basis() if self.use_basis else self.weight wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)} for i, w in enumerate(th.split(weight, 1, dim=0))} + # if self.self_loop: + # self_loop_weight = self.loop_weight_basis() if self.use_basis else self.loop_weight + # self_loop_wdict = {self.ntype_names[i]: w.squeeze(0) + # for i, w in enumerate(th.split(self_loop_weight, 1, dim=0))} else: wdict = {} @@ -215,9 +277,15 @@ def forward(self, g, inputs): def _apply(ntype, h): if self.self_loop: + # h = h + th.matmul(inputs_dst[ntype], self_loop_wdict[ntype]) h = h + th.matmul(inputs_dst[ntype], self.loop_weight) + # mix = nn.functional.softmax(self.mix_weights, dim=0) + # h = h - inputs_dst[ntype] + th.matmul(inputs_dst[ntype], self.loop_weight) #+ h0[ntype][:h.size(0), :] + # h = h + th.matmul(inputs_dst[ntype], self.loop_weight) + h0[ntype][:h.size(0), :] + # h = (torch.stack([h, th.matmul(inputs_dst[ntype], self.loop_weight), h0[ntype][:h.size(0), :]], dim=0) * mix).sum(0) if self.bias: - h = h + self.h_bias + h = h + self.bias_dict[ntype] + # h = h + self.h_bias if self.activation: h = self.activation(h) return self.dropout(h) @@ -226,10 +294,10 @@ def _apply(ntype, h): # TODO # think of possibility switching to GAT - # return {ntype: _apply(ntype, h) for ntype, h in hs.items()} - h_gru_input = {ntype : _apply(ntype, h) for ntype, h in hs.items()} - - return {dsttype: self.gru(h_dst, inputs_dst[dsttype].unsqueeze(1)).squeeze(dim=1) for dsttype, h_dst in h_gru_input.items()} + return {ntype: _apply(ntype, h) for ntype, h in hs.items()} + # h_gru_input = {ntype : _apply(ntype, h) for ntype, h in hs.items()} + # + # return {dsttype: self.gru(h_dst, inputs_dst[dsttype].unsqueeze(1)).squeeze(dim=1) for dsttype, h_dst in h_gru_input.items()} class RGGAN(RGAN): """A gated recurrent unit (GRU) cell @@ -246,9 +314,9 @@ class RGGAN(RGAN): where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.""" def __init__(self, g, - h_dim, num_classes, + h_dim, node_emb_size, num_bases, - num_steps=1, + n_layers=1, dropout=0, use_self_loop=False, activation=F.relu, @@ -256,13 +324,15 @@ def __init__(self, super(RGCNSampling, self).__init__() self.g = g self.h_dim = h_dim - self.out_dim = num_classes + self.out_dim = node_emb_size self.activation = activation - assert h_dim == num_classes, f"Parameter h_dim and num_classes should be equal in {self.__class__.__name__}" + assert h_dim == node_emb_size, f"Parameter h_dim and num_classes should be equal in {self.__class__.__name__}" self.rel_names = list(set(g.etypes)) + self.ntype_names = list(set(g.ntypes)) self.rel_names.sort() + self.ntype_names.sort() if num_bases < 0 or num_bases > len(self.rel_names): self.num_bases = len(self.rel_names) else: @@ -273,16 +343,16 @@ def __init__(self, # i2h self.layer = RGGANLayer( - self.h_dim, self.h_dim, self.rel_names, + self.h_dim, self.h_dim, self.rel_names, self.ntype_names, self.num_bases, activation=self.activation, self_loop=self.use_self_loop, - dropout=self.dropout, weight=False, use_gcn_checkpoint=use_gcn_checkpoint, + dropout=self.dropout, weight=True, use_gcn_checkpoint=use_gcn_checkpoint, # : ) use_att_checkpoint=use_att_checkpoint, use_gru_checkpoint=use_gru_checkpoint ) # TODO # think of possibility switching to GAT # weight=False - self.emb_size = num_classes - self.num_layers = num_steps - self.layers = [self.layer] * num_steps + self.emb_size = node_emb_size + self.num_layers = n_layers + self.layers = [self.layer] * n_layers self.layer_norm = nn.ModuleList([nn.LayerNorm([self.h_dim]) for _ in range(self.num_layers)]) diff --git a/SourceCodeTools/models/graph/train/Scorer.py b/SourceCodeTools/models/graph/train/Scorer.py index 11fb736f..eba46e3a 100644 --- a/SourceCodeTools/models/graph/train/Scorer.py +++ b/SourceCodeTools/models/graph/train/Scorer.py @@ -1,47 +1,161 @@ import random -from collections import Iterable +import time +from collections import Iterable, defaultdict +from typing import Dict, List import torch import numpy as np -from sklearn.metrics import ndcg_score +from sklearn.metrics import ndcg_score, top_k_accuracy_score +from sklearn.neighbors import NearestNeighbors from sklearn.neighbors._ball_tree import BallTree from sklearn.preprocessing import normalize class FaissIndex: - def __init__(self, X, *args, **kwargs): + def __init__(self, X, method="inner_prod", *args, **kwargs): import faiss - self.index = faiss.IndexFlatL2(X.shape[1]) + self.method = method + if method == "inner_prod": + self.index = faiss.IndexFlatIP(X.shape[1]) + elif method == "l2": + self.index = faiss.IndexFlatL2(X.shape[1]) + else: + raise NotImplementedError() self.index.add(X.astype(np.float32)) - def query(self, X, k): + def query_inner_prod(self, X, k): + X = normalize(X, axis=1) return self.index.search(X.astype(np.float32), k=k) + def query_l2(self, X, k): + return self.index.search(X.astype(np.float32), k=k) + + def query(self, X, k): + if self.method == "inner_prod": + return self.query_inner_prod(X, k) + elif self.method == "l2": + return self.query_l2(X, k) + else: + raise NotImplementedError() + + +class Brute: + def __init__(self, X, method="inner_prod", device="cpu", *args, **kwargs): + self.vectors = X + self.method = method + self.device = device + + def query_inner_prod(self, X, k): + X = torch.Tensor(normalize(X, axis=1)).to(self.device) + vectors = torch.Tensor(self.vectors).to(self.device) + score = (vectors @ X.T).reshape(-1,).to("cpu").numpy() + ind = np.flip(np.argsort(score))[:k] + return score[ind][:k], ind + + def query_l2(self, X, k): + vectors = torch.Tensor(self.vectors).to(self.device) + X = torch.Tensor(X).to(self.device) + score = torch.norm(vectors - X, dim=-1).to("cpu").numpy() + # score = np.linalg.norm(vectors - X).reshape(-1,) + ind = np.argsort(score)[:k] + return score[ind][:k], ind + + def query(self, X, k): + assert X.shape[0] == 1 + if self.method == "inner_prod": + dist, ind = self.query_inner_prod(X, k) + elif self.method == "l2": + dist, ind = self.query_l2(X, k) + else: + raise NotImplementedError() + + return dist, ind#.reshape((-1,1)) + class Scorer: """ - Implements sampler for hard triplet loss. This sampler is useful when the loss is based on the neighbourhood + Implements sampler for triplet loss. This sampler is useful when the loss is based on the neighbourhood similarity. It becomes less useful when the decision is made by neural network because it does not need to mode points to learn how to make correct decisions. """ - def __init__(self, num_embs, emb_size, src2dst, neighbours_to_sample=5, index_backend="sklearn"): + def __init__( + self, num_embs, emb_size, src2dst: Dict[int, List[int]], neighbours_to_sample=5, index_backend="brute", + method = "inner_prod", device="cpu", ns_groups=None + ): + """ + Creates an embedding table, the embeddings in this table are updated once during an epoch. Embeddings from this + table are used for nearest neighbour queries during negative sampling. We avoid keeping track of all possible + embeddings by knowing that only part of embeddings are eligible as DST. + :param num_embs: number of unique DST + :param emb_size: embedding dimensionality + :param src2dst: Mapping from SRC to all DST, need this to find the hardest negative example for all DST at once + :param neighbours_to_sample: default number of neighbours + :param index_backend: Choose between sklearn and faiss + """ self.scorer_num_emb = num_embs self.scorer_emb_size = emb_size - self.scorer_src2dst = src2dst + self.scorer_src2dst = src2dst # mapping from src to all possible dst self.scorer_index_backend = index_backend + self.scorer_method = method + self.scorer_device = device - self.scorer_all_emb = np.ones((num_embs, emb_size)) + self.scorer_all_emb = normalize(np.ones((num_embs, emb_size)), axis=1) # unique dst embedding table self.scorer_all_keys = self.get_cand_to_score_against(None) self.scorer_key_order = dict(zip(self.scorer_all_keys, range(len(self.scorer_all_keys)))) self.scorer_index = None self.neighbours_to_sample = min(neighbours_to_sample, self.scorer_num_emb) + self.prepare_ns_groups(ns_groups) + + def prepare_ns_groups(self, ns_groups): + if ns_groups is None: + return + + self.scorer_node2ns_group = {} + self.scorer_ns_group2nodes = defaultdict(list) + + unique_dst = set() + for dsts in self.scorer_src2dst.values(): + for dst in dsts: + if isinstance(dst, tuple): + unique_dst.add(dst[0]) + else: + unique_dst.add(dst) + + for id, mentioned_in in ns_groups.values: + id_ = self.id2nodeid[id] + mentioned_in_ = self.id2nodeid[mentioned_in] + self.scorer_node2ns_group[id_] = mentioned_in_ + if id_ in unique_dst: + self.scorer_ns_group2nodes[mentioned_in_].append(id_) + + def sample_negative_from_groups(self, key_groups, k): + possible_targets = [] + for key_group in key_groups: + any_key = key_group[0] + possible_negative = self.scorer_ns_group2nodes[self.scorer_node2ns_group[any_key]] + possible_negative_ = list(set(possible_negative) - set(key_group)) + if len(possible_negative_) < k: + # backup strategy + possible_negative_.extend( + random.choices(list(set(self.scorer_all_keys) - set(key_group)), k=k - len(possible_negative_)) + ) + possible_targets.append(possible_negative_) + return possible_targets - def prepare_index(self): + + def prepare_index(self, override_strategy=None): + if self.scorer_method == "nn": + self.scorer_index = None + return if self.scorer_index_backend == "sklearn": - self.scorer_index = BallTree(self.scorer_all_emb, leaf_size=1) + self.scorer_index = NearestNeighbors() + self.scorer_index.fit(self.scorer_all_emb) + # self.scorer_index = BallTree(self.scorer_all_emb, leaf_size=1) # self.scorer_index = BallTree(normalize(self.scorer_all_emb, axis=1), leaf_size=1) elif self.scorer_index_backend == "faiss": - self.scorer_index = FaissIndex(self.scorer_all_emb) + self.scorer_index = FaissIndex(self.scorer_all_emb, method=self.scorer_method) + elif self.scorer_index_backend == "brute": + self.scorer_index = Brute(self.scorer_all_emb, method=self.scorer_method, device=self.scorer_device) else: raise ValueError(f"Unsupported backend: {self.scorer_index_backend}. Supported backends are: sklearn|faiss") @@ -51,8 +165,18 @@ def sample_closest_negative(self, ids, k=None): assert ids is not None seed_pool = [] - [seed_pool.append(self.scorer_src2dst[id]) for id in ids] - nested_negative = self.get_closest_to_keys(seed_pool, k=k+1) + for id in ids: + pool = self.scorer_src2dst[id] + if len(pool) > 0 and isinstance(pool[0], tuple): + pool = [x[0] for x in pool] + seed_pool.append(pool) + if id in self.scorer_key_order: + seed_pool[-1] = seed_pool[-1] + [id] # make sure that original list is not changed + # [seed_pool.append(self.scorer_src2dst[id]) for id in ids] + if hasattr(self, "scorer_ns_group2nodes"): + nested_negative = self.sample_negative_from_groups(seed_pool, k=k+1) + else: + nested_negative = self.get_closest_to_keys(seed_pool, k=k+1) negative = [] for neg in nested_negative: @@ -68,10 +192,11 @@ def get_closest_to_keys(self, key_groups, k=None): self.scorer_all_emb[self.scorer_key_order[key]].reshape(1, -1), k=k ) closest_keys.extend(self.scorer_all_keys[c] for c in closest.ravel()) + # ensure that negative samples do not come from positive edges closest_keys_ = list(set(closest_keys) - set(key_group)) if len(closest_keys_) == 0: # backup strategy - closest_keys_ = random.choices(list(set(self.scorer_all_keys) - set(key_group)), k=k-1) + closest_keys_ = random.choices(list(set(self.scorer_all_keys) - set(key_group)), k=k) possible_targets.append(closest_keys_) # possible_targets.extend(self.scorer_all_keys[c] for c in closest.ravel()) return possible_targets @@ -80,75 +205,263 @@ def get_closest_to_keys(self, key_groups, k=None): def set_embed(self, ids, embs): ids = np.array(list(map(self.scorer_key_order.get, ids.tolist()))) - self.scorer_all_emb[ids, :] = embs# normalize(embs, axis=1) + self.scorer_all_emb[ids, :] = normalize(embs, axis=1) if self.scorer_method == "inner_prod" else embs # for ind, id in enumerate(ids): # self.all_embs[self.key_order[id], :] = embs[ind, :] def score_candidates_cosine(self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=None): - score_matr = (to_score_embs @ embs_to_score_against.t()) / \ - to_score_embs.norm(p=2, dim=1, keepdim=True) / \ - embs_to_score_against.norm(p=2, dim=1, keepdim=True).t() - y_pred = score_matr.tolist() + to_score_embs = to_score_embs / to_score_embs.norm(p=2, dim=1, keepdim=True) + embs_to_score_against = embs_to_score_against / embs_to_score_against.norm(p=2, dim=1, keepdim=True) + + score_matr = (to_score_embs @ embs_to_score_against.t()) + score_matr = (score_matr + 1.) / 2. + # score_matr = score_matr - self.margin + # score_matr[score_matr < 0.] = 0. + y_pred = score_matr.cpu().tolist() return y_pred - def score_candidates_lp(self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, link_predictor, at=None): + def set_margin(self, margin): + self.margin = margin + + def score_candidates_l2(self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=None): y_pred = [] for i in range(len(to_score_ids)): - input_embs = to_score_embs[i, :].repeat((embs_to_score_against.shape[0], 1)) - # predictor_input = torch.cat([input_embs, all_emb], dim=1) - y_pred.append(torch.nn.functional.softmax(link_predictor(input_embs, embs_to_score_against), dim=1)[:, 1].tolist()) # 0 - negative, 1 - positive + input_embs = to_score_embs[i, :].reshape(1, -1) + score_matr = torch.norm(embs_to_score_against - input_embs, dim=-1) + score_matr = 1. / (1. + score_matr) + # score_matr = score_matr + self.margin + # score_matr[score_matr < 0.] = 0 + y_pred.append(score_matr.cpu().tolist()) + + # embs_to_score_against = embs_to_score_against.unsqueeze(0) + # to_score_embs = to_score_embs.unsqueeze(1) + # + # score_matr = -torch.norm(embs_to_score_against - to_score_embs, dim=-1) + # score_matr = score_matr + self.margin + # score_matr[score_matr < 0.] = 0 + # y_pred = score_matr.cpu().tolist() return y_pred + def score_candidates_lp( + self, to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, link_predictor, at=None, + with_types=None + ): + + if with_types is None: + y_pred = [] + for i in range(len(to_score_ids)): + input_embs = to_score_embs[i, :].repeat((embs_to_score_against.shape[0], 1)) + # predictor_input = torch.cat([input_embs, all_emb], dim=1) + y_pred.append( + torch.nn.functional.softmax(link_predictor(input_embs, embs_to_score_against), dim=1)[:, 1].tolist() + ) # 0 - negative, 1 - positive + + return y_pred + + else: + y_pred = [] + for i in range(len(to_score_ids)): + y_pred.append(dict()) + for type in with_types[i]: + input_embs = to_score_embs[i, :].unsqueeze(0) + # predictor_input = torch.cat([input_embs, all_emb], dim=1) + + labels = torch.LongTensor([type]).to(link_predictor.proj_matr.weight.device) + weights = link_predictor.proj_matr(labels).reshape((-1, link_predictor.rel_dim, link_predictor.input_dim)) + rels = link_predictor.rel_emb(labels) + m_a = (weights * input_embs.unsqueeze(1)).sum(-1) + m_s = (weights * embs_to_score_against.unsqueeze(1)).sum(-1) + + transl = m_a + rels + sim = torch.norm(transl - m_s, dim=-1) + sim = 1./ (1. + sim) + y_pred[-1][type] = sim.cpu().tolist() + + return y_pred + + def get_gt_candidates(self, ids): candidates = [set(list(self.scorer_src2dst[id])) for id in ids] return candidates def get_cand_to_score_against(self, ids): + """ + Generate sorted list of all possible DST. These will be used as possible targets during NDCG calculation + :param ids: + :return: + """ all_keys = set() - [all_keys.update(self.scorer_src2dst[key]) for key in self.scorer_src2dst] + for key in self.scorer_src2dst: + cand = self.scorer_src2dst[key] + for c in cand: + if isinstance(c, tuple): # happens with graphlinkclassifier objectives + all_keys.add(c[0]) + else: + all_keys.add(c) + + # [all_keys.update(self.scorer_src2dst[key]) for key in self.scorer_src2dst] return sorted(list(all_keys)) # list(self.elem2id[a] for a in all_keys) def get_embeddings_for_scoring(self, device, **kwargs): + """ + Get all embeddings as a tensor + :param device: + :param kwargs: + :return: + """ return torch.Tensor(self.scorer_all_emb).to(device) def get_keys_for_scoring(self): return self.scorer_all_keys + def hits_at_k(self, y_true, y_pred, k): + correct = y_true + predicted = y_pred + result = [] + for y_true, y_pred in zip(correct, predicted): + ind_true = set([ind for ind, y_t in enumerate(y_true) if y_t == 1.]) + ind_pred = set(list(np.flip(np.argsort(y_pred))[:k])) + result.append(len(ind_pred.intersection(ind_true)) / min(len(ind_true), k)) + + return sum(result) / len(result) + + def mean_rank(self, y_true, y_pred): + true_ranks = [] + reciprocal_ranks = [] + + correct = y_true + predicted = y_pred + for y_true, y_pred in zip(correct, predicted): + ranks = sorted(zip(y_true, y_pred), key=lambda x: x[1], reverse=True) + for ind, (true, pred) in enumerate(ranks): + if true > 0.: + true_ranks.append(ind + 1) + reciprocal_ranks.append(1 / (ind + 1)) + break # should consider only first result https://en.wikipedia.org/wiki/Mean_reciprocal_rank + + return sum(true_ranks) / len(true_ranks), sum(reciprocal_ranks) / len(reciprocal_ranks) + + def mean_average_precision(self, y_true, y_pred): + correct = y_true + predicted = y_pred + map = 0. + for y_true, y_pred in zip(correct, predicted): + ranks = sorted(zip(y_true, y_pred), key=lambda x: x[1], reverse=True) + found_relevant = 0 + avep = 0. + for ind, (true, pred) in enumerate(ranks): + if true > 0.: + found_relevant += 1 + avep += found_relevant / (ind + 1) # precision@k + avep /= found_relevant + + map += avep + + map /= len(correct) + + return map + + def get_y_true_from_candidate_list(self, candidates, keys_to_score_against): + y_true = [[1. if key in cand else 0. for key in keys_to_score_against] for cand in candidates] + return y_true + + def get_y_true_from_candidates(self, candidates, keys_to_score_against): + + has_types = isinstance(list(candidates[0])[0], tuple) + + if not has_types: + return self.get_y_true_from_candidate_list(candidates, keys_to_score_against) + + candidate_dicts = [] + for cand in candidates: + candidate_dicts.append(dict()) + for ent, type in cand: + if type not in candidate_dicts[-1]: + candidate_dicts[-1][type] = [] + candidate_dicts[-1][type].append(ent) + + y_true = [] + for cand in candidate_dicts: + y_true.append(dict()) + for type, ents in cand.items(): + y_true[-1][type] = [1. if key in ents else 0. for key in keys_to_score_against] + assert sum(y_true[-1][type]) > 0. + + return y_true + + def flatten_pred(self, y): + + flattened = [] + for x in y: + for key, scores in x.items(): + flattened.append(scores) + + return flattened + def score_candidates(self, to_score_ids, to_score_embs, link_predictor=None, at=None, type=None, device="cpu"): if at is None: at = [1, 3, 5, 10] + start = time.time() + to_score_ids = to_score_ids.tolist() - candidates = self.get_gt_candidates(to_score_ids) + candidates = self.get_gt_candidates(to_score_ids) # positive candidates # keys_to_score_against = self.get_cand_to_score_against(to_score_ids) keys_to_score_against = self.get_keys_for_scoring() - y_true = [[1. if key in cand else 0. for key in keys_to_score_against] for cand in candidates] + y_true = self.get_y_true_from_candidates(candidates, keys_to_score_against) embs_to_score_against = self.get_embeddings_for_scoring(device=to_score_embs.device) if type == "nn": + has_types = isinstance(y_true[0], dict) + y_pred = self.score_candidates_lp( to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, - link_predictor, at=at + link_predictor, at=at, with_types=y_true if has_types else None ) elif type == "inner_prod": y_pred = self.score_candidates_cosine( to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=at ) + elif type == "l2": + y_pred = self.score_candidates_l2( + to_score_ids, to_score_embs, keys_to_score_against, embs_to_score_against, at=at + ) else: raise ValueError(f"`type` can be either `nn` or `inner_prod` but `{type}` given") + has_types = isinstance(y_pred[0], dict) + + if has_types: + y_true = self.flatten_pred(y_true) + y_pred = self.flatten_pred(y_pred) + + scores = {} + # y_true_onehot = np.array(y_true) + # labels=list(range(y_true_onehot.shape[1])) + if isinstance(at, Iterable): - scores = {f"ndcg@{k}": ndcg_score(y_true, y_pred, k=k) for k in at} + scores.update({f"hits@{k}": self.hits_at_k(y_true, y_pred, k=k) for k in at}) + scores.update({f"ndcg@{k}": ndcg_score(y_true, y_pred, k=k) for k in at}) + # scores = {f"ndcg@{k}": ndcg_score(y_true, y_pred, k=k) for k in at} else: - scores = {f"ndcg@{at}": ndcg_score(y_true, y_pred, k=at)} - return scores \ No newline at end of file + scores.update({f"hits@{at}": self.hits_at_k(y_true, y_pred, k=at)}) + scores.update({f"ndcg@{at}": ndcg_score(y_true, y_pred, k=at)}) + # scores = {f"ndcg@{at}": ndcg_score(y_true, y_pred, k=at)} + + mr, mrr = self.mean_rank(y_true, y_pred) + map = self.mean_average_precision(y_true, y_pred) + scores["mr"] = mr + scores["mrr"] = mrr + scores["map"] = map + scores["scoring_time"] = time.time() - start + return scores diff --git a/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py b/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py index f17197cf..d1f361b7 100644 --- a/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py +++ b/SourceCodeTools/models/graph/train/deprecated/sampling_multitask.py @@ -11,6 +11,7 @@ from os.path import join import logging +from SourceCodeTools.mltools.torch import compute_accuracy from SourceCodeTools.models.Embedder import Embedder from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedderWithBpeSubwords from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase @@ -19,10 +20,6 @@ from SourceCodeTools.models.graph.NodeEmbedder import NodeEmbedder -def _compute_accuracy(pred_, true_): - return torch.sum(pred_ == true_).item() / len(true_) - - class SamplingMultitaskTrainer: def __init__(self, @@ -382,7 +379,7 @@ def _evaluate_embedder(self, ee, lp, loader, neg_sampling_factor=1): logp = nn.functional.log_softmax(logits, 1) loss = nn.functional.cross_entropy(logp, labels) - acc = _compute_accuracy(logp.argmax(dim=1), labels) + acc = compute_accuracy(logp.argmax(dim=1), labels) total_loss += loss.item() total_acc += acc @@ -405,7 +402,7 @@ def _evaluate_nodes(self, ee, lp, create_api_call_loader, loader, logp = nn.functional.log_softmax(logits, 1) loss = nn.functional.cross_entropy(logp, labels) - acc = _compute_accuracy(logp.argmax(dim=1), labels) + acc = compute_accuracy(logp.argmax(dim=1), labels) total_loss += loss.item() total_acc += acc @@ -560,9 +557,9 @@ def train_all(self): input_nodes_api_call, seeds_api_call, blocks_api_call ) - train_acc_node_name = _compute_accuracy(logits_node_name.argmax(dim=1), labels_node_name) - train_acc_var_use = _compute_accuracy(logits_var_use.argmax(dim=1), labels_var_use) - train_acc_api_call = _compute_accuracy(logits_api_call.argmax(dim=1), labels_api_call) + train_acc_node_name = compute_accuracy(logits_node_name.argmax(dim=1), labels_node_name) + train_acc_var_use = compute_accuracy(logits_var_use.argmax(dim=1), labels_var_use) + train_acc_api_call = compute_accuracy(logits_api_call.argmax(dim=1), labels_api_call) train_logits = torch.cat([logits_node_name, logits_var_use, logits_api_call], 0) train_labels = torch.cat([labels_node_name, labels_var_use, labels_api_call], 0) diff --git a/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py b/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py index 93b10249..2364abbe 100644 --- a/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py +++ b/SourceCodeTools/models/graph/train/objectives/AbstractObjective.py @@ -1,52 +1,96 @@ import logging from abc import abstractmethod -from collections import OrderedDict -from itertools import chain +from collections import defaultdict import dgl import torch -from torch.nn import CosineEmbeddingLoss +from tqdm import tqdm -from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker -from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedderWithBpeSubwords, NameEmbedderWithGroups, \ - GraphLinkSampler +from SourceCodeTools.code.data.dataset.SubwordMasker import SubwordMasker +from SourceCodeTools.mltools.torch import compute_accuracy +from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedderWithBpeSubwords, GraphLinkSampler from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase -from SourceCodeTools.models.graph.LinkPredictor import LinkPredictor, CosineLinkPredictor, BilinearLinkPedictor +from SourceCodeTools.models.graph.LinkPredictor import CosineLinkPredictor, BilinearLinkPedictor, L2LinkPredictor import torch.nn as nn -def _compute_accuracy(pred_: torch.Tensor, true_: torch.Tensor): - return torch.sum(pred_ == true_).item() / len(true_) +class ZeroEdges(Exception): + def __init__(self, *args): + super(ZeroEdges, self).__init__(*args) + + +class EarlyStoppingTracker: + def __init__(self, early_stopping_tolerance): + self.early_stopping_tolerance = early_stopping_tolerance + self.early_stopping_counter = 0 + self.early_stopping_value = 0. + self.early_stopping_trigger = False + + def should_stop(self, metric): + if metric <= self.early_stopping_value: + self.early_stopping_counter += 1 + if self.early_stopping_counter >= self.early_stopping_tolerance: + return True + else: + self.early_stopping_counter = 0 + self.early_stopping_value = metric + return False + + def reset(self): + self.early_stopping_counter = 0 + self.early_stopping_value = 0. + + +def sum_scores(s): + n = len(s) + if n == 0: + n += 1 + return sum(s) / n class AbstractObjective(nn.Module): + # # set in the init + # name = None + # graph_model = None + # sampling_neighbourhood_size = None + # batch_size = None + # target_emb_size = None + # node_embedder = None + # device = None + # masker = None + # link_predictor_type = None + # measure_scores = None + # dilate_scores = None + # early_stopping_tracker = None + # early_stopping_trigger = None + # + # # set elsewhere + # target_embedder = None + # link_predictor = None + # positive_label = None + # negative_label = None + # label_dtype = None + # + # train_loader = None + # test_loader = None + # val_loader = None + # num_train_batches = None + # num_test_batches = None + # num_val_batches = None + # + # ntypes = None + def __init__( - # self, name, objective_type, graph_model, node_embedder, nodes, data_loading_func, device, self, name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, - tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker=None, - measure_ndcg=False, dilate_ndcg=1 + tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, + measure_scores=False, dilate_scores=1, early_stopping=False, early_stopping_tolerance=20, nn_index="brute", + ns_groups=None ): - """ - :param name: name for reference - :param objective_type: one of: graph_link_prediction|graph_link_classification|subword_ranker|node_classification - :param graph_model: - :param nodes: - :param data_loading_func: - :param device: - :param sampling_neighbourhood_size: - :param tokenizer_path: - :param target_emb_size: - :param link_predictor_type: - """ super(AbstractObjective, self).__init__() - # if objective_type not in {"graph_link_prediction", "graph_link_classification", "subword_ranker", "classification"}: - # raise NotImplementedError() - self.name = name - # self.type = objective_type self.graph_model = graph_model self.sampling_neighbourhood_size = sampling_neighbourhood_size self.batch_size = batch_size @@ -55,8 +99,12 @@ def __init__( self.device = device self.masker = masker self.link_predictor_type = link_predictor_type - self.measure_ndcg = measure_ndcg - self.dilate_ndcg = dilate_ndcg + self.measure_scores = measure_scores + self.dilate_scores = dilate_scores + self.nn_index = nn_index + self.early_stopping_tracker = EarlyStoppingTracker(early_stopping_tolerance) if early_stopping else None + self.early_stopping_trigger = False + self.ns_groups = ns_groups self.verify_parameters() @@ -64,10 +112,12 @@ def __init__( self.create_link_predictor() self.create_loaders() + self.target_embedding_fn = self.get_targets_from_embedder + self.negative_factor = 1 + self.update_embeddings_for_queries = True + @abstractmethod def verify_parameters(self): - # if self.link_predictor_type == "inner_prod": # TODO incorrect - # assert self.target_emb_size == self.graph_model.emb_size, "Graph embedding and target embedder dimensionality should match for `inner_prod` type of link predictor." pass def create_base_element_sampler(self, data_loading_func, nodes): @@ -78,7 +128,8 @@ def create_base_element_sampler(self, data_loading_func, nodes): def create_graph_link_sampler(self, data_loading_func, nodes): self.target_embedder = GraphLinkSampler( elements=data_loading_func(), nodes=nodes, compact_dst=False, dst_to_global=True, - emb_size=self.target_emb_size + emb_size=self.target_emb_size, device=self.device, method=self.link_predictor_type, nn_index=self.nn_index, + ns_groups=self.ns_groups ) def create_subword_embedder(self, data_loading_func, nodes, tokenizer_path): @@ -89,64 +140,88 @@ def create_subword_embedder(self, data_loading_func, nodes, tokenizer_path): @abstractmethod def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): - # # create target embedder - # if self.type == "graph_link_prediction" or self.type == "graph_link_classification": - # self.create_base_element_sampler(data_loading_func, nodes) - # elif self.type == "subword_ranker": - # self.create_subword_embedder(data_loading_func, nodes, tokenizer_path) - # elif self.type == "node_classification": - # self.target_embedder = None + # self.create_base_element_sampler(data_loading_func, nodes) + # self.create_graph_link_sampler(data_loading_func, nodes) + # self.create_subword_embedder(data_loading_func, nodes, tokenizer_path) raise NotImplementedError() def create_nn_link_predictor(self): - # self.link_predictor = LinkPredictor(self.target_emb_size + self.graph_model.emb_size).to(self.device) self.link_predictor = BilinearLinkPedictor(self.target_emb_size, self.graph_model.emb_size, 2).to(self.device) self.positive_label = 1 self.negative_label = 0 self.label_dtype = torch.long def create_inner_prod_link_predictor(self): - self.link_predictor = CosineLinkPredictor().to(self.device) - self.cosine_loss = CosineEmbeddingLoss(margin=0.4) + self.margin = -0.2 + self.target_embedder.set_margin(self.margin) + self.link_predictor = CosineLinkPredictor(margin=self.margin).to(self.device) + self.hinge_loss = nn.HingeEmbeddingLoss(margin=1. - self.margin) + + def cosine_loss(x1, x2, label): + sim = nn.CosineSimilarity() + dist = 1. - sim(x1, x2) + return self.hinge_loss(dist, label) + + # self.cosine_loss = CosineEmbeddingLoss(margin=self.margin) + self.cosine_loss = cosine_loss + self.positive_label = 1. + self.negative_label = -1. + self.label_dtype = torch.float32 + + def create_l2_link_predictor(self): + self.margin = 2.0 + self.target_embedder.set_margin(self.margin) + self.link_predictor = L2LinkPredictor().to(self.device) + # self.hinge_loss = nn.HingeEmbeddingLoss(margin=self.margin) + self.triplet_loss = nn.TripletMarginLoss(margin=self.margin) + + def l2_loss(x1, x2, label): + half = x1.shape[0] // 2 + pos = x2[:half, :] + neg = x2[half:, :] + + return self.triplet_loss(x1[:half, :], pos, neg) + # dist = torch.norm(x1 - x2, dim=-1) + # return self.hinge_loss(dist, label) + + self.l2_loss = l2_loss self.positive_label = 1. self.negative_label = -1. self.label_dtype = torch.float32 - @abstractmethod def create_link_predictor(self): - # # create link predictors - # if self.type in {"graph_link_prediction", "graph_link_classification", "subword_ranker"}: - # if self.link_predictor_type == "nn": - # self.create_nn_link_predictor() - # elif self.link_predictor_type == "inner_prod": - # self.create_inner_prod_link_predictor() - # else: - # raise NotImplementedError() - # else: - # # for node classifier - # self.link_predictor = LinkPredictor(self.graph_model.emb_size).to(self.device) - raise NotImplementedError() + if self.link_predictor_type == "nn": + self.create_nn_link_predictor() + elif self.link_predictor_type == "inner_prod": + self.create_inner_prod_link_predictor() + elif self.link_predictor_type == "l2": + self.create_l2_link_predictor() + else: + raise NotImplementedError() - # @abstractmethod def create_loaders(self): - # create loaders - # if self.type in {"graph_link_prediction", "graph_link_classification", "subword_ranker"}: + print("Number of nodes", self.graph_model.g.number_of_nodes()) train_idx, val_idx, test_idx = self._get_training_targets() train_idx, val_idx, test_idx = self.target_embedder.create_idx_pools( train_idx=train_idx, val_idx=val_idx, test_idx=test_idx ) - # else: - # raise NotImplementedError() logging.info( f"Pool sizes for {self.name}: train {self._idx_len(train_idx)}, " f"val {self._idx_len(val_idx)}, " f"test {self._idx_len(test_idx)}." ) - self.train_loader, self.test_loader, self.val_loader = self._get_loaders( + self.train_loader, self.val_loader, self.test_loader = self._get_loaders( train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, batch_size=self.batch_size # batch_size_node_name ) + def get_num_nodes(ids): + return sum(len(ids[key_]) for key_ in ids) // self.batch_size + 1 + + self.num_train_batches = get_num_nodes(train_idx) + self.num_test_batches = get_num_nodes(test_idx) + self.num_val_batches = get_num_nodes(val_idx) + def _idx_len(self, idx): if isinstance(idx, dict): length = 0 @@ -179,49 +254,51 @@ def _get_training_targets(self): # labels = labels[key] self.use_types = False - train_idx = { - ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['train_mask'], as_tuple=False).squeeze() - for ntype in self.ntypes - } - val_idx = { - ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['val_mask'], as_tuple=False).squeeze() - for ntype in self.ntypes - } - test_idx = { - ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['test_mask'], as_tuple=False).squeeze() - for ntype in self.ntypes - } + def get_targets(data_label): + return { + ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data[data_label], as_tuple=False).squeeze() + for ntype in self.ntypes + } + + train_idx = get_targets("train_mask") + val_idx = get_targets("val_mask") + test_idx = get_targets("test_mask") + + # train_idx = { + # ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['train_mask'], as_tuple=False).squeeze() + # for ntype in self.ntypes + # } + # val_idx = { + # ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['val_mask'], as_tuple=False).squeeze() + # for ntype in self.ntypes + # } + # test_idx = { + # ntype: torch.nonzero(self.graph_model.g.nodes[ntype].data['test_mask'], as_tuple=False).squeeze() + # for ntype in self.ntypes + # } else: # not sure when this is called raise NotImplementedError() - # self.ntypes = None - # # labels = g.ndata['labels'] - # train_idx = self.graph_model.g.ndata['train_mask'] - # val_idx = self.graph_model.g.ndata['val_mask'] - # test_idx = self.graph_model.g.ndata['test_mask'] - # self.use_types = False return train_idx, val_idx, test_idx - def _get_loaders(self, train_idx, val_idx, test_idx, batch_size): - # train sampler - layers = self.graph_model.num_layers - sampler = dgl.dataloading.MultiLayerNeighborSampler([self.sampling_neighbourhood_size] * layers) + def _create_loader(self, ids, batch_size=None, shuffle=False): + if batch_size is None: + # TODO + # only works when ids do not have types + batch_size = self._idx_len(ids) + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(self.graph_model.num_layers) loader = dgl.dataloading.NodeDataLoader( - self.graph_model.g, train_idx, sampler, batch_size=batch_size, shuffle=False, num_workers=0) + self.graph_model.g, ids, sampler, batch_size=batch_size, shuffle=shuffle, num_workers=0) + return loader - # validation sampler - # we do not use full neighbor to save computation resources - val_sampler = dgl.dataloading.MultiLayerNeighborSampler([self.sampling_neighbourhood_size] * layers) - val_loader = dgl.dataloading.NodeDataLoader( - self.graph_model.g, val_idx, val_sampler, batch_size=batch_size, shuffle=False, num_workers=0) + def _get_loaders(self, train_idx, val_idx, test_idx, batch_size): - # we do not use full neighbor to save computation resources - test_sampler = dgl.dataloading.MultiLayerNeighborSampler([self.sampling_neighbourhood_size] * layers) - test_loader = dgl.dataloading.NodeDataLoader( - self.graph_model.g, test_idx, test_sampler, batch_size=batch_size, shuffle=False, num_workers=0) + train_loader = self._create_loader(train_idx, batch_size, shuffle=False) + val_loader = self._create_loader(val_idx, batch_size, shuffle=False) + test_loader = self._create_loader(test_idx, batch_size, shuffle=False) - return loader, val_loader, test_loader + return train_loader, val_loader, test_loader def reset_iterator(self, data_split): iter_name = f"{data_split}_loader_iter" @@ -233,11 +310,10 @@ def loader_next(self, data_split): setattr(self, iter_name, iter(getattr(self, f"{data_split}_loader"))) return next(getattr(self, iter_name)) - def _create_loader(self, indices): - sampler = dgl.dataloading.MultiLayerNeighborSampler( - [self.sampling_neighbourhood_size] * self.graph_model.num_layers) - return dgl.dataloading.NodeDataLoader( - self.graph_model.g, indices, sampler, batch_size=len(indices), num_workers=0) + # def _create_loader(self, indices): + # sampler = dgl.dataloading.MultiLayerFullNeighborSampler(self.graph_model.num_layers) + # return dgl.dataloading.NodeDataLoader( + # self.graph_model.g, indices, sampler, batch_size=len(indices), num_workers=0) def _extract_embed(self, input_nodes, train_embeddings=True, masked=None): emb = {} @@ -257,15 +333,30 @@ def compute_acc_loss(self, node_embs_, element_embs_, labels): elif self.link_predictor_type == "inner_prod": loss = self.cosine_loss(node_embs_, element_embs_, labels) labels[labels < 0] = 0 + elif self.link_predictor_type == "l2": + loss = self.l2_loss(node_embs_, element_embs_, labels) + labels[labels < 0] = 0 + # num_examples = len(labels) // 2 + # anchor = node_embs_[:num_examples, :] + # positive = element_embs_[:num_examples, :] + # negative = element_embs_[num_examples:, :] + # # pos_labels_ = labels[:num_examples] + # # neg_labels_ = labels[num_examples:] + # margin = 1. + # triplet = nn.TripletMarginLoss(margin=margin) + # self.target_embedder.set_margin(margin) + # loss = triplet(anchor, positive, negative) + # logits = (torch.norm(node_embs_ - element_embs_, keepdim=True) < 1.).float() + # logits = torch.cat([1 - logits, logits], dim=1) + # labels[labels < 0] = 0 else: raise NotImplementedError() - acc = _compute_accuracy(logits.argmax(dim=1), labels) + acc = compute_accuracy(logits.argmax(dim=1), labels) return acc, loss - - def _logits_batch(self, input_nodes, blocks, train_embeddings=True, masked=None): + def _graph_embeddings(self, input_nodes, blocks, train_embeddings=True, masked=None): cumm_logits = [] @@ -312,91 +403,173 @@ def seeds_to_global(self, seeds): else: return seeds - def _logits_embedder(self, node_embeddings, elem_embedder, link_predictor, seeds, negative_factor=1): - k = negative_factor - indices = self.seeds_to_global(seeds).tolist() - batch_size = len(indices) + def sample_negative(self, ids, k, neg_sampling_strategy): + if neg_sampling_strategy is not None: + negative = self.target_embedder.sample_negative( + k, ids=ids, strategy=neg_sampling_strategy + ) + else: + negative = self.target_embedder.sample_negative( + k, ids=ids, + ) + return negative - node_embeddings_batch = node_embeddings - element_embeddings = elem_embedder(elem_embedder[indices].to(self.device)) + def get_targets_from_nodes( + self, positive_indices, negative_indices=None, train_embeddings=True + ): + negative_indices = torch.tensor(negative_indices, dtype=torch.long) if negative_indices is not None else None - # labels_pos = torch.ones(batch_size, dtype=torch.long) - labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype) + def get_embeddings_for_targets(dst): + unique_dst, slice_map = self._handle_non_unique(dst) + assert unique_dst[slice_map].tolist() == dst.tolist() - node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1) - # negative_random = elem_embedder(elem_embedder.sample_negative(batch_size * k).to(self.device)) - negative_random = elem_embedder(elem_embedder.sample_negative(batch_size * k, ids=indices).to(self.device)) # closest negative + dataloader = self._create_loader(unique_dst) + input_nodes, dst_seeds, blocks = next(iter(dataloader)) + blocks = [blk.to(self.device) for blk in blocks] + assert dst_seeds.shape == unique_dst.shape + assert dst_seeds.tolist() == unique_dst.tolist() + unique_dst_embeddings = self._graph_embeddings(input_nodes, blocks, train_embeddings) # use_types, ntypes) + dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)] + + if self.update_embeddings_for_queries: + self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(), + unique_dst_embeddings.detach().cpu().numpy()) - # labels_neg = torch.zeros(batch_size * k, dtype=torch.long) - labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype) + return dst_embeddings - # positive_batch = torch.cat([node_embeddings_batch, element_embeddings], 1) - # negative_batch = torch.cat([node_embeddings_neg_batch, negative_random], 1) - # batch = torch.cat([positive_batch, negative_batch], 0) - # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) + positive_dst = get_embeddings_for_targets(positive_indices) + negative_dst = get_embeddings_for_targets(negative_indices) if negative_indices is not None else None + return positive_dst, negative_dst + + def get_targets_from_embedder( + self, positive_indices, negative_indices=None, train_embeddings=True + ): + + # def get_embeddings_for_targets(dst): + # unique_dst, slice_map = self._handle_non_unique(dst) + # assert unique_dst[slice_map].tolist() == dst.tolist() + # unique_dst_embeddings = self.target_embedder(unique_dst.to(self.device)) + # dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)] # - # logits = link_predictor(batch) + # if self.update_embeddings_for_queries: + # self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(), + # unique_dst_embeddings.detach().cpu().numpy()) # - # return logits, labels - nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0) - embs = torch.cat([element_embeddings, negative_random], dim=0) - labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) - return nodes, embs, labels + # return dst_embeddings + + positive_dst = self.target_embedder(positive_indices.to(self.device)) + negative_dst = self.target_embedder(negative_indices.to(self.device)) if negative_indices is not None else None + # + # positive_dst = get_embeddings_for_targets(positive_indices) + # negative_dst = get_embeddings_for_targets(negative_indices) if negative_indices is not None else None + + return positive_dst, negative_dst + + def create_positive_labels(self, ids): + return torch.full((len(ids),), self.positive_label, dtype=self.label_dtype) - def _logits_nodes(self, node_embeddings, - elem_embedder, link_predictor, create_dataloader, - src_seeds, negative_factor=1, train_embeddings=True): + def create_negative_labels(self, ids, k): + return torch.full((len(ids) * k,), self.negative_label, dtype=self.label_dtype) + + def prepare_for_prediction( + self, node_embeddings, seeds, target_embedding_fn, negative_factor=1, + neg_sampling_strategy=None, train_embeddings=True, + ): k = negative_factor - indices = self.seeds_to_global(src_seeds).tolist() + indices = self.seeds_to_global(seeds).tolist() batch_size = len(indices) node_embeddings_batch = node_embeddings - next_call_indices = elem_embedder[indices] # this assumes indices is torch tensor - - # dst targets are not unique - unique_dst, slice_map = self._handle_non_unique(next_call_indices) - assert unique_dst[slice_map].tolist() == next_call_indices.tolist() - - dataloader = create_dataloader(unique_dst) - input_nodes, dst_seeds, blocks = next(iter(dataloader)) - blocks = [blk.to(self.device) for blk in blocks] - assert dst_seeds.shape == unique_dst.shape - assert dst_seeds.tolist() == unique_dst.tolist() - unique_dst_embeddings = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes) - next_call_embeddings = unique_dst_embeddings[slice_map.to(self.device)] - # labels_pos = torch.ones(batch_size, dtype=torch.long) - labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype) - node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1) - # negative_indices = torch.tensor(elem_embedder.sample_negative( - # batch_size * k), dtype=torch.long) # embeddings are sampled from 3/4 unigram distribution - negative_indices = torch.tensor(elem_embedder.sample_negative( - batch_size * k, ids=indices), dtype=torch.long) # closest negative - unique_negative, slice_map = self._handle_non_unique(negative_indices) - assert unique_negative[slice_map].tolist() == negative_indices.tolist() - - dataloader = create_dataloader(unique_negative) - input_nodes, dst_seeds, blocks = next(iter(dataloader)) - blocks = [blk.to(self.device) for blk in blocks] - assert dst_seeds.shape == unique_negative.shape - assert dst_seeds.tolist() == unique_negative.tolist() - unique_negative_random = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes) - negative_random = unique_negative_random[slice_map.to(self.device)] - # labels_neg = torch.zeros(batch_size * k, dtype=torch.long) - labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype) - - # positive_batch = torch.cat([node_embeddings_batch, next_call_embeddings], 1) - # negative_batch = torch.cat([node_embeddings_neg_batch, negative_random], 1) - # batch = torch.cat([positive_batch, negative_batch], 0) - # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) - # - # logits = link_predictor(batch) - # - # return logits, labels - nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0) - embs = torch.cat([next_call_embeddings, negative_random], dim=0) + + positive_indices = self.target_embedder[indices] + negative_indices = self.sample_negative( + k=batch_size * k, ids=indices, neg_sampling_strategy=neg_sampling_strategy + ) + + positive_dst, negative_dst = target_embedding_fn( + positive_indices, negative_indices, train_embeddings + ) + + # TODO breaks cache in + # SourceCodeTools.models.graph.train.objectives.GraphLinkClassificationObjective.TargetLinkMapper.get_labels + labels_pos = self.create_positive_labels(indices) + labels_neg = self.create_negative_labels(indices, k) + + src_embs = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0) + dst_embs = torch.cat([positive_dst, negative_dst], dim=0) labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) - return nodes, embs, labels + return src_embs, dst_embs, labels + + # def _logits_embedder( + # self, node_embeddings, elem_embedder, link_predictor, seeds, negative_factor=1, neg_sampling_strategy=None + # ): + # k = negative_factor + # indices = self.seeds_to_global(seeds).tolist() + # batch_size = len(indices) + # + # node_embeddings_batch = node_embeddings + # node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1) + # + # element_embeddings = elem_embedder(elem_embedder[indices].to(self.device)) + # negative_random = elem_embedder(self.sample_negative( + # k=batch_size * k, ids=indices, neg_sampling_strategy=neg_sampling_strategy + # ).to(self.device)) + # + # labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype) + # labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype) + # + # src_embs = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0) + # dst_embs = torch.cat([element_embeddings, negative_random], dim=0) + # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) + # return src_embs, dst_embs, labels + # + # def _logits_nodes( + # self, node_embeddings, elem_embedder, link_predictor, create_dataloader, + # src_seeds, negative_factor=1, train_embeddings=True, neg_sampling_strategy=None, + # update_embeddings_for_queries=False + # ): + # k = negative_factor + # indices = self.seeds_to_global(src_seeds).tolist() + # batch_size = len(indices) + # + # node_embeddings_batch = node_embeddings + # node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1) + # + # next_call_indices = elem_embedder[indices] # this assumes indices is torch tensor + # negative_indices = torch.tensor(self.sample_negative( + # k=batch_size * k, ids=indices, neg_sampling_strategy=neg_sampling_strategy + # ), dtype=torch.long) + # + # # dst targets are not unique + # def get_embeddings_for_targets(dst, update_embeddings_for_queries): + # unique_dst, slice_map = self._handle_non_unique(dst) + # assert unique_dst[slice_map].tolist() == dst.tolist() + # + # dataloader = create_dataloader(unique_dst) + # input_nodes, dst_seeds, blocks = next(iter(dataloader)) + # blocks = [blk.to(self.device) for blk in blocks] + # assert dst_seeds.shape == unique_dst.shape + # assert dst_seeds.tolist() == unique_dst.tolist() + # unique_dst_embeddings = self._graph_embeddings(input_nodes, blocks, train_embeddings) # use_types, ntypes) + # dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)] + # + # if update_embeddings_for_queries: + # self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(), + # unique_dst_embeddings.detach().cpu().numpy()) + # + # return dst_embeddings + # + # next_call_embeddings = get_embeddings_for_targets(next_call_indices, update_embeddings_for_queries) + # negative_random = get_embeddings_for_targets(negative_indices, update_embeddings_for_queries) + # + # labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype) + # labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype) + # + # src_embs = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0) + # dst_embs = torch.cat([next_call_embeddings, negative_random], dim=0) + # labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) + # return src_embs, dst_embs, labels def seeds_to_python(self, seeds): if isinstance(seeds, dict): @@ -407,47 +580,32 @@ def seeds_to_python(self, seeds): python_seeds = seeds.tolist() return python_seeds - @abstractmethod - def forward(self, input_nodes, seeds, blocks, train_embeddings=True): - # masked = None - # if self.type in {"subword_ranker"}: - # masked = self.masker.get_mask(self.seeds_to_python(seeds)) - # graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked) - # if self.type in {"subword_ranker"}: - # # logits, labels = self._logits_embedder(graph_emb, self.target_embedder, self.link_predictor, seeds) - # node_embs_, element_embs_, labels = self._logits_embedder(graph_emb, self.target_embedder, self.link_predictor, seeds) - # elif self.type in {"graph_link_prediction", "graph_link_classification"}: - # # logits, labels = self._logits_nodes(graph_emb, self.target_embedder, self.link_predictor, - # # self._create_loader, seeds, train_embeddings=train_embeddings) - # node_embs_, element_embs_, labels = self._logits_nodes(graph_emb, self.target_embedder, self.link_predictor, - # self._create_loader, seeds, train_embeddings=train_embeddings) - # else: - # raise NotImplementedError() - # - # # logits = self.link_predictor(node_embs_, element_embs_) - # # - # # acc = _compute_accuracy(logits.argmax(dim=1), labels) - # # logp = nn.functional.log_softmax(logits, 1) - # # loss = nn.functional.nll_loss(logp, labels) - # acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) - # - # return loss, acc - raise NotImplementedError() + def forward(self, input_nodes, seeds, blocks, train_embeddings=True, neg_sampling_strategy=None): + masked = self.masker.get_mask(self.seeds_to_python(seeds)) if self.masker is not None else None + graph_emb = self._graph_embeddings(input_nodes, blocks, train_embeddings, masked=masked) + node_embs_, element_embs_, labels = self.prepare_for_prediction( + graph_emb, seeds, self.target_embedding_fn, negative_factor=self.negative_factor, + neg_sampling_strategy=neg_sampling_strategy, + train_embeddings=train_embeddings + ) + + acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) - def _evaluate_embedder(self, ee, lp, data_split, neg_sampling_factor=1): + return loss, acc - total_loss = 0 - total_acc = 0 - ndcg_at = [1, 3, 5, 10] - total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at} - ndcg_count = 0 + def evaluate_objective(self, data_split, neg_sampling_strategy=None, negative_factor=1): + # total_loss = 0 + # total_acc = 0 + at = [1, 3, 5, 10] + # total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at} + # ndcg_count = 0 count = 0 - # if self.measure_ndcg: - # if self.link_predictor == "inner_prod": - # self.target_embedder.prepare_index() + scores = defaultdict(list) - for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"): + for input_nodes, seeds, blocks in tqdm( + getattr(self, f"{data_split}_loader"), total=getattr(self, f"num_{data_split}_batches") + ): blocks = [blk.to(self.device) for blk in blocks] if self.masker is None: @@ -455,134 +613,151 @@ def _evaluate_embedder(self, ee, lp, data_split, neg_sampling_factor=1): else: masked = self.masker.get_mask(self.seeds_to_python(seeds)) - src_embs = self._logits_batch(input_nodes, blocks, masked=masked) - # logits, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor) - node_embs_, element_embs_, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor) - - if self.measure_ndcg: - if count % self.dilate_ndcg == 0: - ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device) - for key, val in ndcg.items(): - total_ndcg[key] = total_ndcg[key] + val - ndcg_count += 1 - - # logits = self.link_predictor(node_embs_, element_embs_) - # - # logp = nn.functional.log_softmax(logits, 1) - # loss = nn.functional.cross_entropy(logp, labels) - # acc = _compute_accuracy(logp.argmax(dim=1), labels) + src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked) + node_embs_, element_embs_, labels = self.prepare_for_prediction( + src_embs, seeds, self.target_embedding_fn, negative_factor=negative_factor, + neg_sampling_strategy=neg_sampling_strategy, + train_embeddings=False + ) + + if self.measure_scores: + if count % self.dilate_scores == 0: + scores_ = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, + self.link_predictor, at=at, + type=self.link_predictor_type, device=self.device) + for key, val in scores_.items(): + scores[key].append(val) acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) - total_loss += loss.item() - total_acc += acc + scores["Loss"].append(loss.item()) + scores["Accuracy"].append(acc) count += 1 - return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_ndcg else None - def _evaluate_nodes(self, ee, lp, create_api_call_loader, data_split, neg_sampling_factor=1): - - total_loss = 0 - total_acc = 0 - ndcg_at = [1, 3, 5, 10] - total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at} - ndcg_count = 0 - count = 0 - - for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"): - blocks = [blk.to(self.device) for blk in blocks] - - if self.masker is None: - masked = None - else: - masked = self.masker.get_mask(self.seeds_to_python(seeds)) - - src_embs = self._logits_batch(input_nodes, blocks, masked=masked) - # logits, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor) - node_embs_, element_embs_, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor) - - if self.measure_ndcg: - if count % self.dilate_ndcg == 0: - ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device) - for key, val in ndcg.items(): - total_ndcg[key] = total_ndcg[key] + val - ndcg_count += 1 - - # logits = self.link_predictor(node_embs_, element_embs_) - # - # logp = nn.functional.log_softmax(logits, 1) - # loss = nn.functional.cross_entropy(logp, labels) - # acc = _compute_accuracy(logp.argmax(dim=1), labels) - - acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) - - total_loss += loss.item() - total_acc += acc - count += 1 - return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_ndcg else None + scores = {key: sum_scores(val) for key, val in scores.items()} + return scores + # return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in + # total_ndcg.items()} if self.measure_scores else None + + # def _evaluate_embedder(self, ee, lp, data_split, neg_sampling_factor=1): + # + # total_loss = 0 + # total_acc = 0 + # ndcg_at = [1, 3, 5, 10] + # total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at} + # ndcg_count = 0 + # count = 0 + # + # for input_nodes, seeds, blocks in tqdm( + # getattr(self, f"{data_split}_loader"), total=getattr(self, f"num_{data_split}_batches") + # ): + # blocks = [blk.to(self.device) for blk in blocks] + # + # if self.masker is None: + # masked = None + # else: + # masked = self.masker.get_mask(self.seeds_to_python(seeds)) + # + # src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked) + # # logits, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor) + # node_embs_, element_embs_, labels = self._logits_embedder(src_embs, ee, lp, seeds, neg_sampling_factor) + # + # if self.measure_scores: + # if count % self.dilate_scores == 0: + # ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device) + # for key, val in ndcg.items(): + # total_ndcg[key] = total_ndcg[key] + val + # ndcg_count += 1 + # + # # logits = self.link_predictor(node_embs_, element_embs_) + # # + # # logp = nn.functional.log_softmax(logits, 1) + # # loss = nn.functional.cross_entropy(logp, labels) + # # acc = _compute_accuracy(logp.argmax(dim=1), labels) + # + # acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) + # + # total_loss += loss.item() + # total_acc += acc + # count += 1 + # return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_scores else None + # + # def _evaluate_nodes(self, ee, lp, create_api_call_loader, data_split, neg_sampling_factor=1): + # + # total_loss = 0 + # total_acc = 0 + # ndcg_at = [1, 3, 5, 10] + # total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at} + # ndcg_count = 0 + # count = 0 + # + # for input_nodes, seeds, blocks in tqdm( + # getattr(self, f"{data_split}_loader"), total=getattr(self, f"num_{data_split}_batches") + # ): + # blocks = [blk.to(self.device) for blk in blocks] + # + # if self.masker is None: + # masked = None + # else: + # masked = self.masker.get_mask(self.seeds_to_python(seeds)) + # + # src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked) + # # logits, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor) + # # node_embs_, element_embs_, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor) + # + # node_embs_, element_embs_, labels = self.prepare_for_prediction( + # src_embs, seeds, self.target_embedding_fn, negative_factor=1, + # neg_sampling_strategy=None, train_embeddings=True, + # update_embeddings_for_queries=False + # ) + # + # if self.measure_scores: + # if count % self.dilate_scores == 0: + # ndcg = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, self.link_predictor, at=ndcg_at, type=self.link_predictor_type, device=self.device) + # for key, val in ndcg.items(): + # total_ndcg[key] = total_ndcg[key] + val + # ndcg_count += 1 + # + # # logits = self.link_predictor(node_embs_, element_embs_) + # # + # # logp = nn.functional.log_softmax(logits, 1) + # # loss = nn.functional.cross_entropy(logp, labels) + # # acc = _compute_accuracy(logp.argmax(dim=1), labels) + # + # acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) + # + # total_loss += loss.item() + # total_acc += acc + # count += 1 + # return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_scores else None + + def check_early_stopping(self, metric): + """ + Checks the metric value and raises Early Stopping when the metric stops increasing. + Assumes that the metric grows. Uses accuracy as a metric by default. Check implementation of child classes. + :param metric: metric value + :return: Nothing + """ + if self.early_stopping_tracker is not None: + self.early_stopping_trigger = self.early_stopping_tracker.should_stop(metric) - @abstractmethod - def evaluate(self, data_split, neg_sampling_factor=1): - # if self.type in {"subword_ranker"}: - # loss, acc, ndcg = self._evaluate_embedder( - # self.target_embedder, self.link_predictor, data_split=data_split, neg_sampling_factor=neg_sampling_factor - # ) - # elif self.type in {"graph_link_prediction", "graph_link_classification"}: - # loss, acc = self._evaluate_nodes( - # self.target_embedder, self.link_predictor, self._create_loader, data_split=data_split, - # neg_sampling_factor=neg_sampling_factor - # ) - # ndcg = None - # else: - # raise NotImplementedError() - # - # return loss, acc, ndcg - raise NotImplementedError() + def evaluate(self, data_split, *, neg_sampling_strategy=None, early_stopping=False, early_stopping_tolerance=20): + # negative factor is 1 for evaluation + scores = self.evaluate_objective(data_split, neg_sampling_strategy=None, negative_factor=1) + if data_split == "val": + self.check_early_stopping(scores["Accuracy"]) + return scores @abstractmethod def parameters(self, recurse: bool = True): - # if self.type in {"subword_ranker"}: - # return chain(self.target_embedder.parameters(), self.link_predictor.parameters()) - # elif self.type in {"graph_link_prediction", "graph_link_classification"}: - # return self.link_predictor.parameters() - # else: - # raise NotImplementedError() raise NotImplementedError() @abstractmethod def custom_state_dict(self): - # state_dict = OrderedDict() - # if self.type in {"subword_ranker"}: - # for k, v in self.target_embedder.state_dict().items(): - # state_dict[f"target_embedder.{k}"] = v - # for k, v in self.link_predictor.state_dict().items(): - # state_dict[f"link_predictor.{k}"] = v - # # state_dict["target_embedder"] = self.target_embedder.state_dict() - # # state_dict["link_predictor"] = self.link_predictor.state_dict() - # elif self.type in {"graph_link_prediction", "graph_link_classification"}: - # for k, v in self.link_predictor.state_dict().items(): - # state_dict[f"link_predictor.{k}"] = v - # # state_dict["link_predictor"] = self.link_predictor.state_dict() - # else: - # raise NotImplementedError() - # - # return state_dict raise NotImplementedError() @abstractmethod def custom_load_state_dict(self, state_dicts): - # if self.type in {"subword_ranker"}: - # self.target_embedder.load_state_dict( - # self.get_prefix("target_embedder", state_dicts) - # ) - # self.link_predictor.load_state_dict( - # self.get_prefix("link_predictor", state_dicts) - # ) - # elif self.type in {"graph_link_prediction", "graph_link_classification"}: - # self.link_predictor.load_state_dict( - # self.get_prefix("link_predictor", state_dicts) - # ) - # else: - # raise NotImplementedError() raise NotImplementedError() def get_prefix(self, prefix, state_dict): diff --git a/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py b/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py index 0cd50be7..4f762541 100644 --- a/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py +++ b/SourceCodeTools/models/graph/train/objectives/GraphLinkClassificationObjective.py @@ -1,15 +1,15 @@ -from collections import OrderedDict import numpy as np import random as rnd import torch from torch.nn import CrossEntropyLoss -from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker +from SourceCodeTools.code.data.dataset.SubwordMasker import SubwordMasker +from SourceCodeTools.mltools.torch import compute_accuracy +from SourceCodeTools.models.graph.ElementEmbedder import GraphLinkSampler from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase -from SourceCodeTools.models.graph.LinkPredictor import LinkClassifier, BilinearLinkPedictor +from SourceCodeTools.models.graph.LinkPredictor import BilinearLinkPedictor, TransRLinkPredictor from SourceCodeTools.models.graph.train.Scorer import Scorer -from SourceCodeTools.models.graph.train.objectives import GraphLinkObjective -from SourceCodeTools.models.graph.train.sampling_multitask2 import _compute_accuracy +from SourceCodeTools.models.graph.train.objectives.GraphLinkObjective import GraphLinkObjective from SourceCodeTools.tabular.common import compact_property @@ -18,69 +18,32 @@ def __init__( self, name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1, ns_groups=None ): super().__init__( name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, ns_groups=ns_groups + ) + self.measure_scores = True + self.update_embeddings_for_queries = True + + def create_graph_link_sampler(self, data_loading_func, nodes): + self.target_embedder = TargetLinkMapper( + elements=data_loading_func(), nodes=nodes, emb_size=self.target_emb_size, ns_groups=self.ns_groups ) - self.measure_ndcg = False def create_link_predictor(self): - self.link_predictor = BilinearLinkPedictor(self.target_emb_size, self.graph_model.emb_size, self.target_embedder.num_classes).to(self.device) + self.link_predictor = BilinearLinkPedictor( + self.target_emb_size, self.graph_model.emb_size, self.target_embedder.num_classes + ).to(self.device) # self.positive_label = 1 - self.negative_label = 0 + self.negative_label = self.target_embedder.null_class self.label_dtype = torch.long - def _logits_nodes(self, node_embeddings, - elem_embedder, link_predictor, create_dataloader, - src_seeds, negative_factor=1, train_embeddings=True): - k = negative_factor - indices = self.seeds_to_global(src_seeds).tolist() - batch_size = len(indices) - - node_embeddings_batch = node_embeddings - dst_indices, labels_pos = elem_embedder[indices] # this assumes indices is torch tensor - - # dst targets are not unique - unique_dst, slice_map = self._handle_non_unique(dst_indices) - assert unique_dst[slice_map].tolist() == dst_indices.tolist() - - dataloader = create_dataloader(unique_dst) - input_nodes, dst_seeds, blocks = next(iter(dataloader)) - blocks = [blk.to(self.device) for blk in blocks] - assert dst_seeds.shape == unique_dst.shape - assert dst_seeds.tolist() == unique_dst.tolist() - unique_dst_embeddings = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes) - dst_embeddings = unique_dst_embeddings[slice_map.to(self.device)] - - self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(), unique_dst_embeddings.detach().cpu().numpy()) - - node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1) - # negative_indices = torch.tensor(elem_embedder.sample_negative( - # batch_size * k), dtype=torch.long) # embeddings are sampled from 3/4 unigram distribution - negative_indices = torch.tensor(elem_embedder.sample_negative( - batch_size * k, ids=indices), dtype=torch.long) # closest negative - unique_negative, slice_map = self._handle_non_unique(negative_indices) - assert unique_negative[slice_map].tolist() == negative_indices.tolist() - - dataloader = create_dataloader(unique_negative) - input_nodes, dst_seeds, blocks = next(iter(dataloader)) - blocks = [blk.to(self.device) for blk in blocks] - assert dst_seeds.shape == unique_negative.shape - assert dst_seeds.tolist() == unique_negative.tolist() - unique_negative_random = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes) - negative_random = unique_negative_random[slice_map.to(self.device)] - labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype) - - self.target_embedder.set_embed(unique_negative.detach().cpu().numpy(), unique_negative_random.detach().cpu().numpy()) - - nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0) - embs = torch.cat([dst_embeddings, negative_random], dim=0) - labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) - return nodes, embs, labels + def create_positive_labels(self, ids): + return torch.LongTensor(self.target_embedder.get_labels(ids)) def compute_acc_loss(self, node_embs_, element_embs_, labels): logits = self.link_predictor(node_embs_, element_embs_) @@ -89,14 +52,52 @@ def compute_acc_loss(self, node_embs_, element_embs_, labels): loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) - acc = _compute_accuracy(logits.argmax(dim=1), labels) + acc = compute_accuracy(logits.argmax(dim=1), labels) return acc, loss -class TargetLinkMapper(ElementEmbedderBase, Scorer): - def __init__(self, elements, nodes): - ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes, ) +class TransRObjective(GraphLinkClassificationObjective): + def __init__( + self, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, + measure_scores=False, dilate_scores=1, ns_groups=None + ): + super().__init__( + "TransR", graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, ns_groups=ns_groups + ) + + def create_link_predictor(self): + self.link_predictor = TransRLinkPredictor( + input_dim=self.target_emb_size, rel_dim=30, + num_relations=self.target_embedder.num_classes + ).to(self.device) + # self.positive_label = 1 + self.negative_label = -1 + self.label_dtype = torch.long + + def compute_acc_loss(self, node_embs_, element_embs_, labels): + + num_examples = len(labels) // 2 + anchor = node_embs_[:num_examples, :] + positive = element_embs_[:num_examples, :] + negative = element_embs_[num_examples:, :] + labels_ = labels[:num_examples] + + loss, sim = self.link_predictor(anchor, positive, negative, labels_) + acc = compute_accuracy(sim, labels >= 0) + + return acc, loss + +class TargetLinkMapper(GraphLinkSampler): + def __init__(self, elements, nodes, emb_size=1, ns_groups=None): + super(TargetLinkMapper, self).__init__( + elements, nodes, compact_dst=False, dst_to_global=True, emb_size=emb_size, ns_groups=ns_groups + ) def init(self, compact_dst): if compact_dst: @@ -106,7 +107,7 @@ def init(self, compact_dst): self.elements['emb_id'] = self.elements['dst'] self.link_type2id, self.inverse_link_type_map = compact_property(self.elements['type'], return_order=True, index_from_one=True) - self.elements["link_type"] = self.elements["type"].apply(lambda x: self.link_type2id[x]) + self.elements["link_type"] = list(map(lambda x: self.link_type2id[x], self.elements["type"].tolist())) self.element_lookup = {} for id_, emb_id, link_type in self.elements[["id", "emb_id", "link_type"]].values: @@ -117,12 +118,22 @@ def init(self, compact_dst): self.init_neg_sample() self.num_classes = len(self.inverse_link_type_map) + self.null_class = 0 def __getitem__(self, ids): + self.cached_ids = ids node_ids, labels = zip(*(rnd.choice(self.element_lookup[id]) for id in ids)) - return np.array(node_ids, dtype=np.int32), np.array(labels, dtype=np.int32) + self.cached_labels = list(labels) + return np.array(node_ids, dtype=np.int32)#, torch.LongTensor(np.array(labels, dtype=np.int32)) + + def get_labels(self, ids): + if self.cached_ids == ids: + return self.cached_labels + else: + node_ids, labels = zip(*(rnd.choice(self.element_lookup[id]) for id in ids)) + return list(labels) - def sample_negative(self, size, ids=None, strategy="closest"): # TODO switch to w2v? + def sample_negative(self, size, ids=None, strategy="w2v"): # TODO switch to w2v? if strategy == "w2v": negative = ElementEmbedderBase.sample_negative(self, size) else: diff --git a/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py b/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py index 3aaaedde..00258c50 100644 --- a/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py +++ b/SourceCodeTools/models/graph/train/objectives/GraphLinkObjective.py @@ -1,9 +1,6 @@ from collections import OrderedDict -import torch - -from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker -from SourceCodeTools.models.graph.LinkPredictor import LinkClassifier +from SourceCodeTools.code.data.dataset.SubwordMasker import SubwordMasker from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective @@ -12,14 +9,18 @@ def __init__( self, name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1, nn_index="brute", ns_groups=None ): super(GraphLinkObjective, self).__init__( name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index, + ns_groups=ns_groups ) + self.target_embedding_fn = self.get_targets_from_nodes + self.negative_factor = 1 + self.update_embeddings_for_queries = True def verify_parameters(self): pass @@ -27,34 +28,6 @@ def verify_parameters(self): def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): self.create_graph_link_sampler(data_loading_func, nodes) - def create_link_predictor(self): - if self.link_predictor_type == "nn": - self.create_nn_link_predictor() - elif self.link_predictor_type == "inner_prod": - self.create_inner_prod_link_predictor() - else: - raise NotImplementedError() - - def forward(self, input_nodes, seeds, blocks, train_embeddings=True): - masked = None - graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked) - node_embs_, element_embs_, labels = self._logits_nodes( - graph_emb, self.target_embedder, self.link_predictor, - self._create_loader, seeds, train_embeddings=train_embeddings - ) - - acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) - - return loss, acc - - def evaluate(self, data_split, neg_sampling_factor=1): - loss, acc, ndcg = self._evaluate_nodes( - self.target_embedder, self.link_predictor, self._create_loader, data_split=data_split, - neg_sampling_factor=neg_sampling_factor - ) - # ndcg = None - return loss, acc, ndcg - def parameters(self, recurse: bool = True): return self.link_predictor.parameters() @@ -69,86 +42,37 @@ def custom_load_state_dict(self, state_dicts): self.get_prefix("link_predictor", state_dicts) ) - def _logits_nodes(self, node_embeddings, - elem_embedder, link_predictor, create_dataloader, - src_seeds, negative_factor=1, train_embeddings=True): - k = negative_factor - indices = self.seeds_to_global(src_seeds).tolist() - batch_size = len(indices) - - node_embeddings_batch = node_embeddings - next_call_indices = elem_embedder[indices] # this assumes indices is torch tensor - - # dst targets are not unique - unique_dst, slice_map = self._handle_non_unique(next_call_indices) - assert unique_dst[slice_map].tolist() == next_call_indices.tolist() - - dataloader = create_dataloader(unique_dst) - input_nodes, dst_seeds, blocks = next(iter(dataloader)) - blocks = [blk.to(self.device) for blk in blocks] - assert dst_seeds.shape == unique_dst.shape - assert dst_seeds.tolist() == unique_dst.tolist() - unique_dst_embeddings = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes) - next_call_embeddings = unique_dst_embeddings[slice_map.to(self.device)] - labels_pos = torch.full((batch_size,), self.positive_label, dtype=self.label_dtype) - - node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1) - # negative_indices = torch.tensor(elem_embedder.sample_negative( - # batch_size * k), dtype=torch.long) # embeddings are sampled from 3/4 unigram distribution - negative_indices = torch.tensor(elem_embedder.sample_negative( - batch_size * k, ids=indices), dtype=torch.long) # closest negative - unique_negative, slice_map = self._handle_non_unique(negative_indices) - assert unique_negative[slice_map].tolist() == negative_indices.tolist() - - self.target_embedder.set_embed(unique_dst.detach().cpu().numpy(), unique_dst_embeddings.detach().cpu().numpy()) - - dataloader = create_dataloader(unique_negative) - input_nodes, dst_seeds, blocks = next(iter(dataloader)) - blocks = [blk.to(self.device) for blk in blocks] - assert dst_seeds.shape == unique_negative.shape - assert dst_seeds.tolist() == unique_negative.tolist() - unique_negative_random = self._logits_batch(input_nodes, blocks, train_embeddings) # use_types, ntypes) - negative_random = unique_negative_random[slice_map.to(self.device)] - labels_neg = torch.full((batch_size * k,), self.negative_label, dtype=self.label_dtype) - - self.target_embedder.set_embed(unique_negative.detach().cpu().numpy(), unique_negative_random.detach().cpu().numpy()) - - nodes = torch.cat([node_embeddings_batch, node_embeddings_neg_batch], dim=0) - embs = torch.cat([next_call_embeddings, negative_random], dim=0) - labels = torch.cat([labels_pos, labels_neg], 0).to(self.device) - return nodes, embs, labels - - -class GraphLinkTypeObjective(GraphLinkObjective): - def __init__( - self, name, graph_model, node_embedder, nodes, data_loading_func, device, - sampling_neighbourhood_size, batch_size, - tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 - ): - self.set_num_classes(data_loading_func) - - super(GraphLinkObjective, self).__init__( - name, graph_model, node_embedder, nodes, data_loading_func, device, - sampling_neighbourhood_size, batch_size, - tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg - ) - - def set_num_classes(self, data_loading_func): - pass - - def create_nn_link_type_predictor(self): - self.link_predictor = LinkClassifier(2 * self.graph_model.emb_size, self.num_classes).to(self.device) - self.positive_label = 1 - self.negative_label = 0 - self.label_dtype = torch.long - def create_link_predictor(self): - if self.link_predictor_type == "nn": - self.create_nn_link_predictor() - elif self.link_predictor_type == "inner_prod": - self.create_inner_prod_link_predictor() - else: - raise NotImplementedError() +# class GraphLinkTypeObjective(GraphLinkObjective): +# def __init__( +# self, name, graph_model, node_embedder, nodes, data_loading_func, device, +# sampling_neighbourhood_size, batch_size, +# tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, +# measure_scores=False, dilate_scores=1 +# ): +# self.set_num_classes(data_loading_func) +# +# super(GraphLinkObjective, self).__init__( +# name, graph_model, node_embedder, nodes, data_loading_func, device, +# sampling_neighbourhood_size, batch_size, +# tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, +# masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores +# ) +# +# def set_num_classes(self, data_loading_func): +# pass +# +# def create_nn_link_type_predictor(self): +# self.link_predictor = LinkClassifier(2 * self.graph_model.emb_size, self.num_classes).to(self.device) +# self.positive_label = 1 +# self.negative_label = 0 +# self.label_dtype = torch.long +# +# def create_link_predictor(self): +# if self.link_predictor_type == "nn": +# self.create_nn_link_predictor() +# elif self.link_predictor_type == "inner_prod": +# self.create_inner_prod_link_predictor() +# else: +# raise NotImplementedError() diff --git a/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py b/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py index cd568ec8..8d26f45f 100644 --- a/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py +++ b/SourceCodeTools/models/graph/train/objectives/NodeClassificationObjective.py @@ -1,31 +1,33 @@ -from collections import OrderedDict +import logging +from collections import OrderedDict, defaultdict from itertools import chain import torch -from sklearn.metrics import ndcg_score +from sklearn.metrics import ndcg_score, top_k_accuracy_score from torch import nn from torch.nn import CrossEntropyLoss -from SourceCodeTools.code.data.sourcetrail import SubwordMasker -from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, _compute_accuracy +from SourceCodeTools.code.data.dataset import SubwordMasker +from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, compute_accuracy, \ + sum_scores from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase from SourceCodeTools.models.graph.train.Scorer import Scorer import numpy as np -class NodeNameClassifier(AbstractObjective): +class NodeClassifierObjective(AbstractObjective): def __init__( - self, graph_model, node_embedder, nodes, data_loading_func, device, + self, name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type=None, masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1, early_stopping=False, early_stopping_tolerance=20 ): super().__init__( - "NodeNameClassifier", graph_model, node_embedder, nodes, data_loading_func, device, + name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, early_stopping=early_stopping, early_stopping_tolerance=early_stopping_tolerance ) def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): @@ -36,63 +38,85 @@ def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): def create_link_predictor(self): self.classifier = NodeClassifier(self.target_emb_size, self.target_embedder.num_classes).to(self.device) - def compute_acc_loss(self, graph_emb, labels, return_logits=False): + def compute_acc_loss(self, graph_emb, element_emb, labels, return_logits=False): logits = self.classifier(graph_emb) loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) - acc = _compute_accuracy(logits.argmax(dim=1), labels) + acc = compute_accuracy(logits.argmax(dim=1), labels) if return_logits: return acc, loss, logits return acc, loss - def forward(self, input_nodes, seeds, blocks, train_embeddings=True): - masked = self.masker.get_mask(self.seeds_to_python(seeds)) if self.masker is not None else None - graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked) + def prepare_for_prediction( + self, node_embeddings, seeds, target_embedding_fn, negative_factor=1, + neg_sampling_strategy=None, train_embeddings=True, + ): indices = self.seeds_to_global(seeds).tolist() labels = torch.LongTensor(self.target_embedder[indices]).to(self.device) - acc, loss = self.compute_acc_loss(graph_emb, labels) - return loss, acc + return node_embeddings, None, labels - def evaluate_generation(self, data_split): - total_loss = 0 - total_acc = 0 - ndcg_at = [1, 3, 5, 10] - total_ndcg = {f"ndcg@{k}": 0. for k in ndcg_at} - ndcg_count = 0 + def evaluate_objective(self, data_split, neg_sampling_strategy=None, negative_factor=1): + at = [1, 3, 5, 10] count = 0 + scores = defaultdict(list) for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"): blocks = [blk.to(self.device) for blk in blocks] - src_embs = self._logits_batch(input_nodes, blocks) - indices = self.seeds_to_global(seeds).tolist() - labels = self.target_embedder[indices] - labels = torch.LongTensor(labels).to(self.device) - acc, loss, logits = self.compute_acc_loss(src_embs, labels, return_logits=True) + if self.masker is None: + masked = None + else: + masked = self.masker.get_mask(self.seeds_to_python(seeds)) + + src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked) + node_embs_, element_embs_, labels = self.prepare_for_prediction( + src_embs, seeds, self.target_embedding_fn, negative_factor=negative_factor, + neg_sampling_strategy=neg_sampling_strategy, + train_embeddings=False + ) + # indices = self.seeds_to_global(seeds).tolist() + # labels = self.target_embedder[indices] + # labels = torch.LongTensor(labels).to(self.device) + acc, loss, logits = self.compute_acc_loss(node_embs_, element_embs_, labels, return_logits=True) y_pred = nn.functional.softmax(logits, dim=-1).to("cpu").numpy() y_true = np.zeros(y_pred.shape) y_true[np.arange(0, y_true.shape[0]), labels.to("cpu").numpy()] = 1. - if self.measure_ndcg: - if count % self.dilate_ndcg == 0: - for k in ndcg_at: - total_ndcg[f"ndcg@{k}"] += ndcg_score(y_true, y_pred, k=k) - ndcg_count += 1 + if self.measure_scores: + if y_pred.shape[1] == 2: + logging.warning("Scores are meaningless for binary classification. Disabling.") + self.measure_scores = False + else: + if count % self.dilate_scores == 0: + y_true_onehot = np.array(y_true) + labels = list(range(y_true_onehot.shape[1])) + + for k in at: + if k >= y_pred.shape[1]: # do not measure for binary classification + if not hasattr(self, f"meaning_scores_warning_{k}"): + logging.warning(f"Disabling @{k} scores for task with {y_pred.shape[1]} classes") + setattr(self, f"meaning_scores_warning_{k}", True) + continue # scores do not have much sense in this situation + scores[f"ndcg@{k}"].append(ndcg_score(y_true, y_pred, k=k)) + scores[f"acc@{k}"].append( + top_k_accuracy_score(y_true_onehot.argmax(-1), y_pred, k=k, labels=labels) + ) + + scores["Loss"].append(loss.item()) + scores["Accuracy"].append(acc) + count += 1 - total_loss += loss.item() - total_acc += acc + if count == 0: count += 1 - return total_loss / count, total_acc / count, {key: val / ndcg_count for key, val in total_ndcg.items()} if self.measure_ndcg else None - def evaluate(self, data_split, neg_sampling_factor=1): - loss, acc, bleu = self.evaluate_generation(data_split) - return loss, acc, bleu + scores = {key: sum_scores(val) for key, val in scores.items()} + return scores def parameters(self, recurse: bool = True): return chain(self.classifier.parameters()) @@ -100,10 +124,25 @@ def parameters(self, recurse: bool = True): def custom_state_dict(self): state_dict = OrderedDict() for k, v in self.classifier.state_dict().items(): - state_dict[f"target_embedder.{k}"] = v + state_dict[f"classifier.{k}"] = v return state_dict +class NodeNameClassifier(NodeClassifierObjective): + def __init__( + self, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=None, target_emb_size=None, link_predictor_type=None, masker: SubwordMasker = None, + measure_scores=False, dilate_scores=1, early_stopping=False, early_stopping_tolerance=20 + ): + super().__init__( + "NodeNameClassifier", graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, early_stopping=early_stopping, early_stopping_tolerance=early_stopping_tolerance + ) + + class ClassifierTargetMapper(ElementEmbedderBase, Scorer): def __init__(self, elements, nodes): ElementEmbedderBase.__init__(self, elements=elements, nodes=nodes,) @@ -112,7 +151,7 @@ def __init__(self, elements, nodes): def set_embed(self, *args, **kwargs): pass - def prepare_index(self): + def prepare_index(self, *args): pass diff --git a/SourceCodeTools/models/graph/train/objectives/SubgraphClassifierObjective.py b/SourceCodeTools/models/graph/train/objectives/SubgraphClassifierObjective.py new file mode 100644 index 00000000..fe54b393 --- /dev/null +++ b/SourceCodeTools/models/graph/train/objectives/SubgraphClassifierObjective.py @@ -0,0 +1,398 @@ +import logging +from collections import OrderedDict, defaultdict +from itertools import chain +from typing import Optional + +import numpy as np +import torch +from sklearn.metrics import ndcg_score, top_k_accuracy_score +from torch import nn +from tqdm import tqdm + +from SourceCodeTools.code.data.dataset.SubwordMasker import SubwordMasker +from SourceCodeTools.code.data.file_utils import unpersist +from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedderWithBpeSubwords +from SourceCodeTools.models.graph.ElementEmbedderBase import ElementEmbedderBase +from SourceCodeTools.models.graph.train.Scorer import Scorer +from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, sum_scores +from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeClassifier, \ + NodeClassifierObjective +from SourceCodeTools.tabular.common import compact_property + + +class SubgraphLoader: + def __init__(self, ids, subgraph_mapping, loading_fn, batch_size, graph_node_types): + self.ids = ids + self.loading_fn = loading_fn + self.subgraph_mapping = subgraph_mapping + self.iterator = None + self.batch_size = batch_size + self.graph_node_types = graph_node_types + + def __iter__(self): + # TODO + # supports only nodes without types + + for i in range(0, len(self.ids), self.batch_size): + + node_ids = dict() + subgraphs = dict() + + batch_ids = self.ids[i: i + self.batch_size] + for id_ in batch_ids: + subgraph_nodes = self._get_subgraph(id_) + subgraphs[id_] = subgraph_nodes + + for type_ in subgraph_nodes: + if type_ not in node_ids: + node_ids[type_] = set() + node_ids[type_].update(subgraph_nodes[type_]) + + for type_ in node_ids: + node_ids[type_] = sorted(list(node_ids[type_])) + + coincidence_matrix = [] + for id_, subgraph in subgraphs.items(): + coincidence_matrix.append([]) + for type_ in self.graph_node_types: + subgraph_nodes = subgraph[type_] + for node_id in node_ids[type_]: + coincidence_matrix[-1].append(node_id in subgraph_nodes) + + coincidence_matrix = torch.BoolTensor(coincidence_matrix) + + loader = self.loading_fn(node_ids) + + for input_nodes, seeds, blocks in loader: + yield input_nodes, (coincidence_matrix, torch.LongTensor(batch_ids)), blocks + + # for id_ in self.ids: + # idx = self._get_subgraph(id_) + # loader = self.loading_fn(idx) + # for input_nodes, seeds, blocks in loader: + # yield input_nodes, torch.LongTensor([id_]), blocks + + def _get_subgraph(self, id_): + return self.subgraph_mapping[id_] + + +class SubgraphAbstractObjective(AbstractObjective): + def __init__( + self, name, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", + masker: Optional[SubwordMasker] = None, measure_scores=False, dilate_scores=1, + early_stopping=False, early_stopping_tolerance=20, nn_index="brute", + ns_groups=None, subgraph_mapping=None, subgraph_partition=None + ): + assert subgraph_partition is not None, "Provide train/val/test splits with `subgraph_partition` option" + self.subgraph_mapping = subgraph_mapping + self.subgraph_partition = unpersist(subgraph_partition) + super(SubgraphAbstractObjective, self).__init__( + name, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path, target_emb_size, link_predictor_type, masker, + measure_scores, dilate_scores, early_stopping, early_stopping_tolerance, nn_index, + ns_groups + ) + + self.target_embedding_fn = self.get_targets_from_embedder + self.negative_factor = 1 + self.update_embeddings_for_queries = False + + def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): + self.target_embedder = SubgraphElementEmbedderWithSubwords( + data_loading_func(), self.target_emb_size, tokenizer_path + ) + + def _get_training_targets(self): + + if hasattr(self.graph_model.g, 'ntypes'): + self.ntypes = self.graph_model.g.ntypes + self.use_types = True + + if len(self.graph_model.g.ntypes) == 1: + self.use_types = False + else: + # not sure when this is called + raise NotImplementedError() + + train_idx = self.subgraph_partition[ + self.subgraph_partition["train_mask"] + ]["id"].to_numpy() + val_idx = self.subgraph_partition[ + self.subgraph_partition["val_mask"] + ]["id"].to_numpy() + test_idx = self.subgraph_partition[ + self.subgraph_partition["test_mask"] + ]["id"].to_numpy() + + return train_idx, val_idx, test_idx + + def create_loaders(self): + print("Number of nodes", self.graph_model.g.number_of_nodes()) + train_idx, val_idx, test_idx = self._get_training_targets() + train_idx, val_idx, test_idx = self.target_embedder.create_idx_pools( + train_idx=train_idx, val_idx=val_idx, test_idx=test_idx + ) + logging.info( + f"Pool sizes for {self.name}: train {self._idx_len(train_idx)}, " + f"val {self._idx_len(val_idx)}, " + f"test {self._idx_len(test_idx)}." + ) + loaders = self._get_loaders( + train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, + batch_size=self.batch_size # batch_size_node_name + ) + self.train_loader, self.val_loader, self.test_loader = loaders + + # def get_num_nodes(ids): + # return sum(len(ids[key_]) for key_ in ids) // self.batch_size + 1 + + self.num_train_batches = len(train_idx) // self.batch_size + 1 + self.num_test_batches = len(test_idx) // self.batch_size + 1 + self.num_val_batches = len(val_idx) // self.batch_size + 1 + + def _get_loaders(self, train_idx, val_idx, test_idx, batch_size): + + logging.info("Batch size is ignored for subgraphs") + + subgraph_mapping = self.subgraph_mapping + + train_loader = SubgraphLoader(train_idx, subgraph_mapping, self._create_loader, batch_size, self.graph_model.g.ntypes) + val_loader = SubgraphLoader(val_idx, subgraph_mapping, self._create_loader, batch_size, self.graph_model.g.ntypes) + test_loader = SubgraphLoader(test_idx, subgraph_mapping, self._create_loader, batch_size, self.graph_model.g.ntypes) + + return train_loader, val_loader, test_loader + + def parameters(self, recurse: bool = True): + return chain(self.target_embedder.parameters(), self.link_predictor.parameters()) + + def custom_state_dict(self): + state_dict = OrderedDict() + for k, v in self.target_embedder.state_dict().items(): + state_dict[f"target_embedder.{k}"] = v + for k, v in self.link_predictor.state_dict().items(): + state_dict[f"link_predictor.{k}"] = v + return state_dict + + def custom_load_state_dict(self, state_dicts): + self.target_embedder.load_state_dict( + self.get_prefix("target_embedder", state_dicts) + ) + self.link_predictor.load_state_dict( + self.get_prefix("link_predictor", state_dicts) + ) + + def pooling_fn(self, node_embeddings): + return torch.mean(node_embeddings, dim=0, keepdim=True) + + def _graph_embeddings(self, input_nodes, blocks, train_embeddings=True, masked=None, subgraph_masks=None): + node_embs = super(SubgraphAbstractObjective, self)._graph_embeddings( + input_nodes, blocks, train_embeddings, masked + ) + + subgraph_embs = [] + for subgraph_mask in subgraph_masks: + subgraph_embs.append(self.pooling_fn(node_embs[subgraph_mask])) + + return torch.cat(subgraph_embs, dim=0) + + def forward(self, input_nodes, seeds, blocks, train_embeddings=True, neg_sampling_strategy=None): + subgraph_masks, seeds = seeds + masked = self.masker.get_mask(self.seeds_to_python(seeds)) if self.masker is not None else None + graph_emb = self._graph_embeddings(input_nodes, blocks, train_embeddings, masked=masked, subgraph_masks=subgraph_masks) + subgraph_embs_, element_embs_, labels = self.prepare_for_prediction( + graph_emb, seeds, self.target_embedding_fn, negative_factor=self.negative_factor, + neg_sampling_strategy=neg_sampling_strategy, + train_embeddings=train_embeddings + ) + + acc, loss = self.compute_acc_loss(subgraph_embs_, element_embs_, labels) + + return loss, acc + + def evaluate_objective(self, data_split, neg_sampling_strategy=None, negative_factor=1): + at = [1, 3, 5, 10] + count = 0 + + scores = defaultdict(list) + + for input_nodes, seeds, blocks in tqdm( + getattr(self, f"{data_split}_loader"), total=getattr(self, f"num_{data_split}_batches") + ): + blocks = [blk.to(self.device) for blk in blocks] + + subgraph_masks, seeds = seeds + + if self.masker is None: + masked = None + else: + masked = self.masker.get_mask(self.seeds_to_python(seeds)) + + src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked, subgraph_masks=subgraph_masks) + node_embs_, element_embs_, labels = self.prepare_for_prediction( + src_embs, seeds, self.target_embedding_fn, negative_factor=negative_factor, + neg_sampling_strategy=neg_sampling_strategy, + train_embeddings=False + ) + + if self.measure_scores: + if count % self.dilate_scores == 0: + scores_ = self.target_embedder.score_candidates(self.seeds_to_global(seeds), src_embs, + self.link_predictor, at=at, + type=self.link_predictor_type, device=self.device) + for key, val in scores_.items(): + scores[key].append(val) + + acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) + + scores["Loss"].append(loss.item()) + scores["Accuracy"].append(acc) + count += 1 + + scores = {key: sum_scores(val) for key, val in scores.items()} + return scores + + def verify_parameters(self): + pass + + +class SubgraphEmbeddingObjective(SubgraphAbstractObjective): + def __init__( + self, name, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", + masker: Optional[SubwordMasker] = None, measure_scores=False, dilate_scores=1, + early_stopping=False, early_stopping_tolerance=20, nn_index="brute", + ns_groups=None, subgraph_mapping=None, subgraph_partition=None + ): + super(SubgraphEmbeddingObjective, self).__init__( + name, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path, target_emb_size, link_predictor_type, + masker, measure_scores, dilate_scores, early_stopping, early_stopping_tolerance, nn_index, + ns_groups, subgraph_mapping, subgraph_partition + ) + + +class SubgraphClassifierObjective(NodeClassifierObjective, SubgraphAbstractObjective): + def __init__( + self, name, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", + masker: Optional[SubwordMasker] = None, measure_scores=False, dilate_scores=1, + early_stopping=False, early_stopping_tolerance=20, nn_index="brute", + ns_groups=None, subgraph_mapping=None, subgraph_partition=None + ): + SubgraphAbstractObjective.__init__(self, + name, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path, target_emb_size, link_predictor_type, + masker, measure_scores, dilate_scores, early_stopping, early_stopping_tolerance, nn_index, + ns_groups, subgraph_mapping, subgraph_partition + ) + + def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): + self.target_embedder = SubgraphClassifierTargetMapper( + elements=data_loading_func() + ) + + def evaluate_objective(self, data_split, neg_sampling_strategy=None, negative_factor=1): + at = [1, 3, 5, 10] + count = 0 + scores = defaultdict(list) + + for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"): + blocks = [blk.to(self.device) for blk in blocks] + + subgraph_masks, seeds = seeds + + if self.masker is None: + masked = None + else: + masked = self.masker.get_mask(self.seeds_to_python(seeds)) + + src_embs = self._graph_embeddings(input_nodes, blocks, masked=masked, subgraph_masks=subgraph_masks) + node_embs_, element_embs_, labels = self.prepare_for_prediction( + src_embs, seeds, self.target_embedding_fn, negative_factor=negative_factor, + neg_sampling_strategy=neg_sampling_strategy, + train_embeddings=False + ) + # indices = self.seeds_to_global(seeds).tolist() + # labels = self.target_embedder[indices] + # labels = torch.LongTensor(labels).to(self.device) + acc, loss, logits = self.compute_acc_loss(node_embs_, element_embs_, labels, return_logits=True) + + y_pred = nn.functional.softmax(logits, dim=-1).to("cpu").numpy() + y_true = np.zeros(y_pred.shape) + y_true[np.arange(0, y_true.shape[0]), labels.to("cpu").numpy()] = 1. + + if self.measure_scores: + if count % self.dilate_scores == 0: + y_true_onehot = np.array(y_true) + labels = list(range(y_true_onehot.shape[1])) + + for k in at: + scores[f"ndcg@{k}"].append(ndcg_score(y_true, y_pred, k=k)) + scores[f"acc@{k}"].append( + top_k_accuracy_score(y_true_onehot.argmax(-1), y_pred, k=k, labels=labels) + ) + + scores["Loss"].append(loss.item()) + scores["Accuracy"].append(acc) + count += 1 + + if count == 0: + count += 1 + + scores = {key: sum_scores(val) for key, val in scores.items()} + return scores + + def parameters(self, recurse: bool = True): + return chain(self.classifier.parameters()) + + def custom_state_dict(self): + state_dict = OrderedDict() + for k, v in self.classifier.state_dict().items(): + state_dict[f"classifier.{k}"] = v + return state_dict + + +class SubgraphElementEmbedderBase(ElementEmbedderBase): + def __init__(self, elements, compact_dst=True): + # super(ElementEmbedderBase, self).__init__() + self.elements = elements.rename({"src": "id"}, axis=1) + self.init(compact_dst) + + def preprocess_element_data(self, *args, **kwargs): + pass + + def create_idx_pools(self, train_idx, val_idx, test_idx): + pool = set(self.elements["id"]) + train_pool, val_pool, test_pool = self._create_pools(train_idx, val_idx, test_idx, pool) + return train_pool, val_pool, test_pool + + +class SubgraphElementEmbedderWithSubwords(SubgraphElementEmbedderBase, ElementEmbedderWithBpeSubwords): + def __init__(self, elements, emb_size, tokenizer_path, num_buckets=100000, max_len=10): + self.tokenizer_path = tokenizer_path + SubgraphElementEmbedderBase.__init__(self, elements=elements, compact_dst=False) + nn.Module.__init__(self) + Scorer.__init__(self, num_embs=len(self.elements["dst"].unique()), emb_size=emb_size, + src2dst=self.element_lookup) + + self.emb_size = emb_size + self.init_subwords(elements, num_buckets=num_buckets, max_len=max_len) + + +class SubgraphClassifierTargetMapper(SubgraphElementEmbedderBase, Scorer): + def __init__(self, elements): + SubgraphElementEmbedderBase.__init__(self, elements=elements) + self.num_classes = len(self.inverse_dst_map) + + def set_embed(self, *args, **kwargs): + pass + + def prepare_index(self, *args): + pass \ No newline at end of file diff --git a/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py b/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py index ff83d6e0..86cc3782 100644 --- a/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py +++ b/SourceCodeTools/models/graph/train/objectives/SubwordEmbedderObjective.py @@ -1,7 +1,7 @@ from collections import OrderedDict from itertools import chain -from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker +from SourceCodeTools.code.data.dataset.SubwordMasker import SubwordMasker from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective @@ -10,14 +10,17 @@ def __init__( self, name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1, nn_index="brute" ): super(SubwordEmbedderObjective, self).__init__( name, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index ) + self.target_embedding_fn = self.get_targets_from_embedder + self.negative_factor = 1 + self.update_embeddings_for_queries = False def verify_parameters(self): if self.link_predictor_type == "inner_prod": @@ -26,34 +29,6 @@ def verify_parameters(self): def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): self.create_subword_embedder(data_loading_func, nodes, tokenizer_path) - def create_link_predictor(self): - if self.link_predictor_type == "nn": - self.create_nn_link_predictor() - elif self.link_predictor_type == "inner_prod": - self.create_inner_prod_link_predictor() - else: - raise NotImplementedError() - - def forward(self, input_nodes, seeds, blocks, train_embeddings=True): - masked = self.masker.get_mask(self.seeds_to_python(seeds)) if self.masker is not None else None - graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings, masked=masked) - node_embs_, element_embs_, labels = self._logits_embedder( - graph_emb, self.target_embedder, self.link_predictor, seeds - ) - acc, loss = self.compute_acc_loss(node_embs_, element_embs_, labels) - - return loss, acc - - # def train(self): - # pass - - def evaluate(self, data_split, neg_sampling_factor=1): - loss, acc, ndcg = self._evaluate_embedder( - self.target_embedder, self.link_predictor, data_split=data_split, - neg_sampling_factor=neg_sampling_factor - ) - return loss, acc, ndcg - def parameters(self, recurse: bool = True): return chain(self.target_embedder.parameters(), self.link_predictor.parameters()) diff --git a/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py b/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py index 56ff5d8a..de941d73 100644 --- a/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py +++ b/SourceCodeTools/models/graph/train/objectives/TextPredictionObjective.py @@ -1,19 +1,18 @@ -from collections import OrderedDict +from collections import OrderedDict, defaultdict from itertools import chain import datasets import torch -from torch import nn -from torch.nn import CrossEntropyLoss, NLLLoss +from torch.nn import CrossEntropyLoss -from SourceCodeTools.code.data.sourcetrail import SubwordMasker +from SourceCodeTools.code.data.dataset import SubwordMasker from SourceCodeTools.models.graph.ElementEmbedder import DocstringEmbedder, create_fixed_length, \ ElementEmbedderWithBpeSubwords from SourceCodeTools.models.graph.train.objectives import SubwordEmbedderObjective -from SourceCodeTools.models.graph.train.objectives.AbstractObjective import AbstractObjective, _compute_accuracy -from SourceCodeTools.models.nlp.Decoder import LSTMDecoder, Decoder +from SourceCodeTools.models.graph.train.objectives.AbstractObjective import compute_accuracy, \ + sum_scores +from SourceCodeTools.models.nlp.TorchDecoder import Decoder from SourceCodeTools.models.nlp.Vocabulary import Vocabulary -from SourceCodeTools.nlp.embed.bpe import load_bpe_model import numpy as np @@ -22,13 +21,13 @@ def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1 ): super().__init__( "GraphTextPrediction", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores ) def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): @@ -43,14 +42,14 @@ def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1, max_len=20 + measure_scores=False, dilate_scores=1, max_len=20 ): self.max_len = max_len + 2 # add pad and eos super().__init__( "GraphTextGeneration", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores ) def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): @@ -85,17 +84,21 @@ def compute_acc_loss(self, graph_emb, labels, lengths, return_logits=False): max_len = logits.shape[1] length_mask = torch.arange(max_len).to(self.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) + logits_unrolled = logits[length_mask, :] + labels_unrolled = labels[length_mask] + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(logits_unrolled, labels_unrolled) # mask_ = length_mask.reshape(-1,) # loss_fct = NLLLoss(reduction="sum") - loss = loss_fct(logits.reshape(-1, logits.size(-1)),#[mask_, :], - labels.reshape(-1)) #[mask_]) + # loss = loss_fct(logits.reshape(-1, logits.size(-1)),#[mask_, :], + # labels.reshape(-1)) #[mask_]) def masked_accuracy(pred, true, mask): mask = mask.reshape(-1,) pred = pred.reshape(-1,)[mask] true = true.reshape(-1,)[mask] - return _compute_accuracy(pred, true) + return compute_accuracy(pred, true) acc = masked_accuracy(logits.argmax(dim=2), labels, length_mask) @@ -103,9 +106,8 @@ def masked_accuracy(pred, true, mask): return acc, loss, logits return acc, loss - - def forward(self, input_nodes, seeds, blocks, train_embeddings=True): - graph_emb = self._logits_batch(input_nodes, blocks, train_embeddings) + def forward(self, input_nodes, seeds, blocks, train_embeddings=True, neg_sampling_strategy=None): + graph_emb = self._graph_embeddings(input_nodes, blocks, train_embeddings) indices = self.seeds_to_global(seeds).tolist() labels, lengths = self.target_embedder[indices] labels = labels.to(self.device) @@ -129,16 +131,13 @@ def get_generated(self, tokens): return sents def evaluate_generation(self, data_split): - total_loss = 0 - total_acc = 0 - total_bleu = {f"bleu": 0.} - bleu_count = 0 + scores = defaultdict(list) count = 0 for input_nodes, seeds, blocks in getattr(self, f"{data_split}_loader"): blocks = [blk.to(self.device) for blk in blocks] - src_embs = self._logits_batch(input_nodes, blocks) + src_embs = self._graph_embeddings(input_nodes, blocks) indices = self.seeds_to_global(seeds).tolist() labels, lengths = self.target_embedder[indices] labels = labels.to(self.device) @@ -154,16 +153,19 @@ def evaluate_generation(self, data_split): # bleu = sacrebleu.corpus_bleu(pred, true) bleu = self.bleu_metric.compute(predictions=pred, references=[[t] for t in true]) - bleu_count += 1 - total_bleu["bleu"] += bleu['score'] + scores["bleu"].append(bleu['score']) - total_loss += loss.item() - total_acc += acc + scores["Loss"].append(loss.item()) + scores["Accuracy"].append(acc) count += 1 - return total_loss / count, total_acc / count, {"bleu": total_bleu["bleu"] / bleu_count} - def evaluate(self, data_split, neg_sampling_factor=1): + scores = {key: sum_scores(val) for key, val in scores.items()} + return scores + + def evaluate(self, data_split, *, neg_sampling_strategy=None, early_stopping=False, early_stopping_tolerance=20): loss, acc, bleu = self.evaluate_generation(data_split) + if data_split == "val": + self.check_early_stopping(acc) return loss, acc, bleu def parameters(self, recurse: bool = True): diff --git a/SourceCodeTools/models/graph/train/objectives/__init__.py b/SourceCodeTools/models/graph/train/objectives/__init__.py index 57b55612..9721cfa8 100644 --- a/SourceCodeTools/models/graph/train/objectives/__init__.py +++ b/SourceCodeTools/models/graph/train/objectives/__init__.py @@ -1,6 +1,8 @@ -from SourceCodeTools.code.data.sourcetrail.SubwordMasker import SubwordMasker -from SourceCodeTools.models.graph.train.objectives.GraphLinkObjective import GraphLinkObjective, GraphLinkTypeObjective -from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeNameClassifier +from SourceCodeTools.code.data.dataset.SubwordMasker import SubwordMasker +from SourceCodeTools.models.graph.train.objectives.GraphLinkClassificationObjective import \ + GraphLinkClassificationObjective +from SourceCodeTools.models.graph.train.objectives.GraphLinkObjective import GraphLinkObjective +from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeNameClassifier, NodeClassifierObjective from SourceCodeTools.models.graph.train.objectives.SubwordEmbedderObjective import SubwordEmbedderObjective from SourceCodeTools.models.graph.train.objectives.TextPredictionObjective import GraphTextPrediction, GraphTextGeneration @@ -10,13 +12,13 @@ def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1 ): super(TokenNamePrediction, self).__init__( "TokenNamePrediction", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores ) @@ -25,13 +27,29 @@ def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1, nn_index="brute" ): super(NodeNamePrediction, self).__init__( "NodeNamePrediction", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index + ) + + +class TypeAnnPrediction(NodeClassifierObjective): + def __init__( + self, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=None, target_emb_size=None, link_predictor_type=None, masker: SubwordMasker = None, + measure_scores=False, dilate_scores=1 + ): + super(TypeAnnPrediction, self).__init__( + "TypeAnnPrediction", graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + # tokenizer_path=tokenizer_path, + target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, + masker=masker, dilate_scores=dilate_scores ) @@ -40,13 +58,13 @@ def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1 ): super(VariableNameUsePrediction, self).__init__( "VariableNamePrediction", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores ) @@ -55,13 +73,13 @@ def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1 ): super(NextCallPrediction, self).__init__( "NextCallPrediction", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores ) @@ -70,26 +88,68 @@ def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1 ): super(GlobalLinkPrediction, self).__init__( "GlobalLinkPrediction", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores ) -class LinkTypePrediction(GraphLinkTypeObjective): +# class LinkTypePrediction(GraphLinkTypeObjective): +# def __init__( +# self, graph_model, node_embedder, nodes, data_loading_func, device, +# sampling_neighbourhood_size, batch_size, +# tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, +# measure_scores=False, dilate_scores=1 +# ): +# super(GraphLinkTypeObjective, self).__init__( +# "LinkTypePrediction", graph_model, node_embedder, nodes, data_loading_func, device, +# sampling_neighbourhood_size, batch_size, +# tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, +# masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores +# ) + + +class EdgePrediction(GraphLinkObjective): def __init__( self, graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, - measure_ndcg=False, dilate_ndcg=1 + measure_scores=False, dilate_scores=1, nn_index="brute", ns_groups=None ): - super(GraphLinkTypeObjective, self).__init__( - "LinkTypePrediction", graph_model, node_embedder, nodes, data_loading_func, device, + super(EdgePrediction, self).__init__( + "EdgePrediction", graph_model, node_embedder, nodes, data_loading_func, device, sampling_neighbourhood_size, batch_size, tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, - masker=masker, measure_ndcg=measure_ndcg, dilate_ndcg=dilate_ndcg + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores, nn_index=nn_index, + ns_groups=ns_groups ) + # super(EdgePrediction, self).__init__( + # "LinkTypePrediction", graph_model, node_embedder, nodes, data_loading_func, device, + # sampling_neighbourhood_size, batch_size, + # tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, + # masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores + # ) + +class EdgePrediction2(SubwordEmbedderObjective): + def __init__( + self, graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=None, target_emb_size=None, link_predictor_type="inner_prod", masker: SubwordMasker = None, + measure_scores=False, dilate_scores=1 + ): + super(EdgePrediction2, self).__init__( + "EdgePrediction2", graph_model, node_embedder, nodes, data_loading_func, device, + sampling_neighbourhood_size, batch_size, + tokenizer_path=tokenizer_path, target_emb_size=target_emb_size, link_predictor_type=link_predictor_type, + masker=masker, measure_scores=measure_scores, dilate_scores=dilate_scores + ) + + def create_target_embedder(self, data_loading_func, nodes, tokenizer_path): + from SourceCodeTools.models.graph.ElementEmbedder import ElementEmbedder + self.target_embedder = ElementEmbedder( + elements=data_loading_func(), nodes=nodes, emb_size=self.target_emb_size, + ).to(self.device) \ No newline at end of file diff --git a/SourceCodeTools/models/graph/train/sampling_multitask2.py b/SourceCodeTools/models/graph/train/sampling_multitask2.py index 46c9d873..dbd1c394 100644 --- a/SourceCodeTools/models/graph/train/sampling_multitask2.py +++ b/SourceCodeTools/models/graph/train/sampling_multitask2.py @@ -1,7 +1,10 @@ import os +from collections import defaultdict from copy import copy +from pprint import pprint from typing import Tuple +import numpy as np import torch import torch.nn as nn from torch.utils.tensorboard import SummaryWriter @@ -10,68 +13,98 @@ from os.path import join import logging +from tqdm import tqdm + from SourceCodeTools.models.Embedder import Embedder from SourceCodeTools.models.graph.train.objectives import VariableNameUsePrediction, TokenNamePrediction, \ NextCallPrediction, NodeNamePrediction, GlobalLinkPrediction, GraphTextPrediction, GraphTextGeneration, \ - NodeNameClassifier + NodeNameClassifier, EdgePrediction, TypeAnnPrediction, EdgePrediction2, NodeClassifierObjective from SourceCodeTools.models.graph.NodeEmbedder import NodeEmbedder +from SourceCodeTools.models.graph.train.objectives.GraphLinkClassificationObjective import TransRObjective +from SourceCodeTools.models.graph.train.objectives.SubgraphClassifierObjective import SubgraphAbstractObjective, \ + SubgraphClassifierObjective, SubgraphEmbeddingObjective + + +class EarlyStopping(Exception): + def __init__(self, *args, **kwargs): + super(EarlyStopping, self).__init__(*args, **kwargs) -def _compute_accuracy(pred_, true_): - return torch.sum(pred_ == true_).item() / len(true_) +def add_to_summary(summary, partition, objective_name, scores, postfix): + summary.update({ + f"{key}/{partition}/{objective_name}_{postfix}": val for key, val in scores.items() + }) class SamplingMultitaskTrainer: - def __init__(self, - dataset=None, model_name=None, model_params=None, - trainer_params=None, restore=None, device=None, - pretrained_embeddings_path=None, - tokenizer_path=None - ): + def __init__( + self, dataset=None, model_name=None, model_params=None, trainer_params=None, restore=None, device=None, + pretrained_embeddings_path=None, tokenizer_path=None, load_external_dataset=None + ): self.graph_model = model_name(dataset.g, **model_params).to(device) self.model_params = model_params self.trainer_params = trainer_params self.device = device self.epoch = 0 + self.restore_epoch = 0 self.batch = 0 self.dtype = torch.float32 + if load_external_dataset is not None: + logging.info("Loading external dataset") + external_args, external_dataset = load_external_dataset() + self.graph_model.g = external_dataset.g + dataset = external_dataset self.create_node_embedder( dataset, tokenizer_path, n_dims=model_params["h_dim"], pretrained_path=pretrained_embeddings_path, n_buckets=trainer_params["embedding_table_size"] ) - self.summary_writer = SummaryWriter(self.model_base_path) - self.create_objectives(dataset, tokenizer_path) if restore: self.restore_from_checkpoint(self.model_base_path) - self.optimizer = self._create_optimizer() + if load_external_dataset is not None: + self.trainer_params["model_base_path"] = external_args.external_model_base + + self._create_optimizer() self.lr_scheduler = ExponentialLR(self.optimizer, gamma=1.0) # self.lr_scheduler = ReduceLROnPlateau(self.optimizer, patience=10, cooldown=20) + self.summary_writer = SummaryWriter(self.model_base_path) def create_objectives(self, dataset, tokenizer_path): + objective_list = self.trainer_params["objectives"] + self.objectives = nn.ModuleList() - if "token_pred" in self.trainer_params["objectives"]: + if "token_pred" in objective_list: self.create_token_pred_objective(dataset, tokenizer_path) - if "node_name_pred" in self.trainer_params["objectives"]: + if "node_name_pred" in objective_list: self.create_node_name_objective(dataset, tokenizer_path) - if "var_use_pred" in self.trainer_params["objectives"]: + if "var_use_pred" in objective_list: self.create_var_use_objective(dataset, tokenizer_path) - if "next_call_pred" in self.trainer_params["objectives"]: + if "next_call_pred" in objective_list: self.create_api_call_objective(dataset, tokenizer_path) - if "global_link_pred" in self.trainer_params["objectives"]: + if "global_link_pred" in objective_list: self.create_global_link_objective(dataset, tokenizer_path) - if "doc_pred" in self.trainer_params["objectives"]: + if "edge_pred" in objective_list: + self.create_edge_objective(dataset, tokenizer_path) + if "transr" in objective_list: + self.create_transr_objective(dataset, tokenizer_path) + if "doc_pred" in objective_list: self.create_text_prediction_objective(dataset, tokenizer_path) - if "doc_gen" in self.trainer_params["objectives"]: + if "doc_gen" in objective_list: self.create_text_generation_objective(dataset, tokenizer_path) - if "node_clf" in self.trainer_params["objectives"]: - self.create_node_name_classifier_objective(dataset, tokenizer_path) + if "node_clf" in objective_list: + self.create_node_classifier_objective(dataset, tokenizer_path) + if "type_ann_pred" in objective_list: + self.create_type_ann_objective(dataset, tokenizer_path) + if "subgraph_name_clf" in objective_list: + self.create_subgraph_name_objective(dataset, tokenizer_path) + if "subgraph_clf" in objective_list: + self.create_subgraph_classifier_objective(dataset, tokenizer_path) def create_token_pred_objective(self, dataset, tokenizer_path): self.objectives.append( @@ -80,21 +113,86 @@ def create_token_pred_objective(self, dataset, tokenizer_path): dataset.load_token_prediction, self.device, self.sampling_neighbourhood_size, self.batch_size, tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod", - masker=dataset.create_subword_masker(), measure_ndcg=self.trainer_params["measure_ndcg"], - dilate_ndcg=self.trainer_params["dilate_ndcg"] + masker=dataset.create_subword_masker(), measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"] ) ) def create_node_name_objective(self, dataset, tokenizer_path): self.objectives.append( + # GraphTextGeneration( + # self.graph_model, self.node_embedder, dataset.nodes, + # dataset.load_node_names, self.device, + # self.sampling_neighbourhood_size, self.batch_size, + # tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, + # ) NodeNamePrediction( self.graph_model, self.node_embedder, dataset.nodes, dataset.load_node_names, self.device, self.sampling_neighbourhood_size, self.batch_size, tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod", masker=dataset.create_node_name_masker(tokenizer_path), - measure_ndcg=self.trainer_params["measure_ndcg"], - dilate_ndcg=self.trainer_params["dilate_ndcg"] + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"], nn_index=self.trainer_params["nn_index"] + ) + ) + + def _create_subgraph_objective( + self, *, objective_name, objective_class, dataset, labels_fn, tokenizer_path, subgraph_mapping=None, + subgraph_partition=None, + masker=None + ): + if subgraph_mapping is None: + subgraph_mapping = dataset.subgraph_mapping + + if subgraph_partition is None: + subgraph_partition = dataset.subgraph_partition + + return objective_class( + objective_name, + self.graph_model, self.node_embedder, dataset.nodes, + labels_fn, self.device, + self.sampling_neighbourhood_size, self.batch_size, + tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod", + masker=masker, # dataset.create_node_name_masker(tokenizer_path), + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"], nn_index=self.trainer_params["nn_index"], + subgraph_mapping=subgraph_mapping, + subgraph_partition=subgraph_partition + ) + + def create_subgraph_name_objective(self, dataset, tokenizer_path): + self.objectives.append( + self._create_subgraph_objective( + objective_name="SubgraphNameEmbeddingObjective", + objective_class=SubgraphEmbeddingObjective, + dataset=dataset, + tokenizer_path=tokenizer_path, + labels_fn=dataset.load_subgraph_function_names, + ) + ) + + def create_subgraph_classifier_objective(self, dataset, tokenizer_path): + self.objectives.append( + self._create_subgraph_objective( + objective_name="SubgraphClassifierObjective", + objective_class=SubgraphClassifierObjective, + dataset=dataset, + tokenizer_path=tokenizer_path, + labels_fn=dataset.load_cubert_subgraph_labels, + ) + ) + + def create_type_ann_objective(self, dataset, tokenizer_path): + self.objectives.append( + TypeAnnPrediction( + self.graph_model, self.node_embedder, dataset.nodes, + dataset.load_type_prediction, self.device, + self.sampling_neighbourhood_size, self.batch_size, + tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod", + masker=None, + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"] ) ) @@ -106,8 +204,8 @@ def create_node_name_classifier_objective(self, dataset, tokenizer_path): self.sampling_neighbourhood_size, self.batch_size, tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, masker=dataset.create_node_name_masker(tokenizer_path), - measure_ndcg=self.trainer_params["measure_ndcg"], - dilate_ndcg=self.trainer_params["dilate_ndcg"] + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"] ) ) @@ -119,8 +217,8 @@ def create_var_use_objective(self, dataset, tokenizer_path): self.sampling_neighbourhood_size, self.batch_size, tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod", masker=dataset.create_variable_name_masker(tokenizer_path), - measure_ndcg=self.trainer_params["measure_ndcg"], - dilate_ndcg=self.trainer_params["dilate_ndcg"] + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"] ) ) @@ -131,8 +229,8 @@ def create_api_call_objective(self, dataset, tokenizer_path): dataset.load_api_call, self.device, self.sampling_neighbourhood_size, self.batch_size, tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod", - measure_ndcg=self.trainer_params["measure_ndcg"], - dilate_ndcg=self.trainer_params["dilate_ndcg"] + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"] ) ) @@ -145,8 +243,35 @@ def create_global_link_objective(self, dataset, tokenizer_path): dataset.load_global_edges_prediction, self.device, self.sampling_neighbourhood_size, self.batch_size, tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="nn", - measure_ndcg=self.trainer_params["measure_ndcg"], - dilate_ndcg=self.trainer_params["dilate_ndcg"] + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"] + ) + ) + + def create_edge_objective(self, dataset, tokenizer_path): + self.objectives.append( + EdgePrediction( + self.graph_model, self.node_embedder, dataset.nodes, + dataset.load_edge_prediction, self.device, + self.sampling_neighbourhood_size, self.batch_size, + tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type=self.trainer_params["metric"], + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"], nn_index=self.trainer_params["nn_index"], + # ns_groups=dataset.get_negative_sample_groups() + ) + ) + + def create_transr_objective(self, dataset, tokenizer_path): + self.objectives.append( + TransRObjective( + self.graph_model, self.node_embedder, dataset.nodes, + dataset.load_edge_prediction, self.device, + self.sampling_neighbourhood_size, self.batch_size, + tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, + link_predictor_type=self.trainer_params["metric"], + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"], + # ns_groups=dataset.get_negative_sample_groups() ) ) @@ -157,8 +282,8 @@ def create_text_prediction_objective(self, dataset, tokenizer_path): dataset.load_docstring, self.device, self.sampling_neighbourhood_size, self.batch_size, tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, link_predictor_type="inner_prod", - measure_ndcg=self.trainer_params["measure_ndcg"], - dilate_ndcg=self.trainer_params["dilate_ndcg"] + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"] ) ) @@ -172,6 +297,22 @@ def create_text_generation_objective(self, dataset, tokenizer_path): ) ) + def create_node_classifier_objective(self, dataset, tokenizer_path): + self.objectives.append( + NodeClassifierObjective( + "NodeTypeClassifier", + self.graph_model, self.node_embedder, dataset.nodes, + dataset.load_node_classes, self.device, + self.sampling_neighbourhood_size, self.batch_size, + tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, + masker=dataset.create_node_clf_masker(), + measure_scores=self.trainer_params["measure_scores"], + dilate_scores=self.trainer_params["dilate_scores"], + early_stopping=self.trainer_params["early_stopping"], + early_stopping_tolerance=self.trainer_params["early_stopping_tolerance"] + ) + ) + def create_node_embedder(self, dataset, tokenizer_path, n_dims=None, pretrained_path=None, n_buckets=500000): from SourceCodeTools.nlp.embed.fasttext import load_w2v_map @@ -204,7 +345,7 @@ def create_node_embedder(self, dataset, tokenizer_path, n_dims=None, pretrained_ @property def lr(self): - return self.trainer_params['lr'] + return self.trainer_params['learning_rate'] @property def batch_size(self): @@ -248,6 +389,10 @@ def finetune(self): return False return self.epoch >= self.trainer_params['pretraining_phase'] + @property + def subgraph_id_column(self): + return self.trainer_params["subgraph_id_column"] + @property def do_save(self): return self.trainer_params['save_checkpoints'] @@ -271,14 +416,31 @@ def write_hyperparams(self, scores, epoch): def _create_optimizer(self): parameters = nn.ParameterList(self.graph_model.parameters()) - parameters.extend(self.node_embedder.parameters()) + nodeembedder_params = list(self.node_embedder.parameters()) + # parameters.extend(self.node_embedder.parameters()) [parameters.extend(objective.parameters()) for objective in self.objectives] # AdaHessian TODO could not run # optimizer = Yogi(parameters, lr=self.lr) - optimizer = torch.optim.AdamW( - [{"params": parameters}], lr=self.lr + self.optimizer = torch.optim.AdamW( + [{"params": parameters}], lr=self.lr, weight_decay=0.5 + ) + self.sparse_optimizer = torch.optim.SparseAdam( + [{"params": nodeembedder_params}], lr=self.lr ) - return optimizer + + def compute_embeddings_for_scorer(self, objective): + if hasattr(objective.target_embedder, "scorer_all_keys") and objective.update_embeddings_for_queries: + def chunks(lst, n): + for i in range(0, len(lst), n): + yield torch.LongTensor(lst[i:i + n]) + + batches = chunks(objective.target_embedder.scorer_all_keys, self.trainer_params["batch_size"]) + for batch in tqdm( + batches, + total=len(objective.target_embedder.scorer_all_keys) // self.trainer_params["batch_size"] + 1, + desc="Precompute Target Embeddings", leave=True + ): + _ = objective.target_embedding_fn(batch) # scorer embedding updated inside def train_all(self): """ @@ -287,9 +449,12 @@ def train_all(self): """ summary_dict = {} + best_val_loss = float("inf") + write_best_model = False for objective in self.objectives: with torch.set_grad_enabled(False): + self.compute_embeddings_for_scorer(objective) objective.target_embedder.prepare_index() # need this to update sampler for the next epoch for epoch in range(self.epoch, self.epochs): @@ -297,10 +462,18 @@ def train_all(self): start = time() - keep_training = True - summary_dict = {} - while keep_training: + num_batches = min([objective.num_train_batches for objective in self.objectives]) + + train_losses = defaultdict(list) + train_accs = defaultdict(list) + + # def append_metric(destination, name, metric): + # if name not in destination: + # destination[name] = [] + # destination[name].append(metric) + + for step in tqdm(range(num_batches), total=num_batches, desc=f"Epoch {self.epoch}"): loss_accum = 0 @@ -312,10 +485,24 @@ def train_all(self): break self.optimizer.zero_grad() - for objective, (input_nodes, seeds, blocks) in zip(self.objectives, loaders): + self.sparse_optimizer.zero_grad() + for ind, (objective, (input_nodes, seeds, blocks)) in enumerate(zip(self.objectives, loaders)): blocks = [blk.to(self.device) for blk in blocks] - loss, acc = objective(input_nodes, seeds, blocks, train_embeddings=self.finetune) + objective.target_embedder.prepare_index() + + do_break = False + for block in blocks: + if block.num_edges() == 0: + do_break = True + if do_break: + break + + # try: + loss, acc = objective( + input_nodes, seeds, blocks, train_embeddings=self.finetune, + neg_sampling_strategy="w2v" if self.trainer_params["force_w2v_ns"] else None + ) loss = loss / len(self.objectives) # assumes the same batch size for all objectives loss_accum += loss.item() @@ -323,21 +510,34 @@ def train_all(self): # for param in groups["params"]: # torch.nn.utils.clip_grad_norm_(param, max_norm=1.) loss.backward() # create_graph = True + + summary = {} + add_to_summary( + summary=summary, partition="train", objective_name=objective.name, + scores={"Loss": loss.item(), "Accuracy": acc}, postfix="" + ) + + train_losses[f"Loss/train_avg/{objective.name}"].append(loss.item()) + train_accs[f"Accuracy/train_avg/{objective.name}"].append(acc) - summary.update({ - f"Loss/train/{objective.name}_vs_batch": loss.item(), - f"Accuracy/train/{objective.name}_vs_batch": acc, - }) + # except ZeroEdges as e: + # logging.warning(f"Zero edges in loader in step {step}") + # except Exception as e: + # raise e self.optimizer.step() + self.sparse_optimizer.step() + step += 1 self.write_summary(summary, self.batch) summary_dict.update(summary) self.batch += 1 - summary = { - f"Loss/train": loss_accum, - } + # summary = { + # f"Loss/train": loss_accum, + # } + summary = {key: sum(val) / len(val) for key, val in train_losses.items()} + summary.update({key: sum(val) / len(val) for key, val in train_accs.items()}) self.write_summary(summary, self.batch) summary_dict.update(summary) @@ -350,38 +550,43 @@ def train_all(self): with torch.set_grad_enabled(False): objective.target_embedder.prepare_index() # need this to update sampler for the next epoch - val_loss, val_acc, val_ndcg = objective.evaluate("val") - test_loss, test_acc, test_ndcg = objective.evaluate("test") - - summary = { - f"Accuracy/test/{objective.name}_vs_batch": test_acc, - f"Accuracy/val/{objective.name}_vs_batch": val_acc, - } - if test_ndcg is not None: - summary.update({f"{key}/test/{objective.name}_vs_batch": val for key, val in test_ndcg.items()}) - if val_ndcg is not None: - summary.update({f"{key}/val/{objective.name}_vs_batch": val for key, val in val_ndcg.items()}) + val_scores = objective.evaluate("val") + test_scores = objective.evaluate("test") + + summary = {} + add_to_summary(summary, "val", objective.name, val_scores, postfix="") + add_to_summary(summary, "test", objective.name, test_scores, postfix="") self.write_summary(summary, self.batch) summary_dict.update(summary) + val_losses = [item for key, item in summary_dict.items() if key.startswith("Loss/val")] + avg_val_loss = sum(val_losses) / len(val_losses) + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + write_best_model = True + + if self.do_save: + self.save_checkpoint(self.model_base_path, write_best_model=write_best_model) + write_best_model = False + + if objective.early_stopping_trigger is True: + raise EarlyStopping() + objective.train() # self.write_hyperparams({k.replace("vs_batch", "vs_epoch"): v for k, v in summary_dict.items()}, self.epoch) end = time() - print(f"Epoch: {self.epoch}, Time: {int(end - start)} s", end="\t") - print(summary_dict) - - if self.do_save: - self.save_checkpoint(self.model_base_path) + print(f"Epoch: {self.epoch}, Time: {int(end - start)} s", end="\n") + pprint(summary_dict) self.lr_scheduler.step() - def save_checkpoint(self, checkpoint_path=None, checkpoint_name=None, **kwargs): + def save_checkpoint(self, checkpoint_path=None, checkpoint_name=None, write_best_model=False, **kwargs): - checkpoint_path = join(checkpoint_path, "saved_state.pt") + model_path = join(checkpoint_path, f"saved_state.pt") param_dict = { 'graph_model': self.graph_model.state_dict(), @@ -396,15 +601,21 @@ def save_checkpoint(self, checkpoint_path=None, checkpoint_name=None, **kwargs): if len(kwargs) > 0: param_dict.update(kwargs) - torch.save(param_dict, checkpoint_path) + torch.save(param_dict, model_path) + if self.trainer_params["save_each_epoch"]: + torch.save(param_dict, join(checkpoint_path, f"saved_state_{self.epoch}.pt")) + + if write_best_model: + torch.save(param_dict, join(checkpoint_path, f"best_model.pt")) def restore_from_checkpoint(self, checkpoint_path): - checkpoint = torch.load(join(checkpoint_path, "saved_state.pt")) + checkpoint = torch.load(join(checkpoint_path, "saved_state.pt"), map_location=torch.device('cpu')) self.graph_model.load_state_dict(checkpoint['graph_model']) self.node_embedder.load_state_dict(checkpoint['node_embedder']) for objective in self.objectives: objective.custom_load_state_dict(checkpoint[objective.name]) self.epoch = checkpoint['epoch'] + self.restore_epoch = checkpoint['epoch'] self.batch = checkpoint['batch'] logging.info(f"Restored from epoch {checkpoint['epoch']}") # TODO needs test @@ -417,27 +628,24 @@ def final_evaluation(self): objective.reset_iterator("train") objective.reset_iterator("val") objective.reset_iterator("test") + # objective.early_stopping = False + self.compute_embeddings_for_scorer(objective) + objective.target_embedder.prepare_index() + objective.update_embeddings_for_queries = False with torch.set_grad_enabled(False): for objective in self.objectives: - train_loss, train_acc, train_ndcg = objective.evaluate("train") - val_loss, val_acc, val_ndcg = objective.evaluate("val") - test_loss, test_acc, test_ndcg = objective.evaluate("test") - - summary = { - f"Accuracy/train/{objective.name}_final": train_acc, - f"Accuracy/test/{objective.name}_final": test_acc, - f"Accuracy/val/{objective.name}_final": val_acc, - } - if train_ndcg is not None: - summary.update({f"{key}/train/{objective.name}_final": val for key, val in train_ndcg.items()}) - if val_ndcg is not None: - summary.update({f"{key}/val/{objective.name}_final": val for key, val in val_ndcg.items()}) - if test_ndcg is not None: - summary.update({f"{key}/test/{objective.name}_final": val for key, val in test_ndcg.items()}) - + # train_scores = objective.evaluate("train") + val_scores = objective.evaluate("val") + test_scores = objective.evaluate("test") + + summary = {} + # add_to_summary(summary, "train", objective.name, train_scores, postfix="final") + add_to_summary(summary, "val", objective.name, val_scores, postfix="final") + add_to_summary(summary, "test", objective.name, test_scores, postfix="final") + summary_dict.update(summary) # self.write_hyperparams(summary_dict, self.epoch) @@ -474,7 +682,8 @@ def get_embeddings(self): for ntype in self.graph_model.g.ntypes } - h = self.graph_model.inference(batch_size=256, device='cpu', num_workers=0, x=node_embs) + logging.info("Computing all embeddings") + h = self.graph_model.inference(batch_size=2048, device='cpu', num_workers=0, x=node_embs) original_id = [] global_id = [] @@ -491,62 +700,90 @@ def get_embeddings(self): def select_device(args): device = 'cpu' - use_cuda = args.gpu >= 0 and torch.cuda.is_available() + use_cuda = args["gpu"] >= 0 and torch.cuda.is_available() if use_cuda: - torch.cuda.set_device(args.gpu) - device = 'cuda:%d' % args.gpu + torch.cuda.set_device(args["gpu"]) + device = 'cuda:%d' % args["gpu"] return device -def training_procedure( - dataset, model_name, model_params, args, model_base_path -) -> Tuple[SamplingMultitaskTrainer, dict]: +def resolve_activation_function(function_name): + known_functions = { + "tanh": torch.tanh + } - device = select_device(args) + return known_functions.get(function_name, eval(f"nn.functional.{function_name}")) - model_params['num_classes'] = args.node_emb_size - model_params['use_gcn_checkpoint'] = args.use_gcn_checkpoint - model_params['use_att_checkpoint'] = args.use_att_checkpoint - model_params['use_gru_checkpoint'] = args.use_gru_checkpoint - trainer_params = { - 'lr': model_params.pop('lr'), - 'batch_size': args.batch_size, - 'sampling_neighbourhood_size': args.num_per_neigh, - 'neg_sampling_factor': args.neg_sampling_factor, - 'epochs': args.epochs, - # 'node_name_file': args.fname_file, - # 'var_use_file': args.varuse_file, - # 'call_seq_file': args.call_seq_file, - 'elem_emb_size': args.elem_emb_size, - 'model_base_path': model_base_path, - 'pretraining_phase': args.pretraining_phase, - 'use_layer_scheduling': args.use_layer_scheduling, - 'schedule_layers_every': args.schedule_layers_every, - 'embedding_table_size': args.embedding_table_size, - 'save_checkpoints': args.save_checkpoints, - 'measure_ndcg': args.measure_ndcg, - 'dilate_ndcg': args.dilate_ndcg, - "objectives": args.objectives.split(",") - } - - trainer = SamplingMultitaskTrainer( +def training_procedure( + dataset, model_name, model_params, trainer_params, model_base_path, + tokenizer_path=None, trainer=None, load_external_dataset=None +) -> Tuple[SamplingMultitaskTrainer, dict]: + model_params = copy(model_params) + trainer_params = copy(trainer_params) + + if trainer is None: + trainer = SamplingMultitaskTrainer + + device = select_device(trainer_params) + + # model_params['num_classes'] = args.node_emb_size + # model_params['use_gcn_checkpoint'] = args.use_gcn_checkpoint + # model_params['use_att_checkpoint'] = args.use_att_checkpoint + # model_params['use_gru_checkpoint'] = args.use_gru_checkpoint + + if len(trainer_params["objectives"].split(",")) > 1 and trainer_params["early_stopping"] is True: + print("Early stopping disabled when several objectives are used") + trainer_params["early_stopping"] = False + + trainer_params['model_base_path'] = model_base_path + + model_params["activation"] = resolve_activation_function(model_params["activation"]) + + # trainer_params = { + # 'lr': model_params.pop('lr'), + # 'batch_size': args.batch_size, + # 'sampling_neighbourhood_size': args.num_per_neigh, + # 'neg_sampling_factor': args.neg_sampling_factor, + # 'epochs': args.epochs, + # 'elem_emb_size': args.elem_emb_size, + # 'model_base_path': model_base_path, + # 'pretraining_phase': args.pretraining_phase, + # 'use_layer_scheduling': args.use_layer_scheduling, + # 'schedule_layers_every': args.schedule_layers_every, + # 'embedding_table_size': args.embedding_table_size, + # 'save_checkpoints': args.save_checkpoints, + # 'measure_scores': args.measure_scores, + # 'dilate_scores': args.dilate_scores, + # "objectives": args.objectives.split(","), + # "early_stopping": args.early_stopping, + # "early_stopping_tolerance": args.early_stopping_tolerance, + # "force_w2v_ns": args.force_w2v_ns, + # "metric": args.metric, + # "nn_index": args.nn_index, + # "save_each_epoch": args.save_each_epoch + # } + + trainer = trainer( dataset=dataset, model_name=model_name, model_params=model_params, trainer_params=trainer_params, - restore=args.restore_state, + restore=trainer_params["restore_state"], device=device, - pretrained_embeddings_path=args.pretrained, - tokenizer_path=args.tokenizer + pretrained_embeddings_path=trainer_params["pretrained"], + tokenizer_path=tokenizer_path, + load_external_dataset=load_external_dataset ) - try: - trainer.train_all() - except KeyboardInterrupt: - print("Training interrupted") - except Exception as e: - raise e + # try: + trainer.train_all() + # except KeyboardInterrupt: + # logging.info("Training interrupted") + # except EarlyStopping: + # logging.info("Early stopping triggered") + # except Exception as e: + # print("There was an exception", e) trainer.eval() scores = trainer.final_evaluation() @@ -557,8 +794,12 @@ def training_procedure( def evaluation_procedure( - dataset, model_name, model_params, args, model_base_path -) -> Tuple[SamplingMultitaskTrainer, dict]: + dataset, model_name, model_params, args, model_base_path, trainer=None +): + model_params = copy(model_params) + + if trainer is None: + trainer = SamplingMultitaskTrainer device = select_device(args) @@ -575,9 +816,6 @@ def evaluation_procedure( 'sampling_neighbourhood_size': args.num_per_neigh, 'neg_sampling_factor': args.neg_sampling_factor, 'epochs': args.epochs, - # 'node_name_file': args.fname_file, - # 'var_use_file': args.varuse_file, - # 'call_seq_file': args.call_seq_file, 'elem_emb_size': args.elem_emb_size, 'model_base_path': model_base_path, 'pretraining_phase': args.pretraining_phase, @@ -585,11 +823,11 @@ def evaluation_procedure( 'schedule_layers_every': args.schedule_layers_every, 'embedding_table_size': args.embedding_table_size, 'save_checkpoints': args.save_checkpoints, - 'measure_ndcg': args.measure_ndcg, - 'dilate_ndcg': args.dilate_ndcg + 'measure_scores': args.measure_scores, + 'dilate_scores': args.dilate_scores } - trainer = SamplingMultitaskTrainer( + trainer = trainer( #SamplingMultitaskTrainer( dataset=dataset, model_name=model_name, model_params=model_params, diff --git a/SourceCodeTools/models/graph/train/test_rggan.py b/SourceCodeTools/models/graph/train/test_rggan.py new file mode 100644 index 00000000..ca804ddf --- /dev/null +++ b/SourceCodeTools/models/graph/train/test_rggan.py @@ -0,0 +1,537 @@ +import itertools +import logging +from copy import copy +from datetime import datetime +from os import mkdir +from os.path import isdir + +import pandas as pd + +from dgl.data import FB15kDataset, FB15k237Dataset, AMDataset +import torch +from sklearn.model_selection import ParameterGrid + +from SourceCodeTools.code.data.dataset.Dataset import SourceGraphDataset +from SourceCodeTools.models.graph import RGGAN +from SourceCodeTools.models.graph.NodeEmbedder import NodeIdEmbedder +from SourceCodeTools.models.graph.train.sampling_multitask2 import SamplingMultitaskTrainer +from SourceCodeTools.models.graph.train.utils import get_name, get_model_base +from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments + +rggan_grids = [ + { + 'h_dim': [15], + 'num_bases': [-1], + 'num_steps': [5], + 'dropout': [0.0], + 'use_self_loop': [False], + 'activation': [torch.tanh], # torch.nn.functional.hardswish], #[torch.nn.functional.hardtanh], #torch.nn.functional.leaky_relu + 'lr': [1e-3], # 1e-4] + } +] + +rggan_params = list( + itertools.chain.from_iterable( + [ParameterGrid(p) for p in rggan_grids] + ) +) + + +class TestNodeClfGraph(SourceGraphDataset): + def __init__(self, data_loader, + label_from=None, use_node_types=True, + use_edge_types=True, filter=None, self_loops=False, + train_frac=0.6, random_seed=None, tokenizer_path=None, min_count_for_objectives=1, + no_global_edges=False, remove_reverse=False, package_names=None): + self.random_seed = random_seed + self.nodes_have_types = use_node_types + self.edges_have_types = use_edge_types + self.labels_from = label_from + # self.data_path = data_path + self.tokenizer_path = tokenizer_path + self.min_count_for_objectives = min_count_for_objectives + self.no_global_edges = no_global_edges + self.remove_reverse = remove_reverse + + # dataset = AIFBDataset() + # dataset = MUTAGDataset() + # dataset = BGSDataset() + # dataset = AMDataset() + dataset = data_loader() + + self.nodes, self.edges, self.typed_id_map = self.create_nodes_edges_df(dataset) + + # index is later used for sampling and is assumed to be unique + assert len(self.nodes) == len(self.nodes.index.unique()) + assert len(self.edges) == len(self.edges.index.unique()) + + if self_loops: + self.nodes, self.edges = SourceGraphDataset._assess_need_for_self_loops(self.nodes, self.edges) + + if filter is not None: + for e_type in filter.split(","): + logging.info(f"Filtering edge type {e_type}") + self.edges = self.edges.query(f"type != '{e_type}'") + + # if self.remove_reverse: + # self.remove_reverse_edges() + # + # if self.no_global_edges: + # self.remove_global_edges() + + if use_node_types is False and use_edge_types is False: + new_nodes, new_edges = self._create_nodetype_edges() + self.nodes = self.nodes.append(new_nodes, ignore_index=True) + self.edges = self.edges.append(new_edges, ignore_index=True) + + self.nodes['type_backup'] = self.nodes['type'] + if not self.nodes_have_types: + self.nodes['type'] = "node_" + self.nodes = self.nodes.astype({'type': 'category'}) + + self.add_embeddable_flag() + + # need to do this to avoid issues insode dgl library + self.edges['type'] = self.edges['type'].apply(lambda x: f"{x}_") + self.edges['type_backup'] = self.edges['type'] + if not self.edges_have_types: + self.edges['type'] = "edge_" + self.edges = self.edges.astype({'type': 'category'}) + + # compact labels + # self.nodes['label'] = self.nodes[label_from] + # self.nodes = self.nodes.astype({'label': 'category'}) + # self.label_map = compact_property(self.nodes['label']) + # assert any(pandas.isna(self.nodes['label'])) is False + + logging.info(f"Unique nodes: {len(self.nodes)}, node types: {len(self.nodes['type'].unique())}") + logging.info(f"Unique edges: {len(self.edges)}, edge types: {len(self.edges['type'].unique())}") + + # self.nodes, self.label_map = self.add_compact_labels() + self._add_typed_ids() + + # self.add_splits(train_frac=train_frac, package_names=package_names) + + # self.mark_leaf_nodes() + + self._create_hetero_graph() + + self._update_global_id() + + self.nodes.sort_values('global_graph_id', inplace=True) + + # self.splits = SourceGraphDataset.get_global_graph_id_splits(self.nodes) + + def create_nodes_edges_df(self, dataset): + graph = dataset[0] + nodes_df = None + + node_id_map = {} + typed_id_map = {} + + for ntype in graph.ntypes: + typed_id = graph.nodes(ntype=ntype).tolist() + type = [ntype] * len(typed_id) + name = list(map(lambda x: ntype+f"_{x}", typed_id)) + id_ = graph.nodes[ntype].data["_ID"].tolist() + node_dict = {"id": id_, "type": type, "name": name, "typed_id": typed_id} + + node_id_map[ntype] = dict(zip(typed_id, id_)) + typed_id_map[ntype] = dict(zip(id_, typed_id)) + + if "labels" in graph.nodes[ntype].data: + node_dict["labels"] = graph.nodes[ntype].data["labels"].tolist() + else: + node_dict["labels"] = [-1] * len(typed_id) + + if "train_mask" in graph.nodes[ntype].data: + node_dict["train_mask"] = graph.nodes[ntype].data["train_mask"].bool().tolist() + else: + node_dict["train_mask"] = [False] * len(typed_id) + + if "test_mask" in graph.nodes[ntype].data: + node_dict["test_mask"] = graph.nodes[ntype].data["test_mask"].bool().tolist() + else: + node_dict["test_mask"] = [False] * len(typed_id) + + if "val_mask" in graph.nodes[ntype].data: + node_dict["val_mask"] = graph.nodes[ntype].data["val_mask"].bool().tolist() + else: + node_dict["val_mask"] = [False] * len(typed_id) + + if nodes_df is None: + nodes_df = pd.DataFrame.from_dict(node_dict) + else: + nodes_df = nodes_df.append(pd.DataFrame.from_dict(node_dict)) + + assert len(nodes_df["id"]) == len(nodes_df["id"].unique()) # thus is a must have assert + + nodes_df = nodes_df.reset_index(drop=True) + # contiguous_node_index = dict(zip(nodes_df["index"], nodes_df.index)) + + edges_df = None + for srctype, etype, dsttype in graph.canonical_etypes: + src, dst = graph.edges(etype=(srctype, etype, dsttype)) + src = list(map(lambda x: node_id_map[srctype][x], src.tolist())) + dst = list(map(lambda x: node_id_map[dsttype][x], dst.tolist())) + + edge_data = { + "id": graph.edata["_ID"][(srctype, etype, dsttype)].tolist(), + "type": [etype] * len(src), + "src": src, + "dst": dst + } + + if edges_df is None: + edges_df = pd.DataFrame.from_dict(edge_data) + else: + edges_df = edges_df.append(pd.DataFrame.from_dict(edge_data)) + + # assert len(edges_df["id"]) == len(edges_df["id"].unique()) # fails with AMDataset + + edges_df = edges_df.reset_index(drop=True) + + return nodes_df, edges_df, typed_id_map + + def add_embeddable_flag(self): + self.nodes['embeddable'] = True + + def _add_typed_ids(self): + pass + + def load_node_classes(self): + labels = self.nodes.query("train_mask == True or test_mask == True or val_mask == True")[["id", "labels"]].rename({ + "id": "src", + "labels": "dst" + }, axis=1) + return labels + + +class TestLinkPredGraph(TestNodeClfGraph): + def __init__( + self, data_loader, label_from=None, use_node_types=True, + use_edge_types=True, filter=None, self_loops=False, + train_frac=0.6, random_seed=None, tokenizer_path=None, min_count_for_objectives=1, + no_global_edges=False, remove_reverse=False, package_names=None + ): + self.random_seed = random_seed + self.nodes_have_types = use_node_types + self.edges_have_types = use_edge_types + self.labels_from = label_from + # self.data_path = data_path + self.tokenizer_path = tokenizer_path + self.min_count_for_objectives = min_count_for_objectives + self.no_global_edges = no_global_edges + self.remove_reverse = remove_reverse + + dataset = data_loader() + + self.nodes, self.edges, self.typed_id_map = self.create_nodes_edges_df(dataset) + + # index is later used for sampling and is assumed to be unique + assert len(self.nodes) == len(self.nodes.index.unique()) + assert len(self.edges) == len(self.edges.index.unique()) + + if self_loops: + self.nodes, self.edges = SourceGraphDataset._assess_need_for_self_loops(self.nodes, self.edges) + self.nodes, self.val_edges = SourceGraphDataset._assess_need_for_self_loops(self.nodes, self.val_edges) + self.nodes, self.test_edges = SourceGraphDataset._assess_need_for_self_loops(self.nodes, self.test_edges) + + if filter is not None: + for e_type in filter.split(","): + logging.info(f"Filtering edge type {e_type}") + self.edges = self.edges.query(f"type != '{e_type}'") + self.val_edges = self.val_edges.query(f"type != '{e_type}'") + self.test_edges = self.test_edges.query(f"type != '{e_type}'") + + if use_node_types is False and use_edge_types is False: + new_nodes, new_edges = self._create_nodetype_edges() + self.nodes = self.nodes.append(new_nodes, ignore_index=True) + self.edges = self.edges.append(new_edges, ignore_index=True) + + self.nodes['type_backup'] = self.nodes['type'] + if not self.nodes_have_types: + self.nodes['type'] = "node_" + self.nodes = self.nodes.astype({'type': 'category'}) + + self.add_embeddable_flag() + + # need to do this to avoid issues insode dgl library + self.edges['type'] = self.edges['type'].apply(lambda x: f"{x}_") + self.edges['type_backup'] = self.edges['type'] + if not self.edges_have_types: + self.edges['type'] = "edge_" + self.edges = self.edges.astype({'type': 'category'}) + + # compact labels + # self.nodes['label'] = self.nodes[label_from] + # self.nodes = self.nodes.astype({'label': 'category'}) + # self.label_map = compact_property(self.nodes['label']) + # assert any(pandas.isna(self.nodes['label'])) is False + + logging.info(f"Unique nodes: {len(self.nodes)}, node types: {len(self.nodes['type'].unique())}") + logging.info(f"Unique edges: {len(self.edges)}, edge types: {len(self.edges['type'].unique())}") + + # self.nodes, self.label_map = self.add_compact_labels() + self._add_typed_ids() + + # self.add_splits(train_frac=train_frac, package_names=package_names) + + # self.mark_leaf_nodes() + + self._create_hetero_graph() + + self._update_global_id() + + self.nodes.sort_values('global_graph_id', inplace=True) + + # self.splits = SourceGraphDataset.get_global_graph_id_splits(self.nodes) + + def create_nodes_edges_df(self, dataset): + graph = dataset[0] + nodes_df = None + + node_id_map = {} + typed_id_map = {} + typed_id = [] + + id_ = graph.nodes().tolist() + type = graph.ndata["ntype"].tolist() + src, dst = graph.edges() + src = src.tolist() + dst = dst.tolist() + edge_type = graph.edata["etype"].tolist() + train_mask = graph.edata["train_mask"].tolist() + val_mask = graph.edata["val_mask"].tolist() + test_mask = graph.edata["test_mask"].tolist() + # node2train_mask = dict(zip(src, train_mask)) + # node2val_mask = dict(zip(src, val_mask)) + # node2test_mask = dict(zip(src, test_mask)) + + for nid_, ntype_ in zip(id_, type): + if ntype_ not in node_id_map: + node_id_map[ntype_] = dict() + typed_id_map[ntype_] = dict() + + node_id_map[ntype_][len(node_id_map[ntype_])] = nid_ + typed_id_map[ntype_][nid_] = len(typed_id_map[ntype_]) + typed_id.append(typed_id_map[ntype_][nid_]) + + name = list(map(lambda x: f"{x[0]}_{x[1]}", zip(type, typed_id))) + + node_dict = { + "id": id_, "type": type, "name": name, "typed_id": typed_id, + # "train_mask": list(map(lambda x: node2train_mask[x], id_)), + # "val_mask": list(map(lambda x: node2val_mask[x], id_)), + # "test_mask": list(map(lambda x: node2test_mask[x], id_)) + } + + nodes_df = pd.DataFrame.from_dict(node_dict) + + assert len(nodes_df["id"]) == len(nodes_df["id"].unique()) + + nodes_df = nodes_df.reset_index(drop=True) + # contiguous_node_index = dict(zip(nodes_df["index"], nodes_df.index)) + + assert len(node_id_map) == 1 + + edges_df = [] + for e_ind, (src, etype, dst) in enumerate(zip(src, edge_type, dst)): + edge_data = { + "id": e_ind, + "type": etype, + "src": src, + "dst": dst, + "train_mask": train_mask[e_ind], + "val_mask": val_mask[e_ind], + "test_mask": test_mask[e_ind] + } + edges_df.append(edge_data) + + edges_df = pd.DataFrame(edges_df) + + assert len(edges_df["id"]) == len(edges_df["id"].unique()) + + edges_df = edges_df.reset_index(drop=True) + self.train_edges = edges_df.query("train_mask == True") + self.val_edges = edges_df.query("train_mask == True or val_mask == True") + self.test_edges = edges_df + + return nodes_df, self.train_edges, typed_id_map + + +class TestTrainer(SamplingMultitaskTrainer): + def __init__(self, *args, **kwargs): + super(TestTrainer, self).__init__(*args, **kwargs) + + def create_node_embedder(self, dataset, tokenizer_path, n_dims=None, pretrained_path=None, n_buckets=500000): + self.node_embedder = NodeIdEmbedder( + nodes=dataset.nodes, + emb_size=n_dims, + dtype=self.dtype, + n_buckets=len(dataset.nodes) + 1 # override this because bucket size should be the same as number of nodes for this embedder + ) + + +def main_node_clf(models, args, data_loader): + for model, param_grid in models.items(): + for params in param_grid: + + if args.h_dim is None: + params["h_dim"] = args.node_emb_size + else: + params["h_dim"] = args.h_dim + + params["num_steps"] = args.n_layers + + date_time = str(datetime.now()) + print("\n\n") + print(date_time) + print(f"Model: {model.__name__}, Params: {params}") + + model_attempt = get_name(model, date_time) + + model_base = get_model_base(args, model_attempt) + + dataset = TestNodeClfGraph(data_loader=data_loader) + + args.objectives = "node_clf" + + from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure + + trainer, scores = training_procedure(dataset, model, copy(params), args, model_base, trainer=TestTrainer) + + return scores + + +def main_link_pred(models, args, data_loader): + for model, param_grid in models.items(): + for params in param_grid: + + if args.h_dim is None: + params["h_dim"] = args.node_emb_size + else: + params["h_dim"] = args.h_dim + + params["num_steps"] = args.n_layers + + date_time = str(datetime.now()) + print("\n\n") + print(date_time) + print(f"Model: {model.__name__}, Params: {params}") + + model_attempt = get_name(model, date_time) + + model_base = get_model_base(args, model_attempt) + + dataset = TestLinkPredGraph(data_loader=data_loader) + + args.objectives = "link_pred" + + from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure + + trainer, scores = training_procedure(dataset, model, copy(params), args, model_base, trainer=TestTrainer) + + return scores + +def format_data(dataset): + graph = dataset[0] + category = dataset.predict_category + train_mask = graph.nodes[category].data.pop('train_mask') + test_mask = graph.nodes[category].data.pop('test_mask') + labels = graph.nodes[category].data.pop('labels') + train_labels = labels[train_mask] + test_labels = labels[test_mask] + +def node_clf(args): + logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s") + + models_ = { + # GCNSampling: gcnsampling_params, + # GATSampler: gatsampling_params, + # RGCNSampling: rgcnsampling_params, + # RGAN: rgcnsampling_params, + RGGAN: rggan_params + + } + + if not isdir(args.model_output_dir): + mkdir(args.model_output_dir) + args.save_checkpoints = False + + # data_loaders = [AIFBDataset, MUTAGDataset, BGSDataset, AMDataset] + data_loaders = [AMDataset] + + for dl in data_loaders: + print(dl.__name__) + scores = main_node_clf(models_, args, data_loader=dl) + print("\t", scores) + + +def link_pred(args): + logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s") + + models_ = { + # GCNSampling: gcnsampling_params, + # GATSampler: gatsampling_params, + # RGCNSampling: rgcnsampling_params, + # RGAN: rgcnsampling_params, + RGGAN: rggan_params + + } + + if not isdir(args.model_output_dir): + mkdir(args.model_output_dir) + args.save_checkpoints = False + + data_loaders = [FB15kDataset, FB15k237Dataset] + + for dl in data_loaders: + print(dl.__name__) + scores = main_link_pred(models_, args, data_loader=dl) + print("\t", scores) + + + # dataset = AIFBDataset() + # graph = dataset[0] + # category = dataset.predict_category + # num_classes = dataset.num_classes + # train_mask = graph.nodes[category].data.pop('train_mask') + # test_mask = graph.nodes[category].data.pop('test_mask') + # labels = graph.nodes[category].data.pop('labels') + # + # print() + +# def main(): +# dataset = FB15k237Dataset() +# g = dataset.graph +# e_type = g.edata['e_type'] +# +# train_mask = g.edata['train_mask'] +# val_mask = g.edata['val_mask'] +# test_mask = g.edata['test_mask'] +# +# # graph = dataset[0] +# # train_mask = graph.edata['train_mask'] +# # test_mask = g.edata['test_mask'] +# +# train_set = torch.arange(g.number_of_edges())[train_mask] +# val_set = torch.arange(g.number_of_edges())[val_mask] +# +# train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() +# src, dst = graph.edges(train_idx) +# rel = graph.edata['etype'][train_idx] +# +# print() + +if __name__=="__main__": + import argparse + + parser = argparse.ArgumentParser(description='Process some integers.') + add_gnn_train_args(parser) + + args = parser.parse_args() + verify_arguments(args) + + node_clf(args) + # link_pred(args) \ No newline at end of file diff --git a/SourceCodeTools/models/graph/train/utils.py b/SourceCodeTools/models/graph/train/utils.py index 95d1a14c..cf07e3ee 100644 --- a/SourceCodeTools/models/graph/train/utils.py +++ b/SourceCodeTools/models/graph/train/utils.py @@ -18,11 +18,11 @@ def get_name(model, timestamp): return "{} {}".format(model.__name__, timestamp).replace(":", "-").replace(" ", "-").replace(".", "-") -def get_model_base(args, model_attempt): - if args.restore_state: - model_base = args.model_output_dir +def get_model_base(args, model_attempt, force_new=False): + if args["restore_state"] and not force_new: + model_base = args["model_output_dir"] else: - model_base = join(args.model_output_dir, model_attempt) + model_base = join(args["model_output_dir"], model_attempt) if not isdir(model_base): mkdir(model_base) diff --git a/SourceCodeTools/models/graph/utils/converters/convert_dglke_for_experiments.py b/SourceCodeTools/models/graph/utils/converters/convert_dglke_for_experiments.py index 0e157b34..d4844ec4 100644 --- a/SourceCodeTools/models/graph/utils/converters/convert_dglke_for_experiments.py +++ b/SourceCodeTools/models/graph/utils/converters/convert_dglke_for_experiments.py @@ -3,7 +3,7 @@ import os import pandas as pd import numpy as np -from SourceCodeTools.code.data.sourcetrail.Dataset import get_train_val_test_indices, load_data +from SourceCodeTools.code.data.dataset.Dataset import SourceGraphDataset, load_data # get_train_val_test_indices, load_data import pickle import argparse @@ -66,13 +66,13 @@ def load_npy(ent_map_path, npy_path): # splits = get_train_val_test_indices(nodes.index) from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types -splits = get_train_val_test_indices(nodes.query(f"type_backup == '{node_types[4096]}'").index) +splits = SourceGraphDataset.get_train_val_test_indices(nodes.query(f"type_backup == '{node_types[4096]}'").index) # nodes, edges, held = SourceGraphDataset.holdout(nodes, edges, 0.001) # nodes['label'] = nodes['type'] -from SourceCodeTools.code.data.sourcetrail.Dataset import create_train_val_test_masks +# from SourceCodeTools.code.data.dataset.Dataset import create_train_val_test_masks # def add_splits(nodes, splits): # nodes['train_mask'] = False # nodes.loc[nodes.index[splits[0]], 'train_mask'] = True @@ -94,7 +94,7 @@ def load_npy(ent_map_path, npy_path): os.path.join(args.out_path, "state_dict.pt") ) -create_train_val_test_masks(nodes, *splits) +SourceGraphDataset.create_train_val_test_masks(nodes, *splits) nodes.to_csv(os.path.join(args.out_path, "nodes.csv"), index=False) edges.to_csv(os.path.join(args.out_path, "edges.csv"), index=False) diff --git a/SourceCodeTools/models/graph/utils/converters/convert_w2v_for_experiments.py b/SourceCodeTools/models/graph/utils/converters/convert_w2v_for_experiments.py index 3e60e587..f59687bd 100644 --- a/SourceCodeTools/models/graph/utils/converters/convert_w2v_for_experiments.py +++ b/SourceCodeTools/models/graph/utils/converters/convert_w2v_for_experiments.py @@ -2,7 +2,7 @@ import torch import os, sys import numpy as np -from SourceCodeTools.code.data.sourcetrail.Dataset import get_train_val_test_indices, load_data +from SourceCodeTools.code.data.dataset.Dataset import SourceGraphDataset, load_data # get_train_val_test_indices, load_data import pickle @@ -43,7 +43,7 @@ def load_w2v(path): # splits = get_train_val_test_indices(nodes.index) from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types -splits = get_train_val_test_indices(nodes.query(f"type_backup == '{node_types[4096]}'").index) +splits = SourceGraphDataset.get_train_val_test_indices(nodes.query(f"type_backup == '{node_types[4096]}'").index) id_map, vecs = load_w2v(emb_path) @@ -65,8 +65,8 @@ def load_w2v(path): os.path.join(out_path, "state_dict.pt") ) -from SourceCodeTools.code.data.sourcetrail.Dataset import create_train_val_test_masks -create_train_val_test_masks(nodes, *splits) +# from SourceCodeTools.code.data.dataset.Dataset import create_train_val_test_masks +SourceGraphDataset.create_train_val_test_masks(nodes, *splits) nodes.to_csv(os.path.join(out_path, "nodes.csv"), index=False) edges.to_csv(os.path.join(out_path, "edges.csv"), index=False) diff --git a/SourceCodeTools/models/graph/utils/dglke_to_embedder.py b/SourceCodeTools/models/graph/utils/dglke_to_embedder.py new file mode 100644 index 00000000..093b1a00 --- /dev/null +++ b/SourceCodeTools/models/graph/utils/dglke_to_embedder.py @@ -0,0 +1,38 @@ +import pickle + +import numpy +import pandas as pd + +from SourceCodeTools.models.Embedder import Embedder + + +def read_entities(path): + return pd.read_csv(path, sep="\t", header=None)[1] + + +def read_vectors(path): + return numpy.load(path) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("entities") + parser.add_argument("vectors") + parser.add_argument("output") + + args = parser.parse_args() + + entities = read_entities(args.entities) + vectors = read_vectors(args.vectors) + + embedder = Embedder(dict(zip(entities, range(len(entities)))), vectors) + with open(args.output, "wb") as sink: + pickle.dump(embedder, sink) + + + + + +if __name__=="__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/models/graph/utils/export4visualization.py b/SourceCodeTools/models/graph/utils/export4visualization.py index 9533a5d1..69d28454 100644 --- a/SourceCodeTools/models/graph/utils/export4visualization.py +++ b/SourceCodeTools/models/graph/utils/export4visualization.py @@ -1,89 +1,130 @@ import pickle -import sys -import pandas +from os.path import join + import numpy as np import os from collections import Counter - -model_path = sys.argv[1] - -if len(sys.argv) > 2: - max_embs = int(sys.argv[2]) -else: - max_embs = 5000 - -nodes_path = os.path.join(model_path, "nodes.csv") -edges_path = os.path.join(model_path, "edges.csv") -embedders_path = os.path.join(model_path, "embeddings.pkl") - -embedders = pickle.load(open(embedders_path, "rb")) -nodes = pandas.read_csv(nodes_path) - -edges = pandas.read_csv(edges_path) -degrees = Counter(edges['src'].tolist()) + Counter(edges['dst'].tolist()) - -ids = nodes['id'].values -names = nodes['name']#.apply(lambda x: x.split(".")[-1]).values - -id_name_map = list(zip(ids, names)) -id_name_map_d = dict(id_name_map) - -ind_mapper = embedders[0].ind - -id_name_map = sorted(id_name_map, key=lambda x: ind_mapper[x[0]]) - -ids, names = zip(*id_name_map) - -print(f"Limiting to {max_embs} embeddings") - - -# emb0 = [] -# emb1 = [] -# emb2 = [] - -for group, gr_ in nodes.groupby("type_backup"): - c_ = 0 - - names = [] - embs = [[] for _ in embedders] - - nodes_in_group = set(gr_['id'].tolist()) - for ind, (id_, count) in enumerate(degrees.most_common()): - if id_ not in nodes_in_group: continue - names.append(id_name_map_d[id_]) - for emb_, embedder in zip(embs, embedders): - emb_.append(embedder.e[embedder.ind[id_]]) - # emb0.append(embedders[0].e[embedders[0].ind[id_]]) - # emb1.append(embedders[1].e[embedders[1].ind[id_]]) - # emb2.append(embedders[2].e[embedders[2].ind[id_]]) - c_ += 1 - if c_ >= max_embs - 1: break - # if ind >= max_embs-1: break - - # np.savetxt(os.path.join(model_path, "emb4proj_meta.tsv"), np.array(names).reshape(-1, 1)) - for ind, emb_ in enumerate(embs): - np.savetxt(os.path.join(model_path, f"emb4proj{ind}_{group}.tsv"), np.array(emb_), delimiter="\t") - # np.savetxt(os.path.join(model_path, "emb4proj0.tsv"), np.array(emb0), delimiter="\t") - # np.savetxt(os.path.join(model_path, "emb4proj1.tsv"), np.array(emb1), delimiter="\t") - # np.savetxt(os.path.join(model_path, "emb4proj2.tsv"), np.array(emb2), delimiter="\t") - - print("Writing meta...", end="") - with open(os.path.join(model_path, f"emb4proj_meta_{group}.tsv"), "w") as meta: - for name in names[:max_embs]: - meta.write(f"{name}\n") - print("done") - -# with open(os.path.join(model_path, "emb4proj2w2v.txt"), "w") as w2v: -# for ind, name in enumerate(names): -# w2v.write("%s " % name) -# for j, v in enumerate(emb2[ind]): -# if j < len(emb2[ind]) - 1: -# w2v.write("%f " % v) -# else: -# w2v.write("%f\n" % v) - - -# for ind, e in enumerate(embedders): -# print(f"Writing embedding layer {ind}...", end="") -# np.savetxt(os.path.join(model_path, f"emb4proj{ind}.tsv"), e.e[:max_embs], delimiter="\t") -# print("done") +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("model_path") + parser.add_argument("output_path") + parser.add_argument("--max_embs", type=int, default=5000) + parser.add_argument("--into_groups", action="store_true") + args = parser.parse_args() + model_path = args.model_path + + dataset = pickle.load(open(join(args.model_path, "dataset.pkl"), "rb")) + embedders = pickle.load(open(join(args.model_path, "embeddings.pkl"), "rb")) + + nodes = dataset.nodes + edges = dataset.edges + + degrees = Counter(edges['src'].tolist()) + Counter(edges['dst'].tolist()) + + ids = nodes['id'].values + names = nodes['name']#.apply(lambda x: x.split(".")[-1]).values + + id_name_map = list(zip(ids, names)) + id_name_map_d = dict(id_name_map) + + ind_mapper = embedders[0].ind + + id_name_map = sorted(id_name_map, key=lambda x: ind_mapper[x[0]]) + + ids, names = zip(*id_name_map) + + print(f"Limiting to {args.max_embs} embeddings") + + # emb0 = [] + # emb1 = [] + # emb2 = [] + + def write_in_groups(): + + for group, gr_ in nodes.groupby("type_backup"): + c_ = 0 + + names = [] + embs = [[] for _ in embedders] + + nodes_in_group = set(gr_['id'].tolist()) + for ind, (id_, count) in enumerate(degrees.most_common()): + if id_ not in nodes_in_group: continue + names.append(id_name_map_d[id_]) + for emb_, embedder in zip(embs, embedders): + emb_.append(embedder.e[embedder.ind[id_]]) + # emb0.append(embedders[0].e[embedders[0].ind[id_]]) + # emb1.append(embedders[1].e[embedders[1].ind[id_]]) + # emb2.append(embedders[2].e[embedders[2].ind[id_]]) + c_ += 1 + if c_ >= args.max_embs - 1: break + # if ind >= max_embs-1: break + + # np.savetxt(os.path.join(model_path, "emb4proj_meta.tsv"), np.array(names).reshape(-1, 1)) + for ind, emb_ in enumerate(embs): + np.savetxt(os.path.join(args.output_path, f"emb4proj{ind}_{group}.tsv"), np.array(emb_), delimiter="\t") + # np.savetxt(os.path.join(model_path, "emb4proj0.tsv"), np.array(emb0), delimiter="\t") + # np.savetxt(os.path.join(model_path, "emb4proj1.tsv"), np.array(emb1), delimiter="\t") + # np.savetxt(os.path.join(model_path, "emb4proj2.tsv"), np.array(emb2), delimiter="\t") + + print("Writing meta...", end="") + with open(os.path.join(args.output_path, f"emb4proj_meta_{group}.tsv"), "w") as meta: + for name in names[:args.max_embs]: + meta.write(f"{name}\n") + print("done") + + def write_all_together(): + c_ = 0 + + names = [] + embs = [[] for _ in embedders] + + nodes_in_group = set(nodes['id'].tolist()) + for ind, (id_, count) in enumerate(degrees.most_common()): + if id_ not in nodes_in_group: continue + names.append(id_name_map_d[id_]) + for emb_, embedder in zip(embs, embedders): + emb_.append(embedder.e[embedder.ind[id_]]) + c_ += 1 + if c_ >= args.max_embs - 1: break + # if ind >= max_embs-1: break + + # np.savetxt(os.path.join(model_path, "emb4proj_meta.tsv"), np.array(names).reshape(-1, 1)) + for ind, emb_ in enumerate(embs): + np.savetxt(os.path.join(args.output_path, f"emb4proj{ind}.tsv"), np.array(emb_), delimiter="\t") + # np.savetxt(os.path.join(model_path, "emb4proj0.tsv"), np.array(emb0), delimiter="\t") + # np.savetxt(os.path.join(model_path, "emb4proj1.tsv"), np.array(emb1), delimiter="\t") + # np.savetxt(os.path.join(model_path, "emb4proj2.tsv"), np.array(emb2), delimiter="\t") + + print("Writing meta...", end="") + with open(os.path.join(args.output_path, f"emb4proj_meta.tsv"), "w") as meta: + for name in names[:args.max_embs]: + meta.write(f"{name}\n") + print("done") + + if args.into_groups: + write_in_groups() + else: + write_all_together() + + + # with open(os.path.join(model_path, "emb4proj2w2v.txt"), "w") as w2v: + # for ind, name in enumerate(names): + # w2v.write("%s " % name) + # for j, v in enumerate(emb2[ind]): + # if j < len(emb2[ind]) - 1: + # w2v.write("%f " % v) + # else: + # w2v.write("%f\n" % v) + + + # for ind, e in enumerate(embedders): + # print(f"Writing embedding layer {ind}...", end="") + # np.savetxt(os.path.join(model_path, f"emb4proj{ind}.tsv"), e.e[:max_embs], delimiter="\t") + # print("done") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/models/graph/utils/import_into_neo4j.py b/SourceCodeTools/models/graph/utils/import_into_neo4j.py index bfc7571e..2b122be8 100644 --- a/SourceCodeTools/models/graph/utils/import_into_neo4j.py +++ b/SourceCodeTools/models/graph/utils/import_into_neo4j.py @@ -2,7 +2,7 @@ from neo4j import GraphDatabase import argparse -from SourceCodeTools.code.data.sourcetrail.Dataset import load_data +from SourceCodeTools.code.data.dataset.Dataset import load_data # from SourceCodeTools.data.sourcetrail.sourcetrail_types import node_types, edge_types diff --git a/SourceCodeTools/models/graph/utils/prepare_dglke_format.py b/SourceCodeTools/models/graph/utils/prepare_dglke_format.py index 3482f452..e039b249 100644 --- a/SourceCodeTools/models/graph/utils/prepare_dglke_format.py +++ b/SourceCodeTools/models/graph/utils/prepare_dglke_format.py @@ -1,79 +1,85 @@ +import logging import os -import pandas as pd -from SourceCodeTools.code.data.sourcetrail.Dataset import SourceGraphDataset, load_data, compact_property +from os.path import isdir, join, isfile + +from SourceCodeTools.code.data.dataset.Dataset import load_data, compact_property, SourceGraphDataset import argparse -parser = argparse.ArgumentParser(description='Process some integers.') -parser.add_argument('--nodes_path', dest='nodes_path', default=None, - help='Path to the file with nodes') -parser.add_argument('--edges_path', dest='edges_path', default=None, - help='Path to the file with edges') -parser.add_argument('--fname_path', dest='fname_path', default=None, - help='') -parser.add_argument('--varuse_path', dest='varuse_path', default=None, - help='') -parser.add_argument('--apicall_path', dest='apicall_path', default=None, - help='') -parser.add_argument('--out_path', dest='out_path', default=None, - help='') +from SourceCodeTools.code.data.file_utils import unpersist, persist + + +def get_paths(dataset_path, use_extra_objectives): + extra_objectives = ["node_names.bz2", "common_function_variable_pairs.bz2", "common_call_seq.bz2", "type_annotations.bz2"] + + largest_component = join(dataset_path, "largest_component") + if isdir(largest_component): + logging.info("Using graph from largest_component directory") + nodes_path = join(largest_component, "nodes.bz2") + edges_path = join(largest_component, "edges.bz2") + else: + nodes_path = join(dataset_path, "common_nodes.bz2") + edges_path = join(dataset_path, "common_edges.bz2") + + + extra_paths = list(map( + lambda file: file if use_extra_objectives and isfile(file) else None, + (join(dataset_path, objective) for objective in extra_objectives) + )) -args = parser.parse_args() + return nodes_path, edges_path, extra_paths -nodes, edges = load_data(args.nodes_path, args.edges_path) -node2graph_id = compact_property(nodes['id']) -nodes['global_graph_id'] = nodes['id'].apply(lambda x: node2graph_id[x]) -nodes, edges, held = SourceGraphDataset.holdout(nodes, edges, 0.005) -edges.to_csv(os.path.join(args.out_path, "edges_train.csv"), index=False) -held.to_csv(os.path.join(args.out_path, "held.csv"), index=False) +def filter_relevant(data, node_ids): + return data.query("src in @allowed", local_dict={"allowed": node_ids}) -edges = edges.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']] -node_ids = set(nodes['id'].unique()) +def main(): + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('dataset_path', default=None, help='Path to the dataset') + parser.add_argument('output_path', default=None, help='') + parser.add_argument("--extra_objectives", action="store_true", default=False) + parser.add_argument("--eval_frac", dest="eval_frac", default=0.05, type=float) -if args.fname_path is not None: - fname = pd.read_csv(args.fname_path).astype({"src": 'int32', "dst": "str"}) - fname['type'] = 'fname' + args = parser.parse_args() - fname = fname[ - fname['src'].apply(lambda x: x in node_ids) - ] + nodes_path, edges_path, extra_paths = get_paths( + args.dataset_path, use_extra_objectives=args.extra_objectives + ) - edges = pd.concat([edges, fname]) + nodes, edges = load_data(nodes_path, edges_path) + nodes, edges, holdout = SourceGraphDataset.holdout(nodes, edges) + edges = edges.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']] + holdout = holdout.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']] -if args.varuse_path is not None: - varuse = pd.read_csv(args.varuse_path).astype({"src": 'int32', "dst": "str"}) - varuse['type'] = 'varuse' + node2graph_id = compact_property(nodes['id']) + nodes['global_graph_id'] = nodes['id'].apply(lambda x: node2graph_id[x]) - varuse = varuse[ - varuse['src'].apply(lambda x: x in node_ids) - ] + node_ids = set(nodes['id'].unique()) - edges = pd.concat([edges, varuse]) + if args.extra_objectives: + for objective_path in extra_paths: + data = unpersist(objective_path) + data = filter_relevant(data, node_ids) + data["type"] = objective_path.split(".")[0] + edges = edges.append(data) -if args.apicall_path is not None: - apicall = pd.read_csv(args.apicall_path).astype({"src": 'int32', "dst": "int32"}) - apicall['type'] = 'nextcall' + if not os.path.isdir(args.output_path): + os.mkdir(args.output_path) - apicall = apicall[ - apicall['src'].apply(lambda x: x in node_ids) - ] + edges = edges[['src','dst','type']] + eval_sample = edges.sample(frac=args.eval_frac) - edges = pd.concat([edges, apicall]) + persist(nodes, join(args.output_path, "nodes_dglke.csv")) + persist(edges, join(args.output_path, "edges_train_dglke.tsv"), header=False, sep="\t") + persist(edges, join(args.output_path, "edges_train_node2vec.tsv"), header=False, sep=" ") + persist(eval_sample, join(args.output_path, "edges_eval_dglke.tsv"), header=False, sep="\t") + persist(eval_sample, join(args.output_path, "edges_eval_node2vec.tsv"), header=False, sep=" ") + persist(holdout, join(args.output_path, "edges_eval_dglke_10000.tsv"), header=False, sep="\t") + persist(holdout, join(args.output_path, "edges_eval_node2vec_10000.tsv"), header=False, sep=" ") -# splits = get_train_test_val_indices(edges.index, train_frac=0.6) -nodes['label'] = nodes['type'] -if not os.path.isdir(args.out_path): - os.mkdir(args.out_path) -nodes.to_csv(os.path.join(args.out_path, "nodes.csv"), index=False) -edges.to_csv(os.path.join(args.out_path, "edges_train_dglke.tsv"), index=False, header=False, sep="\t") -edges[['src','dst']].to_csv(os.path.join(args.out_path, "edges_train_node2vec.csv"), index=False, header=False, sep=" ") -# edges.iloc[splits[0]].to_csv(os.path.join(args.out_path, "edges_train.csv"), index=False, header=False, sep="\t") -# edges.iloc[splits[1]].to_csv(os.path.join(args.out_path, "edges_val.csv"), index=False, header=False, sep="\t") -# edges.iloc[splits[2]].to_csv(os.path.join(args.out_path, "edges_test.csv"), index=False, header=False, sep="\t") -held[['src','dst','type']].to_csv(os.path.join(args.out_path, "held_dglkg.tsv"), index=False, sep='\t', header=False) -held[['src','dst']].to_csv(os.path.join(args.out_path, "held_node2vec.csv"), index=False, sep=' ', header=False) \ No newline at end of file +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/models/graph/utils/prepare_dglke_format2.py b/SourceCodeTools/models/graph/utils/prepare_dglke_format2.py new file mode 100644 index 00000000..72b985c5 --- /dev/null +++ b/SourceCodeTools/models/graph/utils/prepare_dglke_format2.py @@ -0,0 +1,270 @@ +import logging +import os +from collections import defaultdict +from os.path import isdir, join, isfile +from random import random + +import numpy as np +import pandas as pd + +from SourceCodeTools.code.common import read_edges, read_nodes +from SourceCodeTools.code.data.dataset.Dataset import load_data, compact_property, SourceGraphDataset + +import argparse + +from SourceCodeTools.code.data.file_utils import unpersist, persist +from SourceCodeTools.code.data.sourcetrail.sourcetrail_types import node_types + + +def get_paths(dataset_path, use_extra_objectives): + extra_objectives = ["node_names.json", "common_function_variable_pairs.json", "common_call_seq.json", "type_annotations.json"] + + largest_component = join(dataset_path, "largest_component") + if isdir(largest_component): + logging.info("Using graph from largest_component directory") + nodes_path = join(largest_component, "common_nodes.json") + edges_path = join(largest_component, "common_edges.json") + else: + nodes_path = join(dataset_path, "common_nodes.json") + edges_path = join(dataset_path, "common_edges.json") + + extra_paths = list(filter(lambda x: x is not None, map( + lambda file: file if use_extra_objectives and isfile(file) else None, + (join(dataset_path, objective) for objective in extra_objectives) + ))) + + return nodes_path, edges_path, extra_paths + + +def filter_relevant(data, node_ids): + return data.query("src in @allowed", local_dict={"allowed": node_ids}) + + +def add_counts(counter, node_ids): + for node_id in node_ids: + counter[node_id] += 1 + + +def count_degrees(edges_path): + counter = defaultdict(lambda: 0) + for edges in read_edges(edges_path, as_chunks=True): + add_counts(counter, edges["source_node_id"]) + add_counts(counter, edges["target_node_id"]) + + return counter + + +def count_with_occurrence(counter, min_occurrence): + c = 0 + for id_, count in counter.items(): + if count > min_occurrence: + c += 1 + return c + + +def get_writing_mode(is_csv, first_written): + kwargs = {} + if first_written is True: + kwargs["mode"] = "a" + if is_csv: + kwargs["header"] = False + return kwargs + + +def do_holdout(edges_path, output_path, node_descriptions, holdout_size=10000, min_count=2): + + counter = count_degrees(edges_path) + num_valid_candidates = count_with_occurrence(counter, min_count) + + frac = holdout_size / num_valid_candidates + + # temp_edges = join(os.path.dirname(edges_path), "temp_" + os.path.basename(edges_path)) + out_edges_path = join(output_path, "edges_train_dglke.tsv") + out_holdout_path = join(output_path, "edges_eval_dglke_10000.tsv") + is_csv = True + + first_edges = False + first_holdout = False + + total_edges = 0 + total_holdout = 0 + + seen = set() + + for edges in read_edges(edges_path, as_chunks=True): + edges.rename({"source_node_id": "src", "target_node_id": "dst"}, axis=1, inplace=True) + edges = edges[['src', 'dst', 'type']] + + sufficient_count = edges["src"].apply(lambda x: counter[x] > min_count) & \ + edges["dst"].apply(lambda x: counter[x] > min_count) + + definitely_keep = edges[~sufficient_count] + probably_keep = edges[sufficient_count] + + probably_holdout_mask = np.array([random() < frac for _ in range(len(probably_keep))]) + + probably_holdout = probably_keep[probably_holdout_mask] + + definitely_holdout_mask = [] + for src, dst, type_ in probably_holdout.values: + if counter[src] > min_count and counter[dst] > min_count: + definitely_holdout_mask.append(True) + counter[src] -= 1 + counter[dst] -= 1 + else: + definitely_holdout_mask.append(False) + + definitely_holdout_mask = np.array(definitely_holdout_mask) + + definitely_holdout = probably_holdout[definitely_holdout_mask] + definitely_keep = pd.concat([definitely_keep, probably_keep[~probably_holdout_mask], probably_holdout[~definitely_holdout_mask]]) + + total_edges += len(definitely_keep) + total_holdout += len(definitely_holdout) + + def apply_description(edges): + edges["src"] = edges["src"].apply(node_descriptions.get) + edges["dst"] = edges["dst"].apply(node_descriptions.get) + return edges + + def write_filtered(table, path, first_written): + with_description = apply_description(table) + + with_description.drop_duplicates(inplace=True) + + reprs = [(src, dst, type_) for src, dst, type_ in with_description.values] + + seen_mask = np.array(list(map(lambda x: x in seen, reprs))) + + with_description = with_description.loc[~seen_mask] + seen.update(reprs) + + kwargs = get_writing_mode(is_csv, first_written) + persist(with_description, path, sep="\t", **kwargs) + + write_filtered(definitely_keep, out_edges_path, first_edges) + first_edges = True + + if len(definitely_holdout) > 0: + write_filtered(definitely_holdout, out_holdout_path, first_holdout) + first_holdout = True + + return counter, total_edges, total_holdout + + +def add_extra_objectives(extra_paths, output_path, node_ids): + out_edges_path = join(output_path, "edges_train_dglke.tsv") + + total_extra = 0 + + for objective_path in extra_paths: + data = unpersist(objective_path) + data = filter_relevant(data, node_ids) + data["type"] = data["type"].split(".")[0] + data = data[["src", "dst", "type"]] + raise NotImplementedError() + kwargs = get_writing_mode(is_csv=True, first_written=True) + persist(data, out_edges_path, sep="\t", **kwargs) # write_filtered + + total_extra += len(data) + + return total_extra + + +def save_eval(output_dir, eval_frac): + eval_path = join(output_dir, "edges_eval_dglke.tsv") + + total_eval = 0 + + for ind, edges in enumerate(read_edges(join(output_dir, "edges_train_dglke.tsv"), as_chunks=True)): + eval = edges.sample(frac=eval_frac) + if len(eval) > 0: + kwargs = get_writing_mode(is_csv=True, first_written=ind != 0) + persist(eval, eval_path, sep="\t", **kwargs) + + total_eval += len(eval) + + return total_eval + + +def get_node_descriptions(nodes_path, distinct_node_types): + + description = {} + + for nodes in read_nodes(nodes_path, as_chunks=True): + transform_mask = nodes.eval("type in @distinct_node_types", local_dict={"distinct_node_types": distinct_node_types}) + + nodes.loc[transform_mask, "transformed"] = nodes.loc[transform_mask, "id"].astype("string") + nodes.loc[~transform_mask, "transformed"] = nodes.loc[~transform_mask, "type"] + + for id, desc in nodes[["id", "transformed"]].values: + description[id] = desc + + return description + + +def main(): + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('dataset_path', default=None, help='Path to the dataset') + parser.add_argument('output_path', default=None, help='') + parser.add_argument("--extra_objectives", action="store_true", default=False) + parser.add_argument("--eval_frac", dest="eval_frac", default=0.05, type=float) + + distinct_node_types = set(node_types.values()) | { + "FunctionDef", "mention", "Op", "#attr#", "#keyword#", "subword" + } + + args = parser.parse_args() + + if not os.path.isdir(args.output_path): + os.mkdir(args.output_path) + + nodes_path, edges_path, extra_paths = get_paths( + args.dataset_path, use_extra_objectives=args.extra_objectives + ) + + node_descriptions = get_node_descriptions(nodes_path, distinct_node_types) + + counter, total_edges, total_holdout = do_holdout(edges_path, args.output_path, node_descriptions) + + total_extra = add_extra_objectives(extra_paths, args.output_path, set(counter.keys())) + + temp_edges = join(args.output_path, "temp_common_edges.tsv") + + total_eval = save_eval(args.output_path, args.eval_frac) + + # nodes, edges = load_data(nodes_path, edges_path) + # nodes, edges, holdout = SourceGraphDataset.holdout(nodes, edges) + # edges = edges.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']] + # holdout = holdout.astype({"src": 'str', "dst": "str", "type": 'str'})[['src', 'dst', 'type']] + # + # node2graph_id = compact_property(nodes['id']) + # nodes['global_graph_id'] = nodes['id'].apply(lambda x: node2graph_id[x]) + # + # node_ids = set(nodes['id'].unique()) + # + # if args.extra_objectives: + # for objective_path in extra_paths: + # data = unpersist(objective_path) + # data = filter_relevant(data, node_ids) + # data["type"] = objective_path.split(".")[0] + # edges = edges.append(data) + + + + # edges = edges[['src','dst','type']] + # eval_sample = edges.sample(frac=args.eval_frac) + # + # persist(nodes, join(args.output_path, "nodes_dglke.csv")) + # persist(edges, join(args.output_path, "edges_train_dglke.tsv"), header=False, sep="\t") + # persist(edges, join(args.output_path, "edges_train_node2vec.tsv"), header=False, sep=" ") + # persist(eval_sample, join(args.output_path, "edges_eval_dglke.tsv"), header=False, sep="\t") + # persist(eval_sample, join(args.output_path, "edges_eval_node2vec.tsv"), header=False, sep=" ") + # persist(holdout, join(args.output_path, "edges_eval_dglke_10000.tsv"), header=False, sep="\t") + # persist(holdout, join(args.output_path, "edges_eval_node2vec_10000.tsv"), header=False, sep=" ") + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/models/nlp/TFDecoder.py b/SourceCodeTools/models/nlp/TFDecoder.py new file mode 100644 index 00000000..0c313c8f --- /dev/null +++ b/SourceCodeTools/models/nlp/TFDecoder.py @@ -0,0 +1,122 @@ +import tensorflow as tf +from tensorflow.python.keras.layers import Layer, Dense, Dropout +from tensorflow_addons.layers import MultiHeadAttention + +from SourceCodeTools.models.nlp.common import positional_encoding + + +class ConditionalDecoderLayer(tf.keras.layers.Layer): + def __init__(self, d_model, num_heads, dff, rate=0.1): + super(ConditionalDecoderLayer, self).__init__() + + def point_wise_feed_forward_network(d_model, dff): + return tf.keras.Sequential([ + tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff) + tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model) + ]) + + self.mha1 = MultiHeadAttention(d_model, num_heads, return_attn_coef=True) + self.mha2 = MultiHeadAttention(d_model, num_heads, return_attn_coef=True) + + self.ffn = point_wise_feed_forward_network(d_model, dff) + + self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) + self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) + self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6) + + self.dropout1 = tf.keras.layers.Dropout(rate) + self.dropout2 = tf.keras.layers.Dropout(rate) + self.dropout3 = tf.keras.layers.Dropout(rate) + + def call(self, inputs, look_ahead_mask=None, mask=None, training=None): + encoder_out, x = inputs + # enc_output.shape == (batch_size, input_seq_len, d_model) + + attn1, attn_weights_block1 = self.mha1((x, x, x), mask=look_ahead_mask) # (batch_size, target_seq_len, d_model) + attn1 = self.dropout1(attn1, training=training) + out1 = self.layernorm1(attn1 + x) # skip connection + + attn2, attn_weights_block2 = self.mha2( + (out1, encoder_out, encoder_out), mask=tf.tile(tf.expand_dims(encoder_out._keras_mask, axis=1), (1,out1.shape[1],1))) # (batch_size, target_seq_len, d_model) + attn2 = self.dropout2(attn2, training=training) + out2 = self.layernorm2(attn2 + out1) # skip connection (batch_size, target_seq_len, d_model) + + ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model) + ffn_output = self.dropout3(ffn_output, training=training) + out3 = self.layernorm3(ffn_output + out2) # skip connection (batch_size, target_seq_len, d_model) + + return out3, attn_weights_block1, attn_weights_block2 + + +class ConditionalAttentionDecoder(tf.keras.layers.Layer): + def __init__(self, input_dim, out_dim, num_layers, num_heads, ff_hidden, target_vocab_size, + maximum_position_encoding, rate=0.1): + super(ConditionalAttentionDecoder, self).__init__() + + self.d_model = out_dim + self.num_layers = num_layers + + self.embedding = tf.keras.layers.Embedding(target_vocab_size, input_dim) + self.pos_encoding = positional_encoding(maximum_position_encoding, input_dim) + + self.dec_layers = [ConditionalDecoderLayer(input_dim, num_heads, ff_hidden, rate) + for _ in range(num_layers)] + self.dropout = tf.keras.layers.Dropout(rate) + + self.look_ahead_mask = self.create_look_ahead_mask(1) + self.fc_out = Dense(out_dim) + + def create_look_ahead_mask(self, size): + mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) + return mask # (seq_len, seq_len) + + def compute_mask(self, inputs, mask=None): + # encoder_out, target = inputs + return mask + + def call(self, inputs, training=None, mask=None): + if len(inputs) == 2: + encoder_out, target = inputs + elif len(inputs) == 1: + encoder_out = inputs + target = None + else: + raise ValueError("Incorrect number of parameters") + + if target is None: + seq_len = tf.shape(encoder_out)[1] + x = encoder_out + else: + seq_len = tf.shape(target)[1] + x = self.embedding(target) # (batch_size, target_seq_len, d_model) + x += self.pos_encoding[:, :seq_len, :] + + if self.look_ahead_mask.shape[0] != seq_len: + self.look_ahead_mask = self.create_look_ahead_mask(seq_len) + + attention_weights = {} + + x = self.dropout(x, training=training) + + for i in range(self.num_layers): + x, block1, block2 = self.dec_layers[i]( + (encoder_out, x), look_ahead_mask=self.look_ahead_mask, mask=mask, training=training + ) + + attention_weights[f'decoder_layer{i+1}_block1'] = block1 + attention_weights[f'decoder_layer{i+1}_block2'] = block2 + + # x.shape == (batch_size, target_seq_len, d_model) + return self.fc_out(x), attention_weights + + +class FlatDecoder(Layer): + def __init__(self, out_dims, hidden=100, dropout=0.1): + super(FlatDecoder, self).__init__() + self.fc1 = Dense(hidden, activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.HeNormal()) + self.drop = Dropout(rate=dropout) + self.fc2 = Dense(out_dims) + + def call(self, inputs, training=None, mask=None): + encoder_out, target = inputs + return self.fc2(self.drop(self.fc1(encoder_out, training=training), training=training), training=training), None \ No newline at end of file diff --git a/SourceCodeTools/models/nlp/TFEncoder.py b/SourceCodeTools/models/nlp/TFEncoder.py new file mode 100644 index 00000000..c6cf6af0 --- /dev/null +++ b/SourceCodeTools/models/nlp/TFEncoder.py @@ -0,0 +1,238 @@ +from copy import copy + +from scipy.linalg import toeplitz +from tensorflow.keras.layers import Layer, Dense, Conv2D, Flatten, Input, Embedding, concatenate, GRU +from tensorflow.keras import Model +import tensorflow as tf +from tensorflow.python.keras.layers import Dropout + + +class DefaultEmbedding(Layer): + """ + Creates an embedder that provides the default value for the index -1. The default value is a zero-vector + """ + def __init__(self, init_vectors=None, shape=None, trainable=True): + super(DefaultEmbedding, self).__init__() + + if init_vectors is not None: + self.embs = tf.Variable(init_vectors, dtype=tf.float32, + trainable=trainable, name="default_embedder_var") + shape = init_vectors.shape + else: + # TODO + # the default value is no longer constant. need to replace this with a standard embedder + self.embs = tf.Variable(tf.random.uniform(shape=(shape[0], shape[1]), dtype=tf.float32), + name="default_embedder_pad") + self.pad = tf.zeros(shape=(1, shape[1]), name="default_embedder_pad") + # self.pad = tf.random.uniform(shape=(1, init_vectors.shape[1]), name="default_embedder_pad") + # self.pad = tf.Variable(tf.random.uniform(shape=(1, shape[1]), dtype=tf.float32), + # name="default_embedder_pad") + + # def compute_mask(self, inputs, mask=None): + # ids, lengths = inputs + # # position with value -1 is a pad + # return tf.sequence_mask(lengths, ids.shape[1]) + # # return inputs != self.embs.shape[0] + + def call(self, ids): + emb_matr = tf.concat([self.embs, self.pad], axis=0) + return tf.nn.embedding_lookup(params=emb_matr, ids=ids) + # return tf.expand_dims(tf.nn.embedding_lookup(params=self.emb_matr, ids=ids), axis=3) + + +class PositionalEncoding(Model): + def __init__(self, seq_len, pos_emb_size): + """ + Create positional embedding with a trainable embedding matrix. Currently not using because it results + in N^2 computational complexity. Should move this functionality to batch preparation. + :param seq_len: maximum sequence length + :param pos_emb_size: the dimensionality of positional embeddings + """ + super(PositionalEncoding, self).__init__() + + positions = list(range(seq_len * 2)) + position_splt = positions[:seq_len] + position_splt.reverse() + self.position_encoding = tf.constant(toeplitz(position_splt, positions[seq_len:]), + dtype=tf.int32, + name="position_encoding") + # self.position_embedding = tf.random.uniform(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32) + self.position_embedding = tf.Variable(tf.random.uniform(shape=(seq_len * 2, pos_emb_size), dtype=tf.float32), + name="position_embedding") + # self.position_embedding = tf.Variable(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32) + + def call(self): + # return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup") + return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup") + + +class TextCnnLayer(Model): + def __init__(self, out_dim, kernel_shape, activation=None): + super(TextCnnLayer, self).__init__() + + self.kernel_shape = kernel_shape + self.out_dim = out_dim + + self.textConv = Conv2D(filters=out_dim, kernel_size=kernel_shape, + activation=activation, data_format='channels_last') + + padding_size = (self.kernel_shape[0] - 1) // 2 + assert padding_size * 2 + 1 == self.kernel_shape[0] + self.pad_constant = tf.constant([[0, 0], [padding_size, padding_size], [0, 0], [0, 0]]) + + self.supports_masking = True + + def call(self, x, training=None, mask=None): + padded = tf.pad(x, self.pad_constant) + # emb_sent_exp = tf.expand_dims(input, axis=3) + convolve = self.textConv(padded) + return tf.squeeze(convolve, axis=-2) + + +class TextCnnEncoder(Model): + """ + TextCnnEncoder model for classifying tokens in a sequence. The model uses following pipeline: + + token_embeddings (provided from outside) -> + several convolutional layers, get representations for all tokens -> + pass representation for all tokens through a dense network -> + classify each token + """ + def __init__(self, + # input_size, + h_sizes, seq_len, + pos_emb_size, cnn_win_size, dense_size, out_dim, + activation=None, dense_activation=None, drop_rate=0.2): + """ + + :param input_size: dimensionality of input embeddings + :param h_sizes: sizes of hidden CNN layers, internal dimensionality of token embeddings + :param seq_len: maximum sequence length + :param pos_emb_size: dimensionality of positional embeddings + :param cnn_win_size: width of cnn window + :param dense_size: number of unius in dense network + :param num_classes: number of output units + :param activation: activation for cnn + :param dense_activation: activation for dense layers + :param drop_rate: dropout rate for dense network + """ + super(TextCnnEncoder, self).__init__() + + self.seq_len = seq_len + self.h_sizes = h_sizes + self.dense_size = dense_size + self.out_dim = out_dim + self.pos_emb_size = pos_emb_size + self.cnn_win_size = cnn_win_size + self.activation = activation + self.dense_activation = dense_activation + self.drop_rate = drop_rate + + self.supports_masking = True + + def build(self, input_shape): + assert len(input_shape) == 3 + + input_size = input_shape[2] + + def infer_kernel_sizes(h_sizes): + """ + Compute kernel sizes from the desired dimensionality of hidden layers + :param h_sizes: + :return: + """ + kernel_sizes = copy(h_sizes) + kernel_sizes.pop(-1) # pop last because it is the output of the last CNN layer + kernel_sizes.insert(0, input_size) # the first kernel size should be (cnn_win_size, input_size) + kernel_sizes = [(self.cnn_win_size, ks) for ks in kernel_sizes] + return kernel_sizes + + kernel_sizes = infer_kernel_sizes(self.h_sizes) + + self.layers_tok = [TextCnnLayer(out_dim=h_size, kernel_shape=kernel_size, activation=self.activation) + for h_size, kernel_size in zip(self.h_sizes, kernel_sizes)] + + # self.layers_pos = [TextCnnLayer(out_dim=h_size, kernel_shape=(cnn_win_size, pos_emb_size), activation=activation) + # for h_size, _ in zip(h_sizes, kernel_sizes)] + + # self.positional = PositionalEncoding(seq_len=seq_len, pos_emb_size=pos_emb_size) + + if self.dense_activation is None: + dense_activation = self.activation + + # self.attention = tfa.layers.MultiHeadAttention(head_size=200, num_heads=1) + + self.dense_1 = Dense(self.dense_size, activation=self.dense_activation) + self.dropout_1 = tf.keras.layers.Dropout(rate=self.drop_rate) + self.dense_2 = Dense(self.out_dim, activation=None) # logits + self.dropout_2 = tf.keras.layers.Dropout(rate=self.drop_rate) + + def compute_mask(self, inputs, mask=None): + return mask + + def call(self, embs, training=True, mask=None): + + temp_cnn_emb = embs # shape (?, seq_len, input_size) + + # pass embeddings through several CNN layers + for l in self.layers_tok: + temp_cnn_emb = l(tf.expand_dims(temp_cnn_emb, axis=3)) # shape (?, seq_len, h_size) + + # TODO + # simplify to one CNN and one attention + + # pos_cnn = self.positional() + # for l in self.layers_pos: + # pos_cnn = l(tf.expand_dims(pos_cnn, axis=3)) + # + # cnn_pool_feat = [] + # for i in range(self.seq_len): + # # slice tensor for the line that corresponds to the current position in the sentence + # position_features = tf.expand_dims(pos_cnn[i, ...], axis=0, name="exp_dim_%d" % i) + # # convolution without activation can be combined later, hence: temp_cnn_emb + position_features + # cnn_pool_feat.append( + # tf.expand_dims(tf.nn.tanh(tf.reduce_max(temp_cnn_emb + position_features, axis=1)), axis=1)) + # # cnn_pool_feat.append( + # # tf.expand_dims(tf.nn.tanh(tf.reduce_max(tf.concat([temp_cnn_emb, position_features], axis=-1), axis=1)), axis=1)) + # + # cnn_pool_features = tf.concat(cnn_pool_feat, axis=1) + cnn_pool_features = temp_cnn_emb + + # cnn_pool_features = self.attention([cnn_pool_features, cnn_pool_features]) + + # token_features = self.dropout_1( + # tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1])) + # , training=training) + + # reshape before passing through a dense network + # token_features = tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1])) # shape (? * seq_len, h_size[-1]) + + # local_h2 = self.dropout_2( + # self.dense_1(token_features) + # , training=training) + local_h2 = self.dense_1(cnn_pool_features) # shape (? * seq_len, dense_size) + tag_logits = self.dense_2(local_h2) # shape (? * seq_len, num_classes) + + return tag_logits # tf.reshape(tag_logits, (-1, seq_len, self.out_dim)) # reshape back, shape (?, seq_len, num_classes) + + +class GRUEncoder(Model): + def __init__(self, input_dim, out_dim=100, num_layers=1, dropout=0.1): + super(GRUEncoder, self).__init__() + self.num_layers = num_layers + + self.gru_layers = [ + tf.keras.layers.Bidirectional(GRU(out_dim, dropout=dropout, return_sequences=True)) for _ in range(num_layers) + ] + + self.dropout = Dropout(dropout) + self.supports_masking = True + + def call(self, inputs, training=None, mask=None): + x = inputs + + for layer in self.gru_layers: + x = layer(x, training=training, mask=mask) + x = self.dropout(x, training=training) + + return x \ No newline at end of file diff --git a/SourceCodeTools/models/nlp/Decoder.py b/SourceCodeTools/models/nlp/TorchDecoder.py similarity index 90% rename from SourceCodeTools/models/nlp/Decoder.py rename to SourceCodeTools/models/nlp/TorchDecoder.py index 1ba6f128..dae8fc21 100644 --- a/SourceCodeTools/models/nlp/Decoder.py +++ b/SourceCodeTools/models/nlp/TorchDecoder.py @@ -1,25 +1,33 @@ +import logging + import torch from torch import nn from torch.nn import Embedding import torch.nn.functional as F from torch.autograd import Variable +from SourceCodeTools.models.nlp.common import positional_encoding + class Decoder(nn.Module): - def __init__(self, encoder_out_dim, decoder_dim, out_dim, vocab_size, nheads=1, layers=1): + def __init__(self, encoder_out_dim, decoder_dim, out_dim, vocab_size, seq_len, nheads=1, layers=1): super(Decoder, self).__init__() self.encoder_adapter = nn.Linear(encoder_out_dim, decoder_dim) self.embed = nn.Embedding(vocab_size, decoder_dim) self.decoder_layer = nn.TransformerDecoderLayer(decoder_dim, nheads, dim_feedforward=decoder_dim) self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=layers) self.mask = self.generate_square_subsequent_mask(1) + self.seq_len = seq_len + + self.position_encoding = positional_encoding(self.seq_len, decoder_dim).permute(1, 0, 2) self.fc = nn.Linear(decoder_dim, out_dim) def forward(self, encoder_out, target): encoder_out = self.encoder_adapter(encoder_out).permute(1, 0, 2) target = self.embed(target).permute(1, 0, 2) - if self.mask.size(0) != target.size(0): + target = target + self.position_encoding[:target.shape[0]] + if self.mask.size(0) != target.size(0): # for self-attention self.mask = self.generate_square_subsequent_mask(target.size(0)).to(encoder_out.device) out = self.decoder(target, encoder_out, tgt_mask=self.mask) @@ -64,6 +72,7 @@ class LSTMDecoder(nn.Module): def __init__(self, num_buckets, padding=0, encoder_embed_dim=100, embed_dim=100, out_embed_dim=100, num_layers=1, dropout_in=0.1, dropout_out=0.1, use_cuda=True): + embed_dim = out_embed_dim = encoder_embed_dim super(LSTMDecoder, self).__init__() self.use_cuda = use_cuda self.dropout_in = dropout_in @@ -73,7 +82,7 @@ def __init__(self, num_buckets, padding=0, encoder_embed_dim=100, embed_dim=100, padding_idx = padding self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) - self.create_layers(embed_dim, embed_dim, num_layers) + self.create_layers(encoder_embed_dim, embed_dim, num_layers) self.attention = AttentionLayer(out_embed_dim, encoder_embed_dim, embed_dim) if embed_dim != out_embed_dim: @@ -82,10 +91,14 @@ def __init__(self, num_buckets, padding=0, encoder_embed_dim=100, embed_dim=100, def create_layers(self, encoder_embed_dim, embed_dim, num_layers): self.layers = nn.ModuleList([ - LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim) + LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else encoder_embed_dim, encoder_embed_dim if layer == 0 else embed_dim) for layer in range(num_layers) ]) + self.norms = nn.ModuleList([ + nn.LayerNorm(encoder_embed_dim + embed_dim if layer == 0 else encoder_embed_dim) for layer in range(num_layers) + ]) + def forward(self, prev_output_tokens, encoder_out, incremental_state=None, inference=False): if incremental_state is not None: # TODO what is this? prev_output_tokens = prev_output_tokens[:, -1:] @@ -112,9 +125,9 @@ def forward(self, prev_output_tokens, encoder_out, incremental_state=None, infer else: # _, encoder_hiddens, encoder_cells = encoder_out num_layers = len(self.layers) - prev_hiddens = [Variable(x.data.new(bsz, embed_dim).zero_()) for i in range(num_layers)] - prev_cells = [Variable(x.data.new(bsz, embed_dim).zero_()) for i in range(num_layers)] - input_feed = Variable(x.data.new(bsz, embed_dim).zero_()) + prev_hiddens = [Variable(x.data.new(bsz, embed_dim).zero_()) if i != 0 else encoder_outs[0] for i in range(num_layers)] + prev_cells = [Variable(x.data.new(bsz, embed_dim if i!=0 else encoder_outs.size(-1)).zero_()) for i in range(num_layers)] + input_feed = Variable(x.data.new(bsz, encoder_outs.size(-1)).zero_()) attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_()) outs = [] @@ -122,8 +135,9 @@ def forward(self, prev_output_tokens, encoder_out, incremental_state=None, infer # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], input_feed), dim=1) - for i, rnn in enumerate(self.layers): + for i, (rnn, lnorm) in enumerate(zip(self.layers, self.norms)): # recurrent cell + input = lnorm(input) hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer diff --git a/SourceCodeTools/models/nlp/Encoder.py b/SourceCodeTools/models/nlp/TorchEncoder.py similarity index 97% rename from SourceCodeTools/models/nlp/Encoder.py rename to SourceCodeTools/models/nlp/TorchEncoder.py index 82adf0c3..938d9efb 100644 --- a/SourceCodeTools/models/nlp/Encoder.py +++ b/SourceCodeTools/models/nlp/TorchEncoder.py @@ -9,7 +9,7 @@ class Encoder(nn.Module): def __init__(self, encoder_dim, out_dim, nheads=1, layers=1): super(Encoder, self).__init__() # self.embed = nn.Embedding(vocab_size, encoder_dim) - self.encoder_lauer = nn.TransformerEncoderLayer(encoder_dim, nheads, dim_feedforward=encoder_dim) + self.encoder_layer = nn.TransformerEncoderLayer(encoder_dim, nheads, dim_feedforward=encoder_dim) self.encoder = nn.TransformerEncoder(self.encoder_lauer, num_layers=layers) self.out_adapter = nn.Linear(encoder_dim, out_dim) @@ -29,7 +29,6 @@ def forward(self, input, lengths=None): return out.permute(1, 0, 2) - class LSTMEncoder(nn.Module): """LSTM encoder.""" def __init__(self, embed_dim=100, num_layers=1, dropout_in=0.1, dropout_out=0.1): diff --git a/SourceCodeTools/models/nlp/common.py b/SourceCodeTools/models/nlp/common.py new file mode 100644 index 00000000..1004b0b1 --- /dev/null +++ b/SourceCodeTools/models/nlp/common.py @@ -0,0 +1,20 @@ +import numpy as np + +def positional_encoding(position, d_model): + def get_angles(pos, i, d_model): + angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) + return pos * angle_rates + + angle_rads = get_angles(np.arange(position)[:, np.newaxis], + np.arange(d_model)[np.newaxis, :], + d_model) + + # apply sin to even indices in the array; 2i + angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) + + # apply cos to odd indices in the array; 2i+1 + angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) + + pos_encoding = angle_rads[np.newaxis, ...] + return pos_encoding + # return tf.cast(pos_encoding, dtype=tf.float32) \ No newline at end of file diff --git a/SourceCodeTools/models/training_config.py b/SourceCodeTools/models/training_config.py new file mode 100644 index 00000000..9e7ebd4b --- /dev/null +++ b/SourceCodeTools/models/training_config.py @@ -0,0 +1,126 @@ +from copy import copy + +import yaml + + +config_specification = { + "DATASET": { + "data_path": None, + "train_frac": 0.9, + "filter_edges": None, + "min_count_for_objectives": 5, + # "packages_file": None, # partition file + "self_loops": False, + "use_node_types": False, + "use_edge_types": False, + "no_global_edges": False, + "remove_reverse": False, + "custom_reverse": None, + "restricted_id_pool": None, + "random_seed": None, + "subgraph_id_column": "mentioned_in", + "subgraph_partition": None + }, + "TRAINING": { + "model_output_dir": None, + "pretrained": None, + "pretraining_phase": 0, + + "sampling_neighbourhood_size": 10, + "neg_sampling_factor": 3, + "use_layer_scheduling": False, + "schedule_layers_every": 10, + + "elem_emb_size": 100, + "embedding_table_size": 200000, + + "epochs": 100, + "batch_size": 128, + "learning_rate": 1e-3, + + "objectives": None, + "save_each_epoch": False, + "save_checkpoints": True, + "early_stopping": False, + "early_stopping_tolerance": 20, + + "force_w2v_ns": False, + "use_ns_groups": False, + "nn_index": "brute", + + "metric": "inner_prod", + + "measure_scores": False, + "dilate_scores": 200, # downsample + + "gpu": -1, + + "external_dataset": None, + + "restore_state": False, + }, + "MODEL": { + "node_emb_size": 100, + "h_dim": 100, + "n_layers": 5, + "use_self_loop": True, + + "use_gcn_checkpoint": False, + "use_att_checkpoint": False, + "use_gru_checkpoint": False, + + 'num_bases': 10, + 'dropout': 0.0, + + 'activation': "tanh", + # torch.nn.functional.hardswish], #[torch.nn.functional.hardtanh], #torch.nn.functional.leaky_relu + }, + "TOKENIZER": { + "tokenizer_path": None, + } +} + + +def default_config(): + config = copy(config_specification) + + assert len( + set(config["MODEL"].keys()) | + set(config["TRAINING"].keys()) | + set(config["DATASET"].keys()) + ) == ( + len(set(config["MODEL"].keys())) + + len(set(config["TRAINING"].keys())) + + len(set(config["DATASET"].keys())) + ), "All parameter names in the configuration should be unique" + return config + + +def get_config(**kwargs): + config = default_config() + return update_config(config, **kwargs) + + +def update_config(config, **kwargs): + recognized_options = set() + + for section, args in config.items(): + for key, value in kwargs.items(): + if key in args: + args[key] = value + recognized_options.add(key) + + unrecognized = {key: value for key, value in kwargs.items() if key not in recognized_options} + + if len(unrecognized) > 0: + raise ValueError(f"Some configuration options are not recognized: {unrecognized}") + + return config + + +def save_config(config, path): + yaml.dump(config, open(path, "w")) + + +def load_config(path): + return yaml.load(open(path, "r").read(), Loader=yaml.Loader) \ No newline at end of file diff --git a/SourceCodeTools/models/training_options.py b/SourceCodeTools/models/training_options.py new file mode 100644 index 00000000..2100f4ff --- /dev/null +++ b/SourceCodeTools/models/training_options.py @@ -0,0 +1,85 @@ +def add_data_arguments(parser): + parser.add_argument('--data_path', '-d', dest='data_path', default=None, help='Path to folder with dataset') + parser.add_argument('--train_frac', dest='train_frac', default=0.9, type=float, help='Fraction of nodes to be used for training') + parser.add_argument('--filter_edges', dest='filter_edges', default=None, help='Comma separated list of edges that should be filtered from the graph') + parser.add_argument('--min_count_for_objectives', dest='min_count_for_objectives', default=5, type=int, help='Filter all target examples that occurr less than set numbers of times') + # parser.add_argument('--packages_file', dest='packages_file', default=None, type=str, help='???') + parser.add_argument('--self_loops', action='store_true', help='Add self loops to the graph') + parser.add_argument('--use_node_types', action='store_true', help='Add node types to the graph') + parser.add_argument('--use_edge_types', action='store_true', help='Add edge types to the graph') + parser.add_argument('--restore_state', action='store_true', help='Load from checkpoint') + parser.add_argument('--no_global_edges', action='store_true', help='Remove all global edges from the graph') + parser.add_argument('--remove_reverse', action='store_true', help="Remove reverse edges from the graph") + parser.add_argument('--custom_reverse', dest='custom_reverse', default=None, help='List of edges for which to add reverse types. Should use together with `remove_reverse`') + parser.add_argument('--restricted_id_pool', dest='restricted_id_pool', default=None, help='???') + parser.add_argument('--subgraph_partition', default=None) + parser.add_argument('--subgraph_id_column', default=None) + + +def add_pretraining_arguments(parser): + parser.add_argument('--pretrained', '-p', dest='pretrained', default=None, help='Path to pretrained subtoken vectors') + parser.add_argument('--tokenizer_path', '-t', dest='tokenizer_path', default=None, help='???') + parser.add_argument('--pretraining_phase', dest='pretraining_phase', default=0, type=int, help='Number of epochs for pretraining') + + +def add_training_arguments(parser): + parser.add_argument('--embedding_table_size', dest='embedding_table_size', default=200000, type=int, help='Bucket size for the embedding table. Overriden when pretrained vectors provided???') + parser.add_argument('--random_seed', dest='random_seed', default=None, type=int, help='Random seed for generating dataset splits') + + parser.add_argument('--node_emb_size', dest='node_emb_size', default=100, type=int, help='Dimensionality of node embeddings') + parser.add_argument('--elem_emb_size', dest='elem_emb_size', default=100, type=int, help='Dimensionality of target embeddings (node names). Should match node embeddings when cosine distance loss is used') + parser.add_argument('--sampling_neighbourhood_size', dest='sampling_neighbourhood_size', default=10, type=int, help='Number of dependencies to sample per node') + parser.add_argument('--neg_sampling_factor', dest='neg_sampling_factor', default=3, type=int, help='Number of negative samples for each positive') + + parser.add_argument('--use_layer_scheduling', action='store_true', help='???') + parser.add_argument('--schedule_layers_every', dest='schedule_layers_every', default=10, type=int, help='???') + + parser.add_argument('--epochs', dest='epochs', default=100, type=int, help='Number of epochs') + parser.add_argument('--batch_size', dest='batch_size', default=128, type=int, help='Batch size') + + parser.add_argument("--h_dim", dest="h_dim", default=None, type=int, help='Should be the same as `node_emb_size`') + parser.add_argument("--n_layers", dest="n_layers", default=5, type=int, help='Number of layers') + parser.add_argument("--objectives", dest="objectives", default=None, type=str, help='???') + + parser.add_argument("--save_each_epoch", action="store_true", help='Save checkpoints for each epoch (high disk space utilization)') + parser.add_argument("--early_stopping", action="store_true", help='???') + parser.add_argument("--early_stopping_tolerance", default=20, type=int, help='???') + parser.add_argument("--force_w2v_ns", action="store_true", help='Use w2v negative sampling strategy p_unigram^(3/4)') + parser.add_argument("--use_ns_groups", action="store_true", help='Perform negative sampling only from closest neighbours???') + + parser.add_argument("--metric", default="inner_prod", type=str, help='???') + parser.add_argument("--nn_index", default="brute", type=str, help='Index backend for generating negative samples???') + + parser.add_argument("--external_dataset", default=None, type=str, help='Path to external graph, use for inference') + + +def add_scoring_arguments(parser): + parser.add_argument('--measure_scores', action='store_true') + parser.add_argument('--dilate_scores', dest='dilate_scores', default=200, type=int, help='') + + +def add_performance_arguments(parser): + parser.add_argument('--no_checkpoints', dest="save_checkpoints", action='store_false') + + parser.add_argument('--use_gcn_checkpoint', action='store_true') + parser.add_argument('--use_att_checkpoint', action='store_true') + parser.add_argument('--use_gru_checkpoint', action='store_true') + + +def add_gnn_train_args(parser): + parser.add_argument("--config", default=None, help="Path to config file") + add_data_arguments(parser) + add_pretraining_arguments(parser) + add_training_arguments(parser) + add_scoring_arguments(parser) + add_performance_arguments(parser) + + # parser.add_argument('--note', dest='note', default="", help='Note, added to metadata') + parser.add_argument('model_output_dir', help='Location of the final model') + + # parser.add_argument('--intermediate_supervision', action='store_true') + parser.add_argument('--gpu', dest='gpu', default=-1, type=int, help='') + + +def verify_arguments(args): + pass \ No newline at end of file diff --git a/SourceCodeTools/nlp/batchers/PythonBatcher.py b/SourceCodeTools/nlp/batchers/PythonBatcher.py index e6ff2049..42f5c580 100644 --- a/SourceCodeTools/nlp/batchers/PythonBatcher.py +++ b/SourceCodeTools/nlp/batchers/PythonBatcher.py @@ -1,21 +1,26 @@ import json import os import shelve +import shutil import tempfile -from functools import lru_cache +from collections import defaultdict +from time import time from math import ceil from typing import Dict, Optional, List -from spacy.gold import biluo_tags_from_offsets +import spacy -from SourceCodeTools.code.ast_tools import get_declarations +from SourceCodeTools.code.ast.ast_tools import get_declarations from SourceCodeTools.models.ClassWeightNormalizer import ClassWeightNormalizer from SourceCodeTools.nlp import create_tokenizer, tag_map_from_sentences, TagMap, token_hasher, try_int from SourceCodeTools.nlp.entity import fix_incorrect_tags +from SourceCodeTools.code.annotator_utils import adjust_offsets, biluo_tags_from_offsets from SourceCodeTools.nlp.entity.utils import overlap - import numpy as np + + + def filter_unlabeled(entities, declarations): """ Get a list of declarations that were not mentioned in `entities` @@ -33,37 +38,46 @@ def filter_unlabeled(entities, declarations): return for_mask +def print_token_tag(doc, tags): + for t, tag in zip(doc, tags): + print(t, "\t", tag) + + class PythonBatcher: def __init__( self, data, batch_size: int, seq_len: int, - graphmap: Dict[str, int], wordmap: Dict[str, int], tagmap: Optional[TagMap] = None, + wordmap: Dict[str, int], *, graphmap: Optional[Dict[str, int]], tagmap: Optional[TagMap] = None, mask_unlabeled_declarations=True, - class_weights=False, element_hash_size=1000 + class_weights=False, element_hash_size=1000, len_sort=True, tokenizer="spacy", no_localization=False ): self.create_cache() - self.data = data + self.data = sorted(data, key=lambda x: len(x[0])) if len_sort else data self.batch_size = batch_size self.seq_len = seq_len self.class_weights = None self.mask_unlabeled_declarations = mask_unlabeled_declarations + self.tokenizer = tokenizer + if tokenizer == "codebert": + self.vocab = spacy.blank("en").vocab + self.no_localization = no_localization - self.nlp = create_tokenizer("spacy") + self.nlp = create_tokenizer(tokenizer) if tagmap is None: self.tagmap = tag_map_from_sentences(list(zip(*[self.prepare_sent(sent) for sent in data]))[1]) else: self.tagmap = tagmap - self.graphpad = len(graphmap) + self.graphpad = len(graphmap) if graphmap is not None else None self.wordpad = len(wordmap) self.tagpad = self.tagmap["O"] self.prefpad = element_hash_size self.suffpad = element_hash_size - self.graphmap_func = lambda g: graphmap.get(g, len(graphmap)) + self.graphmap_func = (lambda g: graphmap.get(g, len(graphmap))) if graphmap is not None else None self.wordmap_func = lambda w: wordmap.get(w, len(wordmap)) - self.tagmap_func = lambda t: self.tagmap.get(t, len(self.tagmap)) + self.tagmap_func = lambda t: self.tagmap.get(t, self.tagmap["O"]) self.prefmap_func = lambda w: token_hasher(w[:3], element_hash_size) self.suffmap_func = lambda w: token_hasher(w[-3:], element_hash_size) @@ -81,10 +95,26 @@ def __init__( else: self.classw_func = lambda t: 1. + def __del__(self): + self.sent_cache.close() + self.batch_cache.close() + + from shutil import rmtree + rmtree(self.tmp_dir, ignore_errors=True) + def create_cache(self): - self.tmp_dir = tempfile.TemporaryDirectory() - self.sent_cache = shelve.open(os.path.join(self.tmp_dir.name, "sent_cache.db")) - self.batch_cache = shelve.open(os.path.join(self.tmp_dir.name, "batch_cache.db")) + char_ranges = [chr(i) for i in range(ord("a"), ord("a")+26)] + [chr(i) for i in range(ord("A"), ord("A")+26)] + [chr(i) for i in range(ord("0"), ord("0")+10)] + from random import sample + rnd_name = "".join(sample(char_ranges, k=10)) + str(int(time() * 1e6)) + time() + + self.tmp_dir = os.path.join(tempfile.gettempdir(), rnd_name) + if os.path.isdir(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + os.mkdir(self.tmp_dir) + + self.sent_cache = shelve.open(os.path.join(self.tmp_dir, "sent_cache")) + self.batch_cache = shelve.open(os.path.join(self.tmp_dir, "batch_cache")) def num_classes(self): return len(self.tagmap) @@ -106,16 +136,48 @@ def prepare_sent(self, sent): unlabeled_dec = filter_unlabeled(ents, get_declarations(text)) tokens = [t.text for t in doc] - ents_tags = biluo_tags_from_offsets(doc, ents) - repl_tags = biluo_tags_from_offsets(doc, repl) + + if self.tokenizer == "codebert": + backup_tokens = doc + fixed_spaces = [False] + fixed_words = [""] + + for ind, t in enumerate(doc): + if len(t.text) > 1: + fixed_words.append(t.text.strip("Ġ")) + else: + fixed_words.append(t.text) + if ind != 0: + fixed_spaces.append(t.text.startswith("Ġ") and len(t.text) > 1) + fixed_spaces.append(False) + fixed_spaces.append(False) + fixed_words.append("") + + assert len(fixed_spaces) == len(fixed_words) + + from spacy.tokens import Doc + doc = Doc(self.vocab, fixed_words, fixed_spaces) + + assert len(doc) - 2 == len(backup_tokens) + assert len(doc.text) - 7 == len(backup_tokens.text) + ents = adjust_offsets(ents, -3) + repl = adjust_offsets(repl, -3) + if self.mask_unlabeled_declarations: + unlabeled_dec = adjust_offsets(unlabeled_dec, -3) + + ents_tags = biluo_tags_from_offsets(doc, ents, self.no_localization) + repl_tags = biluo_tags_from_offsets(doc, repl, self.no_localization) if self.mask_unlabeled_declarations: - unlabeled_dec = biluo_tags_from_offsets(doc, unlabeled_dec) + unlabeled_dec = biluo_tags_from_offsets(doc, unlabeled_dec, self.no_localization) fix_incorrect_tags(ents_tags) fix_incorrect_tags(repl_tags) if self.mask_unlabeled_declarations: fix_incorrect_tags(unlabeled_dec) + if self.tokenizer == "codebert": + tokens = [""] + [t.text for t in backup_tokens] + [""] + assert len(tokens) == len(ents_tags) == len(repl_tags) if self.mask_unlabeled_declarations: assert len(tokens) == len(unlabeled_dec) @@ -151,7 +213,7 @@ def encode(seq, encode_func, pad): pref = encode(sent, self.prefmap_func, self.prefpad) suff = encode(sent, self.suffmap_func, self.suffpad) s = encode(sent, self.wordmap_func, self.wordpad) - r = encode(repl, self.graphmap_func, self.graphpad) # TODO test + r = encode(repl, self.graphmap_func, self.graphpad) if self.graphmap_func is not None else None # labels t = encode(tags, self.tagmap_func, self.tagpad) @@ -160,38 +222,55 @@ def encode(seq, encode_func, pad): hidem = encode( list(range(len(sent))) if unlabeled_decls is None else unlabeled_decls, self.mask_unlbl_func, self.mask_unlblpad - ) + ).astype(np.bool) # class weights classw = encode(tags, self.classw_func, self.classwpad) - assert len(s) == len(r) == len(pref) == len(suff) == len(t) == len(classw) == len(hidem) + assert len(s) == len(pref) == len(suff) == len(t) == len(classw) == len(hidem) + if r is not None: + assert len(r) == len(s) + + no_localization_mask = np.array([tag != self.tagpad for tag in t]).astype(np.bool) output = { "tok_ids": s, - "graph_ids": r, + "replacements": repl, + # "graph_ids": r, "prefix": pref, "suffix": suff, "tags": t, "class_weights": classw, "hide_mask": hidem, - "lens": len(s) if len(s) < self.seq_len else self.seq_len + "no_loc_mask": no_localization_mask, + "lens": len(sent) if len(sent) < self.seq_len else self.seq_len } + if r is not None: + output["graph_ids"] = r + self.batch_cache[input_json] = output return output def format_batch(self, batch): - fbatch = { - "tok_ids": [], "graph_ids": [], "prefix": [], "suffix": [], - "tags": [], "class_weights": [], "hide_mask": [], "lens": [] - } + # fbatch = { + # "tok_ids": [], "graph_ids": [], "prefix": [], "suffix": [], + # "tags": [], "class_weights": [], "hide_mask": [], "lens": [], "replacements": [] + # } + fbatch = defaultdict(list) for sent in batch: for key, val in sent.items(): fbatch[key].append(val) - return {key: np.stack(val) if key != "lens" else np.array(val, dtype=np.int32) for key, val in fbatch.items()} + if len(fbatch["graph_ids"]) == 0: + fbatch.pop("graph_ids") + + max_len = max(fbatch["lens"]) + + return { + key: np.stack(val)[:,:max_len] if key != "lens" and key != "replacements" + else (np.array(val, dtype=np.int32) if key != "replacements" else np.array(val)) for key, val in fbatch.items()} def generate_batches(self): batch = [] @@ -200,7 +279,9 @@ def generate_batches(self): if len(batch) >= self.batch_size: yield self.format_batch(batch) batch = [] - yield self.format_batch(batch) + if len(batch) > 0: + yield self.format_batch(batch) + # yield self.format_batch(batch) def __iter__(self): return self.generate_batches() diff --git a/SourceCodeTools/nlp/batchers/PythonBatcherWithMentions.py b/SourceCodeTools/nlp/batchers/PythonBatcherWithMentions.py index c9fa2c53..a84a3fcd 100644 --- a/SourceCodeTools/nlp/batchers/PythonBatcherWithMentions.py +++ b/SourceCodeTools/nlp/batchers/PythonBatcherWithMentions.py @@ -5,7 +5,7 @@ import numpy as np from spacy.gold import biluo_tags_from_offsets -from SourceCodeTools.code.ast_tools import get_declarations +from SourceCodeTools.code.ast.ast_tools import get_declarations from SourceCodeTools.nlp import TagMap, try_int from SourceCodeTools.nlp.batchers import PythonBatcher from SourceCodeTools.nlp.entity import fix_incorrect_tags diff --git a/SourceCodeTools/nlp/codebert/codebert.py b/SourceCodeTools/nlp/codebert/codebert.py new file mode 100644 index 00000000..78bd9c2d --- /dev/null +++ b/SourceCodeTools/nlp/codebert/codebert.py @@ -0,0 +1,114 @@ +import pickle + +import torch +from tqdm import tqdm +from transformers import RobertaTokenizer, RobertaModel + +from SourceCodeTools.models.Embedder import Embedder +from SourceCodeTools.nlp.entity.type_prediction import get_type_prediction_arguments, ModelTrainer +from SourceCodeTools.nlp.entity.utils.data import read_data + + +def load_typed_nodes(path): + from SourceCodeTools.code.data.file_utils import unpersist + type_ann = unpersist(path) + + filter_rule = lambda name: "0x" not in name + + type_ann = type_ann[ + type_ann["dst"].apply(filter_rule) + ] + + typed_nodes = set(type_ann["src"].tolist()) + return typed_nodes + + +class CodeBertModelTrainer(ModelTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_type_ann_edges(self, path): + self.type_ann_edges = path + + def get_batcher(self, *args, **kwargs): + kwargs.update({"tokenizer": "codebert"}) + return self.batcher(*args, **kwargs) + + def train_model(self): + # graph_emb = load_pkl_emb(self.graph_emb_path) if self.graph_emb_path is not None else None + + typed_nodes = load_typed_nodes(self.type_ann_edges) + + decoder_mapping = RobertaTokenizer.from_pretrained("microsoft/codebert-base").decoder + tok_ids, words = zip(*decoder_mapping.items()) + vocab_mapping = dict(zip(words, tok_ids)) + batcher = self.get_batcher( + self.train_data + self.test_data, self.batch_size, seq_len=self.seq_len, + graphmap=None, + wordmap=vocab_mapping, tagmap=None, + class_weights=False, element_hash_size=1 + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = RobertaModel.from_pretrained("microsoft/codebert-base") + model.to(device) + + node_ids = [] + embeddings = [] + + for ind, batch in enumerate(tqdm(batcher)): + # token_ids, graph_ids, labels, class_weights, lengths = b + token_ids = torch.LongTensor(batch["tok_ids"]) + lens = torch.LongTensor(batch["lens"]) + + token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""] + + def get_length_mask(target, lens): + mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None] + return mask + + mask = get_length_mask(token_ids, lens) + with torch.no_grad(): + embs = model(input_ids=token_ids, attention_mask=mask) + + for s_emb, s_repl in zip(embs.last_hidden_state, batch["replacements"]): + unique_repls = set(list(s_repl)) + repls_for_ann = [r for r in unique_repls if r in typed_nodes] + + for r in repls_for_ann: + position = s_repl.index(r) + if position > 512: + continue + node_ids.append(r) + embeddings.append(s_emb[position]) + + all_embs = torch.stack(embeddings, dim=0).numpy() + embedder = Embedder(dict(zip(node_ids, range(len(node_ids)))), all_embs) + pickle.dump(embedder, open("codebert_embeddings.pkl", "wb"), fix_imports=False) + print(node_ids) + + +def main(): + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") + model = RobertaModel.from_pretrained("microsoft/codebert-base") + # model.to(device) + args = get_type_prediction_arguments() + + # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray', + # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', 'Type'} + + train_data, test_data = read_data( + open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, + include_only="entities", + min_entity_count=args.min_entity_count, random_seed=args.random_seed + ) + + trainer = CodeBertModelTrainer(train_data, test_data, params={}, seq_len=512) + trainer.set_type_ann_edges(args.type_ann_edges) + trainer.train_model() + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/nlp/codebert/codebert_train.py b/SourceCodeTools/nlp/codebert/codebert_train.py new file mode 100644 index 00000000..e166786f --- /dev/null +++ b/SourceCodeTools/nlp/codebert/codebert_train.py @@ -0,0 +1,421 @@ +import json +import os +import pickle +from datetime import datetime +from time import time + +import torch +from torch.utils.tensorboard import SummaryWriter +from torch.version import cuda +from tqdm import tqdm +from transformers import RobertaTokenizer, RobertaModel + +from SourceCodeTools.models.Embedder import Embedder +from SourceCodeTools.nlp.codebert.codebert import CodeBertModelTrainer, load_typed_nodes +from SourceCodeTools.nlp.entity.type_prediction import get_type_prediction_arguments, ModelTrainer, load_pkl_emb, \ + scorer, filter_labels +from SourceCodeTools.nlp.entity.utils.data import read_data + +import torch.nn as nn + +class CodebertHybridModel(nn.Module): + def __init__( + self, codebert_model, graph_emb, padding_idx, num_classes, dense_hidden=100, dropout=0.1, bert_emb_size=768, + no_graph=False + ): + super(CodebertHybridModel, self).__init__() + + self.codebert_model = codebert_model + self.use_graph = not no_graph + + num_emb = padding_idx + 1 # padding id is usually not a real embedding + emb_dim = graph_emb.shape[1] + self.graph_emb = nn.Embedding(num_embeddings=num_emb, embedding_dim=emb_dim, padding_idx=padding_idx) + + import numpy as np + pretrained_embeddings = torch.from_numpy(np.concatenate([graph_emb, np.zeros((1, emb_dim))], axis=0)).float() + new_param = torch.nn.Parameter(pretrained_embeddings) + assert self.graph_emb.weight.shape == new_param.shape + self.graph_emb.weight = new_param + self.graph_emb.weight.requires_grad = False + + self.fc1 = nn.Linear( + bert_emb_size + (emb_dim if self.use_graph else 0), + dense_hidden + ) + self.drop = nn.Dropout(dropout) + self.fc2 = nn.Linear(dense_hidden, num_classes) + + self.loss_f = nn.CrossEntropyLoss(reduction="mean") + + def forward(self, token_ids, graph_ids, mask, finetune=False): + if finetune: + x = self.codebert_model(input_ids=token_ids, attention_mask=mask).last_hidden_state + else: + with torch.no_grad(): + x = self.codebert_model(input_ids=token_ids, attention_mask=mask).last_hidden_state + + if self.use_graph: + graph_emb = self.graph_emb(graph_ids) + x = torch.cat([x, graph_emb], dim=-1) + + x = torch.relu(self.fc1(x)) + x = self.drop(x) + x = self.fc2(x) + + return x + + def loss(self, logits, labels, mask, class_weights=None, extra_mask=None): + if extra_mask is not None: + mask = torch.logical_and(mask, extra_mask) + logits = logits[mask, :] + labels = labels[mask] + loss = self.loss_f(logits, labels) + # if class_weights is None: + # loss = tf.reduce_mean(tf.boolean_mask(losses, seq_mask)) + # else: + # loss = tf.reduce_mean(tf.boolean_mask(losses * class_weights, seq_mask)) + + return loss + + def score(self, logits, labels, mask, scorer=None, extra_mask=None): + if extra_mask is not None: + mask = torch.logical_and(mask, extra_mask) + true_labels = labels[mask] + argmax = logits.argmax(-1) + estimated_labels = argmax[mask] + + p, r, f1 = scorer(to_numpy(estimated_labels), to_numpy(true_labels)) + + return p, r, f1 + + +def get_length_mask(target, lens): + mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None] + return mask + + +def batch_to_torch(batch, device): + key_types = { + 'tok_ids': torch.LongTensor, + 'tags': torch.LongTensor, + 'hide_mask': torch.BoolTensor, + 'no_loc_mask': torch.BoolTensor, + 'lens': torch.LongTensor, + 'graph_ids': torch.LongTensor + } + for key, tf in key_types.items(): + batch[key] = tf(batch[key]).to(device) + + +def to_numpy(tensor): + return tensor.cpu().detach().numpy() + + +def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids, labels, lengths, + extra_mask=None, class_weights=None, scorer=None, finetune=False, vocab_mapping=None): + token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""] + seq_mask = get_length_mask(token_ids, lengths) + logits = model(token_ids, graph_ids, mask=seq_mask, finetune=finetune) + loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask) + p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return loss, p, r, f1 + + +def test_step( + model, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None, + vocab_mapping=None +): + with torch.no_grad(): + token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""] + seq_mask = get_length_mask(token_ids, lengths) + logits = model(token_ids, graph_ids, mask=seq_mask) + loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask) + p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask) + + return loss, p, r, f1 + + +class CodeBertModelTrainer2(CodeBertModelTrainer): + def __init__(self, *args, gpu_id=-1, **kwargs): + self.gpu_id = gpu_id + super().__init__(*args, **kwargs) + self.set_gpu() + + def get_dataloaders(self, word_emb, graph_emb, suffix_prefix_buckets): + decoder_mapping = RobertaTokenizer.from_pretrained("microsoft/codebert-base").decoder + tok_ids, words = zip(*decoder_mapping.items()) + self.vocab_mapping = dict(zip(words, tok_ids)) + + train_batcher = self.get_batcher( + self.train_data, self.batch_size, seq_len=self.seq_len, + graphmap=graph_emb.ind if graph_emb is not None else None, + wordmap=self.vocab_mapping, tagmap=None, + class_weights=False, element_hash_size=suffix_prefix_buckets, no_localization=self.no_localization + ) + test_batcher = self.get_batcher( + self.test_data, self.batch_size, seq_len=self.seq_len, + graphmap=graph_emb.ind if graph_emb is not None else None, + wordmap=self.vocab_mapping, + tagmap=train_batcher.tagmap, # use the same mapping + class_weights=False, element_hash_size=suffix_prefix_buckets, # class_weights are not used for testing + no_localization=self.no_localization + ) + return train_batcher, test_batcher + + def train( + self, model, train_batches, test_batches, epochs, report_every=10, scorer=None, learning_rate=0.01, + learning_rate_decay=1., finetune=False, summary_writer=None, save_ckpt_fn=None, no_localization=False + ): + + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=learning_rate_decay) + + train_losses = [] + test_losses = [] + train_f1s = [] + test_f1s = [] + + num_train_batches = len(train_batches) + num_test_batches = len(test_batches) + + best_f1 = 0. + + try: + for e in range(epochs): + losses = [] + ps = [] + rs = [] + f1s = [] + + start = time() + model.train() + + for ind, batch in enumerate(tqdm(train_batches)): + batch_to_torch(batch, self.device) + # token_ids, graph_ids, labels, class_weights, lengths = b + loss, p, r, f1 = train_step_finetune( + model=model, optimizer=optimizer, token_ids=batch['tok_ids'], + prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'], + labels=batch['tags'], lengths=batch['lens'], + extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'], + # class_weights=batch['class_weights'], + scorer=scorer, finetune=finetune and e / epochs > 0.6, + vocab_mapping=self.vocab_mapping + ) + losses.append(loss.cpu().item()) + ps.append(p) + rs.append(r) + f1s.append(f1) + + self.summary_writer.add_scalar("Loss/Train", loss, global_step=e * num_train_batches + ind) + self.summary_writer.add_scalar("Precision/Train", p, global_step=e * num_train_batches + ind) + self.summary_writer.add_scalar("Recall/Train", r, global_step=e * num_train_batches + ind) + self.summary_writer.add_scalar("F1/Train", f1, global_step=e * num_train_batches + ind) + + test_alosses = [] + test_aps = [] + test_ars = [] + test_af1s = [] + + model.eval() + + for ind, batch in enumerate(test_batches): + batch_to_torch(batch, self.device) + # token_ids, graph_ids, labels, class_weights, lengths = b + test_loss, test_p, test_r, test_f1 = test_step( + model=model, token_ids=batch['tok_ids'], + prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'], + labels=batch['tags'], lengths=batch['lens'], + extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'], + # class_weights=batch['class_weights'], + scorer=scorer, vocab_mapping=self.vocab_mapping + ) + + self.summary_writer.add_scalar("Loss/Test", test_loss, global_step=e * num_test_batches + ind) + self.summary_writer.add_scalar("Precision/Test", test_p, global_step=e * num_test_batches + ind) + self.summary_writer.add_scalar("Recall/Test", test_r, global_step=e * num_test_batches + ind) + self.summary_writer.add_scalar("F1/Test", test_f1, global_step=e * num_test_batches + ind) + test_alosses.append(test_loss.cpu().item()) + test_aps.append(test_p) + test_ars.append(test_r) + test_af1s.append(test_f1) + + epoch_time = time() - start + + train_losses.append(float(sum(losses) / len(losses))) + train_f1s.append(float(sum(f1s) / len(f1s))) + test_losses.append(float(sum(test_alosses) / len(test_alosses))) + test_f1s.append(float(sum(test_af1s) / len(test_af1s))) + + print( + f"Epoch: {e}, {epoch_time: .2f} s, Train Loss: {train_losses[-1]: .4f}, Train P: {sum(ps) / len(ps): .4f}, Train R: {sum(rs) / len(rs): .4f}, Train F1: {sum(f1s) / len(f1s): .4f}, " + f"Test loss: {test_losses[-1]: .4f}, Test P: {sum(test_aps) / len(test_aps): .4f}, Test R: {sum(test_ars) / len(test_ars): .4f}, Test F1: {test_f1s[-1]: .4f}") + + if save_ckpt_fn is not None and float(test_f1s[-1]) > best_f1: + save_ckpt_fn() + best_f1 = float(test_f1s[-1]) + + scheduler.step(epoch=e) + + except KeyboardInterrupt: + pass + + return train_losses, train_f1s, test_losses, test_f1s + + def create_summary_writer(self, path): + self.summary_writer = SummaryWriter(path) + + def set_gpu(self): + # torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.gpu_id == -1: + self.use_cuda = False + self.device = "cpu" + else: + torch.cuda.set_device(self.gpu_id) + self.use_cuda = True + self.device = f"cuda:{self.gpu_id}" + + def train_model(self): + + print(f"\n\n{self.model_params}") + lr = self.model_params.pop("learning_rate") + lr_decay = self.model_params.pop("learning_rate_decay") + suffix_prefix_buckets = self.model_params.pop("suffix_prefix_buckets") + + graph_emb = load_pkl_emb(self.graph_emb_path) if self.graph_emb_path is not None else None + + train_batcher, test_batcher = self.get_dataloaders(None, graph_emb, suffix_prefix_buckets=suffix_prefix_buckets) + + codebert_model = RobertaModel.from_pretrained("microsoft/codebert-base") + model = CodebertHybridModel( + codebert_model, graph_emb.e, padding_idx=train_batcher.graphpad, num_classes=train_batcher.num_classes(), + no_graph=self.no_graph + ) + if self.use_cuda: + model.cuda() + + trial_dir = os.path.join(self.output_dir, "codebert_" + str(datetime.now())).replace(":", "-").replace(" ", "_") + os.mkdir(trial_dir) + self.create_summary_writer(trial_dir) + + def save_ckpt_fn(): + checkpoint_path = os.path.join(trial_dir, "checkpoint") + torch.save(model, open(checkpoint_path, 'wb')) + + train_losses, train_f1, test_losses, test_f1 = self.train( + model=model, train_batches=train_batcher, test_batches=test_batcher, + epochs=self.epochs, learning_rate=lr, + scorer=lambda pred, true: scorer(pred, true, train_batcher.tagmap, no_localization=self.no_localization), + learning_rate_decay=lr_decay, finetune=self.finetune, save_ckpt_fn=save_ckpt_fn, + no_localization=self.no_localization + ) + + # checkpoint_path = os.path.join(trial_dir, "checkpoint") + # model.save_weights(checkpoint_path) + + metadata = { + "train_losses": train_losses, + "train_f1": train_f1, + "test_losses": test_losses, + "test_f1": test_f1, + "learning_rate": lr, + "learning_rate_decay": lr_decay, + "epochs": self.epochs, + "suffix_prefix_buckets": suffix_prefix_buckets, + "seq_len": self.seq_len, + "batch_size": self.batch_size, + "no_localization": self.no_localization + } + + print("Maximum f1:", max(test_f1)) + + # write_config(trial_dir, params, extra_params={"suffix_prefix_buckets": suffix_prefix_buckets, "seq_len": seq_len}) + + metadata.update(self.model_params) + + with open(os.path.join(trial_dir, "params.json"), "w") as metadata_sink: + metadata_sink.write(json.dumps(metadata, indent=4)) + + pickle.dump(train_batcher.tagmap, open(os.path.join(trial_dir, "tag_types.pkl"), "wb")) + + # for ind, batch in enumerate(tqdm(batcher)): + # # token_ids, graph_ids, labels, class_weights, lengths = b + # token_ids = torch.LongTensor(batch["tok_ids"]) + # lens = torch.LongTensor(batch["lens"]) + # + # token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""] + # + # def get_length_mask(target, lens): + # mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None] + # return mask + # + # mask = get_length_mask(token_ids, lens) + # with torch.no_grad(): + # embs = model(input_ids=token_ids, attention_mask=mask) + # + # for s_emb, s_repl in zip(embs.last_hidden_state, batch["replacements"]): + # unique_repls = set(list(s_repl)) + # repls_for_ann = [r for r in unique_repls if r in typed_nodes] + # + # for r in repls_for_ann: + # position = s_repl.index(r) + # if position > 512: + # continue + # node_ids.append(r) + # embeddings.append(s_emb[position]) + # + # all_embs = torch.stack(embeddings, dim=0).numpy() + # embedder = Embedder(dict(zip(node_ids, range(len(node_ids)))), all_embs) + # pickle.dump(embedder, open("codebert_embeddings.pkl", "wb"), fix_imports=False) + # print(node_ids) + + +def main(): + args = get_type_prediction_arguments() + + # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray', + # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', 'Type'} + if args.restrict_allowed: + allowed = { + 'str', 'Optional', 'int', 'Any', 'Union', 'bool', 'Callable', 'Dict', 'bytes', 'float', 'Description', + 'List', 'Sequence', 'Namespace', 'T', 'Type', 'object', 'HTTPServerRequest', 'Future', "Matcher" + } + else: + allowed = None + + # train_data, test_data = read_data( + # open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, + # include_only="entities", + # min_entity_count=args.min_entity_count, random_seed=args.random_seed + # ) + + from pathlib import Path + dataset_dir = Path(args.data_path).parent + train_data = filter_labels( + pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_train.pkl"), "rb")), + allowed=allowed + ) + test_data = filter_labels( + pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_test.pkl"), "rb")), + allowed=allowed + ) + + trainer = CodeBertModelTrainer2( + train_data, test_data, params={"learning_rate": 1e-4, "learning_rate_decay": 0.99, "suffix_prefix_buckets": 1}, + graph_emb_path=args.graph_emb_path, word_emb_path=args.word_emb_path, + output_dir=args.model_output, epochs=args.epochs, batch_size=args.batch_size, gpu_id=args.gpu, + finetune=args.finetune, trials=args.trials, seq_len=args.max_seq_len, no_localization=args.no_localization, + no_graph=args.no_graph + ) + trainer.set_type_ann_edges(args.type_ann_edges) + trainer.train_model() + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/nlp/codebert/codebert_training.py b/SourceCodeTools/nlp/codebert/codebert_training.py new file mode 100644 index 00000000..5f1ebd11 --- /dev/null +++ b/SourceCodeTools/nlp/codebert/codebert_training.py @@ -0,0 +1,101 @@ +import pickle + +import torch +from tqdm import tqdm +from transformers import RobertaTokenizer, RobertaModel + +from SourceCodeTools.models.Embedder import Embedder +from SourceCodeTools.nlp.codebert.codebert import CodeBertModelTrainer, load_typed_nodes +from SourceCodeTools.nlp.entity.type_prediction import get_type_prediction_arguments, ModelTrainer +from SourceCodeTools.nlp.entity.utils.data import read_data + + +class CodeBertModelTrainer2(CodeBertModelTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_type_ann_edges(self, path): + self.type_ann_edges = path + + def get_batcher(self, *args, **kwargs): + kwargs.update({"tokenizer": "codebert"}) + return self.batcher(*args, **kwargs) + + def train_model(self): + # graph_emb = load_pkl_emb(self.graph_emb_path) if self.graph_emb_path is not None else None + + typed_nodes = load_typed_nodes(self.type_ann_edges) + + decoder_mapping = RobertaTokenizer.from_pretrained("microsoft/codebert-base").decoder + tok_ids, words = zip(*decoder_mapping.items()) + vocab_mapping = dict(zip(words, tok_ids)) + batcher = self.get_batcher( + self.train_data + self.test_data, self.batch_size, seq_len=self.seq_len, + graphmap=None, + wordmap=vocab_mapping, tagmap=None, + class_weights=False, element_hash_size=1 + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = RobertaModel.from_pretrained("microsoft/codebert-base") + model.to(device) + + node_ids = [] + embeddings = [] + + for ind, batch in enumerate(tqdm(batcher)): + # token_ids, graph_ids, labels, class_weights, lengths = b + token_ids = torch.LongTensor(batch["tok_ids"]) + lens = torch.LongTensor(batch["lens"]) + + token_ids[token_ids == len(vocab_mapping)] = vocab_mapping[""] + + def get_length_mask(target, lens): + mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None] + return mask + + mask = get_length_mask(token_ids, lens) + with torch.no_grad(): + embs = model(input_ids=token_ids, attention_mask=mask) + + for s_emb, s_repl in zip(embs.last_hidden_state, batch["replacements"]): + unique_repls = set(list(s_repl)) + repls_for_ann = [r for r in unique_repls if r in typed_nodes] + + for r in repls_for_ann: + position = s_repl.index(r) + if position > 512: + continue + node_ids.append(r) + embeddings.append(s_emb[position]) + + all_embs = torch.stack(embeddings, dim=0).numpy() + embedder = Embedder(dict(zip(node_ids, range(len(node_ids)))), all_embs) + pickle.dump(embedder, open("codebert_embeddings.pkl", "wb"), fix_imports=False) + print(node_ids) + + +def main(): + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") + model = RobertaModel.from_pretrained("microsoft/codebert-base") + # model.to(device) + args = get_type_prediction_arguments() + + # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray', + # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', 'Type'} + + train_data, test_data = read_data( + open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, + include_only="entities", + min_entity_count=args.min_entity_count, random_seed=args.random_seed + ) + + trainer = CodeBertModelTrainer(train_data, test_data, params={}, seq_len=512) + trainer.set_type_ann_edges(args.type_ann_edges) + trainer.train_model() + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/SourceCodeTools/nlp/embed/lda.py b/SourceCodeTools/nlp/embed/lda.py index fddbeafa..76f229ae 100644 --- a/SourceCodeTools/nlp/embed/lda.py +++ b/SourceCodeTools/nlp/embed/lda.py @@ -11,7 +11,7 @@ def read_corpus(path, data_field): if path.endswith("bz2") or path.endswith("parquet") or path.endswith("csv"): - from SourceCodeTools.code.data.sourcetrail.file_utils import unpersist + from SourceCodeTools.code.data.file_utils import unpersist data = unpersist(path)[data_field].tolist() elif path.endswith("jsonl"): import json diff --git a/SourceCodeTools/nlp/entity/apply_model.py b/SourceCodeTools/nlp/entity/apply_model.py index d6d75e68..9ddd64bd 100644 --- a/SourceCodeTools/nlp/entity/apply_model.py +++ b/SourceCodeTools/nlp/entity/apply_model.py @@ -4,11 +4,17 @@ import os import pickle +import numpy as np +from tqdm import tqdm + +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' + import tensorflow as tf from SourceCodeTools.nlp.entity import parse_biluo from SourceCodeTools.nlp.entity.tf_models.tf_model import TypePredictor -from SourceCodeTools.nlp.entity.type_prediction import PythonBatcher, load_pkl_emb +from SourceCodeTools.nlp.entity.type_prediction import PythonBatcher, load_pkl_emb, span_f1 from SourceCodeTools.nlp.entity.utils.data import read_data @@ -42,14 +48,14 @@ def predict_one(model, input, tagmap): return logits_to_annotations( tf.math.argmax(model( token_ids=input['tok_ids'], prefix_ids=input['prefix'], suffix_ids=input['suffix'], - graph_ids=input['graph_ids'], training=False + graph_ids=input['graph_ids'], training=False, mask=tf.sequence_mask(input["lens"], input["tok_ids"].shape[1]) ), axis=-1), input['lens'], tagmap ) -def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=None, checkpoint_path=None, batch_size=1): +def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=None, checkpoint_path=None): graph_emb = load_pkl_emb(graph_emb_path) word_emb = load_pkl_emb(word_emb_path) @@ -61,10 +67,11 @@ def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=No seq_len = params.pop("seq_len") suffix_prefix_buckets = params.pop("suffix_prefix_buckets") + batch_size = params.pop("batch_size") data_batcher = Batcher( data, batch_size, seq_len=seq_len, wordmap=word_emb.ind, graphmap=graph_emb.ind, tagmap=tagmap, - mask_unlabeled_declarations=True, class_weights=False, element_hash_size=suffix_prefix_buckets + mask_unlabeled_declarations=True, class_weights=False, element_hash_size=suffix_prefix_buckets, len_sort=True ) params.pop("train_losses") @@ -76,26 +83,96 @@ def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=No params.pop("learning_rate_decay") model = Model( - word_emb, graph_emb, train_embeddings=False, num_classes=len(tagmap), seq_len=seq_len, **params + word_emb, graph_emb, train_embeddings=False, num_classes=len(tagmap), seq_len=seq_len, + suffix_prefix_buckets=suffix_prefix_buckets, **params ) model.load_weights(os.path.join(checkpoint_path, "checkpoint")) all_true = [] all_estimated = [] + true_scoring = [] + pred_scoring = [] - for batch in data_batcher: + for batch_ind, batch in enumerate(data_batcher): true_annotations = logits_to_annotations(batch['tags'], batch['lens'], tagmap) est_annotations = predict_one(model, batch, tagmap) all_true.extend(true_annotations) all_estimated.extend(est_annotations) + for s_ind in range(len(batch["lens"])): + for ent in true_annotations[s_ind]: + true_scoring.append((batch_ind, s_ind, ent)) + for ent in est_annotations[s_ind]: + pred_scoring.append((batch_ind, s_ind, ent)) + + precision, recall, f1 = span_f1(set(pred_scoring), set(true_scoring)) + + with open(os.path.join(args.checkpoint_path, "scores.txt"), "w") as scores_sink: + scores_str = f"Precision: {precision: .2f}, Recall: {recall: .2f}, f1: {f1: .2f}" + print(scores_str) + scores_sink.write(f"{scores_str}\n") + from SourceCodeTools.nlp.entity.entity_render import render_annotations - html = render_annotations(zip(data, all_estimated, all_true)) - with open("render.html", "w") as render: + html = render_annotations(zip(data_batcher.data, all_estimated, all_true)) + with open(os.path.join(args.checkpoint_path, "render.html"), "w") as render: render.write(html) + estimate_confusion(all_estimated, all_true) + + +def estimate_confusion(pred, true): + pred_filtered = [] + true_filtered = [] + for p, t in zip(pred, true): + for e in p: + e_span = e[:2] + for e_t in t: + t_span = e_t[:2] + if e_span == t_span: + pred_filtered.append(e[2]) + true_filtered.append(e_t[2]) + break + + labels = sorted(list(set(true_filtered + pred_filtered))) + label2ind = dict(zip(labels, range(len(labels)))) + + confusion = np.zeros((len(labels), len(labels))) + + for pred, true in zip(pred_filtered, true_filtered): + confusion[label2ind[true], label2ind[pred]] += 1 + + norm = np.array([x if x != 0 else 1. for x in np.sum(confusion, axis=1)]).reshape(-1,1) + confusion /= norm + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(45,45)) + im = ax.imshow(confusion) + + # We want to show all ticks... + ax.set_xticks(np.arange(len(labels))) + ax.set_yticks(np.arange(len(labels))) + # ... and label them with the respective list entries + ax.set_xticklabels(labels) + ax.set_yticklabels(labels) + + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", + rotation_mode="anchor") + + # Loop over data dimensions and create text annotations. + for i in range(len(labels)): + for j in range(len(labels)): + text = ax.text(j, i, f"{confusion[i, j]: .2f}", + ha="center", va="center", color="w") + + ax.set_title("Confusion matrix for Python type prediction") + fig.tight_layout() + # plt.show() + plt.savefig(os.path.join(args.checkpoint_path, "confusion.png")) + if __name__ == "__main__": import argparse @@ -108,16 +185,18 @@ def apply_to_dataset(data, Batcher, Model, graph_emb_path=None, word_emb_path=No parser.add_argument('--word_emb_path', dest='word_emb_path', default=None, help='Path to the file with edges') parser.add_argument('checkpoint_path', default=None, help='') + parser.add_argument('--random_seed', dest='random_seed', default=None, type=int, + help='') args = parser.parse_args() train_data, test_data = read_data( open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, - include_only="entities", - min_entity_count=5 + include_only="entities", random_seed=args.random_seed, + min_entity_count=3 ) apply_to_dataset( test_data, PythonBatcher, TypePredictor, graph_emb_path=args.graph_emb_path, word_emb_path=args.word_emb_path, - checkpoint_path=args.checkpoint_path, batch_size=args.batch_size + checkpoint_path=args.checkpoint_path ) diff --git a/SourceCodeTools/nlp/entity/entity_render.py b/SourceCodeTools/nlp/entity/entity_render.py index 9b708142..8dcf73ed 100644 --- a/SourceCodeTools/nlp/entity/entity_render.py +++ b/SourceCodeTools/nlp/entity/entity_render.py @@ -1,5 +1,5 @@ import spacy -from SourceCodeTools.nlp.entity.util import inject_tokenizer +from SourceCodeTools.nlp import create_tokenizer html_template = """ @@ -57,7 +57,7 @@ def annotate(doc, entities): def render_annotations(annotations): - nlp = inject_tokenizer(spacy.blank("en")) + nlp = create_tokenizer("spacy") entries = "" for annotation in annotations: text, predicted, annotated = annotation diff --git a/SourceCodeTools/nlp/entity/map_args_to_mentions.py b/SourceCodeTools/nlp/entity/map_args_to_mentions.py new file mode 100644 index 00000000..087c1970 --- /dev/null +++ b/SourceCodeTools/nlp/entity/map_args_to_mentions.py @@ -0,0 +1,48 @@ +import argparse +import json +from os.path import join + +from SourceCodeTools.code.data.dataset.Dataset import load_data +from SourceCodeTools.code.data.file_utils import unpersist + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("working_directory") + parser.add_argument("output") + args = parser.parse_args() + + nodes, edges = load_data(join(args.working_directory, "nodes.bz2"), join(args.working_directory, "edges.bz2")) + type_annotated = set(unpersist(join(args.working_directory, "type_annotations.bz2"))["src"].tolist()) + arguments = set(nodes.query("type == 'arg'")["id"].tolist()) + mentions = set(nodes.query("type == 'mention'")["id"].tolist()) + + edges["in_mentions"] = edges["src"].apply(lambda src: src in mentions) + + edges["in_args"] = edges["dst"].apply(lambda dst: dst in arguments) + + edges = edges.query("in_mentions == True and in_args == True") + + mapping = {} + for src, dst in edges[["src", "dst"]].values: + if dst in mapping: + print() + mapping[dst] = src + + with open(args.output, "w") as sink: + with open(join(args.working_directory, "function_annotations.jsonl")) as fa: + for line in fa: + entry = json.loads(line) + new_repl = [[s, e, int(mapping.get(r, r))] for s, e, r in entry["replacements"]] + entry["replacements"] = new_repl + + sink.write(f"{json.dumps(entry)}\n") + + + print() + + # pickle.dump(mapping, open(args.output, "wb")) + + +if __name__ == "__main__": + main() diff --git a/SourceCodeTools/nlp/entity/split_dataset.py b/SourceCodeTools/nlp/entity/split_dataset.py new file mode 100644 index 00000000..2aa75596 --- /dev/null +++ b/SourceCodeTools/nlp/entity/split_dataset.py @@ -0,0 +1,34 @@ +import pickle +from collections import Counter + +from SourceCodeTools.nlp.entity.type_prediction import get_type_prediction_arguments +from SourceCodeTools.nlp.entity.utils.data import read_data + + +def get_all_annotations(dataset): + ann = [] + for _, annotations in dataset: + for _, _, e in annotations["entities"]: + ann.append(e) + return ann + + +if __name__ == "__main__": + + args = get_type_prediction_arguments() + + train_data, test_data = read_data( + open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, include_only="entities", + min_entity_count=args.min_entity_count, random_seed=args.random_seed + ) + + pickle.dump(train_data, open("type_prediction_dataset_no_defaults_train.pkl", "wb")) + pickle.dump(test_data, open("type_prediction_dataset_no_defaults_test.pkl", "wb")) + + ent_counts = Counter(get_all_annotations(train_data)) | Counter(get_all_annotations(test_data)) + + with open("type_prediction_dataset_argument_annotations_counts.txt", "w") as sink: + for ent, count in ent_counts.most_common(): + sink.write(f"{ent}\t{count}\n") + + print() \ No newline at end of file diff --git a/SourceCodeTools/nlp/entity/tf_models/params.py b/SourceCodeTools/nlp/entity/tf_models/params.py index f8560eeb..341e578f 100644 --- a/SourceCodeTools/nlp/entity/tf_models/params.py +++ b/SourceCodeTools/nlp/entity/tf_models/params.py @@ -20,22 +20,30 @@ # "suffix_prefix_dims": 20, # "suffix_prefix_buckets": 1000, # }, - { - "h_sizes": [40, 40, 40], - "dense_size": 30, - "pos_emb_size": 30, - "cnn_win_size": 5, - "suffix_prefix_dims": 50, - "suffix_prefix_buckets": 2000, - }, - # { + # { + # "h_sizes": [40, 40, 40], + # "dense_size": 30, + # "pos_emb_size": 30, + # "cnn_win_size": 5, + # "suffix_prefix_dims": 50, + # "suffix_prefix_buckets": 2000, + # }, + # { # "h_sizes": [80, 80, 80], # "dense_size": 40, # "pos_emb_size": 50, # "cnn_win_size": 7, # "suffix_prefix_dims": 70, # "suffix_prefix_buckets": 3000, - # } + # }, + { + "h_sizes": [100, 100, 100], + "dense_size": 60, + "pos_emb_size": 50, + "cnn_win_size": 7, + "suffix_prefix_dims": 70, + "suffix_prefix_buckets": 3000, + } ], "learning_rate": [0.0001], "learning_rate_decay": [0.998] # 0.991 @@ -115,4 +123,4 @@ def flatten(params): att_params = list( map(flatten, ParameterGrid(att_params_grids)) -) \ No newline at end of file +) diff --git a/SourceCodeTools/nlp/entity/tf_models/tf_model.py b/SourceCodeTools/nlp/entity/tf_models/tf_model.py index 4a1ba2b9..100d52ec 100644 --- a/SourceCodeTools/nlp/entity/tf_models/tf_model.py +++ b/SourceCodeTools/nlp/entity/tf_models/tf_model.py @@ -1,17 +1,21 @@ # import tensorflow as tf # import sys # from gensim.models import Word2Vec +import logging +from time import time + import numpy as np # from collections import Counter from scipy.linalg import toeplitz # from gensim.models import KeyedVectors -from copy import copy +from copy import copy, deepcopy import tensorflow as tf import tensorflow_addons as tfa from tensorflow.keras.layers import Dense, Conv2D, Flatten, Input, Embedding, concatenate from tensorflow.keras import Model +from tensorflow.keras.layers import Layer # from tensorflow.keras import regularizers # from spacy.gold import offsets_from_biluo_tags @@ -21,194 +25,45 @@ # https://github.com/dhiraa/tener/tree/master/src/tener/models # https://arxiv.org/pdf/1903.07785v1.pdf # https://github.com/tensorflow/models/tree/master/research/cvt_text/model - - -class DefaultEmbedding(Model): - """ - Creates an embedder that provides the default value for the index -1. The default value is a zero-vector - """ - def __init__(self, init_vectors=None, shape=None, trainable=True): - super(DefaultEmbedding, self).__init__() - - if init_vectors is not None: - self.embs = tf.Variable(init_vectors, dtype=tf.float32, - trainable=trainable, name="default_embedder_var") - shape = init_vectors.shape - else: - # TODO - # the default value is no longer constant. need to replace this with a standard embedder - self.embs = tf.Variable(tf.random.uniform(shape=(shape[0], shape[1]), dtype=tf.float32), - name="default_embedder_pad") - # self.pad = tf.zeros(shape=(1, init_vectors.shape[1]), name="default_embedder_pad") - # self.pad = tf.random.uniform(shape=(1, init_vectors.shape[1]), name="default_embedder_pad") - self.pad = tf.Variable(tf.random.uniform(shape=(1, shape[1]), dtype=tf.float32), - name="default_embedder_pad") - - - def __call__(self, ids): - emb_matr = tf.concat([self.embs, self.pad], axis=0) - return tf.nn.embedding_lookup(params=emb_matr, ids=ids) - # return tf.expand_dims(tf.nn.embedding_lookup(params=self.emb_matr, ids=ids), axis=3) - - -class PositionalEncoding(Model): - def __init__(self, seq_len, pos_emb_size): - """ - Create positional embedding matrix for tokens. Currently not using because it results in N^2 computational - complexity. Should move this functionality to batch preparation. - :param seq_len: maximum sequence length - :param pos_emb_size: the dimensionality of positional embeddings - """ - super(PositionalEncoding, self).__init__() - - positions = list(range(seq_len * 2)) - position_splt = positions[:seq_len] - position_splt.reverse() - self.position_encoding = tf.constant(toeplitz(position_splt, positions[seq_len:]), - dtype=tf.int32, - name="position_encoding") - # self.position_embedding = tf.random.uniform(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32) - self.position_embedding = tf.Variable(tf.random.uniform(shape=(seq_len * 2, pos_emb_size), dtype=tf.float32), - name="position_embedding") - # self.position_embedding = tf.Variable(name="position_embedding", shape=(seq_len * 2, pos_emb_size), dtype=tf.float32) - - def __call__(self): - # return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup") - return tf.nn.embedding_lookup(self.position_embedding, self.position_encoding, name="position_lookup") - - -class TextCnnLayer(Model): - """ - - """ - def __init__(self, out_dim, kernel_shape, activation=None): - super(TextCnnLayer, self).__init__() - - self.kernel_shape = kernel_shape - self.out_dim = out_dim - - self.textConv = Conv2D(filters=out_dim, kernel_size=kernel_shape, - activation=activation, data_format='channels_last') - - padding_size = (self.kernel_shape[0] - 1) // 2 - assert padding_size * 2 + 1 == self.kernel_shape[0] - self.pad_constant = tf.constant([[0, 0], [padding_size, padding_size], [0, 0], [0, 0]]) - - def __call__(self, x): - padded = tf.pad(x, self.pad_constant) - # emb_sent_exp = tf.expand_dims(input, axis=3) - convolve = self.textConv(padded) - return tf.squeeze(convolve, axis=-2) - - -class TextCnn(Model): - """ - TextCnn model for classifying tokens in a sequence. The model uses following pipeline: - - token_embeddings (provided from outside) -> - several convolutional layers, get representations for all tokens -> - pass representation for all tokens through a dense network -> - classify each token - """ - def __init__(self, input_size, h_sizes, seq_len, - pos_emb_size, cnn_win_size, dense_size, num_classes, - activation=None, dense_activation=None, drop_rate=0.2): - """ - - :param input_size: dimensionality of input embeddings - :param h_sizes: sizes of hidden CNN layers, internal dimensionality of token embeddings - :param seq_len: maximum sequence length - :param pos_emb_size: dimensionality of positional embeddings - :param cnn_win_size: width of cnn window - :param dense_size: number of unius in dense network - :param num_classes: number of output units - :param activation: activation for cnn - :param dense_activation: activation for dense layers - :param drop_rate: dropout rate for dense network - """ - super(TextCnn, self).__init__() - - self.seq_len = seq_len - self.h_sizes = h_sizes - self.dense_size = dense_size - self.num_classes = num_classes - - def infer_kernel_sizes(h_sizes): - """ - Compute kernel sizes from the desired dimensionality of hidden layers - :param h_sizes: - :return: - """ - kernel_sizes = copy(h_sizes) - kernel_sizes.pop(-1) # pop last because it is the output of the last CNN layer - kernel_sizes.insert(0, input_size) # the first kernel size should be (cnn_win_size, input_size) - kernel_sizes = [(cnn_win_size, ks) for ks in kernel_sizes] - return kernel_sizes - - kernel_sizes = infer_kernel_sizes(h_sizes) - - self.layers_tok = [ TextCnnLayer(out_dim=h_size, kernel_shape=kernel_size, activation=activation) - for h_size, kernel_size in zip(h_sizes, kernel_sizes)] - - # self.layers_pos = [TextCnnLayer(out_dim=h_size, kernel_shape=(cnn_win_size, pos_emb_size), activation=activation) - # for h_size, _ in zip(h_sizes, kernel_sizes)] - - # self.positional = PositionalEncoding(seq_len=seq_len, pos_emb_size=pos_emb_size) - - if dense_activation is None: - dense_activation = activation - - # self.attention = tfa.layers.MultiHeadAttention(head_size=200, num_heads=1) - - self.dense_1 = Dense(dense_size, activation=dense_activation) - self.dropout_1 = tf.keras.layers.Dropout(rate=drop_rate) - self.dense_2 = Dense(num_classes, activation=None) # logits - self.dropout_2 = tf.keras.layers.Dropout(rate=drop_rate) - - def __call__(self, embs, training=True): - - temp_cnn_emb = embs # shape (?, seq_len, input_size) - - # pass embeddings through several CNN layers - for l in self.layers_tok: - temp_cnn_emb = l(tf.expand_dims(temp_cnn_emb, axis=3)) # shape (?, seq_len, h_size) - - # TODO - # simplify to one CNN and one attention - - # pos_cnn = self.positional() - # for l in self.layers_pos: - # pos_cnn = l(tf.expand_dims(pos_cnn, axis=3)) - # - # cnn_pool_feat = [] - # for i in range(self.seq_len): - # # slice tensor for the line that corresponds to the current position in the sentence - # position_features = tf.expand_dims(pos_cnn[i, ...], axis=0, name="exp_dim_%d" % i) - # # convolution without activation can be combined later, hence: temp_cnn_emb + position_features - # cnn_pool_feat.append( - # tf.expand_dims(tf.nn.tanh(tf.reduce_max(temp_cnn_emb + position_features, axis=1)), axis=1)) - # # cnn_pool_feat.append( - # # tf.expand_dims(tf.nn.tanh(tf.reduce_max(tf.concat([temp_cnn_emb, position_features], axis=-1), axis=1)), axis=1)) - # - # cnn_pool_features = tf.concat(cnn_pool_feat, axis=1) - cnn_pool_features = temp_cnn_emb - - # cnn_pool_features = self.attention([cnn_pool_features, cnn_pool_features]) - - # token_features = self.dropout_1( - # tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1])) - # , training=training) - - # reshape before passing through a dense network - token_features = tf.reshape(cnn_pool_features, shape=(-1, self.h_sizes[-1])) # shape (? * seq_len, h_size[-1]) - - # local_h2 = self.dropout_2( - # self.dense_1(token_features) - # , training=training) - local_h2 = self.dense_1(token_features) # shape (? * seq_len, dense_size) - tag_logits = self.dense_2(local_h2) # shape (? * seq_len, num_classes) - - return tf.reshape(tag_logits, (-1, self.seq_len, self.num_classes)) # reshape back, shape (?, seq_len, num_classes) +from tensorflow_addons.layers import MultiHeadAttention +from tqdm import tqdm + +from SourceCodeTools.models.nlp.TFDecoder import ConditionalAttentionDecoder, FlatDecoder +from SourceCodeTools.models.nlp.TFEncoder import DefaultEmbedding, TextCnnEncoder, GRUEncoder + + +class T5Encoder(Model): + def __init__(self): + super(T5Encoder, self).__init__() + from transformers.models.t5.modeling_tf_t5 import T5Config, TFT5MainLayer + self.adapter = Dense(40, activation=tf.nn.relu) + self.t5_config = T5Config(d_model=40, d_ff=40, num_layers=1, num_decoder_layers=1, num_heads=1, is_encoder_decoder=False,d_kv=40) + self.encoder = TFT5MainLayer(self.t5_config,) + + def call(self, embs, training=True, mask=None): + from transformers.modeling_tf_utils import input_processing + inputs = input_processing( + func=self.call, + config=self.t5_config, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=self.adapter(embs), + decoder_inputs_embeds=None, + use_cache=False, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=training, + kwargs_call={}, + ) + encoded = self.encoder(inputs["inputs"], inputs_embeds=inputs["inputs_embeds"], training=training) + return encoded["last_hidden_state"] class TypePredictor(Model): @@ -222,9 +77,10 @@ class TypePredictor(Model): def __init__(self, tok_embedder, graph_embedder, train_embeddings=False, h_sizes=None, dense_size=100, num_classes=None, seq_len=100, pos_emb_size=30, cnn_win_size=3, - crf_transitions=None, suffix_prefix_dims=50, suffix_prefix_buckets=1000): + crf_transitions=None, suffix_prefix_dims=50, suffix_prefix_buckets=1000, + no_graph=False): """ - Initialize TypePredictor. Model initializes embedding layers and then passes embeddings to TextCnn model + Initialize TypePredictor. Model initializes embedding layers and then passes embeddings to TextCnnEncoder model :param tok_embedder: Embedder for tokens :param graph_embedder: Embedder for graph nodes :param train_embeddings: whether to finetune embeddings @@ -255,31 +111,38 @@ def __init__(self, tok_embedder, graph_embedder, train_embeddings=False, self.prefix_emb = DefaultEmbedding(shape=(suffix_prefix_buckets, suffix_prefix_dims)) self.suffix_emb = DefaultEmbedding(shape=(suffix_prefix_buckets, suffix_prefix_dims)) - # self.tok_emb = Embedding(input_dim=tok_embedder.e.shape[0], - # output_dim=tok_embedder.e.shape[1], - # weights=tok_embedder.e, trainable=train_embeddings, - # mask_zero=True) - # - # self.graph_emb = Embedding(input_dim=graph_embedder.e.shape[0], - # output_dim=graph_embedder.e.shape[1], - # weights=graph_embedder.e, trainable=train_embeddings, - # mask_zero=True) - # compute final embedding size after concatenation - input_dim = tok_embedder.e.shape[1] + suffix_prefix_dims * 2 + graph_embedder.e.shape[1] - - self.text_cnn = TextCnn(input_size=input_dim, h_sizes=h_sizes, - seq_len=seq_len, pos_emb_size=pos_emb_size, - cnn_win_size=cnn_win_size, dense_size=dense_size, - num_classes=num_classes, activation=tf.nn.relu, - dense_activation=tf.nn.tanh) + input_dim = tok_embedder.e.shape[1] + suffix_prefix_dims * 2 #+ graph_embedder.e.shape[1] + + if cnn_win_size % 2 == 0: + cnn_win_size += 1 + logging.info(f"Window size should be odd. Setting to {cnn_win_size}") + + self.encoder = TextCnnEncoder( + # input_size=input_dim, + h_sizes=h_sizes, + seq_len=seq_len, pos_emb_size=pos_emb_size, + cnn_win_size=cnn_win_size, dense_size=dense_size, + out_dim=input_dim, activation=tf.nn.relu, + dense_activation=tf.nn.tanh) + # self.encoder = GRUEncoder(input_dim=input_dim, out_dim=input_dim, num_layers=1, dropout=0.1) + # self.encoder = T5Encoder() + + # self.decoder = ConditionalAttentionDecoder( + # input_dim, out_dim=num_classes, num_layers=1, num_heads=1, + # ff_hidden=100, target_vocab_size=num_classes, maximum_position_encoding=self.seq_len + # ) + self.decoder = FlatDecoder(out_dims=num_classes) self.crf_transition_params = None + self.supports_masking = True + self.use_graph = not no_graph - def __call__(self, token_ids, prefix_ids, suffix_ids, graph_ids, training=True): + # @tf.function + def __call__(self, token_ids, prefix_ids, suffix_ids, graph_ids, target=None, training=False, mask=None): """ - Inference + # Inference :param token_ids: ids for tokens, shape (?, seq_len) :param prefix_ids: ids for prefixes, shape (?, seq_len) :param suffix_ids: ids for suffixes, shape (?, seq_len) @@ -287,22 +150,35 @@ def __call__(self, token_ids, prefix_ids, suffix_ids, graph_ids, training=True): :param training: whether to finetune embeddings :return: logits for token classes, shape (?, seq_len, num_classes) """ + assert mask is not None, "Mask is required" + tok_emb = self.tok_emb(token_ids) - graph_emb = self.graph_emb(graph_ids) prefix_emb = self.prefix_emb(prefix_ids) suffix_emb = self.suffix_emb(suffix_ids) embs = tf.concat([tok_emb, - graph_emb, + # graph_emb, prefix_emb, suffix_emb], axis=-1) - logits = self.text_cnn(embs, training=training) + encoded = self.encoder(embs, training=training, mask=mask) + # if target is None: + # logits = self.decoder.seq_decode(encoded, training=training, mask=mask) + # else: + if self.use_graph: + graph_emb = self.graph_emb(graph_ids) + encoded = tf.concat([encoded, graph_emb], axis=-1) + logits, _ = self.decoder((encoded, target), training=training, mask=mask) # consider sending input instead of target return logits + def compute_mask(self, inputs, mask=None): + mask = self.encoder.compute_mask(None, mask=mask) + return self.decoder.compute_mask(None, mask=mask) - def loss(self, logits, labels, lengths, class_weights=None, extra_mask=None): + + # @tf.function + def loss(self, logits, labels, mask, class_weights=None, extra_mask=None): """ Compute cross-entropy loss for each meaningful tokens. Mask padded tokens. :param logits: shape (?, seq_len, num_classes) @@ -313,7 +189,7 @@ def loss(self, logits, labels, lengths, class_weights=None, extra_mask=None): :return: average cross-entropy loss """ losses = tf.nn.softmax_cross_entropy_with_logits(tf.one_hot(labels, depth=logits.shape[-1]), logits, axis=-1) - seq_mask = tf.sequence_mask(lengths, self.seq_len) + seq_mask = mask # logits._keras_mask# tf.sequence_mask(lengths, self.seq_len) if extra_mask is not None: seq_mask = tf.math.logical_and(seq_mask, extra_mask) if class_weights is None: @@ -329,28 +205,35 @@ def loss(self, logits, labels, lengths, class_weights=None, extra_mask=None): return loss - def score(self, logits, labels, lengths, scorer=None, extra_mask=None): + def score(self, logits, labels, mask, scorer=None, extra_mask=None): """ Compute precision, recall and f1 scores using the provided scorer function :param logits: shape (?, seq_len, num_classes) :param labels: ids of token labels, shape (?, seq_len) :param lengths: tensor of actual sentence lengths, shape (?,) - :param scorer: scorer function, takes `pred_labels` and `true_labels` as aguments + :param scorer: scorer function, takes `pred_labels` and `true_labels` as arguments :param extra_mask: mask for hiding some of the token labels, not counting them towards the score, shape (?, seq_len) :return: """ - mask = tf.sequence_mask(lengths, self.seq_len) + # mask = logits._keras_mask # tf.sequence_mask(lengths, self.seq_len) if extra_mask is not None: mask = tf.math.logical_and(mask, extra_mask) true_labels = tf.boolean_mask(labels, mask) argmax = tf.math.argmax(logits, axis=-1) estimated_labels = tf.cast(tf.boolean_mask(argmax, mask), tf.int32) - p, r, f1 = scorer(estimated_labels.numpy(), true_labels.numpy()) + p, r, f1 = scorer(to_numpy(estimated_labels), to_numpy(true_labels)) return p, r, f1 +def to_numpy(tensor): + if hasattr(tensor, "numpy"): + return tensor.numpy() + else: + return tf.make_ndarray(tf.make_tensor_proto(tensor)) + + # def estimate_crf_transitions(batches, n_tags): # transitions = [] # for _, _, labels, lengths in batches: @@ -359,7 +242,7 @@ def score(self, logits, labels, lengths, scorer=None, extra_mask=None): # # return np.stack(transitions, axis=0).mean(axis=0) - +# @tf.function def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None, finetune=False): """ @@ -379,9 +262,11 @@ def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids, :return: values for loss, precision, recall and f1-score """ with tf.GradientTape() as tape: - logits = model(token_ids, prefix, suffix, graph_ids, training=True) - loss = model.loss(logits, labels, lengths, class_weights=class_weights, extra_mask=extra_mask) - p, r, f1 = model.score(logits, labels, lengths, scorer=scorer, extra_mask=extra_mask) + seq_mask = tf.sequence_mask(lengths, token_ids.shape[1]) + logits = model(token_ids, prefix, suffix, graph_ids, target=None, training=True, mask=seq_mask) + loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask) + # token_acc = tf.reduce_sum(tf.cast(tf.argmax(logits, axis=-1) == labels, tf.float32)) / (token_ids.shape[0] * token_ids.shape[1]) + p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask) gradients = tape.gradient(loss, model.trainable_variables) if not finetune: # do not update embeddings @@ -393,7 +278,10 @@ def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids, return loss, p, r, f1 # @tf.function -def test_step(model, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None): +def test_step( + model, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None, + no_localization=False +): """ :param model: TypePrediction model instance @@ -408,14 +296,44 @@ def test_step(model, token_ids, prefix, suffix, graph_ids, labels, lengths, extr :param scorer: scorer function, takes `pred_labels` and `true_labels` as aguments :return: values for loss, precision, recall and f1-score """ - logits = model(token_ids, prefix, suffix, graph_ids, training=False) - loss = model.loss(logits, labels, lengths, class_weights=class_weights, extra_mask=extra_mask) - p, r, f1 = model.score(logits, labels, lengths, scorer=scorer, extra_mask=extra_mask) + seq_mask = tf.sequence_mask(lengths, token_ids.shape[1]) + logits = model(token_ids, prefix, suffix, graph_ids, target=None, training=False, mask=seq_mask) + loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask) + p, r, f1 = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask) return loss, p, r, f1 -def train(model, train_batches, test_batches, epochs, report_every=10, scorer=None, learning_rate=0.01, learning_rate_decay=1., finetune=False, summary_writer=None): +def test(model, test_batches, scorer=None): + test_alosses = [] + test_aps = [] + test_ars = [] + test_af1s = [] + + for ind, batch in enumerate(test_batches): + # token_ids, graph_ids, labels, class_weights, lengths = b + test_loss, test_p, test_r, test_f1 = test_step( + model=model, token_ids=batch['tok_ids'], + prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'], + labels=batch['tags'], lengths=batch['lens'], extra_mask=batch['hide_mask'], + # class_weights=batch['class_weights'], + scorer=scorer + ) + + test_alosses.append(test_loss) + test_aps.append(test_p) + test_ars.append(test_r) + test_af1s.append(test_f1) + + def avg(arr): + return sum(arr) / len(arr) + + return avg(test_alosses), avg(test_aps), avg(test_ars), avg(test_af1s) + +def train( + model, train_batches, test_batches, epochs, report_every=10, scorer=None, learning_rate=0.01, + learning_rate_decay=1., finetune=False, summary_writer=None, save_ckpt_fn=None, no_localization=False +): assert summary_writer is not None @@ -430,6 +348,8 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No num_train_batches = len(train_batches) num_test_batches = len(test_batches) + best_f1 = 0. + try: with summary_writer.as_default(): @@ -440,12 +360,15 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No rs = [] f1s = [] - for ind, batch in enumerate(train_batches): + start = time() + + for ind, batch in enumerate(tqdm(train_batches)): # token_ids, graph_ids, labels, class_weights, lengths = b loss, p, r, f1 = train_step_finetune( model=model, optimizer=optimizer, token_ids=batch['tok_ids'], prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'], - labels=batch['tags'], lengths=batch['lens'], extra_mask=batch['hide_mask'], + labels=batch['tags'], lengths=batch['lens'], + extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'], # class_weights=batch['class_weights'], scorer=scorer, finetune=finetune and e/epochs > 0.6 ) @@ -459,12 +382,18 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No tf.summary.scalar("Recall/Train", r, step=e * num_train_batches + ind) tf.summary.scalar("F1/Train", f1, step=e * num_train_batches + ind) + test_alosses = [] + test_aps = [] + test_ars = [] + test_af1s = [] + for ind, batch in enumerate(test_batches): # token_ids, graph_ids, labels, class_weights, lengths = b test_loss, test_p, test_r, test_f1 = test_step( model=model, token_ids=batch['tok_ids'], prefix=batch['prefix'], suffix=batch['suffix'], graph_ids=batch['graph_ids'], - labels=batch['tags'], lengths=batch['lens'], extra_mask=batch['hide_mask'], + labels=batch['tags'], lengths=batch['lens'], + extra_mask=batch['no_loc_mask'] if no_localization else batch['hide_mask'], # class_weights=batch['class_weights'], scorer=scorer ) @@ -473,14 +402,24 @@ def train(model, train_batches, test_batches, epochs, report_every=10, scorer=No tf.summary.scalar("Precision/Test", test_p, step=e * num_test_batches + ind) tf.summary.scalar("Recall/Test", test_r, step=e * num_test_batches + ind) tf.summary.scalar("F1/Test", test_f1, step=e * num_test_batches + ind) + test_alosses.append(test_loss) + test_aps.append(test_p) + test_ars.append(test_r) + test_af1s.append(test_f1) - print(f"Epoch: {e}, Train Loss: {sum(losses) / len(losses)}, Train P: {sum(ps) / len(ps)}, Train R: {sum(rs) / len(rs)}, Train F1: {sum(f1s) / len(f1s)}, " - f"Test loss: {test_loss}, Test P: {test_p}, Test R: {test_r}, Test F1: {test_f1}") + epoch_time = time() - start train_losses.append(float(sum(losses) / len(losses))) train_f1s.append(float(sum(f1s) / len(f1s))) - test_losses.append(float(test_loss)) - test_f1s.append(float(test_f1)) + test_losses.append(float(sum(test_alosses) / len(test_alosses))) + test_f1s.append(float(sum(test_af1s) / len(test_af1s))) + + print(f"Epoch: {e}, {epoch_time: .2f} s, Train Loss: {train_losses[-1]: .4f}, Train P: {sum(ps) / len(ps): .4f}, Train R: {sum(rs) / len(rs): .4f}, Train F1: {sum(f1s) / len(f1s): .4f}, " + f"Test loss: {test_losses[-1]: .4f}, Test P: {sum(test_aps) / len(test_aps): .4f}, Test R: {sum(test_ars) / len(test_ars): .4f}, Test F1: {test_f1s[-1]: .4f}") + + if save_ckpt_fn is not None and float(test_f1s[-1]) > best_f1: + save_ckpt_fn() + best_f1 = float(test_f1s[-1]) lr.assign(lr * learning_rate_decay) diff --git a/SourceCodeTools/nlp/entity/type_prediction.py b/SourceCodeTools/nlp/entity/type_prediction.py index b15851b6..3394d0ef 100644 --- a/SourceCodeTools/nlp/entity/type_prediction.py +++ b/SourceCodeTools/nlp/entity/type_prediction.py @@ -1,25 +1,32 @@ from __future__ import unicode_literals, print_function import json +import logging import os import pickle from copy import copy from datetime import datetime +# os.environ['CUDA_VISIBLE_DEVICES'] = '-1' +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' + +from pathlib import Path + import tensorflow from SourceCodeTools.nlp.batchers import PythonBatcher from SourceCodeTools.nlp.entity import parse_biluo from SourceCodeTools.nlp.entity.tf_models.params import cnn_params -from SourceCodeTools.nlp.entity.utils import get_unique_entities +from SourceCodeTools.nlp.entity.utils import get_unique_entities, overlap from SourceCodeTools.nlp.entity.utils.data import read_data def load_pkl_emb(path): """ - - :param path: - :return: + Load graph embeddings from a pickle file. Embeddigns are stored in class Embedder or in a list of Embedders. The + last embedder in the list is returned. + :param path: path to graph embeddigs stored as Embedder pickle + :return: Embedder object """ embedder = pickle.load(open(path, "rb")) if isinstance(embedder, list): @@ -27,7 +34,41 @@ def load_pkl_emb(path): return embedder -def scorer(pred, labels, tagmap, eps=1e-8): +def compute_precision_recall_f1(tp, fp, fn, eps=1e-8): + precision = tp / (tp + fp + eps) + recall = tp / (tp + fn + eps) + f1 = 2 * precision * recall / (precision + recall + eps) + return precision, recall, f1 + + +def localized_f1(pred_spans, true_spans, eps=1e-8): + + tp = 0. + fp = 0. + fn = 0. + + for pred, true in zip(pred_spans, true_spans): + if true != "O": + if true == pred: + tp += 1 + else: + if pred == "O": + fn += 1 + else: + fp += 1 + + return compute_precision_recall_f1(tp, fp, fn) + + +def span_f1(pred_spans, true_spans, eps=1e-8): + tp = len(pred_spans.intersection(true_spans)) + fp = len(pred_spans - true_spans) + fn = len(true_spans - pred_spans) + + return compute_precision_recall_f1(tp, fp, fn) + + +def scorer(pred, labels, tagmap, no_localization=False, eps=1e-8): """ Compute f1 score, precision, and recall from BILUO labels :param pred: predicted BILUO labels @@ -42,16 +83,13 @@ def scorer(pred, labels, tagmap, eps=1e-8): pred_biluo = [tagmap.inverse(p) for p in pred] labels_biluo = [tagmap.inverse(p) for p in labels] - pred_spans = set(parse_biluo(pred_biluo)) - true_spans = set(parse_biluo(labels_biluo)) + if not no_localization: + pred_spans = set(parse_biluo(pred_biluo)) + true_spans = set(parse_biluo(labels_biluo)) - tp = len(pred_spans.intersection(true_spans)) - fp = len(pred_spans - true_spans) - fn = len(true_spans - pred_spans) - - precision = tp / (tp + fp + eps) - recall = tp / (tp + fn + eps) - f1 = 2 * precision * recall / (precision + recall + eps) + precision, recall, f1 = span_f1(pred_spans, true_spans, eps=eps) + else: + precision, recall, f1 = localized_f1(pred_biluo, labels_biluo, eps=eps) return precision, recall, f1 @@ -74,7 +112,8 @@ def write_config(trial_dir, params, extra_params=None): class ModelTrainer: def __init__(self, train_data, test_data, params, graph_emb_path=None, word_emb_path=None, - output_dir=None, epochs=30, batch_size=32, seq_len=100, finetune=False, trials=1): + output_dir=None, epochs=30, batch_size=32, seq_len=100, finetune=False, trials=1, + no_localization=False, ckpt_path=None, no_graph=False): self.set_batcher_class() self.set_model_class() @@ -89,6 +128,9 @@ def __init__(self, train_data, test_data, params, graph_emb_path=None, word_emb_ self.finetune = finetune self.trials = trials self.seq_len = seq_len + self.no_localization = no_localization + self.ckpt_path = ckpt_path + self.no_graph = no_graph def set_batcher_class(self): self.batcher = PythonBatcher @@ -97,16 +139,23 @@ def set_model_class(self): from SourceCodeTools.nlp.entity.tf_models.tf_model import TypePredictor self.model = TypePredictor - def get_batcher(self, *args, **kwards): - return self.batcher(*args, **kwards) + def get_batcher(self, *args, **kwargs): + return self.batcher(*args, **kwargs) def get_model(self, *args, **kwargs): - return self.model(*args, **kwargs) + model = self.model(*args, **kwargs) + if self.ckpt_path is not None: + model.load_weights(os.path.join(self.ckpt_path, "checkpoint")) + return model def train(self, *args, **kwargs): from SourceCodeTools.nlp.entity.tf_models.tf_model import train return train(*args, summary_writer=self.summary_writer, **kwargs) + def test(self, *args, **kwargs): + from SourceCodeTools.nlp.entity.tf_models.tf_model import test + return test(*args, **kwargs) + def create_summary_writer(self, path): self.summary_writer = tensorflow.summary.create_file_writer(path) @@ -114,48 +163,69 @@ def create_summary_writer(self, path): # with self.summary_writer.as_default(): # tensorflow.summary.scalar(value_name, value, step=step) - def train_model(self): + def get_dataloaders(self, word_emb, graph_emb, suffix_prefix_buckets): - graph_emb = load_pkl_emb(self.graph_emb_path) - word_emb = load_pkl_emb(self.word_emb_path) - - suffix_prefix_buckets = params.pop("suffix_prefix_buckets") + if self.ckpt_path is not None: + tagmap = pickle.load(open(os.path.join(self.ckpt_path, "tag_types.pkl"), "rb")) + else: + tagmap = None train_batcher = self.get_batcher( - train_data, self.batch_size, seq_len=self.seq_len, graphmap=graph_emb.ind, wordmap=word_emb.ind, tagmap=None, - class_weights=False, element_hash_size=suffix_prefix_buckets + self.train_data, self.batch_size, seq_len=self.seq_len, + graphmap=graph_emb.ind if graph_emb is not None else None, + wordmap=word_emb.ind, tagmap=tagmap, + class_weights=False, element_hash_size=suffix_prefix_buckets, no_localization=self.no_localization ) test_batcher = self.get_batcher( - test_data, self.batch_size, seq_len=self.seq_len, graphmap=graph_emb.ind, wordmap=word_emb.ind, + self.test_data, self.batch_size, seq_len=self.seq_len, + graphmap=graph_emb.ind if graph_emb is not None else None, + wordmap=word_emb.ind, tagmap=train_batcher.tagmap, # use the same mapping - class_weights=False, element_hash_size=suffix_prefix_buckets # class_weights are not used for testing + class_weights=False, element_hash_size=suffix_prefix_buckets, # class_weights are not used for testing + no_localization=self.no_localization ) + return train_batcher, test_batcher + + def train_model(self): + + graph_emb = load_pkl_emb(self.graph_emb_path) if self.graph_emb_path is not None else None + word_emb = load_pkl_emb(self.word_emb_path) + + suffix_prefix_buckets = params.pop("suffix_prefix_buckets") + + train_batcher, test_batcher = self.get_dataloaders(word_emb, graph_emb, suffix_prefix_buckets) print(f"\n\n{params}") lr = params.pop("learning_rate") lr_decay = params.pop("learning_rate_decay") - param_dir = os.path.join(output_dir, str(datetime.now())) + timestamp = str(datetime.now()).replace(":","-").replace(" ","_") + param_dir = os.path.join(output_dir, timestamp) os.mkdir(param_dir) for trial_ind in range(self.trials): trial_dir = os.path.join(param_dir, repr(trial_ind)) + logging.info(f"Running trial: {timestamp}") os.mkdir(trial_dir) self.create_summary_writer(trial_dir) model = self.get_model( word_emb, graph_emb, train_embeddings=self.finetune, suffix_prefix_buckets=suffix_prefix_buckets, - num_classes=train_batcher.num_classes(), seq_len=self.seq_len, **params + num_classes=train_batcher.num_classes(), seq_len=self.seq_len, no_graph=self.no_graph, **params ) + def save_ckpt_fn(): + checkpoint_path = os.path.join(trial_dir, "checkpoint") + model.save_weights(checkpoint_path) + train_losses, train_f1, test_losses, test_f1 = self.train( model=model, train_batches=train_batcher, test_batches=test_batcher, - epochs=self.epochs, learning_rate=lr, scorer=lambda pred, true: scorer(pred, true, train_batcher.tagmap), - learning_rate_decay=lr_decay, finetune=self.finetune + epochs=self.epochs, learning_rate=lr, scorer=lambda pred, true: scorer(pred, true, train_batcher.tagmap, no_localization=self.no_localization), + learning_rate_decay=lr_decay, finetune=self.finetune, save_ckpt_fn=save_ckpt_fn, no_localization=self.no_localization ) - checkpoint_path = os.path.join(trial_dir, "checkpoint") - model.save_weights(checkpoint_path) + # checkpoint_path = os.path.join(trial_dir, "checkpoint") + # model.save_weights(checkpoint_path) metadata = { "train_losses": train_losses, @@ -166,9 +236,13 @@ def train_model(self): "learning_rate_decay": lr_decay, "epochs": self.epochs, "suffix_prefix_buckets": suffix_prefix_buckets, - "seq_len": self.seq_len + "seq_len": self.seq_len, + "batch_size": self.batch_size, + "no_localization": self.no_localization } + print("Maximum f1:", max(test_f1)) + # write_config(trial_dir, params, extra_params={"suffix_prefix_buckets": suffix_prefix_buckets, "seq_len": seq_len}) metadata.update(params) @@ -184,11 +258,13 @@ def get_type_prediction_arguments(): parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--data_path', dest='data_path', default=None, - help='Path to the file with nodes') + help='Path to the dataset file') parser.add_argument('--graph_emb_path', dest='graph_emb_path', default=None, - help='Path to the file with edges') + help='Path to the file with graph embeddings') parser.add_argument('--word_emb_path', dest='word_emb_path', default=None, - help='Path to the file with edges') + help='Path to the file with token embeddings') + parser.add_argument('--type_ann_edges', dest='type_ann_edges', default=None, + help='Path to type annotation edges') parser.add_argument('--learning_rate', dest='learning_rate', default=0.01, type=float, help='') parser.add_argument('--learning_rate_decay', dest='learning_rate_decay', default=1.0, type=float, @@ -199,17 +275,34 @@ def get_type_prediction_arguments(): help='') parser.add_argument('--max_seq_len', dest='max_seq_len', default=100, type=int, help='') - # parser.add_argument('--pretrain_phase', dest='pretrain_phase', default=20, type=int, - # help='') + parser.add_argument('--min_entity_count', dest='min_entity_count', default=3, type=int, + help='') + parser.add_argument('--pretraining_epochs', dest='pretraining_epochs', default=0, type=int, + help='') + parser.add_argument('--ckpt_path', dest='ckpt_path', default=None, type=str, + help='') parser.add_argument('--epochs', dest='epochs', default=500, type=int, help='') parser.add_argument('--trials', dest='trials', default=1, type=int, help='') + parser.add_argument('--gpu', dest='gpu', default=-1, type=int, + help='Does not work with Tensorflow backend') parser.add_argument('--finetune', action='store_true') + parser.add_argument('--no_localization', action='store_true') + parser.add_argument('--restrict_allowed', action='store_true', default=False) + parser.add_argument('--no_graph', action='store_true', default=False) parser.add_argument('model_output', help='') args = parser.parse_args() + + if args.finetune is False and args.pretraining_epochs > 0: + logging.info(f"Finetuning is disabled, but the the number of pretraining epochs is {args.pretraining_epochs}. Setting pretraining epochs to 0.") + args.pretraining_epochs = 0 + + if args.graph_emb_path is not None and not os.path.isfile(args.graph_emb_path): + logging.warning(f"File with graph embeddings does not exist: {args.graph_emb_path}") + args.graph_emb_path = None return args @@ -219,6 +312,22 @@ def save_entities(path, entities): entitiesfile.write(f"{e}\n") +def filter_labels(dataset, allowed=None, field=None): + if allowed is None: + return dataset + dataset = copy(dataset) + for sent, annotations in dataset: + annotations["entities"] = [e for e in annotations["entities"] if e[2] in allowed] + return dataset + + +def find_example(dataset, needed_label): + for sent, annotations in dataset: + for start, end, e in annotations["entities"]: + if e == needed_label: + print(f"{sent}: {sent[start: end]}") + + if __name__ == "__main__": args = get_type_prediction_arguments() @@ -228,10 +337,27 @@ def save_entities(path, entities): # allowed = {'str', 'bool', 'Optional', 'None', 'int', 'Any', 'Union', 'List', 'Dict', 'Callable', 'ndarray', # 'FrameOrSeries', 'bytes', 'DataFrame', 'Matcher', 'float', 'Tuple', 'bool_t', 'Description', 'Type'} - - train_data, test_data = read_data( - open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, include_only="entities", - min_entity_count=1, random_seed=args.random_seed + if args.restrict_allowed: + allowed = { + 'str', 'Optional', 'int', 'Any', 'Union', 'bool', 'Callable', 'Dict', 'bytes', 'float', 'Description', + 'List', 'Sequence', 'Namespace', 'T', 'Type', 'object', 'HTTPServerRequest', 'Future', "Matcher" + } + else: + allowed = None + + # train_data, test_data = read_data( + # open(args.data_path, "r").readlines(), normalize=True, allowed=None, include_replacements=True, include_only="entities", + # min_entity_count=args.min_entity_count, random_seed=args.random_seed + # ) + + dataset_dir = Path(args.data_path).parent + train_data = filter_labels( + pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_train.pkl"), "rb")), + allowed=allowed + ) + test_data = filter_labels( + pickle.load(open(dataset_dir.joinpath("type_prediction_dataset_no_defaults_test.pkl"), "rb")), + allowed=allowed ) unique_entities = get_unique_entities(train_data, field="entities") @@ -241,6 +367,7 @@ def save_entities(path, entities): trainer = ModelTrainer( train_data, test_data, params, graph_emb_path=args.graph_emb_path, word_emb_path=args.word_emb_path, output_dir=output_dir, epochs=args.epochs, batch_size=args.batch_size, - finetune=args.finetune, trials=args.trials, seq_len=args.max_seq_len, + finetune=args.finetune, trials=args.trials, seq_len=args.max_seq_len, no_localization=args.no_localization, + ckpt_path=args.ckpt_path, no_graph=args.no_graph ) trainer.train_model() diff --git a/SourceCodeTools/nlp/entity/utils/__init__.py b/SourceCodeTools/nlp/entity/utils/__init__.py index 52df6d10..8cf742ca 100644 --- a/SourceCodeTools/nlp/entity/utils/__init__.py +++ b/SourceCodeTools/nlp/entity/utils/__init__.py @@ -24,7 +24,7 @@ def normalize_entities(entities): - norm = lambda x: x.split("[")[0].split(".")[-1] + norm = lambda x: x.strip("\"").strip("'").split("[")[0].split(".")[-1] if len(entities) == 0: return entities diff --git a/SourceCodeTools/nlp/entity/utils/data.py b/SourceCodeTools/nlp/entity/utils/data.py index 8ef6652c..8135b607 100644 --- a/SourceCodeTools/nlp/entity/utils/data.py +++ b/SourceCodeTools/nlp/entity/utils/data.py @@ -55,7 +55,7 @@ def read_data( logging.info("Splitting dataset randomly") else: random.seed(random_seed) - logging.warning(f"Using ransom seed {random_seed} for dataset split") + logging.warning(f"Using random seed {random_seed} for dataset split") filter_infrequent( train_data, entities_in_dataset=Counter(entities_in_dataset), diff --git a/SourceCodeTools/nlp/entity/utils/visualize_dataset.py b/SourceCodeTools/nlp/entity/utils/visualize_dataset.py new file mode 100644 index 00000000..7098156e --- /dev/null +++ b/SourceCodeTools/nlp/entity/utils/visualize_dataset.py @@ -0,0 +1,36 @@ +import os.path + +import spacy +import sys +import json +from spacy.gold import biluo_tags_from_offsets +from spacy.tokenizer import Tokenizer +import re + +from SourceCodeTools.nlp import create_tokenizer +from SourceCodeTools.nlp.entity import parse_biluo +from SourceCodeTools.nlp.entity.entity_render import render_annotations + +annotations_path = sys.argv[1] +output_path = os.path.dirname(annotations_path) + +data = [] +references = [] +results = [] + +nlp = create_tokenizer("spacy") + +with open(annotations_path) as annotations: + for line in annotations: + entry = json.loads(line.strip()) + + doc = nlp(entry['text']) + tags_r = parse_biluo(biluo_tags_from_offsets(doc, entry['replacements'])) + tags_e = parse_biluo(biluo_tags_from_offsets(doc, entry['ents'])) + + data.append([entry['text']]) + references.append(tags_r) + results.append(tags_e) + +html = render_annotations(zip(data, references, results)) +open(os.path.join(output_path, "annotations.html"), "w").write(html) \ No newline at end of file diff --git a/SourceCodeTools/nlp/generation/data.py b/SourceCodeTools/nlp/generation/data.py new file mode 100644 index 00000000..9f8f0c68 --- /dev/null +++ b/SourceCodeTools/nlp/generation/data.py @@ -0,0 +1 @@ +# dataloader for documentation generation \ No newline at end of file diff --git a/SourceCodeTools/nlp/tokenizers.py b/SourceCodeTools/nlp/tokenizers.py index e20fa9e7..40084d1c 100644 --- a/SourceCodeTools/nlp/tokenizers.py +++ b/SourceCodeTools/nlp/tokenizers.py @@ -75,9 +75,53 @@ def default_tokenizer(text): from SourceCodeTools.nlp.embed.bpe import load_bpe_model, make_tokenizer return make_tokenizer(load_bpe_model(bpe_path)) + elif type == "codebert": + from transformers import RobertaTokenizer + import spacy + from spacy.tokens import Doc + + tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") + nlp = spacy.blank("en") + + def tokenize(text): + tokens = tokenizer.tokenize(text) + doc = Doc(nlp.vocab, tokens, spaces=[False] * len(tokens)) + return doc + + return tokenize else: raise Exception("Supported tokenizer types: spacy, regex, bpe") + +def codebert_to_spacy(tokens): + backup_tokens = tokens + fixed_spaces = [False] + fixed_words = [""] + + for ind, t in enumerate(tokens): + if len(t.text) > 1: + fixed_words.append(t.text.strip("Ġ")) + else: + fixed_words.append(t.text) + if ind != 0: + fixed_spaces.append(t.text.startswith("Ġ") and len(t.text) > 1) + fixed_spaces.append(False) + fixed_spaces.append(False) + fixed_words.append("") + + assert len(fixed_spaces) == len(fixed_words) + + from spacy.tokens import Doc + import spacy + doc = Doc(spacy.blank("en").vocab, fixed_words, fixed_spaces) + + assert len(doc) - 2 == len(backup_tokens) + assert len(doc.text) - 7 == len(backup_tokens.text) + + adjustment = -3 + # spans = [adjust_offsets(sp, -3) for sp in spans] + return doc, adjustment + # import tokenize # from io import BytesIO # tokenize.tokenize(BytesIO(s.encode('utf-8')).readline) \ No newline at end of file diff --git a/SourceCodeTools/tabular/common.py b/SourceCodeTools/tabular/common.py index 771945fa..ec4a74ac 100644 --- a/SourceCodeTools/tabular/common.py +++ b/SourceCodeTools/tabular/common.py @@ -2,6 +2,13 @@ def compact_property(values, return_order=False, index_from_one=False): + """ + Returns a map from the original value to the compact index + :param values: + :param return_order: + :param index_from_one: + :return: + """ uniq = numpy.unique(values) if index_from_one: index = range(1, uniq.size + 1) @@ -12,5 +19,5 @@ def compact_property(values, return_order=False, index_from_one=False): inv_index = uniq.tolist() if index_from_one: inv_index.insert(0, "NA") - return prop2pid, + return prop2pid, inv_index return prop2pid \ No newline at end of file diff --git a/docker/preprocess_with_sourcetrail/Dockerfile b/docker/preprocess_with_sourcetrail/Dockerfile new file mode 100644 index 00000000..c673d08b --- /dev/null +++ b/docker/preprocess_with_sourcetrail/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.8-bullseye + +WORKDIR /usr/src/app + +RUN wget https://github.com/CoatiSoftware/Sourcetrail/releases/download/2020.1.117/Sourcetrail_2020_1_117_Linux_64bit.tar.gz +RUN tar -zxvf Sourcetrail_2020_1_117_Linux_64bit.tar.gz +ENV PATH="/usr/src/app/Sourcetrail":$PATH +RUN export LD_LIBRARY_PATH=/usr/src/app/Sourcetrail/lib:$LD_LIBRARY_PATH +ENV APP_PATH="/usr/src/app" + +RUN python -m venv python_env +RUN pip install pandas==1.1.1 +RUN apt-get update -y +RUN apt-get install -y sqlite3 +RUN apt-get install -y libglx0 + +COPY sourcetrail_verify_files.py . +COPY process_folders.sh . + +CMD ["bash", "process_folders.sh", "python_env", "/dataset"] \ No newline at end of file diff --git a/docker/preprocess_with_sourcetrail/process_folders.sh b/docker/preprocess_with_sourcetrail/process_folders.sh new file mode 100644 index 00000000..1355cc2d --- /dev/null +++ b/docker/preprocess_with_sourcetrail/process_folders.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# this script processes arbitrary code with sourcetrails, no dependencies are pulled +PYTHON_ENV=$(realpath $1) +DATA_PATH=$2 +#RUN_DIR=$(realpath "$(dirname "$0")") +RUN_DIR=$APP_PATH +VERIFIER_PATH="$RUN_DIR/sourcetrail_verify_files.py" + +create_sourcetrail_project_if_not_exist () { + PYTHON_ENV=$1 + PRJ_PATH=$2 + if [ ! -f $PRJ_PATH ]; then + echo "Creating Sourcetrail project for $repo" + echo " + + + + Python Source Group + $PYTHON_ENV + + .py + + + . + + enabled + Python Source Group + + + 8 +" > $PRJ_PATH + fi +} + + +run_indexer () { + repo=$1 + if [ ! -f "$repo/sourcetrail.log" ]; then + run_indexing=true + else + find_edges=$(cat "$repo/sourcetrail.log" | grep " Edges") + if [ -z "$find_edges" ]; then + echo "Indexing was interrupted, recovering..." + run_indexing=true + else + run_indexing=false + fi + fi + + if $run_indexing; then + echo "Begin indexing" + Sourcetrail.sh index -i $repo/$repo.srctrlprj >> $repo/sourcetrail.log + else + echo "Already indexed" + fi +} + +#echo "Running from $RUN_DIR" +#echo "Python env path $PYTHON_ENV" +#echo "Data path $DATA_PATH" +#echo "Verifier path $VERIFIER_PATH" + +for repo_dir in $DATA_PATH/* +do +# echo $repo_dir +# repo=$DATA_PATH/$folder + if [ -d "$repo_dir" ] + then + repo="$(basename $repo_dir)" + echo $repo + cd $DATA_PATH + create_sourcetrail_project_if_not_exist $PYTHON_ENV "$repo/$repo.srctrlprj" + run_indexer "$repo" + + if [ -f "$repo/$repo.srctrldb" ]; then + cd "$repo" + echo $(pwd) + echo ".headers on +.mode csv +.output edges.csv +SELECT * FROM edge; +.output nodes.csv +SELECT * FROM node; +.output element_component.csv +SELECT * FROM element_component; +.output source_location.csv +SELECT * FROM source_location; +.output occurrence.csv +SELECT * FROM occurrence; +.output filecontent.csv +SELECT * FROM filecontent; +.quit" | sqlite3 "$repo.srctrldb" + python $VERIFIER_PATH . + cd .. + else + echo "Package not indexed" + fi + + cd $RUN_DIR + fi +done diff --git a/docker/preprocess_with_sourcetrail/sourcetrail_verify_files.py b/docker/preprocess_with_sourcetrail/sourcetrail_verify_files.py new file mode 100644 index 00000000..622ac79f --- /dev/null +++ b/docker/preprocess_with_sourcetrail/sourcetrail_verify_files.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +import pandas as pd +import sys, os + +def verify_files(working_dir): + + fileheaders = { + "edges.csv": "id,type,source_node_id,target_node_id\n", + "nodes.csv": "id,type,serialized_name\n", + "element_component.csv": "id,element_id,type,data\n", + "source_location.csv": "id,file_node_id,start_line,start_column,end_line,end_column,type\n", + "occurrence.csv": "element_id,source_location_id\n", + "filecontent.csv": "id,content\n" + } + + for filename, header in fileheaders.items(): + file_path = os.path.join(working_dir, filename) + try: + pd.read_csv(file_path) + except pd.errors.EmptyDataError: + with open(file_path, "w") as sink: + sink.write(header) + + +if __name__ == "__main__": + working_dir = sys.argv[1] + verify_files(working_dir) \ No newline at end of file diff --git a/examples/Node Classification 2nd part.ipynb b/examples/Node Classification 2nd part.ipynb new file mode 100644 index 00000000..a1790f6b --- /dev/null +++ b/examples/Node Classification 2nd part.ipynb @@ -0,0 +1,408 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%load_ext tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "from random import random\n", + "\n", + "from SourceCodeTools.models.training_config import get_config, save_config, load_config\n", + "from SourceCodeTools.code.data.dataset.Dataset import SourceGraphDataset, filter_dst_by_freq\n", + "from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure, SamplingMultitaskTrainer\n", + "from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeClassifierObjective\n", + "from SourceCodeTools.models.graph.train.utils import get_name, get_model_base\n", + "from SourceCodeTools.models.graph import RGGAN\n", + "from SourceCodeTools.tabular.common import compact_property\n", + "from SourceCodeTools.code.data.file_utils import unpersist\n", + "\n", + "import dgl\n", + "import torch\n", + "import numpy as np\n", + "from torch import nn\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare parameters and options\n", + "\n", + "Full list of options that can be added can be found in `SourceCodeTools/models/training_options.py`. They are ment to be used as arguments for cli trainer. Trainer script can be found in `SourceCodeTools/scripts/train.py`.\n", + "\n", + "There are a lot of parameters. Ones that might be of interest are marked with `***`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "config = get_config(\n", + " # tokenizer\n", + " tokenizer_path=\"sentencepiece_bpe.model\", # *** path to sentencepiece model\n", + " \n", + " # dataset parameters\n", + " data_path=\"large_graph\", # *** path to node type\n", + " use_node_types=False, # node types currently not supported\n", + " use_edge_types=True, # whether to use edge types\n", + " filter_edges=None, # None or list of edge type names\n", + " self_loops=False, # whether to use self loops\n", + " train_frac=0.8, # *** fraction of nodes to use for training\n", + " random_seed=42, # random seed for splitting dataset int o train test validation\n", + " min_count_for_objectives=5, # *** minimum frequency of targets\n", + " no_global_edges=False, # remove global edges\n", + " remove_reverse=False, # remove reverse edges\n", + " custom_reverse=None, # None or list of edges, for which reverse edges should be created (use together with `remove_reverse`)\n", + " \n", + " # training parameters\n", + " model_output_dir=\"large_graph\", # *** directory to save checkpoints and training data\n", + " batch_size=128, # *** \n", + " sampling_neighbourhood_size=10, # number of dependencies to sample for each node\n", + " neg_sampling_factor=1, # *** number of negative samples for each positive sample\n", + " epochs=10, # *** number of epochs\n", + " elem_emb_size=100, # *** dimensionality of target embeddings (for node name prediction)\n", + " pretraining_phase=0, # number of epochs for pretraining\n", + " embedding_table_size=200000, # *** embedding table size for subwords\n", + " save_checkpoints=False, # set to False if checkpoints are not needed\n", + " save_each_epoch=False, # save each epoch, useful in case of studying model behavior\n", + " measure_scores=True, # *** measure ranking scores during evaluation\n", + " dilate_scores=200, # downsampling factor for measuring scores to make evaluation faster\n", + " objectives=\"node_clf\", # type of objective\n", + " force_w2v_ns=True, # negative sampling strategy\n", + " gpu=-1, # gpuid\n", + " restore_state=False,\n", + " pretrained=None,\n", + " \n", + " # model parameters\n", + " node_emb_size=100, # *** dimensionality of node embeddings\n", + " h_dim=100, # *** should match to node dimensionality\n", + " num_bases=10, # number of bases for computing parmetwer weights for different edge types\n", + " dropout=0.2, # *** \n", + " use_self_loop=True, #\n", + " activation=\"tanh\", # *** \n", + " learning_rate=1e-3, # *** \n", + " use_gcn_checkpoint=True,\n", + " use_att_checkpoint=True,\n", + " use_gru_checkpoint=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "save_config(config, \"config.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "config = load_config(\"config.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'DATASET': {'custom_reverse': None,\n", + " 'data_path': 'large_graph',\n", + " 'filter_edges': None,\n", + " 'min_count_for_objectives': 5,\n", + " 'no_global_edges': False,\n", + " 'random_seed': 42,\n", + " 'remove_reverse': False,\n", + " 'restricted_id_pool': None,\n", + " 'self_loops': False,\n", + " 'train_frac': 0.8,\n", + " 'use_edge_types': True,\n", + " 'use_node_types': False},\n", + " 'MODEL': {'activation': 'tanh',\n", + " 'dropout': 0.2,\n", + " 'h_dim': 100,\n", + " 'n_layers': 5,\n", + " 'node_emb_size': 100,\n", + " 'num_bases': 10,\n", + " 'use_att_checkpoint': True,\n", + " 'use_gcn_checkpoint': True,\n", + " 'use_gru_checkpoint': True,\n", + " 'use_self_loop': True},\n", + " 'TOKENIZER': {'tokenizer_path': 'sentencepiece_bpe.model'},\n", + " 'TRAINING': {'batch_size': 128,\n", + " 'dilate_scores': 200,\n", + " 'early_stopping': False,\n", + " 'early_stopping_tolerance': 20,\n", + " 'elem_emb_size': 100,\n", + " 'embedding_table_size': 200000,\n", + " 'epochs': 10,\n", + " 'external_dataset': None,\n", + " 'force_w2v_ns': True,\n", + " 'gpu': -1,\n", + " 'learning_rate': 0.001,\n", + " 'measure_scores': True,\n", + " 'metric': 'inner_prod',\n", + " 'model_output_dir': 'large_graph',\n", + " 'neg_sampling_factor': 1,\n", + " 'nn_index': 'brute',\n", + " 'objectives': 'node_clf',\n", + " 'pretrained': None,\n", + " 'pretraining_phase': 0,\n", + " 'restore_state': False,\n", + " 'sampling_neighbourhood_size': 10,\n", + " 'save_checkpoints': False,\n", + " 'save_each_epoch': False,\n", + " 'schedule_layers_every': 10,\n", + " 'use_layer_scheduling': False,\n", + " 'use_ns_groups': False}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Random state for splitting dataset is fixed\n" + ] + } + ], + "source": [ + "dataset = SourceGraphDataset(\n", + " **{**config[\"DATASET\"], **config[\"TOKENIZER\"]},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Declare target loading function (labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def load_type_prediction():\n", + " from SourceCodeTools.code.data.dataset.reader import load_data\n", + " \n", + " nodes, edges = dataset.nodes, dataset.edges\n", + " \n", + " type_ann = unpersist(\"large_graph/type_annotations.json.bz2\").query(\"src in @node_ids\", local_dict={\"node_ids\": nodes[\"id\"]})\n", + " \n", + " norm = lambda x: x.strip(\"\\\"\").strip(\"'\").split(\"[\")[0].split(\".\")[-1]\n", + "\n", + " type_ann[\"dst\"] = type_ann[\"dst\"].apply(norm)\n", + " type_ann = filter_dst_by_freq(type_ann, config[\"DATASET\"][\"min_count_for_objectives\"])\n", + " type_ann = type_ann[[\"src\", \"dst\"]]\n", + "\n", + " return type_ann" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define objectives\n", + "\n", + "Currenlty objectives for node classification (`NodeClassifierObjective`), and name-based node embedding training `SubwordEmbedderObjective`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](\"examples/figures/img1.png)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One or several objectives could be used" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "class Trainer(SamplingMultitaskTrainer):\n", + " def create_objectives(self, dataset, tokenizer_path):\n", + " self.objectives = nn.ModuleList()\n", + " \n", + "# self.objectives.append(\n", + "# NodeClassifierObjective(\n", + "# \"NodeTypeClassifier\",\n", + "# self.graph_model, self.node_embedder, dataset.nodes,\n", + "# dataset.load_node_classes, # need to define this function\n", + "# self.device, self.sampling_neighbourhood_size, self.batch_size,\n", + "# tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size,\n", + "# masker=dataset.create_node_clf_masker(), # this is needed only for node type classification\n", + "# measure_scores=self.trainer_params[\"measure_scores\"],\n", + "# dilate_scores=self.trainer_params[\"dilate_scores\"]\n", + "# )\n", + "# )\n", + " \n", + " self.objectives.append(\n", + " NodeClassifierObjective(\n", + " \"TypeAnnPrediction\",\n", + " self.graph_model, self.node_embedder, dataset.nodes,\n", + " load_type_prediction, # need to define this function\n", + " self.device, self.sampling_neighbourhood_size, self.batch_size,\n", + " tokenizer_path=tokenizer_path, target_emb_size=self.elem_emb_size, \n", + " masker=None, # masker is not needed here\n", + " measure_scores=self.trainer_params[\"measure_scores\"],\n", + " dilate_scores=self.trainer_params[\"dilate_scores\"]\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ERROR: Could not find `tensorboard`. Please ensure that your PATH\n", + "contains an executable `tensorboard` program, or explicitly specify\n", + "the path to a TensorBoard binary by setting the `TENSORBOARD_BINARY`\n", + "environment variable." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%tensorboard --logdir \"large_graph\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of nodes 324218\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 0%| | 0/30 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypenamementioned_instring
00moduleExampleModule<NA><NA>
11classExampleModule.ExampleClass<NA><NA>
22class_methodExampleModule.ExampleClass.__init__<NA><NA>
33non_indexed_symbolbuiltins<NA><NA>
44classbuiltins.int<NA><NA>
..................
145152mentionprint@FunctionDef_0x16d41315c9c41f53143<NA>
146153CallCall_0x16d41315c9148a75143print(a+b)
147154BinOpBinOp_0x16d41315c919b5a9143a+b
148155mentionmain@Module_0x16d41315c9361936138<NA>
149156CallCall_0x16d41315c99405f0138main()
\n", + "

150 rows × 5 columns

\n", + "" + ], + "text/plain": [ + " id type name \\\n", + "0 0 module ExampleModule \n", + "1 1 class ExampleModule.ExampleClass \n", + "2 2 class_method ExampleModule.ExampleClass.__init__ \n", + "3 3 non_indexed_symbol builtins \n", + "4 4 class builtins.int \n", + ".. ... ... ... \n", + "145 152 mention print@FunctionDef_0x16d41315c9c41f53 \n", + "146 153 Call Call_0x16d41315c9148a75 \n", + "147 154 BinOp BinOp_0x16d41315c919b5a9 \n", + "148 155 mention main@Module_0x16d41315c9361936 \n", + "149 156 Call Call_0x16d41315c99405f0 \n", + "\n", + " mentioned_in string \n", + "0 \n", + "1 \n", + "2 \n", + "3 \n", + "4 \n", + ".. ... ... \n", + "145 143 \n", + "146 143 print(a+b) \n", + "147 143 a+b \n", + "148 138 \n", + "149 138 main() \n", + "\n", + "[150 rows x 5 columns]" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nodes" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "40cbd760-f8c0-42a7-94cc-0a6e12aa2735", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypesrcdstfile_idmentioned_in
00defines01NaN<NA>
11defines12NaN<NA>
22defines34NaN<NA>
33uses_type24NaN<NA>
44defines15NaN<NA>
.....................
410464func_rev15615532.0138
411465next14315632.0138
412466prev15614332.0138
413467defined_in_module15613832.0138
414468defined_in_module_rev13815632.0138
\n", + "

415 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " id type src dst file_id mentioned_in\n", + "0 0 defines 0 1 NaN \n", + "1 1 defines 1 2 NaN \n", + "2 2 defines 3 4 NaN \n", + "3 3 uses_type 2 4 NaN \n", + "4 4 defines 1 5 NaN \n", + ".. ... ... ... ... ... ...\n", + "410 464 func_rev 156 155 32.0 138\n", + "411 465 next 143 156 32.0 138\n", + "412 466 prev 156 143 32.0 138\n", + "413 467 defined_in_module 156 138 32.0 138\n", + "414 468 defined_in_module_rev 138 156 32.0 138\n", + "\n", + "[415 rows x 6 columns]" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "edges" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "99a5bcf2-1835-4ef2-b6c4-190a76f477fc", + "metadata": {}, + "outputs": [], + "source": [ + "assert all(edges.eval(\"src in @node_ids\", local_dict={\"node_ids\": nodes[\"id\"]}))\n", + "assert all(edges.eval(\"dst in @node_ids\", local_dict={\"node_ids\": nodes[\"id\"]}))" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "bbd3686e-0227-4673-8f9f-2c4e25cb2d2e", + "metadata": {}, + "outputs": [], + "source": [ + "nodes = nodes[[\"id\", \"type\", \"name\"]]\n", + "edges = edges[[\"id\", \"type\", \"src\", \"dst\"]]" + ] + }, + { + "cell_type": "markdown", + "id": "e33b5977-7e85-4d79-bb9f-a13d930a8275", + "metadata": {}, + "source": [ + "## Reading type annotations" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "1bdb3167-ce60-4ac7-8c70-79ec664c1e85", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
srcdst
022int
235str
449str
653int
856str
1079None
12104int
14128str
\n", + "
" + ], + "text/plain": [ + " src dst\n", + "0 22 int\n", + "2 35 str\n", + "4 49 str\n", + "6 53 int\n", + "8 56 str\n", + "10 79 None\n", + "12 104 int\n", + "14 128 str" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type_annotations = unpersist(\"small_graph/type_annotations.json\").query(\"src in @node_ids\", local_dict={\"node_ids\": nodes[\"id\"]})\n", + "type_annotations" + ] + }, + { + "cell_type": "markdown", + "id": "3bbe4302-77fc-4c52-939b-f658a08b7c80", + "metadata": {}, + "source": [ + "# Preprocessing graph\n", + "## Removing some edges" + ] + }, + { + "cell_type": "markdown", + "id": "72fdc4f8-f0d8-44d8-a680-99c777aa8a35", + "metadata": {}, + "source": [ + "As an exercise, we remove some edge types" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "d90f13af-c576-4da4-a40a-5f17b22182d0", + "metadata": {}, + "outputs": [], + "source": [ + "def remove_global_edges(edges):\n", + " global_edges = SourceGraphDataset.get_global_edges()\n", + " is_ast = lambda type: type not in global_edges\n", + " edges = edges.query(\"type.map(@is_ast)\", local_dict={\"is_ast\": is_ast})\n", + " return edges" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "d4d5fd73-1591-427b-89d4-5e6abadfffcf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypesrcdst
5050subword1415
5151arg1519
5252arg_rev1915
5353args1920
5454args_rev2019
...............
410464func_rev156155
411465next143156
412466prev156143
413467defined_in_module156138
414468defined_in_module_rev138156
\n", + "

327 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " id type src dst\n", + "50 50 subword 14 15\n", + "51 51 arg 15 19\n", + "52 52 arg_rev 19 15\n", + "53 53 args 19 20\n", + "54 54 args_rev 20 19\n", + ".. ... ... ... ...\n", + "410 464 func_rev 156 155\n", + "411 465 next 143 156\n", + "412 466 prev 156 143\n", + "413 467 defined_in_module 156 138\n", + "414 468 defined_in_module_rev 138 156\n", + "\n", + "[327 rows x 4 columns]" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "edges_ast = remove_global_edges(edges)\n", + "edges_ast" + ] + }, + { + "cell_type": "markdown", + "id": "0fd6e0c9-63f9-4e81-865d-22557bb97800", + "metadata": {}, + "source": [ + "## Making sure no isolated nodes are present" + ] + }, + { + "cell_type": "markdown", + "id": "35a9d53e-85bc-435a-ac13-9087c7401008", + "metadata": {}, + "source": [ + "After graph has been edited, need to make sure there are no isolated nodes. They will cause errors when training GNN." + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "cc1e0844-b5f4-4a0c-89d2-ca5f8992f8d4", + "metadata": {}, + "outputs": [], + "source": [ + "def ensure_connectedness(nodes, edges):\n", + " \"\"\"\n", + " Filter isolated nodes\n", + " :param nodes: DataFrame\n", + " :param edges: DataFrame\n", + " :return:\n", + " \"\"\"\n", + " unique_connected_nodes = set(edges['src'].append(edges['dst']))\n", + " \n", + " nodes = nodes.query(\"id in @unique_connected_nodes\", local_dict={\"unique_connected_nodes\": unique_connected_nodes})\n", + " \n", + " print(f\"Ending up with {len(nodes)} nodes and {len(edges)} edges\")\n", + " return nodes, edges" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "2758500f-3b46-4192-a364-5f1e65da9275", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ending up with 150 nodes and 415 edges\n" + ] + } + ], + "source": [ + "nodes, edges = ensure_connectedness(nodes, edges)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "1fcf3230-8028-4a8a-8ce4-df30c7a3b563", + "metadata": {}, + "outputs": [], + "source": [ + "assert all(edges.eval(\"src in @node_ids\", local_dict={\"node_ids\": nodes[\"id\"]}))\n", + "assert all(edges.eval(\"dst in @node_ids\", local_dict={\"node_ids\": nodes[\"id\"]}))" + ] + }, + { + "cell_type": "markdown", + "id": "5942c29a-7076-4b0c-9590-0551b6b7a74e", + "metadata": {}, + "source": [ + "## Adding extra node information" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "551ff497-6660-4ad9-be08-33d059bdd86b", + "metadata": {}, + "outputs": [], + "source": [ + "def format_node_types(nodes):\n", + " \"\"\"\n", + " DGL confuses some node types with internal objects, need to change current type names\n", + " \"\"\"\n", + " nodes = nodes.copy() # copying is slow for large datasets, prefer in-place operations\n", + " nodes['type_backup'] = nodes['type']\n", + " nodes['type'] = nodes['type'].apply(lambda x: f\"{x}_\")\n", + " # nodes['type'] = \"node_\"\n", + " # nodes = nodes.astype({'type': 'category'})\n", + " return nodes\n", + "\n", + "# def add_embedding_names(nodes):\n", + "# \"\"\"\n", + "# Embedding names are used for initial embeddings (layer 0)\n", + "# \"\"\"\n", + "# nodes = nodes.copy()\n", + "# nodes[\"embeddable\"] = True\n", + "# nodes[\"embeddable_name\"] = nodes[\"name\"].apply(SourceGraphDataset.get_embeddable_name)\n", + "# return nodes\n", + "\n", + "def add_splits(nodes, train_frac, restricted_id_pool=None):\n", + " nodes = nodes.copy()\n", + " \n", + " def random_partition():\n", + " r = random()\n", + " if r < train_frac:\n", + " return \"train\"\n", + " elif r < train_frac + (1 - train_frac) / 2:\n", + " return \"val\"\n", + " else:\n", + " return \"test\"\n", + " \n", + " import numpy as np\n", + " # define partitioning\n", + " masks = np.array([random_partition() for _ in range(len(nodes))])\n", + " \n", + " # create masks\n", + " nodes[\"train_mask\"] = masks == \"train\"\n", + " nodes[\"val_mask\"] = masks == \"val\"\n", + " nodes[\"test_mask\"] = masks == \"test\"\n", + " \n", + " if restricted_id_pool is not None:\n", + " # if `restricted_id_pool` is provided, mask all nodes not in `restricted_id_pool` negatively\n", + " to_keep = nodes.eval(\"id in @restricted_ids\", local_dict={\"restricted_ids\": restricted_id_pool})\n", + " nodes[\"train_mask\"] = nodes[\"train_mask\"] & to_keep\n", + " nodes[\"test_mask\"] = nodes[\"test_mask\"] & to_keep\n", + " nodes[\"val_mask\"] = nodes[\"val_mask\"] & to_keep\n", + " \n", + " return nodes" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "63b0ec32-69aa-4793-b73c-8fb766adecc5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train examples: 4\n", + "Test examples: 1\n", + "Validation examples: 3\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypenametype_backuptrain_maskval_masktest_mask
00module_ExampleModulemoduleFalseFalseFalse
11class_ExampleModule.ExampleClassclassFalseFalseFalse
22class_method_ExampleModule.ExampleClass.__init__class_methodFalseFalseFalse
33non_indexed_symbol_builtinsnon_indexed_symbolFalseFalseFalse
44class_builtins.intclassFalseFalseFalse
........................
145152mention_print@FunctionDef_0x16d41315c9c41f53mentionFalseFalseFalse
146153Call_Call_0x16d41315c9148a75CallFalseFalseFalse
147154BinOp_BinOp_0x16d41315c919b5a9BinOpFalseFalseFalse
148155mention_main@Module_0x16d41315c9361936mentionFalseFalseFalse
149156Call_Call_0x16d41315c99405f0CallFalseFalseFalse
\n", + "

150 rows × 7 columns

\n", + "
" + ], + "text/plain": [ + " id type name \\\n", + "0 0 module_ ExampleModule \n", + "1 1 class_ ExampleModule.ExampleClass \n", + "2 2 class_method_ ExampleModule.ExampleClass.__init__ \n", + "3 3 non_indexed_symbol_ builtins \n", + "4 4 class_ builtins.int \n", + ".. ... ... ... \n", + "145 152 mention_ print@FunctionDef_0x16d41315c9c41f53 \n", + "146 153 Call_ Call_0x16d41315c9148a75 \n", + "147 154 BinOp_ BinOp_0x16d41315c919b5a9 \n", + "148 155 mention_ main@Module_0x16d41315c9361936 \n", + "149 156 Call_ Call_0x16d41315c99405f0 \n", + "\n", + " type_backup train_mask val_mask test_mask \n", + "0 module False False False \n", + "1 class False False False \n", + "2 class_method False False False \n", + "3 non_indexed_symbol False False False \n", + "4 class False False False \n", + ".. ... ... ... ... \n", + "145 mention False False False \n", + "146 Call False False False \n", + "147 BinOp False False False \n", + "148 mention False False False \n", + "149 Call False False False \n", + "\n", + "[150 rows x 7 columns]" + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nodes = add_splits(format_node_types(nodes), 0.5, restricted_id_pool=type_annotations[\"src\"])\n", + "print(\"Train examples:\", len(nodes.query(\"train_mask == True\")))\n", + "print(\"Test examples:\", len(nodes.query(\"test_mask == True\")))\n", + "print(\"Validation examples:\", len(nodes.query(\"val_mask == True\")))\n", + "nodes" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "fe87c0bd-4294-4978-bbba-a3f24bf73024", + "metadata": {}, + "outputs": [], + "source": [ + "def add_type_dependent_dense_ids_to_nodes(nodes):\n", + " \"\"\"\n", + " DGL requires dense ids: https://docs.dgl.ai/en/latest/generated/dgl.heterograph.html#dgl.heterograph\n", + " Compute dense ids for each node type\n", + " \"\"\"\n", + " nodes = nodes.copy()\n", + "\n", + " typed_id_map = {}\n", + "\n", + " for type_ in nodes['type'].unique():\n", + " # create mask for the current node type\n", + " type_mask = nodes['type'] == type_\n", + "\n", + " # `compact_property` will create a dense mapping\n", + " # it is equivalent to dict(zip(node_ids, range(len(node_ids))))\n", + " id_map = compact_property(nodes.loc[type_mask, 'id'])\n", + "\n", + " # add a new column with dense type-dependent ids\n", + " nodes.loc[type_mask, 'typed_id'] = nodes.loc[type_mask, 'id'].apply(lambda old_id: id_map[old_id])\n", + "\n", + " # store for further reference\n", + " typed_id_map[type_] = id_map\n", + "\n", + " nodes = nodes.astype({\"typed_id\": \"int64\"})\n", + " return nodes, typed_id_map" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "40b80587-d44c-4093-b074-66a434354264", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypenametype_backuptrain_maskval_masktest_masktyped_id
00module_ExampleModulemoduleFalseFalseFalse0
11class_ExampleModule.ExampleClassclassFalseFalseFalse0
22class_method_ExampleModule.ExampleClass.__init__class_methodFalseFalseFalse0
33non_indexed_symbol_builtinsnon_indexed_symbolFalseFalseFalse0
44class_builtins.intclassFalseFalseFalse1
...........................
145152mention_print@FunctionDef_0x16d41315c9c41f53mentionFalseFalseFalse33
146153Call_Call_0x16d41315c9148a75CallFalseFalseFalse9
147154BinOp_BinOp_0x16d41315c919b5a9BinOpFalseFalseFalse1
148155mention_main@Module_0x16d41315c9361936mentionFalseFalseFalse34
149156Call_Call_0x16d41315c99405f0CallFalseFalseFalse10
\n", + "

150 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id type name \\\n", + "0 0 module_ ExampleModule \n", + "1 1 class_ ExampleModule.ExampleClass \n", + "2 2 class_method_ ExampleModule.ExampleClass.__init__ \n", + "3 3 non_indexed_symbol_ builtins \n", + "4 4 class_ builtins.int \n", + ".. ... ... ... \n", + "145 152 mention_ print@FunctionDef_0x16d41315c9c41f53 \n", + "146 153 Call_ Call_0x16d41315c9148a75 \n", + "147 154 BinOp_ BinOp_0x16d41315c919b5a9 \n", + "148 155 mention_ main@Module_0x16d41315c9361936 \n", + "149 156 Call_ Call_0x16d41315c99405f0 \n", + "\n", + " type_backup train_mask val_mask test_mask typed_id \n", + "0 module False False False 0 \n", + "1 class False False False 0 \n", + "2 class_method False False False 0 \n", + "3 non_indexed_symbol False False False 0 \n", + "4 class False False False 1 \n", + ".. ... ... ... ... ... \n", + "145 mention False False False 33 \n", + "146 Call False False False 9 \n", + "147 BinOp False False False 1 \n", + "148 mention False False False 34 \n", + "149 Call False False False 10 \n", + "\n", + "[150 rows x 8 columns]" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nodes, typed_id_map = add_type_dependent_dense_ids_to_nodes(nodes)\n", + "nodes" + ] + }, + { + "cell_type": "markdown", + "id": "1cd0fa5c-d17e-4273-a7a4-b7a1fbea93da", + "metadata": {}, + "source": [ + "## Adding extra edge information" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "be0d6b46-2b58-471a-880c-baf1f4f2a38a", + "metadata": {}, + "outputs": [], + "source": [ + "def format_edge_types(edges):\n", + " \"\"\"\n", + " DGL confuses some edge types with internal objects, need to change current type names\n", + " \"\"\"\n", + " edges = edges.copy()\n", + " edges['type'] = edges['type'].apply(lambda x: f\"{x}_\")\n", + " return edges" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "70626daf-c6f4-46cb-a405-828e93eeea6e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypesrcdst
00defines_01
11defines_12
22defines_34
33uses_type_24
44defines_15
...............
410464func_rev_156155
411465next_143156
412466prev_156143
413467defined_in_module_156138
414468defined_in_module_rev_138156
\n", + "

415 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " id type src dst\n", + "0 0 defines_ 0 1\n", + "1 1 defines_ 1 2\n", + "2 2 defines_ 3 4\n", + "3 3 uses_type_ 2 4\n", + "4 4 defines_ 1 5\n", + ".. ... ... ... ...\n", + "410 464 func_rev_ 156 155\n", + "411 465 next_ 143 156\n", + "412 466 prev_ 156 143\n", + "413 467 defined_in_module_ 156 138\n", + "414 468 defined_in_module_rev_ 138 156\n", + "\n", + "[415 rows x 4 columns]" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "edges = format_edge_types(edges)\n", + "edges" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "ca377c26-bd4c-403b-a628-d33b2b3aad06", + "metadata": {}, + "outputs": [], + "source": [ + "def add_node_types_to_edges(nodes, edges):\n", + " \"\"\"\n", + " Add node types because they are needed for refining edge signatures\n", + " \"\"\"\n", + " edges = edges.copy()\n", + " node_type_map = dict(zip(nodes['id'], nodes['type']))\n", + "\n", + " edges['src_type'] = edges['src'].apply(lambda src_id: node_type_map[src_id])\n", + " edges['dst_type'] = edges['dst'].apply(lambda dst_id: node_type_map[dst_id])\n", + " edges = edges.astype({'src_type': 'category', 'dst_type': 'category'})\n", + "\n", + " return edges" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "5d330c29-65eb-4b18-9003-d89503e5747e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypesrcdstsrc_typedst_type
00defines_01module_class_
11defines_12class_class_method_
22defines_34non_indexed_symbol_class_
33uses_type_24class_method_class_
44defines_15class_class_field_
.....................
410464func_rev_156155Call_mention_
411465next_143156FunctionDef_Call_
412466prev_156143Call_FunctionDef_
413467defined_in_module_156138Call_Module_
414468defined_in_module_rev_138156Module_Call_
\n", + "

415 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " id type src dst src_type dst_type\n", + "0 0 defines_ 0 1 module_ class_\n", + "1 1 defines_ 1 2 class_ class_method_\n", + "2 2 defines_ 3 4 non_indexed_symbol_ class_\n", + "3 3 uses_type_ 2 4 class_method_ class_\n", + "4 4 defines_ 1 5 class_ class_field_\n", + ".. ... ... ... ... ... ...\n", + "410 464 func_rev_ 156 155 Call_ mention_\n", + "411 465 next_ 143 156 FunctionDef_ Call_\n", + "412 466 prev_ 156 143 Call_ FunctionDef_\n", + "413 467 defined_in_module_ 156 138 Call_ Module_\n", + "414 468 defined_in_module_rev_ 138 156 Module_ Call_\n", + "\n", + "[415 rows x 6 columns]" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "edges = add_node_types_to_edges(nodes, edges)\n", + "edges" + ] + }, + { + "cell_type": "markdown", + "id": "05ddca13-ee7f-4525-9813-0938abc35536", + "metadata": {}, + "source": [ + "# Building graph" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "6ccab980-1be4-44c7-a220-3f849c9ed2c3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unique nodes: 150, node types: 26\n", + "Unique edges: 415, edge types: 50\n" + ] + } + ], + "source": [ + "print(f\"Unique nodes: {len(nodes)}, node types: {len(nodes['type'].unique())}\")\n", + "print(f\"Unique edges: {len(edges)}, edge types: {len(edges['type'].unique())}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "94b006cb-9821-4e13-b03b-26dae995f9c3", + "metadata": {}, + "outputs": [], + "source": [ + "def add_global_dense_graph_id(nodes, graph, typed_id_map):\n", + " \"\"\"\n", + " Add dense global node ids to make it easier working with embeddings in the future\n", + " \"\"\"\n", + " orig_id = []\n", + " graph_id = []\n", + " prev_offset = 0\n", + " \n", + " nodes = nodes.copy()\n", + "\n", + " # simply assign global id in the order node types appear in `graph.ntypes`\n", + " for type_ in graph.ntypes:\n", + " from_id, to_id = zip(*typed_id_map[type_].items())\n", + " orig_id.extend(from_id)\n", + " graph_id.extend([t + prev_offset for t in to_id])\n", + " prev_offset += graph.number_of_nodes(type_)\n", + "\n", + " global_map = dict(zip(orig_id, graph_id))\n", + "\n", + " nodes['global_graph_id'] = nodes['id'].apply(lambda old_id: global_map[old_id])\n", + " \n", + " return nodes\n", + "\n", + "def add_node_data(graph, nodes):\n", + " field_types = {\n", + " \"train_mask\": torch.bool,\n", + " \"test_mask\": torch.bool,\n", + " \"val_mask\": torch.bool,\n", + " \"typed_id\": torch.int64,\n", + " \"original_id\": torch.int64,\n", + " \"global_graph_id\": torch.int64,\n", + " }\n", + " \n", + " for ntype in graph.ntypes:\n", + " node_data = nodes.query(f\"type == '{ntype}'\").sort_values('typed_id').rename({\"id\": \"original_id\"}, axis=1)\n", + " \n", + " for field_name, field_type in field_types.items():\n", + " graph.nodes[ntype].data[field_name] = torch.tensor(node_data[field_name].values, dtype=field_type)\n", + " \n", + " return graph\n", + "\n", + "def create_hetero_graph(nodes, edges, typed_id_map):\n", + " # nodes = nodes.copy()\n", + " # edges = edges.copy()\n", + "\n", + " # edges = add_node_types_to_edges(nodes, edges)\n", + "\n", + " typed_node_id = dict(zip(nodes['id'], nodes['typed_id']))\n", + "\n", + " typed_subgraphs = {}\n", + "\n", + " # group by in pandas is slow, use something else for large datasets\n", + " for signature, signature_edges in edges.groupby(['src_type', 'type', 'dst_type']): \n", + " # `signature` is a tuple (src_type, edge_type, dst_type)\n", + " typed_subgraphs[signature] = list(\n", + " zip(\n", + " signature_edges['src'].map(lambda old_id: typed_node_id[old_id]),\n", + " signature_edges['dst'].map(lambda old_id: typed_node_id[old_id])\n", + " )\n", + " )\n", + "\n", + " print(\n", + " f\"Unique triplet types in the graph: {len(typed_subgraphs.keys())}\"\n", + " )\n", + "\n", + " g = dgl.heterograph(typed_subgraphs)\n", + " \n", + " nodes = add_global_dense_graph_id(nodes, g, typed_id_map)\n", + " \n", + " g = add_node_data(g, nodes)\n", + " return g, nodes, edges" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "7d81d65a-2f1b-4525-b4e9-2cbdeb234980", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unique triplet types in the graph: 140\n" + ] + } + ], + "source": [ + "g, nodes, edges = create_hetero_graph(nodes, edges, typed_id_map)" + ] + }, + { + "cell_type": "markdown", + "id": "f461773c-620c-45c9-9ce5-1d26f8f9c72d", + "metadata": {}, + "source": [ + "## Graph Atributes" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "25fc1ce2-309c-40c2-9bff-e6455800dc7f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['#attr#_',\n", + " 'AnnAssign_',\n", + " 'Assign_',\n", + " 'Attribute_',\n", + " 'BinOp_',\n", + " 'Call_',\n", + " 'ClassDef_',\n", + " 'Constant_',\n", + " 'FunctionDef_',\n", + " 'ImportFrom_',\n", + " 'JoinedStr_',\n", + " 'Module_',\n", + " 'Op_',\n", + " 'Return_',\n", + " 'alias_',\n", + " 'arg_',\n", + " 'arguments_',\n", + " 'class_',\n", + " 'class_field_',\n", + " 'class_method_',\n", + " 'function_',\n", + " 'global_variable_',\n", + " 'mention_',\n", + " 'module_',\n", + " 'non_indexed_symbol_',\n", + " 'subword_']" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.ntypes" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "b5110724-8f41-4e93-b4a7-689aea7760a2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "NodeSpace(data={'train_mask': tensor([False, False, False]), 'test_mask': tensor([False, False, False]), 'val_mask': tensor([False, False, False]), 'typed_id': tensor([0, 1, 2]), 'original_id': tensor([ 0, 10, 91]), 'global_graph_id': tensor([122, 123, 124])})" + ] + }, + "execution_count": 94, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.nodes[\"module_\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "17680572-d69d-4f5c-9409-a662992c0307", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'#attr#_': tensor([False, False, False, False]),\n", + " 'AnnAssign_': tensor([False, False]),\n", + " 'Assign_': tensor([False, False, False, False, False]),\n", + " 'Attribute_': tensor([False, False, False, False, False, False, False]),\n", + " 'BinOp_': tensor([False, False]),\n", + " 'Call_': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False]),\n", + " 'ClassDef_': tensor([False, False]),\n", + " 'Constant_': tensor([False]),\n", + " 'FunctionDef_': tensor([False, True, False, True, False, False, False, False]),\n", + " 'ImportFrom_': tensor([False, False]),\n", + " 'JoinedStr_': tensor([False]),\n", + " 'Module_': tensor([False, False, False, False]),\n", + " 'Op_': tensor([False]),\n", + " 'Return_': tensor([False, False, False, False]),\n", + " 'alias_': tensor([False, False]),\n", + " 'arg_': tensor([False, False, False, False, False, False, False, False, False]),\n", + " 'arguments_': tensor([False, False, False, False, False, False]),\n", + " 'class_': tensor([False, False, False, False]),\n", + " 'class_field_': tensor([False, False]),\n", + " 'class_method_': tensor([False, False, False]),\n", + " 'function_': tensor([False, False, False, False, False, False]),\n", + " 'global_variable_': tensor([False]),\n", + " 'mention_': tensor([False, False, False, False, False, False, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False]),\n", + " 'module_': tensor([False, False, False]),\n", + " 'non_indexed_symbol_': tensor([False]),\n", + " 'subword_': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False])}" + ] + }, + "execution_count": 95, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.ndata[\"train_mask\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "bfbb97fa-2c68-4290-b74e-b63f0b5cd226", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[('#attr#_', 'attr_', 'Attribute_'),\n", + " ('AnnAssign_', 'defined_in_function_', 'FunctionDef_'),\n", + " ('AnnAssign_', 'next_', 'AnnAssign_'),\n", + " ('AnnAssign_', 'next_', 'Return_'),\n", + " ('AnnAssign_', 'prev_', 'AnnAssign_'),\n", + " ('AnnAssign_', 'target_rev_', 'mention_'),\n", + " ('AnnAssign_', 'value_rev_', 'Attribute_'),\n", + " ('AnnAssign_', 'value_rev_', 'Call_'),\n", + " ('Assign_', 'defined_in_function_', 'FunctionDef_'),\n", + " ('Assign_', 'defined_in_module_', 'Module_'),\n", + " ('Assign_', 'next_', 'Assign_'),\n", + " ('Assign_', 'next_', 'Call_'),\n", + " ('Assign_', 'next_', 'FunctionDef_'),\n", + " ('Assign_', 'prev_', 'Assign_'),\n", + " ('Assign_', 'prev_', 'ImportFrom_'),\n", + " ('Assign_', 'targets_rev_', 'Attribute_'),\n", + " ('Assign_', 'targets_rev_', 'mention_'),\n", + " ('Assign_', 'value_rev_', 'Call_'),\n", + " ('Assign_', 'value_rev_', 'mention_'),\n", + " ('Attribute_', 'func_', 'Call_'),\n", + " ('Attribute_', 'left_', 'BinOp_'),\n", + " ('Attribute_', 'right_', 'BinOp_'),\n", + " ('Attribute_', 'targets_', 'Assign_'),\n", + " ('Attribute_', 'value_', 'AnnAssign_'),\n", + " ('Attribute_', 'value_rev_', 'mention_'),\n", + " ('BinOp_', 'args_', 'Call_'),\n", + " ('BinOp_', 'left_rev_', 'Attribute_'),\n", + " ('BinOp_', 'left_rev_', 'mention_'),\n", + " ('BinOp_', 'right_rev_', 'Attribute_'),\n", + " ('BinOp_', 'right_rev_', 'mention_'),\n", + " ('Call_', 'args_', 'Call_'),\n", + " ('Call_', 'args_rev_', 'BinOp_'),\n", + " ('Call_', 'args_rev_', 'Call_'),\n", + " ('Call_', 'args_rev_', 'mention_'),\n", + " ('Call_', 'defined_in_function_', 'FunctionDef_'),\n", + " ('Call_', 'defined_in_module_', 'Module_'),\n", + " ('Call_', 'func_rev_', 'Attribute_'),\n", + " ('Call_', 'func_rev_', 'mention_'),\n", + " ('Call_', 'prev_', 'Assign_'),\n", + " ('Call_', 'prev_', 'FunctionDef_'),\n", + " ('Call_', 'value_', 'AnnAssign_'),\n", + " ('Call_', 'value_', 'Assign_'),\n", + " ('Call_', 'value_', 'Return_'),\n", + " ('ClassDef_', 'class_name_', 'mention_'),\n", + " ('ClassDef_', 'defined_in_class_rev_', 'FunctionDef_'),\n", + " ('ClassDef_', 'defined_in_module_', 'Module_'),\n", + " ('Constant_', 'args_', 'Call_'),\n", + " ('FunctionDef_', 'args_rev_', 'arguments_'),\n", + " ('FunctionDef_', 'defined_in_class_', 'ClassDef_'),\n", + " ('FunctionDef_', 'defined_in_function_rev_', 'AnnAssign_'),\n", + " ('FunctionDef_', 'defined_in_function_rev_', 'Assign_'),\n", + " ('FunctionDef_', 'defined_in_function_rev_', 'Call_'),\n", + " ('FunctionDef_', 'defined_in_function_rev_', 'Return_'),\n", + " ('FunctionDef_', 'defined_in_module_', 'Module_'),\n", + " ('FunctionDef_', 'function_name_', 'mention_'),\n", + " ('FunctionDef_', 'next_', 'Call_'),\n", + " ('FunctionDef_', 'next_', 'FunctionDef_'),\n", + " ('FunctionDef_', 'prev_', 'Assign_'),\n", + " ('FunctionDef_', 'prev_', 'FunctionDef_'),\n", + " ('FunctionDef_', 'prev_', 'ImportFrom_'),\n", + " ('ImportFrom_', 'defined_in_module_', 'Module_'),\n", + " ('ImportFrom_', 'module_rev_', 'mention_'),\n", + " ('ImportFrom_', 'names_rev_', 'alias_'),\n", + " ('ImportFrom_', 'next_', 'Assign_'),\n", + " ('ImportFrom_', 'next_', 'FunctionDef_'),\n", + " ('JoinedStr_', 'value_', 'Return_'),\n", + " ('Module_', 'defined_in_module_rev_', 'Assign_'),\n", + " ('Module_', 'defined_in_module_rev_', 'Call_'),\n", + " ('Module_', 'defined_in_module_rev_', 'ClassDef_'),\n", + " ('Module_', 'defined_in_module_rev_', 'FunctionDef_'),\n", + " ('Module_', 'defined_in_module_rev_', 'ImportFrom_'),\n", + " ('Op_', 'op_', 'BinOp_'),\n", + " ('Return_', 'defined_in_function_', 'FunctionDef_'),\n", + " ('Return_', 'prev_', 'AnnAssign_'),\n", + " ('Return_', 'value_rev_', 'Call_'),\n", + " ('Return_', 'value_rev_', 'mention_'),\n", + " ('alias_', 'asname_rev_', 'mention_'),\n", + " ('alias_', 'name_rev_', 'mention_'),\n", + " ('alias_', 'names_', 'ImportFrom_'),\n", + " ('arg_', 'arg_rev_', 'mention_'),\n", + " ('arg_', 'args_', 'arguments_'),\n", + " ('arguments_', 'args_', 'FunctionDef_'),\n", + " ('arguments_', 'args_rev_', 'arg_'),\n", + " ('class_', 'defined_in_', 'module_'),\n", + " ('class_', 'defined_in_', 'non_indexed_symbol_'),\n", + " ('class_', 'defines_', 'class_field_'),\n", + " ('class_', 'defines_', 'class_method_'),\n", + " ('class_', 'defines_', 'function_'),\n", + " ('class_', 'global_mention_', 'mention_'),\n", + " ('class_', 'imported_by_', 'module_'),\n", + " ('class_', 'type_used_by_', 'class_method_'),\n", + " ('class_', 'type_used_by_', 'function_'),\n", + " ('class_field_', 'defined_in_', 'class_'),\n", + " ('class_field_', 'used_by_', 'class_method_'),\n", + " ('class_field_', 'used_by_', 'function_'),\n", + " ('class_method_', 'called_by_', 'function_'),\n", + " ('class_method_', 'called_by_', 'module_'),\n", + " ('class_method_', 'defined_in_', 'class_'),\n", + " ('class_method_', 'global_mention_', 'mention_'),\n", + " ('class_method_', 'uses_', 'class_field_'),\n", + " ('class_method_', 'uses_type_', 'class_'),\n", + " ('function_', 'called_by_', 'function_'),\n", + " ('function_', 'called_by_', 'module_'),\n", + " ('function_', 'calls_', 'class_method_'),\n", + " ('function_', 'calls_', 'function_'),\n", + " ('function_', 'defined_in_', 'class_'),\n", + " ('function_', 'defined_in_', 'non_indexed_symbol_'),\n", + " ('function_', 'global_mention_', 'mention_'),\n", + " ('function_', 'uses_', 'class_field_'),\n", + " ('function_', 'uses_', 'global_variable_'),\n", + " ('function_', 'uses_type_', 'class_'),\n", + " ('global_variable_', 'defined_in_', 'module_'),\n", + " ('global_variable_', 'global_mention_', 'mention_'),\n", + " ('global_variable_', 'used_by_', 'function_'),\n", + " ('mention_', 'arg_', 'arg_'),\n", + " ('mention_', 'args_', 'Call_'),\n", + " ('mention_', 'asname_', 'alias_'),\n", + " ('mention_', 'class_name_rev_', 'ClassDef_'),\n", + " ('mention_', 'func_', 'Call_'),\n", + " ('mention_', 'function_name_rev_', 'FunctionDef_'),\n", + " ('mention_', 'left_', 'BinOp_'),\n", + " ('mention_', 'module_', 'ImportFrom_'),\n", + " ('mention_', 'name_', 'alias_'),\n", + " ('mention_', 'right_', 'BinOp_'),\n", + " ('mention_', 'target_', 'AnnAssign_'),\n", + " ('mention_', 'targets_', 'Assign_'),\n", + " ('mention_', 'value_', 'Assign_'),\n", + " ('mention_', 'value_', 'Attribute_'),\n", + " ('mention_', 'value_', 'Return_'),\n", + " ('module_', 'calls_', 'class_method_'),\n", + " ('module_', 'calls_', 'function_'),\n", + " ('module_', 'defines_', 'class_'),\n", + " ('module_', 'defines_', 'global_variable_'),\n", + " ('module_', 'global_mention_', 'mention_'),\n", + " ('module_', 'imports_', 'class_'),\n", + " ('module_', 'used_by_', 'module_'),\n", + " ('module_', 'uses_', 'module_'),\n", + " ('non_indexed_symbol_', 'defines_', 'class_'),\n", + " ('non_indexed_symbol_', 'defines_', 'function_'),\n", + " ('subword_', 'subword_', 'mention_')]" + ] + }, + "execution_count": 96, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.canonical_etypes" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "b4b3e786-a021-4048-9b0c-020733171683", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "EdgeSpace(data={})" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.edges[('subword_', 'subword_', 'mention_')]" + ] + }, + { + "cell_type": "markdown", + "id": "1df19fd9-7675-41ba-bcf7-99121f15048f", + "metadata": {}, + "source": [ + "## Dataloaders" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "1a41f3b6-2f04-4ac3-8c2c-109c26c6d578", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'FunctionDef_': tensor([1, 3]), 'mention_': tensor([ 6, 19])}" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def get_train_nodes(graph):\n", + " train_nodes = {}\n", + " for node_type, mask in g.ndata[\"train_mask\"].items():\n", + " train_ids = g.ndata[\"typed_id\"][node_type][mask]\n", + " if len(train_ids) > 0:\n", + " train_nodes[node_type] = train_ids\n", + " return train_nodes\n", + "\n", + "get_train_nodes(g)" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "ef3e7b73-925c-43e8-a2ab-c25ba710251d", + "metadata": {}, + "outputs": [], + "source": [ + "sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n", + "loader = dgl.dataloading.NodeDataLoader(\n", + " g, get_train_nodes(g), sampler, batch_size=1, shuffle=False, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "ba78130d-64e9-460d-87fb-122f413902f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 0\n", + "Seeds:\n", + "FunctionDef_ tensor([35])\n", + "\n", + "Input nodes:\n", + "ClassDef_ tensor([17])\n", + "FunctionDef_ tensor([35, 16, 49])\n", + "Return_ tensor([45])\n", + "arguments_ tensor([37])\n", + "mention_ tensor([46])\n", + "\n", + "Layer 0\n", + "tensor([17]) --> ('ClassDef_', 'defined_in_class_rev_', 'FunctionDef_') --> tensor([35])\n", + "tensor([35, 16, 49]) --> ('FunctionDef_', 'next_', 'FunctionDef_') --> tensor([35])\n", + "tensor([35, 16, 49]) --> ('FunctionDef_', 'prev_', 'FunctionDef_') --> tensor([35])\n", + "tensor([45]) --> ('Return_', 'defined_in_function_', 'FunctionDef_') --> tensor([35])\n", + "tensor([37]) --> ('arguments_', 'args_', 'FunctionDef_') --> tensor([35])\n", + "tensor([46]) --> ('mention_', 'function_name_rev_', 'FunctionDef_') --> tensor([35])\n", + "\n", + "\n", + "Batch: 1\n", + "Seeds:\n", + "FunctionDef_ tensor([79])\n", + "\n", + "Input nodes:\n", + "Assign_ tensor([75])\n", + "Call_ tensor([83, 90])\n", + "FunctionDef_ tensor([79])\n", + "Module_ tensor([66])\n", + "mention_ tensor([89])\n", + "\n", + "Layer 0\n", + "tensor([75]) --> ('Assign_', 'next_', 'FunctionDef_') --> tensor([79])\n", + "tensor([83, 90]) --> ('Call_', 'defined_in_function_', 'FunctionDef_') --> tensor([79])\n", + "tensor([83, 90]) --> ('Call_', 'prev_', 'FunctionDef_') --> tensor([79])\n", + "tensor([66]) --> ('Module_', 'defined_in_module_rev_', 'FunctionDef_') --> tensor([79])\n", + "tensor([89]) --> ('mention_', 'function_name_rev_', 'FunctionDef_') --> tensor([79])\n", + "\n", + "\n", + "Batch: 2\n", + "Seeds:\n", + "mention_ tensor([53])\n", + "\n", + "Input nodes:\n", + "AnnAssign_ tensor([54])\n", + "Call_ tensor([59])\n", + "mention_ tensor([53])\n", + "subword_ tensor([52, 47])\n", + "\n", + "Layer 0\n", + "tensor([54]) --> ('AnnAssign_', 'target_rev_', 'mention_') --> tensor([53])\n", + "tensor([59]) --> ('Call_', 'args_rev_', 'mention_') --> tensor([53])\n", + "tensor([52, 47]) --> ('subword_', 'subword_', 'mention_') --> tensor([53])\n", + "\n", + "\n", + "Batch: 3\n", + "Seeds:\n", + "mention_ tensor([104])\n", + "\n", + "Input nodes:\n", + "Assign_ tensor([106])\n", + "arg_ tensor([105])\n", + "mention_ tensor([104])\n", + "subword_ tensor([103])\n", + "\n", + "Layer 0\n", + "tensor([106]) --> ('Assign_', 'value_rev_', 'mention_') --> tensor([104])\n", + "tensor([105]) --> ('arg_', 'arg_rev_', 'mention_') --> tensor([104])\n", + "tensor([103]) --> ('subword_', 'subword_', 'mention_') --> tensor([104])\n", + "\n", + "\n" + ] + } + ], + "source": [ + "for ind, (input_nodes, seeds, blocks) in enumerate(loader):\n", + " print(\"Batch:\", ind)\n", + " \n", + " print(\"Seeds:\")\n", + " for key, val in seeds.items():\n", + " if len(val) > 0:\n", + " print(key, blocks[-1].dstnodes[key].data[\"original_id\"])\n", + " \n", + " print()\n", + " \n", + " print(\"Input nodes:\")\n", + " for key, val in input_nodes.items():\n", + " if len(val) > 0:\n", + " print(key, blocks[0].srcnodes[key].data[\"original_id\"])\n", + " \n", + " print()\n", + " \n", + " for b_ind, block in enumerate(blocks):\n", + " print(\"Layer\", b_ind)\n", + " for etype in block.canonical_etypes:\n", + " if block[etype].num_edges() > 0:\n", + " # for srctype, dsttype in zip(\n", + " # print(blocks[0][etype].adj().to_dense())\n", + " print(block.srcnodes[etype[0]].data[\"original_id\"], \"-->\", etype, \"-->\", block.dstnodes[etype[2]].data[\"original_id\"]) #, block[etype].num_edges()) \n", + " print()\n", + " print()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f18ad6d7-8e92-48a5-9888-f9c2c9471414", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:SourceCodeTools]", + "language": "python", + "name": "conda-env-SourceCodeTools-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/Subgraph Classification.ipynb b/examples/Subgraph Classification.ipynb new file mode 100644 index 00000000..97555ffa --- /dev/null +++ b/examples/Subgraph Classification.ipynb @@ -0,0 +1,698 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "244ac956", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%load_ext tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "bdefc65e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "from random import random\n", + "\n", + "from SourceCodeTools.models.training_config import get_config, save_config, load_config\n", + "from SourceCodeTools.code.data.dataset.Dataset import SourceGraphDataset, filter_dst_by_freq\n", + "from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure, SamplingMultitaskTrainer\n", + "from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeClassifierObjective\n", + "from SourceCodeTools.models.graph.train.objectives.SubgraphClassifierObjective import SubgraphAbstractObjective, \\\n", + " SubgraphClassifierObjective, SubgraphEmbeddingObjective\n", + "from SourceCodeTools.models.graph.train.utils import get_name, get_model_base\n", + "from SourceCodeTools.models.graph import RGGAN\n", + "from SourceCodeTools.tabular.common import compact_property\n", + "from SourceCodeTools.code.data.file_utils import unpersist\n", + "\n", + "import dgl\n", + "import torch\n", + "import numpy as np\n", + "from argparse import Namespace\n", + "from torch import nn\n", + "from datetime import datetime\n", + "from os.path import join\n", + "from functools import partial" + ] + }, + { + "cell_type": "markdown", + "id": "4becf482", + "metadata": { + "tags": [] + }, + "source": [ + "# Prepare parameters and options\n", + "\n", + "Full list of options that can be added can be found in `SourceCodeTools/models/training_options.py`. They are ment to be used as arguments for cli trainer. Trainer script can be found in `SourceCodeTools/scripts/train.py`.\n", + "\n", + "For the task of subgraph classification the important options are:\n", + "- `subgraph_partition` is path to subgraph-based train/val/test sets. Storead as Dataframe with subgraph id and partition mask\n", + "- `subgraph_id_column` is a column is `common_edges` file that stores subgraph id.\n", + "- For variable misuse task (same will apply to authorship attribution) subgraphs are created for individual functions (files for SCAA). The label is stored in `common_filecontent`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5637f70e-b8fd-4a9f-9956-d60a61bd2870", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_path = \"sentencepiece_bpe.model\"\n", + "\n", + "data_path = \"cubert_varmisuse_tiny\"\n", + "subgraph_partition = join(data_path, \"partition.json.bz2\")\n", + "filecontent_path = join(data_path, \"common_filecontent.json.bz2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "404565f5-09a3-4535-ba51-8533491f7c72", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtrain_maskval_masktest_mask
05040TrueFalseFalse
19406TrueFalseFalse
29720TrueFalseFalse
35332TrueFalseFalse
46923TrueFalseFalse
...............
23087993090FalseFalseTrue
23088998594FalseFalseTrue
23089998064FalseFalseTrue
23090997465FalseFalseTrue
23091994477FalseFalseTrue
\n", + "

23092 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " id train_mask val_mask test_mask\n", + "0 5040 True False False\n", + "1 9406 True False False\n", + "2 9720 True False False\n", + "3 5332 True False False\n", + "4 6923 True False False\n", + "... ... ... ... ...\n", + "23087 993090 False False True\n", + "23088 998594 False False True\n", + "23089 998064 False False True\n", + "23090 997465 False False True\n", + "23091 994477 False False True\n", + "\n", + "[23092 rows x 4 columns]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unpersist(subgraph_partition)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "d09b740a-eaaf-4fac-a73f-0bd9bb671f70", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtypesource_node_idtarget_node_idfile_idmentioned_in
00subword4831744250645040NaN
11arg425064971245040142323.0
22arg_rev971244250645040142323.0
33args971242487515040142323.0
44args_rev248751971245040142323.0
55args2487511423235040142323.0
66args_rev1423232487515040142323.0
77subword7861981432875040NaN
88decorator_list1432871423235040142323.0
99decorator_list_rev1423231432875040142323.0
\n", + "
" + ], + "text/plain": [ + " id type source_node_id target_node_id file_id \\\n", + "0 0 subword 483174 425064 5040 \n", + "1 1 arg 425064 97124 5040 \n", + "2 2 arg_rev 97124 425064 5040 \n", + "3 3 args 97124 248751 5040 \n", + "4 4 args_rev 248751 97124 5040 \n", + "5 5 args 248751 142323 5040 \n", + "6 6 args_rev 142323 248751 5040 \n", + "7 7 subword 786198 143287 5040 \n", + "8 8 decorator_list 143287 142323 5040 \n", + "9 9 decorator_list_rev 142323 143287 5040 \n", + "\n", + " mentioned_in \n", + "0 NaN \n", + "1 142323.0 \n", + "2 142323.0 \n", + "3 142323.0 \n", + "4 142323.0 \n", + "5 142323.0 \n", + "6 142323.0 \n", + "7 NaN \n", + "8 142323.0 \n", + "9 142323.0 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unpersist(join(data_path, \"common_edges.json.bz2\"), nrows=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d38f83db", + "metadata": {}, + "outputs": [], + "source": [ + "config = get_config(\n", + " subgraph_id_column=\"file_id\",\n", + " data_path=data_path,\n", + " model_output_dir=data_path,\n", + " subgraph_partition=subgraph_partition,\n", + " tokenizer_path=tokenizer_path,\n", + " objectives=\"subgraph_clf\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bf61e1a2-2a2b-4688-b6b6-4ad3c22b22a5", + "metadata": {}, + "outputs": [], + "source": [ + "save_config(config, \"var_misuse_tiny.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "07dc7b86", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'DATASET': {'data_path': 'cubert_varmisuse_tiny',\n", + " 'train_frac': 0.9,\n", + " 'filter_edges': None,\n", + " 'min_count_for_objectives': 5,\n", + " 'self_loops': False,\n", + " 'use_node_types': False,\n", + " 'use_edge_types': False,\n", + " 'no_global_edges': False,\n", + " 'remove_reverse': False,\n", + " 'custom_reverse': None,\n", + " 'restricted_id_pool': None,\n", + " 'random_seed': None,\n", + " 'subgraph_id_column': 'file_id',\n", + " 'subgraph_partition': 'cubert_varmisuse_tiny/partition.json.bz2'},\n", + " 'TRAINING': {'model_output_dir': 'cubert_varmisuse_tiny',\n", + " 'pretrained': None,\n", + " 'pretraining_phase': 0,\n", + " 'sampling_neighbourhood_size': 10,\n", + " 'neg_sampling_factor': 3,\n", + " 'use_layer_scheduling': False,\n", + " 'schedule_layers_every': 10,\n", + " 'elem_emb_size': 100,\n", + " 'embedding_table_size': 200000,\n", + " 'epochs': 100,\n", + " 'batch_size': 128,\n", + " 'learning_rate': 0.001,\n", + " 'objectives': 'subgraph_clf',\n", + " 'save_each_epoch': False,\n", + " 'save_checkpoints': True,\n", + " 'early_stopping': False,\n", + " 'early_stopping_tolerance': 20,\n", + " 'force_w2v_ns': False,\n", + " 'use_ns_groups': False,\n", + " 'nn_index': 'brute',\n", + " 'metric': 'inner_prod',\n", + " 'measure_scores': False,\n", + " 'dilate_scores': 200,\n", + " 'gpu': -1,\n", + " 'external_dataset': None,\n", + " 'restore_state': False},\n", + " 'MODEL': {'node_emb_size': 100,\n", + " 'h_dim': 100,\n", + " 'n_layers': 5,\n", + " 'use_self_loop': True,\n", + " 'use_gcn_checkpoint': False,\n", + " 'use_att_checkpoint': False,\n", + " 'use_gru_checkpoint': False,\n", + " 'num_bases': 10,\n", + " 'dropout': 0.0,\n", + " 'activation': 'tanh'},\n", + " 'TOKENIZER': {'tokenizer_path': 'sentencepiece_bpe.model'}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "markdown", + "id": "069a2528", + "metadata": {}, + "source": [ + "# Create Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "d0ce29fc", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SourceGraphDataset(\n", + " **{**config[\"DATASET\"], **config[\"TOKENIZER\"]}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "897d41bf", + "metadata": {}, + "source": [ + "# Declare target loading function (labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c279f43e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def load_labels():\n", + " filecontent = unpersist(filecontent_path)\n", + " return filecontent[[\"id\", \"label\"]].rename({\"id\": \"src\", \"label\": \"dst\"}, axis=1)" + ] + }, + { + "cell_type": "markdown", + "id": "1ac27d90", + "metadata": {}, + "source": [ + "One or several objectives could be used" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "67605a81", + "metadata": {}, + "outputs": [], + "source": [ + "class Trainer(SamplingMultitaskTrainer):\n", + " def create_objectives(self, dataset, tokenizer_path):\n", + " self.objectives = nn.ModuleList()\n", + " \n", + " self.objectives.append(\n", + " self._create_subgraph_objective(\n", + " objective_name=\"VariableMisuseClf\",\n", + " objective_class=SubgraphClassifierObjective,\n", + " dataset=dataset,\n", + " tokenizer_path=tokenizer_path,\n", + " labels_fn=load_labels,\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "393251c2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%tensorboard --logdir data_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "369ed1bb", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 0%| | 0/110 [00:00 str:\n \"\"\"\n Call another method. \u0412\u044b\u0437\u043e\u0432 \u0434\u0440\u0443\u0433\u043e\u0433\u043e \u043c\u0435\u0442\u043e\u0434\u0430.\n :return:\n \"\"\"\n return self.method2()\n\n def method2(self) -> str:\n \"\"\"\n Simple operations.\n \u041f\u0440\u043e\u0441\u0442\u044b\u0435 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0438.\n :return:\n \"\"\"\n variable1: int = self.field\n variable2: str = str(variable1)\n return variable2"} +{"id":18,"type":"Module","serialized_name":"Module_0x16d41315b631c9ef","mentioned_in":null,"string":null} +{"id":19,"type":"arg","serialized_name":"arg_0x16d41315b616478a","mentioned_in":16,"string":"self"} +{"id":20,"type":"arguments","serialized_name":"arguments_0x16d41315b6320346","mentioned_in":16,"string":null} +{"id":21,"type":"subword","serialized_name":"\u2581argument","mentioned_in":null,"string":null} +{"id":22,"type":"mention","serialized_name":"argument@FunctionDef_0x16d41315b61e0b59","mentioned_in":16,"string":null} +{"id":23,"type":"arg","serialized_name":"arg_0x16d41315b6854c98","mentioned_in":16,"string":"argument: int"} +{"id":26,"type":"Assign","serialized_name":"Assign_0x16d41315b65e0868","mentioned_in":16,"string":"self.field = argument"} +{"id":27,"type":"Attribute","serialized_name":"Attribute_0x16d41315b6d16a6d","mentioned_in":16,"string":"self.field"} +{"id":28,"type":"#attr#","serialized_name":"field","mentioned_in":null,"string":null} +{"id":30,"type":"subword","serialized_name":"\u2581__","mentioned_in":null,"string":null} +{"id":31,"type":"mention","serialized_name":"__init__@ClassDef_0x16d41315b60713c7","mentioned_in":17,"string":null} +{"id":32,"type":"subword","serialized_name":"init","mentioned_in":null,"string":null} +{"id":33,"type":"subword","serialized_name":"__","mentioned_in":null,"string":null} +{"id":34,"type":"mention","serialized_name":"self@FunctionDef_0x16d41315b69b5874","mentioned_in":35,"string":null} +{"id":35,"type":"FunctionDef","serialized_name":"FunctionDef_0x16d41315b69b5874","mentioned_in":17,"string":"def method1(self) -> str:\n \"\"\"\n Call another method. \u0412\u044b\u0437\u043e\u0432 \u0434\u0440\u0443\u0433\u043e\u0433\u043e \u043c\u0435\u0442\u043e\u0434\u0430.\n :return:\n \"\"\"\n return self.method2()"} +{"id":36,"type":"arg","serialized_name":"arg_0x16d41315b6881089","mentioned_in":35,"string":"self"} +{"id":37,"type":"arguments","serialized_name":"arguments_0x16d41315b64e4b84","mentioned_in":35,"string":null} +{"id":39,"type":"subword","serialized_name":"\u2581str","mentioned_in":null,"string":null} +{"id":40,"type":"Attribute","serialized_name":"Attribute_0x16d41315b6633cfb","mentioned_in":35,"string":"self.method2"} +{"id":41,"type":"#attr#","serialized_name":"method2","mentioned_in":null,"string":null} +{"id":42,"type":"subword","serialized_name":"\u2581method","mentioned_in":null,"string":null} +{"id":43,"type":"subword","serialized_name":"2","mentioned_in":null,"string":null} +{"id":44,"type":"Call","serialized_name":"Call_0x16d41315b6f8b441","mentioned_in":35,"string":"self.method2()"} +{"id":45,"type":"Return","serialized_name":"Return_0x16d41315b6887186","mentioned_in":35,"string":"return self.method2()"} +{"id":46,"type":"mention","serialized_name":"method1@ClassDef_0x16d41315b60713c7","mentioned_in":17,"string":null} +{"id":47,"type":"subword","serialized_name":"1","mentioned_in":null,"string":null} +{"id":48,"type":"mention","serialized_name":"self@FunctionDef_0x16d41315b6ad90da","mentioned_in":49,"string":null} +{"id":49,"type":"FunctionDef","serialized_name":"FunctionDef_0x16d41315b6ad90da","mentioned_in":17,"string":"def method2(self) -> str:\n \"\"\"\n Simple operations.\n \u041f\u0440\u043e\u0441\u0442\u044b\u0435 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0438.\n :return:\n \"\"\"\n variable1: int = self.field\n variable2: str = str(variable1)\n return variable2"} +{"id":50,"type":"arg","serialized_name":"arg_0x16d41315b69c8916","mentioned_in":49,"string":"self"} +{"id":51,"type":"arguments","serialized_name":"arguments_0x16d41315b685c455","mentioned_in":49,"string":null} +{"id":52,"type":"subword","serialized_name":"\u2581variable","mentioned_in":null,"string":null} +{"id":53,"type":"mention","serialized_name":"variable1@FunctionDef_0x16d41315b6ad90da","mentioned_in":49,"string":null} +{"id":54,"type":"AnnAssign","serialized_name":"AnnAssign_0x16d41315b69586f8","mentioned_in":49,"string":"variable1: int = self.field"} +{"id":55,"type":"Attribute","serialized_name":"Attribute_0x16d41315b692f3a5","mentioned_in":49,"string":"self.field"} +{"id":56,"type":"mention","serialized_name":"variable2@FunctionDef_0x16d41315b6ad90da","mentioned_in":49,"string":null} +{"id":57,"type":"AnnAssign","serialized_name":"AnnAssign_0x16d41315b64fbc69","mentioned_in":49,"string":"variable2: str = str(variable1)"} +{"id":58,"type":"mention","serialized_name":"str@FunctionDef_0x16d41315b6ad90da","mentioned_in":49,"string":null} +{"id":59,"type":"Call","serialized_name":"Call_0x16d41315b6720a4f","mentioned_in":49,"string":"str(variable1)"} +{"id":60,"type":"Return","serialized_name":"Return_0x16d41315b639dcda","mentioned_in":49,"string":"return variable2"} +{"id":61,"type":"mention","serialized_name":"method2@ClassDef_0x16d41315b60713c7","mentioned_in":17,"string":null} +{"id":62,"type":"subword","serialized_name":"\u2581Example","mentioned_in":null,"string":null} +{"id":63,"type":"mention","serialized_name":"ExampleClass@Module_0x16d41315b631c9ef","mentioned_in":18,"string":null} +{"id":64,"type":"subword","serialized_name":"Class","mentioned_in":null,"string":null} +{"id":65,"type":"mention","serialized_name":"ExampleModule@Module_0x16d41315b7d4afab","mentioned_in":66,"string":null} +{"id":66,"type":"Module","serialized_name":"Module_0x16d41315b7d4afab","mentioned_in":null,"string":null} +{"id":67,"type":"subword","serialized_name":"Module","mentioned_in":null,"string":null} +{"id":68,"type":"ImportFrom","serialized_name":"ImportFrom_0x16d41315b71ee551","mentioned_in":66,"string":"from ExampleModule import ExampleClass as EC"} +{"id":69,"type":"mention","serialized_name":"ExampleClass@Module_0x16d41315b7d4afab","mentioned_in":66,"string":null} +{"id":70,"type":"alias","serialized_name":"alias_0x16d41315b78c6d18","mentioned_in":66,"string":null} +{"id":71,"type":"subword","serialized_name":"\u2581EC","mentioned_in":null,"string":null} +{"id":72,"type":"mention","serialized_name":"EC@Module_0x16d41315b7d4afab","mentioned_in":66,"string":null} +{"id":73,"type":"Call","serialized_name":"Call_0x16d41315b75ea742","mentioned_in":66,"string":"EC(5)"} +{"id":74,"type":"Constant","serialized_name":"Constant_","mentioned_in":null,"string":null} +{"id":75,"type":"Assign","serialized_name":"Assign_0x16d41315b72b1e6c","mentioned_in":66,"string":"instance = EC(5)"} +{"id":76,"type":"subword","serialized_name":"\u2581instance","mentioned_in":null,"string":null} +{"id":77,"type":"mention","serialized_name":"instance@Module_0x16d41315b7d4afab","mentioned_in":66,"string":null} +{"id":79,"type":"FunctionDef","serialized_name":"FunctionDef_0x16d41315b735b6ae","mentioned_in":66,"string":"def main() -> None:\n print(instance.method1())"} +{"id":81,"type":"subword","serialized_name":"\u2581print","mentioned_in":null,"string":null} +{"id":82,"type":"mention","serialized_name":"print@FunctionDef_0x16d41315b735b6ae","mentioned_in":79,"string":null} +{"id":83,"type":"Call","serialized_name":"Call_0x16d41315b795d9b1","mentioned_in":79,"string":"print(instance.method1())"} +{"id":84,"type":"mention","serialized_name":"instance@FunctionDef_0x16d41315b735b6ae","mentioned_in":79,"string":null} +{"id":85,"type":"Attribute","serialized_name":"Attribute_0x16d41315b7d4d09e","mentioned_in":79,"string":"instance.method1"} +{"id":86,"type":"#attr#","serialized_name":"method1","mentioned_in":null,"string":null} +{"id":87,"type":"Call","serialized_name":"Call_0x16d41315b7d93dbb","mentioned_in":79,"string":"instance.method1()"} +{"id":88,"type":"subword","serialized_name":"\u2581main","mentioned_in":null,"string":null} +{"id":89,"type":"mention","serialized_name":"main@Module_0x16d41315b7d4afab","mentioned_in":66,"string":null} +{"id":90,"type":"Call","serialized_name":"Call_0x16d41315b738fe7d","mentioned_in":66,"string":"main()"} +{"id":91,"type":"module","serialized_name":"Module","mentioned_in":null,"string":null} +{"id":92,"type":"class","serialized_name":"Module.Number","mentioned_in":null,"string":null} +{"id":93,"type":"class_method","serialized_name":"Module.Number.__init__","mentioned_in":null,"string":null} +{"id":94,"type":"class_field","serialized_name":"Module.Number.val","mentioned_in":null,"string":null} +{"id":95,"type":"function","serialized_name":"Module.Number.__add__","mentioned_in":null,"string":null} +{"id":96,"type":"function","serialized_name":"Module.Number.__repr__","mentioned_in":null,"string":null} +{"id":97,"type":"mention","serialized_name":"self@FunctionDef_0x16d41315c8deaa67","mentioned_in":98,"string":null} +{"id":98,"type":"FunctionDef","serialized_name":"FunctionDef_0x16d41315c8deaa67","mentioned_in":99,"string":"def __init__(self, value: int):\n \"\"\"\n Initialize. \u0418\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u044f\n :param argument:\n \"\"\"\n self.val = value"} +{"id":99,"type":"ClassDef","serialized_name":"ClassDef_0x16d41315c87e23dc","mentioned_in":100,"string":"class Number:\n def __init__(self, value: int):\n \"\"\"\n Initialize. \u0418\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u044f\n :param argument:\n \"\"\"\n self.val = value\n\n def __add__(self, value):\n \"\"\"\n Add two numbers.\n \u0421\u043b\u043e\u0436\u0438\u0442\u044c 2 \u0447\u0438\u0441\u043b\u0430\n :param value:\n :return:\n \"\"\"\n return Number(self.val + value.val)\n\n def __repr__(self) -> str:\n \"\"\"\n Return representation\n :return: \u041f\u043e\u043b\u0443\u0447\u0438\u0442\u044c \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435\n \"\"\"\n return f\"Number({self.val})\""} +{"id":100,"type":"Module","serialized_name":"Module_0x16d41315c8795210","mentioned_in":null,"string":null} +{"id":101,"type":"arg","serialized_name":"arg_0x16d41315c8c0c85f","mentioned_in":98,"string":"self"} +{"id":102,"type":"arguments","serialized_name":"arguments_0x16d41315c86b314b","mentioned_in":98,"string":null} +{"id":103,"type":"subword","serialized_name":"\u2581value","mentioned_in":null,"string":null} +{"id":104,"type":"mention","serialized_name":"value@FunctionDef_0x16d41315c8deaa67","mentioned_in":98,"string":null} +{"id":105,"type":"arg","serialized_name":"arg_0x16d41315c829b52d","mentioned_in":98,"string":"value: int"} +{"id":106,"type":"Assign","serialized_name":"Assign_0x16d41315c8f71d5b","mentioned_in":98,"string":"self.val = value"} +{"id":107,"type":"Attribute","serialized_name":"Attribute_0x16d41315c898a89c","mentioned_in":98,"string":"self.val"} +{"id":108,"type":"#attr#","serialized_name":"val","mentioned_in":null,"string":null} +{"id":110,"type":"mention","serialized_name":"__init__@ClassDef_0x16d41315c87e23dc","mentioned_in":99,"string":null} +{"id":111,"type":"mention","serialized_name":"self@FunctionDef_0x16d41315c8336015","mentioned_in":112,"string":null} +{"id":112,"type":"FunctionDef","serialized_name":"FunctionDef_0x16d41315c8336015","mentioned_in":99,"string":"def __add__(self, value):\n \"\"\"\n Add two numbers.\n \u0421\u043b\u043e\u0436\u0438\u0442\u044c 2 \u0447\u0438\u0441\u043b\u0430\n :param value:\n :return:\n \"\"\"\n return Number(self.val + value.val)"} +{"id":113,"type":"arg","serialized_name":"arg_0x16d41315c836c8e3","mentioned_in":112,"string":"self"} +{"id":114,"type":"arguments","serialized_name":"arguments_0x16d41315c89f1854","mentioned_in":112,"string":null} +{"id":115,"type":"mention","serialized_name":"value@FunctionDef_0x16d41315c8336015","mentioned_in":112,"string":null} +{"id":116,"type":"arg","serialized_name":"arg_0x16d41315c8b89fd0","mentioned_in":112,"string":"value"} +{"id":117,"type":"subword","serialized_name":"\u2581Number","mentioned_in":null,"string":null} +{"id":118,"type":"mention","serialized_name":"Number@FunctionDef_0x16d41315c8336015","mentioned_in":112,"string":null} +{"id":119,"type":"Call","serialized_name":"Call_0x16d41315c81ad247","mentioned_in":112,"string":"Number(self.val + value.val)"} +{"id":120,"type":"Attribute","serialized_name":"Attribute_0x16d41315c8e05daa","mentioned_in":112,"string":"self.val"} +{"id":121,"type":"BinOp","serialized_name":"BinOp_0x16d41315c81b0e8d","mentioned_in":112,"string":"self.val + value.val"} +{"id":122,"type":"Attribute","serialized_name":"Attribute_0x16d41315c83902eb","mentioned_in":112,"string":"value.val"} +{"id":123,"type":"Op","serialized_name":"Add","mentioned_in":null,"string":null} +{"id":124,"type":"Return","serialized_name":"Return_0x16d41315c814af65","mentioned_in":112,"string":"return Number(self.val + value.val)"} +{"id":125,"type":"mention","serialized_name":"__add__@ClassDef_0x16d41315c87e23dc","mentioned_in":99,"string":null} +{"id":126,"type":"subword","serialized_name":"add","mentioned_in":null,"string":null} +{"id":127,"type":"mention","serialized_name":"self@FunctionDef_0x16d41315c81c9bdf","mentioned_in":128,"string":null} +{"id":128,"type":"FunctionDef","serialized_name":"FunctionDef_0x16d41315c81c9bdf","mentioned_in":99,"string":"def __repr__(self) -> str:\n \"\"\"\n Return representation\n :return: \u041f\u043e\u043b\u0443\u0447\u0438\u0442\u044c \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435\n \"\"\"\n return f\"Number({self.val})\""} +{"id":129,"type":"arg","serialized_name":"arg_0x16d41315c8766105","mentioned_in":128,"string":"self"} +{"id":130,"type":"arguments","serialized_name":"arguments_0x16d41315c8044c7c","mentioned_in":128,"string":null} +{"id":131,"type":"JoinedStr","serialized_name":"JoinedStr_","mentioned_in":null,"string":null} +{"id":132,"type":"Return","serialized_name":"Return_0x16d41315c8a05fc6","mentioned_in":128,"string":"return f\"Number({self.val})\""} +{"id":133,"type":"mention","serialized_name":"__repr__@ClassDef_0x16d41315c87e23dc","mentioned_in":99,"string":null} +{"id":134,"type":"subword","serialized_name":"repr","mentioned_in":null,"string":null} +{"id":135,"type":"mention","serialized_name":"Number@Module_0x16d41315c8795210","mentioned_in":100,"string":null} +{"id":136,"type":"subword","serialized_name":"\u2581Module","mentioned_in":null,"string":null} +{"id":137,"type":"mention","serialized_name":"Module@Module_0x16d41315c9361936","mentioned_in":138,"string":null} +{"id":138,"type":"Module","serialized_name":"Module_0x16d41315c9361936","mentioned_in":null,"string":null} +{"id":139,"type":"ImportFrom","serialized_name":"ImportFrom_0x16d41315c996d4c8","mentioned_in":138,"string":"from Module import Number"} +{"id":140,"type":"mention","serialized_name":"Number@Module_0x16d41315c9361936","mentioned_in":138,"string":null} +{"id":141,"type":"alias","serialized_name":"alias_0x16d41315c9a652fd","mentioned_in":138,"string":null} +{"id":142,"type":"mention","serialized_name":"Number@FunctionDef_0x16d41315c9c41f53","mentioned_in":143,"string":null} +{"id":143,"type":"FunctionDef","serialized_name":"FunctionDef_0x16d41315c9c41f53","mentioned_in":138,"string":"def main():\n a = Number(4)\n b = Number(5)\n print(a+b)"} +{"id":144,"type":"Call","serialized_name":"Call_0x16d41315c99e75f7","mentioned_in":143,"string":"Number(4)"} +{"id":145,"type":"Assign","serialized_name":"Assign_0x16d41315c909ace6","mentioned_in":143,"string":"a = Number(4)"} +{"id":146,"type":"subword","serialized_name":"\u2581a","mentioned_in":null,"string":null} +{"id":147,"type":"mention","serialized_name":"a@FunctionDef_0x16d41315c9c41f53","mentioned_in":143,"string":null} +{"id":148,"type":"Call","serialized_name":"Call_0x16d41315c9bee524","mentioned_in":143,"string":"Number(5)"} +{"id":149,"type":"Assign","serialized_name":"Assign_0x16d41315c9c3c775","mentioned_in":143,"string":"b = Number(5)"} +{"id":150,"type":"subword","serialized_name":"\u2581b","mentioned_in":null,"string":null} +{"id":151,"type":"mention","serialized_name":"b@FunctionDef_0x16d41315c9c41f53","mentioned_in":143,"string":null} +{"id":152,"type":"mention","serialized_name":"print@FunctionDef_0x16d41315c9c41f53","mentioned_in":143,"string":null} +{"id":153,"type":"Call","serialized_name":"Call_0x16d41315c9148a75","mentioned_in":143,"string":"print(a+b)"} +{"id":154,"type":"BinOp","serialized_name":"BinOp_0x16d41315c919b5a9","mentioned_in":143,"string":"a+b"} +{"id":155,"type":"mention","serialized_name":"main@Module_0x16d41315c9361936","mentioned_in":138,"string":null} +{"id":156,"type":"Call","serialized_name":"Call_0x16d41315c99405f0","mentioned_in":138,"string":"main()"} diff --git a/examples/small_graph/type_annotations.json b/examples/small_graph/type_annotations.json new file mode 100644 index 00000000..72c533f6 --- /dev/null +++ b/examples/small_graph/type_annotations.json @@ -0,0 +1,16 @@ +{"src":22,"dst":"int"} +{"src":24,"dst":"argument@FunctionDef_0x16d41315b61e0b59"} +{"src":35,"dst":"str"} +{"src":38,"dst":"FunctionDef_0x16d41315b69b5874"} +{"src":49,"dst":"str"} +{"src":38,"dst":"FunctionDef_0x16d41315b6ad90da"} +{"src":53,"dst":"int"} +{"src":24,"dst":"variable1@FunctionDef_0x16d41315b6ad90da"} +{"src":56,"dst":"str"} +{"src":38,"dst":"variable2@FunctionDef_0x16d41315b6ad90da"} +{"src":79,"dst":"None"} +{"src":78,"dst":"FunctionDef_0x16d41315b735b6ae"} +{"src":104,"dst":"int"} +{"src":24,"dst":"value@FunctionDef_0x16d41315c8deaa67"} +{"src":128,"dst":"str"} +{"src":38,"dst":"FunctionDef_0x16d41315c81c9bdf"} diff --git a/examples/var_misuse_tiny.yaml b/examples/var_misuse_tiny.yaml new file mode 100644 index 00000000..a3fab0f0 --- /dev/null +++ b/examples/var_misuse_tiny.yaml @@ -0,0 +1,55 @@ +DATASET: + custom_reverse: null + data_path: cubert_varmisuse_tiny + filter_edges: null + min_count_for_objectives: 5 + no_global_edges: false + random_seed: null + remove_reverse: false + restricted_id_pool: null + self_loops: false + subgraph_id_column: file_id + subgraph_partition: cubert_varmisuse_tiny/partition.json.bz2 + train_frac: 0.9 + use_edge_types: false + use_node_types: false +MODEL: + activation: tanh + dropout: 0.0 + h_dim: 100 + n_layers: 5 + node_emb_size: 100 + num_bases: 10 + use_att_checkpoint: false + use_gcn_checkpoint: false + use_gru_checkpoint: false + use_self_loop: true +TOKENIZER: + tokenizer_path: sentencepiece_bpe.model +TRAINING: + batch_size: 128 + dilate_scores: 200 + early_stopping: false + early_stopping_tolerance: 20 + elem_emb_size: 100 + embedding_table_size: 200000 + epochs: 100 + external_dataset: null + force_w2v_ns: false + gpu: -1 + learning_rate: 0.001 + measure_scores: false + metric: inner_prod + model_output_dir: cubert_varmisuse_tiny + neg_sampling_factor: 3 + nn_index: brute + objectives: subgraph_clf + pretrained: null + pretraining_phase: 0 + restore_state: false + sampling_neighbourhood_size: 10 + save_checkpoints: true + save_each_epoch: false + schedule_layers_every: 10 + use_layer_scheduling: false + use_ns_groups: false diff --git a/figures/graph_examples/100.png b/figures/graph_examples/100.png new file mode 100644 index 00000000..86be27a8 Binary files /dev/null and b/figures/graph_examples/100.png differ diff --git a/figures/graph_examples/100_.png b/figures/graph_examples/100_.png new file mode 100644 index 00000000..ecb56188 Binary files /dev/null and b/figures/graph_examples/100_.png differ diff --git a/figures/graph_examples/112.png b/figures/graph_examples/112.png new file mode 100644 index 00000000..6966861e Binary files /dev/null and b/figures/graph_examples/112.png differ diff --git a/figures/graph_examples/118.png b/figures/graph_examples/118.png new file mode 100644 index 00000000..78b6ed01 Binary files /dev/null and b/figures/graph_examples/118.png differ diff --git a/figures/graph_examples/125.png b/figures/graph_examples/125.png new file mode 100644 index 00000000..afd49762 Binary files /dev/null and b/figures/graph_examples/125.png differ diff --git a/figures/graph_examples/139_.png b/figures/graph_examples/139_.png new file mode 100644 index 00000000..6a4ca38f Binary files /dev/null and b/figures/graph_examples/139_.png differ diff --git a/figures/graph_examples/140.png b/figures/graph_examples/140.png new file mode 100644 index 00000000..db2cbc0a Binary files /dev/null and b/figures/graph_examples/140.png differ diff --git a/figures/graph_examples/141.png b/figures/graph_examples/141.png new file mode 100644 index 00000000..4faf97c7 Binary files /dev/null and b/figures/graph_examples/141.png differ diff --git a/figures/graph_examples/144_.png b/figures/graph_examples/144_.png new file mode 100644 index 00000000..18c76ccf Binary files /dev/null and b/figures/graph_examples/144_.png differ diff --git a/figures/graph_examples/149.png b/figures/graph_examples/149.png new file mode 100644 index 00000000..e1e59e34 Binary files /dev/null and b/figures/graph_examples/149.png differ diff --git a/figures/graph_examples/16_.png b/figures/graph_examples/16_.png new file mode 100644 index 00000000..f03a2674 Binary files /dev/null and b/figures/graph_examples/16_.png differ diff --git a/figures/graph_examples/17_.png b/figures/graph_examples/17_.png new file mode 100644 index 00000000..9398955f Binary files /dev/null and b/figures/graph_examples/17_.png differ diff --git a/figures/graph_examples/35_.png b/figures/graph_examples/35_.png new file mode 100644 index 00000000..c814b7a0 Binary files /dev/null and b/figures/graph_examples/35_.png differ diff --git a/figures/graph_examples/40.png b/figures/graph_examples/40.png new file mode 100644 index 00000000..eeea76d5 Binary files /dev/null and b/figures/graph_examples/40.png differ diff --git a/figures/graph_examples/49_.png b/figures/graph_examples/49_.png new file mode 100644 index 00000000..94cb9370 Binary files /dev/null and b/figures/graph_examples/49_.png differ diff --git a/figures/graph_examples/54.png b/figures/graph_examples/54.png new file mode 100644 index 00000000..887f0f06 Binary files /dev/null and b/figures/graph_examples/54.png differ diff --git a/figures/graph_examples/60.png b/figures/graph_examples/60.png new file mode 100644 index 00000000..9ea01603 Binary files /dev/null and b/figures/graph_examples/60.png differ diff --git a/figures/graph_examples/61.png b/figures/graph_examples/61.png new file mode 100644 index 00000000..b2899e0c Binary files /dev/null and b/figures/graph_examples/61.png differ diff --git a/figures/graph_examples/67_.png b/figures/graph_examples/67_.png new file mode 100644 index 00000000..96b62864 Binary files /dev/null and b/figures/graph_examples/67_.png differ diff --git a/figures/graph_examples/69.png b/figures/graph_examples/69.png new file mode 100644 index 00000000..8919ed13 Binary files /dev/null and b/figures/graph_examples/69.png differ diff --git a/figures/graph_examples/78.png b/figures/graph_examples/78.png new file mode 100644 index 00000000..b0294e70 Binary files /dev/null and b/figures/graph_examples/78.png differ diff --git a/figures/graph_examples/80_.png b/figures/graph_examples/80_.png new file mode 100644 index 00000000..aefe5039 Binary files /dev/null and b/figures/graph_examples/80_.png differ diff --git a/figures/graph_examples/no_global_index.png b/figures/graph_examples/no_global_index.png new file mode 100644 index 00000000..983eb59d Binary files /dev/null and b/figures/graph_examples/no_global_index.png differ diff --git a/res/python_testdata/example_code/example/ExampleModule.py b/res/python_testdata/example_code/example/ExampleModule.py new file mode 100644 index 00000000..44700482 --- /dev/null +++ b/res/python_testdata/example_code/example/ExampleModule.py @@ -0,0 +1,24 @@ +class ExampleClass: + def __init__(self, argument: int): + """ + Initialize. Инициализация + :param argument: + """ + self.field = argument + + def method1(self) -> str: + """ + Call another method. Вызов другого метода. + :return: + """ + return self.method2() + + def method2(self) -> str: + """ + Simple operations. + Простые операции. + :return: + """ + variable1: int = self.field + variable2: str = str(variable1) + return variable2 diff --git a/res/python_testdata/example_code/example/main.py b/res/python_testdata/example_code/example/main.py new file mode 100644 index 00000000..30caa43d --- /dev/null +++ b/res/python_testdata/example_code/example/main.py @@ -0,0 +1,9 @@ +from ExampleModule import ExampleClass as EC + +instance = EC(5) + +def main() -> None: + print(instance.method1()) + +main() + diff --git a/res/python_testdata/example_code/example2/Module.py b/res/python_testdata/example_code/example2/Module.py new file mode 100644 index 00000000..6c949d86 --- /dev/null +++ b/res/python_testdata/example_code/example2/Module.py @@ -0,0 +1,24 @@ +class Number: + def __init__(self, value: int): + """ + Initialize. Инициализация + :param argument: + """ + self.val = value + + def __add__(self, value): + """ + Add two numbers. + Сложить 2 числа + :param value: + :return: + """ + return Number(self.val + value.val) + + def __repr__(self) -> str: + """ + Return representation + :return: Получить представление + """ + return f"Number({self.val})" + diff --git a/res/python_testdata/example_code/example2/main.py b/res/python_testdata/example_code/example2/main.py new file mode 100644 index 00000000..77b3999d --- /dev/null +++ b/res/python_testdata/example_code/example2/main.py @@ -0,0 +1,9 @@ +from Module import Number + +def main(): + a = Number(4) + b = Number(5) + print(a+b) + +main() + diff --git a/scripts/data_collection/python/process_authors.sh b/scripts/data_collection/python/process_authors.sh new file mode 100644 index 00000000..b7180997 --- /dev/null +++ b/scripts/data_collection/python/process_authors.sh @@ -0,0 +1,51 @@ +# this script processes arbitrary code with sourcetrails, no dependencies are pulled + +conda activate python37 + +while read repo +do + + if [ ! -f "$repo/$repo.srctrlprj" ]; then + echo "Creating Sourcetrail project for $repo" + echo " + + + + Python Source Group + . + + .py + + + . + + enabled + Python Source Group + + + 8 +" > $repo/$repo.srctrlprj + fi + + if [ ! -f "$repo/sourcetrail.log" ]; then + run_indexing=true + else + find_edges=$(cat "$repo/sourcetrail.log" | grep " Edges") + if [ -z "$find_edges" ]; then + echo "Indexing was interrupted, recovering..." + run_indexing=true + else + run_indexing=false + fi + fi + + if $run_indexing; then + echo "Begin indexing" + Sourcetrail.sh index -i $repo/$repo.srctrlprj >> $repo/sourcetrail.log + else + echo "Already indexed" + fi + +done < "${1:-/dev/stdin}" + +conda deactivate diff --git a/scripts/data_collection/python/process_folders.sh b/scripts/data_collection/python/process_folders.sh new file mode 100644 index 00000000..09f42a41 --- /dev/null +++ b/scripts/data_collection/python/process_folders.sh @@ -0,0 +1,59 @@ +# this script processes arbitrary code with sourcetrails, no dependencies are pulled + +conda activate SourceCodeTools + +create_sourcetrail_project_if_not_exist () { + if [ ! -f $1 ]; then + echo "Creating Sourcetrail project for $repo" + echo " + + + + Python Source Group + . + + .py + + + . + + enabled + Python Source Group + + + 8 +" > $1 + fi +} + + +run_indexer () { + repo=$1 + if [ ! -f "$repo/sourcetrail.log" ]; then + run_indexing=true + else + find_edges=$(cat "$repo/sourcetrail.log" | grep " Edges") + if [ -z "$find_edges" ]; then + echo "Indexing was interrupted, recovering..." + run_indexing=true + else + run_indexing=false + fi + fi + + if $run_indexing; then + echo "Begin indexing" + Sourcetrail.sh index -i $repo/$repo.srctrlprj >> $repo/sourcetrail.log + else + echo "Already indexed" + fi +} + + +while read repo +do + create_sourcetrail_project_if_not_exist "$repo/$repo.srctrlprj" + run_indexer "$repo" +done < "${1:-/dev/stdin}" + +conda deactivate diff --git a/scripts/data_collection/requirements.txt b/scripts/data_collection/requirements.txt index 9f0f1d16..2af0f6c9 100644 --- a/scripts/data_collection/requirements.txt +++ b/scripts/data_collection/requirements.txt @@ -8,7 +8,7 @@ mkl-fft==1.0.15 mkl-random==1.1.1 mkl-service==2.3.0 msgpack-python==0.5.6 -nltk==3.5 +nltk==3.6.6 numba==0.50.1 numpy==1.18.1 pandas==1.0.3 diff --git a/scripts/data_extraction/process_sourcetrail.sh b/scripts/data_extraction/process_sourcetrail.sh index 52ff257c..4dfbc28e 100644 --- a/scripts/data_extraction/process_sourcetrail.sh +++ b/scripts/data_extraction/process_sourcetrail.sh @@ -18,21 +18,21 @@ for dir in "$ENVS_DIR"/*; do echo "Found package $package_name" - if [ -f "$dir/source_graph_bodies.csv" ]; then - rm "$dir/source_graph_bodies.csv" - fi - if [ -f "$dir/nodes_with_ast.csv" ]; then - rm "$dir/nodes_with_ast.csv" - fi - if [ -f "$dir/edges_with_ast.csv" ]; then - rm "$dir/edges_with_ast.csv" - fi - if [ -f "$dir/call_seq.csv" ]; then - rm "$dir/call_seq.csv" - fi - if [ -f "$dir/source_graph_function_variable_pairs.csv" ]; then - rm "$dir/source_graph_function_variable_pairs.csv" - fi +# if [ -f "$dir/source_graph_bodies.csv" ]; then +# rm "$dir/source_graph_bodies.csv" +# fi +# if [ -f "$dir/nodes_with_ast.csv" ]; then +# rm "$dir/nodes_with_ast.csv" +# fi +# if [ -f "$dir/edges_with_ast.csv" ]; then +# rm "$dir/edges_with_ast.csv" +# fi +# if [ -f "$dir/call_seq.csv" ]; then +# rm "$dir/call_seq.csv" +# fi +# if [ -f "$dir/source_graph_function_variable_pairs.csv" ]; then +# rm "$dir/source_graph_function_variable_pairs.csv" +# fi if [ -f "$dir/$package_name.srctrldb" ]; then @@ -40,25 +40,25 @@ for dir in "$ENVS_DIR"/*; do sqlite3 "$dir/$package_name.srctrldb" < "$SQL_Q" cd "$RUN_DIR" sourcetrail_verify_files.py "$dir" - sourcetrail_node_name_merge.py "$dir/nodes.csv" - sourcetrail_decode_edge_types.py "$dir/edges.csv" - sourcetrail_filter_ambiguous_edges.py $dir - sourcetrail_parse_bodies.py "$dir" - sourcetrail_call_seq_extractor.py "$dir" +# sourcetrail_node_name_merge.py "$dir/nodes.csv" +# sourcetrail_decode_edge_types.py "$dir/edges.csv" +# sourcetrail_filter_ambiguous_edges.py $dir +# sourcetrail_parse_bodies.py "$dir" +# sourcetrail_call_seq_extractor.py "$dir" - sourcetrail_add_reverse_edges.py "$dir/edges.bz2" - if [ -n "$BPE_PATH" ]; then - BPE_PATH=$(realpath "$BPE_PATH") - sourcetrail_ast_edges.py "$dir" -bpe $BPE_PATH --create_subword_instances - else - sourcetrail_ast_edges.py "$dir" - fi - sourcetrail_extract_variable_names.py python "$dir" +# sourcetrail_add_reverse_edges.py "$dir/edges.bz2" +# if [ -n "$BPE_PATH" ]; then +# BPE_PATH=$(realpath "$BPE_PATH") +# sourcetrail_ast_edges.py "$dir" -bpe $BPE_PATH --create_subword_instances +# else +# sourcetrail_ast_edges.py "$dir" +# fi +# sourcetrail_extract_variable_names.py python "$dir" - if [ -f "$dir"/edges_with_ast_temp.csv ]; then - rm "$dir"/edges_with_ast_temp.csv - fi +# if [ -f "$dir"/edges_with_ast_temp.csv ]; then +# rm "$dir"/edges_with_ast_temp.csv +# fi else echo "Package not indexed" diff --git a/scripts/data_extraction/process_sourcetrail_v2.sh b/scripts/data_extraction/process_sourcetrail_v2.sh new file mode 100644 index 00000000..a9679b14 --- /dev/null +++ b/scripts/data_extraction/process_sourcetrail_v2.sh @@ -0,0 +1,24 @@ +conda activate SourceCodeTools + +ENVS_DIR=$(realpath "$1") +RUN_DIR=$(realpath "$(dirname "$0")") +SQL_Q=$(realpath "$RUN_DIR/extract.sql") + +for dir in "$ENVS_DIR"/*; do + if [ -d "$dir" ]; then + package_name="$(basename "$dir")" + + echo "Found package $package_name" + + if [ -f "$dir/$package_name.srctrldb" ]; then + cd "$dir" + sqlite3 "$dir/$package_name.srctrldb" < "$SQL_Q" + cd "$RUN_DIR" + sourcetrail_verify_files.py "$dir" + else + echo "Package not indexed" + fi + fi +done + +conda deactivate \ No newline at end of file diff --git a/scripts/training/dglke_log_parser.py b/scripts/training/dglke_log_parser.py new file mode 100644 index 00000000..01a7bee4 --- /dev/null +++ b/scripts/training/dglke_log_parser.py @@ -0,0 +1,73 @@ +import matplotlib.pyplot as plt +import numpy as np + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("logfile") + + args = parser.parse_args() + + with open(args.logfile, "r") as logfile: + steps = [[]] + buffer_pos = [[]] + buffer_neg = [[]] + model_names = [] + scores = [{}] + for line in logfile: + if "pos_loss" in line: + buffer_pos[-1].append(float(line.split("pos_loss: ")[-1])) + steps[-1].append(float(line.split("(")[-1].split("/")[0])) + elif "neg_loss" in line: + buffer_neg[-1].append(float(line.split("neg_loss: ")[-1])) + elif "Save model" in line: + model_names.append(line.split("/")[-2]) + scores[-1]["Model"] = model_names[-1] + elif "Test average MRR" in line: + scores[-1]["MRR"] = line.split(":")[-1].strip() + elif "Test average MR" in line: + scores[-1]["MR"] = line.split(":")[-1].strip() + elif "Test average HITS@1:" in line: + scores[-1]["HITS@1"] = line.split(":")[-1].strip() + elif "Test average HITS@3:" in line: + scores[-1]["HITS@3"] = line.split(":")[-1].strip() + elif "Test average HITS@10:" in line: + scores[-1]["HITS@10"] = line.split(":")[-1].strip() + buffer_pos.append([]) + buffer_neg.append([]) + steps.append([]) + scores.append({}) + + for s, v in zip(steps, buffer_pos): + plt.plot(np.log10(s), np.log10(v)) + plt.xlabel("Step") + plt.ylabel("Loss") + plt.legend(model_names) + plt.savefig("positive_loss.png") + plt.close() + + for s, v in zip(steps, buffer_neg): + plt.plot(np.log10(s), np.log10(v)) + plt.xlabel("Step") + plt.ylabel("Loss") + plt.legend(model_names) + plt.savefig("negative_loss.png") + plt.close() + + for s, v, n in zip(steps, buffer_pos, buffer_neg): + plt.plot(np.log10(s), np.log10(np.array(v) + np.array(n))) + plt.xlabel("Step") + plt.ylabel("Loss") + plt.legend(model_names) + plt.savefig("overall_loss.png") + plt.close() + + import pandas as pd + s = pd.DataFrame.from_records(scores) + print(s.to_string) + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/training/evaluate.py b/scripts/training/evaluate.py index 69e23432..a5425523 100644 --- a/scripts/training/evaluate.py +++ b/scripts/training/evaluate.py @@ -1,20 +1,44 @@ import json import logging import os +import shutil from datetime import datetime from os import mkdir -from os.path import isdir, join +from os.path import isdir -from SourceCodeTools.code.data.sourcetrail.Dataset import read_or_create_dataset -from SourceCodeTools.models.graph import RGCNSampling, RGAN, RGGAN +from SourceCodeTools.code.data.dataset.Dataset import read_or_create_gnn_dataset +from SourceCodeTools.models.graph import RGGAN from SourceCodeTools.models.graph.train.utils import get_name, get_model_base -from params import rgcnsampling_params, rggan_params +from SourceCodeTools.models.training_options import add_gnn_train_args + + +def detect_checkpoint_files(path): + checkpoints = [] + for file in os.listdir(path): + filepath = os.path.join(path, file) + if not os.path.isfile: + continue + + if not file.startswith("saved_state_"): + continue + + epoch = int(file.split("_")[2].split(".")[0]) + checkpoints.append((epoch, filepath)) + + checckpoints = sorted(checkpoints, key=lambda x: x[0]) + + return checckpoints def main(models, args): for model, param_grid in models.items(): for params in param_grid: + if args.h_dim is None: + params["h_dim"] = args.node_emb_size + else: + params["h_dim"] = args.h_dim + date_time = str(datetime.now()) print("\n\n") print(date_time) @@ -24,29 +48,26 @@ def main(models, args): model_base = get_model_base(args, model_attempt) - dataset = read_or_create_dataset(args=args, model_base=model_base) + dataset = read_or_create_gnn_dataset(args=args, model_base=model_base) + + from SourceCodeTools.models.graph.train.sampling_multitask2 import evaluation_procedure - if args.training_mode == "multitask": + checkpoints = detect_checkpoint_files(model_base) - if args.intermediate_supervision: - # params['use_self_loop'] = True # ???? - from SourceCodeTools.models.graph.train.sampling_multitask_intermediate_supervision import evaluation_procedure - else: - from SourceCodeTools.models.graph.train.sampling_multitask2 import evaluation_procedure + for epoch, ckpt_path in checkpoints: + shutil.copy(ckpt_path, os.path.join(model_base, "saved_state.pt")) evaluation_procedure(dataset, model, params, args, model_base) - else: - raise ValueError("Issue! ", args.training_mode) if __name__ == "__main__": import argparse - from train import add_train_args + # from train import add_train_args parser = argparse.ArgumentParser(description='Process some integers.') - add_train_args(parser) + add_gnn_train_args(parser) args = parser.parse_args() @@ -54,11 +75,17 @@ def main(models, args): model_out = args.model_output_dir - saved_args = json.loads(open(os.path.join(model_out, "metadata.json")).read()) - - models_ = { - eval(saved_args.pop("name").split("-")[0]): [saved_args.pop("parameters")] - } + try: + saved_args = json.loads(open(os.path.join(model_out, "metadata.json")).read()) + models_ = { + eval(saved_args.pop("name").split("-")[0]): [saved_args.pop("parameters")] + } + except: + param_keys = 'activation', 'use_self_loop', 'num_steps', 'dropout', 'num_bases', 'lr' + saved_args = json.loads(open(os.path.join(model_out, "params.json")).read()) + models_ = { + RGGAN: [{key: saved_args[key] for key in param_keys}] + } args.__dict__.update(saved_args) args.restore_state = True diff --git a/scripts/training/evaluate_only.py b/scripts/training/evaluate_only.py new file mode 100644 index 00000000..1a122480 --- /dev/null +++ b/scripts/training/evaluate_only.py @@ -0,0 +1,105 @@ +import json +import logging +from copy import copy +from datetime import datetime +from os import mkdir +from os.path import isdir, join + +from SourceCodeTools.code.data.dataset.Dataset import read_or_create_gnn_dataset +from SourceCodeTools.models.graph import RGGAN +from SourceCodeTools.models.graph.train.utils import get_name, get_model_base +from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments +from params import rggan_params + + +def main(models, args): + for model, param_grid in models.items(): + for params in param_grid: + + if args.h_dim is None: + params["h_dim"] = args.node_emb_size + else: + params["h_dim"] = args.h_dim + + params["num_steps"] = args.n_layers + + date_time = str(datetime.now()) + print("\n\n") + print(date_time) + print(f"Model: {model.__name__}, Params: {params}") + + model_attempt = get_name(model, date_time) + + model_base = get_model_base(args, model_attempt) + + dataset = read_or_create_gnn_dataset(args=args, model_base=model_base) + + def write_params(args, params): + args = copy(args.__dict__) + args.update(params) + args['activation'] = args['activation'].__name__ + with open(join(model_base, "params.json"), "w") as mdata: + mdata.write(json.dumps(args, indent=4)) + + if not args.restore_state: + write_params(args, params) + + from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure + + trainer, scores = \ + training_procedure(dataset, model, copy(params), args, model_base) + + trainer.save_checkpoint(model_base) + + print("Saving...") + + params['activation'] = params['activation'].__name__ + + metadata = { + "base": model_base, + "name": model_attempt, + "parameters": params, + "layers": "embeddings.pkl", + "mappings": "nodes.csv", + "state": "state_dict.pt", + "scores": scores, + "time": date_time, + } + + metadata.update(args.__dict__) + + # pickle.dump(dataset, open(join(model_base, "dataset.pkl"), "wb")) + import pickle + pickle.dump(trainer.get_embeddings(), open(join(model_base, metadata['layers']), "wb")) + + with open(join(model_base, "metadata.json"), "w") as mdata: + mdata.write(json.dumps(metadata, indent=4)) + + print("Done saving") + + +if __name__ == "__main__": + + import argparse + + parser = argparse.ArgumentParser(description='') + add_gnn_train_args(parser) + + args = parser.parse_args() + verify_arguments(args) + + logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s") + + models_ = { + # GCNSampling: gcnsampling_params, + # GATSampler: gatsampling_params, + # RGCNSampling: rgcnsampling_params, + # RGAN: rgcnsampling_params, + RGGAN: rggan_params + + } + + if not isdir(args.model_output_dir): + mkdir(args.model_output_dir) + + main(models_, args) diff --git a/scripts/training/params.py b/scripts/training/params.py index 85444fce..0a5a350c 100644 --- a/scripts/training/params.py +++ b/scripts/training/params.py @@ -68,12 +68,12 @@ rggan_grids = [ { - 'h_dim': [100], - 'num_bases': [-1], - 'num_steps': [5], - 'dropout': [0.0], - 'use_self_loop': [False], - 'activation': [torch.nn.functional.leaky_relu], # torch.nn.functional.hardswish], #[torch.nn.functional.hardtanh], #torch.nn.functional.leaky_relu + 'h_dim': [100], # set from cli + 'num_bases': [10], + 'num_steps': [9], # set from cli + 'dropout': [0.2], + 'use_self_loop': [True], + 'activation': [torch.tanh], # torch.nn.functional.hardswish], #[torch.nn.functional.hardtanh], #torch.nn.functional.leaky_relu 'lr': [1e-3], # 1e-4] } ] diff --git a/scripts/training/substitute_model.py b/scripts/training/substitute_model.py new file mode 100644 index 00000000..001d8898 --- /dev/null +++ b/scripts/training/substitute_model.py @@ -0,0 +1,122 @@ +import json +import logging +from copy import copy +from datetime import datetime +from os import mkdir +from os.path import isdir, join + +from SourceCodeTools.code.data.dataset.Dataset import read_or_create_gnn_dataset +from SourceCodeTools.models.graph import RGGAN +from SourceCodeTools.models.graph.train.utils import get_name, get_model_base +from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments +from params import rggan_params + + +def main(models, args): + for model, param_grid in models.items(): + for params in param_grid: + + if args.h_dim is None: + params["h_dim"] = args.node_emb_size + else: + params["h_dim"] = args.h_dim + + params["num_steps"] = args.n_layers + + date_time = str(datetime.now()) + print("\n\n") + print(date_time) + print(f"Model: {model.__name__}, Params: {params}") + + model_attempt = get_name(model, date_time) + + model_base = get_model_base(args, model_attempt) + + dataset = read_or_create_gnn_dataset(args=args, model_base=model_base) + + if args.external_dataset is not None: + external_args = copy(args) + external_args.data_path = external_args.external_dataset + external_args.external_model_base = get_model_base(external_args, model_attempt, force_new=True) + def load_external_dataset(): + return external_args, read_or_create_gnn_dataset(args=external_args, + model_base=external_args.external_model_base, + force_new=True) + else: + load_external_dataset = None + + def write_params(args, params): + args = copy(args.__dict__) + args.update(params) + args['activation'] = args['activation'].__name__ + with open(join(model_base, "params.json"), "w") as mdata: + mdata.write(json.dumps(args, indent=4)) + + if not args.restore_state: + write_params(args, params) + + from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure + + trainer, scores = \ + training_procedure(dataset, model, copy(params), args, model_base, load_external_dataset=load_external_dataset) + + if load_external_dataset is not None: + model_base = external_args.external_model_base + + trainer.save_checkpoint(model_base) + + print("Saving...") + + params['activation'] = params['activation'].__name__ + + metadata = { + "base": model_base, + "name": model_attempt, + "parameters": params, + "layers": "embeddings.pkl", + "mappings": "nodes.csv", + "state": "state_dict.pt", + "scores": scores, + "time": date_time, + } + + metadata.update(args.__dict__) + + # pickle.dump(dataset, open(join(model_base, "dataset.pkl"), "wb")) + import pickle + pickle.dump(trainer.get_embeddings(), open(join(model_base, metadata['layers']), "wb")) + + with open(join(model_base, "metadata.json"), "w") as mdata: + mdata.write(json.dumps(metadata, indent=4)) + + print("Done saving") + + + + + +if __name__ == "__main__": + + import argparse + + parser = argparse.ArgumentParser(description='') + add_gnn_train_args(parser) + + args = parser.parse_args() + verify_arguments(args) + + logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s") + + models_ = { + # GCNSampling: gcnsampling_params, + # GATSampler: gatsampling_params, + # RGCNSampling: rgcnsampling_params, + # RGAN: rgcnsampling_params, + RGGAN: rggan_params + + } + + if not isdir(args.model_output_dir): + mkdir(args.model_output_dir) + + main(models_, args) diff --git a/scripts/training/train.py b/scripts/training/train.py index 840fa42f..227a3752 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -5,13 +5,17 @@ from os import mkdir from os.path import isdir, join -from SourceCodeTools.code.data.sourcetrail.Dataset import read_or_create_dataset -from SourceCodeTools.models.graph import RGCNSampling, RGAN, RGGAN +from SourceCodeTools.code.data.dataset.Dataset import read_or_create_gnn_dataset +from SourceCodeTools.models.graph import RGGAN +from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure from SourceCodeTools.models.graph.train.utils import get_name, get_model_base -from params import rgcnsampling_params, rggan_params +from SourceCodeTools.models.training_config import get_config, load_config, update_config, save_config +from SourceCodeTools.models.training_options import add_gnn_train_args, verify_arguments +from params import rggan_params -def main(models, args): +def train_grid(models, args): + for model, param_grid in models.items(): for params in param_grid: @@ -31,7 +35,7 @@ def main(models, args): model_base = get_model_base(args, model_attempt) - dataset = read_or_create_dataset(args=args, model_base=model_base) + dataset = read_or_create_gnn_dataset(args=args, model_base=model_base) def write_params(args, params): args = copy(args.__dict__) @@ -40,24 +44,17 @@ def write_params(args, params): with open(join(model_base, "params.json"), "w") as mdata: mdata.write(json.dumps(args, indent=4)) - write_params(args, params) - - if args.training_mode == "multitask": + if not args.restore_state: + write_params(args, params) - # if args.intermediate_supervision: - # # params['use_self_loop'] = True # ???? - # from SourceCodeTools.models.graph.train.sampling_multitask_intermediate_supervision import training_procedure - # else: - from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure + from SourceCodeTools.models.graph.train.sampling_multitask2 import training_procedure - trainer, scores = \ - training_procedure(dataset, model, copy(params), args, model_base) + trainer, scores = \ + training_procedure(dataset, model, copy(params), args, model_base) - trainer.save_checkpoint(model_base) - else: - raise ValueError("Unknown training mode:", args.training_mode) + trainer.save_checkpoint(model_base) - print("Saving...", end="") + print("Saving...") params['activation'] = params['activation'].__name__ @@ -81,107 +78,98 @@ def write_params(args, params): with open(join(model_base, "metadata.json"), "w") as mdata: mdata.write(json.dumps(metadata, indent=4)) - print("done") - + print("Done saving") -def add_data_arguments(parser): - parser.add_argument('--data_path', '-d', dest='data_path', default=None, help='Path to the files') - parser.add_argument('--train_frac', dest='train_frac', default=0.9, type=float, help='') - parser.add_argument('--filter_edges', dest='filter_edges', default=None, help='Edges filtered before training') - parser.add_argument('--min_count_for_objectives', dest='min_count_for_objectives', default=5, type=int, help='') - parser.add_argument('--packages_file', dest='packages_file', default=None, type=str, help='') - parser.add_argument('--self_loops', action='store_true') - parser.add_argument('--use_node_types', action='store_true') - parser.add_argument('--use_edge_types', action='store_true') - parser.add_argument('--restore_state', action='store_true') - parser.add_argument('--no_global_edges', action='store_true') - parser.add_argument('--remove_reverse', action='store_true') +def main(args): -def add_pretraining_arguments(parser): - parser.add_argument('--pretrained', '-p', dest='pretrained', default=None, help='') - parser.add_argument('--tokenizer', '-t', dest='tokenizer', default=None, help='') - parser.add_argument('--pretraining_phase', dest='pretraining_phase', default=0, type=int, help='') + args = copy(args.__dict__) + config_path = args.pop("config") + if config_path is None: + config = get_config(**args) + else: + config = load_config(config_path) -def add_training_arguments(parser): - parser.add_argument('--embedding_table_size', dest='embedding_table_size', default=200000, type=int, help='Batch size') - parser.add_argument('--random_seed', dest='random_seed', default=None, type=int, help='') + model = RGGAN - parser.add_argument('--node_emb_size', dest='node_emb_size', default=100, type=int, help='') - parser.add_argument('--elem_emb_size', dest='elem_emb_size', default=100, type=int, help='') - parser.add_argument('--num_per_neigh', dest='num_per_neigh', default=10, type=int, help='') - parser.add_argument('--neg_sampling_factor', dest='neg_sampling_factor', default=3, type=int, help='') + date_time = str(datetime.now()) + print("\n\n") + print(date_time) - parser.add_argument('--use_layer_scheduling', action='store_true') - parser.add_argument('--schedule_layers_every', dest='schedule_layers_every', default=10, type=int, help='') + restore_state = config["TRAINING"]["restore_state"] - parser.add_argument('--epochs', dest='epochs', default=100, type=int, help='Number of epochs') - parser.add_argument('--batch_size', dest='batch_size', default=128, type=int, help='Batch size') + model_attempt = get_name(model, date_time) - parser.add_argument("--h_dim", dest="h_dim", default=None, type=int) - parser.add_argument("--n_layers", dest="n_layers", default=5, type=int) - parser.add_argument("--objectives", dest="objectives", default=None, type=str) - - -def add_scoring_arguments(parser): - parser.add_argument('--measure_ndcg', action='store_true') - parser.add_argument('--dilate_ndcg', dest='dilate_ndcg', default=200, type=int, help='') + model_base = get_model_base(config["TRAINING"], model_attempt) + dataset = read_or_create_gnn_dataset( + args={**config["DATASET"], **config["TOKENIZER"]}, + model_base=model_base, restore_state=restore_state + ) -def add_performance_arguments(parser): - parser.add_argument('--no_checkpoints', dest="save_checkpoints", action='store_false') + if not restore_state: + save_config(config, join(model_base, "config.yaml")) - parser.add_argument('--use_gcn_checkpoint', action='store_true') - parser.add_argument('--use_att_checkpoint', action='store_true') - parser.add_argument('--use_gru_checkpoint', action='store_true') + trainer, scores = training_procedure( + dataset, + model_name=model, + model_params=config["MODEL"], + trainer_params=config["TRAINING"], + tokenizer_path=config["TOKENIZER"]["tokenizer_path"], + model_base_path=model_base + ) + trainer.save_checkpoint(model_base) -def add_train_args(parser): - parser.add_argument( - '--training_mode', '-tr', dest='training_mode', default=None, - help='Selects one of training procedures [multitask]' - ) + print("Saving...") - add_data_arguments(parser) - add_pretraining_arguments(parser) - add_training_arguments(parser) - add_scoring_arguments(parser) - add_performance_arguments(parser) + metadata = { + "base": model_base, + "name": model_attempt, + "layers": "embeddings.pkl", + "mappings": "nodes.csv", + "state": "state_dict.pt", + "scores": scores, + "time": date_time, + } - parser.add_argument('--note', dest='note', default="", help='Note, added to metadata') - parser.add_argument('model_output_dir', help='Location of the final model') + metadata["config"] = args - # parser.add_argument('--intermediate_supervision', action='store_true') - parser.add_argument('--gpu', dest='gpu', default=-1, type=int, help='') + # pickle.dump(dataset, open(join(model_base, "dataset.pkl"), "wb")) + import pickle + pickle.dump(trainer.get_embeddings(), open(join(model_base, metadata['layers']), "wb")) + with open(join(model_base, "metadata.json"), "w") as mdata: + mdata.write(json.dumps(metadata, indent=4)) -def verify_arguments(args): - pass + print("Done saving") if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Process some integers.') - add_train_args(parser) + parser = argparse.ArgumentParser(description='') + add_gnn_train_args(parser) args = parser.parse_args() verify_arguments(args) logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(module)s:%(lineno)d:%(message)s") - models_ = { - # GCNSampling: gcnsampling_params, - # GATSampler: gatsampling_params, - # RGCNSampling: rgcnsampling_params, - # RGAN: rgcnsampling_params, - RGGAN: rggan_params - - } + # models_ = { + # # GCNSampling: gcnsampling_params, + # # GATSampler: gatsampling_params, + # # RGCNSampling: rgcnsampling_params, + # # RGAN: rgcnsampling_params, + # RGGAN: rggan_params + # + # } if not isdir(args.model_output_dir): mkdir(args.model_output_dir) - main(models_, args) + main(args) + + # main(models_, args) diff --git a/scripts/training/train_all_emb_types.sh b/scripts/training/train_all_emb_types.sh index 4ddd0545..1e2dd656 100755 --- a/scripts/training/train_all_emb_types.sh +++ b/scripts/training/train_all_emb_types.sh @@ -1,7 +1,16 @@ +DATASET=$1 +DGLKE_OUT=$2 + +#conda activate SourceCodeTools +#prepare_dglke_format.py $DATASET $DGLKE_OUT +#conda deactivate + conda activate dglke -dglke_train --model_name TransR --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path transr_u --hidden_dim 100 --num_thread 4 --gpu 0 -dglke_train --model_name RESCAL --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path RESCAL_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 160000 -dglke_train --model_name DistMult --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path DistMult_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 400000 -dglke_train --model_name ComplEx --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path ComplEx_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 400000 -dglke_train --model_name RotatE --dataset code --data_path . --data_files edges_train_dglke.tsv held_dglkg.tsv held_dglkg.tsv --format raw_udd_htr --save_path RotatE_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 200000 -de + +dglke_train --model_name TransR --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/transr_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test +dglke_train --model_name RESCAL --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/RESCAL_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test +dglke_train --model_name DistMult --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/DistMult_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test +dglke_train --model_name ComplEx --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/ComplEx_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -adv --test +dglke_train --model_name RotatE --dataset code --data_path $DGLKE_OUT --data_files edges_train_dglke.tsv edges_eval_dglke_10000.tsv edges_eval_dglke_10000.tsv --format raw_udd_htr --save_path $DGLKE_OUT/RotatE_u --hidden_dim 100 --num_thread 4 --gpu 0 --max_step 1000000 -de -adv --test + conda deactivate diff --git a/setup.py b/setup.py index 7262d3af..bb7f9977 100644 --- a/setup.py +++ b/setup.py @@ -1,31 +1,34 @@ from distutils.core import setup -requitements = [ - 'nltk', - 'tensorflow>=2.4.0', - 'torch>=1.7.1', - 'pandas>=1.1.1', - 'sklearn', - 'sentencepiece', - 'gensim', - 'numpy>=1.19.2', - 'scipy', - 'networkx', - 'sacrebleu', - 'datasets', +requirements = [ + 'nltk==3.6', + 'tensorflow==2.6.2', + 'torch==1.9.0', + 'pandas==1.1.1', + 'scikit-learn==1.0', + 'sentencepiece==0.1.96', + 'gensim==3.8', + 'numpy==1.19.5', + 'scipy==1.4.1', + 'networkx==2.5', + 'sacrebleu==1.5.1', + 'datasets==1.5.0', 'spacy==2.3.2', - 'pytest' + 'pytest==6.1.2', + 'faiss-cpu==1.7.0', + 'tqdm==4.49.0' # 'pygraphviz' # 'javac_parser' ] - +# conda install pytorch cudatoolkit=11.1 dgl-cuda11.1 -c dglteam -c pytorch -c nvidia setup(name='SourceCodeTools', - version='0.0.2', + version='0.0.3', py_modules=['SourceCodeTools'], - install_requires=requitements + ["dgl"], + install_requires=requirements + ["dgl==0.7.1"], extras_require={ - "gpu": requitements + ["dgl-cu110"] + "gpu": requirements + ["dgl-cu111==0.7.1"] }, + dependency_links=['https://data.dgl.ai/wheels/repo.html'], scripts=[ 'SourceCodeTools/code/data/sourcetrail/sourcetrail_call_seq_extractor.py', 'SourceCodeTools/code/data/sourcetrail/sourcetrail_extract_node_names.py', @@ -45,7 +48,8 @@ 'SourceCodeTools/code/data/sourcetrail/sourcetrail_node_local2global.py', 'SourceCodeTools/code/data/sourcetrail/sourcetrail_connected_component.py', 'SourceCodeTools/code/data/sourcetrail/pandas_format_converter.py', - 'SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py', +# 'SourceCodeTools/code/data/sourcetrail/sourcetrail_create_type_annotation_dataset.py', 'SourceCodeTools/nlp/embed/converters/convert_fasttext_format_bin_to_vec.py', + 'SourceCodeTools/models/graph/utils/prepare_dglke_format.py', ], -) + )