Skip to content

Commit 8123b70

Browse files
committed
No need for classifier registry
1 parent f9c6ac0 commit 8123b70

File tree

11 files changed

+336
-301
lines changed

11 files changed

+336
-301
lines changed

README.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,7 @@ classifier.build(X_train, y_train)
114114
# ... continue with training
115115
```
116116

117-
### Model Persistence
118117

119-
```python
120-
# Save configuration
121-
classifier.to_json('model_config.json')
122-
123-
# Load configuration (creates new instance)
124-
new_classifier = torchTextClassifiers.from_json('model_config.json')
125-
126-
# You'll need to retrain the loaded classifier
127-
new_classifier.build(X_train, y_train)
128-
new_classifier.train(X_train, y_train, X_val, y_val, ...)
129-
```
130118

131119
## 🔧 Advanced Usage
132120

examples/multiclass_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def main():
134134
from torchTextClassifiers import torchTextClassifiers
135135
loaded_classifier = torchTextClassifiers.from_json('multiclass_classifier_config.json')
136136
print("✅ Configuration loaded successfully!")
137-
print(f"Loaded classifier type: {loaded_classifier.classifier_type}")
137+
print(f"Loaded classifier wrapper: {type(loaded_classifier.classifier_wrapper).__name__}")
138138
print(f"Loaded num_classes: {loaded_classifier.config.num_classes}")
139139

140140
print("\n🎉 Multi-class example completed successfully!")

notebooks/example.ipynb

Lines changed: 145 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
},
4343
{
4444
"cell_type": "code",
45-
"execution_count": 2,
45+
"execution_count": 1,
4646
"id": "37c042fe",
4747
"metadata": {},
4848
"outputs": [],
@@ -70,7 +70,7 @@
7070
},
7171
{
7272
"cell_type": "code",
73-
"execution_count": 3,
73+
"execution_count": 2,
7474
"id": "92402df7",
7575
"metadata": {},
7676
"outputs": [],
@@ -217,7 +217,7 @@
217217
},
218218
{
219219
"cell_type": "code",
220-
"execution_count": 4,
220+
"execution_count": 3,
221221
"id": "1fd02895",
222222
"metadata": {},
223223
"outputs": [
@@ -258,10 +258,22 @@
258258
},
259259
{
260260
"cell_type": "code",
261-
"execution_count": null,
261+
"execution_count": 4,
262262
"id": "61b0252e",
263263
"metadata": {},
264-
"outputs": [],
264+
"outputs": [
265+
{
266+
"ename": "ModuleNotFoundError",
267+
"evalue": "No module named 'torchTextClassifiers'",
268+
"output_type": "error",
269+
"traceback": [
270+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
271+
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
272+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorchTextClassifiers\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutilities\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpreprocess\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m clean_text_feature\n\u001b[32m 2\u001b[39m df[\u001b[33m\"\u001b[39m\u001b[33mlibelle_processed\u001b[39m\u001b[33m\"\u001b[39m] = clean_text_feature(df[\u001b[33m\"\u001b[39m\u001b[33mlibelle\u001b[39m\u001b[33m\"\u001b[39m])\n",
273+
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'torchTextClassifiers'"
274+
]
275+
}
276+
],
265277
"source": [
266278
"from torchTextClassifiers.utilities.preprocess import clean_text_feature\n",
267279
"df[\"libelle_processed\"] = clean_text_feature(df[\"libelle\"])"
@@ -445,7 +457,9 @@
445457
"outputs": [],
446458
"source": [
447459
"model.to_json('torchTextClassifiers_config.json')\n",
448-
"# model = create_fasttext.from_json('torchTextClassifiers_config.json')"
460+
"# Loading from JSON now works with the new API:\n",
461+
"# from torchTextClassifiers import torchTextClassifiers\n",
462+
"# loaded_model = torchTextClassifiers.from_json('torchTextClassifiers_config.json')"
449463
]
450464
},
451465
{
@@ -698,6 +712,75 @@
698712
"model.to_json('torchTextClassifiers_config.json')"
699713
]
700714
},
715+
{
716+
"cell_type": "code",
717+
"execution_count": null,
718+
"id": "9amb3ku6gim",
719+
"metadata": {},
720+
"outputs": [],
721+
"source": [
722+
"# Demonstrate the new JSON loading approach\n",
723+
"from torchTextClassifiers import torchTextClassifiers\n",
724+
"\n",
725+
"# Load model from JSON (works with new wrapper-based approach)\n",
726+
"loaded_model = torchTextClassifiers.from_json('torchTextClassifiers_config.json')\n",
727+
"\n",
728+
"print(\"✅ Model loaded from JSON successfully!\")\n",
729+
"print(f\"Loaded wrapper type: {type(loaded_model.classifier_wrapper).__name__}\")\n",
730+
"print(f\"Config parameters: embedding_dim={loaded_model.config.embedding_dim}, sparse={loaded_model.config.sparse}\")\n",
731+
"\n",
732+
"# The loaded model needs to be built before use\n",
733+
"# loaded_model.build(X_train, y_train)"
734+
]
735+
},
736+
{
737+
"cell_type": "markdown",
738+
"id": "f7rq00g68p",
739+
"metadata": {},
740+
"source": [
741+
"## New API Features\n",
742+
"\n",
743+
"The updated `torchTextClassifiers` API provides more flexibility by allowing users to:\n",
744+
"\n",
745+
"### 1. **Direct Wrapper Usage**\n",
746+
"Create classifiers directly using wrapper classes, enabling custom implementations:\n",
747+
"\n",
748+
"```python\n",
749+
"from torchTextClassifiers import torchTextClassifiers\n",
750+
"from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper\n",
751+
"from torchTextClassifiers.classifiers.fasttext.core import FastTextConfig\n",
752+
"\n",
753+
"config = FastTextConfig(...)\n",
754+
"wrapper = FastTextWrapper(config)\n",
755+
"classifier = torchTextClassifiers(wrapper)\n",
756+
"```\n",
757+
"\n",
758+
"### 2. **Convenience Functions (Backward Compatible)**\n",
759+
"The familiar convenience functions still work:\n",
760+
"\n",
761+
"```python\n",
762+
"from torchTextClassifiers import create_fasttext\n",
763+
"classifier = create_fasttext(embedding_dim=50, sparse=False, ...)\n",
764+
"```\n",
765+
"\n",
766+
"### 3. **Enhanced JSON Support**\n",
767+
"Improved serialization/deserialization that works with custom wrapper classes:\n",
768+
"\n",
769+
"```python\n",
770+
"# Save configuration\n",
771+
"classifier.to_json('config.json')\n",
772+
"\n",
773+
"# Load configuration (automatically detects wrapper type)\n",
774+
"loaded_classifier = torchTextClassifiers.from_json('config.json')\n",
775+
"\n",
776+
"# Or specify wrapper class explicitly\n",
777+
"loaded_classifier = torchTextClassifiers.from_json('config.json', FastTextWrapper)\n",
778+
"```\n",
779+
"\n",
780+
"### 4. **Custom Classifier Support**\n",
781+
"Users can now easily create their own classifier wrappers by inheriting from `BaseClassifierWrapper` and implementing the required methods."
782+
]
783+
},
701784
{
702785
"cell_type": "markdown",
703786
"id": "017f8d12-0be8-45df-a0e4-80919c89db2d",
@@ -713,10 +796,61 @@
713796
"where one can first build the tokenizer, and then build the model with\n",
714797
"custom architecture parameters.\n",
715798
"\n",
799+
"**Note**: With the new API, you can also create classifiers directly using wrapper classes:\n",
800+
"\n",
801+
"```python\n",
802+
"from torchTextClassifiers import torchTextClassifiers\n",
803+
"from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper\n",
804+
"from torchTextClassifiers.classifiers.fasttext.core import FastTextConfig\n",
805+
"\n",
806+
"config = FastTextConfig(embedding_dim=50, sparse=False, ...)\n",
807+
"wrapper = FastTextWrapper(config)\n",
808+
"classifier = torchTextClassifiers(wrapper)\n",
809+
"```\n",
810+
"\n",
716811
"The tokenizer can be loaded **from the same JSON file** as the model\n",
717812
"parameters, or initialized using the right arguments."
718813
]
719814
},
815+
{
816+
"cell_type": "code",
817+
"execution_count": null,
818+
"id": "g0rmedya9eb",
819+
"metadata": {},
820+
"outputs": [],
821+
"source": [
822+
"# Example of the new direct wrapper approach\n",
823+
"from torchTextClassifiers import torchTextClassifiers\n",
824+
"from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper\n",
825+
"from torchTextClassifiers.classifiers.fasttext.core import FastTextConfig\n",
826+
"\n",
827+
"# Create configuration\n",
828+
"config = FastTextConfig(\n",
829+
" embedding_dim=50,\n",
830+
" sparse=False,\n",
831+
" num_tokens=100000,\n",
832+
" min_count=1,\n",
833+
" min_n=3,\n",
834+
" max_n=6,\n",
835+
" len_word_ngrams=3,\n",
836+
" categorical_embedding_dims=10,\n",
837+
" num_classes=NUM_CLASSES,\n",
838+
" num_categorical_features=NUM_CAT_VAR,\n",
839+
" categorical_vocabulary_sizes=CAT_VOCAB_SIZE\n",
840+
")\n",
841+
"\n",
842+
"# Create wrapper and classifier\n",
843+
"wrapper = FastTextWrapper(config)\n",
844+
"direct_model = torchTextClassifiers(wrapper)\n",
845+
"\n",
846+
"# Build the model\n",
847+
"direct_model.build(X_train, y_train, lightning=True, lr=parameters_train.get(\"lr\"))\n",
848+
"\n",
849+
"print(\"✅ Direct wrapper model created successfully!\")\n",
850+
"print(f\"Model type: {type(direct_model.classifier_wrapper).__name__}\")\n",
851+
"print(f\"Config type: {type(direct_model.config).__name__}\")"
852+
]
853+
},
720854
{
721855
"cell_type": "code",
722856
"execution_count": 18,
@@ -1064,7 +1198,9 @@
10641198
"id": "f84e6bff-8fa7-4896-b60a-005ae5f1d3eb",
10651199
"metadata": {},
10661200
"source": [
1067-
"# Explainability"
1201+
"# Explainability\n",
1202+
"\n",
1203+
"The `torchTextClassifiers` framework provides explainability features through the `predict_and_explain` method. This allows you to understand which parts of the input text contribute most to the model's predictions."
10681204
]
10691205
},
10701206
{
@@ -1093,10 +1229,9 @@
10931229
],
10941230
"metadata": {
10951231
"kernelspec": {
1096-
"display_name": "Python 3 (ipykernel)",
1232+
"display_name": "Python 3",
10971233
"language": "python",
1098-
"name": "python3",
1099-
"path": "/opt/conda/share/jupyter/kernels/python3"
1234+
"name": "python3"
11001235
},
11011236
"language_info": {
11021237
"codemirror_mode": {

tests/test_base_classes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def create_dataloader(self, dataset, batch_size, num_workers=0, shuffle=True):
8888

8989
def load_best_model(self, checkpoint_path):
9090
self.trained = True
91+
92+
@classmethod
93+
def get_config_class(cls):
94+
return Mock
9195

9296
mock_config = Mock()
9397
wrapper = ConcreteWrapper(mock_config)
@@ -137,6 +141,10 @@ def create_dataloader(self, dataset, batch_size, num_workers=0, shuffle=True):
137141
def load_best_model(self, checkpoint_path):
138142
self.trained = True
139143
self.pytorch_model = f"model_from_{checkpoint_path}"
144+
145+
@classmethod
146+
def get_config_class(cls):
147+
return Mock
140148

141149
mock_config = Mock()
142150
wrapper = ConcreteWrapper(mock_config)
@@ -238,6 +246,10 @@ def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuf
238246

239247
def load_best_model(self, checkpoint_path: str) -> None:
240248
pass
249+
250+
@classmethod
251+
def get_config_class(cls):
252+
return Mock
241253

242254
# Should be able to instantiate with all methods implemented
243255
mock_config = Mock()

0 commit comments

Comments
 (0)