To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

definitions.py 3.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Stephan Wegner & Vithurjan Visuvalingam
pdz, ETH Zürich
2020

This file contains all definitions, you might want to use for running the software code
"""

import os
import json
import numpy as np
from opti import hyperparam, length

#path
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
CONFIG_PATH = os.path.join(ROOT_DIR, 'configuration.conf')
name_base = 'EHC_Y_P'

#Please choose, which oparation(s) you want to start by choosing 0 (operation is not started) or 1 (operation will be started)
operation = {'2DCNN_Inference':     0,
             'extract_features':    0,
             'create_segments':     0,
             '3DCNN_train_class':   0,
             '3DCNN_predict':       1,
             'post-processing':     1
            }

#configurations
train_val_nums = [2,3,4,5,6,7,8,9,12,13,14,15,16,17,18,19,20,21,24,25,26,27,28,29,30,31,32,33,34]
test_nums = [1,10,11,22,23]

# choose: "train" or "test"
mode = "test"

##############################
### Definitions for 2D CNN ###
##############################

#choose: "original" or "black"
image_type='original'

# 'w_pen': 1.0,'w_phone':0.96,'w_pillow':1.3, 'w_smart': 1.6
class_weights = [1, 0.96, 1.3, 1.6]

#Read out the OOIs for the 2D_CNN, Hands are added subsequent to OOIs
path_to_labels = ROOT_DIR + '\\data\\raw\\labels\\labels_yps.json'
with open(path_to_labels) as json_file:
    data = json.load(json_file)
OOI_data = np.array(list(data.items()))
OOIs=[]
for i in range(len(OOI_data)):
    OOIs.append(str(OOI_data[i][0]))
OOIs.append('Hands')

##############################
### Definitions for 3D CNN ###
##############################

###################
#### For inference

path_to_load_classification_network = ROOT_DIR + "\\" + r"models\ThreeDCNN\classification_NN\2020Dec24_0035_0.81_TCNN_class_acc_0.6784.h5"
HEC_classes = ['Background', 'Guiding', 'Directing', 'Checking', 'Observing']

#################
#### For training

# During the training, the model is saved every 5 epochs.
# Define learning rate and number of epochs for the training
# The training set is devided into k folds for training/ val split.
# You can define, which fold you want to start in the training


# Parameters for training the 3D CNN
hyperparam=hyperparam #Hyperparameters for training the 3D CNN
length=length # number of hyperparameter sets

#k-fold cross validation, start and end point defined for validation set, remaining samples are collected in training set
starts=[]   #start of the split, for k=5: 0.01, 0.21, 0.41, 0.61, 0.81
ends=[]     #end of the split, for k=5: 0.20, 0.40, 0.60, 0.80, 1.00
k = 5
step = 1 / k
a = 0
while a < k:
    start = round((1 / 100 + 10 * step * a / 10), 2)
    end = round((start + step - 1 / 100), 2)
    ends.append(end)
    starts.append(start)
    a+=1
runs=len(starts)

# decision, which fold is start fold (0,1,2,3,4)
start_fold=0

# choose if undersampling should be applied (True or False)
undersampling = False

##### Restart training
# You can restart the training from one of the saved models
path_to_load_retrain_network = ROOT_DIR + "\\" + r"models\ThreeDCNN\classification_NN\temp_training\epoch_15.h5"

# choose model via path_to_load_retrain_network or 'None'
model = 'None'

# number of epoch to restart training
if model == 'None':
    start_epoch = 0
else:
    start_epoch = 15

# new learning rate for training the model
learning_rate= 1e-3
epochs = 50

##################################
###Definitions for post-procession
##################################

# Define if you want to use BG in the OOI distribution - HEC plot
use_bg = False