This repository provides an extended 3D U-Net framework for segmentation and analysis of volumetric composite microstructures.
The implementation builds upon the original pytorch-3dunet framework by Wolny et al. and significantly extends it for applications in composite materials and textile reinforcements.
Key extensions include:
- Flexible n-dimensional UNet architectures with configurable blocks, normalization, and activations
- Support for multi-class volumetric segmentation
- Integration of orientation-aware supervision for fiber and yarn direction prediction
- Customizable loss functions and training pipelines
- Adaptation to low-resolution industrial CT data
- Efficient processing of large volumetric datasets
The framework is designed for segmentation of textile reinforcement architectures in composite materials, enabling identification of components such as warp yarns, weft yarns, and matrix regions in 3D CT scans. The resulting segmentations can be used for microstructural characterization, multiscale modeling, and simulation-driven materials design.
- Integrate spatial and self-attention mechanisms (e.g., QKV attention) to improve contextual feature aggregation
- Extend the framework for AI-based segmentation and orientation estimation in low-resolution CT scans, using structure tensor analysis on high-resolution datasets and synthetic data as reference for fiber and yarn orientation
- Enable combined segmentation–orientation pipelines for multiscale microstructural characterization
The corresponding publication is available at:
The dataset can be accessed via Zenodo:
This repository is built with PyTorch, a Python-based, GPU-accelerated deep learning library. It leverages the CUDA toolkit for efficient computation on NVIDIA GPUs.
We strongly recommend using an NVIDIA GPU and installing the appropriate CUDA drivers for full functionality and performance.
- Clone the repository and navigate to it in your terminal.
git clone https://github.com/choROPeNt/3dseg.git
cd 3dsegThen run:
python -m pip install -e .This should install the 3dseg python package via PIP in the current active virtual enviroment. How to set up a virtual enviroment please refer to virtual enviroment section
If you are using the High Performance Computing (HPC) cluster of the TU Dresden, we recommend using one of the GPU clusters like Alpha (Nvidia A100 SXM 40 GB) or Capella (Nvidia H100). First, allocate some ressources e.g. for alpha
srun -p alpha -N 1 -t 01:00:00 -c 6 --mem=16G --gres=gpu:1 --pty /bin/bash -lYou can use the following module setup (adjust as needed for your cluster’s module system):
ml release/24.10 GCC/13.3.0 Python/3.12.3 CUDA/12.8.0 OpenMPI/5.0.3afterwards, create a new virtual enviroment in directory:
python -m venv --system-site-packages .venvIt is important to set the flag --system-site-packages otherwise you dont have access to the prebuild pytorch package (recommended workaround).
Activate the enviroment via:
source .venv/bin/activateModel training is initiated using the train.py script and a corresponding YAML configuration file:
python scripts/train.py --config=<path-to-congig-yml>The configuration file specifies model architecture, dataset paths, training hyperparameters, logging, and checkpointing options. Example configurations can be found in the configs folder. Each config file contains inline comments or is self-explanatory with regard to most parameters such as batch size, learning rate, data augmentation, loss functions, and optimizer settings.
During training, checkpoints are saved periodically, and training metrics are logged for visualization (e.g., via TensorBoard or custom loggers).
To run inference using a trained model, use:
python scripts/predict.py --config=<path-to-congig-yml>This will load the model from the checkpoint defined in the config file and perform prediction on the specified input data.
Please note that the choice of a padding (e.g. mirror) padding is recommended for better prediction on the edges.
The model will output the prediction probabilities after choosen activation function (eg. sigmoid or softmax) for every channel. Please consider memory allocations and space on your hard drive, precition will save a [c,z,y,x] array as float32.
We employ OmniOpt, a hyperparameter optimization framework developed at TU Dresden, to tune model parameters for improved performance. Integration into this project is currently under development, and future releases will include automated optimization workflows using OmniOpt.
Further information can be found here Documematation OmniOpt or from ScaDS.AI
Currently the FFT-based 2-Point Correlation in PyTorch is available. For more higher dimensional descriptors we kindly revise to MCRpy from the NEFM at TU Dresden.
The FFT-based 2-Point correlation function is defined as follows:
where
-
$x$ is your binary input (microstructure or phase) -
$\ast$ is convolution (autocorrelation) -
$\mathcal{F}$ and$\mathcal{F}^{-1}$ are FFT and IFFT -
$N$ is the total number of elements (for normalization)
