11# torchTextClassifiers
22
3- A unified, extensible framework for text classification using PyTorch and PyTorch Lightning.
3+ A unified, extensible framework for text classification built on [ PyTorch] ( https://pytorch.org/ ) and [ PyTorch Lightning] ( https://lightning.ai/docs/pytorch/stable/ ) .
4+
5+
46
57## 🚀 Features
68
7- - ** Unified API** : Consistent interface for different classifier types
8- - ** FastText Support** : Built-in FastText classifier implementation
9+ - ** Unified API** : Consistent interface for different classifier wrappers
10+ - ** Extensible** : Easy to add new classifier implementations through wrapper pattern
11+ - ** FastText Support** : Built-in FastText classifier with n-gram tokenization
12+ - ** Flexible Preprocessing** : Each classifier can implement its own text preprocessing approach
913- ** PyTorch Lightning** : Automated training with callbacks, early stopping, and logging
10- - ** Mixed Features** : Support for both text and categorical features
11- - ** Extensible** : Easy to add new classifier types
12- - ** Production Ready** : Model serialization, validation, and inference
14+
1315
1416## 📦 Installation
1517
1618``` bash
1719# Clone the repository
18- git clone https://github.com/your-repo/torch-fastText .git
19- cd torch-fastText
20+ git clone https://github.com/InseeFrLab/torchTextClassifiers .git
21+ cd torchtextClassifiers
2022
2123# Install with uv (recommended)
2224uv sync
@@ -82,47 +84,44 @@ accuracy = classifier.validate(X_test, np.array([1]))
8284print (f " Accuracy: { accuracy:.3f } " )
8385```
8486
85- ### Working with Mixed Features (Text + Categorical)
87+ ### Custom Classifier Implementation
8688
8789``` python
8890import numpy as np
89- from torchTextClassifiers import create_fasttext
91+ from torchTextClassifiers import torchTextClassifiers
92+ from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextWrapper, SimpleTextConfig
9093
91- # Text data with categorical features
92- X_train = np.column_stack([
93- np.array([" Great product!" , " Terrible service" , " Love it!" ]), # Text
94- np.array([[1 , 2 ], [2 , 1 ], [1 , 3 ]]) # Categorical features
95- ])
96- y_train = np.array([1 , 0 , 1 ])
97-
98- # Create classifier with categorical support
99- classifier = create_fasttext(
100- embedding_dim = 50 ,
101- sparse = False ,
102- num_tokens = 5000 ,
103- min_count = 1 ,
104- min_n = 3 ,
105- max_n = 6 ,
106- len_word_ngrams = 2 ,
94+ # Example: TF-IDF based classifier (alternative to tokenization)
95+ config = SimpleTextConfig(
96+ hidden_dim = 128 ,
10797 num_classes = 2 ,
108- categorical_vocabulary_sizes = [3 , 4 ], # Vocab sizes for categorical features
109- categorical_embedding_dims = [10 , 10 ] # Embedding dims for categorical features
98+ max_features = 5000 ,
99+ learning_rate = 1e-3 ,
100+ dropout_rate = 0.2
110101)
111102
112- # Build and train as usual
103+ # Create classifier with TF-IDF preprocessing
104+ wrapper = SimpleTextWrapper(config)
105+ classifier = torchTextClassifiers(wrapper)
106+
107+ # Text data
108+ X_train = np.array([" Great product!" , " Terrible service" , " Love it!" ])
109+ y_train = np.array([1 , 0 , 1 ])
110+
111+ # Build and train
113112classifier.build(X_train, y_train)
114113# ... continue with training
115114```
116115
117116
118-
119117## 🔧 Advanced Usage
120118
121119### Custom Configuration
122120
123121``` python
124- from torchTextClassifiers import torchTextClassifiers, ClassifierType
122+ from torchTextClassifiers import torchTextClassifiers
125123from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
124+ from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
126125
127126# Create custom configuration
128127config = FastTextConfig(
@@ -138,7 +137,8 @@ config = FastTextConfig(
138137)
139138
140139# Create classifier with custom config
141- classifier = torchTextClassifiers(ClassifierType.FASTTEXT , config)
140+ wrapper = FastTextWrapper(config)
141+ classifier = torchTextClassifiers(wrapper)
142142```
143143
144144### Using Pre-trained Tokenizers
@@ -189,19 +189,18 @@ classifier.train(
189189The main classifier class providing a unified interface.
190190
191191** Key Methods:**
192- - ` build(X_train, y_train) ` : Build tokenizer and model
192+ - ` build(X_train, y_train) ` : Build text preprocessing and model
193193- ` train(X_train, y_train, X_val, y_val, ...) ` : Train the model
194194- ` predict(X) ` : Make predictions
195195- ` validate(X, Y) ` : Evaluate on test data
196196- ` to_json(filepath) ` : Save configuration
197197- ` from_json(filepath) ` : Load configuration
198198
199- #### ` ClassifierType `
200- Enumeration of supported classifier types.
201- - ` FASTTEXT ` : FastText classifier
199+ #### ` BaseClassifierWrapper `
200+ Base class for all classifier wrappers. Each classifier implementation extends this class.
202201
203- #### ` ClassifierFactory `
204- Factory for creating classifier instances .
202+ #### ` FastTextWrapper `
203+ Wrapper for FastText classifier implementation with tokenization-based preprocessing .
205204
206205### FastText Specific
207206
@@ -222,24 +221,25 @@ Create FastText classifier from existing tokenizer.
222221
223222## 🏗️ Architecture
224223
225- The framework follows a modular architecture:
224+ The framework follows a wrapper-based architecture:
226225
227226```
228227torchTextClassifiers/
229228├── torchTextClassifiers.py # Main classifier interface
230229├── classifiers/
231- │ ├── base.py # Abstract base classes
232- │ └── fasttext/ # FastText implementation
233- │ ├── config.py # Configuration
234- │ ├── wrapper.py # Classifier wrapper
235- │ ├── factory.py # Convenience methods
236- │ ├── tokenizer.py # N-gram tokenizer
237- │ ├── pytorch_model.py # PyTorch model
238- │ ├── lightning_module.py # Lightning module
239- │ └── dataset.py # Dataset implementation
230+ │ ├── base.py # Abstract base wrapper classes
231+ │ ├── fasttext/ # FastText implementation
232+ │ │ ├── config.py # Configuration
233+ │ │ ├── wrapper.py # FastText wrapper (tokenization)
234+ │ │ ├── factory.py # Convenience methods
235+ │ │ ├── tokenizer.py # N-gram tokenizer
236+ │ │ ├── pytorch_model.py # PyTorch model
237+ │ │ ├── lightning_module.py # Lightning module
238+ │ │ └── dataset.py # Dataset implementation
239+ │ └── simple_text_classifier.py # Example TF-IDF wrapper
240240├── utilities/
241241│ └── checkers.py # Input validation utilities
242- └── factories.py # Generic factory system
242+ └── factories.py # Convenience factory functions
243243```
244244
245245## 🔬 Testing
0 commit comments