|
1 | | -# torchTextClassifiers : Efficient text classification with PyTorch |
| 1 | +# torchTextClassifiers |
2 | 2 |
|
3 | | -A flexible PyTorch implementation of models for text classification with support for categorical features. |
| 3 | +A unified, extensible framework for text classification using PyTorch and PyTorch Lightning. |
4 | 4 |
|
5 | | -## Features |
| 5 | +## 🚀 Features |
6 | 6 |
|
7 | | -- Supports text classification with FastText architecture |
8 | | -- Handles both text and categorical features |
9 | | -- N-gram tokenization |
10 | | -- Flexible optimizer and scheduler options |
11 | | -- GPU and CPU support |
12 | | -- Model checkpointing and early stopping |
13 | | -- Prediction and model explanation capabilities |
| 7 | +- **Unified API**: Consistent interface for different classifier types |
| 8 | +- **FastText Support**: Built-in FastText classifier implementation |
| 9 | +- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging |
| 10 | +- **Mixed Features**: Support for both text and categorical features |
| 11 | +- **Extensible**: Easy to add new classifier types |
| 12 | +- **Production Ready**: Model serialization, validation, and inference |
14 | 13 |
|
15 | | -## Installation |
16 | | - |
17 | | -- With `pip`: |
| 14 | +## 📦 Installation |
18 | 15 |
|
19 | 16 | ```bash |
20 | | -pip install torchTextClassifiers |
| 17 | +# Clone the repository |
| 18 | +git clone https://github.com/your-repo/torch-fastText.git |
| 19 | +cd torch-fastText |
| 20 | + |
| 21 | +# Install with uv (recommended) |
| 22 | +uv sync |
| 23 | + |
| 24 | +# Or install with pip |
| 25 | +pip install -e . |
21 | 26 | ``` |
22 | 27 |
|
23 | | -- with `uv`: |
| 28 | +## 🎯 Quick Start |
24 | 29 |
|
| 30 | +### Basic FastText Classification |
25 | 31 |
|
26 | | -```bash |
27 | | -uv add torchTextClassifiers |
| 32 | +```python |
| 33 | +import numpy as np |
| 34 | +from torchTextClassifiers import create_fasttext |
| 35 | + |
| 36 | +# Create a FastText classifier |
| 37 | +classifier = create_fasttext( |
| 38 | + embedding_dim=100, |
| 39 | + sparse=False, |
| 40 | + num_tokens=10000, |
| 41 | + min_count=2, |
| 42 | + min_n=3, |
| 43 | + max_n=6, |
| 44 | + len_word_ngrams=2, |
| 45 | + num_classes=2 |
| 46 | +) |
| 47 | + |
| 48 | +# Prepare your data |
| 49 | +X_train = np.array([ |
| 50 | + "This is a positive example", |
| 51 | + "This is a negative example", |
| 52 | + "Another positive case", |
| 53 | + "Another negative case" |
| 54 | +]) |
| 55 | +y_train = np.array([1, 0, 1, 0]) |
| 56 | + |
| 57 | +X_val = np.array([ |
| 58 | + "Validation positive", |
| 59 | + "Validation negative" |
| 60 | +]) |
| 61 | +y_val = np.array([1, 0]) |
| 62 | + |
| 63 | +# Build the model |
| 64 | +classifier.build(X_train, y_train) |
| 65 | + |
| 66 | +# Train the model |
| 67 | +classifier.train( |
| 68 | + X_train, y_train, X_val, y_val, |
| 69 | + num_epochs=50, |
| 70 | + batch_size=32, |
| 71 | + patience_train=5, |
| 72 | + verbose=True |
| 73 | +) |
| 74 | + |
| 75 | +# Make predictions |
| 76 | +X_test = np.array(["This is a test sentence"]) |
| 77 | +predictions = classifier.predict(X_test) |
| 78 | +print(f"Predictions: {predictions}") |
| 79 | + |
| 80 | +# Validate on test set |
| 81 | +accuracy = classifier.validate(X_test, np.array([1])) |
| 82 | +print(f"Accuracy: {accuracy:.3f}") |
28 | 83 | ``` |
29 | 84 |
|
30 | | -## Key Components |
| 85 | +### Working with Mixed Features (Text + Categorical) |
| 86 | + |
| 87 | +```python |
| 88 | +import numpy as np |
| 89 | +from torchTextClassifiers import create_fasttext |
| 90 | + |
| 91 | +# Text data with categorical features |
| 92 | +X_train = np.column_stack([ |
| 93 | + np.array(["Great product!", "Terrible service", "Love it!"]), # Text |
| 94 | + np.array([[1, 2], [2, 1], [1, 3]]) # Categorical features |
| 95 | +]) |
| 96 | +y_train = np.array([1, 0, 1]) |
| 97 | + |
| 98 | +# Create classifier with categorical support |
| 99 | +classifier = create_fasttext( |
| 100 | + embedding_dim=50, |
| 101 | + sparse=False, |
| 102 | + num_tokens=5000, |
| 103 | + min_count=1, |
| 104 | + min_n=3, |
| 105 | + max_n=6, |
| 106 | + len_word_ngrams=2, |
| 107 | + num_classes=2, |
| 108 | + categorical_vocabulary_sizes=[3, 4], # Vocab sizes for categorical features |
| 109 | + categorical_embedding_dims=[10, 10] # Embedding dims for categorical features |
| 110 | +) |
| 111 | + |
| 112 | +# Build and train as usual |
| 113 | +classifier.build(X_train, y_train) |
| 114 | +# ... continue with training |
| 115 | +``` |
| 116 | + |
| 117 | +### Model Persistence |
| 118 | + |
| 119 | +```python |
| 120 | +# Save configuration |
| 121 | +classifier.to_json('model_config.json') |
| 122 | + |
| 123 | +# Load configuration (creates new instance) |
| 124 | +new_classifier = torchTextClassifiers.from_json('model_config.json') |
31 | 125 |
|
32 | | -- `build()`: Constructs the FastText model architecture |
33 | | -- `train()`: Trains the model with built-in callbacks and logging |
34 | | -- `predict()`: Generates class predictions |
35 | | -- `predict_and_explain()`: Provides predictions with feature attributions |
| 126 | +# You'll need to retrain the loaded classifier |
| 127 | +new_classifier.build(X_train, y_train) |
| 128 | +new_classifier.train(X_train, y_train, X_val, y_val, ...) |
| 129 | +``` |
36 | 130 |
|
37 | | -## Subpackages |
| 131 | +## 🔧 Advanced Usage |
38 | 132 |
|
39 | | -- `preprocess`: To preprocess text input, using `nltk` and `unidecode` libraries. |
40 | | -- `explainability`: Simple methods to visualize feature attributions at word and letter levels, using `captum`library. |
| 133 | +### Custom Configuration |
41 | 134 |
|
42 | | -Run `pip install torchTextClassifiers[preprocess]` or `pip install torchTextClassifiers[explainability]` to download these optional dependencies. |
| 135 | +```python |
| 136 | +from torchTextClassifiers import torchTextClassifiers, ClassifierType |
| 137 | +from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig |
| 138 | + |
| 139 | +# Create custom configuration |
| 140 | +config = FastTextConfig( |
| 141 | + embedding_dim=200, |
| 142 | + sparse=True, |
| 143 | + num_tokens=20000, |
| 144 | + min_count=3, |
| 145 | + min_n=2, |
| 146 | + max_n=8, |
| 147 | + len_word_ngrams=3, |
| 148 | + num_classes=5, |
| 149 | + direct_bagging=False, # Custom FastText parameter |
| 150 | +) |
43 | 151 |
|
| 152 | +# Create classifier with custom config |
| 153 | +classifier = torchTextClassifiers(ClassifierType.FASTTEXT, config) |
| 154 | +``` |
44 | 155 |
|
45 | | -## Quick Start |
| 156 | +### Using Pre-trained Tokenizers |
46 | 157 |
|
47 | 158 | ```python |
48 | | -from torchTextClassifiers import torchTextClassifiers |
| 159 | +from torchTextClassifiers import build_fasttext_from_tokenizer |
49 | 160 |
|
50 | | -# Initialize the model |
51 | | -model = torchTextclassifiers( |
52 | | - num_tokens=1000000, |
| 161 | +# Assume you have a pre-trained tokenizer |
| 162 | +# my_tokenizer = ... (previously trained NGramTokenizer) |
| 163 | + |
| 164 | +classifier = build_fasttext_from_tokenizer( |
| 165 | + tokenizer=my_tokenizer, |
53 | 166 | embedding_dim=100, |
54 | | - min_count=5, |
55 | | - min_n=3, |
56 | | - max_n=6, |
57 | | - len_word_ngrams=True, |
58 | | - sparse=True |
| 167 | + num_classes=3, |
| 168 | + sparse=False |
59 | 169 | ) |
60 | 170 |
|
61 | | -# Train the model |
62 | | -model.train( |
63 | | - X_train=train_data, |
64 | | - y_train=train_labels, |
65 | | - X_val=val_data, |
66 | | - y_val=val_labels, |
67 | | - num_epochs=10, |
| 171 | +# Model and tokenizer are already built, ready for training |
| 172 | +classifier.train(X_train, y_train, X_val, y_val, ...) |
| 173 | +``` |
| 174 | + |
| 175 | +### Training Customization |
| 176 | + |
| 177 | +```python |
| 178 | +# Custom PyTorch Lightning trainer parameters |
| 179 | +trainer_params = { |
| 180 | + 'accelerator': 'gpu', |
| 181 | + 'devices': 1, |
| 182 | + 'precision': 16, # Mixed precision training |
| 183 | + 'gradient_clip_val': 1.0, |
| 184 | +} |
| 185 | + |
| 186 | +classifier.train( |
| 187 | + X_train, y_train, X_val, y_val, |
| 188 | + num_epochs=100, |
68 | 189 | batch_size=64, |
69 | | - lr=4e-3 |
| 190 | + patience_train=10, |
| 191 | + trainer_params=trainer_params, |
| 192 | + verbose=True |
70 | 193 | ) |
71 | | -# Make predictions |
72 | | -predictions = model.predict(test_data) |
73 | 194 | ``` |
74 | 195 |
|
75 | | -where ```train_data``` is an array of size $(N,d)$, having the text in string format in the first column, the other columns containing tokenized categorical variables in `int` format. |
| 196 | +## 📊 API Reference |
76 | 197 |
|
77 | | -Please make sure `y_train` contains at least one time each possible label. |
| 198 | +### Main Classes |
78 | 199 |
|
79 | | -## Dependencies |
| 200 | +#### `torchTextClassifiers` |
| 201 | +The main classifier class providing a unified interface. |
80 | 202 |
|
81 | | -- PyTorch Lightning |
82 | | -- NumPy |
| 203 | +**Key Methods:** |
| 204 | +- `build(X_train, y_train)`: Build tokenizer and model |
| 205 | +- `train(X_train, y_train, X_val, y_val, ...)`: Train the model |
| 206 | +- `predict(X)`: Make predictions |
| 207 | +- `validate(X, Y)`: Evaluate on test data |
| 208 | +- `to_json(filepath)`: Save configuration |
| 209 | +- `from_json(filepath)`: Load configuration |
83 | 210 |
|
84 | | -## Categorical features |
| 211 | +#### `ClassifierType` |
| 212 | +Enumeration of supported classifier types. |
| 213 | +- `FASTTEXT`: FastText classifier |
85 | 214 |
|
86 | | -If any, each categorical feature $i$ is associated to an embedding matrix of size (number of unique values, embedding dimension) where the latter is a hyperparameter (`categorical_embedding_dims`) - chosen by the user - that can take three types of values: |
| 215 | +#### `ClassifierFactory` |
| 216 | +Factory for creating classifier instances. |
87 | 217 |
|
88 | | -- `None`: same embedding dimension as the token embedding matrix. The categorical embeddings are then summed to the sentence-level embedding (which itself is an averaging of the token embeddings). See [Figure 1](#Default-architecture). |
89 | | -- `int`: the categorical embeddings have all the same embedding dimensions, they are averaged and the resulting vector is concatenated to the sentence-level embedding (the last linear layer has an adapted input size). See [Figure 2](#avg-architecture). |
90 | | -- `list`: the categorical embeddings have different embedding dimensions, all of them are concatenated without aggregation to the sentence-level embedding (the last linear layer has an adapted input size). See [Figure 3](#concat-architecture). |
| 218 | +### FastText Specific |
91 | 219 |
|
92 | | -Default is `None`. |
| 220 | +#### `create_fasttext(**kwargs)` |
| 221 | +Convenience function to create FastText classifiers. |
93 | 222 |
|
94 | | -<a name="figure-1"></a> |
95 | | - |
96 | | -*Figure 1: The 'sum' architecture* |
| 223 | +**Parameters:** |
| 224 | +- `embedding_dim`: Embedding dimension |
| 225 | +- `sparse`: Use sparse embeddings |
| 226 | +- `num_tokens`: Vocabulary size |
| 227 | +- `min_count`: Minimum token frequency |
| 228 | +- `min_n`, `max_n`: Character n-gram range |
| 229 | +- `len_word_ngrams`: Word n-gram length |
| 230 | +- `num_classes`: Number of output classes |
97 | 231 |
|
98 | | -<a name="figure-2"></a> |
99 | | - |
100 | | -*Figure 2: The 'average and concatenate' architecture* |
| 232 | +#### `build_fasttext_from_tokenizer(tokenizer, **kwargs)` |
| 233 | +Create FastText classifier from existing tokenizer. |
101 | 234 |
|
102 | | -<a name="figure-3"></a> |
103 | | - |
104 | | -*Figure 3: The 'concatenate all' architecture* |
| 235 | +## 🏗️ Architecture |
105 | 236 |
|
106 | | -## Documentation |
| 237 | +The framework follows a modular architecture: |
107 | 238 |
|
108 | | -For detailed usage and examples, please refer to the [example notebook](notebooks/example.ipynb). Use `pip install -r requirements.txt` after cloning the repository to install the necessary dependencies (some are specific to the notebook). |
| 239 | +``` |
| 240 | +torchTextClassifiers/ |
| 241 | +├── torchTextClassifiers.py # Main classifier interface |
| 242 | +├── classifiers/ |
| 243 | +│ ├── base.py # Abstract base classes |
| 244 | +│ └── fasttext/ # FastText implementation |
| 245 | +│ ├── config.py # Configuration |
| 246 | +│ ├── wrapper.py # Classifier wrapper |
| 247 | +│ ├── factory.py # Convenience methods |
| 248 | +│ ├── tokenizer.py # N-gram tokenizer |
| 249 | +│ ├── pytorch_model.py # PyTorch model |
| 250 | +│ ├── lightning_module.py # Lightning module |
| 251 | +│ └── dataset.py # Dataset implementation |
| 252 | +├── utilities/ |
| 253 | +│ └── checkers.py # Input validation utilities |
| 254 | +└── factories.py # Generic factory system |
| 255 | +``` |
109 | 256 |
|
110 | | -## Contributing |
| 257 | +## 🔬 Testing |
111 | 258 |
|
112 | | -Contributions are welcome! Please feel free to submit a Pull Request. |
| 259 | +Run the test suite: |
| 260 | + |
| 261 | +```bash |
| 262 | +# Run all tests |
| 263 | +uv run pytest |
113 | 264 |
|
114 | | -## License |
| 265 | +# Run with coverage |
| 266 | +uv run pytest --cov=torchTextClassifiers |
115 | 267 |
|
116 | | -MIT |
| 268 | +# Run specific test file |
| 269 | +uv run pytest tests/test_torchTextClassifiers.py -v |
| 270 | +``` |
117 | 271 |
|
| 272 | +## 🤝 Contributing |
118 | 273 |
|
119 | | -## References |
| 274 | +We welcome contributions! See our [Developer Guide](docs/developer_guide.md) for information on: |
120 | 275 |
|
121 | | -Inspired by the original FastText paper [1] and implementation. |
| 276 | +- Adding new classifier types |
| 277 | +- Code organization and patterns |
| 278 | +- Testing requirements |
| 279 | +- Documentation standards |
122 | 280 |
|
123 | | -[1] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, [*Bag of Tricks for Efficient Text Classification*](https://arxiv.org/abs/1607.01759) |
| 281 | +## 📄 License |
124 | 282 |
|
125 | | -``` |
126 | | -@InProceedings{joulin2017bag, |
127 | | - title={Bag of Tricks for Efficient Text Classification}, |
128 | | - author={Joulin, Armand and Grave, Edouard and Bojanowski, Piotr and Mikolov, Tomas}, |
129 | | - booktitle={Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 2, Short Papers}, |
130 | | - month={April}, |
131 | | - year={2017}, |
132 | | - publisher={Association for Computational Linguistics}, |
133 | | - pages={427--431}, |
134 | | -} |
135 | | -``` |
| 283 | +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. |
| 284 | + |
| 285 | +## 🙏 Acknowledgments |
| 286 | + |
| 287 | +- Built with [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/) |
| 288 | +- Inspired by [FastText](https://fasttext.cc/) for efficient text classification |
| 289 | +- Uses [uv](https://github.com/astral-sh/uv) for dependency management |
| 290 | + |
| 291 | +## 📚 Examples |
| 292 | + |
| 293 | +See the [examples/](examples/) directory for: |
| 294 | +- Basic text classification |
| 295 | +- Multi-class classification |
| 296 | +- Mixed features (text + categorical) |
| 297 | +- Custom classifier implementation |
| 298 | +- Advanced training configurations |
| 299 | + |
| 300 | +## 🐛 Support |
| 301 | + |
| 302 | +If you encounter any issues: |
| 303 | + |
| 304 | +1. Check the [examples](examples/) for similar use cases |
| 305 | +2. Review the API documentation above |
| 306 | +3. Open an issue on GitHub with: |
| 307 | + - Python version |
| 308 | + - Package versions (`uv tree` or `pip list`) |
| 309 | + - Minimal reproduction code |
| 310 | + - Error messages/stack traces |
0 commit comments