This project implements deep learning models for plant disease classification using leaf images. We experiment with several model architectures:
- Multi-Layer Perceptron (MLP)
- Convolutional Neural Networks (CNN)
- Vision Transformers (ViT)
- Clone the repository:
git clone https://github.com/lolyhop/plant-disease-classification.git
cd plant-disease-classification- Install dependencies:
pip install -r requirements.txt- Download the dataset:
bash scripts/download_raw_data.sh.
├── configs/ # Model configurations
├── data/ # Dataset and logs
├── docs/ # Images and documentation
├── front/ # Streamlit frontend
├── notebooks/ # Jupyter notebooks for exploration
├── scripts/ # Helper scripts
├── src/ # Code for models, training, and utilities
├── train.py # Training script
├── inference.py # Inference script
├── requirements.txt
└── README.md
You can run training and inference for any model using a config file.
Training:
python train.py --config configs/<config_name>.yamlExamples:
python train.py --config configs/resnet.yaml
python train.py --config configs/efficientnet.yaml
python train.py --config configs/vit_2021.yamlInference:
python inference.py --config configs/<config_name>.yamlExamples:
python inference.py --config configs/mlp.yaml
python inference.py --config configs/densenet.yaml
python inference.py --config configs/t2t_vit.yamlThis approach works for all models and keeps commands consistent.
cd plant-disease-classification
docker build -f deploy/Dockerfile -t plant-classifier .Mount your local data directory and provide model config and weights:
cd plant-disease-classification
docker run -p 8000:8000 \
-v $(pwd)/data:/app/data \
-e MODEL_CONFIG_PATH=/app/configs/<your_model>.yaml \
-e MODEL_WEIGHTS_PATH=/app/data/logs/<your_model>/best_model.pt \
plant-classifier-p 8000:8000exposes the API on port 8000.-v $(pwd)/data:/app/dataallows the container to access model weights.MODEL_CONFIG_PATHandMODEL_WEIGHTS_PATHmust point to the YAML config and.ptfile inside the container.
OpenAPI JSON: http://localhost:8000/openapi.json
Send a POST request to /predict with a leaf image:
curl -X POST "http://localhost:8000/predict" \
-F "file=@path_to_leaf_image.jpg"The API returns a JSON dictionary with class probabilities:
{
"Tomato___Early_blight": 0.92,
"Tomato___Leaf_Mold": 0.05,
"Tomato___healthy": 0.02,
...
}Navigate to the frontend directory and set environment variables:
cd front/app
export MODEL_CONFIG_PATH=../../configs/<config_name>.yaml
export MODEL_WEIGHTS_PATH="weights/best_model.pt"
python main.pyAvailable configs: deep_mlp.yaml, densenet.yaml, efficientnet.yaml, mlp.yaml, resnet.yaml, t2t_vit.yaml, vit_2021.yaml
| Model | Accuracy | Precision | Recall | F1-score |
|---|---|---|---|---|
| DenseNet-121 | 0.9961 | 0.9962 | 0.9960 | 0.9960 |
| ResNet-18 | 0.9959 | 0.9958 | 0.9959 | 0.9958 |
| EfficientNet-B0 | 0.9950 | 0.9952 | 0.9947 | 0.9948 |
| Tokens-to-Token ViT (T2T-ViT) | 0.9929 | 0.9930 | 0.9929 | 0.9929 |
| Vision Transformer (ViT) | 0.9909 | 0.9911 | 0.9908 | 0.9909 |
| Deep MLP | 0.7620 | 0.7755 | 0.7618 | 0.7601 |
| MLP | 0.0552 | 0.0430 | 0.0522 | 0.0263 |
Top performing models are CNNs and Transformer-based models. Simple MLPs perform poorly, highlighting the importance of spatial feature extraction.
- Python 3.10+
- PyTorch
- Albumentations
- Torchvision
