diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
new file mode 100644
index 0000000..90047e5
--- /dev/null
+++ b/.devcontainer/Dockerfile
@@ -0,0 +1,16 @@
+FROM python:3.8
+USER root
+
+RUN apt-get update
+RUN apt-get -y install locales && \
+ localedef -f UTF-8 -i ja_JP ja_JP.UTF-8
+RUN apt-get install -y vim
+
+ENV LANG ja_JP.UTF-8
+ENV LANGUAGE ja_JP:ja
+ENV LC_ALL ja_JP.UTF-8
+ENV TZ JST-9
+
+RUN pip install --upgrade pip
+RUN pip install --upgrade setuptools
+RUN pip install poetry poetry-dynamic-versioning
\ No newline at end of file
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 0000000..e275622
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,10 @@
+{
+ "name": "PyTorchCML",
+ "dockerComposeFile": "docker-compose.yml",
+ "extensions": [
+ "ms-python.python",
+ ],
+ "service": "python",
+ "workspaceFolder": "/work",
+ "shutdownAction": "stopCompose"
+}
\ No newline at end of file
diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml
new file mode 100644
index 0000000..4edc229
--- /dev/null
+++ b/.devcontainer/docker-compose.yml
@@ -0,0 +1,10 @@
+version: "3"
+
+services:
+ python:
+ build: .
+ volumes:
+ - ../:/work
+ command: sleep infinity
+
+# chmod 777 ./work/build && .work/build.sh &&
diff --git a/PyTorchCML/models/BaseEmbeddingModel.py b/PyTorchCML/models/BaseEmbeddingModel.py
index 4438939..1d26476 100644
--- a/PyTorchCML/models/BaseEmbeddingModel.py
+++ b/PyTorchCML/models/BaseEmbeddingModel.py
@@ -3,6 +3,10 @@
import torch
from torch import nn
+import pandas as pd
+from joblib import Parallel, delayed
+from tqdm import tqdm
+
from ..adaptors import BaseAdaptor
@@ -44,7 +48,8 @@ def __init__(
)
else:
- self.user_embedding = nn.Embedding.from_pretrained(user_embedding_init)
+ self.user_embedding = nn.Embedding.from_pretrained(
+ user_embedding_init)
self.user_embedding.weight.requires_grad = True
if item_embedding_init is None:
@@ -52,7 +57,8 @@ def __init__(
n_item, n_dim, sparse=False, max_norm=max_norm
)
else:
- self.item_embedding = nn.Embedding.from_pretrained(item_embedding_init)
+ self.item_embedding = nn.Embedding.from_pretrained(
+ item_embedding_init)
self.item_embedding.weight.requires_grad = True
def forward(
@@ -112,3 +118,40 @@ def get_item_weight(self, users: torch.Tensor) -> torch.Tensor:
torch.Tensor: Tensor of weight size (n, n_item)
"""
raise NotImplementedError
+
+ def get_topk_items(self, users: torch.Tensor, k: int, num_batch: int = 100, n_jobs: int = -1):
+ """Method of getting top k items for for each user.
+ Args:
+ users (torch.Tensor): 1d tensor of user_id size (n).
+ k : number of top items.
+ num_batch : number of users for a batch.
+ n_job : number of using process.
+
+ Returns:
+ pd.DataFrame: dataframe of topk items for each user which columns are ["user", "item", "score"]
+ """
+
+ batches = torch.split(users, num_batch)
+ inputs = tqdm(batches)
+ items = torch.LongTensor(torch.arange(self.n_item))
+
+ def predict_user(i, batch_users, k):
+ users_expand = batch_users.expand(self.n_item, -1).T.reshape(-1, 1)
+ items_expand = items.expand(len(batch_users), -1).reshape(-1, 1)
+ pairs_tensor = torch.cat([users_expand, items_expand], axis=1)
+ pairs_array = pairs_tensor.cpu().detach().numpy()
+ pairs_df = pd.DataFrame(pairs_array, columns=['user', 'item'])
+ score_tensor = self.predict(pairs_tensor)
+ pairs_df['score'] = score_tensor.cpu().detach().numpy()
+ pairs_df = pairs_df.sort_values(
+ by=["user", "score"], ascending=[True, False])
+ topk_pairs = pairs_df.groupby("user").head(k)
+ return i, topk_pairs
+
+ scored = Parallel(n_jobs=n_jobs)(
+ delayed(predict_user)(i, batch_users=batch_users, k=k)
+ for i, batch_users in enumerate(inputs)
+ )
+ scored = sorted(scored, key=lambda x: x[0])
+ scored = [s[1] for s in scored]
+ return pd.concat(scored, axis=0)
diff --git a/PyTorchCML/samplers/BaseSampler.py b/PyTorchCML/samplers/BaseSampler.py
index a8359ac..8f6a2f3 100644
--- a/PyTorchCML/samplers/BaseSampler.py
+++ b/PyTorchCML/samplers/BaseSampler.py
@@ -62,7 +62,8 @@ def __init__(
neutral_cpu = neutral.cpu()
not_negative = torch.cat([train_set_cpu, neutral_cpu])
self.not_negative_flag = csr_matrix(
- (np.ones(not_negative.shape[0]), (not_negative[:, 0], not_negative[:, 1])),
+ (np.ones(not_negative.shape[0]),
+ (not_negative[:, 0], not_negative[:, 1])),
[n_user, n_item],
)
self.not_negative_flag.sum_duplicates()
@@ -70,7 +71,8 @@ def __init__(
# device
if device is None:
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ self.device = torch.device(
+ "cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = device
@@ -86,7 +88,8 @@ def __init__(
pos_weight_pair = pos_weight[train_set[:, 0].cpu()]
else:
- raise NotImplementedError
+ raise ValueError(
+ "The length of pos_weight does not match any of n_user, n_item, or n_positive_pair.")
else: # uniform
pos_weight_pair = torch.ones(train_set.shape[0])
@@ -108,7 +111,8 @@ def __init__(
self.neg_item_weight = torch.Tensor(neg_weight).to(self.device)
else:
- raise NotImplementedError
+ raise ValueError(
+ "The length of neg_weight does not match any of n_user or n_item.")
def get_pos_batch(self) -> torch.Tensor:
"""Method for positive sampling.
@@ -152,6 +156,7 @@ def get_neg_batch(self, users: torch.Tensor) -> torch.Tensor:
else:
neg_sampler = Categorical(probs=self.neg_item_weight)
- neg_samples = neg_sampler.sample([self.batch_size, self.n_neg_samples])
+ neg_samples = neg_sampler.sample(
+ [self.batch_size, self.n_neg_samples])
return neg_samples
diff --git a/PyTorchCML/trainers/BaseTrainer.py b/PyTorchCML/trainers/BaseTrainer.py
index 924393d..a0e78b5 100644
--- a/PyTorchCML/trainers/BaseTrainer.py
+++ b/PyTorchCML/trainers/BaseTrainer.py
@@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
-from tqdm import tqdm
+from tqdm.auto import tqdm
from ..evaluators import BaseEvaluator
from ..losses import BaseLoss
@@ -66,13 +66,17 @@ def fit(
for b in pbar:
# batch sampling
batch = self.sampler.get_pos_batch()
- users = batch[:, self.column_names["user_id"]].reshape(-1, 1)
- pos_items = batch[:, self.column_names["item_id"]].reshape(-1, 1)
+ users = batch[:, self.column_names["user_id"]
+ ].reshape(-1, 1)
+ pos_items = batch[:,
+ self.column_names["item_id"]].reshape(-1, 1)
if self.sampler.two_stage:
neg_candidates = self.sampler.get_and_set_candidates()
- dist = self.model.spreadout_distance(pos_items, neg_candidates)
- self.sampler.set_candidates_weight(dist, self.model.n_dim)
+ dist = self.model.spreadout_distance(
+ pos_items, neg_candidates)
+ self.sampler.set_candidates_weight(
+ dist, self.model.n_dim)
neg_items = self.sampler.get_neg_batch(users.reshape(-1))
@@ -83,7 +87,8 @@ def fit(
embeddings_dict = self.model(users, pos_items, neg_items)
# compute loss
- loss = self.criterion(embeddings_dict, batch, self.column_names)
+ loss = self.criterion(
+ embeddings_dict, batch, self.column_names)
# adding loss for domain adaptation
if self.model.user_adaptor is not None:
@@ -117,4 +122,5 @@ def fit(
valid_scores_sub = valid_evaluator.score(self.model)
valid_scores_sub["epoch"] = ep + 1
valid_scores_sub["loss"] = accum_loss / n_batch
- self.valid_scores = pd.concat([self.valid_scores, valid_scores_sub])
+ self.valid_scores = pd.concat(
+ [self.valid_scores, valid_scores_sub])
diff --git a/build.sh b/build.sh
new file mode 100755
index 0000000..7538a76
--- /dev/null
+++ b/build.sh
@@ -0,0 +1,3 @@
+poetry config virtualenvs.in-project true
+poetry install
+poetry build
\ No newline at end of file
diff --git a/examples/notebooks/movielens_cml.ipynb b/examples/notebooks/movielens_cml.ipynb
index 0760bbb..cd2bd12 100644
--- a/examples/notebooks/movielens_cml.ipynb
+++ b/examples/notebooks/movielens_cml.ipynb
@@ -3,15 +3,17 @@
{
"cell_type": "code",
"execution_count": 1,
+ "metadata": {},
+ "outputs": [],
"source": [
"# !pip install PyTorchCML"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 1,
+ "metadata": {},
+ "outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../../\")\n",
@@ -26,13 +28,13 @@
"from sklearn.model_selection import train_test_split\n",
"from sklearn.decomposition import TruncatedSVD\n",
"from scipy.sparse import csr_matrix"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 2,
+ "metadata": {},
+ "outputs": [],
"source": [
"# download movielens dataset\n",
"movielens = pd.read_csv(\n",
@@ -84,20 +86,20 @@
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"train_set = torch.LongTensor(train_set).to(device)\n",
"test_set = torch.LongTensor(test_set).to(device)\n"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {},
"source": [
"## Defalt"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 3,
+ "metadata": {},
+ "outputs": [],
"source": [
"lr = 1e-3\n",
"n_dim = 10\n",
@@ -113,58 +115,53 @@
"}\n",
"evaluator = evaluators.UserwiseEvaluator(test_set, score_function_dict, ks=[3,5])\n",
"trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler)\n"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 4,
- "source": [
- "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
- "100%|██████████| 943/943 [00:20<00:00, 46.66it/s]\n",
- "epoch1 avg_loss:0.931: 100%|██████████| 256/256 [00:06<00:00, 38.85it/s]\n",
- "epoch2 avg_loss:0.753: 100%|██████████| 256/256 [00:06<00:00, 40.74it/s]\n",
- "epoch3 avg_loss:0.658: 100%|██████████| 256/256 [00:06<00:00, 39.85it/s]\n",
- "epoch4 avg_loss:0.597: 100%|██████████| 256/256 [00:05<00:00, 46.31it/s]\n",
- "epoch5 avg_loss:0.558: 100%|██████████| 256/256 [00:07<00:00, 34.50it/s]\n",
- "epoch6 avg_loss:0.525: 100%|██████████| 256/256 [00:05<00:00, 44.82it/s]\n",
- "epoch7 avg_loss:0.500: 100%|██████████| 256/256 [00:06<00:00, 42.24it/s]\n",
- "epoch8 avg_loss:0.476: 100%|██████████| 256/256 [00:07<00:00, 35.35it/s]\n",
- "epoch9 avg_loss:0.455: 100%|██████████| 256/256 [00:07<00:00, 34.50it/s]\n",
- "epoch10 avg_loss:0.433: 100%|██████████| 256/256 [00:06<00:00, 41.74it/s]\n",
- "100%|██████████| 943/943 [00:21<00:00, 44.59it/s]\n",
- "epoch11 avg_loss:0.412: 100%|██████████| 256/256 [00:06<00:00, 41.59it/s]\n",
- "epoch12 avg_loss:0.387: 100%|██████████| 256/256 [00:06<00:00, 39.21it/s]\n",
- "epoch13 avg_loss:0.368: 100%|██████████| 256/256 [00:06<00:00, 42.29it/s]\n",
- "epoch14 avg_loss:0.350: 100%|██████████| 256/256 [00:06<00:00, 40.07it/s]\n",
- "epoch15 avg_loss:0.334: 100%|██████████| 256/256 [00:06<00:00, 42.41it/s]\n",
- "epoch16 avg_loss:0.319: 100%|██████████| 256/256 [00:06<00:00, 42.64it/s]\n",
- "epoch17 avg_loss:0.303: 100%|██████████| 256/256 [00:06<00:00, 41.52it/s]\n",
- "epoch18 avg_loss:0.294: 100%|██████████| 256/256 [00:06<00:00, 41.94it/s]\n",
- "epoch19 avg_loss:0.283: 100%|██████████| 256/256 [00:05<00:00, 44.28it/s]\n",
- "epoch20 avg_loss:0.274: 100%|██████████| 256/256 [00:06<00:00, 41.52it/s]\n",
- "100%|██████████| 943/943 [00:21<00:00, 43.05it/s]\n"
+ "100%|██████████| 943/943 [00:22<00:00, 42.54it/s]\n",
+ "epoch1 avg_loss:0.934: 100%|██████████| 256/256 [00:07<00:00, 33.02it/s]\n",
+ "epoch2 avg_loss:0.758: 100%|██████████| 256/256 [00:08<00:00, 30.57it/s]\n",
+ "epoch3 avg_loss:0.660: 100%|██████████| 256/256 [00:07<00:00, 33.52it/s]\n",
+ "epoch4 avg_loss:0.599: 100%|██████████| 256/256 [00:07<00:00, 33.07it/s]\n",
+ "epoch5 avg_loss:0.555: 100%|██████████| 256/256 [00:07<00:00, 32.99it/s]\n",
+ "epoch6 avg_loss:0.525: 100%|██████████| 256/256 [00:07<00:00, 32.41it/s]\n",
+ "epoch7 avg_loss:0.501: 100%|██████████| 256/256 [00:09<00:00, 27.15it/s]\n",
+ "epoch8 avg_loss:0.478: 100%|██████████| 256/256 [00:09<00:00, 27.08it/s]\n",
+ "epoch9 avg_loss:0.455: 100%|██████████| 256/256 [00:08<00:00, 31.50it/s]\n",
+ "epoch10 avg_loss:0.438: 100%|██████████| 256/256 [00:08<00:00, 29.36it/s]\n",
+ "100%|██████████| 943/943 [00:24<00:00, 38.02it/s]\n",
+ "epoch11 avg_loss:0.411: 100%|██████████| 256/256 [00:08<00:00, 29.68it/s]\n",
+ "epoch12 avg_loss:0.390: 100%|██████████| 256/256 [00:08<00:00, 30.24it/s]\n",
+ "epoch13 avg_loss:0.366: 100%|██████████| 256/256 [00:08<00:00, 29.65it/s]\n",
+ "epoch14 avg_loss:0.347: 100%|██████████| 256/256 [00:08<00:00, 29.27it/s]\n",
+ "epoch15 avg_loss:0.330: 100%|██████████| 256/256 [00:09<00:00, 28.05it/s]\n",
+ "epoch16 avg_loss:0.313: 100%|██████████| 256/256 [00:08<00:00, 30.14it/s]\n",
+ "epoch17 avg_loss:0.297: 100%|██████████| 256/256 [00:08<00:00, 29.47it/s]\n",
+ "epoch18 avg_loss:0.285: 100%|██████████| 256/256 [00:09<00:00, 26.84it/s]\n",
+ "epoch19 avg_loss:0.276: 100%|██████████| 256/256 [00:09<00:00, 28.35it/s]\n",
+ "epoch20 avg_loss:0.266: 100%|██████████| 256/256 [00:09<00:00, 27.31it/s]\n",
+ "100%|██████████| 943/943 [00:26<00:00, 35.80it/s]\n"
]
}
],
- "metadata": {}
+ "source": [
+ "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
+ ]
},
{
"cell_type": "code",
"execution_count": 5,
- "source": [
- "trainer.valid_scores"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -198,36 +195,36 @@
"
\n",
" \n",
" | 0 | \n",
- " 0.007423 | \n",
- " 0.012902 | \n",
- " 0.001568 | \n",
- " 0.008634 | \n",
- " 0.017922 | \n",
- " 0.002917 | \n",
+ " 0.011286 | \n",
+ " 0.019883 | \n",
+ " 0.001652 | \n",
+ " 0.011248 | \n",
+ " 0.024576 | \n",
+ " 0.002976 | \n",
" 0 | \n",
" NaN | \n",
"
\n",
" \n",
" | 0 | \n",
- " 0.042387 | \n",
- " 0.070078 | \n",
- " 0.006008 | \n",
- " 0.046897 | \n",
- " 0.084353 | \n",
- " 0.011537 | \n",
+ " 0.046553 | \n",
+ " 0.072022 | \n",
+ " 0.005758 | \n",
+ " 0.048727 | \n",
+ " 0.084752 | \n",
+ " 0.010512 | \n",
" 10 | \n",
- " 0.432954 | \n",
+ " 0.437877 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 0.202032 | \n",
- " 0.291888 | \n",
- " 0.044877 | \n",
- " 0.202131 | \n",
- " 0.311188 | \n",
- " 0.073077 | \n",
+ " 0.198986 | \n",
+ " 0.272269 | \n",
+ " 0.038865 | \n",
+ " 0.201747 | \n",
+ " 0.292173 | \n",
+ " 0.068094 | \n",
" 20 | \n",
- " 0.274121 | \n",
+ " 0.266131 | \n",
"
\n",
" \n",
"\n",
@@ -235,27 +232,151 @@
],
"text/plain": [
" nDCG@3 MAP@3 Recall@3 nDCG@5 MAP@5 Recall@5 epoch loss\n",
- "0 0.007423 0.012902 0.001568 0.008634 0.017922 0.002917 0 NaN\n",
- "0 0.042387 0.070078 0.006008 0.046897 0.084353 0.011537 10 0.432954\n",
- "0 0.202032 0.291888 0.044877 0.202131 0.311188 0.073077 20 0.274121"
+ "0 0.011286 0.019883 0.001652 0.011248 0.024576 0.002976 0 NaN\n",
+ "0 0.046553 0.072022 0.005758 0.048727 0.084752 0.010512 10 0.437877\n",
+ "0 0.198986 0.272269 0.038865 0.201747 0.292173 0.068094 20 0.266131"
]
},
+ "execution_count": 5,
"metadata": {},
- "execution_count": 5
+ "output_type": "execute_result"
}
],
- "metadata": {}
+ "source": [
+ "trainer.valid_scores"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 10/10 [00:06<00:00, 1.59it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user | \n",
+ " item | \n",
+ " score | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 155 | \n",
+ " 0 | \n",
+ " 155 | \n",
+ " 1.500806 | \n",
+ "
\n",
+ " \n",
+ " | 134 | \n",
+ " 0 | \n",
+ " 134 | \n",
+ " 1.471315 | \n",
+ "
\n",
+ " \n",
+ " | 181 | \n",
+ " 0 | \n",
+ " 181 | \n",
+ " 1.450132 | \n",
+ "
\n",
+ " \n",
+ " | 1696 | \n",
+ " 1 | \n",
+ " 14 | \n",
+ " 1.589973 | \n",
+ "
\n",
+ " \n",
+ " | 1939 | \n",
+ " 1 | \n",
+ " 257 | \n",
+ " 1.584218 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user item score\n",
+ "155 0 155 1.500806\n",
+ "134 0 134 1.471315\n",
+ "181 0 181 1.450132\n",
+ "1696 1 14 1.589973\n",
+ "1939 1 257 1.584218"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "users = torch.LongTensor(torch.arange(n_user))\n",
+ "topk_items_df = model.get_topk_items(users, k=3)\n",
+ "topk_items_df.head(5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "user\n",
+ "0 [155, 134, 181]\n",
+ "1 [14, 257, 274]\n",
+ "2 [545, 72, 146]\n",
+ "3 [323, 545, 852]\n",
+ "4 [3, 185, 654]\n",
+ "Name: item, dtype: object"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "topk_items_df.groupby(\"user\")[\"item\"].unique().head()"
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {},
"source": [
"## Strict Negative"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 6,
+ "metadata": {},
+ "outputs": [],
"source": [
"lr = 1e-3\n",
"n_dim = 10\n",
@@ -271,20 +392,18 @@
"}\n",
"evaluator = evaluators.UserwiseEvaluator(test_set, score_function_dict, ks=[3,5])\n",
"trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler)\n"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 7,
- "source": [
- "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
- ],
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
"100%|██████████| 943/943 [00:18<00:00, 51.01it/s]\n",
"epoch1 avg_loss:0.949: 100%|██████████| 256/256 [00:09<00:00, 26.99it/s]\n",
@@ -312,19 +431,16 @@
]
}
],
- "metadata": {
- "tags": []
- }
+ "source": [
+ "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
+ ]
},
{
"cell_type": "code",
"execution_count": 8,
- "source": [
- "trainer.valid_scores"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -400,22 +516,27 @@
"0 0.273629 0.369123 0.034714 0.272501 0.385817 0.059150 20 0.287908"
]
},
+ "execution_count": 8,
"metadata": {},
- "execution_count": 8
+ "output_type": "execute_result"
}
],
- "metadata": {}
+ "source": [
+ "trainer.valid_scores"
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {},
"source": [
"## Global Orthogonal Regularization"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 9,
+ "metadata": {},
+ "outputs": [],
"source": [
"lr = 1e-3\n",
"n_dim = 10\n",
@@ -432,20 +553,16 @@
"}\n",
"evaluator = evaluators.UserwiseEvaluator(test_set, score_function_dict, ks=[3,5])\n",
"trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler)"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 10,
- "source": [
- "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
"100%|██████████| 943/943 [00:21<00:00, 44.39it/s]\n",
"epoch1 avg_loss:0.948: 100%|██████████| 256/256 [00:11<00:00, 21.50it/s]\n",
@@ -473,17 +590,16 @@
]
}
],
- "metadata": {}
+ "source": [
+ "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
+ ]
},
{
"cell_type": "code",
"execution_count": 11,
- "source": [
- "trainer.valid_scores"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -559,34 +675,39 @@
"0 0.276036 0.380877 0.035444 0.276884 0.395498 0.060833 20 0.281246"
]
},
+ "execution_count": 11,
"metadata": {},
- "execution_count": 11
+ "output_type": "execute_result"
}
],
- "metadata": {}
+ "source": [
+ "trainer.valid_scores"
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {},
"source": [
"## Two Stage"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 3,
+ "metadata": {},
+ "outputs": [],
"source": [
"item_count = train.groupby(\"item_id\")[\"user_id\"].count()\n",
"count_index = np.array(item_count.index)\n",
"neg_weight = np.zeros(n_item)\n",
"neg_weight[count_index] = item_count ** 0.1"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 4,
+ "metadata": {},
+ "outputs": [],
"source": [
"lr = 1e-3\n",
"n_dim = 10\n",
@@ -608,20 +729,16 @@
"}\n",
"evaluator = evaluators.UserwiseEvaluator(test_set, score_function_dict, ks=[3,5])\n",
"trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler)"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 5,
- "source": [
- "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
"100%|██████████| 943/943 [00:27<00:00, 34.76it/s]\n",
"epoch1 avg_loss:1.495: 100%|██████████| 256/256 [00:08<00:00, 31.49it/s]\n",
@@ -649,17 +766,16 @@
]
}
],
- "metadata": {}
+ "source": [
+ "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
+ ]
},
{
"cell_type": "code",
"execution_count": 6,
- "source": [
- "trainer.valid_scores"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -735,22 +851,27 @@
"0 0.356546 0.484093 0.052573 0.326033 0.484409 0.074481 20 1.001474"
]
},
+ "execution_count": 6,
"metadata": {},
- "execution_count": 6
+ "output_type": "execute_result"
}
],
- "metadata": {}
+ "source": [
+ "trainer.valid_scores"
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {},
"source": [
"## model weighted negative sampler"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 4,
+ "metadata": {},
+ "outputs": [],
"source": [
"def svd_init(X, dim):\n",
" \"\"\"\n",
@@ -769,13 +890,13 @@
" vb = (2 / n_dim) ** 0.5 * V_.sum(axis=0) * s\n",
"\n",
" return U, V, ub, vb"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 5,
+ "metadata": {},
+ "outputs": [],
"source": [
"n_dim = 10\n",
"X = csr_matrix(\n",
@@ -791,13 +912,13 @@
" item_bias_init = torch.Tensor(vb)\n",
").to(device)\n",
"neg_weight_model.link_weight = lambda x : 1 - torch.sigmoid(x)"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 6,
+ "metadata": {},
+ "outputs": [],
"source": [
"lr = 1e-3\n",
"model = models.CollaborativeMetricLearning(n_user, n_item, n_dim).to(device)\n",
@@ -816,20 +937,16 @@
"}\n",
"evaluator = evaluators.UserwiseEvaluator(test_set, score_function_dict, ks=[3,5])\n",
"trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler)"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 7,
- "source": [
- "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
"100%|██████████| 943/943 [00:16<00:00, 55.70it/s]\n",
"epoch1 avg_loss:0.968: 100%|██████████| 256/256 [00:05<00:00, 44.73it/s]\n",
@@ -857,17 +974,16 @@
]
}
],
- "metadata": {}
+ "source": [
+ "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
+ ]
},
{
"cell_type": "code",
"execution_count": 8,
- "source": [
- "trainer.valid_scores"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -943,62 +1059,62 @@
"0 0.233268 0.322552 0.030232 0.23401 0.336276 0.049536 20 0.430135"
]
},
+ "execution_count": 8,
"metadata": {},
- "execution_count": 8
+ "output_type": "execute_result"
}
],
- "metadata": {}
+ "source": [
+ "trainer.valid_scores"
+ ]
},
{
"cell_type": "code",
"execution_count": null,
- "source": [],
+ "metadata": {},
"outputs": [],
- "metadata": {}
+ "source": []
},
{
"cell_type": "code",
"execution_count": 14,
- "source": [],
+ "metadata": {},
"outputs": [],
- "metadata": {}
+ "source": []
},
{
"cell_type": "markdown",
+ "metadata": {},
"source": [
"# Domain Adaptation"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 3,
+ "metadata": {},
+ "outputs": [],
"source": [
"from PyTorchCML import adaptors"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 4,
+ "metadata": {},
+ "outputs": [],
"source": [
"df_item = pd.read_csv('http://files.grouplens.org/datasets/movielens/ml-100k/u.item' , sep=\"|\", header=None, encoding='latin-1')\n",
"item_feature = df_item.iloc[:, -19:]\n",
"item_feature_torch = torch.Tensor(item_feature.values)"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 6,
- "source": [
- "item_feature"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -1318,15 +1434,20 @@
"[1682 rows x 19 columns]"
]
},
+ "execution_count": 6,
"metadata": {},
- "execution_count": 6
+ "output_type": "execute_result"
}
],
- "metadata": {}
+ "source": [
+ "item_feature"
+ ]
},
{
"cell_type": "code",
"execution_count": 6,
+ "metadata": {},
+ "outputs": [],
"source": [
"lr = 1e-3\n",
"n_dim = 10\n",
@@ -1343,20 +1464,16 @@
"}\n",
"evaluator = evaluators.UserwiseEvaluator(test_set, score_function_dict, ks=[3,5])\n",
"trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler)\n"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": 7,
- "source": [
- "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
"100%|██████████| 943/943 [00:18<00:00, 51.76it/s]\n",
"epoch1 avg_loss:1.192: 100%|██████████| 256/256 [00:05<00:00, 50.56it/s]\n",
@@ -1384,17 +1501,16 @@
]
}
],
- "metadata": {}
+ "source": [
+ "trainer.fit(n_batch=256, n_epoch=20, valid_evaluator = evaluator, valid_per_epoch=10)"
+ ]
},
{
"cell_type": "code",
"execution_count": 8,
- "source": [
- "trainer.valid_scores"
- ],
+ "metadata": {},
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -1470,24 +1586,30 @@
"0 0.243347 0.335101 0.054578 0.234080 0.344372 0.080134 20 0.400818"
]
},
+ "execution_count": 8,
"metadata": {},
- "execution_count": 8
+ "output_type": "execute_result"
}
],
- "metadata": {}
+ "source": [
+ "trainer.valid_scores"
+ ]
},
{
"cell_type": "code",
"execution_count": null,
- "source": [],
+ "metadata": {},
"outputs": [],
- "metadata": {}
+ "source": []
}
],
"metadata": {
+ "interpreter": {
+ "hash": "1a6e8c4c71356cfd7f7f45384d81183fdca12e98ad893ee020bd76249bbd6be9"
+ },
"kernelspec": {
- "name": "python3",
- "display_name": "Python 3.8.6 64-bit ('pytorchcml-MJCCLiEQ-py3.8': poetry)"
+ "display_name": "Python 3.8.6 64-bit ('pytorchcml-MJCCLiEQ-py3.8': poetry)",
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -1499,12 +1621,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.6"
- },
- "interpreter": {
- "hash": "1a6e8c4c71356cfd7f7f45384d81183fdca12e98ad893ee020bd76249bbd6be9"
+ "version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
-}
\ No newline at end of file
+}
diff --git a/tests/models/test_CollaborativeMetricLearning.py b/tests/models/test_CollaborativeMetricLearning.py
index c71eed9..f06b198 100644
--- a/tests/models/test_CollaborativeMetricLearning.py
+++ b/tests/models/test_CollaborativeMetricLearning.py
@@ -64,3 +64,21 @@ def test_spreadout_distance(self):
# y_hat shape
shape = so_dist.shape
self.assertEqual(shape, torch.Size([2, 3]))
+
+ def test_get_topk_items(self):
+ n_user = 1000
+ k = 3
+
+ model = CollaborativeMetricLearning(
+ n_user=n_user,
+ n_item=100,
+ n_dim=10,
+ )
+
+ users = torch.LongTensor(torch.arange(n_user))
+ topk_items_df = model.get_topk_items(users, k=k)
+ n_items_per_user = topk_items_df.groupby("user")["item"].count().mean()
+ n, m = topk_items_df.shape
+ self.assertEqual(n, n_user * k)
+ self.assertEqual(m, 3)
+ self.assertEqual(n_items_per_user, k)