Train and edit images using label-conditioned diffusion models, with support for custom datasets like Chest X-rays, fundus images, or any medical or visual data.
- A full training loop for conditional diffusion using a Stable Diffusion-style VAE + UNet
- Modular label encoder for conditioning
- Dataset format guide and examples
- Image editing code using reverse+forward diffusion
conditional-diffusion-toolkit/
├── training/ # Training script
├── editing/ # Inference/editing script
├── datasets/ # Dataset format for image+label
├── models/ # Conditioning encoder
├── utils/ # Schedulers & diffusion helpers
├── outputs/ # Saved weights and generated samples
└── my_data/ # You create this to provide your dataset
You must provide a dataset returning:
image: Tensor [C, H, W]condition: scalar label (float or int)
# my_data/dataset.py
import torch
from datasets.image_condition_dataset import ImageConditionDataset
def get_dataset():
images = torch.load("path/to/images.pt") # [N, C, H, W]
labels = torch.load("path/to/labels.pt") # [N]
return ImageConditionDataset(images, labels)Or load from JPGs and CSV using torchvision:
def get_dataset():
import pandas as pd
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),#
transforms.ToTensor(),
])
df = pd.read_csv("labels.csv")
imgs, labs = [], []
for _, row in df.iterrows():
img = Image.open(f"images/{row['filename']}").convert("L") # or rgb
imgs.append(transform(img))
labs.append(float(row['label']))
return ImageConditionDataset(torch.stack(imgs), torch.tensor(labs))python training/train.py --dataset my_data.dataset --batch_size 16Your weights + generations will be saved in outputs/.
python editing/edit_image.py \
--image_path path/to/sample.pt \
--original_label 0This inverts the image to noise space and re-generates under new conditions.
Results saved to edited_outputs/.
✅ Binary classification (e.g. “tumor: yes/no”)
✅ Severity scores (e.g. “DR Grade 0–4”)
✅ Digit class (0–9) or multi-class labels
You define the meaning of the conditioning signal!
pip install -r requirements.txt- torch >= 1.10
- torchvision
- diffusers
- tqdm
- pandas
- Pillow
Made with ❤️ to help others learn diffusion through practical code.