Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Noise Is All You Need
This is the codebase for the **DDPMs** implementation, supplementing the project report *Noise Is All You Need*.
## Installation
Clone this repository. The run:
```bash
pip install -e .
```
to install the `semantic_diffusion` package and all required dependencies.
## Training Data
You can eighter use our [dataset](https://polybox.ethz.ch/index.php/s/iYOsCjI9xcjhkMI), or create your own training dataset. In the latter case make sure the datafolder has the following structure:
```
data
└───subdir_1
│ └───training
│ └───images
│ │ satimage_1.png
│ │ satimage_2.png
│ │ ...
│ └───groundtruth
│ │ satimage_1.png
│ │ satimage_2.png
│ │ ...
│ └───validation
│ └───images
│ │ satimage_10.png
│ │ satimage_11.png
│ │ ...
│ └───groundtruth
│ │ satimage_10.png
│ │ satimage_11.png
│ │ ...
└───...
```
The training set will then correspond to all image-groundtruth pairs from the `training` directories of all subdirectories contained in the data folder.
The same holds for the validation set.
## Training
To train the diffusion model, run
```bash
python scripts/train.py --data_dir path/to/data/ $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
```
In the following, we provide a summary of the most relevant supported flags:
### Model Flags
- ```resolution```: Resolution of the input image. If the input image size does not match, the image will be rescaled.
- ```use_pretrained_model```: Boolean flag to indicate whether to use the pretrained U-Net++ model or the non-pretrained U-Net model.
The following flags are only relevant in case the non-pretrained U-Net model is used, i.e., when ```--use_pretrained_model False``` is set.
- ```num_res_blocks```: Number of residual blocks per resolution level.
- ```num_channels```: Multiplier for the number of channels per resolution level.
- ```attention_resolutions```: A list of downsample resolutions at which an attention layer is inserted.
- ```num_heads```: Number of attention heads for each attention layer.
### Diffusion Flags
- ```diffusion_type```: One of ```[gaussian, bernoulli]```.
- ```diffusion_steps```: Number of diffusion steps.
- ```noise_schedule```: One of ```[linear, cosine]```.
- ```loss_type```: One of ```[vlb, mse, hybrid]``` for the continuous Gaussian diffusion, and one of ```[vlb, nll, hybrid]``` for the discrete Bernoulli diffusion.
- ```hybrid_loss_lambda```: Only relevant if ```--loss_type hybrid``` is set. The loss will then be computed as ```vlb + λ mse``` or ```vlb + λ nll```, respectively.
### Training Flags
- ```lr```: Initial learning rate.
- ```batch_size```: Training batch size.
- ```train_steps```: Number of training steps. If set to `-1`, training will continue till manually aborted.
- ```validation_interval```: Number of training steps after which a validation step is performed.
- ```save_interval```: Number of training steps after which a model checkpoint is saved.
- ```log_interval```: Number of training steps after which the training progress is logged.
- ```wandb_logging```: Set to ```True``` to enable logging to [WANDB](https://wandb.ai). It is necessary to create a `.env` file containing the following keys: ```WANDB_API_KEY```, ```WANDB_ENTITY```, ```WANDB_PROJECT```.
- ```resume_checkpoint```: Provide ```path/to/modelxxxx.py``` to continue training from a saved checkpoint.
## Predicting new masks
The above training script saves checkpoints to `.pt` files in the logging directory. These checkpoints can be used to predict new masks for a given set of images like so:
```bash
python scripts/predict.py --test_dir "path/to/images/" --output_dir "./path/to/output_dir/" --model_path "/path/to/checkpoint.pt" $PREDICT_FLAGS
```
### Predict Flags
- ```num_mask_samples```: Number of masks to sample for each input image. The average of the sampled mask will be output.
- ```threshold```: Threshold value in ```[0, 1]``` for a pixel to be set to label 0 or 1.
- ```batch_size```: Number of input images to be processed at a time.
## Experiments
In the following, we list the commands used to perform our experiments described in the paper.
- U-Net model with Gaussian diffusion:
```bash
python scripts/train.py --data_dir path/to/data/ --resolution 128 --use_pretrained False --num_channels 64 --num_res_blocks 2 --diffusion_type gaussian --loss_type mse --diffusion_steps 100 --noise_schedule linear --train_steps -1 --batch_size 54 --augmentation True --validation_num_samples 4 --log_interval 100 --save_interval 5000 --validation_interval 5000
```
- U-Net model with Bernoulli diffusion:
```bash
python scripts/train.py --data_dir path/to/data/ --resolution 128 --use_pretrained False --num_channels 64 --num_res_blocks 2 --diffusion_type bernoulli --loss_type nll --diffusion_steps 100 --noise_schedule linear --train_steps -1 --batch_size 54 --augmentation True --validation_num_samples 4 --log_interval 100 --save_interval 5000 --validation_interval 5000
```
- Pretrained U-Net++ model with Bernoulli diffusion:
```bash
python scripts/train.py --data_dir path/to/data/ --resolution 416 --use_pretrained True --diffusion_type bernoulli --loss_type nll --diffusion_steps 100 --noise_schedule linear --train_steps -1 --batch_size 12 --augmentation True --validation_num_samples 4 --log_interval 100 --save_interval 5000 --validation_interval 5000
```