|
| 1 | +import numpy as np |
| 2 | +from scipy.linalg import pinv |
| 3 | +from scipy.special import lpmv |
| 4 | + |
| 5 | +def eeg_interp(EEG, bad_chans, method='spherical', t_range=None, params=None): |
| 6 | + # set defaults |
| 7 | + if method not in ('spherical','sphericalKang','sphericalCRD','sphericalfast'): |
| 8 | + raise ValueError(f"Unknown method {method}") |
| 9 | + if t_range is None: |
| 10 | + t_range = (EEG.xmin, EEG.xmax) |
| 11 | + if params is None: |
| 12 | + if method=='spherical': |
| 13 | + params = (0,4,7) |
| 14 | + elif method=='sphericalKang': |
| 15 | + params = (1e-8,3,50) |
| 16 | + elif method=='sphericalCRD': |
| 17 | + params = (1e-5,4,500) |
| 18 | + else: |
| 19 | + if len(params)!=3: |
| 20 | + raise ValueError("params must be length-3 tuple") |
| 21 | + method = 'spherical' |
| 22 | + |
| 23 | + # ensure channel locations present |
| 24 | + locs = EEG.chanlocs |
| 25 | + if not locs or any(('X' not in ch or 'Y' not in ch or 'Z' not in ch) for ch in locs): |
| 26 | + raise RuntimeError("Channel locations required for interpolation") |
| 27 | + |
| 28 | + # convert bad_chans from labels to indices if needed |
| 29 | + if isinstance(bad_chans, list) and isinstance(bad_chans[0], str): |
| 30 | + labels = [ch['labels'] for ch in locs] |
| 31 | + bad_idx = [labels.index(lbl) for lbl in bad_chans] |
| 32 | + else: |
| 33 | + bad_idx = sorted(bad_chans) |
| 34 | + |
| 35 | + good_idx = [i for i in range(EEG.nbchan) if i not in bad_idx] |
| 36 | + if method=='sphericalfast': |
| 37 | + # drop bad, later reshuffle if desired |
| 38 | + data = EEG.data.copy() |
| 39 | + data = np.delete(data, bad_idx, axis=0) |
| 40 | + EEG.data = data |
| 41 | + EEG.nbchan = data.shape[0] |
| 42 | + return EEG |
| 43 | + |
| 44 | + # extract Cartesian positions and normalize to unit sphere |
| 45 | + def _norm(ch_ids): |
| 46 | + xyz = np.vstack([ [locs[i][c] for i in ch_ids] for c in ('X','Y','Z') ]) |
| 47 | + rad = np.linalg.norm(xyz, axis=0) |
| 48 | + return xyz / rad |
| 49 | + |
| 50 | + xyz_good = _norm(good_idx) |
| 51 | + xyz_bad = _norm(bad_idx) |
| 52 | + |
| 53 | + # reshape data to (n_chan, n_timepoints) |
| 54 | + d = EEG.data.reshape(EEG.nbchan, -1) |
| 55 | + |
| 56 | + # compute interpolated signals for bad channels |
| 57 | + bad_data = spheric_spline( |
| 58 | + xelec=xyz_good[0], yelec=xyz_good[1], zelec=xyz_good[2], |
| 59 | + xbad =xyz_bad[0], ybad =xyz_bad[1], zbad =xyz_bad[2], |
| 60 | + values=d[good_idx,:], |
| 61 | + params=params |
| 62 | + ) |
| 63 | + |
| 64 | + # restore original time range if needed |
| 65 | + if t_range != (EEG.xmin, EEG.xmax): |
| 66 | + start, end = t_range |
| 67 | + ts = np.arange(EEG.nbchan) # dummy |
| 68 | + # here you would mask out-of-range portions as in MATLAB |
| 69 | + |
| 70 | + # assemble full data array |
| 71 | + full = np.zeros_like(d) |
| 72 | + full[good_idx,:] = d[good_idx,:] |
| 73 | + full[bad_idx,:] = bad_data |
| 74 | + |
| 75 | + EEG.data = full.reshape(EEG.nbchan, EEG.pnts, EEG.trials) |
| 76 | + return EEG |
| 77 | + |
| 78 | +def spheric_spline(xelec, yelec, zelec, xbad, ybad, zbad, values, params): |
| 79 | + # values: (n_good, n_points) |
| 80 | + Gelec = computeg(xelec, yelec, zelec, xelec, yelec, zelec, params) |
| 81 | + Gsph = computeg(xbad, ybad, zbad, xelec, yelec, zelec, params) |
| 82 | + |
| 83 | + meanvals = values.mean(axis=1, keepdims=True) |
| 84 | + V = values - meanvals |
| 85 | + V = np.vstack([V, np.zeros((1, V.shape[1]))]) |
| 86 | + |
| 87 | + lam = params[0] |
| 88 | + A = np.vstack([Gelec + np.eye(Gelec.shape[0])*lam, |
| 89 | + np.ones((1, Gelec.shape[0]))]) |
| 90 | + C = pinv(A) @ V |
| 91 | + |
| 92 | + allres = Gsph @ C |
| 93 | + meanval_broadcast = values.mean() # scalar mean across all good channels and time points |
| 94 | + allres += meanval_broadcast |
| 95 | + return allres |
| 96 | + |
| 97 | +def computeg(x, y, z, xelec, yelec, zelec, params): |
| 98 | + # x,y,z are points to interpolate; xelec,... electrode locations |
| 99 | + X = x.ravel()[:,None]; Y = y.ravel()[:,None]; Z = z.ravel()[:,None] |
| 100 | + E = 1 - np.sqrt((X - xelec[None,:])**2 + (Y - yelec[None,:])**2 + (Z - zelec[None,:])**2) |
| 101 | + |
| 102 | + m, maxn = params[1], int(params[2]) |
| 103 | + g = np.zeros((E.shape[0], E.shape[1])) |
| 104 | + for n in range(1, maxn+1): |
| 105 | + Pn = lpmv(0, n, E) # shape (E.shape) |
| 106 | + g += ((2*n+1)/(n**m*(n+1)**m)) * Pn |
| 107 | + |
| 108 | + return g/(4*np.pi) |
| 109 | + |
| 110 | +def test_spheric_spline(): |
| 111 | + import numpy as np |
| 112 | + from scipy.io import loadmat, savemat |
| 113 | + |
| 114 | + # generate random electrode positions on the unit sphere |
| 115 | + rng = np.random.default_rng(0) |
| 116 | + n_good, n_bad, n_pts = 10, 2, 100 |
| 117 | + xyz = rng.normal(size=(3, n_good)) |
| 118 | + xyz /= np.linalg.norm(xyz, axis=0) |
| 119 | + xbad = rng.normal(size=(3, n_bad)) |
| 120 | + xbad /= np.linalg.norm(xbad, axis=0) |
| 121 | + |
| 122 | + # random “good” channel data |
| 123 | + values = rng.standard_normal((n_good, n_pts)) |
| 124 | + |
| 125 | + # write to MATLAB file |
| 126 | + mat = { |
| 127 | + 'xelec': xyz[0], |
| 128 | + 'yelec': xyz[1], |
| 129 | + 'zelec': xyz[2], |
| 130 | + 'xbad': xbad[0], |
| 131 | + 'ybad': xbad[1], |
| 132 | + 'zbad': xbad[2], |
| 133 | + 'values': values, |
| 134 | + 'params': (0.0, 4.0, 7.0) |
| 135 | + } |
| 136 | + savemat('test_spheric_spline.mat', mat) |
| 137 | + |
| 138 | + # compute in Python |
| 139 | + py_res = spheric_spline( |
| 140 | + xelec=xyz[0], yelec=xyz[1], zelec=xyz[2], |
| 141 | + xbad=xbad[0], ybad=xbad[1], zbad=xbad[2], |
| 142 | + values=values, params=(0, 4, 7) |
| 143 | + ) |
| 144 | + |
| 145 | + # # load MATLAB result (assumed saved as `mat_res` in test.mat) |
| 146 | + mat_data = loadmat('test_spheric_spline_results.mat') |
| 147 | + mat_res = mat_data['allres'] # Assuming the MATLAB result is saved as 'mat_res' |
| 148 | + |
| 149 | + # # compare |
| 150 | + diff = np.abs(py_res - mat_res) |
| 151 | + |
| 152 | + # do a proper max abs and rel difference |
| 153 | + max_abs_diff = np.max(np.abs(py_res - mat_res)) |
| 154 | + max_rel_diff = np.max(np.abs(py_res - mat_res) / np.abs(mat_res)) |
| 155 | + print(f"Max absolute difference: {max_abs_diff}") |
| 156 | + print(f"Max relative difference: {max_rel_diff}") |
| 157 | + |
| 158 | +def test_computeg(): |
| 159 | + import numpy as np |
| 160 | + from scipy.io import loadmat, savemat |
| 161 | + # test computeg |
| 162 | + x = np.linspace(0, 1, 100) |
| 163 | + y = np.linspace(0, 1, 100) |
| 164 | + z = np.linspace(0, 1, 100) |
| 165 | + xelec = np.linspace(0, 1, 10) |
| 166 | + yelec = np.linspace(0, 1, 10) |
| 167 | + zelec = np.linspace(0, 1, 10) |
| 168 | + params = (0.0, 4.0, 7.0) |
| 169 | + |
| 170 | + # save to mat file |
| 171 | + mat = { |
| 172 | + 'x': x, |
| 173 | + 'y': y, |
| 174 | + 'z': z, |
| 175 | + 'xelec': xelec, |
| 176 | + 'yelec': yelec, |
| 177 | + 'zelec': zelec, |
| 178 | + 'params': params |
| 179 | + } |
| 180 | + savemat('test_computeg.mat', mat) |
| 181 | + |
| 182 | + # compute in Python |
| 183 | + g = computeg(x, y, z, xelec, yelec, zelec, params) |
| 184 | + print("g.shape python:", g.shape) |
| 185 | + |
| 186 | + # load MATLAB result |
| 187 | + mat_data = loadmat('test_computeg_results.mat') |
| 188 | + mat_res = mat_data['g'] |
| 189 | + print("g.shape matlab:", mat_res.shape) |
| 190 | + |
| 191 | + # compare |
| 192 | + diff = np.abs(g - mat_res) |
| 193 | + |
| 194 | + # do a proper max abs and rel difference |
| 195 | + max_abs_diff = np.max(np.abs(g - mat_res)) |
| 196 | + max_rel_diff = np.max(np.abs(g - mat_res) / np.abs(mat_res)) |
| 197 | + print(f"Max absolute difference: {max_abs_diff}") |
| 198 | + print(f"Max relative difference: {max_rel_diff}") |
| 199 | + |
| 200 | +if __name__ == '__main__': |
| 201 | + print("Running test_computeg") |
| 202 | + test_computeg() |
| 203 | + print("\nRunning test_spheric_spline") |
| 204 | + test_spheric_spline() |
| 205 | + |
| 206 | + |
0 commit comments