|
42 | 42 | }, |
43 | 43 | { |
44 | 44 | "cell_type": "code", |
45 | | - "execution_count": 2, |
| 45 | + "execution_count": 1, |
46 | 46 | "id": "37c042fe", |
47 | 47 | "metadata": {}, |
48 | 48 | "outputs": [], |
|
70 | 70 | }, |
71 | 71 | { |
72 | 72 | "cell_type": "code", |
73 | | - "execution_count": 3, |
| 73 | + "execution_count": 2, |
74 | 74 | "id": "92402df7", |
75 | 75 | "metadata": {}, |
76 | 76 | "outputs": [], |
|
217 | 217 | }, |
218 | 218 | { |
219 | 219 | "cell_type": "code", |
220 | | - "execution_count": 4, |
| 220 | + "execution_count": 3, |
221 | 221 | "id": "1fd02895", |
222 | 222 | "metadata": {}, |
223 | 223 | "outputs": [ |
|
258 | 258 | }, |
259 | 259 | { |
260 | 260 | "cell_type": "code", |
261 | | - "execution_count": null, |
| 261 | + "execution_count": 4, |
262 | 262 | "id": "61b0252e", |
263 | 263 | "metadata": {}, |
264 | | - "outputs": [], |
| 264 | + "outputs": [ |
| 265 | + { |
| 266 | + "ename": "ModuleNotFoundError", |
| 267 | + "evalue": "No module named 'torchTextClassifiers'", |
| 268 | + "output_type": "error", |
| 269 | + "traceback": [ |
| 270 | + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", |
| 271 | + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", |
| 272 | + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorchTextClassifiers\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutilities\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpreprocess\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m clean_text_feature\n\u001b[32m 2\u001b[39m df[\u001b[33m\"\u001b[39m\u001b[33mlibelle_processed\u001b[39m\u001b[33m\"\u001b[39m] = clean_text_feature(df[\u001b[33m\"\u001b[39m\u001b[33mlibelle\u001b[39m\u001b[33m\"\u001b[39m])\n", |
| 273 | + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'torchTextClassifiers'" |
| 274 | + ] |
| 275 | + } |
| 276 | + ], |
265 | 277 | "source": [ |
266 | 278 | "from torchTextClassifiers.utilities.preprocess import clean_text_feature\n", |
267 | 279 | "df[\"libelle_processed\"] = clean_text_feature(df[\"libelle\"])" |
|
445 | 457 | "outputs": [], |
446 | 458 | "source": [ |
447 | 459 | "model.to_json('torchTextClassifiers_config.json')\n", |
448 | | - "# model = create_fasttext.from_json('torchTextClassifiers_config.json')" |
| 460 | + "# Loading from JSON now works with the new API:\n", |
| 461 | + "# from torchTextClassifiers import torchTextClassifiers\n", |
| 462 | + "# loaded_model = torchTextClassifiers.from_json('torchTextClassifiers_config.json')" |
449 | 463 | ] |
450 | 464 | }, |
451 | 465 | { |
|
698 | 712 | "model.to_json('torchTextClassifiers_config.json')" |
699 | 713 | ] |
700 | 714 | }, |
| 715 | + { |
| 716 | + "cell_type": "code", |
| 717 | + "execution_count": null, |
| 718 | + "id": "9amb3ku6gim", |
| 719 | + "metadata": {}, |
| 720 | + "outputs": [], |
| 721 | + "source": [ |
| 722 | + "# Demonstrate the new JSON loading approach\n", |
| 723 | + "from torchTextClassifiers import torchTextClassifiers\n", |
| 724 | + "\n", |
| 725 | + "# Load model from JSON (works with new wrapper-based approach)\n", |
| 726 | + "loaded_model = torchTextClassifiers.from_json('torchTextClassifiers_config.json')\n", |
| 727 | + "\n", |
| 728 | + "print(\"✅ Model loaded from JSON successfully!\")\n", |
| 729 | + "print(f\"Loaded wrapper type: {type(loaded_model.classifier_wrapper).__name__}\")\n", |
| 730 | + "print(f\"Config parameters: embedding_dim={loaded_model.config.embedding_dim}, sparse={loaded_model.config.sparse}\")\n", |
| 731 | + "\n", |
| 732 | + "# The loaded model needs to be built before use\n", |
| 733 | + "# loaded_model.build(X_train, y_train)" |
| 734 | + ] |
| 735 | + }, |
| 736 | + { |
| 737 | + "cell_type": "markdown", |
| 738 | + "id": "f7rq00g68p", |
| 739 | + "metadata": {}, |
| 740 | + "source": [ |
| 741 | + "## New API Features\n", |
| 742 | + "\n", |
| 743 | + "The updated `torchTextClassifiers` API provides more flexibility by allowing users to:\n", |
| 744 | + "\n", |
| 745 | + "### 1. **Direct Wrapper Usage**\n", |
| 746 | + "Create classifiers directly using wrapper classes, enabling custom implementations:\n", |
| 747 | + "\n", |
| 748 | + "```python\n", |
| 749 | + "from torchTextClassifiers import torchTextClassifiers\n", |
| 750 | + "from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper\n", |
| 751 | + "from torchTextClassifiers.classifiers.fasttext.core import FastTextConfig\n", |
| 752 | + "\n", |
| 753 | + "config = FastTextConfig(...)\n", |
| 754 | + "wrapper = FastTextWrapper(config)\n", |
| 755 | + "classifier = torchTextClassifiers(wrapper)\n", |
| 756 | + "```\n", |
| 757 | + "\n", |
| 758 | + "### 2. **Convenience Functions (Backward Compatible)**\n", |
| 759 | + "The familiar convenience functions still work:\n", |
| 760 | + "\n", |
| 761 | + "```python\n", |
| 762 | + "from torchTextClassifiers import create_fasttext\n", |
| 763 | + "classifier = create_fasttext(embedding_dim=50, sparse=False, ...)\n", |
| 764 | + "```\n", |
| 765 | + "\n", |
| 766 | + "### 3. **Enhanced JSON Support**\n", |
| 767 | + "Improved serialization/deserialization that works with custom wrapper classes:\n", |
| 768 | + "\n", |
| 769 | + "```python\n", |
| 770 | + "# Save configuration\n", |
| 771 | + "classifier.to_json('config.json')\n", |
| 772 | + "\n", |
| 773 | + "# Load configuration (automatically detects wrapper type)\n", |
| 774 | + "loaded_classifier = torchTextClassifiers.from_json('config.json')\n", |
| 775 | + "\n", |
| 776 | + "# Or specify wrapper class explicitly\n", |
| 777 | + "loaded_classifier = torchTextClassifiers.from_json('config.json', FastTextWrapper)\n", |
| 778 | + "```\n", |
| 779 | + "\n", |
| 780 | + "### 4. **Custom Classifier Support**\n", |
| 781 | + "Users can now easily create their own classifier wrappers by inheriting from `BaseClassifierWrapper` and implementing the required methods." |
| 782 | + ] |
| 783 | + }, |
701 | 784 | { |
702 | 785 | "cell_type": "markdown", |
703 | 786 | "id": "017f8d12-0be8-45df-a0e4-80919c89db2d", |
|
713 | 796 | "where one can first build the tokenizer, and then build the model with\n", |
714 | 797 | "custom architecture parameters.\n", |
715 | 798 | "\n", |
| 799 | + "**Note**: With the new API, you can also create classifiers directly using wrapper classes:\n", |
| 800 | + "\n", |
| 801 | + "```python\n", |
| 802 | + "from torchTextClassifiers import torchTextClassifiers\n", |
| 803 | + "from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper\n", |
| 804 | + "from torchTextClassifiers.classifiers.fasttext.core import FastTextConfig\n", |
| 805 | + "\n", |
| 806 | + "config = FastTextConfig(embedding_dim=50, sparse=False, ...)\n", |
| 807 | + "wrapper = FastTextWrapper(config)\n", |
| 808 | + "classifier = torchTextClassifiers(wrapper)\n", |
| 809 | + "```\n", |
| 810 | + "\n", |
716 | 811 | "The tokenizer can be loaded **from the same JSON file** as the model\n", |
717 | 812 | "parameters, or initialized using the right arguments." |
718 | 813 | ] |
719 | 814 | }, |
| 815 | + { |
| 816 | + "cell_type": "code", |
| 817 | + "execution_count": null, |
| 818 | + "id": "g0rmedya9eb", |
| 819 | + "metadata": {}, |
| 820 | + "outputs": [], |
| 821 | + "source": [ |
| 822 | + "# Example of the new direct wrapper approach\n", |
| 823 | + "from torchTextClassifiers import torchTextClassifiers\n", |
| 824 | + "from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper\n", |
| 825 | + "from torchTextClassifiers.classifiers.fasttext.core import FastTextConfig\n", |
| 826 | + "\n", |
| 827 | + "# Create configuration\n", |
| 828 | + "config = FastTextConfig(\n", |
| 829 | + " embedding_dim=50,\n", |
| 830 | + " sparse=False,\n", |
| 831 | + " num_tokens=100000,\n", |
| 832 | + " min_count=1,\n", |
| 833 | + " min_n=3,\n", |
| 834 | + " max_n=6,\n", |
| 835 | + " len_word_ngrams=3,\n", |
| 836 | + " categorical_embedding_dims=10,\n", |
| 837 | + " num_classes=NUM_CLASSES,\n", |
| 838 | + " num_categorical_features=NUM_CAT_VAR,\n", |
| 839 | + " categorical_vocabulary_sizes=CAT_VOCAB_SIZE\n", |
| 840 | + ")\n", |
| 841 | + "\n", |
| 842 | + "# Create wrapper and classifier\n", |
| 843 | + "wrapper = FastTextWrapper(config)\n", |
| 844 | + "direct_model = torchTextClassifiers(wrapper)\n", |
| 845 | + "\n", |
| 846 | + "# Build the model\n", |
| 847 | + "direct_model.build(X_train, y_train, lightning=True, lr=parameters_train.get(\"lr\"))\n", |
| 848 | + "\n", |
| 849 | + "print(\"✅ Direct wrapper model created successfully!\")\n", |
| 850 | + "print(f\"Model type: {type(direct_model.classifier_wrapper).__name__}\")\n", |
| 851 | + "print(f\"Config type: {type(direct_model.config).__name__}\")" |
| 852 | + ] |
| 853 | + }, |
720 | 854 | { |
721 | 855 | "cell_type": "code", |
722 | 856 | "execution_count": 18, |
|
1064 | 1198 | "id": "f84e6bff-8fa7-4896-b60a-005ae5f1d3eb", |
1065 | 1199 | "metadata": {}, |
1066 | 1200 | "source": [ |
1067 | | - "# Explainability" |
| 1201 | + "# Explainability\n", |
| 1202 | + "\n", |
| 1203 | + "The `torchTextClassifiers` framework provides explainability features through the `predict_and_explain` method. This allows you to understand which parts of the input text contribute most to the model's predictions." |
1068 | 1204 | ] |
1069 | 1205 | }, |
1070 | 1206 | { |
|
1093 | 1229 | ], |
1094 | 1230 | "metadata": { |
1095 | 1231 | "kernelspec": { |
1096 | | - "display_name": "Python 3 (ipykernel)", |
| 1232 | + "display_name": "Python 3", |
1097 | 1233 | "language": "python", |
1098 | | - "name": "python3", |
1099 | | - "path": "/opt/conda/share/jupyter/kernels/python3" |
| 1234 | + "name": "python3" |
1100 | 1235 | }, |
1101 | 1236 | "language_info": { |
1102 | 1237 | "codemirror_mode": { |
|
0 commit comments