Commit 5f819931 authored by Lukas's avatar Lukas
Browse files

added functionality for calibration task

parent 7b30a04c
......@@ -43,7 +43,7 @@ finally used for classification.
Cluster can be set to clustering(), clustering2() or clustering3(), where different clusters based on literature are used.
"""
# Choose experiment TODO: create config['experiment'] which can be chosen like the model via commenting out
config['gaze-reg'] = False # Set to False if you want to run the saccade classification task
config['gaze-reg'] = True # Set to False if you want to run the saccade classification task
config['prosaccade'] = False
config['calibration-task'] = False
......@@ -57,10 +57,11 @@ config['epochs'] = 1
config['batch_size'] = 64
# Choose which dataset to run the gaze regression on
config['data_mode'] = 'fix_only'
#config['data_mode'] = 'fix_only'
#onfig['data_mode'] = 'sacc_only'
#config['data_mode'] = 'sacc_fix'
#config['data_mode'] = 'fix_sacc_fix'
config['data_mode'] = 'calib_task_fix_sacc_fix'
# Choose model
#config['model'] = 'cnn'
......@@ -100,6 +101,8 @@ if config['gaze-reg']:
config['padding'] = 'repeat' # options: zero, repeat #TODO: find more options for clever padding
config['min_fixation'] = 50 # choose a minimum length for the gaze fixation
config['max_fixation'] = 150 # choose a maximum length for the gaze fixation
if config['data_mode'] == 'calib_task_fix_sacc_fix':
config['max_fixation'] = 1000
config['min_saccade'] = 10 # minimum number of samples for a saccade that we want to use
config['max_saccade'] = 30 # maximum number of samples for a saccade that we want to use
config['x_screen'] = 600
......
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"rows, cols = (27,2)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"tar_pos = [[400, 300],\n",
" [650, 500],\n",
" [400, 100],\n",
" [100, 450],\n",
" [700, 450],\n",
" [100, 500],\n",
" [200, 350],\n",
" [300, 400],\n",
" [100, 150],\n",
" [150, 500],\n",
" [150, 100],\n",
" [700, 100],\n",
" [300, 200],\n",
" [100, 100],\n",
" [700, 500],\n",
" [500, 400],\n",
" [600, 250],\n",
" [650, 100],\n",
" [400, 300],\n",
" [200, 250],\n",
" [400, 500],\n",
" [700, 150],\n",
" [500, 200],\n",
" [100, 300],\n",
" [700, 300],\n",
" [600, 350],\n",
" [400, 300]]"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"tar_pos_back = list(reversed(tar_pos))"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[[400, 300],\n",
" [600, 350],\n",
" [700, 300],\n",
" [100, 300],\n",
" [500, 200],\n",
" [700, 150],\n",
" [400, 500],\n",
" [200, 250],\n",
" [400, 300],\n",
" [650, 100],\n",
" [600, 250],\n",
" [500, 400],\n",
" [700, 500],\n",
" [100, 100],\n",
" [300, 200],\n",
" [700, 100],\n",
" [150, 100],\n",
" [150, 500],\n",
" [100, 150],\n",
" [300, 400],\n",
" [200, 350],\n",
" [100, 500],\n",
" [700, 450],\n",
" [100, 450],\n",
" [400, 100],\n",
" [650, 500],\n",
" [400, 300]]"
]
},
"metadata": {},
"execution_count": 63
}
],
"source": [
"tar_pos_back"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"tar_pos_flip = []\n",
"for i in range(rows):\n",
" tar_pos_flip.append([800 - tar_pos[i][0], 600 - tar_pos[i][1]])"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[[400, 300],\n",
" [150, 100],\n",
" [400, 500],\n",
" [700, 150],\n",
" [100, 150],\n",
" [700, 100],\n",
" [600, 250],\n",
" [500, 200],\n",
" [700, 450],\n",
" [650, 100],\n",
" [650, 500],\n",
" [100, 500],\n",
" [500, 400],\n",
" [700, 500],\n",
" [100, 100],\n",
" [300, 200],\n",
" [200, 350],\n",
" [150, 500],\n",
" [400, 300],\n",
" [600, 350],\n",
" [400, 100],\n",
" [100, 450],\n",
" [300, 400],\n",
" [700, 300],\n",
" [100, 300],\n",
" [200, 250],\n",
" [400, 300]]"
]
},
"metadata": {},
"execution_count": 50
}
],
"source": [
"tar_pos_flip_back = list(reversed(tar_pos_flip))\n",
"tar_pos_flip"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[[400, 300],\n",
" [200, 250],\n",
" [100, 300],\n",
" [700, 300],\n",
" [300, 400],\n",
" [100, 450],\n",
" [400, 100],\n",
" [600, 350],\n",
" [400, 300],\n",
" [150, 500],\n",
" [200, 350],\n",
" [300, 200],\n",
" [100, 100],\n",
" [700, 500],\n",
" [500, 400],\n",
" [100, 500],\n",
" [650, 500],\n",
" [650, 100],\n",
" [700, 450],\n",
" [500, 200],\n",
" [600, 250],\n",
" [700, 100],\n",
" [100, 150],\n",
" [700, 150],\n",
" [400, 500],\n",
" [150, 100],\n",
" [400, 300]]"
]
},
"metadata": {},
"execution_count": 49
}
],
"source": [
"tar_pos_flip_back"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}
\ No newline at end of file
%% Cell type:code id: tags:
``` python
rows, cols = (27,2)
```
%% Cell type:code id: tags:
``` python
tar_pos = [[400, 300],
[650, 500],
[400, 100],
[100, 450],
[700, 450],
[100, 500],
[200, 350],
[300, 400],
[100, 150],
[150, 500],
[150, 100],
[700, 100],
[300, 200],
[100, 100],
[700, 500],
[500, 400],
[600, 250],
[650, 100],
[400, 300],
[200, 250],
[400, 500],
[700, 150],
[500, 200],
[100, 300],
[700, 300],
[600, 350],
[400, 300]]
```
%% Cell type:code id: tags:
``` python
tar_pos_back = list(reversed(tar_pos))
```
%% Cell type:code id: tags:
``` python
tar_pos_back
```
%%%% Output: execute_result
[[400, 300],
[600, 350],
[700, 300],
[100, 300],
[500, 200],
[700, 150],
[400, 500],
[200, 250],
[400, 300],
[650, 100],
[600, 250],
[500, 400],
[700, 500],
[100, 100],
[300, 200],
[700, 100],
[150, 100],
[150, 500],
[100, 150],
[300, 400],
[200, 350],
[100, 500],
[700, 450],
[100, 450],
[400, 100],
[650, 500],
[400, 300]]
%% Cell type:code id: tags:
``` python
tar_pos_flip = []
for i in range(rows):
tar_pos_flip.append([800 - tar_pos[i][0], 600 - tar_pos[i][1]])
```
%% Cell type:code id: tags:
``` python
tar_pos_flip_back = list(reversed(tar_pos_flip))
tar_pos_flip
```
%%%% Output: execute_result
[[400, 300],
[150, 100],
[400, 500],
[700, 150],
[100, 150],
[700, 100],
[600, 250],
[500, 200],
[700, 450],
[650, 100],
[650, 500],
[100, 500],
[500, 400],
[700, 500],
[100, 100],
[300, 200],
[200, 350],
[150, 500],
[400, 300],
[600, 350],
[400, 100],
[100, 450],
[300, 400],
[700, 300],
[100, 300],
[200, 250],
[400, 300]]
%% Cell type:code id: tags:
``` python
tar_pos_flip_back
```
%%%% Output: execute_result
[[400, 300],
[200, 250],
[100, 300],
[700, 300],
[300, 400],
[100, 450],
[400, 100],
[600, 350],
[400, 300],
[150, 500],
[200, 350],
[300, 200],
[100, 100],
[700, 500],
[500, 400],
[100, 500],
[650, 500],
[650, 100],
[700, 450],
[500, 200],
[600, 250],
[700, 100],
[100, 150],
[700, 150],
[400, 500],
[150, 100],
[400, 300]]
%% Cell type:code id: tags:
``` python
```
def get_block_targets():
"""
Computes the targets for the 5 blocks of the calibration task
blocks: TarPos, TarPosBack, TarPosFlip, TarPosFlipBack, TarPos
"""
num_coordinates = 27
# Define the target positions for the first block
tar_pos = [
[400, 300],
[650, 500],
[400, 100],
[100, 450],
[700, 450],
[100, 500],
[200, 350],
[300, 400],
[100, 150],
[150, 500],
[150, 100],
[700, 100],
[300, 200],
[100, 100],
[700, 500],
[500, 400],
[600, 250],
[650, 100],
[400, 300],
[200, 250],
[400, 500],
[700, 150],
[500, 200],
[100, 300],
[700, 300],
[600, 350],
[400, 300]
]
# Compute the targets for the second block
tar_pos_back = list(reversed(tar_pos))
# Compute the third block
tar_pos_flip = []
for i in range(num_coordinates):
tar_pos_flip.append([800 - tar_pos[i][0], 600 - tar_pos[i][1]])
# Compute the fourth block
tar_pos_flip_back = list(reversed(tar_pos_flip))
return [tar_pos, tar_pos_back, tar_pos_flip, tar_pos_flip_back, tar_pos]
\ No newline at end of file
......@@ -12,6 +12,8 @@ from sklearn import preprocessing
import os
import pandas as pd
from utils.calibration_task_preprocessing import get_block_targets
def max_fixation_duration():
"""
......@@ -56,6 +58,9 @@ def load_regression_data(verbose=True):
elif config['data_mode'] == 'fix_only':
logging.info("Using fixation only dataset")
return get_fix_data(verbose=verbose)
elif config['data_mode'] == 'calib_task_fix_sacc_fix':
logging.info("Using fixation-saccade-fixation from calibration task dataset")
return get_calibration_task_fix_sacc_fix_data(verbose=verbose)
else:
raise Exception("Choose a valid data_mode in config.py")
......@@ -468,21 +473,45 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
"""
Returns X, y for the gaze regression task, specifically the calibration task dataset, with EEG data X only from triplets of fixation-saccade-fixation
"""
L_saccade = 'L_saccade'
L_fixation = 'L_fixation'
L_blink = 'L_blink'
R_saccade = 'R_saccade'
R_fixation = 'R_fixation'
R_blink = 'R_blink'
blockOn = '55'
blockOff = '56'
current_block = 0
block_list = ['TarPos', 'TarPosBack', 'TarPosFlip', 'TarPosFlipBack', 'TarPos']
#import hashlib
# somehow the string comparisons didn't work properly because of the mixed numbers and string as event names, therefore I hash the string and compare them
# Event names
L_saccade = 'L_saccade' #hashlib.md5(str('L_saccade').encode()).hexdigest()
L_fixation = 'L_fixation' #hashlib.md5(str('L_fixation').encode()).hexdigest()
L_blink = 'L_blink' #hashlib.md5(str('L_blink').encode()).hexdigest()
R_saccade = 'R_saccade' #hashlib.md5(str('R_saccade').encode()).hexdigest()
R_fixation = 'R_fixation' #hashlib.md5(str('R_fixation').encode()).hexdigest()
R_blink = 'R_blink' #hashlib.md5(str('R_blink').encode()).hexdigest()
blockOn = 55 #.md5(str('55').encode()).hexdigest()
blockOff = 56 #hashlib.md5(str('56').encode()).hexdigest()
stimulusOff = 41 #hashlib.md5(str('L_saccade').encode()).hexdigest()
dots_start = [12, 13, 14, 15, 16, 17]
#dots_start = [hashlib.md5(str('12').encode()).hexdigest(), hashlib.md5(str('13').encode()).hexdigest(), hashlib.md5(str('14').encode()).hexdigest(),
#hashlib.md5(str('15').encode()).hexdigest(), hashlib.md5(str('16').encode()).hexdigest(), hashlib.md5(str('17').encode()).hexdigest()]
# Define lists of event names for easier search
fixations = [L_fixation, R_fixation]
saccades = [L_saccade, R_saccade]
blinks = [L_blink, R_blink]
non_stimulus_strings = fixations + saccades + blinks
non_stimulus_ints = [blockOff, blockOn, stimulusOff] + dots_start
print(non_stimulus_ints)
print(non_stimulus_strings)
#print(dots_start)
# Threshold for the distance between fixation_avg_pos and the expected label
THRESH_DIFF = 30
# Compute the positions of the five blocks, are returned as [tarpos, tarposback, tarposflip, tarposflipback, tarpos]
block_list = get_block_targets()
# Intialize current target and block
current_target_list = block_list[0] # extract the first target list [tar1, tar2, ..., tar27]
current_target = current_target_list[0] # extract the first expected target [x, y]
next_target = current_target_list[1] # this would be the target for the next fixation
# Loop over all directories and extract and concat the events from all people
rootdir = './data/calibration_task' # modify it if necessary
......@@ -490,6 +519,9 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
x_list = []
y_list = []
# Start searching for events
isBlock = True
for subdir, dirs, files in os.walk(rootdir):
for file in files:
# Get the correct path to the current file
......@@ -497,30 +529,151 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
events = load_sEEG_events(path) # access event i via events[i]
data = load_sEEG_data(path)
# We now look for fixation-sacade-fixation triplets without blinks or other actions inbetween
# Extract X and y from sEEG.data and sEEG.events
# Reset target and block counter for the new file
block_cnt = -1 # index for the block we take from the target list
target_cnt = -1 # index for the current target in the current block
"""
We now look for fixation-sacade-fixation triplets without blinks or other actions inbetween
For that matter we look for [fixation, stimulusOff, newStimulus, saccade, fixation] tuples,
where both fixations are within a thresholded distance of the expected coordinates
"""
for i in range(len(events)):
first_event = events[i]
first_event_name = first_event[0][0] # dereference the event name, e.g. 'L_saccade'
if first_event_name not in fixations:
continue # we first want a saccade
# Read the current event
event_1 = events[i]
event_1_name = event_1[0][0] # dereference the event name, e.g. 'L_saccade'
#event_1_name = hashlib.md5(str(event_1_name).encode()).hexdigest()
print("{}".format(i) + ": " + event_1_name)
try:
num = int(event_1_name)
isNumeric = True
except:
isNumeric = False
#print("is numeric: {}".format(isNumeric))
# Check wheter the file starts
if isNumeric:
if int(event_1_name) in dots_start:
#print("dots start")
continue # this was the start of the DOTS file
if int(event_1_name) == stimulusOff:
#print("found 41 stimulus off")
continue
if int(event_1_name) == blockOff:
#print("block off")
isBlock = False # Block ends, therefore do not read any of the next events
continue
if int(event_1_name) == blockOn:
#print("block on")
isBlock = True # start looking for events again
block_cnt += 1 # increment to next block
current_target_list = block_list[block_cnt] # get the target list of the next block
target_cnt = -1 # reset the target pointer to first element of new list
continue
# Ignore events while we are outside of a block
if not isBlock:
#("not within block")
continue
# Consume the new stimulus signal and increment the target
if isNumeric:
if event_1_name not in non_stimulus_ints:
#print("consume stimulus")
if target_cnt <= 26:
target_cnt += 1
current_target = current_target_list[target_cnt]
if target_cnt <= 25: # otherwise we are at the end of the current block and will change targets after block ends
next_target = current_target_list[target_cnt+1]
continue
# From here on it must be a fixation first, therefore if the current event is numeric, we can continue
if isNumeric:
continue
# From here on we look for the quintuple
if event_1_name not in fixations:
#("first event wasn't a fixation")
continue # we first want a fixation
#print("look for stimulusoff trigger")
# We found a fixation as first event, now look for a following saccade
second_event = events[i + 1]
second_event_name = second_event[0][0]