Skip to content
Snippets Groups Projects
README.md 5.51 KiB
Newer Older
andbergm's avatar
andbergm committed
# Noise Is All You Need

This is the codebase for the **DDPMs** implementation, supplementing the project report *Noise Is All You Need*.

## Installation

Clone this repository. The run:

```bash
pip install -e .
```

to install the `semantic_diffusion` package and all required dependencies.

## Training Data

You can eighter use our [dataset](https://polybox.ethz.ch/index.php/s/iYOsCjI9xcjhkMI), or create your own training dataset. In the latter case make sure the datafolder has the following structure:

```
data
└───subdir_1
│   └───training
│       └───images
│           │   satimage_1.png
│           │   satimage_2.png
│           │   ...
│       └───groundtruth
│           │   satimage_1.png
│           │   satimage_2.png
│           │   ...
│   └───validation
│       └───images
│           │   satimage_10.png
│           │   satimage_11.png
│           │   ...
│       └───groundtruth
│           │   satimage_10.png
│           │   satimage_11.png
│           │   ...
└───...
```

The training set will then correspond to all image-groundtruth pairs from the `training` directories of all subdirectories contained in the data folder.
The same holds for the validation set.

## Training

To train the diffusion model, run

```bash
python scripts/train.py --data_dir path/to/data/ $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 
```

In the following, we provide a summary of the most relevant supported flags:

### Model Flags

- ```resolution```: Resolution of the input image. If the input image size does not match, the image will be rescaled.
- ```use_pretrained_model```: Boolean flag to indicate whether to use the pretrained U-Net++ model or the non-pretrained U-Net model.

The following flags are only relevant in case the non-pretrained U-Net model is used, i.e., when ```--use_pretrained_model False``` is set.

- ```num_res_blocks```: Number of residual blocks per resolution level.
- ```num_channels```: Multiplier for the number of channels per resolution level.
- ```attention_resolutions```: A list of downsample resolutions at which an attention layer is inserted.
- ```num_heads```: Number of attention heads for each attention layer.

### Diffusion Flags

- ```diffusion_type```: One of ```[gaussian, bernoulli]```.
- ```diffusion_steps```: Number of diffusion steps.
- ```noise_schedule```: One of ```[linear, cosine]```.
- ```loss_type```: One of ```[vlb, mse, hybrid]``` for the continuous Gaussian diffusion, and one of ```[vlb, nll, hybrid]``` for the discrete Bernoulli diffusion.
- ```hybrid_loss_lambda```: Only relevant if ```--loss_type hybrid``` is set. The loss will then be computed as ```vlb + λ mse``` or ```vlb + λ nll```, respectively.

### Training Flags

- ```lr```: Initial learning rate.
- ```batch_size```: Training batch size.
- ```train_steps```: Number of training steps. If set to `-1`, training will continue till manually aborted.
- ```validation_interval```: Number of training steps after which a validation step is performed.
- ```save_interval```: Number of training steps after which a model checkpoint is saved.
- ```log_interval```: Number of training steps after which the training progress is logged.
- ```wandb_logging```: Set to ```True``` to enable logging to [WANDB](https://wandb.ai). It is necessary to create a `.env` file containing the following keys: ```WANDB_API_KEY```, ```WANDB_ENTITY```, ```WANDB_PROJECT```.
- ```resume_checkpoint```: Provide ```path/to/modelxxxx.py``` to continue training from a saved checkpoint.

## Predicting new masks

The above training script saves checkpoints to `.pt` files in the logging directory. These checkpoints can be used to predict new masks for a given set of images like so:

```bash
python scripts/predict.py --test_dir "path/to/images/" --output_dir "./path/to/output_dir/" --model_path "/path/to/checkpoint.pt" $PREDICT_FLAGS
```

### Predict Flags

- ```num_mask_samples```: Number of masks to sample for each input image. The average of the sampled mask will be output.
- ```threshold```: Threshold value in ```[0, 1]``` for a pixel to be set to label 0 or 1.
- ```batch_size```: Number of input images to be processed at a time.

## Experiments

In the following, we list the commands used to perform our experiments described in the paper.

- U-Net model with Gaussian diffusion:
  
  ```bash
  python scripts/train.py --data_dir path/to/data/ --resolution 128 --use_pretrained False --num_channels 64 --num_res_blocks 2 --diffusion_type gaussian --loss_type mse --diffusion_steps 100 --noise_schedule linear --train_steps -1 --batch_size 54 --augmentation True --validation_num_samples 4 --log_interval 100 --save_interval 5000 --validation_interval 5000
  ```

- U-Net model with Bernoulli diffusion:
  
  ```bash
  python scripts/train.py --data_dir path/to/data/ --resolution 128 --use_pretrained False --num_channels 64 --num_res_blocks 2 --diffusion_type bernoulli --loss_type nll --diffusion_steps 100 --noise_schedule linear --train_steps -1 --batch_size 54 --augmentation True --validation_num_samples 4 --log_interval 100 --save_interval 5000 --validation_interval 5000
  ```

- Pretrained U-Net++ model with Bernoulli diffusion:
  
  ```bash
  python scripts/train.py --data_dir path/to/data/ --resolution 416 --use_pretrained True --diffusion_type bernoulli --loss_type nll --diffusion_steps 100 --noise_schedule linear --train_steps -1 --batch_size 12 --augmentation True --validation_num_samples 4 --log_interval 100 --save_interval 5000 --validation_interval 5000
  ```