Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bda13ff
Remove tensorflow components
timokau Mar 12, 2021
e4e907f
Mention the migration status in the README
timokau Apr 3, 2021
05999f8
Remove tensorflow and keras dependencies
timokau Mar 12, 2021
4b1b38a
Avoid ambiguous variable names
timokau Mar 15, 2021
172722e
Mark poetry2nix setup as broken
timokau Apr 3, 2021
b596113
Update nixpkgs
timokau Mar 17, 2021
4bb3df3
Make the nix definitions reusable
timokau Mar 16, 2021
ae7e44c
Update to python 3.8
timokau Mar 15, 2021
8dcd6b3
Add torch dependency
timokau Mar 15, 2021
614cd54
Add skorch dependency
timokau Mar 15, 2021
a009ef9
Update the dependencies in the README
timokau Apr 3, 2021
ddeb39d
Sort estimator class listings alphabetically
timokau Apr 3, 2021
ffaa488
Add pytorch losses and metrics
timokau Mar 15, 2021
99a2aa9
Add superclasses for skorch based estimators
timokau Apr 8, 2021
d643736
Add pytorch FATE estimators
timokau Apr 8, 2021
b48bba1
Prepare for pytorch tests
timokau Mar 23, 2021
4a46a3d
Always use 32 bit floats in tests
timokau Mar 23, 2021
93f5096
Remove some of the fit special casing in tests
timokau Apr 1, 2021
14a6aee
Remove star imports in tests
timokau Apr 3, 2021
0f2b0bc
Add the pytorch based FATE estimators to the tests
timokau Mar 22, 2021
63439d1
Deduplicate the README
timokau Apr 3, 2021
553eaaf
Add a changelog entry for the pytorch migration
timokau Apr 3, 2021
6d14786
Mark FATE as available again in the README
timokau Apr 3, 2021
0eab5b1
Define common skorch arguments in the tests
timokau Apr 7, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: python
python:
- 3.7.9 # Pinned since tensorflow 1.x is not available for python > 3.7
- 3.8

cache:
directories:
Expand All @@ -23,18 +23,16 @@ stages:
jobs:
fast_finish: true
include:
- python: 3.7.9
- python: 3.8
env: TOXENV=test1
- python: 3.7.9
- python: 3.8
env: TOXENV=test2
- python: 3.7.9
- python: 3.8
env: TOXENV=test3
- python: 3.7.9
env: TOXENV=test4
- python: 3.7.9
- python: 3.8
env: TOXENV=lint
- stage: docs
python: 3.7.9
python: 3.8
env: TOXENV=docs

before_deploy:
Expand All @@ -47,6 +45,6 @@ deploy:
script: poetry publish -v --build
on:
tags: true
python: 3.7.9
python: 3.8
repo: kiudee/cs-ranking
branch: master
9 changes: 8 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
History
=======

Unreleased
2.0.0 (Unreleased)
------------------

* The library has been migrated to pytorch. This is a breaking change. You will
likely need to adapt to this new version if you have been using estimators
from version 1.x.

1.3.0 (Unreleased)
------------------

* We no longer override any of the defaults of our default optimizer (SGD). In
Expand Down
22 changes: 15 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
|Build Status| |Coverage| |Binder|

****
NOTE
****

This repository is currently in the process of a migration from tensorflow to
PyTorch. You should use the latest released version if you are not interested
in the partial PyTorch implementation.

*******
CS-Rank
*******

CS-Rank is a Python package for context-sensitive ranking and choice
algorithms.

We implement the following new object ranking/choice architectures:

* FATE (First aggregate then evaluate)
* FETA (First evaluate then aggregate)
* FETA (First evaluate then aggregate) (currently not available due to the
PyTorch migration)

In addition, we also implement these algorithms for choice functions:

* RankNetChoiceFunction
* RankNetChoiceFunction (currently not available due to the PyTorch migration)
* GeneralizedLinearModel
* PairwiseSVMChoiceFunction

Expand All @@ -24,12 +34,10 @@ setting:
* MixedLogitModel
* NestedLogitModel
* PairedCombinatorialLogit
* RankNetDiscreteChoiceFunction
* RankNetDiscreteChoiceFunction (currently not available due to the PyTorch
migration)
* PairwiseSVMDiscreteChoiceFunction

Check out our `interactive notebooks`_ to quickly find out what our package can
do.


Getting started
===============
Expand Down Expand Up @@ -73,7 +81,7 @@ Another option is to clone the repository and install CS-Rank using::

Dependencies
------------
CS-Rank depends on Tensorflow, Keras, NumPy, SciPy, matplotlib, scikit-learn,
CS-Rank depends on PyTorch, skorch, NumPy, SciPy, matplotlib, scikit-learn,
joblib and tqdm. For data processing and generation you will
also need PyGMO, H5Py and pandas.

Expand Down
203 changes: 0 additions & 203 deletions csrank/callbacks.py

This file was deleted.

10 changes: 0 additions & 10 deletions csrank/choicefunction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
from .baseline import AllPositive
from .cmpnet_choice import CmpNetChoiceFunction
from .fate_choice import FATEChoiceFunction
from .fatelinear_choice import FATELinearChoiceFunction
from .feta_choice import FETAChoiceFunction
from .fetalinear_choice import FETALinearChoiceFunction
from .generalized_linear_model import GeneralizedLinearModel
from .pairwise_choice import PairwiseSVMChoiceFunction
from .ranknet_choice import RankNetChoiceFunction

__all__ = [
"AllPositive",
"CmpNetChoiceFunction",
"FATEChoiceFunction",
"FATELinearChoiceFunction",
"FETAChoiceFunction",
"FETALinearChoiceFunction",
"GeneralizedLinearModel",
"PairwiseSVMChoiceFunction",
"RankNetChoiceFunction",
]
49 changes: 49 additions & 0 deletions csrank/choicefunction/choice_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import logging

import numpy as np
import skorch
import torch.nn as nn

from csrank.constants import CHOICE_FUNCTION
from csrank.learner import SkorchInstanceEstimator
from csrank.metrics_np import f1_measure
from csrank.util import progress_bar

Expand Down Expand Up @@ -69,3 +72,49 @@ def _tune_threshold(self, X_val, Y_val, thin_thresholds=1, verbose=0):
" a micro F1-measure of {:.2f}".format(threshold, best)
)
return threshold


class SkorchChoiceFunction(ChoiceFunctions, SkorchInstanceEstimator):
"""A variable choice estimator based on some scoring module.

This estimator takes a scoring module and combines it with a sigmoid
activation to predict scores between 0 and 1. The choice is then made based
on a fixed threshold value. This makes it very simple to derive new
estimators with any given scoring function. Refer to skorch's documentation
for supported parameters. For example the optimizer or the optimizer's
learning rate could be overridden.

Parameters
----------
module : torch module (class)
This is the scoring module. It should be an uninstantiated
``torch.nn.Module`` class that expects the number of features per
object as its only parameter on initialization.

criterion : torch criterion (class)
The criterion that is used to evaluate and optimize the module.

threshold : float
The threshold value that is used to convert scores to a choice. Must be
between 0 and 1. Defaults to 0.5.

**kwargs : skorch NeuralNet arguments
All keyword arguments are passed to the constructor of
``skorch.NeuralNet``. See the documentation of that class for more
details.
"""

def __init__(self, module, criterion=nn.BCELoss, threshold=0.5, **kwargs):
super().__init__(module=module, criterion=criterion, **kwargs)
# The scoring is trained to predict something close to "0" for
# non-chosen values, something close to "1" for chosen values. So 0.5
# is a natural threshold. It would be possible to additionally tune
# that threshold.
self.threshold_ = threshold

def initialize_module(self, *args, **kwargs):
params = self.get_params_for("module")
# Add a Sigmoid activation since the resulting "scores" should be
# between 0 and 1.
self.module_ = nn.Sequential(self.module(**params), nn.Sigmoid())
self.module_ = skorch.utils.to_device(self.module_, self.device)
Loading