To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 6842a8f0 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

added function for data split

parent a9ae96b9
import numpy as np
import math
import logging
def generate_split(X, y):
"""
Computes train, test and validation data from X, y
Splits them in 70/15/15 percentages
"""
size = len(y)
ids = y[:,0]
num_ids = len(np.unique(ids))
test_split = math.ceil(0.15 * num_ids)
valid_split = test_split
train_split = num_ids - valid_split - test_split
assert(train_split + test_split + valid_split == num_ids)
# Run through the data and assemble the sets
X_train, y_train, X_test, y_test, X_valid, y_valid = [], [], [], [], [], []
curr_id = -1
id_count = 0
# Create the ids up to which we split parts of the set
valid_split = train_split + valid_split
test_split = valid_split + test_split
for i, label in enumerate(y):
id = label[0]
if id != curr_id:
curr_id = id
id_count += 1
if id_count <= train_split:
X_train.append(X[i])
y_train.append(label)
elif id_count <= valid_split:
X_valid.append(X[i])
y_valid.append(label)
else:
X_test.append(X[i])
y_test.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
X_test = np.array(X_test)
y_test = np.array(y_test)
X_valid = np.array(X_valid)
y_valid = np.array(y_valid)
logging.info(f"Training data shapes X, y: {X_train, y_train}")
logging.info(f"Test data shapes X, y: {X_test, y_test}")
logging.info(f"Validation data shapes: {X_valid, y_valid}")
return X_train, y_train, X_test, y_test, X_valid, y_valid
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment