main.py 1.82 KB
Newer Older
1
2
from scipy.io import savemat

3
from config import config
okiss's avatar
okiss committed
4
from ensemble import run
5
6
from SimpleClassifiers.sklearnclassifier import try_sklearn_classifiers, cross_validate_kNN, cross_validate_RFC, \
    cross_validate_SVC
7
import numpy as np
okiss's avatar
okiss committed
8
import scipy
Ard Kastrati's avatar
Ard Kastrati committed
9
from utils.utils import select_best_model, comparison_plot_loss, comparison_plot_accuracy
okiss's avatar
okiss committed
10
from utils import IOHelper
okiss's avatar
okiss committed
11
from scipy import io
okiss's avatar
okiss committed
12
import h5py
okiss's avatar
okiss committed
13
14
import logging
import time
15

Ard Kastrati's avatar
Ard Kastrati committed
16
17
18
19
from kerasTuner import tune
from utils import IOHelper


20
21
22
23
24
25
26
27
28
29
30
31
32
import sys

class Tee(object):
    def __init__(self, *files):
        self.files = files
    def write(self, obj):
        for f in self.files:
            f.write(obj)
            f.flush() # If you want the output to be visible immediately
    def flush(self) :
        for f in self.files:
            f.flush()

33
def main():
34
    # For Logging the data
Ard Kastrati's avatar
Ard Kastrati committed
35
36
    logging.basicConfig(filename=config['info_log'], level=logging.INFO)
    logging.info('Started the Logging')
okiss's avatar
okiss committed
37

38
39
40
41
    # For being able to see progress that some methods use verbose
    f = open(config['model_dir'] + '/console.out', 'w')
    original = sys.stdout
    sys.stdout = Tee(sys.stdout, f)
Ard Kastrati's avatar
Ard Kastrati committed
42

43
44
45
46
47
48
    # Start of the main program
    start_time = time.time()
    # try:
    trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
    print(trainX.shape)
    print(trainY.shape)
Ard Kastrati's avatar
Ard Kastrati committed
49

50
51
    if config['model'] == 'simple_classifier':
        # try_sklearn_classifiers(trainX, trainY)
52
        # cross_validate_kNN(trainX, trainY)
53
        # cross_validate_SVC(trainX, trainY)
54
        cross_validate_RFC(trainX, trainY)
55
56
57
58
59
    else:
        # tune(trainX,trainY)
        run(trainX,trainY)
        # comparison_plot_accuracy()
        # comparison_plot_loss()
Ard Kastrati's avatar
Ard Kastrati committed
60

Ard Kastrati's avatar
Ard Kastrati committed
61
62
    logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
    logging.info('Finished Logging')
63

64
if __name__=='__main__':
65
    main()