Skip to content

Unsupervised Domain Adaptation (MNIST -> USPS) using a 'From Scratch' implementation of the Sinkhorn Optimal Transport Algorithm.

Notifications You must be signed in to change notification settings

Faycal214/Optimal-Transport-Domain-Adaptation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🌉 Unsupervised Domain Adaptation via Optimal Transport

Python PyTorch Status License

The Question: Imagine you train a self-driving car in sunny California. What happens when you deploy it in snowy Norway? It fails.

This is the Domain Shift problem. This repository implements a mathematical solution to transfer knowledge from a Source Domain (MNIST) to a Target Domain (USPS) without seeing a single label from the target dataset.

It features a "From Scratch" implementation (no POT library) of the Sinkhorn-Knopp Algorithm to solve the Entropic Optimal Transport problem.


🎬 Visualizing the Math

Before diving into code, let's see Optimal Transport in action. The animation below shows the Geodesic Interpolation between the Source distribution (Red) and the Target distribution (Blue).

Optimal Transport Interpolation

The algorithm physically moves the source features through the latent space to align them with the target geometry, minimizing the Wasserstein Distance.


🧠 The Theory: What is Optimal Transport?

In simple terms, Optimal Transport (OT) answers the question: "What is the most efficient way to move a pile of sand (distribution A) to fill a hole (distribution B)?"

In Machine Learning, we use it to align two datasets that contain similar information but look different (e.g., different camera angles, lighting, or writing styles).

The Math Behind the Code

Standard OT is computationally expensive $O(n^3)$. To make it usable for Deep Learning, we use Entropic Regularization. This turns the problem into a strictly convex optimization that can be solved extremely fast using matrix scaling.

I implemented the Sinkhorn Algorithm manually to solve:

$$P^* = \text{argmin}_{P \in U(r, c)} \langle P, C \rangle - \epsilon H(P)$$

Where:

  • $C$: The Cost Matrix (Distance between features).
  • $H(P)$: The Entropy (adds "blur" to speed up convergence).
  • $\epsilon$: The regularization coefficient.

🚀 Results & Performance

We transfer knowledge from MNIST (Source) to USPS (Target).

  • Source: Model trained only on MNIST labels.
  • Target: Model evaluated on USPS (Target labels are never used during adaptation).
Method Accuracy on USPS
Random Guessing 10.0%
Naive Model (Source Only) ~35.0%
Our Approach (Sinkhorn OT) 87.20% 🚀

By simply aligning the geometry of the data, we recover 87% accuracy without any supervision.


📊 Analysis & Interpretation

1. Geometric Matching (The "Transport Plan")

The algorithm automatically finds which source digits correspond to which target digits based on their topology.

Feature Matching

  • Observation: The model correctly maps a sharp MNIST "5" to a blurry USPS "5".
  • Insight: It relies on geometric features (curves, loops) rather than pixel values.

2. Error Analysis (Confusion Matrix)

Where does the geometry fail?

Confusion Matrix

  • The "4 vs 7" Problem: Look at the bottom right. The model often confuses 4 with 7.
  • Why? In the USPS dataset, handwriting for '4' often has an open loop or a straight tail, making it topologically almost identical to a '7' in the feature space. This highlights a limitation of purely geometric unsupervised adaptation.

💻 How to Run

1. Installation

Clone the repo and install dependencies (lightweight, mainly PyTorch & Numpy).

git clone [https://github.com/YourUsername/Optimal-Transport-Domain-Adaptation.git](https://github.com/YourUsername/Optimal-Transport-Domain-Adaptation.git)
cd Optimal-Transport-Domain-Adaptation
pip install -r requirements.txt

2. Train Source Model

First, we need a model that understands digits on the Source domain (MNIST).

python src/train_mnist.py

3. Run Domain Adaptation

This script extracts features, computes the Sinkhorn transport plan, predicts labels, and generates the GIF/Plots.

python full_demo.py

About

Unsupervised Domain Adaptation (MNIST -> USPS) using a 'From Scratch' implementation of the Sinkhorn Optimal Transport Algorithm.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages