Skip to content

QTIM-Lab/Conditional-Diffusion-Editing-basic

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🌀 Conditional Diffusion Training & Editing

Train and edit images using label-conditioned diffusion models, with support for custom datasets like Chest X-rays, fundus images, or any medical or visual data.

✨ What’s Included

  • A full training loop for conditional diffusion using a Stable Diffusion-style VAE + UNet
  • Modular label encoder for conditioning
  • Dataset format guide and examples
  • Image editing code using reverse+forward diffusion

📁 Folder Structure

conditional-diffusion-toolkit/
├── training/           # Training script
├── editing/            # Inference/editing script
├── datasets/           # Dataset format for image+label
├── models/             # Conditioning encoder
├── utils/              # Schedulers & diffusion helpers
├── outputs/            # Saved weights and generated samples
└── my_data/            # You create this to provide your dataset

🧠 How to Use with Your Own Data

You must provide a dataset returning:

  • image: Tensor [C, H, W]
  • condition: scalar label (float or int)

✅ Step 1: Format your data

# my_data/dataset.py
import torch
from datasets.image_condition_dataset import ImageConditionDataset

def get_dataset():
    images = torch.load("path/to/images.pt")        # [N, C, H, W]
    labels = torch.load("path/to/labels.pt")        # [N]
    return ImageConditionDataset(images, labels)

Or load from JPGs and CSV using torchvision:

def get_dataset():
    import pandas as pd
    from PIL import Image
    from torchvision import transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),#
        transforms.ToTensor(),
    ])

    df = pd.read_csv("labels.csv")
    imgs, labs = [], []
    for _, row in df.iterrows():
        img = Image.open(f"images/{row['filename']}").convert("L") # or rgb 
        imgs.append(transform(img))
        labs.append(float(row['label']))
    return ImageConditionDataset(torch.stack(imgs), torch.tensor(labs))

🏋️‍♀️ Train the Model

python training/train.py --dataset my_data.dataset --batch_size 16

Your weights + generations will be saved in outputs/.


🎨 Edit Images with New Conditions

python editing/edit_image.py \
    --image_path path/to/sample.pt \
    --original_label 0

This inverts the image to noise space and re-generates under new conditions.

Results saved to edited_outputs/.


💡 Examples

✅ Binary classification (e.g. “tumor: yes/no”)
✅ Severity scores (e.g. “DR Grade 0–4”)
✅ Digit class (0–9) or multi-class labels

You define the meaning of the conditioning signal!


📦 Installation

pip install -r requirements.txt

✅ Requirements

  • torch >= 1.10
  • torchvision
  • diffusers
  • tqdm
  • pandas
  • Pillow

Made with ❤️ to help others learn diffusion through practical code.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages