Skip to content

Commit e3e8cb9

Browse files
ludovicdmtnbara
andauthored
TRCA variation (#39)
* Add riemann mean variation to TRCA Regularization in covariance matrices estimations + riemannian mean instead of euclid mean for S computation * Output of example notebook * Fix docstring * Still docstring errors * Docstring fixes * Fix docstring * Still docstring errors * Docstring fixes * Missing blank line * style fixes * comments * illustrate example + fix tests * title * Update requirements.txt Co-authored-by: nbara <10333715+nbara@users.noreply.github.com>
1 parent 4aa4ba4 commit e3e8cb9

File tree

7 files changed

+437
-189
lines changed

7 files changed

+437
-189
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# MEEGkit
22

33
[![unit-tests](https://github.com/nbara/python-meegkit/workflows/unit-tests/badge.svg?style=flat)](https://github.com/nbara/python-meegkit/actions?workflow=unit-tests)
4-
[![documentation](https://img.shields.io/travis/nbara/python-meegkit.svg?label=documentation&logo=travis)](https://travis-ci.org/nbara/python-meegkit)
4+
[![documentation](https://img.shields.io/travis/nbara/python-meegkit.svg?label=documentation&logo=travis)](https://www.travis-ci.com/github/nbara/python-meegkit)
55
[![codecov](https://codecov.io/gh/nbara/python-meegkit/branch/master/graph/badge.svg)](https://codecov.io/gh/nbara/python-meegkit)
66
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/nbara/python-meegkit/master)
77
[![twitter](https://img.shields.io/twitter/follow/lebababa?label=Twitter&style=flat&logo=Twitter)](https://twitter.com/intent/follow?screen_name=lebababa)

examples/example_trca.ipynb

Lines changed: 110 additions & 61 deletions
Large diffs are not rendered by default.

examples/example_trca.py

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""
2-
Task-related component analysis (TRCA)-based SSVEP detection
3-
============================================================
2+
Task-related component analysis for SSVEP detection
3+
===================================================
44
55
Sample code for the task-related component analysis (TRCA)-based steady
66
-state visual evoked potential (SSVEP) detection method [1]_. The filter
7-
bank analysis [2, 3]_ can also be combined to the TRCA-based algorithm.
7+
bank analysis can also be combined to the TRCA-based algorithm [2]_ [3]_.
88
9-
Uses meegkit.trca.TRCA()
9+
This code is based on the Matlab implementation from:
10+
https://github.com/mnakanishi/TRCA-SSVEP
11+
12+
Uses `meegkit.trca.TRCA()`.
1013
1114
References:
1215
@@ -21,14 +24,13 @@
2124
"High-speed spelling with a noninvasive brain-computer interface",
2225
Proc. Int. Natl. Acad. Sci. U. S. A, 112(44): E6058-6067, 2015.
2326
24-
This code is based on the Matlab implementation from
25-
https://github.com/mnakanishi/TRCA-SSVEP
26-
2727
"""
28-
# Author: Giuseppe Ferraro <giuseppe.ferraro@isae-supaero.fr>
28+
# Authors: Giuseppe Ferraro <giuseppe.ferraro@isae-supaero.fr>
29+
# Nicolas Barascud <nicolas.barascud@gmail.com>
2930
import os
3031
import time
3132

33+
import matplotlib.pyplot as plt
3234
import numpy as np
3335
import scipy.io
3436
from meegkit.trca import TRCA
@@ -39,78 +41,100 @@
3941
###############################################################################
4042
# Parameters
4143
# -----------------------------------------------------------------------------
42-
len_gaze_s = 0.5 # data length for target identification [s]
43-
len_delay_s = 0.13 # visual latency being considered in the analysis [s]
44+
dur_gaze = 0.5 # data length for target identification [s]
45+
delay = 0.13 # visual latency being considered in the analysis [s]
4446
n_bands = 5 # number of sub-bands in filter bank analysis
4547
is_ensemble = True # True = ensemble TRCA method; False = TRCA method
4648
alpha_ci = 0.05 # 100*(1-alpha_ci): confidence interval for accuracy
4749
sfreq = 250 # sampling rate [Hz]
48-
len_shift_s = 0.5 # duration for gaze shifting [s]
49-
list_freqs = np.concatenate(
50-
[[x + 8 for x in range(8)],
50+
dur_shift = 0.5 # duration for gaze shifting [s]
51+
list_freqs = np.array(
52+
[[x + 8.0 for x in range(8)],
5153
[x + 8.2 for x in range(8)],
5254
[x + 8.4 for x in range(8)],
5355
[x + 8.6 for x in range(8)],
54-
[x + 8.8 for x in range(8)]]) # list of stimulus frequencies
55-
n_targets = len(list_freqs) # The number of stimuli
56+
[x + 8.8 for x in range(8)]]).T # list of stimulus frequencies
57+
n_targets = list_freqs.size # The number of stimuli
5658

57-
# Preparing useful variables (DONT'T need to modify)
58-
len_gaze_smpl = round_half_up(len_gaze_s * sfreq) # data length [samples]
59-
len_delay_smpl = round_half_up(len_delay_s * sfreq) # visual latency [samples]
60-
len_sel_s = len_gaze_s + len_shift_s # selection time [s]
59+
# Useful variables (no need to modify)
60+
dur_gaze_s = round_half_up(dur_gaze * sfreq) # data length [samples]
61+
delay_s = round_half_up(delay * sfreq) # visual latency [samples]
62+
dur_sel_s = dur_gaze + dur_shift # selection time [s]
6163
ci = 100 * (1 - alpha_ci) # confidence interval
6264

63-
# Performing the TRCA-based SSVEP detection algorithm
64-
print('Results of the ensemble TRCA-based method:\n')
65-
6665
###############################################################################
6766
# Load data
6867
# -----------------------------------------------------------------------------
6968
path = os.path.join('..', 'tests', 'data', 'trcadata.mat')
70-
mat = scipy.io.loadmat(path)
71-
eeg = mat["eeg"]
69+
eeg = scipy.io.loadmat(path)["eeg"]
7270

73-
n_trials = eeg.shape[0]
74-
n_chans = eeg.shape[1]
75-
n_samples = eeg.shape[2]
76-
n_blocks = eeg.shape[3]
71+
n_trials, n_chans, n_samples, n_blocks = eeg.shape
7772

7873
# Convert dummy Matlab format to (sample, channels, trials) and construct
7974
# vector of labels
8075
eeg = np.reshape(eeg.transpose([2, 1, 3, 0]),
8176
(n_samples, n_chans, n_trials * n_blocks))
8277
labels = np.array([x for x in range(n_targets)] * n_blocks)
83-
84-
crop_data = np.arange(len_delay_smpl, len_delay_smpl + len_gaze_smpl)
78+
crop_data = np.arange(delay_s, delay_s + dur_gaze_s)
8579
eeg = eeg[crop_data]
8680

8781
###############################################################################
8882
# TRCA classification
8983
# -----------------------------------------------------------------------------
9084
# Estimate classification performance with a Leave-One-Block-Out
9185
# cross-validation approach.
86+
#
87+
# To get a sense of the filterbank specification in relation to the stimuli
88+
# we can plot the individual filterbank sub-bands as well as the target
89+
# frequencies (with their expected harmonics in the EEG spectrum). We use the
90+
# filterbank specification described in [2]_.
9291

93-
# We use the filterbank specification described in [2]_.
9492
filterbank = [[(6, 90), (4, 100)], # passband, stopband freqs [(Wp), (Ws)]
9593
[(14, 90), (10, 100)],
9694
[(22, 90), (16, 100)],
9795
[(30, 90), (24, 100)],
9896
[(38, 90), (32, 100)],
9997
[(46, 90), (40, 100)],
10098
[(54, 90), (48, 100)]]
99+
100+
f, ax = plt.subplots(1, figsize=(7, 4))
101+
for i, band in enumerate(filterbank):
102+
ax.axvspan(ymin=i / len(filterbank) + .02,
103+
ymax=(i + 1) / len(filterbank) - .02,
104+
xmin=filterbank[i][1][0], xmax=filterbank[i][1][1],
105+
alpha=0.2, facecolor=f'C{i}')
106+
ax.axvspan(ymin=i / len(filterbank) + .02,
107+
ymax=(i + 1) / len(filterbank) - .02,
108+
xmin=filterbank[i][0][0], xmax=filterbank[i][0][1],
109+
alpha=0.5, label=f'sub-band{i}', facecolor=f'C{i}')
110+
111+
for f in list_freqs.flat:
112+
colors = np.ones((9, 4))
113+
colors[:, :3] = np.linspace(0, .5, 9)[:, None]
114+
ax.scatter(f * np.arange(1, 10), [f] * 9, c=colors, s=8, zorder=100)
115+
116+
ax.set_ylabel('Stimulus frequency (Hz)')
117+
ax.set_xlabel('EEG response frequency (Hz)')
118+
ax.set_xlim([0, 102])
119+
ax.set_xticks(np.arange(0, 100, 10))
120+
ax.grid(True, ls=':', axis='x')
121+
ax.legend(bbox_to_anchor=(1.05, .5), fontsize='small')
122+
plt.tight_layout()
123+
plt.show()
124+
125+
###############################################################################
126+
# Now perform the TRCA-based SSVEP detection algorithm
101127
trca = TRCA(sfreq, filterbank, is_ensemble)
102128

129+
print('Results of the ensemble TRCA-based method:\n')
103130
accs = np.zeros(n_blocks)
104131
itrs = np.zeros(n_blocks)
105132
for i in range(n_blocks):
106133

107-
# Training stage
108-
traindata = eeg.copy()
109-
110134
# Select all folds except one for training
111135
traindata = np.concatenate(
112-
(traindata[..., :i * n_trials],
113-
traindata[..., (i + 1) * n_trials:]), 2)
136+
(eeg[..., :i * n_trials],
137+
eeg[..., (i + 1) * n_trials:]), 2)
114138
y_train = np.concatenate(
115139
(labels[:i * n_trials], labels[(i + 1) * n_trials:]), 0)
116140

@@ -125,16 +149,15 @@
125149
# Evaluation of the performance for this fold (accuracy and ITR)
126150
is_correct = estimated == y_test
127151
accs[i] = np.mean(is_correct) * 100
128-
itrs[i] = itr(n_targets, np.mean(is_correct), len_sel_s)
152+
itrs[i] = itr(n_targets, np.mean(is_correct), dur_sel_s)
129153
print(f"Block {i}: accuracy = {accs[i]:.1f}, \tITR = {itrs[i]:.1f}")
130154

131155
# Mean accuracy and ITR computation
132156
mu, _, muci, _ = normfit(accs, alpha_ci)
133-
print()
134-
print(f"Mean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa
157+
print(f"\nMean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa
135158

136159
mu, _, muci, _ = normfit(itrs, alpha_ci)
137-
print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)")
160+
print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f})")
138161
if is_ensemble:
139162
ensemble = 'ensemble TRCA-based method'
140163
else:

0 commit comments

Comments
 (0)