diff --git a/keras/format/data/kotlin_sequence.json b/keras/format/data/kotlin_sequence.json new file mode 100644 index 0000000..999c1d3 --- /dev/null +++ b/keras/format/data/kotlin_sequence.json @@ -0,0 +1,1099 @@ +{ + "entries": { + "GoldErrors": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 12, + 14, + 15, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 16, + 22, + 4, + 23, + 6, + 24, + 25, + 14, + 27, + 29, + 30, + 28, + 22, + 26, + 31, + 32, + 6, + 24, + 25, + 14, + 27, + 29, + 30, + 28, + 22, + 26, + 31, + 32, + 6, + 24, + 25, + 14, + 27, + 29, + 30, + 28, + 22, + 26, + 31, + 32, + 6, + 24, + 25, + 14, + 27, + 29, + 30, + 28, + 22, + 26, + 31, + 32, + 6, + 24, + 25, + 14, + 27, + 29, + 30, + 28, + 22, + 26, + 6, + 33, + 13, + 6, + 11 + ] + }, + "ModeType": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 7, + 9, + 6, + 8, + 12, + 4, + 23, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 33, + 13, + 6, + 11 + ] + }, + "GroovyData": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 12, + 14, + 6, + 15, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 18, + 16, + 6, + 22, + 4, + 23, + 6, + 33, + 13, + 6, + 11 + ] + }, + "Qid": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 6, + 38, + 40, + 6, + 39, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 7, + 9, + 6, + 8, + 12, + 14, + 6, + 15, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 16, + 6, + 22, + 4, + 23, + 6, + 24, + 12, + 4, + 23, + 6, + 24, + 24, + 34, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 41, + 14, + 6, + 27, + 24, + 24, + 24, + 24, + 17, + 5, + 37, + 29, + 30, + 18, + 31, + 32, + 6, + 24, + 24, + 24, + 24, + 17, + 5, + 37, + 29, + 30, + 18, + 31, + 32, + 6, + 24, + 24, + 24, + 24, + 17, + 5, + 37, + 29, + 30, + 18, + 28, + 6, + 22, + 42, + 30, + 35, + 6, + 33, + 13, + 6, + 33, + 13, + 6, + 11 + ] + }, + "Qid2": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 6, + 38, + 40, + 6, + 39, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 12, + 14, + 15, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 18, + 16, + 22, + 4, + 23, + 6, + 24, + 12, + 4, + 23, + 6, + 24, + 24, + 34, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 41, + 14, + 27, + 28, + 22, + 42, + 30, + 35, + 6, + 33, + 13, + 6, + 33, + 13, + 6, + 11 + ] + }, + "GoldConstants": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 7, + 9, + 6, + 8, + 12, + 4, + 23, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 33, + 13, + 6, + 11 + ] + }, + "GoldEnums": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 7, + 9, + 6, + 9, + 6, + 8, + 12, + 14, + 15, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 16, + 22, + 4, + 23, + 6, + 24, + 25, + 14, + 27, + 29, + 30, + 28, + 22, + 26, + 31, + 32, + 6, + 24, + 25, + 14, + 27, + 29, + 30, + 28, + 22, + 26, + 6, + 33, + 13, + 6, + 11, + 6, + 10, + 12, + 4, + 23, + 6, + 24, + 25, + 26, + 31, + 32, + 6, + 24, + 25, + 26, + 31, + 32, + 6, + 24, + 25, + 26, + 31, + 32, + 6, + 24, + 25, + 26, + 6, + 33, + 13, + 6, + 11 + ] + }, + "CryptoCurrency": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 12, + 4, + 23, + 6, + 24, + 25, + 26, + 31, + 32, + 6, + 24, + 25, + 26, + 31, + 32, + 6, + 24, + 25, + 26, + 6, + 33, + 13, + 6, + 11 + ] + }, + "Money": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 6, + 38, + 40, + 6, + 40, + 6, + 39, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 12, + 14, + 6, + 15, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 18, + 16, + 6, + 22, + 4, + 23, + 6, + 24, + 12, + 4, + 23, + 6, + 24, + 24, + 34, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 41, + 14, + 27, + 28, + 22, + 42, + 30, + 35, + 6, + 33, + 13, + 6, + 33, + 13, + 6, + 11 + ] + }, + "GoldBuffer": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 12, + 14, + 6, + 15, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 18, + 31, + 32, + 6, + 24, + 17, + 19, + 4, + 5, + 20, + 4, + 21, + 18, + 16, + 6, + 22, + 4, + 23, + 6, + 33, + 13, + 6, + 11 + ] + }, + "DataReader": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 43, + 12, + 4, + 23, + 6, + 45, + 47, + 48, + 46, + 45, + 47, + 48, + 49, + 24, + 17, + 18, + 31, + 32, + 6, + 24, + 17, + 18, + 31, + 32, + 6, + 24, + 17, + 18, + 50, + 46, + 33, + 13, + 6, + 44 + ] + }, + "ModeTypeXml": { + "sequence": [ + 1, + 3, + 4, + 5, + 6, + 6, + 2, + 7, + 9, + 6, + 9, + 6, + 8, + 6, + 10, + 7, + 9, + 6, + 8, + 12, + 4, + 23, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 24, + 34, + 36, + 4, + 19, + 4, + 5, + 20, + 4, + 21, + 4, + 37, + 4, + 29, + 30, + 35, + 6, + 33, + 13, + 6, + 11 + ] + } + }, + "name": "Kotlin AST Training Sequences", + "type": "LSTMTrainingSequences", + "author": "mrco", + "version": "1.0", + "process": "PrepareTrainDataFromASTXml", + "updateDate": "2025-10-13T18:43:34.903419" +} \ No newline at end of file diff --git a/keras/format/data/kotlin_vocab.json b/keras/format/data/kotlin_vocab.json index 6e88ee7..ffb86e0 100644 --- a/keras/format/data/kotlin_vocab.json +++ b/keras/format/data/kotlin_vocab.json @@ -4,7 +4,7 @@ "author": "mrco", "version": "1.0", "process": "PrepareTrainDataFromASTXml", - "updateDate": "2025-09-20T18:40:58.186300", + "updateDate": "2025-10-13T18:43:34.902845", "dictionary": { "NamespaceDeclaration_open": { "id": 1, diff --git a/keras/format/data/lstm-kotlin-n4_v50_u64.h1.keras.88571 b/keras/format/data/lstm-kotlin-n4_v50_u64.h1.keras.88571 new file mode 100644 index 0000000..882fe06 Binary files /dev/null and b/keras/format/data/lstm-kotlin-n4_v50_u64.h1.keras.88571 differ diff --git a/keras/format/data/lstm-kotlin-n4_v50_u96.h1.keras.92673 b/keras/format/data/lstm-kotlin-n4_v50_u96.h1.keras.92673 new file mode 100644 index 0000000..ae9b92e Binary files /dev/null and b/keras/format/data/lstm-kotlin-n4_v50_u96.h1.keras.92673 differ diff --git a/keras/format/data/lstm-kotlin-n4_v50_u96.h1.keras.96366 b/keras/format/data/lstm-kotlin-n4_v50_u96.h1.keras.96366 new file mode 100644 index 0000000..fad3835 Binary files /dev/null and b/keras/format/data/lstm-kotlin-n4_v50_u96.h1.keras.96366 differ diff --git a/keras/format/format_xml.py b/keras/format/format_xml.py new file mode 100644 index 0000000..6e7d705 --- /dev/null +++ b/keras/format/format_xml.py @@ -0,0 +1,43 @@ +import getpass + +from sequences import Dictionary, DictionaryOperations +from sequences import Sequence, SequenceOperations +from lstm_formatter import XmlOperations +from lstm_formatter import LSTMFormatter + +if __name__ == "__main__": + KOTLIN_VOCAB_FILE = "./data/kotlin_vocab.json" + process = "PrepareTrainDataFromASTXml" + inp_words = 4 + units = 96 + sequence_operations = SequenceOperations() + dictionary_operations = DictionaryOperations() + xml_operations = XmlOperations() + + dictionary = dictionary_operations.load( + filepath = KOTLIN_VOCAB_FILE, + username = getpass.getuser(), + process = process + ) + + model_filename = f"./data/lstm-kotlin-n{inp_words}_v{dictionary.size()}_u{units}.h1.keras" + sequences = xml_operations.loadSequencesUseCase( + directory="../../generated/kotlin/", + filename="output_tree_Kotlin.xml", + dictionary= dictionary + ) + if sequences.is_err(): + print(f"Error loading sequences: {sequences.unwrap_err()}") + exit(1) + sequences = sequences.unwrap() + sequences.author = getpass.getuser() + sequences.process = process + + formatter = LSTMFormatter(inp_words=inp_words) + if (not formatter.loadModel(model_filename)): + print(f"Error loading model") + exit(1) + + # formatter.trainModel(sequences) + # formatter.model.save("./data/lstm-kotlin-n4.h1.keras") + diff --git a/keras/format/lstm_formatter/.gitignore b/keras/format/lstm_formatter/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/keras/format/lstm_formatter/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/keras/format/lstm_formatter/LSTMFormatter.py b/keras/format/lstm_formatter/LSTMFormatter.py new file mode 100644 index 0000000..dc45ca7 --- /dev/null +++ b/keras/format/lstm_formatter/LSTMFormatter.py @@ -0,0 +1,68 @@ +import os +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import LSTM, Dense, Embedding +from keras import optimizers +from tensorflow.keras.models import load_model +from sequences import Sequence +from sequences import Dictionary +import numpy as np +from tensorflow.keras.utils import to_categorical +from keras.callbacks import ModelCheckpoint + +class LSTMFormatter: + def __init__(self, inp_words: int = 4): + self.inp_words = inp_words + self.paddingVec = [0] * (inp_words - 1) + self.filename = "" + self.model = None + self.rms = None + + def loadModel(self, filename: str) -> bool: + self.filename = filename + self.rms = optimizers.RMSprop(learning_rate=0.0005) + if os.path.exists(filename): + self.model = load_model(filename) + self.model.compile(optimizer=self.rms, loss='sparse_categorical_crossentropy') + return True + return False + + + def defineModel(self, units: int, dictionary: Dictionary, filename: str): + self.filename = filename + self.rms = optimizers.RMSprop(learning_rate=0.0005) + if os.path.exists(filename): + self.loadModel(filename) + return + self.model = Sequential() + dictionary_size = dictionary.size() + 1 # +1 for padding token + self.model.add(Embedding(dictionary_size, + output_dim=units, + input_length=self.inp_words, + mask_zero=True)) + self.model.add(LSTM(units)) + self.model.add(Dense(dictionary_size, activation='softmax')) + self.model.build(input_shape=(None, self.inp_words)) + self.model.summary() + self.model.compile(optimizer=self.rms, loss='sparse_categorical_crossentropy') + + def trainModel(self, sequence: Sequence): + vectors = [it['sequence'] for it in sequence.entries.values()] + vectors = [self.paddingVec + sb for sb in vectors] + X = [] + Y = [] + for sb in vectors: + for i in range(len(sb) - self.inp_words): + X.append(sb[i:i + self.inp_words]) + Y.append(sb[i + self.inp_words]) + X = np.array(X) + Y = np.array(Y) + print(f"X shape: {X.shape}, Y shape: {Y.shape}") + + checkpoint = ModelCheckpoint(self.filename, monitor='val_loss', verbose=1, save_best_only=True, mode='min') + history = self.model.fit(x = X, + y = Y, + batch_size=16, + validation_split = 0.2, + callbacks=[checkpoint], + epochs=4096) + diff --git a/keras/format/lstm_formatter/XmlOperations.py b/keras/format/lstm_formatter/XmlOperations.py new file mode 100644 index 0000000..bb6f03c --- /dev/null +++ b/keras/format/lstm_formatter/XmlOperations.py @@ -0,0 +1,120 @@ +from lxml import etree +from sequences import Sequence +from sequences import Dictionary +import copy +from datetime import datetime +from result import Result, Err, Ok + +class XmlOperations: + OPEN_ONLY_TAGS = [ + "WorkingDirectory", + "PackageDirectory", + "VariableName", + "CommentLeaf", + "AstTypeLeaf", + "ImportLeaf", + "Space", + "NlSeparator", + "Indent", + "Keyword" + ] + SKIP_TAGS = [ + "FileMetaInformation" + ] + + def __init__(self): + pass + + def process_childs(self, elem, vocab, id): + for child in elem: + if child.tag in XmlOperations.SKIP_TAGS: + ## Skip this tag and its children + continue + tagName = child.tag + tagCanBeClosed = tagName not in XmlOperations.OPEN_ONLY_TAGS + openTag = f"{tagName}_open" if tagCanBeClosed else f"{tagName}" + if tagName == "Keyword": + openTag = f"{tagName}_{child.attrib['name']}" + + if openTag not in vocab: + vocab[openTag] = {"id": id, "priority": 0} + id += 1 + + if tagCanBeClosed: + closeTag = f"{tagName}_close" + if closeTag not in vocab: + vocab[closeTag] = {"id": id, "priority": 0} + id += 1 + + id = self.process_childs(child, vocab, id) + + return id + + def refreshDictionaryUseCase(self, directory, filename, dictionary: Dictionary) -> Dictionary: + tree = etree.parse(directory + "/" + filename) + root = tree.getroot() + newDictionary = copy.deepcopy(dictionary) + id = newDictionary.nextId() + for child in root: + print(f"Child tag: {child.tag}, attributes: {child.attrib}") + id = self.process_childs(child, newDictionary.entries, id) + + print(f"Vocabulary size: {len(newDictionary.entries)}") + newDictionary.updateDate = datetime.now().isoformat() + return newDictionary + + def prepareTrainingSequencesUseCase(self, directory, filename, dictionary: Dictionary) -> Sequence: + tree = etree.parse(directory + "/" + filename) + root = tree.getroot() + sequences = Sequence(username="", process="") + + for child in root: + blockName = child.attrib['name'] + # print(f"Child tag: {child.tag}, attributes: {child.attrib['name']}") + sequence = [] + self.process_childs_for_sequence(child, dictionary, sequence) + sequences.entries[blockName] = {"sequence": sequence} + + print(f"Training sequence length: {len(sequences.entries)}") + return sequences + + def process_childs_for_sequence(self, elem, dictionary: Dictionary, sequence: list): + for child in elem: + if child.tag in XmlOperations.SKIP_TAGS: + ## Skip this tag and its children + continue + tagName = child.tag + tagCanBeClosed = tagName not in XmlOperations.OPEN_ONLY_TAGS + openTag = f"{tagName}_open" if tagCanBeClosed else f"{tagName}" + if tagName == "Keyword": + openTag = f"{tagName}_{child.attrib['name']}" + + if openTag in dictionary.entries: + sequence.append(dictionary.entries[openTag]["id"]) + + if tagCanBeClosed: + closeTag = f"{tagName}_close" + if closeTag in dictionary.entries: + # Process children first (depth-first) + self.process_childs_for_sequence(child, dictionary, sequence) + sequence.append(dictionary.entries[closeTag]["id"]) + else: + raise Exception(f"'{closeTag}' not found in vocabulary") + return sequence + + def loadSequencesUseCase(self, directory, filename, dictionary: Dictionary) -> Result[Sequence, str]: + tree = etree.parse(directory + "/" + filename) + root = tree.getroot() + sequences = Sequence(username="", process="") + + for child in root: + blockName = child.attrib['name'] + sequence = [] + try: + self.process_childs_for_sequence(child, dictionary, sequence) + except Exception as e: + return Err(f"Error processing block '{blockName}': {str(e)}") + sequences.entries[blockName] = {"sequence": sequence} + + print(f"Sequences count: {len(sequences.entries)}") + return Ok(sequences) diff --git a/keras/format/lstm_formatter/__init__.py b/keras/format/lstm_formatter/__init__.py new file mode 100644 index 0000000..4975cc2 --- /dev/null +++ b/keras/format/lstm_formatter/__init__.py @@ -0,0 +1,6 @@ +# lstm_formatter/__init__.py + +from .XmlOperations import XmlOperations +from .LSTMFormatter import LSTMFormatter + +__all__ = ['XmlOperations', 'LSTMFormatter'] \ No newline at end of file diff --git a/keras/format/sequences/Dictionary.py b/keras/format/sequences/Dictionary.py index da01ec6..7bf1cc0 100644 --- a/keras/format/sequences/Dictionary.py +++ b/keras/format/sequences/Dictionary.py @@ -12,6 +12,14 @@ def __init__(self, username: str, process: str): self.process = process self.updateDate = datetime.now().isoformat() + def size(self): + return len(self.entries) + + def nextId(self): + if not self.entries: + return 1 + return max(entry["id"] for entry in self.entries.values()) + 1 + class DictionaryOperations: def __init__(self): self.data = {} diff --git a/keras/format/sequences/__init__.py b/keras/format/sequences/__init__.py index 097a70c..9bc98f7 100644 --- a/keras/format/sequences/__init__.py +++ b/keras/format/sequences/__init__.py @@ -1,4 +1,4 @@ -# styxlib/__init__.py +# sequences/__init__.py from .Dictionary import Dictionary, DictionaryOperations from .Sequence import Sequence, SequenceOperations diff --git a/keras/format/train_from_xml.py b/keras/format/train_from_xml.py index 8c6f82f..f2ff710 100644 --- a/keras/format/train_from_xml.py +++ b/keras/format/train_from_xml.py @@ -1,134 +1,19 @@ -# from tensorflow.keras.models import Sequential -# from tensorflow.keras.layers import Embedding, LSTM, RepeatVector, Dense -# from tensorflow.keras.utils import plot_model -# from keras import optimizers -# import numpy as np -# from tensorflow.keras.utils import to_categorical -# from keras.models import load_model -# from keras.callbacks import ModelCheckpoint -from lxml import etree -import json import getpass -from datetime import datetime from sequences import Dictionary, DictionaryOperations from sequences import Sequence, SequenceOperations - -OPEN_ONLY_TAGS = [ - "WorkingDirectory", - "PackageDirectory", - "VariableName", - "CommentLeaf", - "AstTypeLeaf", - "ImportLeaf", - "Space", - "NlSeparator", - "Indent", - "Keyword" - ] -SKIP_TAGS = [ - "FileMetaInformation" -] - - -def process_childs(elem, - vocab, - id): - for child in elem: - if child.tag in SKIP_TAGS: - ## Skip this tag and its children - continue - tagName = child.tag - tagCanBeClosed = tagName not in OPEN_ONLY_TAGS - openTag = f"{tagName}_open" if tagCanBeClosed else f"{tagName}" - if tagName == "Keyword": - openTag = f"{tagName}_{child.attrib['name']}" - - if openTag not in vocab: - vocab[openTag] = {"id": id, "priority": 0} - id += 1 - - if tagCanBeClosed: - closeTag = f"{tagName}_close" - if closeTag not in vocab: - vocab[closeTag] = {"id": id, "priority": 0} - id += 1 - - id = process_childs(child, vocab, id) - - return id - -def read_and_update_vocab(directory, filename, vocab): - tree = etree.parse(directory + "/" + filename) - root = tree.getroot() - - if vocab: - id = max(entry["id"] for entry in vocab.values()) + 1 - else: - id = 1 - - for child in root: - print(f"Child tag: {child.tag}, attributes: {child.attrib}") - id = process_childs(child, vocab, id) - - print(f"Vocabulary size: {len(vocab)}") - return vocab - -def read_and_split_out_tree(directory, filename): - tree = etree.parse(directory + "/" + filename) - root = tree.getroot() - - for child in root: - output_filename = f"{directory}/{child.attrib['name']}.xml" - # print(output_filename) - with open(output_filename, "wb") as f: - f.write(etree.tostring(child, pretty_print=True, xml_declaration=True, encoding="UTF-8")) - -def process_childs_for_sequence(elem, vocab, sequence): - for child in elem: - if child.tag in SKIP_TAGS: - ## Skip this tag and its children - continue - tagName = child.tag - tagCanBeClosed = tagName not in OPEN_ONLY_TAGS - openTag = f"{tagName}_open" if tagCanBeClosed else f"{tagName}" - if tagName == "Keyword": - openTag = f"{tagName}_{child.attrib['name']}" - - if openTag in vocab: - sequence.append(vocab[openTag]["id"]) - - if tagCanBeClosed: - closeTag = f"{tagName}_close" - if closeTag in vocab: - # Process children first (depth-first) - process_childs_for_sequence(child, vocab, sequence) - sequence.append(vocab[closeTag]["id"]) - else: - raise Exception(f"'{closeTag}' not found in vocabulary") - return sequence - -def read_and_prepare_training_sequence(directory, filename, vocab, sequences: Sequence): - tree = etree.parse(directory + "/" + filename) - root = tree.getroot() - - for child in root: - blockName = child.attrib['name'] - # print(f"Child tag: {child.tag}, attributes: {child.attrib['name']}") - sequence = [] - process_childs_for_sequence(child, vocab, sequence) - sequences.entries[blockName] = {"sequence": sequence} - - print(f"Training sequence length: {len(sequences.entries)}") - return sequences - +from lstm_formatter import XmlOperations +from lstm_formatter import LSTMFormatter if __name__ == "__main__": KOTLIN_VOCAB_FILE = "./data/kotlin_vocab.json" sequence_file = "./data/kotlin_sequence.json" process = "PrepareTrainDataFromASTXml" + inp_words = 4 + units = 96 sequence_operations = SequenceOperations() dictionary_operations = DictionaryOperations() + xml_operations = XmlOperations() dictionary = dictionary_operations.load( filepath = KOTLIN_VOCAB_FILE, @@ -136,23 +21,29 @@ def read_and_prepare_training_sequence(directory, filename, vocab, sequences: Se process = process ) - read_and_update_vocab( + dictionary = xml_operations.refreshDictionaryUseCase( directory="../../generated/kotlin/", filename="output_tree_formatted_Kotlin.xml", - vocab = dictionary.entries + dictionary = dictionary ) - dictionary.updateDate = datetime.now().isoformat() dictionary_operations.store(dictionary, KOTLIN_VOCAB_FILE) - sequences = Sequence( - username = getpass.getuser(), - process = process - ) - - read_and_prepare_training_sequence( + model_filename = f"./data/lstm-kotlin-n{inp_words}_v{dictionary.size()}_u{units}.h1.keras" + sequences = xml_operations.prepareTrainingSequencesUseCase( directory="../../generated/kotlin/", filename="output_tree_formatted_Kotlin.xml", - vocab = dictionary.entries, - sequences = sequences + dictionary= dictionary ) + sequences.author = getpass.getuser() + sequences.process = process sequence_operations.store(sequences, sequence_file) + + formatter = LSTMFormatter(inp_words=inp_words) + formatter.defineModel( + units=units, + dictionary=dictionary, + filename=model_filename + ) + formatter.trainModel(sequences) + # formatter.model.save("./data/lstm-kotlin-n4.h1.keras") + diff --git a/test/project.json b/test/project.json index d6f37a8..f91471e 100644 --- a/test/project.json +++ b/test/project.json @@ -5,7 +5,6 @@ "./test/constants.kts", "./test/enums.kts", "./test/dataclass.kts", - "./test/interface.kts", "./test/constants.xml" ], "targets": [