From 6b5191818eeb39c87c8722425d01d0254c43011f Mon Sep 17 00:00:00 2001 From: hwz0827 Date: Fri, 17 Mar 2023 14:15:52 +0000 Subject: [PATCH 1/8] [feat] add SKT example --- examples/SKT/SKT.ipynb | 223 +++++++++++++++++++++++++++++ examples/SKT/SKT.py | 26 ++++ examples/SKT/prepare_dataset.ipynb | 166 +++++++++++++++++++++ 3 files changed, 415 insertions(+) create mode 100644 examples/SKT/SKT.ipynb create mode 100644 examples/SKT/SKT.py create mode 100644 examples/SKT/prepare_dataset.ipynb diff --git a/examples/SKT/SKT.ipynb b/examples/SKT/SKT.ipynb new file mode 100644 index 0000000..47521ad --- /dev/null +++ b/examples/SKT/SKT.ipynb @@ -0,0 +1,223 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Structure-based Knowledge Tracing" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook will show you how to train and use the SKT. First, we will show how to get the data (here we use a0910 as the dataset). Then we will show how to train a GKT and perform the parameters persistence. At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n", + "\n", + "\n", + "The script version could be found in [SKT.py](https://github.com/bigdata-ustc/EduKTM/blob/main/examples/SKT/SKT.py)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Preparation\n", + "\n", + "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](https://github.com/bigdata-ustc/EduKTM/blob/main/examples/SKT/prepare_dataset.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/huangweizhe/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "reading data from ../../data/assistment_2009_2010/train.json: 3025it [00:01, 2397.38it/s]\n", + "batchify: 100%|██████████| 327/327 [00:00<00:00, 432.65it/s]\n", + "reading data from ../../data/assistment_2009_2010/test.json: 856it [00:00, 3526.76it/s]\n", + "/data/huangweizhe/EduKTM/EduKTM/utils/torch/extlib/sampler.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[104, 108, 115, 119, 120, 122, 125, 127, 129, 130, 134, 143, 147, 157, 159, 160, 163, 165, 166, 169, 173, 174, 178, 181, 184, 188, 189, 192, 193, 194, 196]\n", + " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", + "batchify: 100%|██████████| 134/134 [00:00<00:00, 618.56it/s]\n", + "reading data from ../../data/assistment_2009_2010/test.json: 856it [00:00, 2156.08it/s]\n", + "batchify: 100%|██████████| 134/134 [00:00<00:00, 500.12it/s]\n" + ] + } + ], + "source": [ + "from EduKTM.GKT import etl\n", + "\n", + "batch_size = 16\n", + "train = etl(\"../../data/assistment_2009_2010/train.json\", batch_size=batch_size)\n", + "valid = etl(\"../../data/assistment_2009_2010/test.json\", batch_size=batch_size)\n", + "test = etl(\"../../data/assistment_2009_2010/test.json\", batch_size=batch_size)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training and Persistence" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.getLogger().setLevel(logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|██████████| 327/327 [05:52<00:00, 1.08s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 0] SLMoss: 0.445898\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "evaluating: 100%|██████████| 134/134 [01:27<00:00, 1.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 0] auc: 0.594213, accuracy: 0.631270\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 1: 100%|██████████| 327/327 [06:32<00:00, 1.20s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 1] SLMoss: 0.433631\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "evaluating: 100%|██████████| 134/134 [01:26<00:00, 1.56it/s]\n", + "INFO:root:save parameters to skt.params\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 1] auc: 0.624451, accuracy: 0.681644\n" + ] + } + ], + "source": [ + "from EduKTM import SKT\n", + "import torch\n", + "\n", + "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n", + "model = SKT(ku_num=124, graph_params=[\n", + " ['../../data/assistment_2009_2010/correct_transition_graph.json', True],\n", + " ['../../data/assistment_2009_2010/ctrans_sim.json', False]\n", + " ], hidden_num=5)\n", + "model.train(train, valid, epoch=2, device=device)\n", + "model.save(\"skt.params\")\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading and Testing" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:load parameters from skt.params\n", + "evaluating: 100%|██████████| 134/134 [01:13<00:00, 1.81it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "auc: 0.624451, accuracy: 0.681644\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model.load(\"skt.params\")\n", + "auc, accuracy = model.eval(test, device=device)\n", + "print(\"auc: %.6f, accuracy: %.6f\" % (auc, accuracy))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/SKT/SKT.py b/examples/SKT/SKT.py new file mode 100644 index 0000000..fbdd098 --- /dev/null +++ b/examples/SKT/SKT.py @@ -0,0 +1,26 @@ +# coding: utf-8 + + +import logging +from EduKTM.GKT import etl +from EduKTM import SKT +import torch + +batch_size = 16 +train = etl("../../data/assistment_2009_2010/train.json", batch_size=batch_size) +valid = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) +test = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) + +logging.getLogger().setLevel(logging.INFO) + +device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") +model = SKT(ku_num=124, graph_params=[ + ['../../data/assistment_2009_2010/correct_transition_graph.json', True], + ['../../data/assistment_2009_2010/ctrans_sim.json', False] + ], hidden_num=5) +model.train(train, valid, epoch=2, device=device) +model.save("skt.params") + +model.load("skt.params") +auc, accuracy = model.eval(test, device=device) +print("auc: %.6f, accuracy: %.6f" % (auc, accuracy)) diff --git a/examples/SKT/prepare_dataset.ipynb b/examples/SKT/prepare_dataset.ipynb new file mode 100644 index 0000000..4ca7c23 --- /dev/null +++ b/examples/SKT/prepare_dataset.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/correct_transition_graph.json is saved as ../../data/assistment_2009_2010/correct_transition_graph.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ../../data/assistment_2009_2010/correct_transition_graph.json 100.00%: 34.1KB | 34.1KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/ctrans_sim.json is saved as ../../data/assistment_2009_2010/ctrans_sim.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/ctrans_sim.json 100.00%: 200KB | 200KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/dense_graph.json is saved as ../../data/assistment_2009_2010/dense_graph.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/dense_graph.json 100.00%: 361KB | 361KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/test.json is saved as ../../data/assistment_2009_2010/test.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/test.json 100.00%: 1.02MB | 1.02MB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/train.json is saved as ../../data/assistment_2009_2010/train.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/train.json 100.00%: 3.23MB | 3.23MB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/trans_sim.json is saved as ../../data/assistment_2009_2010/trans_sim.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/trans_sim.json 100.00%: 350KB | 350KB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/transition_graph.json is saved as ../../data/assistment_2009_2010/transition_graph.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ../../data/assistment_2009_2010/transition_graph.json 100.00%: 52.9KB | 52.9KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/valid.json is saved as ../../data/assistment_2009_2010/valid.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/valid.json 100.00%: 310KB | 310KB\n" + ] + }, + { + "data": { + "text/plain": [ + "'../../data/assistment_2009_2010'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from EduData import get_data\n", + "\n", + "get_data(\"ktbd-a0910\", \"../../data\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 85675806a79a78589228ce2b096f650ff647cdb6 Mon Sep 17 00:00:00 2001 From: hwz0827 Date: Fri, 17 Mar 2023 14:28:04 +0000 Subject: [PATCH 2/8] [feat] add SKT --- EduKTM/GKT/GKT.py | 9 +- EduKTM/GKT/GKTNet.py | 17 ++-- EduKTM/SKT/SKT.py | 93 +++++++++++++++++++ EduKTM/SKT/SKTNet.py | 176 ++++++++++++++++++++++++++++++++++++ EduKTM/SKT/__init__.py | 2 + EduKTM/SKT/etl.py | 80 +++++++++++++++++ EduKTM/SKT/utils.py | 198 +++++++++++++++++++++++++++++++++++++++++ EduKTM/__init__.py | 1 + examples/SKT/SKT.py | 6 +- 9 files changed, 567 insertions(+), 15 deletions(-) create mode 100644 EduKTM/SKT/SKT.py create mode 100644 EduKTM/SKT/SKTNet.py create mode 100644 EduKTM/SKT/__init__.py create mode 100644 EduKTM/SKT/etl.py create mode 100644 EduKTM/SKT/utils.py diff --git a/EduKTM/GKT/GKT.py b/EduKTM/GKT/GKT.py index 3d477a8..1dc042c 100644 --- a/EduKTM/GKT/GKT.py +++ b/EduKTM/GKT/GKT.py @@ -24,9 +24,10 @@ def __init__(self, ku_num, graph, hidden_num, net_params: dict = None, loss_para self.loss_params = loss_params if loss_params is not None else {} def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: - loss_function = SLMLoss(**self.loss_params) + loss_function = SLMLoss(**self.loss_params).to(device) + self.gkt_model = self.gkt_model.to(device) trainer = torch.optim.Adam(self.gkt_model.parameters(), lr) - + for e in range(epoch): losses = [] for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e): @@ -52,7 +53,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00 print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses)))) if test_data is not None: - auc, accuracy = self.eval(test_data) + auc, accuracy = self.eval(test_data, device=device) print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy)) def eval(self, test_data, device="cpu") -> tuple: @@ -75,7 +76,7 @@ def eval(self, test_data, device="cpu") -> tuple: output = pick(output, pick_index.to(output.device)) pred = tensor2list(output) label = tensor2list(label) - for i, length in enumerate(label_mask.numpy().tolist()): + for i, length in enumerate(label_mask.cpu().tolist()): length = int(length) y_true.extend(label[i][:length]) y_pred.extend(pred[i][:length]) diff --git a/EduKTM/GKT/GKTNet.py b/EduKTM/GKT/GKTNet.py index ee82579..b9b091e 100644 --- a/EduKTM/GKT/GKTNet.py +++ b/EduKTM/GKT/GKTNet.py @@ -38,7 +38,7 @@ def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0) def in_weight(self, x, ordinal=True, with_weight=True): if isinstance(x, torch.Tensor): - x = x.numpy().tolist() + x = x.cpu().numpy().tolist() if isinstance(x, list): return [self.in_weight(_x) for _x in x] elif isinstance(x, (int, float)): @@ -57,7 +57,7 @@ def in_weight(self, x, ordinal=True, with_weight=True): def out_weight(self, x, ordinal=True, with_weight=True): if isinstance(x, torch.Tensor): - x = x.numpy().tolist() + x = x.cpu().numpy().tolist() if isinstance(x, list): return [self.out_weight(_x) for _x in x] elif isinstance(x, (int, float)): @@ -76,7 +76,7 @@ def out_weight(self, x, ordinal=True, with_weight=True): def neighbors(self, x, ordinal=True, with_weight=False): if isinstance(x, torch.Tensor): - x = x.numpy().tolist() + x = x.cpu().numpy().tolist() if isinstance(x, list): return [self.neighbors(_x) for _x in x] elif isinstance(x, (int, float)): @@ -95,10 +95,11 @@ def neighbors(self, x, ordinal=True, with_weight=False): def forward(self, questions, answers, valid_length=None, compressed_out=True, layout="NTC"): length = questions.shape[1] + device = questions.device inputs, axis, batch_size = format_sequence(length, questions, layout, False) answers, _, _ = format_sequence(length, answers, layout, False) - states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0] + states = states.to(device) outputs = [] all_states = [] for i in range(length): @@ -107,8 +108,8 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la answer_i = answers[i].reshape([batch_size, ]) _neighbors = self.neighbors(inputs_i) - neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num) - _neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num + self.latent_dim) + neighbors_mask = expand_tensor(torch.tensor(_neighbors, device=device), -1, self.hidden_num) + _neighbors_mask = expand_tensor(torch.tensor(_neighbors, device=device), -1, self.hidden_num + self.latent_dim) # get concept embedding concept_embeddings = self.concept_embedding.weight.data @@ -133,8 +134,8 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la _in_state = self.n_in(_hidden_states) _out_state = self.n_out(_hidden_states) - in_weight = expand_tensor(torch.Tensor(self.in_weight(inputs_i)), -1, self.hidden_num) - out_weight = expand_tensor(torch.Tensor(self.out_weight(inputs_i)), -1, self.hidden_num) + in_weight = expand_tensor(torch.tensor(self.in_weight(inputs_i), device=device), -1, self.hidden_num) + out_weight = expand_tensor(torch.tensor(self.out_weight(inputs_i), device=device), -1, self.hidden_num) next_neighbors_states = in_weight * _in_state + out_weight * _out_state diff --git a/EduKTM/SKT/SKT.py b/EduKTM/SKT/SKT.py new file mode 100644 index 0000000..8701097 --- /dev/null +++ b/EduKTM/SKT/SKT.py @@ -0,0 +1,93 @@ +# coding: utf-8 + +import logging +import numpy as np +import torch +from tqdm import tqdm +from EduKTM import KTM +from .SKTNet import SKTNet +from EduKTM.utils import SLMLoss, tensor2list, pick +from sklearn.metrics import roc_auc_score, accuracy_score + + +class SKT(KTM): + def __init__(self, ku_num, graph_params, hidden_num, net_params: dict = None, loss_params=None): + super(SKT, self).__init__() + self.skt_model = SKTNet( + ku_num, + graph_params, + hidden_num, + **(net_params if net_params is not None else {}) + ) + self.loss_params = loss_params if loss_params is not None else {} + + def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: + loss_function = SLMLoss(**self.loss_params).to(device) + self.skt_model = self.skt_model.to(device) + trainer = torch.optim.Adam(self.skt_model.parameters(), lr) + + for e in range(epoch): + losses = [] + for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e): + # convert to device + question: torch.Tensor = question.to(device) + data: torch.Tensor = data.to(device) + data_mask: torch.Tensor = data_mask.to(device) + label: torch.Tensor = label.to(device) + pick_index: torch.Tensor = pick_index.to(device) + label_mask: torch.Tensor = label_mask.to(device) + + # real training + predicted_response, _ = self.skt_model( + question, data, data_mask) + + loss = loss_function(predicted_response, + pick_index, label, label_mask) + + # back propagation + trainer.zero_grad() + loss.backward() + trainer.step() + + losses.append(loss.mean().item()) + print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses)))) + + if test_data is not None: + auc, accuracy = self.eval(test_data, device=device) + print("[Epoch %d] auc: %.6f, accuracy: %.6f" % + (e, auc, accuracy)) + + def eval(self, test_data, device="cpu") -> tuple: + self.skt_model.eval() + y_true = [] + y_pred = [] + + for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"): + # convert to device + question: torch.Tensor = question.to(device) + data: torch.Tensor = data.to(device) + data_mask: torch.Tensor = data_mask.to(device) + label: torch.Tensor = label.to(device) + pick_index: torch.Tensor = pick_index.to(device) + label_mask: torch.Tensor = label_mask.to(device) + + # real evaluating + output, _ = self.skt_model(question, data, data_mask) + output = output[:, :-1] + output = pick(output, pick_index.to(output.device)) + pred = tensor2list(output) + label = tensor2list(label) + for i, length in enumerate(label_mask.cpu().tolist()): + length = int(length) + y_true.extend(label[i][:length]) + y_pred.extend(pred[i][:length]) + self.skt_model.train() + return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5) + + def save(self, filepath) -> ...: + torch.save(self.skt_model.state_dict(), filepath) + logging.info("save parameters to %s" % filepath) + + def load(self, filepath): + self.skt_model.load_state_dict(torch.load(filepath)) + logging.info("load parameters from %s" % filepath) diff --git a/EduKTM/SKT/SKTNet.py b/EduKTM/SKT/SKTNet.py new file mode 100644 index 0000000..2660d1a --- /dev/null +++ b/EduKTM/SKT/SKTNet.py @@ -0,0 +1,176 @@ +__all__ = ["SKTNet"] + + +import torch +import torch.nn as nn +from torch import nn +import torch.nn.functional as F +from EduKTM.utils import GRUCell, begin_states, get_states, expand_tensor, \ + format_sequence, mask_sequence_variable_length +from .utils import Graph + + +class SKTNet(nn.Module): + def __init__(self, ku_num, graph_params=None, + alpha=0.5, + latent_dim=None, activation=None, + hidden_num=90, concept_dim=None, + # dropout=0.5, self_dropout=0.0, + dropout=0.0, self_dropout=0.5, + # dropout=0.0, self_dropout=0.0, + sync_dropout=0.0, + prop_dropout=0.0, + agg_dropout=0.0, + params=None): + super(SKTNet, self).__init__() + self.ku_num = int(ku_num) + self.hidden_num = self.ku_num if hidden_num is None else int( + hidden_num) + self.latent_dim = self.hidden_num if latent_dim is None else int( + latent_dim) + self.concept_dim = self.hidden_num if concept_dim is None else int( + concept_dim) + graph_params = graph_params if graph_params is not None else [] + self.graph = Graph.from_file(ku_num, graph_params) + self.alpha = alpha + + sync_activation = nn.ReLU() if activation is None else activation + prop_activation = nn.ReLU() if activation is None else activation + agg_activation = nn.ReLU() if activation is None else activation + + self.rnn = GRUCell(self.hidden_num) + self.response_embedding = nn.Embedding( + 2 * self.ku_num, self.latent_dim) + self.concept_embedding = nn.Embedding(self.ku_num, self.concept_dim) + self.f_self = GRUCell(self.hidden_num) + self.self_dropout = nn.Dropout(self_dropout) + self.f_prop = nn.Sequential( + nn.Linear(self.hidden_num * 2, self.hidden_num), + prop_activation, + nn.Dropout(prop_dropout), + ) + self.f_sync = nn.Sequential( + nn.Linear(self.hidden_num * 3, self.hidden_num), + sync_activation, + nn.Dropout(sync_dropout), + ) + self.f_agg = nn.Sequential( + nn.Linear(self.hidden_num, self.hidden_num), + agg_activation, + nn.Dropout(agg_dropout), + ) + self.dropout = nn.Dropout(dropout) + self.out = nn.Linear(self.hidden_num, 1) + self.sigmoid = nn.Sigmoid() + + def neighbors(self, x, ordinal=True): + return self.graph.neighbors(x, ordinal) + + def successors(self, x, ordinal=True): + return self.graph.successors(x, ordinal) + + def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True, + *args, **kwargs): + length = questions.shape[1] + device = questions.device + inputs, axis, batch_size = format_sequence( + length, questions, layout, False) + answers, _, _ = format_sequence(length, answers, layout, False) + states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0] + states = states.to(device) + outputs = [] + all_states = [] + for i in range(length): + inputs_i = inputs[i].reshape([batch_size, ]) + answer_i = answers[i].reshape([batch_size, ]) + + # concept embedding + concept_embeddings = self.concept_embedding.weight.data + concept_embeddings = expand_tensor( + concept_embeddings, 0, batch_size) + # concept_embeddings = (_self_mask + _successors_mask + _neighbors_mask) * concept_embeddings + + # self - influence + _self_state = get_states(inputs_i, states) + # fc + # _next_self_state = self.f_self(mx.nd.concat(_self_state, self.response_embedding(answers[i]), dim=-1)) + # gru + _next_self_state, _ = self.f_self( + self.response_embedding(answer_i), [_self_state]) + # _next_self_state = self.f_self(mx.nd.concat(_self_hidden_states, _self_state)) + # _next_self_state, _ = self.f_self(_self_hidden_states, [_self_state]) + _next_self_state = self.self_dropout(_next_self_state) + + # get self mask + _self_mask = torch.unsqueeze(F.one_hot(inputs_i, self.ku_num), -1) + _self_mask = torch.broadcast_to( + _self_mask, (-1, -1, self.hidden_num)) + + # find neighbors + _neighbors = self.neighbors(inputs_i) + _neighbors_mask = torch.unsqueeze( + torch.tensor(_neighbors, device=device), -1) + _neighbors_mask = torch.broadcast_to( + _neighbors_mask, (-1, -1, self.hidden_num)) + + # synchronization + _broadcast_next_self_states = torch.unsqueeze(_next_self_state, 1) + _broadcast_next_self_states = torch.broadcast_to( + _broadcast_next_self_states, (-1, self.ku_num, -1)) + # _sync_diff = mx.nd.concat(states, _broadcast_next_self_states, concept_embeddings, dim=-1) + _sync_diff = torch.concat( + (states, _broadcast_next_self_states, concept_embeddings), dim=-1) + _sync_inf = _neighbors_mask * self.f_sync(_sync_diff) + + # reflection on current vertex + _reflec_inf = torch.sum(_sync_inf, dim=1) + _reflec_inf = torch.broadcast_to( + torch.unsqueeze(_reflec_inf, 1), (-1, self.ku_num, -1)) + _sync_inf = _sync_inf + _self_mask * _reflec_inf + + # find successors + _successors = self.successors(inputs_i) + _successors_mask = torch.unsqueeze( + torch.tensor(_successors, device=device), -1) + _successors_mask = torch.broadcast_to( + _successors_mask, (-1, -1, self.hidden_num)) + + # propagation + _prop_diff = torch.concat( + (_next_self_state - _self_state, self.concept_embedding(inputs_i)), dim=-1) + # _prop_diff = _next_self_state - _self_state + + # 1 + _prop_inf = self.f_prop(_prop_diff) + _prop_inf = _successors_mask * \ + torch.broadcast_to(torch.unsqueeze( + _prop_inf, axis=1), (-1, self.ku_num, -1)) + # 2 + # _broadcast_diff = mx.nd.broadcast_to(mx.nd.expand_dims(_prop_diff, axis=1), (0, self.ku_num, 0)) + # _pro_inf = _successors_mask * self.f_prop( + # mx.nd.concat(_broadcast_diff, concept_embeddings, dim=-1) + # ) + # _pro_inf = _successors_mask * self.f_prop( + # _broadcast_diff + # ) + + # aggregate + _inf = self.f_agg(self.alpha * _sync_inf + + (1 - self.alpha) * _prop_inf) + next_states, _ = self.rnn(_inf, [states]) + # next_states, _ = self.rnn(torch.concat((_inf, concept_embeddings), dim=-1), [states]) + # states = (1 - _self_mask) * next_states + _self_mask * _broadcast_next_self_states + states = next_states + output = self.sigmoid(torch.squeeze( + self.out(self.dropout(states)), axis=-1)) + outputs.append(output) + if valid_length is not None and not compressed_out: + all_states.append([states]) + + if valid_length is not None: + if compressed_out: + states = None + outputs = mask_sequence_variable_length( + torch, outputs, length, valid_length, axis, merge=True) + + return outputs, states diff --git a/EduKTM/SKT/__init__.py b/EduKTM/SKT/__init__.py new file mode 100644 index 0000000..69108dd --- /dev/null +++ b/EduKTM/SKT/__init__.py @@ -0,0 +1,2 @@ +from .SKT import SKT +from .etl import etl diff --git a/EduKTM/SKT/etl.py b/EduKTM/SKT/etl.py new file mode 100644 index 0000000..3e91131 --- /dev/null +++ b/EduKTM/SKT/etl.py @@ -0,0 +1,80 @@ +# coding: utf-8 + + +import torch +import json +from tqdm import tqdm +from EduKTM.utils.torch import PadSequence, FixedBucketSampler + + +def extract(data_src, max_step=200): # pragma: no cover + responses = [] + step = max_step + with open(data_src) as f: + for line in tqdm(f, "reading data from %s" % data_src): + data = json.loads(line) + if step is not None: + for i in range(0, len(data), step): + if len(data[i: i + step]) < 2: + continue + responses.append(data[i: i + step]) + else: + responses.append(data) + + return responses + + +def transform(raw_data, batch_size, num_buckets=100): + # 定义数据转换接口 + # raw_data --> batch_data + + responses = raw_data + + batch_idxes = FixedBucketSampler( + [len(rs) for rs in responses], batch_size, num_buckets=num_buckets) + batch = [] + + def index(r): + correct = 0 if r[1] <= 0 else 1 + return r[0] * 2 + correct + + for batch_idx in tqdm(batch_idxes, "batchify"): + batch_qs = [] + batch_rs = [] + batch_pick_index = [] + batch_labels = [] + for idx in batch_idx: + batch_qs.append([r[0] for r in responses[idx]]) + batch_rs.append([index(r) for r in responses[idx]]) + if len(responses[idx]) <= 1: # pragma: no cover + pick_index, labels = [], [] + else: + pick_index, labels = zip( + *[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) + batch_pick_index.append(list(pick_index)) + batch_labels.append(list(labels)) + + max_len = max([len(rs) for rs in batch_rs]) + padder = PadSequence(max_len, pad_val=0) + batch_qs = [padder(qs) for qs in batch_qs] + batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) + + max_len = max([len(rs) for rs in batch_labels]) + padder = PadSequence(max_len, pad_val=0) + batch_labels, label_mask = zip( + *[(padder(labels), len(labels)) for labels in batch_labels]) + batch_pick_index = [padder(pick_index) + for pick_index in batch_pick_index] + # Load + batch.append( + [torch.tensor(batch_qs), torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels), + torch.tensor(batch_pick_index), + torch.tensor(label_mask)]) + + return batch + + +def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover + batch_size = batch_size if batch_size is not None else cfg.batch_size + raw_data = extract(data_src) + return transform(raw_data, batch_size, **kwargs) diff --git a/EduKTM/SKT/utils.py b/EduKTM/SKT/utils.py new file mode 100644 index 0000000..4bcf746 --- /dev/null +++ b/EduKTM/SKT/utils.py @@ -0,0 +1,198 @@ +# coding: utf-8 + + +import json +import torch +import networkx as nx +__all__ = ["Graph", "load_graph"] + + +def as_list(obj) -> list: + if isinstance(obj, list): + return obj + elif isinstance(obj, tuple): + return list(obj) + else: + return [obj] + + +class Graph(object): + def __init__(self, ku_num, directed_graphs, undirected_graphs): + self.ku_num = ku_num + self.directed_graphs = as_list(directed_graphs) + self.undirected_graphs = as_list(undirected_graphs) + + @staticmethod + def _info(graph: nx.Graph): + return {"edges": len(graph.edges)} + + @property + def info(self): + return { + "directed": [self._info(graph) for graph in self.directed_graphs], + "undirected": [self._info(graph) for graph in self.undirected_graphs] + } + + def neighbors(self, x, ordinal=True, merge_to_one=True, with_weight=False, excluded=None): + excluded = set() if excluded is None else excluded + + if isinstance(x, torch.Tensor): + x = x.tolist() + if isinstance(x, list): + return [self.neighbors(_x) for _x in x] + elif isinstance(x, (int, float)): + if not ordinal: + if len(self.undirected_graphs) == 0: + return None if not merge_to_one else [] + elif len(self.undirected_graphs) == 1: + return [v for v in self.undirected_graphs[0].neighbors(int(x)) if v not in excluded] + else: + if not merge_to_one: + return [[v for v in graph.neighbors(int(x)) if v not in excluded] for graph in + self.undirected_graphs] + else: + _ret = [] + for graph in self.undirected_graphs: + _ret.extend([v for v in graph.neighbors( + int(x)) if v not in excluded]) + return _ret + else: # ordinal + if not merge_to_one: + if len(self.undirected_graphs) == 0: + return None + elif len(self.undirected_graphs) == 1: + graph = self.undirected_graphs[0] + _ret = [0] * self.ku_num + for i in graph.neighbors(int(x)): + if i in excluded: + continue + if with_weight: + _ret[i] = graph[x][i].get('weight', 1) + else: + _ret[i] = 1 + return _ret + else: + _ret = [] + for graph in self.undirected_graphs: + __ret = [0] * self.ku_num + for i in graph.neighbors(int(x)): + if i in excluded: + continue + if with_weight: + __ret[i] = graph[x][i].get('weight', 1) + else: + __ret[i] = 1 + _ret.append(__ret) + else: + if len(self.undirected_graphs) == 0: + return [0] * self.ku_num + else: + _ret = [0] * self.ku_num + for graph in self.undirected_graphs: + for i in graph.neighbors(int(x)): + if i in excluded: + continue + if with_weight: + _ret[i] += graph[x][i].get('weight', 1) + else: + _ret[i] = 1 + return _ret + else: + raise TypeError("cannot handle %s" % type(x)) + + def successors(self, x, ordinal=True, merge_to_one=True, excluded=None): + excluded = set() if excluded is None else excluded + + if isinstance(x, torch.Tensor): + x = x.tolist() + if isinstance(x, list): + return [self.neighbors(_x) for _x in x] + elif isinstance(x, (int, float)): + if not ordinal: + if len(self.directed_graphs) == 0: + return None if not merge_to_one else [] + elif len(self.directed_graphs) == 1: + return [v for v in self.directed_graphs[0].successors(int(x)) if v not in excluded] + else: + if not merge_to_one: + return [[v for v in graph.successors(int(x)) if v not in excluded] for graph in + self.directed_graphs] + else: + _ret = [] + for graph in self.directed_graphs: + _ret.extend([v for v in graph.successors( + int(x)) if v not in excluded]) + return _ret + else: + if not merge_to_one: + if len(self.directed_graphs) == 0: + return None + elif len(self.directed_graphs) == 1: + _ret = [0] * self.ku_num + for i in self.directed_graphs[0].successors(int(x)): + if i in excluded: + continue + _ret[i] = 1 + return _ret + else: + _ret = [] + for graph in self.directed_graphs: + __ret = [0] * self.ku_num + for i in graph.successors(int(x)): + if i in excluded: + continue + _ret[i] = 1 + _ret.append(__ret) + else: + if len(self.directed_graphs) == 0: + return [0] * self.ku_num + else: + _ret = [0] * self.ku_num + for graph in self.directed_graphs: + for i in graph.successors(int(x)): + if i in excluded: + continue + _ret[i] = 1 + return _ret + else: + raise TypeError("cannot handle %s" % type(x)) + + @classmethod + def from_file(cls, graph_nodes_num, graph_params): + directed_graphs = [] + undirected_graphs = [] + for graph_param in graph_params: + graph, directed = load_graph( + graph_nodes_num, *as_list(graph_param)) + if directed: + directed_graphs.append(graph) + else: + undirected_graphs.append(graph) + return cls(graph_nodes_num, directed_graphs, undirected_graphs) + + +def load_graph(graph_nodes_num, filename=None, directed: bool = True, threshold=0.0): + directed = bool(directed) + if directed: + graph = nx.DiGraph() + else: + graph = nx.Graph() + + graph.add_nodes_from(range(graph_nodes_num)) + if threshold < 0.0: + for i in range(graph_nodes_num): + for j in range(graph_nodes_num): + graph.add_edge(i, j) + else: + assert filename is not None + with open(filename) as f: + for data in json.load(f): + pre, suc = data[0], data[1] + if len(data) >= 3 and float(data[2]) < threshold: + continue + elif len(data) >= 3: + weight = float(data[2]) + graph.add_edge(pre, suc, weight=weight) + continue + graph.add_edge(pre, suc) + return graph, directed diff --git a/EduKTM/__init__.py b/EduKTM/__init__.py index 6dc7d36..7fdc264 100644 --- a/EduKTM/__init__.py +++ b/EduKTM/__init__.py @@ -10,3 +10,4 @@ from .LPKT import LPKT from .GKT import GKT from .DKVMN import DKVMN +from .SKT import SKT diff --git a/examples/SKT/SKT.py b/examples/SKT/SKT.py index fbdd098..5d0543f 100644 --- a/examples/SKT/SKT.py +++ b/examples/SKT/SKT.py @@ -1,13 +1,13 @@ # coding: utf-8 - import logging from EduKTM.GKT import etl from EduKTM import SKT import torch batch_size = 16 -train = etl("../../data/assistment_2009_2010/train.json", batch_size=batch_size) +train = etl("../../data/assistment_2009_2010/train.json", + batch_size=batch_size) valid = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) test = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) @@ -17,7 +17,7 @@ model = SKT(ku_num=124, graph_params=[ ['../../data/assistment_2009_2010/correct_transition_graph.json', True], ['../../data/assistment_2009_2010/ctrans_sim.json', False] - ], hidden_num=5) + ], hidden_num=5) model.train(train, valid, epoch=2, device=device) model.save("skt.params") From d7e9f1bb4d13bbf7e7441552a6ec7b1e591b5430 Mon Sep 17 00:00:00 2001 From: hwz0827 Date: Fri, 17 Mar 2023 15:28:41 +0000 Subject: [PATCH 3/8] [docs] update docs for skt --- AUTHORS.md | 2 ++ EduKTM/SKT/SKT.py | 1 + EduKTM/SKT/SKTNet.py | 2 ++ EduKTM/SKT/etl.py | 1 + EduKTM/SKT/utils.py | 1 + README.md | 1 + docs/SKT.md | 14 ++++++++++++++ examples/SKT/SKT.py | 1 + tests/skt/__init__.py | 2 ++ tests/skt/conftest.py | 34 ++++++++++++++++++++++++++++++++++ tests/skt/test_skt.py | 12 ++++++++++++ 11 files changed, 71 insertions(+) create mode 100644 docs/SKT.md create mode 100644 tests/skt/__init__.py create mode 100644 tests/skt/conftest.py create mode 100644 tests/skt/test_skt.py diff --git a/AUTHORS.md b/AUTHORS.md index f8933ea..373da82 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -10,4 +10,6 @@ [Jie Ouyang](https://github.com/0russwest0) +[Weizhe Huang](https://github.com/weizhehuang0827) + The starred is the corresponding author diff --git a/EduKTM/SKT/SKT.py b/EduKTM/SKT/SKT.py index 8701097..58e490b 100644 --- a/EduKTM/SKT/SKT.py +++ b/EduKTM/SKT/SKT.py @@ -1,4 +1,5 @@ # coding: utf-8 +# 2023/3/17 @ weizhehuang0827 import logging import numpy as np diff --git a/EduKTM/SKT/SKTNet.py b/EduKTM/SKT/SKTNet.py index 2660d1a..a67cb4d 100644 --- a/EduKTM/SKT/SKTNet.py +++ b/EduKTM/SKT/SKTNet.py @@ -1,3 +1,5 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 __all__ = ["SKTNet"] diff --git a/EduKTM/SKT/etl.py b/EduKTM/SKT/etl.py index 3e91131..a9043c3 100644 --- a/EduKTM/SKT/etl.py +++ b/EduKTM/SKT/etl.py @@ -1,4 +1,5 @@ # coding: utf-8 +# 2023/3/17 @ weizhehuang0827 import torch diff --git a/EduKTM/SKT/utils.py b/EduKTM/SKT/utils.py index 4bcf746..d6cc224 100644 --- a/EduKTM/SKT/utils.py +++ b/EduKTM/SKT/utils.py @@ -1,4 +1,5 @@ # coding: utf-8 +# 2023/3/17 @ weizhehuang0827 import json diff --git a/README.md b/README.md index 67167dc..eef70eb 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Knowledge Tracing (KT), which aims to monitor students’ evolving knowledge sta * [GKT](EduKTM/GKT)[[doc]](docs/GKT.md) [[example]](examples/GKT) * [AKT](EduKTM/AKT) [[doc]](docs/AKT.md) [[example]](examples/AKT) * [LPKT](EduKTM/LPKT) [[doc]](docs/LPKT.md) [[example]](examples/LPKT) +* [SKT](EduKTM/SKT) [[doc]](docs/SKT.md) [[example]](examples/SKT) ## Contribute diff --git a/docs/SKT.md b/docs/SKT.md new file mode 100644 index 0000000..8b10e61 --- /dev/null +++ b/docs/SKT.md @@ -0,0 +1,14 @@ +# Structure-based Knowledge Tracing (SKT) + +If the reader wants to know the details of SKT, please refer to the Appendix of the paper: *[Structure-based Knowledge Tracing: An Influence Propagation View](https://rlgm.github.io/papers/70.pdf)*. + +```bibtex +@inproceedings{tong2020structure, + title={Structure-based knowledge tracing: An influence propagation view}, + author={Tong, Shiwei and Liu, Qi and Huang, Wei and Hunag, Zhenya and Chen, Enhong and Liu, Chuanren and Ma, Haiping and Wang, Shijin}, + booktitle={2020 IEEE international conference on data mining (ICDM)}, + pages={541--550}, + year={2020}, + organization={IEEE} +} +``` diff --git a/examples/SKT/SKT.py b/examples/SKT/SKT.py index 5d0543f..16b3d01 100644 --- a/examples/SKT/SKT.py +++ b/examples/SKT/SKT.py @@ -1,4 +1,5 @@ # coding: utf-8 +# 2023/3/17 @ weizhehuang0827 import logging from EduKTM.GKT import etl diff --git a/tests/skt/__init__.py b/tests/skt/__init__.py new file mode 100644 index 0000000..2fa4817 --- /dev/null +++ b/tests/skt/__init__.py @@ -0,0 +1,2 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 diff --git a/tests/skt/conftest.py b/tests/skt/conftest.py new file mode 100644 index 0000000..90fe83c --- /dev/null +++ b/tests/skt/conftest.py @@ -0,0 +1,34 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 +import pytest +import json +from EduKTM.utils.tests import pseudo_data_generation +from EduKTM.SKT.etl import transform + + +@pytest.fixture(scope="package") +def conf(): + ques_num = 10 + hidden_num = 5 + return ques_num, hidden_num + + +@pytest.fixture(scope="package") +def data(conf): + ques_num, _ = conf + return transform(pseudo_data_generation(ques_num), 16) + + +@pytest.fixture(scope="session") +def graphs(tmpdir_factory): + graph_dir = tmpdir_factory.mktemp("data").join("graph_1.json") + _graphs = [] + with open(graph_dir, 'w') as wf: + json.dump([[0, 1], [0, 2]], wf) + _graphs.append([graph_dir, False]) + _graphs.append([graph_dir, True]) + graph_dir = tmpdir_factory.mktemp("data").join("graph_2.json") + with open(graph_dir, 'w') as wf: + json.dump([[1, 2, 0.9], [1, 5, 0.1]], wf) + _graphs.append([graph_dir, True, 0.5]) + return _graphs diff --git a/tests/skt/test_skt.py b/tests/skt/test_skt.py new file mode 100644 index 0000000..89af4cc --- /dev/null +++ b/tests/skt/test_skt.py @@ -0,0 +1,12 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 +from EduKTM import SKT + + +def test_train(data, conf, graphs, tmp_path): + ku_num, hidden_num = conf + mgkt = SKT(ku_num, graphs, hidden_num) + mgkt.train(data, test_data=data, epoch=1) + filepath = tmp_path / "skt.params" + mgkt.save(filepath) + mgkt.load(filepath) From 4d69d08c49a367689467a4dbd5d0a0b9099bfb5c Mon Sep 17 00:00:00 2001 From: weizhehuang0827 <871982879@qq.com> Date: Sat, 18 Mar 2023 22:04:11 +0800 Subject: [PATCH 4/8] [docs] fix skt paper url --- docs/SKT.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/SKT.md b/docs/SKT.md index 8b10e61..ef3f198 100644 --- a/docs/SKT.md +++ b/docs/SKT.md @@ -1,6 +1,6 @@ # Structure-based Knowledge Tracing (SKT) -If the reader wants to know the details of SKT, please refer to the Appendix of the paper: *[Structure-based Knowledge Tracing: An Influence Propagation View](https://rlgm.github.io/papers/70.pdf)*. +If the reader wants to know the details of SKT, please refer to the Appendix of the paper: *[Structure-based Knowledge Tracing: An Influence Propagation View](http://staff.ustc.edu.cn/~huangzhy/files/papers/ShiweiTong-ICDM2020.pdf)*. ```bibtex @inproceedings{tong2020structure, From 5feecd1bb9dfd70cffc98328bb56c0ddfa34254d Mon Sep 17 00:00:00 2001 From: weizhehuang0827 <871982879@qq.com> Date: Sun, 19 Mar 2023 17:35:17 +0800 Subject: [PATCH 5/8] [style] fix flake8 --- EduKTM/GKT/GKT.py | 11 +++++++---- EduKTM/GKT/GKTNet.py | 45 +++++++++++++++++++++++++++++--------------- EduKTM/SKT/SKTNet.py | 5 ++--- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/EduKTM/GKT/GKT.py b/EduKTM/GKT/GKT.py index 1dc042c..51de62a 100644 --- a/EduKTM/GKT/GKT.py +++ b/EduKTM/GKT/GKT.py @@ -27,7 +27,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00 loss_function = SLMLoss(**self.loss_params).to(device) self.gkt_model = self.gkt_model.to(device) trainer = torch.optim.Adam(self.gkt_model.parameters(), lr) - + for e in range(epoch): losses = [] for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e): @@ -40,9 +40,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00 label_mask: torch.Tensor = label_mask.to(device) # real training - predicted_response, _ = self.gkt_model(question, data, data_mask) + predicted_response, _ = self.gkt_model( + question, data, data_mask) - loss = loss_function(predicted_response, pick_index, label, label_mask) + loss = loss_function(predicted_response, + pick_index, label, label_mask) # back propagation trainer.zero_grad() @@ -54,7 +56,8 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00 if test_data is not None: auc, accuracy = self.eval(test_data, device=device) - print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy)) + print("[Epoch %d] auc: %.6f, accuracy: %.6f" % + (e, auc, accuracy)) def eval(self, test_data, device="cpu") -> tuple: self.gkt_model.eval() diff --git a/EduKTM/GKT/GKTNet.py b/EduKTM/GKT/GKTNet.py index b9b091e..8a8df85 100644 --- a/EduKTM/GKT/GKTNet.py +++ b/EduKTM/GKT/GKTNet.py @@ -15,8 +15,10 @@ class GKTNet(nn.Module): def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0): super(GKTNet, self).__init__() self.ku_num = int(ku_num) - self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num) - self.latent_dim = self.ku_num if latent_dim is None else int(latent_dim) + self.hidden_num = self.ku_num if hidden_num is None else int( + hidden_num) + self.latent_dim = self.ku_num if latent_dim is None else int( + latent_dim) self.neighbor_dim = self.hidden_num + self.latent_dim self.graph = nx.DiGraph() self.graph.add_nodes_from(list(range(ku_num))) @@ -25,10 +27,12 @@ def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0) self.graph.add_weighted_edges_from(json.load(f)) except ValueError: with open(graph) as f: - self.graph.add_weighted_edges_from([e + [1.0] for e in json.load(f)]) + self.graph.add_weighted_edges_from( + [e + [1.0] for e in json.load(f)]) self.rnn = GRUCell(self.hidden_num) - self.response_embedding = nn.Embedding(2 * self.ku_num, self.latent_dim) + self.response_embedding = nn.Embedding( + 2 * self.ku_num, self.latent_dim) self.concept_embedding = nn.Embedding(self.ku_num, self.latent_dim) self.f_self = nn.Linear(self.neighbor_dim, self.hidden_num) self.n_out = nn.Linear(2 * self.neighbor_dim, self.hidden_num) @@ -96,7 +100,8 @@ def neighbors(self, x, ordinal=True, with_weight=False): def forward(self, questions, answers, valid_length=None, compressed_out=True, layout="NTC"): length = questions.shape[1] device = questions.device - inputs, axis, batch_size = format_sequence(length, questions, layout, False) + inputs, axis, batch_size = format_sequence( + length, questions, layout, False) answers, _, _ = format_sequence(length, answers, layout, False) states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0] states = states.to(device) @@ -108,12 +113,15 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la answer_i = answers[i].reshape([batch_size, ]) _neighbors = self.neighbors(inputs_i) - neighbors_mask = expand_tensor(torch.tensor(_neighbors, device=device), -1, self.hidden_num) - _neighbors_mask = expand_tensor(torch.tensor(_neighbors, device=device), -1, self.hidden_num + self.latent_dim) + neighbors_mask = expand_tensor(torch.tensor( + _neighbors, device=device), -1, self.hidden_num) + _neighbors_mask = expand_tensor(torch.tensor( + _neighbors, device=device), -1, self.hidden_num + self.latent_dim) # get concept embedding concept_embeddings = self.concept_embedding.weight.data - concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size) + concept_embeddings = expand_tensor( + concept_embeddings, 0, batch_size) agg_states = torch.cat((concept_embeddings, states), dim=-1) @@ -122,20 +130,25 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la # self - aggregate _concept_embedding = get_states(inputs_i, states) - _self_hidden_states = torch.cat((_concept_embedding, self.response_embedding(answer_i)), dim=-1) + _self_hidden_states = torch.cat( + (_concept_embedding, self.response_embedding(answer_i)), dim=-1) _self_mask = F.one_hot(inputs_i, self.ku_num) # p _self_mask = expand_tensor(_self_mask, -1, self.hidden_num) - self_hidden_states = expand_tensor(_self_hidden_states, 1, self.ku_num) + self_hidden_states = expand_tensor( + _self_hidden_states, 1, self.ku_num) # aggregate - _hidden_states = torch.cat((_neighbors_states, self_hidden_states), dim=-1) + _hidden_states = torch.cat( + (_neighbors_states, self_hidden_states), dim=-1) _in_state = self.n_in(_hidden_states) _out_state = self.n_out(_hidden_states) - in_weight = expand_tensor(torch.tensor(self.in_weight(inputs_i), device=device), -1, self.hidden_num) - out_weight = expand_tensor(torch.tensor(self.out_weight(inputs_i), device=device), -1, self.hidden_num) + in_weight = expand_tensor(torch.tensor(self.in_weight( + inputs_i), device=device), -1, self.hidden_num) + out_weight = expand_tensor(torch.tensor(self.out_weight( + inputs_i), device=device), -1, self.hidden_num) next_neighbors_states = in_weight * _in_state + out_weight * _out_state @@ -147,7 +160,8 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la next_states = neighbors_mask * next_neighbors_states + next_self_states next_states, _ = self.rnn(next_states, [states]) - next_states = (_self_mask + neighbors_mask) * next_states + (1 - _self_mask - neighbors_mask) * states + next_states = (_self_mask + neighbors_mask) * \ + next_states + (1 - _self_mask - neighbors_mask) * states states = self.dropout(next_states) output = torch.sigmoid(self.out(states).squeeze(axis=-1)) # p @@ -158,6 +172,7 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la if valid_length is not None: if compressed_out: states = None - outputs = mask_sequence_variable_length(torch, outputs, length, valid_length, axis, merge=True) + outputs = mask_sequence_variable_length( + torch, outputs, length, valid_length, axis, merge=True) return outputs, states diff --git a/EduKTM/SKT/SKTNet.py b/EduKTM/SKT/SKTNet.py index a67cb4d..df44384 100644 --- a/EduKTM/SKT/SKTNet.py +++ b/EduKTM/SKT/SKTNet.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn -from torch import nn import torch.nn.functional as F from EduKTM.utils import GRUCell, begin_states, get_states, expand_tensor, \ format_sequence, mask_sequence_variable_length @@ -157,8 +156,8 @@ def forward(self, questions, answers, valid_length=None, states=None, layout='NT # ) # aggregate - _inf = self.f_agg(self.alpha * _sync_inf + - (1 - self.alpha) * _prop_inf) + _inf = self.f_agg(self.alpha * _sync_inf + + (1 - self.alpha) * _prop_inf) next_states, _ = self.rnn(_inf, [states]) # next_states, _ = self.rnn(torch.concat((_inf, concept_embeddings), dim=-1), [states]) # states = (1 - _self_mask) * next_states + _self_mask * _broadcast_next_self_states From b585057c82ae41d7e6a53e3db12faccf0623611f Mon Sep 17 00:00:00 2001 From: weizhehuang0827 <871982879@qq.com> Date: Sun, 19 Mar 2023 17:51:41 +0800 Subject: [PATCH 6/8] [style] fix flake8 --- EduKTM/SKT/SKTNet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/EduKTM/SKT/SKTNet.py b/EduKTM/SKT/SKTNet.py index df44384..641929e 100644 --- a/EduKTM/SKT/SKTNet.py +++ b/EduKTM/SKT/SKTNet.py @@ -156,8 +156,7 @@ def forward(self, questions, answers, valid_length=None, states=None, layout='NT # ) # aggregate - _inf = self.f_agg(self.alpha * _sync_inf - + (1 - self.alpha) * _prop_inf) + _inf = self.f_agg(self.alpha * _sync_inf + (1 - self.alpha) * _prop_inf) next_states, _ = self.rnn(_inf, [states]) # next_states, _ = self.rnn(torch.concat((_inf, concept_embeddings), dim=-1), [states]) # states = (1 - _self_mask) * next_states + _self_mask * _broadcast_next_self_states From bfe6d1efbb7fc567a0b91128bd935d93724d7b26 Mon Sep 17 00:00:00 2001 From: weizhehuang0827 <871982879@qq.com> Date: Sun, 19 Mar 2023 22:48:14 +0800 Subject: [PATCH 7/8] [test] fix skt test converage --- EduKTM/SKT/SKTNet.py | 5 +- EduKTM/SKT/utils.py | 126 ++++++------------------------------------ tests/skt/conftest.py | 26 +++++++++ tests/skt/test_skt.py | 11 ++++ 4 files changed, 57 insertions(+), 111 deletions(-) diff --git a/EduKTM/SKT/SKTNet.py b/EduKTM/SKT/SKTNet.py index 641929e..f034816 100644 --- a/EduKTM/SKT/SKTNet.py +++ b/EduKTM/SKT/SKTNet.py @@ -80,7 +80,6 @@ def forward(self, questions, answers, valid_length=None, states=None, layout='NT states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0] states = states.to(device) outputs = [] - all_states = [] for i in range(length): inputs_i = inputs[i].reshape([batch_size, ]) answer_i = answers[i].reshape([batch_size, ]) @@ -164,8 +163,8 @@ def forward(self, questions, answers, valid_length=None, states=None, layout='NT output = self.sigmoid(torch.squeeze( self.out(self.dropout(states)), axis=-1)) outputs.append(output) - if valid_length is not None and not compressed_out: - all_states.append([states]) + # if valid_length is not None and not compressed_out: + # all_states.append([states]) if valid_length is not None: if compressed_out: diff --git a/EduKTM/SKT/utils.py b/EduKTM/SKT/utils.py index d6cc224..6692e34 100644 --- a/EduKTM/SKT/utils.py +++ b/EduKTM/SKT/utils.py @@ -13,8 +13,8 @@ def as_list(obj) -> list: return obj elif isinstance(obj, tuple): return list(obj) - else: - return [obj] + # else: + # return [obj] class Graph(object): @@ -42,64 +42,14 @@ def neighbors(self, x, ordinal=True, merge_to_one=True, with_weight=False, exclu if isinstance(x, list): return [self.neighbors(_x) for _x in x] elif isinstance(x, (int, float)): - if not ordinal: - if len(self.undirected_graphs) == 0: - return None if not merge_to_one else [] - elif len(self.undirected_graphs) == 1: - return [v for v in self.undirected_graphs[0].neighbors(int(x)) if v not in excluded] - else: - if not merge_to_one: - return [[v for v in graph.neighbors(int(x)) if v not in excluded] for graph in - self.undirected_graphs] - else: - _ret = [] - for graph in self.undirected_graphs: - _ret.extend([v for v in graph.neighbors( - int(x)) if v not in excluded]) - return _ret - else: # ordinal - if not merge_to_one: - if len(self.undirected_graphs) == 0: - return None - elif len(self.undirected_graphs) == 1: - graph = self.undirected_graphs[0] - _ret = [0] * self.ku_num - for i in graph.neighbors(int(x)): - if i in excluded: - continue - if with_weight: - _ret[i] = graph[x][i].get('weight', 1) - else: - _ret[i] = 1 - return _ret - else: - _ret = [] - for graph in self.undirected_graphs: - __ret = [0] * self.ku_num - for i in graph.neighbors(int(x)): - if i in excluded: - continue - if with_weight: - __ret[i] = graph[x][i].get('weight', 1) - else: - __ret[i] = 1 - _ret.append(__ret) - else: - if len(self.undirected_graphs) == 0: - return [0] * self.ku_num - else: - _ret = [0] * self.ku_num - for graph in self.undirected_graphs: - for i in graph.neighbors(int(x)): - if i in excluded: - continue - if with_weight: - _ret[i] += graph[x][i].get('weight', 1) - else: - _ret[i] = 1 - return _ret - else: - raise TypeError("cannot handle %s" % type(x)) + if len(self.undirected_graphs) == 0: + return [0] * self.ku_num + else: + _ret = [0] * self.ku_num + for graph in self.undirected_graphs: + for i in graph.neighbors(int(x)): + _ret[i] = 1 + return _ret def successors(self, x, ordinal=True, merge_to_one=True, excluded=None): excluded = set() if excluded is None else excluded @@ -107,56 +57,16 @@ def successors(self, x, ordinal=True, merge_to_one=True, excluded=None): if isinstance(x, torch.Tensor): x = x.tolist() if isinstance(x, list): - return [self.neighbors(_x) for _x in x] + return [self.successors(_x) for _x in x] elif isinstance(x, (int, float)): - if not ordinal: - if len(self.directed_graphs) == 0: - return None if not merge_to_one else [] - elif len(self.directed_graphs) == 1: - return [v for v in self.directed_graphs[0].successors(int(x)) if v not in excluded] - else: - if not merge_to_one: - return [[v for v in graph.successors(int(x)) if v not in excluded] for graph in - self.directed_graphs] - else: - _ret = [] - for graph in self.directed_graphs: - _ret.extend([v for v in graph.successors( - int(x)) if v not in excluded]) - return _ret + if len(self.directed_graphs) == 0: + return [0] * self.ku_num else: - if not merge_to_one: - if len(self.directed_graphs) == 0: - return None - elif len(self.directed_graphs) == 1: - _ret = [0] * self.ku_num - for i in self.directed_graphs[0].successors(int(x)): - if i in excluded: - continue - _ret[i] = 1 - return _ret - else: - _ret = [] - for graph in self.directed_graphs: - __ret = [0] * self.ku_num - for i in graph.successors(int(x)): - if i in excluded: - continue - _ret[i] = 1 - _ret.append(__ret) - else: - if len(self.directed_graphs) == 0: - return [0] * self.ku_num - else: - _ret = [0] * self.ku_num - for graph in self.directed_graphs: - for i in graph.successors(int(x)): - if i in excluded: - continue - _ret[i] = 1 - return _ret - else: - raise TypeError("cannot handle %s" % type(x)) + _ret = [0] * self.ku_num + for graph in self.directed_graphs: + for i in graph.successors(int(x)): + _ret[i] = 1 + return _ret @classmethod def from_file(cls, graph_nodes_num, graph_params): diff --git a/tests/skt/conftest.py b/tests/skt/conftest.py index 90fe83c..f0a11ee 100644 --- a/tests/skt/conftest.py +++ b/tests/skt/conftest.py @@ -32,3 +32,29 @@ def graphs(tmpdir_factory): json.dump([[1, 2, 0.9], [1, 5, 0.1]], wf) _graphs.append([graph_dir, True, 0.5]) return _graphs + + +@pytest.fixture(scope="session") +def graphs_2(tmpdir_factory): + graph_dir = tmpdir_factory.mktemp("data").join("graph_3.json") + _graphs = [] + with open(graph_dir, 'w') as wf: + json.dump([[0, 1], [0, 2]], wf) + _graphs.append((graph_dir, False)) + _graphs.append([graph_dir, True]) + graph_dir = tmpdir_factory.mktemp("data").join("graph_4.json") + with open(graph_dir, 'w') as wf: + json.dump([[1, 2, 0.9], [1, 5, 0.1]], wf) + _graphs.append([graph_dir, True, -1]) + return _graphs + + +@pytest.fixture(scope="session") +def graphs_3(tmpdir_factory): + graph_dir = tmpdir_factory.mktemp("data").join("graph_5.json") + _graphs = [] + with open(graph_dir, 'w') as wf: + json.dump([], wf) + _graphs.append((graph_dir, False)) + _graphs.append([graph_dir, True]) + return _graphs diff --git a/tests/skt/test_skt.py b/tests/skt/test_skt.py index 89af4cc..3fad8ad 100644 --- a/tests/skt/test_skt.py +++ b/tests/skt/test_skt.py @@ -1,6 +1,7 @@ # coding: utf-8 # 2023/3/17 @ weizhehuang0827 from EduKTM import SKT +from EduKTM.SKT.utils import Graph def test_train(data, conf, graphs, tmp_path): @@ -10,3 +11,13 @@ def test_train(data, conf, graphs, tmp_path): filepath = tmp_path / "skt.params" mgkt.save(filepath) mgkt.load(filepath) + + +def test_graph(conf, graphs_2): + ku_num, _ = conf + graph = Graph.from_file(ku_num, graphs_2) + graph.info + graph.neighbors(0, excluded=[1]) + none_graph = Graph(ku_num, [], []) + none_graph.neighbors(0) + none_graph.successors(0) From 21b4b24fbbc832bbef611b07fbc3d4c2b62619ae Mon Sep 17 00:00:00 2001 From: weizhehuang0827 <871982879@qq.com> Date: Wed, 22 Mar 2023 19:36:53 +0800 Subject: [PATCH 8/8] [revert] revert GKT --- EduKTM/GKT/GKT.py | 16 +++++--------- EduKTM/GKT/GKTNet.py | 52 +++++++++++++++----------------------------- EduKTM/SKT/SKTNet.py | 3 +-- 3 files changed, 25 insertions(+), 46 deletions(-) diff --git a/EduKTM/GKT/GKT.py b/EduKTM/GKT/GKT.py index 51de62a..3d477a8 100644 --- a/EduKTM/GKT/GKT.py +++ b/EduKTM/GKT/GKT.py @@ -24,8 +24,7 @@ def __init__(self, ku_num, graph, hidden_num, net_params: dict = None, loss_para self.loss_params = loss_params if loss_params is not None else {} def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: - loss_function = SLMLoss(**self.loss_params).to(device) - self.gkt_model = self.gkt_model.to(device) + loss_function = SLMLoss(**self.loss_params) trainer = torch.optim.Adam(self.gkt_model.parameters(), lr) for e in range(epoch): @@ -40,11 +39,9 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00 label_mask: torch.Tensor = label_mask.to(device) # real training - predicted_response, _ = self.gkt_model( - question, data, data_mask) + predicted_response, _ = self.gkt_model(question, data, data_mask) - loss = loss_function(predicted_response, - pick_index, label, label_mask) + loss = loss_function(predicted_response, pick_index, label, label_mask) # back propagation trainer.zero_grad() @@ -55,9 +52,8 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00 print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses)))) if test_data is not None: - auc, accuracy = self.eval(test_data, device=device) - print("[Epoch %d] auc: %.6f, accuracy: %.6f" % - (e, auc, accuracy)) + auc, accuracy = self.eval(test_data) + print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy)) def eval(self, test_data, device="cpu") -> tuple: self.gkt_model.eval() @@ -79,7 +75,7 @@ def eval(self, test_data, device="cpu") -> tuple: output = pick(output, pick_index.to(output.device)) pred = tensor2list(output) label = tensor2list(label) - for i, length in enumerate(label_mask.cpu().tolist()): + for i, length in enumerate(label_mask.numpy().tolist()): length = int(length) y_true.extend(label[i][:length]) y_pred.extend(pred[i][:length]) diff --git a/EduKTM/GKT/GKTNet.py b/EduKTM/GKT/GKTNet.py index 9673f6e..821c259 100644 --- a/EduKTM/GKT/GKTNet.py +++ b/EduKTM/GKT/GKTNet.py @@ -15,10 +15,8 @@ class GKTNet(nn.Module): def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0): super(GKTNet, self).__init__() self.ku_num = int(ku_num) - self.hidden_num = self.ku_num if hidden_num is None else int( - hidden_num) - self.latent_dim = self.ku_num if latent_dim is None else int( - latent_dim) + self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num) + self.latent_dim = self.ku_num if latent_dim is None else int(latent_dim) self.neighbor_dim = self.hidden_num + self.latent_dim self.graph = nx.DiGraph() self.graph.add_nodes_from(list(range(ku_num))) @@ -27,12 +25,10 @@ def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0) self.graph.add_weighted_edges_from(json.load(f)) except ValueError: with open(graph) as f: - self.graph.add_weighted_edges_from( - [e + [1.0] for e in json.load(f)]) + self.graph.add_weighted_edges_from([e + [1.0] for e in json.load(f)]) self.rnn = GRUCell(self.hidden_num) - self.response_embedding = nn.Embedding( - 2 * self.ku_num, self.latent_dim) + self.response_embedding = nn.Embedding(2 * self.ku_num, self.latent_dim) self.concept_embedding = nn.Embedding(self.ku_num, self.latent_dim) self.f_self = nn.Linear(self.neighbor_dim, self.hidden_num) self.n_out = nn.Linear(2 * self.neighbor_dim, self.hidden_num) @@ -42,7 +38,7 @@ def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0) def in_weight(self, x, ordinal=True, with_weight=True): if isinstance(x, torch.Tensor): - x = x.cpu().numpy().tolist() + x = x.numpy().tolist() if isinstance(x, list): return [self.in_weight(_x) for _x in x] elif isinstance(x, (int, float)): @@ -53,7 +49,7 @@ def in_weight(self, x, ordinal=True, with_weight=True): def out_weight(self, x, ordinal=True, with_weight=True): if isinstance(x, torch.Tensor): - x = x.cpu().numpy().tolist() + x = x.numpy().tolist() if isinstance(x, list): return [self.out_weight(_x) for _x in x] elif isinstance(x, (int, float)): @@ -64,7 +60,7 @@ def out_weight(self, x, ordinal=True, with_weight=True): def neighbors(self, x, ordinal=True, with_weight=False): if isinstance(x, torch.Tensor): - x = x.cpu().numpy().tolist() + x = x.numpy().tolist() if isinstance(x, list): return [self.neighbors(_x) for _x in x] elif isinstance(x, (int, float)): @@ -75,12 +71,10 @@ def neighbors(self, x, ordinal=True, with_weight=False): def forward(self, questions, answers, valid_length=None, compressed_out=True, layout="NTC"): length = questions.shape[1] - device = questions.device - inputs, axis, batch_size = format_sequence( - length, questions, layout, False) + inputs, axis, batch_size = format_sequence(length, questions, layout, False) answers, _, _ = format_sequence(length, answers, layout, False) + states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0] - states = states.to(device) outputs = [] for i in range(length): # neighbors - aggregate @@ -88,15 +82,12 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la answer_i = answers[i].reshape([batch_size, ]) _neighbors = self.neighbors(inputs_i) - neighbors_mask = expand_tensor(torch.tensor( - _neighbors, device=device), -1, self.hidden_num) - _neighbors_mask = expand_tensor(torch.tensor( - _neighbors, device=device), -1, self.hidden_num + self.latent_dim) + neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num) + _neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num + self.latent_dim) # get concept embedding concept_embeddings = self.concept_embedding.weight.data - concept_embeddings = expand_tensor( - concept_embeddings, 0, batch_size) + concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size) agg_states = torch.cat((concept_embeddings, states), dim=-1) @@ -105,25 +96,20 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la # self - aggregate _concept_embedding = get_states(inputs_i, states) - _self_hidden_states = torch.cat( - (_concept_embedding, self.response_embedding(answer_i)), dim=-1) + _self_hidden_states = torch.cat((_concept_embedding, self.response_embedding(answer_i)), dim=-1) _self_mask = F.one_hot(inputs_i, self.ku_num) # p _self_mask = expand_tensor(_self_mask, -1, self.hidden_num) - self_hidden_states = expand_tensor( - _self_hidden_states, 1, self.ku_num) + self_hidden_states = expand_tensor(_self_hidden_states, 1, self.ku_num) # aggregate - _hidden_states = torch.cat( - (_neighbors_states, self_hidden_states), dim=-1) + _hidden_states = torch.cat((_neighbors_states, self_hidden_states), dim=-1) _in_state = self.n_in(_hidden_states) _out_state = self.n_out(_hidden_states) - in_weight = expand_tensor(torch.tensor(self.in_weight( - inputs_i), device=device), -1, self.hidden_num) - out_weight = expand_tensor(torch.tensor(self.out_weight( - inputs_i), device=device), -1, self.hidden_num) + in_weight = expand_tensor(torch.Tensor(self.in_weight(inputs_i)), -1, self.hidden_num) + out_weight = expand_tensor(torch.Tensor(self.out_weight(inputs_i)), -1, self.hidden_num) next_neighbors_states = in_weight * _in_state + out_weight * _out_state @@ -135,8 +121,7 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la next_states = neighbors_mask * next_neighbors_states + next_self_states next_states, _ = self.rnn(next_states, [states]) - next_states = (_self_mask + neighbors_mask) * \ - next_states + (1 - _self_mask - neighbors_mask) * states + next_states = (_self_mask + neighbors_mask) * next_states + (1 - _self_mask - neighbors_mask) * states states = self.dropout(next_states) output = torch.sigmoid(self.out(states).squeeze(axis=-1)) # p @@ -147,5 +132,4 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la states = None outputs = mask_sequence_variable_length(torch, outputs, valid_length) - return outputs, states diff --git a/EduKTM/SKT/SKTNet.py b/EduKTM/SKT/SKTNet.py index f034816..064d187 100644 --- a/EduKTM/SKT/SKTNet.py +++ b/EduKTM/SKT/SKTNet.py @@ -169,7 +169,6 @@ def forward(self, questions, answers, valid_length=None, states=None, layout='NT if valid_length is not None: if compressed_out: states = None - outputs = mask_sequence_variable_length( - torch, outputs, length, valid_length, axis, merge=True) + outputs = mask_sequence_variable_length(torch, outputs, valid_length) return outputs, states