22from copy import deepcopy
33import math
44import logging
5- import sys
65from collections import OrderedDict
76from glob import glob
8- from typing import Union , List
7+ from typing import Union , List , Dict
98from time import sleep , time
9+ from numpy .core .fromnumeric import std
1010
1111import pandas as pd
1212import numpy as np
1616from mne .channels import make_standard_montage
1717from mne .filter import create_filter
1818from matplotlib import pyplot as plt
19- from scipy import stats
2019from scipy .signal import lfilter , lfilter_zi
2120
2221from eegnb import _get_recording_dir
2322from eegnb .devices .eeg import EEG
2423from eegnb .devices .utils import EEG_INDICES , SAMPLE_FREQS
2524
25+
2626
2727# this should probably not be done here
2828sns .set_context ("talk" )
3232logger = logging .getLogger (__name__ )
3333
3434
35- def _bootstrap (data , n_boot : int , ci : float ):
36- """From: https://stackoverflow.com/a/47582329/965332"""
37- boot_dist = []
38- for i in range (int (n_boot )):
39- resampler = np .random .randint (0 , data .shape [0 ], data .shape [0 ])
40- sample = data .take (resampler , axis = 0 )
41- boot_dist .append (np .mean (sample , axis = 0 ))
42- b = np .array (boot_dist )
43- s1 = np .apply_along_axis (stats .scoreatpercentile , 0 , b , 50 - ci / 2 )
44- s2 = np .apply_along_axis (stats .scoreatpercentile , 0 , b , 50 + ci / 2 )
45- return (s1 , s2 )
46-
47-
48- def _tsplotboot (ax , data , time : list , n_boot : int , ci : float , color ):
49- """From: https://stackoverflow.com/a/47582329/965332"""
50- # Time forms the xaxis of the plot
51- if time is None :
52- x = np .arange (data .shape [1 ])
53- else :
54- x = np .asarray (time )
55- est = np .mean (data , axis = 0 )
56- cis = _bootstrap (data , n_boot , ci )
57- ax .fill_between (x , cis [0 ], cis [1 ], alpha = 0.2 , color = color )
58- ax .plot (x , est , color = color )
59- ax .margins (x = 0 )
60-
61-
6235def load_csv_as_raw (
6336 fnames : List [str ],
6437 sfreq : float ,
@@ -179,9 +152,7 @@ def load_data(
179152 site = "*"
180153
181154 data_path = (
182- _get_recording_dir (
183- device_name , experiment , subject_str , session_str , site , data_dir
184- )
155+ _get_recording_dir (device_name , experiment , subject_str , session_str , site , data_dir )
185156 / "*.csv"
186157 )
187158 fnames = glob (str (data_path ))
@@ -222,8 +193,7 @@ def plot_conditions(
222193 ylim = (- 6 , 6 ),
223194 diff_waveform = (1 , 2 ),
224195 channel_count = 4 ,
225- channel_order = None ,
226- ):
196+ channel_order = None ):
227197 """Plot ERP conditions.
228198 Args:
229199 epochs (mne.epochs): EEG epochs
@@ -249,9 +219,10 @@ def plot_conditions(
249219 """
250220
251221 if channel_order :
252- channel_order = np .array (channel_order )
222+ channel_order = np .array (channel_order )
253223 else :
254- channel_order = np .array (range (channel_count ))
224+ channel_order = np .array (range (channel_count ))
225+
255226
256227 if isinstance (conditions , dict ):
257228 conditions = OrderedDict (conditions )
@@ -261,7 +232,7 @@ def plot_conditions(
261232
262233 X = epochs .get_data () * 1e6
263234
264- X = X [:, channel_order ]
235+ X = X [:,channel_order ]
265236
266237 times = epochs .times
267238 y = pd .Series (epochs .events [:, - 1 ])
@@ -278,15 +249,13 @@ def plot_conditions(
278249
279250 for ch in range (channel_count ):
280251 for cond , color in zip (conditions .values (), palette ):
281- y_cond = y .isin (cond )
282- X_cond = X [y_cond , ch ]
283- _tsplotboot (
284- ax = axes [ch ],
285- data = X_cond ,
252+ sns .tsplot (
253+ X [y .isin (cond ), ch ],
286254 time = times ,
287255 color = color ,
288256 n_boot = n_boot ,
289257 ci = ci ,
258+ ax = axes [ch ],
290259 )
291260
292261 if diff_waveform :
0 commit comments