PyTorch code for diffusion-based medical image segmentation
Top 89.7% on sourcepulse
This repository provides the official PyTorch implementation for "Diffusion Models for Implicit Image Segmentation Ensembles," a novel semantic segmentation method for medical images. It targets researchers and practitioners in medical imaging and computer vision seeking advanced segmentation techniques with built-in uncertainty quantification. The primary benefit is improved segmentation performance through implicit ensembling and detailed pixel-wise uncertainty maps.
How It Works
The method adapts diffusion models for image segmentation by conditioning the diffusion process on the input image. During training, the ground truth segmentation is used, with the image acting as a prior. In the sampling process, the image prior is applied at each step, enabling the generation of a distribution of segmentation masks. This stochasticity allows for the creation of uncertainty maps and an implicit ensemble of segmentations, enhancing overall accuracy.
Quick Start & Requirements
pip install -r requirements.txt
(specific requirements not detailed in README).python3 scripts/segmentation_train.py --data_dir ./data/training $TRAIN_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS
python scripts/segmentation_sample.py --data_dir ./data/testing --model_path ./results/savedmodel.pt --num_ensemble=5 $MODEL_FLAGS $DIFFUSION_FLAGS
Highlighted Details
Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
2 years ago
Inactive