This project implements an image captioning model using a CNN encoder (ResNet-50) and a Transformer-based decoder. It's built with PyTorch and designed for flexibility and ease of use, allowing for training, evaluation, and prediction of image captions.
- Project Overview
- Features
- Workflow
- Setup
- Configuration
- Usage
- Dataset
- Model Architecture
- Logging
- Dependencies
- License
The core goal of this project is to generate descriptive captions for input images. It leverages a pre-trained ResNet-50 to extract visual features from images and a Transformer decoder to generate textual descriptions based on these features. The entire pipeline, from data loading to model inference, is configurable through a central JSON file.
- Encoder-Decoder Architecture: Uses ResNet-50 for image encoding and a Transformer for caption decoding.
- Pre-trained Encoder: Option to use pre-trained ResNet-50 weights and fine-tune or freeze the encoder.
- Configurable Parameters: Easily manage hyperparameters, paths, and settings via
config.json. - Vocabulary Management: Builds and saves/loads a vocabulary from training captions.
- Training Pipeline: Supports training from scratch or resuming from checkpoints, with validation and early stopping.
- Evaluation: Implements standard COCO evaluation metrics (BLEU, METEOR, ROUGE-L, CIDEr, SPICE) using
pycocoevalcap. - Prediction: Allows caption generation for single or multiple custom images, or examples from the validation set.
- Logging: Comprehensive logging for training, evaluation, and prediction processes.
- Dependency Management: Uses PDM for managing Python dependencies.
- Configuration: Define paths, hyperparameters, and other settings in
config.json. - Data Preparation: The
dataset.pyscript handles loading images and COCO-style captions. A vocabulary is built usingvocabulary.py. - Training: Run
train.pyto train the model. Progress is logged, and checkpoints (including the best model) are saved. - Evaluation: Use
evaluate.pyto assess the trained model's performance on a test/validation set using COCO metrics. Results are saved. - Prediction: Employ
predict.pyto generate captions for new images.
- Python (version specified in
pyproject.toml, typically 3.8+) - PDM (Python Dependency Manager)
-
Clone the repository (if you haven't already):
git clone git@github.com:turgaybulut/image-captioning.git cd image-captioning -
Install PDM (if not already installed):
pip install pdm
-
Install project dependencies: This command installs all dependencies, including development tools.
pdm install -G:all
If you only need runtime dependencies:
pdm install
All project settings are managed through the config.json file. This includes:
- Paths: Locations for datasets, vocabulary, model checkpoints, and evaluation results.
- Dataset Parameters: Vocabulary frequency threshold, subset sizes for quick runs.
- Dataloader Settings: Batch size, number of workers, pin memory.
- Model Hyperparameters: Embedding size, decoder layers, heads, feed-forward dimensions, dropout, CNN training flag.
- Training Settings: Learning rate, number of epochs, model loading flag, gradient clipping, early stopping patience.
- Prediction Settings: Maximum caption length.
The project scripts are typically run using PDM.
To train the model:
pdm run train- Ensure
config.jsonpoints to your training images and COCO-style caption files. - The script will build/load vocabulary, initialize the model, and start the training loop.
- Checkpoints and the best model will be saved according to
config.jsonpaths.
To evaluate a trained model:
pdm run evaluate- This uses the model specified by
best_model_checkpointinconfig.json. - It generates captions for the validation/test set and computes COCO metrics.
- Evaluation results (generated captions and scores) are logged and saved.
To generate captions for new images via the command line:
pdm run predict --image_paths /path/to/your/image1.jpg /path/to/your/image2.png- If
--image_pathsis omitted, it will predict on a few random examples from the validation set specified inconfig.json. - The script loads the
best_model_checkpointand associated vocabulary.
The project also includes a Flask-based web application (app.py) for interactive caption generation.
Running the Web App:
- Ensure all dependencies are installed (as per the Installation section).
- Make sure your
config.jsonis correctly set up, especially the paths to thebest_model_checkpointandvocab_file. - Run the Flask application using PDM:
pdm run app
- Open your web browser and navigate to
http://localhost:8080(or the port specified inapp.pyor your environment).
The web interface allows you to upload an image and view the generated caption. It also includes a feature to visualize attention maps if the model supports it and the generate_caption_with_word_attention method is implemented in model.py.
This project is designed to work with datasets in the COCO format.
- Images: A directory containing image files (e.g.,
.jpg,.png). - Captions: A JSON file in COCO annotation format, containing an
"images"list and an"annotations"list. Each annotation should have animage_idand acaption.
Update the paths section in config.json to point to your dataset directories and caption files.
The model (model.py) consists of:
EncoderCNN: A CNN based on a pre-trained ResNet-50 (fromtorchvision.models) to extract image features. The final classification layer is removed. Can be fine-tuned or frozen.PositionalEncoding: Adds positional information to token embeddings, crucial for Transformers.DecoderTransformer: A stack of Transformer decoder layers (torch.nn.TransformerDecoder) that generates captions word by word based on image features and previously generated words.CaptioningModel: Encapsulates the encoder and decoder. Includes agenerate_captionmethod for inference.
- The
utils.pymodule sets up logging for the project. - Logs are output to the console and saved to files within the
logs/directory:train.log: For the training script.evaluate.log: For the evaluation script.predict.log: For the prediction script.
Key Python libraries used:
- PyTorch (
torch,torchvision) pycocotoolsandpycocoevalcapfor COCO dataset interaction and evaluation.- PIL (Pillow) for image manipulation.
tqdmfor progress bars.
All dependencies are managed by PDM via pyproject.toml and pdm.lock.