An Android demo application showcasing Federated Learning (FL) on a mobile device using the org.flex:flexible library and LiteRT (TensorFlow Lite) for on-device inference.
The demo trains a small neural network collaboratively across clients without sharing raw data. Each round, the device trains locally on an XOR classification task and sends only the updated model weights back to the FL server for aggregation.
FlexDemo/
├── app/ # Android application
│ └── src/main/java/com/example/flexdemo/
│ ├── MainActivity.kt # Entry point and Compose UI
│ ├── LiteRTClient.kt # FL client with TFLite integration
│ ├── client.kt # DummyClient (reference) and demo() launcher
│ ├── TrainingEvent.kt # UI state sealed class
│ ├── TrainingProgressListener.kt # Bridge between FL callbacks and UI state
│ └── TensorUtils.kt # Weight serialization utilities
├── scripts/
│ ├── generate_xor_model.py # Generates the xor_model.tflite asset
│ ├── python_client.py # Python reference client for the same FL server
│ └── README.md # Model generation instructions
└── AGENTS.md # Agent/CI guide
FL Server
│
│ 1. broadcast global weights
▼
LiteRTClient (Android)
│
│ 2. setWeights() – receive aggregated weights from server
│ 3. train() – run 100 epochs of local backprop on XOR data
│ 4. evaluate() – run inference via LiteRT, compute accuracy/loss
│ 5. getWeights() – serialize and upload updated weights
│
└──► FL Server (next aggregation round)
The UI reacts to each step via TrainingEvent states emitted through TrainingProgressListener.
A 2-layer MLP trained on the XOR problem:
| Layer | Shape | Activation |
|---|---|---|
| Input | 2 units | — |
| Hidden | 4 units | ReLU |
| Output | 1 unit | Sigmoid |
Total parameters: 17 (W1: 8, b1: 4, W2: 4, b2: 1)
The model is pre-trained in Python and exported as xor_model.tflite into the app's assets. On-device, LiteRT is used for inference and evaluation, while training uses manual backpropagation (the LiteRT interpreter is inference-only).
Weights are packed as TensorData (shape + raw bytes). TensorUtils.kt converts FloatArray to/from byte arrays using Big-Endian order to match the server's expected format.
- Android SDK,
minSdk = 33,compileSdk = 36 - The
org.flex:flexiblelibrary installed in your local Maven cache (~/.m2) - A running FL server (default:
http://192.168.1.137:8080) - Python 3 + TensorFlow (only needed to regenerate the TFLite model)
The model asset is required before building. Run from the project root:
pip install tensorflow numpy
python scripts/generate_xor_model.pyThis trains an XOR MLP for 500 epochs and writes app/src/main/assets/xor_model.tflite.
Alternatively, place any compatible TFLite model at that path manually. The client expects:
- Input shape:
[1, 2] - Output shape:
[1, 1]
The server address is hardcoded in app/src/main/java/com/example/flexdemo/client.kt:
val config = ClientConfig.simple("http://192.168.1.137:8080")Update this to point to your FL server before building.
# Assemble debug APK
./gradlew assembleDebug
# Run unit tests
./gradlew test
# Lint check
./gradlew lint
# Clean
./gradlew cleanInstall and launch on a device or emulator (API 33+). The app connects to the FL server automatically on startup and displays training progress in real time.
scripts/python_client.py is a Keras-based FL client that connects to the same server. It is useful for testing the server independently of the Android app or for running multi-client FL experiments across platforms.
The single-screen Compose UI reflects the following states:
TrainingEvent |
Display |
|---|---|
Idle |
"Waiting for training to start..." |
Started |
Spinner + "Training Started" |
Progress(message, loss?) |
Spinner + round message + optional loss |
Completed(metrics) |
Green checkmark + final metrics card |
Error(message) |
Red error text + error detail card |
| Dependency | Purpose |
|---|---|
org.flex:flexible |
Federated learning client library |
com.google.ai.edge:litert |
On-device TFLite inference |
androidx.compose.* |
UI (Material 3, BOM 2024.09.00) |
androidx.lifecycle:* |
Coroutine lifecycle scope |