diff --git a/AUTHORS.md b/AUTHORS.md index 69fe259..a78cebc 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -10,6 +10,8 @@ [Jie Ouyang](https://github.com/0russwest0) +[Weizhe Huang](https://github.com/weizhehuang0827) + [Bihan Xu](https://github.com/xbh0720) The starred is the corresponding author diff --git a/EduKTM/SKT/SKT.py b/EduKTM/SKT/SKT.py new file mode 100644 index 0000000..58e490b --- /dev/null +++ b/EduKTM/SKT/SKT.py @@ -0,0 +1,94 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 + +import logging +import numpy as np +import torch +from tqdm import tqdm +from EduKTM import KTM +from .SKTNet import SKTNet +from EduKTM.utils import SLMLoss, tensor2list, pick +from sklearn.metrics import roc_auc_score, accuracy_score + + +class SKT(KTM): + def __init__(self, ku_num, graph_params, hidden_num, net_params: dict = None, loss_params=None): + super(SKT, self).__init__() + self.skt_model = SKTNet( + ku_num, + graph_params, + hidden_num, + **(net_params if net_params is not None else {}) + ) + self.loss_params = loss_params if loss_params is not None else {} + + def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: + loss_function = SLMLoss(**self.loss_params).to(device) + self.skt_model = self.skt_model.to(device) + trainer = torch.optim.Adam(self.skt_model.parameters(), lr) + + for e in range(epoch): + losses = [] + for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e): + # convert to device + question: torch.Tensor = question.to(device) + data: torch.Tensor = data.to(device) + data_mask: torch.Tensor = data_mask.to(device) + label: torch.Tensor = label.to(device) + pick_index: torch.Tensor = pick_index.to(device) + label_mask: torch.Tensor = label_mask.to(device) + + # real training + predicted_response, _ = self.skt_model( + question, data, data_mask) + + loss = loss_function(predicted_response, + pick_index, label, label_mask) + + # back propagation + trainer.zero_grad() + loss.backward() + trainer.step() + + losses.append(loss.mean().item()) + print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses)))) + + if test_data is not None: + auc, accuracy = self.eval(test_data, device=device) + print("[Epoch %d] auc: %.6f, accuracy: %.6f" % + (e, auc, accuracy)) + + def eval(self, test_data, device="cpu") -> tuple: + self.skt_model.eval() + y_true = [] + y_pred = [] + + for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"): + # convert to device + question: torch.Tensor = question.to(device) + data: torch.Tensor = data.to(device) + data_mask: torch.Tensor = data_mask.to(device) + label: torch.Tensor = label.to(device) + pick_index: torch.Tensor = pick_index.to(device) + label_mask: torch.Tensor = label_mask.to(device) + + # real evaluating + output, _ = self.skt_model(question, data, data_mask) + output = output[:, :-1] + output = pick(output, pick_index.to(output.device)) + pred = tensor2list(output) + label = tensor2list(label) + for i, length in enumerate(label_mask.cpu().tolist()): + length = int(length) + y_true.extend(label[i][:length]) + y_pred.extend(pred[i][:length]) + self.skt_model.train() + return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5) + + def save(self, filepath) -> ...: + torch.save(self.skt_model.state_dict(), filepath) + logging.info("save parameters to %s" % filepath) + + def load(self, filepath): + self.skt_model.load_state_dict(torch.load(filepath)) + logging.info("load parameters from %s" % filepath) diff --git a/EduKTM/SKT/SKTNet.py b/EduKTM/SKT/SKTNet.py new file mode 100644 index 0000000..064d187 --- /dev/null +++ b/EduKTM/SKT/SKTNet.py @@ -0,0 +1,174 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 +__all__ = ["SKTNet"] + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from EduKTM.utils import GRUCell, begin_states, get_states, expand_tensor, \ + format_sequence, mask_sequence_variable_length +from .utils import Graph + + +class SKTNet(nn.Module): + def __init__(self, ku_num, graph_params=None, + alpha=0.5, + latent_dim=None, activation=None, + hidden_num=90, concept_dim=None, + # dropout=0.5, self_dropout=0.0, + dropout=0.0, self_dropout=0.5, + # dropout=0.0, self_dropout=0.0, + sync_dropout=0.0, + prop_dropout=0.0, + agg_dropout=0.0, + params=None): + super(SKTNet, self).__init__() + self.ku_num = int(ku_num) + self.hidden_num = self.ku_num if hidden_num is None else int( + hidden_num) + self.latent_dim = self.hidden_num if latent_dim is None else int( + latent_dim) + self.concept_dim = self.hidden_num if concept_dim is None else int( + concept_dim) + graph_params = graph_params if graph_params is not None else [] + self.graph = Graph.from_file(ku_num, graph_params) + self.alpha = alpha + + sync_activation = nn.ReLU() if activation is None else activation + prop_activation = nn.ReLU() if activation is None else activation + agg_activation = nn.ReLU() if activation is None else activation + + self.rnn = GRUCell(self.hidden_num) + self.response_embedding = nn.Embedding( + 2 * self.ku_num, self.latent_dim) + self.concept_embedding = nn.Embedding(self.ku_num, self.concept_dim) + self.f_self = GRUCell(self.hidden_num) + self.self_dropout = nn.Dropout(self_dropout) + self.f_prop = nn.Sequential( + nn.Linear(self.hidden_num * 2, self.hidden_num), + prop_activation, + nn.Dropout(prop_dropout), + ) + self.f_sync = nn.Sequential( + nn.Linear(self.hidden_num * 3, self.hidden_num), + sync_activation, + nn.Dropout(sync_dropout), + ) + self.f_agg = nn.Sequential( + nn.Linear(self.hidden_num, self.hidden_num), + agg_activation, + nn.Dropout(agg_dropout), + ) + self.dropout = nn.Dropout(dropout) + self.out = nn.Linear(self.hidden_num, 1) + self.sigmoid = nn.Sigmoid() + + def neighbors(self, x, ordinal=True): + return self.graph.neighbors(x, ordinal) + + def successors(self, x, ordinal=True): + return self.graph.successors(x, ordinal) + + def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True, + *args, **kwargs): + length = questions.shape[1] + device = questions.device + inputs, axis, batch_size = format_sequence( + length, questions, layout, False) + answers, _, _ = format_sequence(length, answers, layout, False) + states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0] + states = states.to(device) + outputs = [] + for i in range(length): + inputs_i = inputs[i].reshape([batch_size, ]) + answer_i = answers[i].reshape([batch_size, ]) + + # concept embedding + concept_embeddings = self.concept_embedding.weight.data + concept_embeddings = expand_tensor( + concept_embeddings, 0, batch_size) + # concept_embeddings = (_self_mask + _successors_mask + _neighbors_mask) * concept_embeddings + + # self - influence + _self_state = get_states(inputs_i, states) + # fc + # _next_self_state = self.f_self(mx.nd.concat(_self_state, self.response_embedding(answers[i]), dim=-1)) + # gru + _next_self_state, _ = self.f_self( + self.response_embedding(answer_i), [_self_state]) + # _next_self_state = self.f_self(mx.nd.concat(_self_hidden_states, _self_state)) + # _next_self_state, _ = self.f_self(_self_hidden_states, [_self_state]) + _next_self_state = self.self_dropout(_next_self_state) + + # get self mask + _self_mask = torch.unsqueeze(F.one_hot(inputs_i, self.ku_num), -1) + _self_mask = torch.broadcast_to( + _self_mask, (-1, -1, self.hidden_num)) + + # find neighbors + _neighbors = self.neighbors(inputs_i) + _neighbors_mask = torch.unsqueeze( + torch.tensor(_neighbors, device=device), -1) + _neighbors_mask = torch.broadcast_to( + _neighbors_mask, (-1, -1, self.hidden_num)) + + # synchronization + _broadcast_next_self_states = torch.unsqueeze(_next_self_state, 1) + _broadcast_next_self_states = torch.broadcast_to( + _broadcast_next_self_states, (-1, self.ku_num, -1)) + # _sync_diff = mx.nd.concat(states, _broadcast_next_self_states, concept_embeddings, dim=-1) + _sync_diff = torch.concat( + (states, _broadcast_next_self_states, concept_embeddings), dim=-1) + _sync_inf = _neighbors_mask * self.f_sync(_sync_diff) + + # reflection on current vertex + _reflec_inf = torch.sum(_sync_inf, dim=1) + _reflec_inf = torch.broadcast_to( + torch.unsqueeze(_reflec_inf, 1), (-1, self.ku_num, -1)) + _sync_inf = _sync_inf + _self_mask * _reflec_inf + + # find successors + _successors = self.successors(inputs_i) + _successors_mask = torch.unsqueeze( + torch.tensor(_successors, device=device), -1) + _successors_mask = torch.broadcast_to( + _successors_mask, (-1, -1, self.hidden_num)) + + # propagation + _prop_diff = torch.concat( + (_next_self_state - _self_state, self.concept_embedding(inputs_i)), dim=-1) + # _prop_diff = _next_self_state - _self_state + + # 1 + _prop_inf = self.f_prop(_prop_diff) + _prop_inf = _successors_mask * \ + torch.broadcast_to(torch.unsqueeze( + _prop_inf, axis=1), (-1, self.ku_num, -1)) + # 2 + # _broadcast_diff = mx.nd.broadcast_to(mx.nd.expand_dims(_prop_diff, axis=1), (0, self.ku_num, 0)) + # _pro_inf = _successors_mask * self.f_prop( + # mx.nd.concat(_broadcast_diff, concept_embeddings, dim=-1) + # ) + # _pro_inf = _successors_mask * self.f_prop( + # _broadcast_diff + # ) + + # aggregate + _inf = self.f_agg(self.alpha * _sync_inf + (1 - self.alpha) * _prop_inf) + next_states, _ = self.rnn(_inf, [states]) + # next_states, _ = self.rnn(torch.concat((_inf, concept_embeddings), dim=-1), [states]) + # states = (1 - _self_mask) * next_states + _self_mask * _broadcast_next_self_states + states = next_states + output = self.sigmoid(torch.squeeze( + self.out(self.dropout(states)), axis=-1)) + outputs.append(output) + # if valid_length is not None and not compressed_out: + # all_states.append([states]) + + if valid_length is not None: + if compressed_out: + states = None + outputs = mask_sequence_variable_length(torch, outputs, valid_length) + + return outputs, states diff --git a/EduKTM/SKT/__init__.py b/EduKTM/SKT/__init__.py new file mode 100644 index 0000000..69108dd --- /dev/null +++ b/EduKTM/SKT/__init__.py @@ -0,0 +1,2 @@ +from .SKT import SKT +from .etl import etl diff --git a/EduKTM/SKT/etl.py b/EduKTM/SKT/etl.py new file mode 100644 index 0000000..a9043c3 --- /dev/null +++ b/EduKTM/SKT/etl.py @@ -0,0 +1,81 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 + + +import torch +import json +from tqdm import tqdm +from EduKTM.utils.torch import PadSequence, FixedBucketSampler + + +def extract(data_src, max_step=200): # pragma: no cover + responses = [] + step = max_step + with open(data_src) as f: + for line in tqdm(f, "reading data from %s" % data_src): + data = json.loads(line) + if step is not None: + for i in range(0, len(data), step): + if len(data[i: i + step]) < 2: + continue + responses.append(data[i: i + step]) + else: + responses.append(data) + + return responses + + +def transform(raw_data, batch_size, num_buckets=100): + # 定义数据转换接口 + # raw_data --> batch_data + + responses = raw_data + + batch_idxes = FixedBucketSampler( + [len(rs) for rs in responses], batch_size, num_buckets=num_buckets) + batch = [] + + def index(r): + correct = 0 if r[1] <= 0 else 1 + return r[0] * 2 + correct + + for batch_idx in tqdm(batch_idxes, "batchify"): + batch_qs = [] + batch_rs = [] + batch_pick_index = [] + batch_labels = [] + for idx in batch_idx: + batch_qs.append([r[0] for r in responses[idx]]) + batch_rs.append([index(r) for r in responses[idx]]) + if len(responses[idx]) <= 1: # pragma: no cover + pick_index, labels = [], [] + else: + pick_index, labels = zip( + *[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) + batch_pick_index.append(list(pick_index)) + batch_labels.append(list(labels)) + + max_len = max([len(rs) for rs in batch_rs]) + padder = PadSequence(max_len, pad_val=0) + batch_qs = [padder(qs) for qs in batch_qs] + batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) + + max_len = max([len(rs) for rs in batch_labels]) + padder = PadSequence(max_len, pad_val=0) + batch_labels, label_mask = zip( + *[(padder(labels), len(labels)) for labels in batch_labels]) + batch_pick_index = [padder(pick_index) + for pick_index in batch_pick_index] + # Load + batch.append( + [torch.tensor(batch_qs), torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels), + torch.tensor(batch_pick_index), + torch.tensor(label_mask)]) + + return batch + + +def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover + batch_size = batch_size if batch_size is not None else cfg.batch_size + raw_data = extract(data_src) + return transform(raw_data, batch_size, **kwargs) diff --git a/EduKTM/SKT/utils.py b/EduKTM/SKT/utils.py new file mode 100644 index 0000000..6692e34 --- /dev/null +++ b/EduKTM/SKT/utils.py @@ -0,0 +1,109 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 + + +import json +import torch +import networkx as nx +__all__ = ["Graph", "load_graph"] + + +def as_list(obj) -> list: + if isinstance(obj, list): + return obj + elif isinstance(obj, tuple): + return list(obj) + # else: + # return [obj] + + +class Graph(object): + def __init__(self, ku_num, directed_graphs, undirected_graphs): + self.ku_num = ku_num + self.directed_graphs = as_list(directed_graphs) + self.undirected_graphs = as_list(undirected_graphs) + + @staticmethod + def _info(graph: nx.Graph): + return {"edges": len(graph.edges)} + + @property + def info(self): + return { + "directed": [self._info(graph) for graph in self.directed_graphs], + "undirected": [self._info(graph) for graph in self.undirected_graphs] + } + + def neighbors(self, x, ordinal=True, merge_to_one=True, with_weight=False, excluded=None): + excluded = set() if excluded is None else excluded + + if isinstance(x, torch.Tensor): + x = x.tolist() + if isinstance(x, list): + return [self.neighbors(_x) for _x in x] + elif isinstance(x, (int, float)): + if len(self.undirected_graphs) == 0: + return [0] * self.ku_num + else: + _ret = [0] * self.ku_num + for graph in self.undirected_graphs: + for i in graph.neighbors(int(x)): + _ret[i] = 1 + return _ret + + def successors(self, x, ordinal=True, merge_to_one=True, excluded=None): + excluded = set() if excluded is None else excluded + + if isinstance(x, torch.Tensor): + x = x.tolist() + if isinstance(x, list): + return [self.successors(_x) for _x in x] + elif isinstance(x, (int, float)): + if len(self.directed_graphs) == 0: + return [0] * self.ku_num + else: + _ret = [0] * self.ku_num + for graph in self.directed_graphs: + for i in graph.successors(int(x)): + _ret[i] = 1 + return _ret + + @classmethod + def from_file(cls, graph_nodes_num, graph_params): + directed_graphs = [] + undirected_graphs = [] + for graph_param in graph_params: + graph, directed = load_graph( + graph_nodes_num, *as_list(graph_param)) + if directed: + directed_graphs.append(graph) + else: + undirected_graphs.append(graph) + return cls(graph_nodes_num, directed_graphs, undirected_graphs) + + +def load_graph(graph_nodes_num, filename=None, directed: bool = True, threshold=0.0): + directed = bool(directed) + if directed: + graph = nx.DiGraph() + else: + graph = nx.Graph() + + graph.add_nodes_from(range(graph_nodes_num)) + if threshold < 0.0: + for i in range(graph_nodes_num): + for j in range(graph_nodes_num): + graph.add_edge(i, j) + else: + assert filename is not None + with open(filename) as f: + for data in json.load(f): + pre, suc = data[0], data[1] + if len(data) >= 3 and float(data[2]) < threshold: + continue + elif len(data) >= 3: + weight = float(data[2]) + graph.add_edge(pre, suc, weight=weight) + continue + graph.add_edge(pre, suc) + return graph, directed diff --git a/EduKTM/__init__.py b/EduKTM/__init__.py index a841ba1..e665d35 100644 --- a/EduKTM/__init__.py +++ b/EduKTM/__init__.py @@ -10,4 +10,5 @@ from .LPKT import LPKT from .GKT import GKT from .DKVMN import DKVMN +from .SKT import SKT from .LBKT import LBKT diff --git a/README.md b/README.md index 67167dc..eef70eb 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Knowledge Tracing (KT), which aims to monitor students’ evolving knowledge sta * [GKT](EduKTM/GKT)[[doc]](docs/GKT.md) [[example]](examples/GKT) * [AKT](EduKTM/AKT) [[doc]](docs/AKT.md) [[example]](examples/AKT) * [LPKT](EduKTM/LPKT) [[doc]](docs/LPKT.md) [[example]](examples/LPKT) +* [SKT](EduKTM/SKT) [[doc]](docs/SKT.md) [[example]](examples/SKT) ## Contribute diff --git a/docs/SKT.md b/docs/SKT.md new file mode 100644 index 0000000..ef3f198 --- /dev/null +++ b/docs/SKT.md @@ -0,0 +1,14 @@ +# Structure-based Knowledge Tracing (SKT) + +If the reader wants to know the details of SKT, please refer to the Appendix of the paper: *[Structure-based Knowledge Tracing: An Influence Propagation View](http://staff.ustc.edu.cn/~huangzhy/files/papers/ShiweiTong-ICDM2020.pdf)*. + +```bibtex +@inproceedings{tong2020structure, + title={Structure-based knowledge tracing: An influence propagation view}, + author={Tong, Shiwei and Liu, Qi and Huang, Wei and Hunag, Zhenya and Chen, Enhong and Liu, Chuanren and Ma, Haiping and Wang, Shijin}, + booktitle={2020 IEEE international conference on data mining (ICDM)}, + pages={541--550}, + year={2020}, + organization={IEEE} +} +``` diff --git a/examples/SKT/SKT.ipynb b/examples/SKT/SKT.ipynb new file mode 100644 index 0000000..47521ad --- /dev/null +++ b/examples/SKT/SKT.ipynb @@ -0,0 +1,223 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Structure-based Knowledge Tracing" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook will show you how to train and use the SKT. First, we will show how to get the data (here we use a0910 as the dataset). Then we will show how to train a GKT and perform the parameters persistence. At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n", + "\n", + "\n", + "The script version could be found in [SKT.py](https://github.com/bigdata-ustc/EduKTM/blob/main/examples/SKT/SKT.py)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Preparation\n", + "\n", + "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](https://github.com/bigdata-ustc/EduKTM/blob/main/examples/SKT/prepare_dataset.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/huangweizhe/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "reading data from ../../data/assistment_2009_2010/train.json: 3025it [00:01, 2397.38it/s]\n", + "batchify: 100%|██████████| 327/327 [00:00<00:00, 432.65it/s]\n", + "reading data from ../../data/assistment_2009_2010/test.json: 856it [00:00, 3526.76it/s]\n", + "/data/huangweizhe/EduKTM/EduKTM/utils/torch/extlib/sampler.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[104, 108, 115, 119, 120, 122, 125, 127, 129, 130, 134, 143, 147, 157, 159, 160, 163, 165, 166, 169, 173, 174, 178, 181, 184, 188, 189, 192, 193, 194, 196]\n", + " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", + "batchify: 100%|██████████| 134/134 [00:00<00:00, 618.56it/s]\n", + "reading data from ../../data/assistment_2009_2010/test.json: 856it [00:00, 2156.08it/s]\n", + "batchify: 100%|██████████| 134/134 [00:00<00:00, 500.12it/s]\n" + ] + } + ], + "source": [ + "from EduKTM.GKT import etl\n", + "\n", + "batch_size = 16\n", + "train = etl(\"../../data/assistment_2009_2010/train.json\", batch_size=batch_size)\n", + "valid = etl(\"../../data/assistment_2009_2010/test.json\", batch_size=batch_size)\n", + "test = etl(\"../../data/assistment_2009_2010/test.json\", batch_size=batch_size)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training and Persistence" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.getLogger().setLevel(logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|██████████| 327/327 [05:52<00:00, 1.08s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 0] SLMoss: 0.445898\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "evaluating: 100%|██████████| 134/134 [01:27<00:00, 1.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 0] auc: 0.594213, accuracy: 0.631270\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 1: 100%|██████████| 327/327 [06:32<00:00, 1.20s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 1] SLMoss: 0.433631\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "evaluating: 100%|██████████| 134/134 [01:26<00:00, 1.56it/s]\n", + "INFO:root:save parameters to skt.params\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 1] auc: 0.624451, accuracy: 0.681644\n" + ] + } + ], + "source": [ + "from EduKTM import SKT\n", + "import torch\n", + "\n", + "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n", + "model = SKT(ku_num=124, graph_params=[\n", + " ['../../data/assistment_2009_2010/correct_transition_graph.json', True],\n", + " ['../../data/assistment_2009_2010/ctrans_sim.json', False]\n", + " ], hidden_num=5)\n", + "model.train(train, valid, epoch=2, device=device)\n", + "model.save(\"skt.params\")\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading and Testing" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:load parameters from skt.params\n", + "evaluating: 100%|██████████| 134/134 [01:13<00:00, 1.81it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "auc: 0.624451, accuracy: 0.681644\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model.load(\"skt.params\")\n", + "auc, accuracy = model.eval(test, device=device)\n", + "print(\"auc: %.6f, accuracy: %.6f\" % (auc, accuracy))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/SKT/SKT.py b/examples/SKT/SKT.py new file mode 100644 index 0000000..16b3d01 --- /dev/null +++ b/examples/SKT/SKT.py @@ -0,0 +1,27 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 + +import logging +from EduKTM.GKT import etl +from EduKTM import SKT +import torch + +batch_size = 16 +train = etl("../../data/assistment_2009_2010/train.json", + batch_size=batch_size) +valid = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) +test = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) + +logging.getLogger().setLevel(logging.INFO) + +device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") +model = SKT(ku_num=124, graph_params=[ + ['../../data/assistment_2009_2010/correct_transition_graph.json', True], + ['../../data/assistment_2009_2010/ctrans_sim.json', False] + ], hidden_num=5) +model.train(train, valid, epoch=2, device=device) +model.save("skt.params") + +model.load("skt.params") +auc, accuracy = model.eval(test, device=device) +print("auc: %.6f, accuracy: %.6f" % (auc, accuracy)) diff --git a/examples/SKT/prepare_dataset.ipynb b/examples/SKT/prepare_dataset.ipynb new file mode 100644 index 0000000..4ca7c23 --- /dev/null +++ b/examples/SKT/prepare_dataset.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/correct_transition_graph.json is saved as ../../data/assistment_2009_2010/correct_transition_graph.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ../../data/assistment_2009_2010/correct_transition_graph.json 100.00%: 34.1KB | 34.1KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/ctrans_sim.json is saved as ../../data/assistment_2009_2010/ctrans_sim.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/ctrans_sim.json 100.00%: 200KB | 200KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/dense_graph.json is saved as ../../data/assistment_2009_2010/dense_graph.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/dense_graph.json 100.00%: 361KB | 361KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/test.json is saved as ../../data/assistment_2009_2010/test.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/test.json 100.00%: 1.02MB | 1.02MB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/train.json is saved as ../../data/assistment_2009_2010/train.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/train.json 100.00%: 3.23MB | 3.23MB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/trans_sim.json is saved as ../../data/assistment_2009_2010/trans_sim.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/trans_sim.json 100.00%: 350KB | 350KB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/transition_graph.json is saved as ../../data/assistment_2009_2010/transition_graph.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ../../data/assistment_2009_2010/transition_graph.json 100.00%: 52.9KB | 52.9KB" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/valid.json is saved as ../../data/assistment_2009_2010/valid.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Downloading ../../data/assistment_2009_2010/valid.json 100.00%: 310KB | 310KB\n" + ] + }, + { + "data": { + "text/plain": [ + "'../../data/assistment_2009_2010'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from EduData import get_data\n", + "\n", + "get_data(\"ktbd-a0910\", \"../../data\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/skt/__init__.py b/tests/skt/__init__.py new file mode 100644 index 0000000..2fa4817 --- /dev/null +++ b/tests/skt/__init__.py @@ -0,0 +1,2 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 diff --git a/tests/skt/conftest.py b/tests/skt/conftest.py new file mode 100644 index 0000000..f0a11ee --- /dev/null +++ b/tests/skt/conftest.py @@ -0,0 +1,60 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 +import pytest +import json +from EduKTM.utils.tests import pseudo_data_generation +from EduKTM.SKT.etl import transform + + +@pytest.fixture(scope="package") +def conf(): + ques_num = 10 + hidden_num = 5 + return ques_num, hidden_num + + +@pytest.fixture(scope="package") +def data(conf): + ques_num, _ = conf + return transform(pseudo_data_generation(ques_num), 16) + + +@pytest.fixture(scope="session") +def graphs(tmpdir_factory): + graph_dir = tmpdir_factory.mktemp("data").join("graph_1.json") + _graphs = [] + with open(graph_dir, 'w') as wf: + json.dump([[0, 1], [0, 2]], wf) + _graphs.append([graph_dir, False]) + _graphs.append([graph_dir, True]) + graph_dir = tmpdir_factory.mktemp("data").join("graph_2.json") + with open(graph_dir, 'w') as wf: + json.dump([[1, 2, 0.9], [1, 5, 0.1]], wf) + _graphs.append([graph_dir, True, 0.5]) + return _graphs + + +@pytest.fixture(scope="session") +def graphs_2(tmpdir_factory): + graph_dir = tmpdir_factory.mktemp("data").join("graph_3.json") + _graphs = [] + with open(graph_dir, 'w') as wf: + json.dump([[0, 1], [0, 2]], wf) + _graphs.append((graph_dir, False)) + _graphs.append([graph_dir, True]) + graph_dir = tmpdir_factory.mktemp("data").join("graph_4.json") + with open(graph_dir, 'w') as wf: + json.dump([[1, 2, 0.9], [1, 5, 0.1]], wf) + _graphs.append([graph_dir, True, -1]) + return _graphs + + +@pytest.fixture(scope="session") +def graphs_3(tmpdir_factory): + graph_dir = tmpdir_factory.mktemp("data").join("graph_5.json") + _graphs = [] + with open(graph_dir, 'w') as wf: + json.dump([], wf) + _graphs.append((graph_dir, False)) + _graphs.append([graph_dir, True]) + return _graphs diff --git a/tests/skt/test_skt.py b/tests/skt/test_skt.py new file mode 100644 index 0000000..3fad8ad --- /dev/null +++ b/tests/skt/test_skt.py @@ -0,0 +1,23 @@ +# coding: utf-8 +# 2023/3/17 @ weizhehuang0827 +from EduKTM import SKT +from EduKTM.SKT.utils import Graph + + +def test_train(data, conf, graphs, tmp_path): + ku_num, hidden_num = conf + mgkt = SKT(ku_num, graphs, hidden_num) + mgkt.train(data, test_data=data, epoch=1) + filepath = tmp_path / "skt.params" + mgkt.save(filepath) + mgkt.load(filepath) + + +def test_graph(conf, graphs_2): + ku_num, _ = conf + graph = Graph.from_file(ku_num, graphs_2) + graph.info + graph.neighbors(0, excluded=[1]) + none_graph = Graph(ku_num, [], []) + none_graph.neighbors(0) + none_graph.successors(0)