Skip to content

Commit 2044f69

Browse files
committed
test: restructure test suite for modular classifier architecture
1 parent fda4bbf commit 2044f69

17 files changed

+1947
-290
lines changed

README.md

Lines changed: 261 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,135 +1,310 @@
1-
# torchTextClassifiers : Efficient text classification with PyTorch
1+
# torchTextClassifiers
22

3-
A flexible PyTorch implementation of models for text classification with support for categorical features.
3+
A unified, extensible framework for text classification using PyTorch and PyTorch Lightning.
44

5-
## Features
5+
## 🚀 Features
66

7-
- Supports text classification with FastText architecture
8-
- Handles both text and categorical features
9-
- N-gram tokenization
10-
- Flexible optimizer and scheduler options
11-
- GPU and CPU support
12-
- Model checkpointing and early stopping
13-
- Prediction and model explanation capabilities
7+
- **Unified API**: Consistent interface for different classifier types
8+
- **FastText Support**: Built-in FastText classifier implementation
9+
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
10+
- **Mixed Features**: Support for both text and categorical features
11+
- **Extensible**: Easy to add new classifier types
12+
- **Production Ready**: Model serialization, validation, and inference
1413

15-
## Installation
16-
17-
- With `pip`:
14+
## 📦 Installation
1815

1916
```bash
20-
pip install torchTextClassifiers
17+
# Clone the repository
18+
git clone https://github.com/your-repo/torch-fastText.git
19+
cd torch-fastText
20+
21+
# Install with uv (recommended)
22+
uv sync
23+
24+
# Or install with pip
25+
pip install -e .
2126
```
2227

23-
- with `uv`:
28+
## 🎯 Quick Start
2429

30+
### Basic FastText Classification
2531

26-
```bash
27-
uv add torchTextClassifiers
32+
```python
33+
import numpy as np
34+
from torchTextClassifiers import create_fasttext
35+
36+
# Create a FastText classifier
37+
classifier = create_fasttext(
38+
embedding_dim=100,
39+
sparse=False,
40+
num_tokens=10000,
41+
min_count=2,
42+
min_n=3,
43+
max_n=6,
44+
len_word_ngrams=2,
45+
num_classes=2
46+
)
47+
48+
# Prepare your data
49+
X_train = np.array([
50+
"This is a positive example",
51+
"This is a negative example",
52+
"Another positive case",
53+
"Another negative case"
54+
])
55+
y_train = np.array([1, 0, 1, 0])
56+
57+
X_val = np.array([
58+
"Validation positive",
59+
"Validation negative"
60+
])
61+
y_val = np.array([1, 0])
62+
63+
# Build the model
64+
classifier.build(X_train, y_train)
65+
66+
# Train the model
67+
classifier.train(
68+
X_train, y_train, X_val, y_val,
69+
num_epochs=50,
70+
batch_size=32,
71+
patience_train=5,
72+
verbose=True
73+
)
74+
75+
# Make predictions
76+
X_test = np.array(["This is a test sentence"])
77+
predictions = classifier.predict(X_test)
78+
print(f"Predictions: {predictions}")
79+
80+
# Validate on test set
81+
accuracy = classifier.validate(X_test, np.array([1]))
82+
print(f"Accuracy: {accuracy:.3f}")
2883
```
2984

30-
## Key Components
85+
### Working with Mixed Features (Text + Categorical)
86+
87+
```python
88+
import numpy as np
89+
from torchTextClassifiers import create_fasttext
90+
91+
# Text data with categorical features
92+
X_train = np.column_stack([
93+
np.array(["Great product!", "Terrible service", "Love it!"]), # Text
94+
np.array([[1, 2], [2, 1], [1, 3]]) # Categorical features
95+
])
96+
y_train = np.array([1, 0, 1])
97+
98+
# Create classifier with categorical support
99+
classifier = create_fasttext(
100+
embedding_dim=50,
101+
sparse=False,
102+
num_tokens=5000,
103+
min_count=1,
104+
min_n=3,
105+
max_n=6,
106+
len_word_ngrams=2,
107+
num_classes=2,
108+
categorical_vocabulary_sizes=[3, 4], # Vocab sizes for categorical features
109+
categorical_embedding_dims=[10, 10] # Embedding dims for categorical features
110+
)
111+
112+
# Build and train as usual
113+
classifier.build(X_train, y_train)
114+
# ... continue with training
115+
```
116+
117+
### Model Persistence
118+
119+
```python
120+
# Save configuration
121+
classifier.to_json('model_config.json')
122+
123+
# Load configuration (creates new instance)
124+
new_classifier = torchTextClassifiers.from_json('model_config.json')
31125

32-
- `build()`: Constructs the FastText model architecture
33-
- `train()`: Trains the model with built-in callbacks and logging
34-
- `predict()`: Generates class predictions
35-
- `predict_and_explain()`: Provides predictions with feature attributions
126+
# You'll need to retrain the loaded classifier
127+
new_classifier.build(X_train, y_train)
128+
new_classifier.train(X_train, y_train, X_val, y_val, ...)
129+
```
36130

37-
## Subpackages
131+
## 🔧 Advanced Usage
38132

39-
- `preprocess`: To preprocess text input, using `nltk` and `unidecode` libraries.
40-
- `explainability`: Simple methods to visualize feature attributions at word and letter levels, using `captum`library.
133+
### Custom Configuration
41134

42-
Run `pip install torchTextClassifiers[preprocess]` or `pip install torchTextClassifiers[explainability]` to download these optional dependencies.
135+
```python
136+
from torchTextClassifiers import torchTextClassifiers, ClassifierType
137+
from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
138+
139+
# Create custom configuration
140+
config = FastTextConfig(
141+
embedding_dim=200,
142+
sparse=True,
143+
num_tokens=20000,
144+
min_count=3,
145+
min_n=2,
146+
max_n=8,
147+
len_word_ngrams=3,
148+
num_classes=5,
149+
direct_bagging=False, # Custom FastText parameter
150+
)
43151

152+
# Create classifier with custom config
153+
classifier = torchTextClassifiers(ClassifierType.FASTTEXT, config)
154+
```
44155

45-
## Quick Start
156+
### Using Pre-trained Tokenizers
46157

47158
```python
48-
from torchTextClassifiers import torchTextClassifiers
159+
from torchTextClassifiers import build_fasttext_from_tokenizer
49160

50-
# Initialize the model
51-
model = torchTextclassifiers(
52-
num_tokens=1000000,
161+
# Assume you have a pre-trained tokenizer
162+
# my_tokenizer = ... (previously trained NGramTokenizer)
163+
164+
classifier = build_fasttext_from_tokenizer(
165+
tokenizer=my_tokenizer,
53166
embedding_dim=100,
54-
min_count=5,
55-
min_n=3,
56-
max_n=6,
57-
len_word_ngrams=True,
58-
sparse=True
167+
num_classes=3,
168+
sparse=False
59169
)
60170

61-
# Train the model
62-
model.train(
63-
X_train=train_data,
64-
y_train=train_labels,
65-
X_val=val_data,
66-
y_val=val_labels,
67-
num_epochs=10,
171+
# Model and tokenizer are already built, ready for training
172+
classifier.train(X_train, y_train, X_val, y_val, ...)
173+
```
174+
175+
### Training Customization
176+
177+
```python
178+
# Custom PyTorch Lightning trainer parameters
179+
trainer_params = {
180+
'accelerator': 'gpu',
181+
'devices': 1,
182+
'precision': 16, # Mixed precision training
183+
'gradient_clip_val': 1.0,
184+
}
185+
186+
classifier.train(
187+
X_train, y_train, X_val, y_val,
188+
num_epochs=100,
68189
batch_size=64,
69-
lr=4e-3
190+
patience_train=10,
191+
trainer_params=trainer_params,
192+
verbose=True
70193
)
71-
# Make predictions
72-
predictions = model.predict(test_data)
73194
```
74195

75-
where ```train_data``` is an array of size $(N,d)$, having the text in string format in the first column, the other columns containing tokenized categorical variables in `int` format.
196+
## 📊 API Reference
76197

77-
Please make sure `y_train` contains at least one time each possible label.
198+
### Main Classes
78199

79-
## Dependencies
200+
#### `torchTextClassifiers`
201+
The main classifier class providing a unified interface.
80202

81-
- PyTorch Lightning
82-
- NumPy
203+
**Key Methods:**
204+
- `build(X_train, y_train)`: Build tokenizer and model
205+
- `train(X_train, y_train, X_val, y_val, ...)`: Train the model
206+
- `predict(X)`: Make predictions
207+
- `validate(X, Y)`: Evaluate on test data
208+
- `to_json(filepath)`: Save configuration
209+
- `from_json(filepath)`: Load configuration
83210

84-
## Categorical features
211+
#### `ClassifierType`
212+
Enumeration of supported classifier types.
213+
- `FASTTEXT`: FastText classifier
85214

86-
If any, each categorical feature $i$ is associated to an embedding matrix of size (number of unique values, embedding dimension) where the latter is a hyperparameter (`categorical_embedding_dims`) - chosen by the user - that can take three types of values:
215+
#### `ClassifierFactory`
216+
Factory for creating classifier instances.
87217

88-
- `None`: same embedding dimension as the token embedding matrix. The categorical embeddings are then summed to the sentence-level embedding (which itself is an averaging of the token embeddings). See [Figure 1](#Default-architecture).
89-
- `int`: the categorical embeddings have all the same embedding dimensions, they are averaged and the resulting vector is concatenated to the sentence-level embedding (the last linear layer has an adapted input size). See [Figure 2](#avg-architecture).
90-
- `list`: the categorical embeddings have different embedding dimensions, all of them are concatenated without aggregation to the sentence-level embedding (the last linear layer has an adapted input size). See [Figure 3](#concat-architecture).
218+
### FastText Specific
91219

92-
Default is `None`.
220+
#### `create_fasttext(**kwargs)`
221+
Convenience function to create FastText classifiers.
93222

94-
<a name="figure-1"></a>
95-
![Default-architecture](images/NN.drawio.png "Default architecture")
96-
*Figure 1: The 'sum' architecture*
223+
**Parameters:**
224+
- `embedding_dim`: Embedding dimension
225+
- `sparse`: Use sparse embeddings
226+
- `num_tokens`: Vocabulary size
227+
- `min_count`: Minimum token frequency
228+
- `min_n`, `max_n`: Character n-gram range
229+
- `len_word_ngrams`: Word n-gram length
230+
- `num_classes`: Number of output classes
97231

98-
<a name="figure-2"></a>
99-
![avg-architecture](images/avg_concat.png "Default architecture")
100-
*Figure 2: The 'average and concatenate' architecture*
232+
#### `build_fasttext_from_tokenizer(tokenizer, **kwargs)`
233+
Create FastText classifier from existing tokenizer.
101234

102-
<a name="figure-3"></a>
103-
![concat-architecture](images/full_concat.png "Default architecture")
104-
*Figure 3: The 'concatenate all' architecture*
235+
## 🏗️ Architecture
105236

106-
## Documentation
237+
The framework follows a modular architecture:
107238

108-
For detailed usage and examples, please refer to the [example notebook](notebooks/example.ipynb). Use `pip install -r requirements.txt` after cloning the repository to install the necessary dependencies (some are specific to the notebook).
239+
```
240+
torchTextClassifiers/
241+
├── torchTextClassifiers.py # Main classifier interface
242+
├── classifiers/
243+
│ ├── base.py # Abstract base classes
244+
│ └── fasttext/ # FastText implementation
245+
│ ├── config.py # Configuration
246+
│ ├── wrapper.py # Classifier wrapper
247+
│ ├── factory.py # Convenience methods
248+
│ ├── tokenizer.py # N-gram tokenizer
249+
│ ├── pytorch_model.py # PyTorch model
250+
│ ├── lightning_module.py # Lightning module
251+
│ └── dataset.py # Dataset implementation
252+
├── utilities/
253+
│ └── checkers.py # Input validation utilities
254+
└── factories.py # Generic factory system
255+
```
109256

110-
## Contributing
257+
## 🔬 Testing
111258

112-
Contributions are welcome! Please feel free to submit a Pull Request.
259+
Run the test suite:
260+
261+
```bash
262+
# Run all tests
263+
uv run pytest
113264

114-
## License
265+
# Run with coverage
266+
uv run pytest --cov=torchTextClassifiers
115267

116-
MIT
268+
# Run specific test file
269+
uv run pytest tests/test_torchTextClassifiers.py -v
270+
```
117271

272+
## 🤝 Contributing
118273

119-
## References
274+
We welcome contributions! See our [Developer Guide](docs/developer_guide.md) for information on:
120275

121-
Inspired by the original FastText paper [1] and implementation.
276+
- Adding new classifier types
277+
- Code organization and patterns
278+
- Testing requirements
279+
- Documentation standards
122280

123-
[1] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, [*Bag of Tricks for Efficient Text Classification*](https://arxiv.org/abs/1607.01759)
281+
## 📄 License
124282

125-
```
126-
@InProceedings{joulin2017bag,
127-
title={Bag of Tricks for Efficient Text Classification},
128-
author={Joulin, Armand and Grave, Edouard and Bojanowski, Piotr and Mikolov, Tomas},
129-
booktitle={Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 2, Short Papers},
130-
month={April},
131-
year={2017},
132-
publisher={Association for Computational Linguistics},
133-
pages={427--431},
134-
}
135-
```
283+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
284+
285+
## 🙏 Acknowledgments
286+
287+
- Built with [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/)
288+
- Inspired by [FastText](https://fasttext.cc/) for efficient text classification
289+
- Uses [uv](https://github.com/astral-sh/uv) for dependency management
290+
291+
## 📚 Examples
292+
293+
See the [examples/](examples/) directory for:
294+
- Basic text classification
295+
- Multi-class classification
296+
- Mixed features (text + categorical)
297+
- Custom classifier implementation
298+
- Advanced training configurations
299+
300+
## 🐛 Support
301+
302+
If you encounter any issues:
303+
304+
1. Check the [examples](examples/) for similar use cases
305+
2. Review the API documentation above
306+
3. Open an issue on GitHub with:
307+
- Python version
308+
- Package versions (`uv tree` or `pip list`)
309+
- Minimal reproduction code
310+
- Error messages/stack traces

0 commit comments

Comments
 (0)