diff --git a/.gitignore b/.gitignore index adf0e2c2e9..1a48f990df 100644 --- a/.gitignore +++ b/.gitignore @@ -5,9 +5,6 @@ __pycache__/ venv/ .env -private/ -/data/ - /models/ -/simple_classifier_checkpoints/ -/main_classifier_checkpoints/ \ No newline at end of file +/data/ +/private/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d2bc28c2b..242328f4f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,6 @@ repos: hooks: - id: run-unittests name: Run unittests - entry: pipenv run python -m unittest discover tests + entry: pipenv run python -m unittest discover tests/unit language: system pass_filenames: false diff --git a/Pipfile b/Pipfile index b646d315d9..5430039488 100644 --- a/Pipfile +++ b/Pipfile @@ -15,7 +15,8 @@ pre-commit = "*" torch = {index = "pytorch-cu128", version = "2.7.0+cu128", markers = "platform_system == 'Windows'"} torchvision = {index = "pytorch-cu128", version = "0.22.0+cu128", markers = "platform_system == 'Windows'"} torchaudio = {index = "pytorch-cu128", version = "2.7.0+cu128", markers = "platform_system == 'Windows'"} -pandas = "*" +pillow = "*" +pandas= "*" torcheval = "*" tqdm = "*" requests = "*" diff --git a/README.md b/README.md index e8a0869ec6..ae09e817c4 100644 --- a/README.md +++ b/README.md @@ -23,24 +23,6 @@ Make sure you have the following installed: For detailed installation instructions, [click here](https://pipenv.pypa.io/en/latest/installation.html). ## Getting Started -### Setting up the repository - -If you're collaborating or want to explore the latest version of the project: - -1. Fork this repository. -2. Clone your fork locally. -3. Configure a remote pointing to the upstream repository to sync changes between your fork and the original repository. - ```bash - git remote add upstream https://github.com/sinan2000/recaptcha - ``` - **Don't skip this step.** We might update the original repository, so you should be able to easily pull our changes. - - To update your forked repo follow these steps: - 1. `git fetch upstream` - 2. `git rebase upstream/main` - 3. `git push origin main` - Sometimes you may need to use git push --force origin main. Only use this flag the first time you push after you rebased, and be careful as you might overwrite someone' changes. - ## Installing Dependencies To install the project dependencies run: @@ -102,9 +84,9 @@ To make navigating through the repository easier, you can find its structure bel β”œβ”€β”€β”€README.md # Instructions ``` -## **API Launching πŸš€** +## **πŸš€ Usage Guide** -To start the FastAPI server locally, follow these steps: +### Running the App 1. Activate pipenv environment (if not already activated) @@ -112,50 +94,50 @@ To start the FastAPI server locally, follow these steps: pipenv shell ``` -2. You can start the FastAPI server using: - +2. To launch any component of our project, run: ```bash -uvicorn recaptcha_classifier.api:app --host 0.0.0.0 --port 8000 --reload +python main.py [OPTION] ``` -Alternatively, run the API by running main.py: +Available list of options: +--streamlit - Launches Streamlit UI +--api - starts the FastAPI backend +--train-simple-cnn - Trains the simple baseline model +--train-main-cnn - Trains our main model -Make sure the open_api() function is uncommented in main.py +If no argument has been passed, an interactive menu will appear to let you choose the action. -```python -def main(): - # train_main_classifier() - open_api() # βœ… Uncomment this line -``` +## API Documentation -### API Documentation +### Example API call and response format -After running the server, you can access the Documentation: +You can make a call to the api using curl, by running the command below. Make sure to include a valid file path. The path can either be absolute (full) or relative to your +current location from command terminal. -Interactive API docs (Swagger UI): http://localhost:8000/docs +```bash +curl -X POST "http://localhost:8000/predict" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@" +``` +You will get a response in the following format: -ReDoc documentation: http://localhost:8000/redoc +```json + { + "class_id":1, + "class_name":"Bridge", + "confidence": "99.9%" + } +``` -These interfaces allow you to test predictions and inspect the request/response formats. +The API is stateless, initializing the model on launching, and caches responses for 1 hour. -### API call and response format -You can make a call to the api using curl, by running the chunk below. Make sure to include a valid file path. +After running the server, you can access the Documentation: -```bash -curl -X POST "http://localhost:8000/predict" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=" -``` -You will get a response in the following format: +Interactive API docs (Swagger UI): http://localhost:8000/docs -```bash - {"class_id":1,"class_name":"Bridge"} -``` +ReDoc documentation: http://localhost:8000/redoc -If you're using Windows, running the request in the format mentioned above may not work in Powershell. Instead, use the format below in Command Prompt: +These interfaces allow you to test predictions and inspect the request/response formats. -```bash -curl -X POST "http://localhost:8000/predict" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@" -``` ### Possible error Responses diff --git a/tests/data/__init__.py b/evaluation_results.json similarity index 100% rename from tests/data/__init__.py rename to evaluation_results.json diff --git a/main.py b/main.py index 5c508c1aa1..bcfe6274a5 100644 --- a/main.py +++ b/main.py @@ -49,10 +49,14 @@ def handle_action(choice: str): handle_action(choice) def ui(): - from recaptcha_classifier.server.app import StreamlitApp + import subprocess + import os + root_dir = os.path.dirname(os.path.abspath(__file__)) + streamlit_file = os.path.join(root_dir, "recaptcha_classifier", "server", "app.py") - app = StreamlitApp() - app.render() + subprocess.Popen(["uvicorn", "recaptcha_classifier.server.api:app", "--reload"]) + + subprocess.run(["streamlit", "run", streamlit_file]) def train_simple_cnn(): from recaptcha_classifier.pipeline.simple_cnn_pipeline import SimpleClassifierPipeline @@ -63,7 +67,7 @@ def train_simple_cnn(): def train_main_classifier(): from recaptcha_classifier.pipeline.main_model_pipeline import MainClassifierPipeline - pipeline = MainClassifierPipeline(epochs=1, k_folds=2) + pipeline = MainClassifierPipeline() pipeline.run(save_train_checkpoints=False) def open_api(): @@ -71,6 +75,5 @@ def open_api(): # opens endpoint at http://localhost:8000/ uvicorn.run("recaptcha_classifier.server.api:app", reload=True) - if __name__ == '__main__': main() \ No newline at end of file diff --git a/recaptcha_classifier/__init__.py b/recaptcha_classifier/__init__.py index c1a3c4d72c..3e5b54a524 100644 --- a/recaptcha_classifier/__init__.py +++ b/recaptcha_classifier/__init__.py @@ -1,9 +1,18 @@ from .models import SimpleCNN +from .models.main_model import MainCNN, HPOptimizer from .detection_labels import DetectionLabels from .data import DataPreprocessingPipeline +from .train import Trainer +from .features import evaluate_model +from .server import load_model __all__ = [ "DetectionLabels", "DataPreprocessingPipeline", - 'SimpleCNN' + 'SimpleCNN', + 'MainCNN', + 'HPOptimizer', + 'Trainer', + 'evaluate_model', + 'load_model' ] \ No newline at end of file diff --git a/recaptcha_classifier/constants.py b/recaptcha_classifier/constants.py index 943487486f..a09a783239 100644 --- a/recaptcha_classifier/constants.py +++ b/recaptcha_classifier/constants.py @@ -5,7 +5,7 @@ IMAGE_SIZE = (IMAGE_WIDTH, IMAGE_HEIGHT) INPUT_SHAPE = (IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH) -MODELS_FOLDER="models" # Folder where models are saved +MODELS_FOLDER="models/final" # Folder where models are saved MAIN_MODEL_FILE_NAME = "main_model.pt" SIMPLE_MODEL_FILE_NAME = "simple_model.pt" OPTIMIZER_FILE_NAME = "optimizer.pt" diff --git a/recaptcha_classifier/data/__init__.py b/recaptcha_classifier/data/__init__.py index c341ff9f94..ab71a0435a 100644 --- a/recaptcha_classifier/data/__init__.py +++ b/recaptcha_classifier/data/__init__.py @@ -1,26 +1,19 @@ from .pipeline import DataPreprocessingPipeline from .downloader import DatasetDownloader -from .pair_loader import ImageLabelLoader +from .paths_loader import ImagePathsLoader from .splitter import DataSplitter from .visualizer import Visualizer from .loader_factory import LoaderFactory from .dataset import ImageDataset from .preprocessor import ImagePrep -from .augment import ( - AugmentationPipeline, - HorizontalFlip, - RandomRotation -) -from .scaler import YOLOScaler +from .augment import AugmentationPipeline from .collate_batch import collate_batch from .types import ( - FilePair, - FilePairList, - DatasetSplitDict, - BBoxList, - DataPair, + ImagePathList, + DatasetSplitMap, + LoadedImg, DataItem, DataBatch ) @@ -29,24 +22,20 @@ # Classes "DataPreprocessingPipeline", "DatasetDownloader", - "ImageLabelLoader", + "ImagePathsLoader", "DataSplitter", "Visualizer", "LoaderFactory", "ImageDataset", "ImagePrep", "AugmentationPipeline", - "HorizontalFlip", - "RandomRotation", - "YOLOScaler", # Methods "collate_batch", # Types "FilePair", - "FilePairList", - "DatasetSplitDict", - "BBoxList", - "DataPair", + "ImagePathList", + "DatasetSplitMap", + "LoadedImg", "DataItem", "DataBatch", ] diff --git a/recaptcha_classifier/data/augment.py b/recaptcha_classifier/data/augment.py index ffbf127646..9a84e1538b 100644 --- a/recaptcha_classifier/data/augment.py +++ b/recaptcha_classifier/data/augment.py @@ -1,93 +1,27 @@ import random -from abc import ABC, abstractmethod -from PIL import Image from typing import List -from .scaler import YOLOScaler -from .types import DataPair, BBoxList - - -class Augmentation(ABC): - """Abstract class for data augmentation.""" - @abstractmethod - def augment(self, - image: Image.Image, - annotations: List) -> DataPair: - """ - Apply the transformation of the image and updates the bounding boxes - if necessary. - - Args: - image (Image.Image): The image to be augmented. - annotations (List): List of annotations associated with the image. - - Returns: - DataPair: The augmented image and the updated - annotations. - """ - pass +from .types import LoadedImg +from torchvision import transforms class AugmentationPipeline: """Class to manage a series of augmentations in sequence.""" - def __init__(self, transforms=[]) -> None: - self._transforms: List[Augmentation] = transforms + def __init__(self, transforms_list: List) -> None: + """ + Initializes the augmentation pipeline with a list of transformations. + """ + self._transforms_list = transforms_list def apply_transforms(self, - image: Image.Image, - annotations: BBoxList) -> DataPair: + image: LoadedImg) -> LoadedImg: """ - Apply all transformations in the pipeline to the image and - annotations. + Apply a random transformation from the pipeline to the image. Args: - image (Image.Image): The image to be augmented. - annotations (BBoxList): List of annotations - associated with the image. + image (LoadedImg): The image to be augmented. Returns: - DataPair: The augmented image and the updated - annotations. + LoadedImg: The augmented image. """ - for transform in self._transforms: - if hasattr(transform, 'prob') and random.random() > transform.prob: - continue - image, annotations = transform.augment(image, annotations) - return image, annotations - - -class HorizontalFlip(Augmentation): - """Flips the image horizontally, with probability p and updates bboxes.""" - def __init__(self, p: float = 0.5) -> None: - self.prob = p - - def augment(self, - image: Image.Image, - annotations: BBoxList) -> DataPair: - flipped = image.transpose(Image.FLIP_LEFT_RIGHT) - new_annotations = YOLOScaler.scale_for_flip(annotations) - - return flipped, new_annotations - - -class RandomRotation(Augmentation): - """ - Rotates the image by a random angle, - also updates bboxes to reflect the rotation. - """ - def __init__(self, degrees: float = 30.0, p: float = 0.5) -> None: - self._degrees = degrees - self.prob = p - - def augment(self, - image: Image.Image, - annotations: BBoxList) -> DataPair: - angle = random.uniform(-self._degrees, self._degrees) - - rotated = image.rotate(angle) - new_annotations = (YOLOScaler - .scale_for_rotation(annotations, - angle, - image.size) - ) - - return rotated, new_annotations + transform = random.choice(self._transforms_list) + return transform(image) \ No newline at end of file diff --git a/recaptcha_classifier/data/collate_batch.py b/recaptcha_classifier/data/collate_batch.py index 55cc67c688..05b0050130 100644 --- a/recaptcha_classifier/data/collate_batch.py +++ b/recaptcha_classifier/data/collate_batch.py @@ -1,34 +1,29 @@ from typing import List -from .types import DataBatch, DataItem +from .types import DataItem, DataBatch import torch def collate_batch(batch: List[DataItem]) -> DataBatch: """ - Custom collate function used in the PyTorch DataLoader to handle - the non-uniform length of bounding box lists. + Custom collate function used in the PyTorch DataLoader to combine + the list of data items into a stacked batch for model training. Args: - batch (List[DataItem]): A batch of training items; each item is - a tuple of format (image tensor, bounding boxes, class index). + batch (List[DataItem]): A list of dataset items, tensors of + image and label pair. Returns: Batch: A single tuple containing: - images_tensor (Tensor): Batched images of shape (batch_size, 3, H, W) - - bboxes (List[BBoxList]): A list of - bounding boxes for each image - labels_tensor (Tensor): A tensor of shape containing the class indices. """ images = [item[0] for item in batch] - # bboxes = [item[1] for item in batch] - # labels = [item[2] for item in batch] labels = [item[1] for item in batch] # Stack them as (3, H, W) tensors images_tensor = torch.stack(images) labels_tensor = torch.stack(labels) - # return images_tensor, bboxes, labels_tensor return images_tensor, labels_tensor diff --git a/recaptcha_classifier/data/dataset.py b/recaptcha_classifier/data/dataset.py index 803b4b6027..4d64cccfb7 100644 --- a/recaptcha_classifier/data/dataset.py +++ b/recaptcha_classifier/data/dataset.py @@ -2,12 +2,12 @@ from torch.utils.data import Dataset from .preprocessor import ImagePrep from .augment import AugmentationPipeline -from .types import FilePairList, DataItem +from .types import ImagePathList, DataItem class ImageDataset(Dataset): """ - A class to handle the pairs of (image, label) for the dataset. + A class to handle the images for the dataset. It makes them ready for training, by applying augmentation for the training set, preprocessing and makes sure that the output format is in PyTorch Tensor format. @@ -17,7 +17,7 @@ class ImageDataset(Dataset): https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html """ def __init__(self, - pairs: FilePairList, + items: ImagePathList, preprocessor: ImagePrep, augmentator: Optional[AugmentationPipeline] = None, class_map: dict = {} @@ -25,13 +25,13 @@ def __init__(self, """ Initializes the ImageDataset with the given parameters. """ - self._pairs = pairs + self._items = items self._prep = preprocessor self._aug = augmentator self._class_map = class_map def __len__(self) -> int: - return len(self._pairs) + return len(self._items) def __getitem__(self, idx: int @@ -47,20 +47,14 @@ def __getitem__(self, preprocessed image in tensor format, the YOLO bound box annotations and the label. """ - # img_path, lbl_path = self._pairs[idx] - img_path = self._pairs[idx] + img_path = self._items[idx] - # Load image and label + # Load image img = self._prep.load_image(img_path) - # bb = self._prep.load_labels(lbl_path) - - # if not bb: - # raise ValueError(f"Bounding box list is empty for {lbl_path}") # Apply augmentation if passed if self._aug: - # img, bb = self._aug.apply_transforms(img, bb) - img, _ = self._aug.apply_transforms(img, []) + img = self._aug.apply_transforms(img) # Convert image to tensor tensor = self._prep.to_tensor(img) @@ -71,6 +65,4 @@ def __getitem__(self, raise KeyError(f"Class name '{c_name}' not found in classes.") c_id = self._class_map[c_name] - # Return image tensor, bounding box and class index return tensor, self._prep.class_id_to_tensor(c_id) - # return tensor, bb, c_id diff --git a/recaptcha_classifier/data/downloader.py b/recaptcha_classifier/data/downloader.py index 0ba4c6a336..219c7f25f0 100644 --- a/recaptcha_classifier/data/downloader.py +++ b/recaptcha_classifier/data/downloader.py @@ -1,11 +1,10 @@ import requests import zipfile -import logging +import shutil from pathlib import Path +from typing import List from alive_progress import alive_bar -logger = logging.getLogger(__name__) - class DatasetDownloader: """ @@ -16,6 +15,7 @@ class DatasetDownloader: dataset downloading operation, with no other responsibilities. """ def __init__(self, + class_names: List[str], url: str = ("https://www.kaggle.com/api/v1/datasets/" "download/mikhailma/test-dataset"), dest: str = "data") -> None: @@ -30,6 +30,7 @@ def __init__(self, self._dest: Path = Path(dest) self._zip_path: Path = self._dest / "dataset.zip" self._progress = alive_bar + self._expected_folder_names = class_names def download(self) -> None: """ @@ -37,22 +38,31 @@ def download(self) -> None: If not, downloads and then unzips it. """ if self._is_downloaded(): - logger.info("Dataset already exists, skipping download.") + print("Dataset already exists, skipping download.") return self._prepare_dest() - logger.info(f"Downloading {self._url} to {self._dest}...") + print(f"Downloading {self._url} to {self._dest}...") self._download_zip() - self._unzip_and_cleanup() + self._extract_zip() + self._move_subfolders() + self._delete_labels() + self._flatten_images_folder() + self._zip_path.unlink() + print("Download and extraction completed successfully.") def _is_downloaded(self) -> bool: """ - Checks if the dataset is already downloaded. + Checks if the dataset is already downloaded and in the expected format. Returns: bool: True if the dataset is already downloaded, False otherwise. """ - return self._dest.exists() and any(self._dest.iterdir()) + if not self._dest.exists(): + return False + folders = {p.name for p in self._dest.iterdir() if p.is_dir()} + expected = set(self._expected_folder_names) + return expected.issubset(folders) def _prepare_dest(self) -> None: """ @@ -82,24 +92,52 @@ def _download_zip(self) -> None: for chunk in resp.iter_content(chunk_size=8192): f.write(chunk) bar(len(chunk)) - logger.info("Download completed successfully.") - - def _unzip_and_cleanup(self) -> None: + + def _extract_zip(self) -> None: """ - Unzips the downloaded dataset. - Then it cleans up the zip file and its main extracted directory. - - Note: this assumes that the downloaded dataset has exactly - the structure of our selected Kaggle dataset for simplicity. + Extracts the downloaded zip file to the destination directory. """ - logger.info("Extracting...") with zipfile.ZipFile(self._zip_path) as z: z.extractall(self._dest) - root = next(p for p in self._dest.iterdir() if p.is_dir()) + def _move_subfolders(self) -> None: + """ + Moves the main images and labels subfolders + from the extracted directory to the destination directory. + Finally, it removes the main extracted directory that is now empty. + """ + root = next(p for p in self._dest.iterdir() if p.is_dir() and p.name not in self._expected_folder_names) + for sub in ("images", "labels"): - (root / sub).rename(self._dest / sub) + source = root / sub + if source.exists(): + target = self._dest / sub + if target.exists(): + shutil.rmtree(target) + source.rename(target) + + if root.exists() and root.is_dir(): + root.rmdir() + + def _delete_labels(self) -> None: + """ + Deletes the labels directory if it exists. + """ + label_dir = self._dest / "labels" + if label_dir.exists() and label_dir.is_dir(): + shutil.rmtree(label_dir) + + def _flatten_images_folder(self) -> None: + images_dir = self._dest / "images" + if not images_dir.exists(): + return - root.rmdir() - self._zip_path.unlink() - print("Extraction and cleanup completed successfully.") + for subfolder in images_dir.iterdir(): + if subfolder.is_dir(): + target_path = self._dest / subfolder.name + if target_path.exists(): + shutil.rmtree(target_path) + subfolder.rename(target_path) + + if not any(images_dir.iterdir()): + images_dir.rmdir() diff --git a/recaptcha_classifier/data/loader_factory.py b/recaptcha_classifier/data/loader_factory.py index 26e807c834..0aeded526f 100644 --- a/recaptcha_classifier/data/loader_factory.py +++ b/recaptcha_classifier/data/loader_factory.py @@ -4,7 +4,7 @@ from .dataset import ImageDataset from .preprocessor import ImagePrep from .augment import AugmentationPipeline -from .types import DatasetSplitDict, FilePairList +from .types import DatasetSplitMap, ImagePathList from .collate_batch import collate_batch @@ -25,7 +25,7 @@ def __init__(self, Args: class_map (dict): A dictionary mapping class names to indices. - preprocessor (ImagePrep): The preprocessor to use. + preprocessor (ImagePrep): The preprocessor to use for data. augmentator (Optional[AugmentationPipeline]): The augmentator used batch_size (int): Batch size for DataLoader. num_workers (int): Number of workers for DataLoader. @@ -39,29 +39,29 @@ def __init__(self, self._class_map = class_map def create_loaders(self, - splits: DatasetSplitDict) -> Dict[str, DataLoader]: + splits: DatasetSplitMap) -> Dict[str, DataLoader]: loaders: Dict[str, DataLoader] = {} for split_name, cls_dict in splits.items(): - # flatten nested dict of pairs - flat_pairs: FilePairList = [pair + # flatten nested dict of image_paths + flat_image_paths: ImagePathList = [image_path # traversing over classes - for pairs in cls_dict.values() - # traversing over pairs - for pair in pairs + for image_paths in cls_dict.values() + # traversing over image_paths + for image_path in image_paths ] # augmentatr only for training set augmentator = self._aug if split_name == 'train' else None dataset = ImageDataset( - pairs=flat_pairs, + items=flat_image_paths, preprocessor=self._preprocessor, augmentator=augmentator, class_map=self._class_map ) - sampler = (self._build_sampler(flat_pairs) if self._balance + sampler = (self._build_sampler(flat_image_paths) if self._balance and split_name == "train" else None) loader = DataLoader( @@ -77,14 +77,13 @@ def create_loaders(self, return loaders - def _build_sampler(self, pairs): + def _build_sampler(self, image_paths): """ Builds a sampler for the dataset to balance the classes. """ class_counts = Counter() targets = [] - # for img_path, _ in pairs: - for img_path in pairs: + for img_path in image_paths: cls = img_path.parent.name targets.append(self._class_map[cls]) class_counts[self._class_map[cls]] += 1 diff --git a/recaptcha_classifier/data/pair_loader.py b/recaptcha_classifier/data/pair_loader.py deleted file mode 100644 index c8885cbf2f..0000000000 --- a/recaptcha_classifier/data/pair_loader.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Dict -from .types import ClassFileDict - -logger = logging.getLogger(__name__) - - -class ImageLabelLoader: - """ - This class loads all image-label pairs for given classes. - - It scans for all matching images and labels and caches the result. - - It follows Single Responsibility Principle (SRP) as it only handles - the loading of the pairs. Also, it uses the Iterator pattern, as - it can be looped over to get the list pairs as tuples by class, - in format (class, [(img_path, lbl_path), ...].) - """ - def __init__(self, - classes: List[str], - images_dir: str = "data/images", - # labels_dir: str = "data/labels" - ) -> None: - """ - Initializes the PairsLoader instance. - - Args: - classes (List[str]): List of class names to load. - images_dir (str): Path to the directory containing images. - labels_dir (str): Path to the directory containing labels. - """ - self._classes = classes - self._images_dir = Path(images_dir) - # self._labels_dir = Path(labels_dir) - self._pairs: ClassFileDict = dict() - - def find_pairs(self) -> ClassFileDict: - """ - Returns all pairs loaded for the given classes. - It caches the response after first run. - - Returns: - List[Tuple[Path, Path]]: List of tuples - containing image and label paths. (img_path, lbl_path) - """ - if not self._pairs: - self._load_pairs() - return self._pairs - - def __iter__(self): - """ - Iterates over classes and their respective list of pairs. - - Yields: - Tuple[str, List[Tuple[Path, Path]]]: Class name and list of - tuples containing image and label paths. - """ - for cls, pairs in self.find_pairs().items(): - yield cls, pairs - - def __len__(self) -> int: - """ - Returns total number of matched pairs. - - Returns: - int: Number of matched pairs. - """ - return sum(len(pairs) for pairs in self.find_pairs().values()) - - def class_count(self) -> Dict[str, int]: - """ - Returns a dictionary with the count of pairs for each class. - The keys are class names and the values are the counts. - - Returns: - Dict[str, int]: Dictionary with class names as keys and - counts as values. - """ - return {cls: len(pairs) for cls, pairs in self.find_pairs().items()} - - def _load_pairs(self) -> None: - """ - Private method to scan directories of given classes and match all - available image-label pairs. - It ignores images with missing labels. - """ - total_count = 0 - for cls in self._classes: - img_dir = self._images_dir / cls - # lbl_dir = self._labels_dir / cls - # skipped, cls_count = 0, 0 - - if not Path.is_dir(img_dir): # or not Path.exists(lbl_dir) - logger.info(f"Warning: Missing folder for {cls}. Skipping.") - continue - - self._pairs[cls] = list(img_dir.glob("*.png")) # [] - N = len(self._pairs[cls]) - logger.info(f"Found {N} images in {cls}.") - total_count += N - """ - for img_path in img_dir.glob("*.png"): - lbl_path = lbl_dir / img_path.name.replace(".png", ".txt") - cls_count += 1 - - if not Path.exists(lbl_path): - skipped += 1 - continue - - self._pairs[cls].append((img_path, lbl_path)) - - - print(f"Loaded {cls_count - skipped} image-label pairs for {cls}.") - if skipped > 0: - print(f"Warning: {skipped} missing labels in {cls}. Skipped.") - - total_count += cls_count - skipped - """ - - print(f"Total pairs loaded: {total_count}") diff --git a/recaptcha_classifier/data/paths_loader.py b/recaptcha_classifier/data/paths_loader.py new file mode 100644 index 0000000000..712e09ad90 --- /dev/null +++ b/recaptcha_classifier/data/paths_loader.py @@ -0,0 +1,93 @@ +from pathlib import Path +from typing import List, Dict +from .types import ClassToImgPaths + + +class ImagePathsLoader: + """ + This class loads all image paths for given classes. + + It scans for all matching images and caches the result. + + It follows Single Responsibility Principle (SRP) as it only handles + the loading of the image paths. Also, it uses the Iterator pattern, as + it can be looped over to get the list paths as tuples by class, + in format (class, [img_path1, ...]). + """ + def __init__(self, + classes: List[str], + images_dir: str = "data") -> None: + """ + Initializes the ImagePathsLoader instance. + + Args: + classes (List[str]): List of class names to load. + images_dir (str): Path to the directory containing images. + """ + self._classes = classes + self._images_dir = Path(images_dir) + self._paths: ClassToImgPaths = dict() + + def find_image_paths(self) -> ClassToImgPaths: + """ + Returns all image paths loaded for the given classes. + It caches the response after first run. + + Returns: + ClassToImgPaths: Dictionary mapping class names to lists + of image paths. + """ + if not self._paths: + self._load_pairs() + return self._paths + + def __iter__(self): + """ + Iterates over classes and their respective list of pairs. + + Yields: + Tuple[str, List[Path]]: Class name and list of + image paths. + """ + for cls, pairs in self.find_image_paths().items(): + yield cls, pairs + + def __len__(self) -> int: + """ + Returns total number of matched pairs. + + Returns: + int: Number of matched pairs. + """ + return sum(len(pairs) for pairs in self.find_image_paths().values()) + + def class_count(self) -> Dict[str, int]: + """ + Returns a dictionary with the count of pairs for each class. + The keys are class names and the values are the counts. + + Returns: + Dict[str, int]: Dictionary with class names as keys and + counts as values. + """ + return {cls: len(pairs) for cls, pairs in self.find_image_paths().items()} + + def _load_pairs(self) -> None: + """ + Private method to scan directories of given classes and match all + available image paths. + """ + total_count = 0 + for cls in self._classes: + img_dir = self._images_dir / cls + + if not Path.is_dir(img_dir): + print(f"Warning: Missing folder for {cls}. Skipping.") + continue + + self._paths[cls] = sorted(img_dir.glob("*.png")) + N = len(self._paths[cls]) + print(f"Found {N} images in {cls}.") + total_count += N + + print(f"Total image paths found: {total_count}") diff --git a/recaptcha_classifier/data/pipeline.py b/recaptcha_classifier/data/pipeline.py index a2949c67c6..63c474c774 100644 --- a/recaptcha_classifier/data/pipeline.py +++ b/recaptcha_classifier/data/pipeline.py @@ -1,16 +1,13 @@ from typing import Dict, Tuple from torch.utils.data import DataLoader - +from enum import EnumMeta from .downloader import DatasetDownloader -from .pair_loader import ImageLabelLoader +from .paths_loader import ImagePathsLoader from .splitter import DataSplitter from .visualizer import Visualizer from .preprocessor import ImagePrep -from .augment import ( - AugmentationPipeline, - HorizontalFlip, - RandomRotation -) +from .augment import AugmentationPipeline +from torchvision import transforms from .loader_factory import LoaderFactory @@ -33,18 +30,18 @@ class DataPreprocessingPipeline: pipeline and it includes all its components. """ def __init__(self, - class_map: Dict[str, int], + class_enum: EnumMeta, ratios: Tuple[float, float, float] = (0.7, 0.2, 0.1), seed: int = 23, # our group number batch_size: int = 32, num_workers: int = 4, - balance: bool = False, - show_plots: bool = True) -> None: + balance: bool = True, + show_plots: bool = False) -> None: """ Initializes the DataPreprocessingPipeline with the given parameters. Args: - class_map (Dict[str, int]): Mappng of class names to indices. + class_enum (EnumMeta): Enum class containing dataset classes. ratios (Tuple[float, float, float]): Ratios for train, val, and test splits. seed (int): Random seed for reproducibility. @@ -53,14 +50,16 @@ def __init__(self, balance (bool): Whether to balance the dataset. show_plots (bool): Whether to show plots. """ - self._downloader = DatasetDownloader() - self._loader = ImageLabelLoader(list(class_map.keys())) + self._class_enum = class_enum + self._downloader = DatasetDownloader(self._class_enum + .dataset_classnames()) + self._loader = ImagePathsLoader(self._class_enum.dataset_classnames()) self._splitter = DataSplitter(ratios, seed=seed) self._show_plots = show_plots self._preproc = ImagePrep() self._augment = self._build_augmentator() self._creator = LoaderFactory( - class_map=class_map, + class_map=self._class_enum.to_class_map(), preprocessor=self._preproc, augmentator=self._augment, batch_size=batch_size, @@ -71,13 +70,26 @@ def __init__(self, def _build_augmentator(self) -> AugmentationPipeline: """ Builds the augmentation pipeline. + One of these augmentations will be applied randomly + in the dataset, for any image in the training set. Returns: AugmentationPipeline: The augmentation pipeline. """ return AugmentationPipeline([ - HorizontalFlip(p=0.5), - RandomRotation(degrees=30, p=0.5) + # 50 % chance to flip the image horizontally + transforms.RandomHorizontalFlip(p=0.5), + # rotates image randomly within +-15 degrees + transforms.RandomRotation(degrees=15), + # randomly changes brightness, contrast and saturation + transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3), + # Translates/ shifts image up to 10% of its size + # in both x and y directions + transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), + # Mimics a camera lens blur effect + transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)), + # Crops image 90% then resizes it to 224x224 (zoom in effect) + transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0)), ]) def run(self) -> Dict[str, DataLoader]: @@ -95,7 +107,7 @@ def run(self) -> Dict[str, DataLoader]: # 2. Finds all pairs of images and YOLO annotations from the dataset print("b. Searching for all the data...") - pairs_by_class = self._loader.find_pairs() + pairs_by_class = self._loader.find_image_paths() # 3a. Splits the data into train, val, and test sets print("c. Splitting the data...") @@ -114,4 +126,4 @@ def run(self) -> Dict[str, DataLoader]: # 4. Create DataLoaders for each split print("e. Creating DataLoaders for each split...") loaders = self._creator.create_loaders(splits) - return loaders + return loaders \ No newline at end of file diff --git a/recaptcha_classifier/data/preprocessor.py b/recaptcha_classifier/data/preprocessor.py index c171503a2e..51cd4dd0ea 100644 --- a/recaptcha_classifier/data/preprocessor.py +++ b/recaptcha_classifier/data/preprocessor.py @@ -3,7 +3,6 @@ from PIL import Image import numpy as np import torch -from .types import BBoxList class ImagePrep: @@ -39,28 +38,6 @@ def load_image(self, img_path: Path) -> Image.Image: loaded = Image.open(img_path).convert("RGB") return self._resize(loaded) - def load_labels(self, lbl_path: Path) -> BBoxList: - """ - Parses the list of YOLO format labels from the file at the given path. - - Args: - lbl_path (Path): Path to the label file. - - Returns: - BBoxList: List of bounding boxes in YOLO format - (x_center, y_center, width, height). - """ - bounding_boxes = [] - with open(lbl_path, "r") as f: - for line in f: - parts = line.strip().split() - if len(parts) < 5: - continue # invalid line skipped - _, x_center, y_center, width, height = map(float, parts) - bounding_boxes.append((x_center, y_center, width, height)) - - return bounding_boxes - def _resize(self, img: Image.Image) -> Image.Image: """ Resizes the image to the target size. diff --git a/recaptcha_classifier/data/scaler.py b/recaptcha_classifier/data/scaler.py deleted file mode 100644 index f35a303125..0000000000 --- a/recaptcha_classifier/data/scaler.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Tuple -from .types import BBoxList - - -class YOLOScaler: - """ - This class is responsible for scaling the YOLO format bounding boxes - based on the transform applied to the image. It is only - used inside the AugmentationPipeline class, to adjust - the coordinates of the bounding boxes after applying - transformations to the image. - """ - @staticmethod - def scale_for_flip(bboxes: BBoxList) -> BBoxList: - """ - Adjusts the bounding boxes for horizontal flip. - - Args: - bboxes (BBoxList): List of bounding boxes in YOLO format - (x_center, y_center, width, height). - - Returns: - BBoxList: List of scaled bounding boxes. - """ - return [(1 - x, y, w, h) for (x, y, w, h) in bboxes] - - @staticmethod - def scale_for_rotation(bboxes: BBoxList, - angle: float, - size: Tuple[int, int]) -> BBoxList: - """ - Adjusts the bounding boxes for rotation. - - Args: - bboxes (BBoxList): List of bounding boxes in YOLO format - (x_center, y_center, width, height). - angle (float): Angle of rotation. - size (Tuple[int, int]): Size of the image. (width, height) - - Returns: - BBoxList: List of scaled bounding boxes. - """ - import math - width, height = size - angle_rad = math.radians(angle) - c_x, c_y = width / 2, height / 2 # center coordinates of the image - - n_ann = [] # the new bounding boxes, that we will return - - for x, y, w, h in bboxes: - # calculate pixel coordinates from bb - x0, y0 = x * width, y * height - bw, bh = w * width, h * height - - # calculate corner coordinates - corners = [ - (x0 - bw / 2, y0 - bh / 2), - (x0 + bw / 2, y0 - bh / 2), - (x0 + bw / 2, y0 + bh / 2), - (x0 - bw / 2, y0 + bh / 2) - ] - - # rotate the corners - new_corners = [] - for cx, cy in corners: - # rotate the corners around the center of the image - # formula at https://en.wikipedia.org/wiki/Rotation_matrix - x_rot = (math.cos(angle_rad) * (cx - c_x) - - math.sin(angle_rad) * (cy - c_y) + - c_x) - y_rot = (math.sin(angle_rad) * (cx - c_x) + - math.cos(angle_rad) * (cy - c_y) + - c_y) - - new_corners.append((x_rot, y_rot)) - - # calculate new bounding box - x_min = min(c[0] for c in new_corners) - x_max = max(c[0] for c in new_corners) - y_min = min(c[1] for c in new_corners) - y_max = max(c[1] for c in new_corners) - - new_x = max(0, min(1, (x_min + x_max) / (2 * width))) - new_y = max(0, min(1, (y_min + y_max) / (2 * height))) - new_w = max(0, min(1, (x_max - x_min) / width)) - new_h = max(0, min(1, (y_max - y_min) / height)) - - n_ann.append((new_x, new_y, new_w, new_h)) - - return n_ann diff --git a/recaptcha_classifier/data/splitter.py b/recaptcha_classifier/data/splitter.py index 4e36bbdb76..dcba9fd3b3 100644 --- a/recaptcha_classifier/data/splitter.py +++ b/recaptcha_classifier/data/splitter.py @@ -1,6 +1,6 @@ import random from typing import Tuple -from .types import ClassFileDict, DatasetSplitDict, FilePairList +from .types import ClassToImgPaths, DatasetSplitMap, ImagePathList class DataSplitter: @@ -15,7 +15,8 @@ class DataSplitter: def __init__(self, ratios: Tuple[float, float, float] = (0.7, 0.2, 0.1), shuffle: bool = True, - seed: int = None) -> None: + seed: int = 23 # our group number + ) -> None: """ Args: ratios (Tuple[float, float, float]): Ratios for @@ -29,8 +30,8 @@ def __init__(self, self._validate_ratios() def split(self, - pairs_by_class: ClassFileDict - ) -> DatasetSplitDict: + pairs_by_class: ClassToImgPaths + ) -> DatasetSplitMap: """ Splits each class into train, validation, and test sets. It shuffles the data if specified and returns a dictionary @@ -40,7 +41,7 @@ def split(self, items (List): List of items to be split. Returns: - DatasetSplitDict: Nested dictionary containing + DatasetSplitMap: Nested dictionary containing splits for each class. """ splits = {'train': {}, 'val': {}, 'test': {}} @@ -69,16 +70,16 @@ def _validate_ratios(self) -> None: if any(ratio < 0 for ratio in self._ratios): raise ValueError("Ratios must be positive.") - def _shuffle_items(self, items: FilePairList) -> FilePairList: + def _shuffle_items(self, items: ImagePathList) -> ImagePathList: """ Returns a shuffled copied version of the items list, using seed if provided. Args: - items (FilePairList): List of items to be shuffled. + items (ImagePathList): List of items to be shuffled. Returns: - FilePairList: Shuffled list of items. + ImagePathList: Shuffled list of items. """ new_items = items.copy() rand = random.Random(self._seed) diff --git a/recaptcha_classifier/data/types.py b/recaptcha_classifier/data/types.py index f5ad6b42d7..ade77d8f66 100644 --- a/recaptcha_classifier/data/types.py +++ b/recaptcha_classifier/data/types.py @@ -3,34 +3,22 @@ from pathlib import Path from torch import Tensor -# OBJECT DETECTION TASK CLASSES -# A (image, label) pair containing their system paths/ locations -FilePair = Path # Tuple[Path, Path] - -# A list of (image, label) pairs, for more items in the dataset -FilePairList = List[FilePair] +# A list of image paths, for more items in the dataset +ImagePathList = List[Path] # A dictionary where the keys are class names and -# the values are lists of (image, label) pairs, elements of that class -ClassFileDict = Dict[str, FilePairList] +# the values are lists of all image paths for that class in the dataset +ClassToImgPaths = Dict[str, ImagePathList] # A nested dictionary where main keys are train/val/test -# and the subkeys are class names -DatasetSplitDict = Dict[str, ClassFileDict] - -# YOLO bounding box format (x_center, y_center, width, height) -BBox = Tuple[float, float, float, float] - -# List of bounding boxes for one image from the dataset -BBoxList = List[BBox] +# and the subkeys are class names, containing all image paths of that class +DatasetSplitMap = Dict[str, ClassToImgPaths] -# A dataset item, of (image, annotations), where both features -# now loaded in memory, instead of Paths like in FilePair -DataPair = Tuple[Image.Image, BBoxList] +# A dataset item, where the image is now loaded from disk +LoadedImg = Image.Image -# A final dataset item, that now contains the image tensor, -# the bounding boxes and the class id; ready for model training -DataItem = Tuple[Tensor, BBoxList, int] +# A final dataset item, that now contains the image tensor, and the class id; ready for model training +DataItem = Tuple[Tensor, Tensor] # Output of the dataloader, a batch of data items -DataBatch = Tuple[Tensor, List[BBoxList], Tensor] +DataBatch = Tuple[Tensor, Tensor] diff --git a/recaptcha_classifier/data/visualizer.py b/recaptcha_classifier/data/visualizer.py index cca973d9ef..cd3801a83f 100644 --- a/recaptcha_classifier/data/visualizer.py +++ b/recaptcha_classifier/data/visualizer.py @@ -1,5 +1,5 @@ import matplotlib.pyplot as plt -from .types import DatasetSplitDict +from .types import DatasetSplitMap import numpy as np @@ -10,31 +10,31 @@ class Visualizer: visualization of classes across the splits. """ @classmethod - def print_counts(cls, splits: DatasetSplitDict) -> None: + def print_counts(cls, splits: DatasetSplitMap) -> None: """ Prints the counts of samples in each class, for each split. Args: - splits (DatasetSplitDict): the dataset splits - containing the pairs for each class. + splits (DatasetSplitMap): the dataset splits + containing the items for each class. """ for split, cls_dict in splits.items(): print(f"{split.upper()}:") - for cls, pairs in cls_dict.items(): - print(f" {cls:5s}: {len(pairs)}") + for cls, items in cls_dict.items(): + print(f" {cls:5s}: {len(items)}") print() @classmethod def plot_splits(cls, - splits: DatasetSplitDict, + splits: DatasetSplitMap, title: str = "Class Distribution in Splits") -> None: """ Plots a bar chart showing the amount and percentage of samples present in each class, for each of the splits. Args: - splits (DatasetSplitDict): the dataset splits - containing the pairs for each class. + splits (DatasetSplitMap): the dataset splits + containing the items for each class. title (str): the title of the plot. """ classes = list(splits['train'].keys()) @@ -51,40 +51,8 @@ def plot_splits(cls, # total for each class, simply adding counts of each split totals = [t+v+te for t, v, te in zip(counts_train, counts_val, counts_test)] - - """ - # drawing bars - train_bars = plt.bar([i - width for i in x], - counts_train, - width, - label='Train') - val_bars = plt.bar(x, counts_val, width, label='Val') - test_bars = plt.bar([i + width for i in x], - counts_test, - width, - label='Test') - - # adding perc labels on top of each of the bar - for bars, counts in [(train_bars, counts_train), - (val_bars, counts_val), - (test_bars, counts_test)]: - for bar, count, total in zip(bars, counts, totals): - perc = count / total * 100 - plt.text( - bar.get_x() + bar.get_width() / 2, # middle of the bar, - bar.get_height(), # on top of the bar, - f'{perc:.1f}%', # percentage, formatted to 1 decimal - ha='center', va='bottom' - ) - - plt.xticks(x, classes) - plt.ylabel('No. of Samples') - plt.title(title) - plt.legend() - plt.tight_layout() - plt.show() - """ - fig, ax = plt.subplots(figsize=(num_classes, 6)) + + _, ax = plt.subplots(figsize=(num_classes, 6)) bar1 = ax.bar(x - width, counts_train, width, label='Train') bar2 = ax.bar(x, counts_val, width, label='Val') bar3 = ax.bar(x + width, counts_test, width, label='Test') @@ -109,4 +77,4 @@ def annotate(bars, counts): ax.grid(axis='y', linestyle='--', alpha=0.7) plt.tight_layout() - plt.show() + plt.show() \ No newline at end of file diff --git a/recaptcha_classifier/detection_labels.py b/recaptcha_classifier/detection_labels.py index 5c9fe8872c..f99c72e3ee 100644 --- a/recaptcha_classifier/detection_labels.py +++ b/recaptcha_classifier/detection_labels.py @@ -5,13 +5,6 @@ class DetectionLabels(Enum): """ Enum for improving readability of the object classes. """ - - """ - OBJECT DETECTION TASK CLASSES - CROSSWALK = 0 - CHIMNEY = 1 - STAIR = 2 - """ BICYCLE = 0 BRIDGE = 1 BUS = 2 @@ -60,7 +53,7 @@ def to_class_map(cls) -> dict: Returns: dict: Dictionary representation of the enum. """ - return {cl.name.capitalize().replace("_", " "): + return {cl.name.replace("_", " ").title(): cl.value for cl in cls} @classmethod @@ -83,12 +76,10 @@ def from_id(cls, id: int) -> str: @classmethod def dataset_classnames(cls) -> list: """ - Returns a list of class names, only with first letter capitalized. - We use it for the pair loader, as that is the format of the folders - downloaded from the dataset. + Returns a list of class names. Returns: list: List of class names. """ - return [name.capitalize().replace("_", " ") + return [name.replace("_", " ").title() for name in cls.__members__.keys()] diff --git a/recaptcha_classifier/features/__init__.py b/recaptcha_classifier/features/__init__.py index e69de29bb2..915c534014 100644 --- a/recaptcha_classifier/features/__init__.py +++ b/recaptcha_classifier/features/__init__.py @@ -0,0 +1,5 @@ +from .evaluation import evaluate_model + +__all__ = [ + "evaluate_model" +] \ No newline at end of file diff --git a/recaptcha_classifier/features/evaluation/__init__.py b/recaptcha_classifier/features/evaluation/__init__.py index e69de29bb2..95daef105c 100644 --- a/recaptcha_classifier/features/evaluation/__init__.py +++ b/recaptcha_classifier/features/evaluation/__init__.py @@ -0,0 +1,5 @@ +from .evaluate import evaluate_model + +__all__ = [ + "evaluate_model" +] \ No newline at end of file diff --git a/recaptcha_classifier/features/evaluation/classification_metrics.py b/recaptcha_classifier/features/evaluation/classification_metrics.py index 0fc17a2e0d..3ad76ba0c9 100644 --- a/recaptcha_classifier/features/evaluation/classification_metrics.py +++ b/recaptcha_classifier/features/evaluation/classification_metrics.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torchmetrics import Accuracy, F1Score -from torchmetrics.classification import MulticlassConfusionMatrix +from torchmetrics.classification import MulticlassConfusionMatrix, MulticlassAccuracy from typing import Optional from recaptcha_classifier.detection_labels import DetectionLabels import matplotlib.pyplot as plt @@ -13,7 +13,8 @@ def evaluate_classification(y_pred: Tensor, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), average: str = 'weighted', cm_plot: bool = True, - class_names: Optional[list[str]] = None) -> dict: + class_names: Optional[list[str]] = None, + ) -> dict: """ Evaluate classification model using torchmetrics @@ -36,18 +37,27 @@ def evaluate_classification(y_pred: Tensor, dict: accuracy, f1, confusion_matrix """ + logits = y_pred # Convert logits to predicted labels y_pred = torch.argmax(y_pred, dim=1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + y_pred = y_pred.to(device) + y_true = y_true.to(device) + + if num_classes <= 0: + raise ValueError("num_classes must be a positive integer") # Initialize Metrics acc = Accuracy(task="multiclass", num_classes=num_classes).to(device=device) f1 = F1Score(task="multiclass", num_classes=num_classes, average=average).to(device=device) + topk_acc = MulticlassAccuracy(num_classes=num_classes, top_k=3).to(device=device) confmat = MulticlassConfusionMatrix(num_classes=num_classes).to(device=device) # Compute metrics acc_val = acc(y_pred, y_true) f1_val = f1(y_pred, y_true) + topk_acc_val = topk_acc(logits, y_true) cm = confmat(y_pred, y_true) if cm_plot: @@ -57,5 +67,6 @@ def evaluate_classification(y_pred: Tensor, return { 'Accuracy': acc_val.item(), 'F1-score': f1_val.item(), - 'Confusion Matrix': cm + 'Confusion Matrix': cm, + 'Top-3 Accuracy': topk_acc_val.item() } diff --git a/recaptcha_classifier/features/evaluation/evaluate.py b/recaptcha_classifier/features/evaluation/evaluate.py index 1224a8922a..db998a1c39 100644 --- a/recaptcha_classifier/features/evaluation/evaluate.py +++ b/recaptcha_classifier/features/evaluation/evaluate.py @@ -1,4 +1,4 @@ -import recaptcha_classifier.features.evaluation.classification_metrics as cm +from .classification_metrics import evaluate_classification import torch from tqdm import tqdm from recaptcha_classifier.detection_labels import DetectionLabels @@ -43,7 +43,7 @@ def evaluate_model(model: torch.nn.Module, y_pred = torch.cat(all_preds) y_true = torch.cat(all_targets) - class_results = cm.evaluate_classification( + class_results = evaluate_classification( y_pred=y_pred, y_true=y_true, class_names=class_names, diff --git a/recaptcha_classifier/models/__init__.py b/recaptcha_classifier/models/__init__.py index 5f6908ea92..735c490da1 100644 --- a/recaptcha_classifier/models/__init__.py +++ b/recaptcha_classifier/models/__init__.py @@ -1,3 +1,4 @@ from .simple_classifier_model import SimpleCNN +from .main_model import MainCNN, HPOptimizer -__all__ = ['SimpleCNN'] +__all__ = ['SimpleCNN', 'MainCNN', 'HPOptimizer'] diff --git a/recaptcha_classifier/models/main_model/HPoptimizer.py b/recaptcha_classifier/models/main_model/HPoptimizer.py index 4bb6490f90..7f669ce0ad 100644 --- a/recaptcha_classifier/models/main_model/HPoptimizer.py +++ b/recaptcha_classifier/models/main_model/HPoptimizer.py @@ -1,8 +1,7 @@ import itertools import random - import pandas as pd - +from typing import List from recaptcha_classifier.models.main_model.model_class import MainCNN from recaptcha_classifier.train.training import Trainer @@ -28,9 +27,9 @@ def get_history(self)->pd.DataFrame: def optimize_hyperparameters(self, - n_layers: list = list(range(1,3)), - kernel_sizes: list = [3, 4, 5], - learning_rates: list = [1e-3, 1e-4], + n_layers: list = [1, 2, 3], + kernel_sizes: list = [3, 5], + learning_rates: list = [1e-2, 1e-3, 1e-4], save_checkpoints: bool = True, n_models: int = 1, n_combos: int = 8, # Number of random samples @@ -81,9 +80,10 @@ def optimize_hyperparameters(self, return df_opt_data.copy()[:n_models] - def _train_one_model(self, hp_combo, save_checkpoints) -> None: + def _train_one_model(self, hp_combo: List, save_checkpoints: bool) -> None: model = MainCNN(n_layers=int(hp_combo[0]), kernel_size=int(hp_combo[1])) - self.trainer.train(model=model, lr=hp_combo[2], load_checkpoint=False, save_checkpoint=save_checkpoints) + self.trainer.train(model=model, lr=hp_combo[2], load_checkpoint=False, + save_checkpoint=save_checkpoints) def _generate_hp_combinations(self, hp) -> list: diff --git a/recaptcha_classifier/models/main_model/__init__.py b/recaptcha_classifier/models/main_model/__init__.py index e69de29bb2..405ca63fb8 100644 --- a/recaptcha_classifier/models/main_model/__init__.py +++ b/recaptcha_classifier/models/main_model/__init__.py @@ -0,0 +1,7 @@ +from .model_class import MainCNN +from .HPoptimizer import HPOptimizer + +__all__ = [ + "MainCNN", + "HPOptimizer" +] \ No newline at end of file diff --git a/recaptcha_classifier/models/main_model/kfold_validation.py b/recaptcha_classifier/models/main_model/kfold_validation.py index 2465030f75..7513e6166f 100644 --- a/recaptcha_classifier/models/main_model/kfold_validation.py +++ b/recaptcha_classifier/models/main_model/kfold_validation.py @@ -2,8 +2,7 @@ from sklearn.model_selection import KFold from sympy.printing.pytorch import torch from torch.utils.data import DataLoader, Subset - -from recaptcha_classifier import DetectionLabels +import matplotlib.pyplot as plt from recaptcha_classifier.features.evaluation.evaluate import evaluate_model from recaptcha_classifier.train.training import Trainer from recaptcha_classifier.models.main_model.model_class import MainCNN @@ -29,15 +28,13 @@ def __init__(self, :param hp_optimizer: Instance of HPOptimizer :param device: Optional torch device """ - self._class_map = DetectionLabels.all() # class labels self.train_loader = train_loader self.val_loader = val_loader self.k_folds = k_folds self.device = device if device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.results = None - + self.device = torch.device("cuda" if torch.cuda.is_available() + else "cpu") def run_cross_validation(self, hp: list, @@ -60,52 +57,82 @@ def run_cross_validation(self, dataset = self.train_loader.dataset kf = KFold(n_splits=self.k_folds, shuffle=True, random_state=42) - + results = [] - + n_layers, kernel_sizes, learning_rates = hp - for fold_index, (train_idx, val_idx) in enumerate(kf.split(all_indices)): + for fold_index, (t_idx, val_idx) in enumerate(kf.split(all_indices)): + print(f"\n--- Fold {fold_index + 1}/{self.k_folds} ---") - train_subset = Subset(dataset, [all_indices[i] for i in train_idx]) + train_subset = Subset(dataset, [all_indices[i] + for i in t_idx]) val_subset = Subset(dataset, [all_indices[i] for i in val_idx]) - fold_train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=False) - fold_val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False) - + fold_train_loader = DataLoader(train_subset, + batch_size=batch_size, + shuffle=False) + fold_val_loader = DataLoader(val_subset, + batch_size=batch_size, + shuffle=False) + model = MainCNN(n_layers=n_layers, kernel_size=kernel_sizes) - + trainer = Trainer(fold_train_loader, fold_val_loader, epochs=20, save_folder=MODELS_FOLDER, device=self.device) - trainer.train(model, lr=learning_rates, save_checkpoint=save_checkpoints, + + trainer.train(model, + lr=learning_rates, + save_checkpoint=save_checkpoints, load_checkpoint=load_checkpoints) - - metrics = evaluate_model(model, fold_val_loader, device=self.device) - metrics.pop('Confusion Matrix') + + metrics = evaluate_model(model, fold_val_loader, + device=self.device) + metrics.pop('Confusion Matrix') metrics["fold"] = fold_index + 1 results.append(metrics) - + df_results = pd.DataFrame(results) - - self.results = df_results - - def print_summary(self) -> None: + + df_results.to_csv(f"{MODELS_FOLDER}/kfold_results.csv", index=False) + self.print_summary(df_results) + self.plot_results(df_results) + + @staticmethod + def print_summary(results: pd.DataFrame) -> None: """ Prints a summary of the cross-validation results. """ - if self.results is None: - print("No results to display. Run cross-validation first.") - return - - print("\n--- Cross-Validation Summary ---") - print(self.results) - - mean_results = self.results.mean() + print("\n~~ Cross-Validation Summary ~~") + print(results.round(3)) + + res = results.drop(columns=["fold"]) + + means = res.mean() print("\nMean Results Across Folds:") - print(mean_results) - - std_results = self.results.std() + print(means.round(3)) + + stds = res.std() print("\nStandard Deviation Across Folds:") - print(std_results) \ No newline at end of file + print(stds.round(3)) + + @staticmethod + def plot_results(results: pd.DataFrame) -> None: + """ + Plots the results of the cross-validation. + """ + metrics = [col for col in results.columns if col != 'fold'] + mean_vals = results[metrics].mean() + std_vals = results[metrics].std() + + fig, ax = plt.subplots(figsize=(10, 6)) + mean_vals.plot(kind="bar", yerr=std_vals, capsize=5, ax=ax) + + ax.set_title("Cross-Validation Metrics (mean +- std)") + ax.set_ylabel("Score") + ax.set_xticklabels(metrics, rotation=45, ha='right') + plt.tight_layout() + plt.grid(axis='y') + plt.show() diff --git a/recaptcha_classifier/models/main_model/model_class.py b/recaptcha_classifier/models/main_model/model_class.py index 96d2b74b01..b24d810212 100644 --- a/recaptcha_classifier/models/main_model/model_class.py +++ b/recaptcha_classifier/models/main_model/model_class.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn from recaptcha_classifier.constants import INPUT_SHAPE +from recaptcha_classifier.models.base_model import BaseModel -class MainCNN(nn.Module): # should inherit BaseModel(nn.Module) + +class MainCNN(BaseModel): def __init__(self, n_layers: int, @@ -64,8 +66,5 @@ def forward(self, x) -> torch.Tensor: x = x.view(x.size(0), -1) x = self.classifier(x) - return x - - - + return x # ! return logits, torch.softmax(x, dim=1) if needed for uncertainty diff --git a/recaptcha_classifier/pipeline/base_pipeline.py b/recaptcha_classifier/pipeline/base_pipeline.py index 7ec64c0fab..8355855936 100644 --- a/recaptcha_classifier/pipeline/base_pipeline.py +++ b/recaptcha_classifier/pipeline/base_pipeline.py @@ -14,7 +14,8 @@ def __init__(self, save_folder: str = "checkpoints", model_file_name: str = "model.pt", optimizer_file_name: str = "optimizer.pt", - scheduler_file_name: str = "scheduler.pt" + scheduler_file_name: str = "scheduler.pt", + early_stopping: bool = True ): self.lr = lr @@ -24,8 +25,8 @@ def __init__(self, self.model_file_name = model_file_name self.optimizer_file_name = optimizer_file_name self.scheduler_file_name = scheduler_file_name - - self._class_map = DetectionLabels.to_class_map() + self.early_stopping = early_stopping + self._class_map = DetectionLabels self._loaders = None self._data = None self._model = None @@ -55,18 +56,19 @@ def _initialize_trainer(self) -> Trainer: model_file_name=self.model_file_name, optimizer_file_name=self.optimizer_file_name, scheduler_file_name=self.scheduler_file_name, - device=self.device) + device=self.device, + early_stopping=self.early_stopping) @property def class_map_length(self): - return len(self._class_map) + return len(self._class_map.all()) def evaluate(self, plot_cm: bool = False) -> dict: eval_results = evaluate_model( model=self._model, test_loader=self._loaders['test'], device=self._trainer.device, - class_names=list(self._class_map.keys()), + class_names=self._class_map.dataset_classnames(), plot_cm=plot_cm ) return eval_results diff --git a/recaptcha_classifier/pipeline/main_model_pipeline.py b/recaptcha_classifier/pipeline/main_model_pipeline.py index 65f7aceaa3..95f46efd5a 100644 --- a/recaptcha_classifier/pipeline/main_model_pipeline.py +++ b/recaptcha_classifier/pipeline/main_model_pipeline.py @@ -55,15 +55,12 @@ def _run_kfold_cross_validation(self) -> None: train_loader=self._loaders["train"], val_loader=self._loaders["val"], k_folds=self.k_folds, - hp_optimizer=self._hp_optimizer, device=self.device ) best_hp = self._hp_optimizer.get_best_hp() - self._kfold.run_cross_validation(hp=best_hp) - print("\n~~ Cross-Validation Summary ~~") - self._kfold.print_summary() + self._kfold.run_cross_validation(hp=best_hp) self.lr = best_hp[2] self._model = self._initialize_model( @@ -76,6 +73,7 @@ def _initialize_model(self, n_layers: int, kernel_size: int) -> MainCNN: num_classes=self.class_map_length) def save_model(self): + os.makedirs(self.save_folder, exist_ok=True) torch.save({ "model_state_dict": self._model.state_dict(), "config": { diff --git a/recaptcha_classifier/pipeline/simple_cnn_pipeline.py b/recaptcha_classifier/pipeline/simple_cnn_pipeline.py index efd3a792f6..4f9b3496d6 100644 --- a/recaptcha_classifier/pipeline/simple_cnn_pipeline.py +++ b/recaptcha_classifier/pipeline/simple_cnn_pipeline.py @@ -1,7 +1,6 @@ import torch from recaptcha_classifier.models.simple_classifier_model import SimpleCNN from recaptcha_classifier.pipeline.base_pipeline import BasePipeline -from recaptcha_classifier.train.training import Trainer import os from recaptcha_classifier.constants import ( MODELS_FOLDER, SIMPLE_MODEL_FILE_NAME, @@ -41,4 +40,5 @@ def _initialize_model(self) -> SimpleCNN: return SimpleCNN(num_classes=self.class_map_length) def save_model(self): + os.makedirs(self.save_folder, exist_ok=True) torch.save(self._model.state_dict(), os.path.join(self.save_folder, self.model_file_name)) \ No newline at end of file diff --git a/recaptcha_classifier/server/api.py b/recaptcha_classifier/server/api.py index e952408e36..8b7b30c2c5 100644 --- a/recaptcha_classifier/server/api.py +++ b/recaptcha_classifier/server/api.py @@ -3,36 +3,51 @@ import torch from fastapi import FastAPI, File, UploadFile, Response from fastapi.responses import JSONResponse +import torch.nn.functional as F from PIL import Image -from .load_model import load_main_model +from .load_model import load_main_model, get_model_path from recaptcha_classifier.detection_labels import DetectionLabels from recaptcha_classifier.constants import IMAGE_SIZE from pydantic import BaseModel from typing import Literal +import os class PredictionResponse(BaseModel): label: str - confidence: float + confidence: str class_id: int app = FastAPI() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = None @app.on_event("startup") def load_models(): """Load the models into memory at startup.""" global model - model = load_main_model(device) - model.to(device) + model_path = get_model_path("main") + if os.path.exists(model_path): + model = load_main_model(device) + else: + print("Model file not found. API will return an error for predictions" + "until model is available.") @app.post("/predict", response_model=PredictionResponse) -async def predict(response: Response, file: UploadFile = File(...)) -> PredictionResponse: +async def predict(response: Response, + file: UploadFile = File(...) + ) -> PredictionResponse: + if model is None: + return JSONResponse( + status_code=503, + content={"error": "Model not loaded. Please train " + "or download first."} + ) try: data = await file.read() img = Image.open(io.BytesIO(data)).convert("RGB") - result = predict(model, device, img) + result = inference(model, device, img) response.headers["Cache-Control"] = "max-age=3600" @@ -43,7 +58,7 @@ async def predict(response: Response, file: UploadFile = File(...)) -> Predictio content={"error": str(e)} ) -def predict(model: torch.nn.Module, device: torch.device, image: Image.Image) -> dict: +def inference(model: torch.nn.Module, device: torch.device, image: Image.Image) -> dict: """ Handles the prediction logic for the uploaded image. """ @@ -54,11 +69,13 @@ def predict(model: torch.nn.Module, device: torch.device, image: Image.Image) -> with torch.no_grad(): output = model(tensor) + prob = F.softmax(output, dim=1) + conf = prob.argmax(dim=1).item() id = output.argmax(dim=1).item() label = DetectionLabels.from_id(id) return { - "label": label.name, - "confidence": float(output.max().item()), + "label": label, + "confidence": f"{conf * 100:.2f}%", "class_id": id } \ No newline at end of file diff --git a/recaptcha_classifier/server/app.py b/recaptcha_classifier/server/app.py index 2a2c770da7..e75e889683 100644 --- a/recaptcha_classifier/server/app.py +++ b/recaptcha_classifier/server/app.py @@ -1,6 +1,9 @@ import streamlit as st import torch from typing import Literal +import requests +from recaptcha_classifier.pipeline.main_model_pipeline import MainClassifierPipeline +from recaptcha_classifier.pipeline.simple_cnn_pipeline import SimpleClassifierPipeline class StreamlitApp: @@ -45,7 +48,18 @@ def render_training_tab(self) -> None: help="CUDA is not available on this machine.") if st.button("Start Training"): - # logic here + if self.model_type == "Simple": + pipeline = SimpleClassifierPipeline(lr=self.lr, + epochs=self.epochs, + early_stopping=self.early_stopping, + device=self.device) + pipeline.run() + else: + pipeline = MainClassifierPipeline(lr=self.lr, + epochs=self.epochs, + early_stopping=self.early_stopping, + device=self.device) + pipeline.run() st.success(f"Started training {self.model_type} model with " f"learning rate {self.lr}, epochs {self.epochs}, " f"early stopping: {self.early_stopping}, on {self.device}.") @@ -54,10 +68,25 @@ def render_inference_tab(self) -> None: st.title("Inference") st.write("Please upload an image for inference.") - uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) + file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) - if uploaded_file is not None: - # process and call api - st.image(uploaded_file, caption='Uploaded Image.', use_column_width=True) + if file is not None: + st.image(file, caption='Uploaded Image.', use_column_width=True) if st.button("Run Inference"): - pass \ No newline at end of file + files = {"file": (file.name, file.getvalue(), file.type)} + + try: + resp = requests.post("http://127.0.0.1:8000/predict", files=files) + resp.raise_for_status() + result = resp.json() + + st.success("Prediction successful!") + st.write(f"Label: {result['label']}") + st.write(f"Confidence: {result['confidence']}") + st.write(f"Class ID: {result['class_id']}") + except Exception as e: + st.error(f"Error during inference: {str(e)}") + +if __name__ == "__main__": + app = StreamlitApp() + app.render() \ No newline at end of file diff --git a/recaptcha_classifier/server/load_model.py b/recaptcha_classifier/server/load_model.py index c8d62eb58f..e8fa22bf75 100644 --- a/recaptcha_classifier/server/load_model.py +++ b/recaptcha_classifier/server/load_model.py @@ -1,9 +1,6 @@ import torch -from torchvision import transforms -from PIL import Image - +import os from ..models.main_model.model_class import MainCNN -from recaptcha_classifier.detection_labels import DetectionLabels from recaptcha_classifier.constants import ( MODELS_FOLDER, @@ -18,7 +15,7 @@ def load_simple_model(device: torch.device = torch.device("cpu")): """ from ..models.simple_classifier_model import SimpleCNN model = SimpleCNN() - path = MODELS_FOLDER + "/" + SIMPLE_MODEL_FILE_NAME + path = get_model_path("simple") model.load_state_dict(torch.load(path, map_location=device)) model.to(device) model.eval() @@ -28,7 +25,7 @@ def load_main_model(device: torch.device = torch.device("cpu")): """ Load the main CNN model for image classification. """ - path = MODELS_FOLDER + "/" + MAIN_MODEL_FILE_NAME + path = get_model_path("main") checkpoint = torch.load(path, map_location=device) config = checkpoint['config'] model = MainCNN( @@ -38,4 +35,14 @@ def load_main_model(device: torch.device = torch.device("cpu")): model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() - return model \ No newline at end of file + return model + +def get_model_path(model_type: str) -> str: + """ + Get the path to the model file based on the model type. + """ + if model_type == "simple": + return os.path.join(MODELS_FOLDER, SIMPLE_MODEL_FILE_NAME) + elif model_type == "main": + return os.path.join(MODELS_FOLDER, MAIN_MODEL_FILE_NAME) + return None \ No newline at end of file diff --git a/recaptcha_classifier/train/__init__.py b/recaptcha_classifier/train/__init__.py index e69de29bb2..3d6a085185 100644 --- a/recaptcha_classifier/train/__init__.py +++ b/recaptcha_classifier/train/__init__.py @@ -0,0 +1,3 @@ +from .training import Trainer + +__all__ = ['Trainer'] \ No newline at end of file diff --git a/recaptcha_classifier/train/training.py b/recaptcha_classifier/train/training.py index c663c04ecf..83ee75e507 100644 --- a/recaptcha_classifier/train/training.py +++ b/recaptcha_classifier/train/training.py @@ -16,11 +16,12 @@ def __init__(self, val_loader: DataLoader, epochs: int, save_folder: str, - model_file_name='model.pt', - optimizer_file_name='optimizer.pt', - scheduler_file_name='scheduler.pt', + model_file_name: str ='model.pt', + optimizer_file_name: str ='optimizer.pt', + scheduler_file_name: str ='scheduler.pt', device: torch.device |None = None, - early_stop_threshold: int = 5 + early_stop_threshold: int = 5, + early_stopping: bool = True ): """ Constructor for Trainer class. @@ -47,7 +48,7 @@ def __init__(self, self._loss_acc_history = [] self.optimizer = None self.scheduler = None - + self._does_early_stop = early_stopping self._early_stop_threshold = early_stop_threshold self._best_val_loss = float('inf') self._stagnation_counter = 0 @@ -126,12 +127,11 @@ def train(self, val_acc = val_accuracy_counter.compute().item() print(f"Epoch {epoch+1} - Val loss: {val_loss:.4f}, Val accuracy: {val_acc:.4f}") - if self._early_stop(val_loss): + if self._does_early_stop and self._early_stop(val_loss): print(f"Early stopping at epoch {epoch + 1}.") print(f"Best validation loss: {self._best_val_loss:.4f}") break - def _train_one_epoch(self, model, train_accuracy_counter, train_loss_counter, train_progress_bar): model.train() for data, targets in train_progress_bar: diff --git a/reports/figures/main-model.png b/reports/figures/main-model.png new file mode 100644 index 0000000000..726336548d Binary files /dev/null and b/reports/figures/main-model.png differ diff --git a/reports/figures/simple-model.png b/reports/figures/simple-model.png new file mode 100644 index 0000000000..55b6774af4 Binary files /dev/null and b/reports/figures/simple-model.png differ diff --git a/reports/main-model.png b/reports/main-model.png deleted file mode 100644 index 4d076e623d..0000000000 Binary files a/reports/main-model.png and /dev/null differ diff --git a/reports/simple-model.png b/reports/simple-model.png deleted file mode 100644 index 8004263166..0000000000 Binary files a/reports/simple-model.png and /dev/null differ diff --git a/tests/data/notyet-test_augment.py b/tests/data/notyet-test_augment.py deleted file mode 100644 index b153bf66cd..0000000000 --- a/tests/data/notyet-test_augment.py +++ /dev/null @@ -1,56 +0,0 @@ -import unittest -import numpy as np -from unittest.mock import patch -from PIL import Image -from recaptcha_classifier.data.augment import ( - AugmentationPipeline, - HorizontalFlip, - RandomRotation -) -from recaptcha_classifier.data.scaler import YOLOScaler - - -class TestAugmentation(unittest.TestCase): - def setUp(self): - self.img = Image.fromarray( - np.tile(np.arange(100, dtype=np.uint8).reshape(100, 1), - (1, 100, 3)) - ) - self.img = self.img.convert("RGB") - self.bb = [(0.5, 0.5, 0.2, 0.2)] - - def test_horizontal_flip(self): - aug = HorizontalFlip(p=1.0) - flipped_img, flipped_bb = aug.augment(self.img, self.bb) - - self.assertFalse(np.array_equal(np.array(flipped_img), - np.array(self.img))) - - result = YOLOScaler.scale_for_flip(self.bb) - self.assertEqual(flipped_bb, result) - - @patch('random.uniform', return_value=30) - def test_random_rotation(self, _): - augmenter = RandomRotation(degrees=30) - rotated_img, rotated_bb = augmenter.augment(self.img, self.bb) - - self.assertFalse(np.array_equal(np.array(rotated_img), - np.array(self.img))) - - result = YOLOScaler.scale_for_rotation(self.bb, 30, self.img.size) - for i, j in zip(rotated_bb, result): - self.assertAlmostEqual(i, j, places=4) - - def test_pipeline(self): - pipeline = AugmentationPipeline() - pipeline.add_transform(HorizontalFlip(p=1.0)) - pipeline.add_transform(RandomRotation(degrees=30)) - - new_img, new_bb = pipeline.apply_transforms(self.img, self.bb) - - self.assertIsInstance(new_img, Image.Image) - self.assertIsInstance(new_bb, list) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/data/notyet-test_pair_loader.py b/tests/data/notyet-test_pair_loader.py deleted file mode 100644 index 987cdf3af3..0000000000 --- a/tests/data/notyet-test_pair_loader.py +++ /dev/null @@ -1,83 +0,0 @@ -import unittest -from pathlib import Path -from unittest.mock import patch - -from recaptcha_classifier.data.pair_loader import ImageLabelLoader - - -class TestImageLabelLoader(unittest.TestCase): - @patch("recaptcha_classifier.data.pair_loader.Path.glob") - @patch("recaptcha_classifier.data.pair_loader.Path.exists", - return_value=True) - @patch("recaptcha_classifier.data.pair_loader.Path.is_dir", - return_value=True) - def test_load_pairs_with_all_labels(self, - is_dir_mock, - exists_mock, - glob_mock): - glob_mock.return_value = [ - Path("data/images/class1/img1.png"), - Path("data/images/class1/img2.png"), - ] # 2 images found in the png glob - - loader = ImageLabelLoader(["class1"]) - pairs = loader.find_pairs() - - expected_pairs = { - (Path("data/images/class1/img1.png"), - Path("data/labels/class1/img1.txt")), - (Path("data/images/class1/img2.png"), - Path("data/labels/class1/img2.txt")), - } - - self.assertEqual(set(pairs["class1"]), expected_pairs) - self.assertEqual(len(pairs["class1"]), 2) - - @patch("recaptcha_classifier.data.pair_loader.Path.glob") - @patch("recaptcha_classifier.data.pair_loader.Path.exists") - @patch("recaptcha_classifier.data.pair_loader.Path.is_dir", - return_value=True) - def test_load_pairs_with_missing_labels(self, - is_dir_mock, - exists_mock, - glob_mock): - glob_mock.return_value = [ - Path("data/images/class1/img1.png"), - Path("data/images/class1/img2.png"), - ] # 2 images found in the png glob - - # label folder exists, img1.txt exists but img2.txt does not - exists_mock.side_effect = [True, True, False] - - loader = ImageLabelLoader(["class1"]) - pairs = loader.find_pairs() - - self.assertIn("class1", pairs) - expected_pair = { - (Path("data/images/class1/img1.png"), - Path("data/labels/class1/img1.txt")), - } - - self.assertEqual(set(pairs["class1"]), expected_pair) - self.assertEqual(len(pairs["class1"]), 1) - - @patch("recaptcha_classifier.data.pair_loader.Path.glob") - @patch("recaptcha_classifier.data.pair_loader.Path.exists") - @patch("recaptcha_classifier.data.pair_loader.Path.is_dir") - def test_caching(self, is_dir_mock, exists_mock, glob_mock): - loader = ImageLabelLoader(["class1"]) - loader._pairs = {"test": [(Path("test.png"), Path("test.txt"))]} - - # if any mocked method is called, then the cache is not used - for method in (is_dir_mock, exists_mock, glob_mock): - method.side_effect = AssertionError(f"{method} called, error!") - - pairs = loader.find_pairs() - - self.assertIs(pairs, loader._pairs) - self.assertEqual(pairs, {"test": [(Path("test.png"), - Path("test.txt"))]}) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/data/notyet-test_scaler.py b/tests/data/notyet-test_scaler.py deleted file mode 100644 index 9a54dd26d4..0000000000 --- a/tests/data/notyet-test_scaler.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest -from recaptcha_classifier.data.scaler import YOLOScaler - - -class TestYOLOScaler(unittest.TestCase): - def test_scale_for_flip(self): - bb = [(0.5, 0.5, 0.2, 0.2), (0.3, 0.4, 0.1, 0.1)] - flipped = YOLOScaler.scale_for_flip(bb) - self.assertEqual(flipped, [(0.5, 0.5, 0.2, 0.2), - (0.7, 0.4, 0.1, 0.1)]) - - def test_scale_for_rotation(self): - bb = [(0.5, 0.5, 0.4, 0.2)] - rot = YOLOScaler.scale_for_rotation(bb, 90, (224, 224)) - self.assertEqual(len(rot), 1) - self.assertTrue(all(0 <= v <= 1 for bb in rot for v in bb)) - - def test_empty_list(self): - self.assertEqual(YOLOScaler.scale_for_flip([]), []) - self.assertEqual(YOLOScaler.scale_for_rotation([], 45, (224, 224)), []) diff --git a/tests/integration/test_checkpoint.py b/tests/integration/test_checkpoint.py new file mode 100644 index 0000000000..91c7294d09 --- /dev/null +++ b/tests/integration/test_checkpoint.py @@ -0,0 +1,48 @@ +import unittest +import torch +import os +from recaptcha_classifier.models.main_model import MainCNN +from recaptcha_classifier.train.training import Trainer +from recaptcha_classifier.data.pipeline import DataPreprocessingPipeline +from recaptcha_classifier.detection_labels import DetectionLabels + + +class TestCheckpointIntegration(unittest.TestCase): + def test_checkpoint_save_load(self): + pipeline = DataPreprocessingPipeline( + DetectionLabels, + batch_size=2, + num_workers=0 + ) + + loaders = pipeline.run() + + model = MainCNN(n_layers=1, kernel_size=3, num_classes=len(DetectionLabels)) + optimizer = torch.optim.Adam(model.parameters()) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + device = torch.device("cpu") + + trainer = Trainer( + train_loader=loaders["train"], + val_loader=loaders["val"], + optimizer=optimizer, + model=model, + scheduler=scheduler, + device=device, + epochs=1, + ) + + trainer.train(model) + + model_state = {k: v.clone() for k, v in model.state_dict().items()} + + model_new = MainCNN(n_layers=1, kernel_size=3, num_classes=len(DetectionLabels)) + + trainer.load_checkpoint_states(model_new) + + for key in model_state: + self.assertTrue( + torch.equal(model_state[key], model_new.state_dict()[key]) + ) + + trainer.delete_checkpoints() \ No newline at end of file diff --git a/tests/integration/test_data_pipeline.py b/tests/integration/test_data_pipeline.py new file mode 100644 index 0000000000..4322638e2f --- /dev/null +++ b/tests/integration/test_data_pipeline.py @@ -0,0 +1,13 @@ +import unittest + +from recaptcha_classifier.data.pipeline import DataPreprocessingPipeline +from recaptcha_classifier.detection_labels import DetectionLabels + + +class TestDataPipeline(unittest.TestCase): + def test_pipeline_runs(self): + pipeline = DataPreprocessingPipeline(DetectionLabels) + loaders = pipeline.run() + + self.assertIn("train", loaders) + self.assertGreater(len(loaders["train"]), 0) \ No newline at end of file diff --git a/tests/features/__init__.py b/tests/unit/data/__init__.py similarity index 100% rename from tests/features/__init__.py rename to tests/unit/data/__init__.py diff --git a/tests/unit/data/test_augment.py b/tests/unit/data/test_augment.py new file mode 100644 index 0000000000..de6eeef074 --- /dev/null +++ b/tests/unit/data/test_augment.py @@ -0,0 +1,29 @@ +import unittest +import numpy as np +from unittest.mock import patch +from PIL import Image +from recaptcha_classifier.data.augment import AugmentationPipeline +from torchvision import transforms + + +class TestAugmentation(unittest.TestCase): + def setUp(self): + self.img = Image.fromarray( + np.tile(np.arange(100, dtype=np.uint8).reshape(100, 1), + (1, 100, 3)) + ) + self.img = self.img.convert("RGB") + + def test_pipeline(self): + pipeline = AugmentationPipeline([ + transforms.RandomHorizontalFlip(p=1.0), + transforms.RandomRotation(degrees=30) + ]) + + new_img = pipeline.apply_transforms(self.img) + + self.assertIsInstance(new_img, Image.Image) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/data/notyet-test_dataset.py b/tests/unit/data/test_dataset.py similarity index 50% rename from tests/data/notyet-test_dataset.py rename to tests/unit/data/test_dataset.py index 48ce2e0376..1fd9118067 100644 --- a/tests/data/notyet-test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -10,57 +10,39 @@ class TestImageDataset(unittest.TestCase): def setUp(self): - self.pairs = [(Path("data/images/c1/i1.png"), - Path("data/labels/c1/i1.txt"))] + self.items = [Path("data/c1/i1.png")] self.class_map = {"c1": 0} self.preprocessor = ImagePrep() - self.augmentator = AugmentationPipeline() + self.augmentator = AugmentationPipeline(transforms_list=[]) @patch.object(ImagePrep, 'load_image') - @patch.object(ImagePrep, 'load_labels') @patch.object(ImagePrep, 'to_tensor') - def test_loading(self, to_tensor_mock, load_labels_mock, load_image_mock): + @patch.object(ImagePrep, 'class_id_to_tensor') + def test_loading(self, class_id_mock, to_tensor_mock, load_image_mock): # Mock the return values load_image_mock.return_value = MagicMock() - load_labels_mock.return_value = [(0.1, 0.2, 0.3, 0.4)] to_tensor_mock.return_value = torch.rand(3, 224, 224) # expect shape + class_id_mock.return_value = torch.tensor(0) dataset = ImageDataset( - pairs=self.pairs, + items=self.items, preprocessor=self.preprocessor, augmentator=self.augmentator, class_map=self.class_map ) - tensor, bboxes, cid = dataset[0] + tensor, cid = dataset[0] self.assertIsInstance(tensor, torch.Tensor) self.assertEqual(tensor.shape, (3, 224, 224)) - self.assertIsInstance(bboxes, list) - self.assertEqual(bboxes, [(0.1, 0.2, 0.3, 0.4)]) - self.assertIsInstance(cid, int) - self.assertEqual(cid, 0) + self.assertIsInstance(cid, torch.Tensor) + self.assertEqual(cid.item(), 0) @patch.object(ImagePrep, 'load_image') - @patch.object(ImagePrep, 'load_labels', return_value=[]) - def test_empty_bb(self, _, load_image_mock): - load_image_mock.return_value = MagicMock() - dataset = ImageDataset( - pairs=self.pairs, - preprocessor=self.preprocessor, - augmentator=self.augmentator, - class_map=self.class_map - ) - with self.assertRaises(ValueError): - dataset[0] - - @patch.object(ImagePrep, 'load_image') - @patch.object(ImagePrep, 'load_labels', - return_value=[(0.1, 0.2, 0.3, 0.4)]) - def test_no_class(self, load_labels_mock, load_image_mock): + def test_no_class(self, load_image_mock): load_image_mock.return_value = np.ones((224, 224, 3)) dataset = ImageDataset( - pairs=self.pairs, + items=self.items, preprocessor=self.preprocessor, augmentator=self.augmentator, ) diff --git a/tests/data/notyet-test_loader_factory.py b/tests/unit/data/test_loader_factory.py similarity index 67% rename from tests/data/notyet-test_loader_factory.py rename to tests/unit/data/test_loader_factory.py index cab1187202..18f768b64e 100644 --- a/tests/data/notyet-test_loader_factory.py +++ b/tests/unit/data/test_loader_factory.py @@ -12,7 +12,7 @@ class TestLoaderFactory(unittest.TestCase): def test_create_loaders(self, dataset_mock): class_map = {"class1": 0, "class2": 1} preprocessor = ImagePrep() - augmentator = AugmentationPipeline() + augmentator = AugmentationPipeline(transforms_list=[]) factory = LoaderFactory(class_map, preprocessor, augmentator) dataset_mock.return_value = dataset_mock @@ -20,20 +20,17 @@ def test_create_loaders(self, dataset_mock): splits = { "train": { - "class1": [(Path("img1.png"), Path("label1.txt")), - (Path("img2.png"), Path("label2.txt"))], - "class2": [(Path("img3.png"), Path("label3.txt"))] + "class1": [Path("img1.png"), Path("img2.png")], }, "val": { - "class1": [(Path("img4.png"), Path("label4.txt"))], - "class2": [(Path("img5.png"), Path("label5.txt"))] + "class1": + [Path("img3.png")], }, "test": { - "class1": [(Path("img6.png"), Path("label6.txt"))], - "class2": [(Path("img7.png"), Path("label7.txt"))] + "class1": [Path("img4.png"), Path("img5.png"), Path("img6.png")] } } - + loaders = factory.create_loaders(splits) self.assertIn("train", loaders) @@ -45,9 +42,7 @@ def test_create_loaders(self, dataset_mock): # We also need to check that the train set has the augmentator dataset_mock.assert_any_call( - pairs=[(Path("img1.png"), Path("label1.txt")), - (Path("img2.png"), Path("label2.txt")), - (Path("img3.png"), Path("label3.txt"))], + items=[Path("img1.png"),Path("img2.png")], preprocessor=preprocessor, augmentator=augmentator, class_map=class_map diff --git a/tests/unit/data/test_paths_loader.py b/tests/unit/data/test_paths_loader.py new file mode 100644 index 0000000000..3eebc66e82 --- /dev/null +++ b/tests/unit/data/test_paths_loader.py @@ -0,0 +1,35 @@ +import unittest +from pathlib import Path +from unittest.mock import patch + +from recaptcha_classifier.data.paths_loader import ImagePathsLoader + + +class TestImagePathsLoader(unittest.TestCase): + @patch("recaptcha_classifier.data.paths_loader.Path.glob") + @patch("recaptcha_classifier.data.paths_loader.Path.exists", + return_value=True) + @patch("recaptcha_classifier.data.paths_loader.Path.is_dir", + return_value=True) + def test_load_paths(self, + is_dir_mock, + exists_mock, + glob_mock): + glob_mock.return_value = [ + Path("data/images/class1/img1.png"), + Path("data/images/class1/img2.png"), + ] # 2 images found in the png glob + + loader = ImagePathsLoader(["class1"]) + pairs = loader.find_image_paths() + + expected_pairs = { + Path("data/images/class1/img1.png"), + Path("data/images/class1/img2.png"), + } + + self.assertEqual(set(pairs["class1"]), expected_pairs) + self.assertEqual(len(pairs["class1"]), 2) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/data/notyet-test_preprocessor.py b/tests/unit/data/test_preprocessor.py similarity index 80% rename from tests/data/notyet-test_preprocessor.py rename to tests/unit/data/test_preprocessor.py index daec7a0f05..9f533df829 100644 --- a/tests/data/notyet-test_preprocessor.py +++ b/tests/unit/data/test_preprocessor.py @@ -24,13 +24,6 @@ def test_load_image(self, open_mock): img_mock.convert.assert_called_once_with("RGB") resized_mock.resize.assert_called_once_with((224, 224), Image.LANCZOS) - @patch("builtins.open", new_callable=unittest.mock.mock_open, - read_data="0 0.2 0.3 0.4 0.5\n") - def test_load_labels(self, file_mock): - result = self.prep.load_labels(Path("test")) - - self.assertEqual(result, [(0.2, 0.3, 0.4, 0.5)]) - def test_to_tensor(self): img = Image.new("RGB", (224, 224)) tensor = self.prep.to_tensor(img) diff --git a/tests/data/notyet-test_splitter.py b/tests/unit/data/test_splitter.py similarity index 94% rename from tests/data/notyet-test_splitter.py rename to tests/unit/data/test_splitter.py index 406b04dc5f..10796cc6d0 100644 --- a/tests/data/notyet-test_splitter.py +++ b/tests/unit/data/test_splitter.py @@ -4,7 +4,7 @@ class TestDataSplitter(unittest.TestCase): def test_default_split(self): - data = {"test_class": [(f"img_{i}", f"label_{i}") for i in range(10)]} + data = {"test_class": [f"img_{i}" for i in range(10)]} splits = DataSplitter().split(data) self.assertEqual(len(splits['train']['test_class']), 7) diff --git a/tests/data/notyet-test_visualizer.py b/tests/unit/data/test_visualizer.py similarity index 79% rename from tests/data/notyet-test_visualizer.py rename to tests/unit/data/test_visualizer.py index 43096a5805..d4f0d9e8e8 100644 --- a/tests/data/notyet-test_visualizer.py +++ b/tests/unit/data/test_visualizer.py @@ -8,16 +8,16 @@ def setUp(self): # only written pairs as string, instead of paths, for simplicity self.sample_splits = { "train": { - "class1": [("img1", "label1"), ("img2", "label2")], - "class2": [("img3", "label3"), ("img4", "label4")], + "class1": ["img1", "img2"], + "class2": ["img3", "img4"], }, "val": { - "class1": [("img5", "label5")], - "class2": [("img6", "label6")], + "class1": ["img5"], + "class2": ["img6"], }, "test": { - "class1": [("img7", "label7")], - "class2": [("img8", "label8")], + "class1": ["img7"], + "class2": ["img8"], }, } diff --git a/tests/models/__init__.py b/tests/unit/features/__init__.py similarity index 100% rename from tests/models/__init__.py rename to tests/unit/features/__init__.py diff --git a/tests/features/test_classification_metrics.py b/tests/unit/features/test_classification_metrics.py similarity index 100% rename from tests/features/test_classification_metrics.py rename to tests/unit/features/test_classification_metrics.py diff --git a/tests/features/test_evaluate.py b/tests/unit/features/test_evaluate.py similarity index 100% rename from tests/features/test_evaluate.py rename to tests/unit/features/test_evaluate.py diff --git a/tests/unit/models/__init__.py b/tests/unit/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/test_HPoptimizer.py b/tests/unit/models/test_HPoptimizer.py similarity index 100% rename from tests/models/test_HPoptimizer.py rename to tests/unit/models/test_HPoptimizer.py diff --git a/tests/test_classifier.py b/tests/unit/models/test_classifier.py similarity index 100% rename from tests/test_classifier.py rename to tests/unit/models/test_classifier.py diff --git a/tests/models/test_training.py b/tests/unit/models/test_training.py similarity index 96% rename from tests/models/test_training.py rename to tests/unit/models/test_training.py index 14b59f42f7..f63c389480 100644 --- a/tests/models/test_training.py +++ b/tests/unit/models/test_training.py @@ -47,6 +47,7 @@ def test_train_process(self) -> None: def test_train_load_checkpoint(self): + self.trainer.train(model=self.model, load_checkpoint=False) self.trainer.train(model=self.model, load_checkpoint=True) assert_equal(os.path.exists(os.path.join(self.trainer.save_folder, self.trainer.model_file_name)), True, diff --git a/tests/models/utils_training_hpo.py b/tests/unit/models/utils_training_hpo.py similarity index 100% rename from tests/models/utils_training_hpo.py rename to tests/unit/models/utils_training_hpo.py