|
1 | 1 | """ |
2 | | -Task-related component analysis (TRCA)-based SSVEP detection |
3 | | -============================================================ |
| 2 | +Task-related component analysis for SSVEP detection |
| 3 | +=================================================== |
4 | 4 |
|
5 | 5 | Sample code for the task-related component analysis (TRCA)-based steady |
6 | 6 | -state visual evoked potential (SSVEP) detection method [1]_. The filter |
7 | | -bank analysis [2, 3]_ can also be combined to the TRCA-based algorithm. |
| 7 | +bank analysis can also be combined to the TRCA-based algorithm [2]_ [3]_. |
8 | 8 |
|
9 | | -Uses meegkit.trca.TRCA() |
| 9 | +This code is based on the Matlab implementation from: |
| 10 | +https://github.com/mnakanishi/TRCA-SSVEP |
| 11 | +
|
| 12 | +Uses `meegkit.trca.TRCA()`. |
10 | 13 |
|
11 | 14 | References: |
12 | 15 |
|
|
21 | 24 | "High-speed spelling with a noninvasive brain-computer interface", |
22 | 25 | Proc. Int. Natl. Acad. Sci. U. S. A, 112(44): E6058-6067, 2015. |
23 | 26 |
|
24 | | -This code is based on the Matlab implementation from |
25 | | -https://github.com/mnakanishi/TRCA-SSVEP |
26 | | -
|
27 | 27 | """ |
28 | | -# Author: Giuseppe Ferraro <giuseppe.ferraro@isae-supaero.fr> |
| 28 | +# Authors: Giuseppe Ferraro <giuseppe.ferraro@isae-supaero.fr> |
| 29 | +# Nicolas Barascud <nicolas.barascud@gmail.com> |
29 | 30 | import os |
30 | 31 | import time |
31 | 32 |
|
| 33 | +import matplotlib.pyplot as plt |
32 | 34 | import numpy as np |
33 | 35 | import scipy.io |
34 | 36 | from meegkit.trca import TRCA |
|
39 | 41 | ############################################################################### |
40 | 42 | # Parameters |
41 | 43 | # ----------------------------------------------------------------------------- |
42 | | -len_gaze_s = 0.5 # data length for target identification [s] |
43 | | -len_delay_s = 0.13 # visual latency being considered in the analysis [s] |
| 44 | +dur_gaze = 0.5 # data length for target identification [s] |
| 45 | +delay = 0.13 # visual latency being considered in the analysis [s] |
44 | 46 | n_bands = 5 # number of sub-bands in filter bank analysis |
45 | 47 | is_ensemble = True # True = ensemble TRCA method; False = TRCA method |
46 | 48 | alpha_ci = 0.05 # 100*(1-alpha_ci): confidence interval for accuracy |
47 | 49 | sfreq = 250 # sampling rate [Hz] |
48 | | -len_shift_s = 0.5 # duration for gaze shifting [s] |
49 | | -list_freqs = np.concatenate( |
50 | | - [[x + 8 for x in range(8)], |
| 50 | +dur_shift = 0.5 # duration for gaze shifting [s] |
| 51 | +list_freqs = np.array( |
| 52 | + [[x + 8.0 for x in range(8)], |
51 | 53 | [x + 8.2 for x in range(8)], |
52 | 54 | [x + 8.4 for x in range(8)], |
53 | 55 | [x + 8.6 for x in range(8)], |
54 | | - [x + 8.8 for x in range(8)]]) # list of stimulus frequencies |
55 | | -n_targets = len(list_freqs) # The number of stimuli |
| 56 | + [x + 8.8 for x in range(8)]]).T # list of stimulus frequencies |
| 57 | +n_targets = list_freqs.size # The number of stimuli |
56 | 58 |
|
57 | | -# Preparing useful variables (DONT'T need to modify) |
58 | | -len_gaze_smpl = round_half_up(len_gaze_s * sfreq) # data length [samples] |
59 | | -len_delay_smpl = round_half_up(len_delay_s * sfreq) # visual latency [samples] |
60 | | -len_sel_s = len_gaze_s + len_shift_s # selection time [s] |
| 59 | +# Useful variables (no need to modify) |
| 60 | +dur_gaze_s = round_half_up(dur_gaze * sfreq) # data length [samples] |
| 61 | +delay_s = round_half_up(delay * sfreq) # visual latency [samples] |
| 62 | +dur_sel_s = dur_gaze + dur_shift # selection time [s] |
61 | 63 | ci = 100 * (1 - alpha_ci) # confidence interval |
62 | 64 |
|
63 | | -# Performing the TRCA-based SSVEP detection algorithm |
64 | | -print('Results of the ensemble TRCA-based method:\n') |
65 | | - |
66 | 65 | ############################################################################### |
67 | 66 | # Load data |
68 | 67 | # ----------------------------------------------------------------------------- |
69 | 68 | path = os.path.join('..', 'tests', 'data', 'trcadata.mat') |
70 | | -mat = scipy.io.loadmat(path) |
71 | | -eeg = mat["eeg"] |
| 69 | +eeg = scipy.io.loadmat(path)["eeg"] |
72 | 70 |
|
73 | | -n_trials = eeg.shape[0] |
74 | | -n_chans = eeg.shape[1] |
75 | | -n_samples = eeg.shape[2] |
76 | | -n_blocks = eeg.shape[3] |
| 71 | +n_trials, n_chans, n_samples, n_blocks = eeg.shape |
77 | 72 |
|
78 | 73 | # Convert dummy Matlab format to (sample, channels, trials) and construct |
79 | 74 | # vector of labels |
80 | 75 | eeg = np.reshape(eeg.transpose([2, 1, 3, 0]), |
81 | 76 | (n_samples, n_chans, n_trials * n_blocks)) |
82 | 77 | labels = np.array([x for x in range(n_targets)] * n_blocks) |
83 | | - |
84 | | -crop_data = np.arange(len_delay_smpl, len_delay_smpl + len_gaze_smpl) |
| 78 | +crop_data = np.arange(delay_s, delay_s + dur_gaze_s) |
85 | 79 | eeg = eeg[crop_data] |
86 | 80 |
|
87 | 81 | ############################################################################### |
88 | 82 | # TRCA classification |
89 | 83 | # ----------------------------------------------------------------------------- |
90 | 84 | # Estimate classification performance with a Leave-One-Block-Out |
91 | 85 | # cross-validation approach. |
| 86 | +# |
| 87 | +# To get a sense of the filterbank specification in relation to the stimuli |
| 88 | +# we can plot the individual filterbank sub-bands as well as the target |
| 89 | +# frequencies (with their expected harmonics in the EEG spectrum). We use the |
| 90 | +# filterbank specification described in [2]_. |
92 | 91 |
|
93 | | -# We use the filterbank specification described in [2]_. |
94 | 92 | filterbank = [[(6, 90), (4, 100)], # passband, stopband freqs [(Wp), (Ws)] |
95 | 93 | [(14, 90), (10, 100)], |
96 | 94 | [(22, 90), (16, 100)], |
97 | 95 | [(30, 90), (24, 100)], |
98 | 96 | [(38, 90), (32, 100)], |
99 | 97 | [(46, 90), (40, 100)], |
100 | 98 | [(54, 90), (48, 100)]] |
| 99 | + |
| 100 | +f, ax = plt.subplots(1, figsize=(7, 4)) |
| 101 | +for i, band in enumerate(filterbank): |
| 102 | + ax.axvspan(ymin=i / len(filterbank) + .02, |
| 103 | + ymax=(i + 1) / len(filterbank) - .02, |
| 104 | + xmin=filterbank[i][1][0], xmax=filterbank[i][1][1], |
| 105 | + alpha=0.2, facecolor=f'C{i}') |
| 106 | + ax.axvspan(ymin=i / len(filterbank) + .02, |
| 107 | + ymax=(i + 1) / len(filterbank) - .02, |
| 108 | + xmin=filterbank[i][0][0], xmax=filterbank[i][0][1], |
| 109 | + alpha=0.5, label=f'sub-band{i}', facecolor=f'C{i}') |
| 110 | + |
| 111 | +for f in list_freqs.flat: |
| 112 | + colors = np.ones((9, 4)) |
| 113 | + colors[:, :3] = np.linspace(0, .5, 9)[:, None] |
| 114 | + ax.scatter(f * np.arange(1, 10), [f] * 9, c=colors, s=8, zorder=100) |
| 115 | + |
| 116 | +ax.set_ylabel('Stimulus frequency (Hz)') |
| 117 | +ax.set_xlabel('EEG response frequency (Hz)') |
| 118 | +ax.set_xlim([0, 102]) |
| 119 | +ax.set_xticks(np.arange(0, 100, 10)) |
| 120 | +ax.grid(True, ls=':', axis='x') |
| 121 | +ax.legend(bbox_to_anchor=(1.05, .5), fontsize='small') |
| 122 | +plt.tight_layout() |
| 123 | +plt.show() |
| 124 | + |
| 125 | +############################################################################### |
| 126 | +# Now perform the TRCA-based SSVEP detection algorithm |
101 | 127 | trca = TRCA(sfreq, filterbank, is_ensemble) |
102 | 128 |
|
| 129 | +print('Results of the ensemble TRCA-based method:\n') |
103 | 130 | accs = np.zeros(n_blocks) |
104 | 131 | itrs = np.zeros(n_blocks) |
105 | 132 | for i in range(n_blocks): |
106 | 133 |
|
107 | | - # Training stage |
108 | | - traindata = eeg.copy() |
109 | | - |
110 | 134 | # Select all folds except one for training |
111 | 135 | traindata = np.concatenate( |
112 | | - (traindata[..., :i * n_trials], |
113 | | - traindata[..., (i + 1) * n_trials:]), 2) |
| 136 | + (eeg[..., :i * n_trials], |
| 137 | + eeg[..., (i + 1) * n_trials:]), 2) |
114 | 138 | y_train = np.concatenate( |
115 | 139 | (labels[:i * n_trials], labels[(i + 1) * n_trials:]), 0) |
116 | 140 |
|
|
125 | 149 | # Evaluation of the performance for this fold (accuracy and ITR) |
126 | 150 | is_correct = estimated == y_test |
127 | 151 | accs[i] = np.mean(is_correct) * 100 |
128 | | - itrs[i] = itr(n_targets, np.mean(is_correct), len_sel_s) |
| 152 | + itrs[i] = itr(n_targets, np.mean(is_correct), dur_sel_s) |
129 | 153 | print(f"Block {i}: accuracy = {accs[i]:.1f}, \tITR = {itrs[i]:.1f}") |
130 | 154 |
|
131 | 155 | # Mean accuracy and ITR computation |
132 | 156 | mu, _, muci, _ = normfit(accs, alpha_ci) |
133 | | -print() |
134 | | -print(f"Mean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa |
| 157 | +print(f"\nMean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa |
135 | 158 |
|
136 | 159 | mu, _, muci, _ = normfit(itrs, alpha_ci) |
137 | | -print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") |
| 160 | +print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f})") |
138 | 161 | if is_ensemble: |
139 | 162 | ensemble = 'ensemble TRCA-based method' |
140 | 163 | else: |
|
0 commit comments