forked from r-zemblys/irf
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_training.py
More file actions
203 lines (162 loc) · 5.68 KB
/
Copy pathrun_training.py
File metadata and controls
203 lines (162 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Raimondas Zemblys
@email: [email protected]
"""
#%% imports
import os, sys, glob
from distutils.dir_util import mkpath
from tqdm import tqdm
import numpy as np
import pandas as pd
###
import argparse
import multiprocessing
import json
from datetime import datetime
import parse
from sklearn.ensemble import RandomForestClassifier
import joblib
from sklearn.metrics import cohen_kappa_score
from util_lib.etdata import ETData
from util_lib.utils import split_path
from util_lib.irf import extractFeatures, ft_all, get_i2mc
#%%
def get_arguments():
'''Parses command line arguments
'''
parser = argparse.ArgumentParser(description='Eye-movement event detection '
'using Random Forest.')
parser.add_argument('root', type=str,
help='The path containing eye-movement data.')
parser.add_argument('dataset', type=str,
help='The directory containing experiment data.')
parser.add_argument('--output_dir', type=str, default=None,
help='The directory to save output.')
parser.add_argument('--workers', type=int, default=4,
help='Number of workers to use.')
args = parser.parse_args()
return args
#%% Setup parameters and variables
args = get_arguments()
exp_output = args.output_dir if not(args.output_dir is None)\
else '%s_irf' % (args.dataset)
mkpath('%s/%s'%(args.root, exp_output))
db_path = '%s/%s'%(args.root, args.dataset)
n_avail_cores = multiprocessing.cpu_count()
n_jobs = args.workers if not(args.workers is None) else n_avail_cores
#load config
with open('config.json', 'r') as f:
config = json.load(f)
with open('%s/db_config.json'%db_path, 'r') as f:
db_config = json.load(f)
config['geom'] = db_config['geom']
etdata = ETData()
#%%Extract features for train and val
FILES = sorted(glob.glob('%s/%s/*[train][val]*/*.npy'%(args.root,args.dataset)))
i2mc_ok = True
for fpath in tqdm(FILES[:]):
fdir, fname = split_path(fpath)
odir = fdir.replace(args.dataset, exp_output)
#create output dirs
odir_feat = '%s/feat'%(odir)
odir_evt = '%s/evt'%(odir)
mkpath(odir_feat)
mkpath(odir_evt)
spath_feat = '%s/feat_%s.npy' % (odir_feat, fname)
spath_evt = '%s/evt_%s.npy' % (odir_evt, fname)
#check if feature files already exist
if not(os.path.exists(spath_feat)) or not(os.path.exists(spath_evt)):
etdata.load(fpath)
#remove other events
evt_mask = np.in1d(etdata.data['evt'], config["events"])
etdata.data['x'][~evt_mask] = np.nan
etdata.data['y'][~evt_mask] = np.nan
etdata.data['status'][~evt_mask] = False
#extract features
if 'i2mc' in config['features']:
fpath_i2mc = '%s/i2mc/%s_i2mc.mat'%(odir, fname)
i2mc = get_i2mc(etdata, fpath_i2mc, config['geom'])
if i2mc is None:
i2mc_ok = False
continue
else:
config['extr_kwargs']['i2mc'] = i2mc
irf_features, pred_mask = extractFeatures(etdata, **config['extr_kwargs'])
#select required features
X = irf_features[ft_all]
new_dtype = [(name, np.float32) for name in X.dtype.names]
# Create a new structured array with float32 data types
X = X.astype(new_dtype)
X = X.view(np.float32).reshape(X.shape + (-1,))
y = etdata.data['evt'][pred_mask]
#check for other events
assert np.in1d(np.unique(y), config["events"]).all()
#check lengths
assert (len(X)==len(y)) and (len(X)==pred_mask.sum())
#save
np.save(spath_feat, X)
np.save(spath_evt, y)
if not i2mc_ok:
sys.exit()
#%%Load features and train IRF
FILES = sorted(glob.glob('%s/%s/train/*.npy'%(args.root, args.dataset)))
X = []
y = []
#TODO: handle different ordering
ft_mask = np.in1d(ft_all, config['features'])
for fpath in tqdm(FILES[:]):
fdir, fname = split_path(fpath)
fdir = fdir.replace(args.dataset, exp_output)
path_feat = '%s/feat/feat_%s.npy' % (fdir, fname)
path_evt = '%s/evt/evt_%s.npy' % (fdir, fname)
_x = np.load(path_feat)[:,ft_mask]
_y = np.load(path_evt)
X.append(_x)
y.append(_y)
X = np.concatenate(X)
y = np.concatenate(y)
#train IRF
clf = RandomForestClassifier(
n_estimators=config["n_trees"],
max_depth=None,
class_weight='balanced_subsample',
max_features=3,
n_jobs=n_jobs,
verbose=3,
)
clf.fit(X, y)
#save
print ('Saving model...')
spath = 'models/irf_%s' % datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
mkpath(spath)
joblib.dump([config["features"], clf], '%s/model.pkl'%(spath), compress=9, protocol=2)
print ('...done')
#%%evaluate
clf.set_params(verbose=0)
FILES = sorted(glob.glob('%s/%s/val/*.npy'%(args.root, args.dataset)))
fmt = 'lookAtPoint_EL_S{sub:d}_{fs:d}_{rms:.4f}{}'
result = []
for fpath in tqdm(FILES[:]):
fdir, fname = split_path(fpath)
fdir = fdir.replace(args.dataset, exp_output)
path_feat = '%s/feat/feat_%s.npy' % (fdir, fname)
path_evt = '%s/evt/evt_%s.npy' % (fdir, fname)
#load data
_x = np.load(path_feat)[:,ft_mask]
_y = np.load(path_evt)
#predict
pred_val = clf.predict_proba(_x)
pred_val_class = np.argmax(pred_val, axis=1)+1
#evaluate
k = cohen_kappa_score(_y, pred_val_class)
#save
try:
_p = parse.parse(fmt, fname).named
fs, sub, rms = _p['fs'], _p['sub'], _p['rms']
except:
fs = sub = rms = None
result.append([fname, sub, fs, rms, k])
result_df = pd.DataFrame(result, columns = ['fname', 'sub', 'fs', 'rms', 'k'])
result_df.to_csv('%s/result_val.csv'%spath, index=False)