"""Segment-based rejection functions for EEG data.
This module provides standalone functions for identifying and annotating
noisy or problematic segments in continuous EEG data using various statistical
and correlation-based methods.
"""
from typing import Dict, List, Optional, Union
import mne
import numpy as np
import pandas as pd
import scipy.stats
import xarray as xr
from scipy.spatial import distance_matrix
[docs]
def annotate_noisy_segments(
raw: mne.io.Raw,
epoch_duration: float = 2.0,
epoch_overlap: float = 0.0,
picks: Optional[Union[List[str], str]] = None,
quantile_k: float = 3.0,
quantile_flag_crit: float = 0.2,
annotation_description: str = "BAD_noisy_segment",
verbose: Optional[bool] = None,
) -> mne.io.Raw:
"""Identify and annotate noisy segments in continuous EEG data.
This function temporarily epochs the continuous data, calculates channel-wise
standard deviations for each epoch, and then identifies epochs where a
significant proportion of channels exhibit outlier standard deviations.
The outlier detection is based on the interquartile range (IQR) method,
similar to what's used in the pylossless pipeline.
The method works by:
1. Creating temporary fixed-length epochs from continuous data
2. Calculating standard deviation for each channel in each epoch
3. Using IQR-based outlier detection to identify abnormal standard deviations
4. Flagging epochs where too many channels show outlier behavior
5. Adding annotations to mark these problematic time periods
Parameters
----------
raw : mne.io.Raw
The continuous EEG data to analyze for noisy segments.
epoch_duration : float, default 2.0
Duration of epochs in seconds for noise detection. Shorter epochs
provide finer temporal resolution but may be less stable for
outlier detection.
epoch_overlap : float, default 0.0
Overlap between epochs in seconds. Non-zero overlap provides
smoother detection but increases computation time.
picks : list of str, str, or None, default None
Channels to include in analysis. If None, defaults to 'eeg'.
Can be channel names (e.g., ['EEG 001', 'EEG 002']) or channel
types (e.g., 'eeg', 'grad').
quantile_k : float, default 3.0
Multiplier for the IQR when defining outlier thresholds for channel
standard deviations. A channel's std in an epoch is an outlier if it's
k IQRs above Q3 or below Q1 relative to its own distribution of stds
across all epochs. Higher values = more conservative detection.
quantile_flag_crit : float, default 0.2
Proportion threshold (0.0-1.0). If more than this proportion of picked
channels are marked as outliers (having outlier std) within an epoch,
that epoch is flagged as noisy. Lower values = more sensitive detection.
annotation_description : str, default "BAD_noisy_segment"
The description to use for MNE annotations marking noisy segments.
Should start with "BAD_" to be recognized by MNE as artifact annotations.
verbose : bool or None, default None
Control verbosity of output during processing.
Returns
-------
raw_annotated : mne.io.Raw
Copy of input Raw object with added annotations for noisy segments.
Original data is not modified.
Raises
------
TypeError
If raw is not an MNE Raw object.
ValueError
If parameters are outside valid ranges or no epochs can be created.
RuntimeError
If processing fails due to insufficient data or other errors.
Notes
-----
**Detection Algorithm:**
1. For each channel, its standard deviation is calculated within each epoch
2. For each channel, the distribution of its standard deviations across all
epochs is analyzed using quartiles (Q1, Q3) and IQR
3. Outlier thresholds are: Q1 - k*IQR and Q3 + k*IQR
4. An epoch is marked as noisy if the proportion of channels whose standard
deviation falls outside their respective outlier bounds exceeds the
quantile_flag_crit threshold
**Parameter Guidelines:**
- epoch_duration: 1-4 seconds typical. Shorter for transient artifacts,
longer for stable outlier detection
- quantile_k: 2-4 typical. Higher values = fewer false positives
- quantile_flag_crit: 0.1-0.3 typical. Lower = more sensitive
**Performance Considerations:**
- Processing time scales with (data_length / epoch_duration)
- Memory usage depends on number of epochs and channels
- Overlap increases computation but may improve detection continuity
Examples
--------
Basic noise detection with default parameters:
>>> from autoclean import annotate_noisy_segments
>>> raw_clean = annotate_noisy_segments(raw)
>>> noisy_annotations = [ann for ann in raw_clean.annotations
... if 'noisy' in ann['description']]
>>> print(f"Found {len(noisy_annotations)} noisy segments")
Conservative detection for high-quality data:
>>> raw_clean = annotate_noisy_segments(
... raw,
... epoch_duration=3.0,
... quantile_k=4.0,
... quantile_flag_crit=0.3
... )
Sensitive detection for noisy data:
>>> raw_clean = annotate_noisy_segments(
... raw,
... epoch_duration=1.0,
... quantile_k=2.0,
... quantile_flag_crit=0.1,
... annotation_description="BAD_very_noisy"
... )
EEG-only detection with channel selection:
>>> raw_clean = annotate_noisy_segments(
... raw,
... picks=['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4'],
... epoch_duration=2.0
... )
See Also
--------
annotate_uncorrelated_segments : Detect segments with poor channel correlations
mne.Annotations : MNE annotations system
autoclean.detect_outlier_epochs : Statistical outlier detection for epochs
References
----------
This implementation adapts concepts from the PREP pipeline and pylossless:
Bigdely-Shamlo, N., Mullen, T., Kothe, C., Su, K. M., & Robbins, K. A. (2015).
The PREP pipeline: standardized preprocessing for large-scale EEG analysis.
Frontiers in neuroinformatics, 9, 16.
"""
# Input validation
if not isinstance(raw, mne.io.BaseRaw):
raise TypeError(f"Data must be an MNE Raw object, got {type(raw).__name__}")
if epoch_duration <= 0:
raise ValueError(f"epoch_duration must be positive, got {epoch_duration}")
if epoch_overlap < 0:
raise ValueError(f"epoch_overlap must be non-negative, got {epoch_overlap}")
if not 0 <= quantile_flag_crit <= 1:
raise ValueError(
f"quantile_flag_crit must be between 0 and 1, got {quantile_flag_crit}"
)
if quantile_k <= 0:
raise ValueError(f"quantile_k must be positive, got {quantile_k}")
# Set default picks
if picks is None:
picks = "eeg"
try:
# Create fixed-length epochs for analysis
events = mne.make_fixed_length_events(
raw, duration=epoch_duration, overlap=epoch_overlap
)
# Ensure events are within data boundaries
if events.shape[0] == 0:
raise ValueError("No epochs could be created with the given parameters")
max_event_time = events[-1, 0] + int(epoch_duration * raw.info["sfreq"])
if max_event_time > len(raw.times):
# Prune events that would lead to epochs exceeding data length
valid_events_mask = events[:, 0] + int(
epoch_duration * raw.info["sfreq"]
) <= len(raw.times)
events = events[valid_events_mask]
if events.shape[0] == 0:
raise ValueError("No valid epochs after boundary check")
# Create epochs
epochs = mne.Epochs(
raw,
events,
tmin=0.0,
tmax=epoch_duration - 1.0 / raw.info["sfreq"],
picks=picks,
preload=True,
baseline=None, # No baseline correction for std calculation
reject=None, # We are detecting bads, not rejecting yet
verbose=verbose,
)
if len(epochs) == 0:
raise ValueError(f"No epochs left after picking channels: {picks}")
# Convert epochs to xarray DataArray (channels, epochs, time)
epochs_xr = _epochs_to_xr(epochs)
# Calculate standard deviation for each channel within each epoch
data_sd = epochs_xr.std("time") # Shape: (channels, epochs)
# Detect noisy epochs using outlier detection logic
outliers_kwargs = {"k": quantile_k}
bad_epoch_indices = _detect_outliers(
data_sd,
flag_dim="epoch", # We want to flag epochs
outlier_method="quantile",
flag_crit=quantile_flag_crit,
init_dir="pos", # Interested in high std_dev for noise
outliers_kwargs=outliers_kwargs,
)
if len(bad_epoch_indices) == 0:
# No noisy segments found, return copy of original
return raw.copy()
# Add annotations to the original raw object
relative_onsets = epochs.events[bad_epoch_indices, 0] / raw.info["sfreq"]
onsets = relative_onsets - raw.first_samp / raw.info["sfreq"]
# Duration of each annotation matches epoch_duration
annotation_durations = np.full_like(onsets, fill_value=epoch_duration)
descriptions = [annotation_description] * len(bad_epoch_indices)
# Create new annotations
new_annotations = mne.Annotations(
onset=onsets,
duration=annotation_durations,
description=descriptions,
orig_time=raw.annotations.orig_time,
)
# Make a copy and add annotations
raw_annotated = raw.copy()
raw_annotated.set_annotations(raw_annotated.annotations + new_annotations)
return raw_annotated
except Exception as e:
raise RuntimeError(f"Failed to annotate noisy segments: {str(e)}") from e
# Helper functions
def _epochs_to_xr(epochs: mne.Epochs) -> xr.DataArray:
"""Create an Xarray DataArray from MNE Epochs.
Converts epochs data to xarray format for easier manipulation
with dimensions (channels, epochs, time).
"""
data = epochs.get_data() # n_epochs, n_channels, n_times
ch_names = epochs.ch_names
# Transpose to (n_channels, n_epochs, n_times)
data_transposed = data.transpose(1, 0, 2)
return xr.DataArray(
data_transposed,
coords={
"ch": ch_names,
"epoch": np.arange(data_transposed.shape[1]),
"time": epochs.times,
},
dims=("ch", "epoch", "time"),
)
def _get_outliers_quantile(
array: xr.DataArray,
dim: str,
lower: float = 0.25,
upper: float = 0.75,
mid: float = 0.5,
k: float = 3.0,
) -> tuple:
"""Calculate outlier bounds based on the IQR method."""
lower_val, mid_val, upper_val = array.quantile([lower, mid, upper], dim=dim)
lower_dist = mid_val - lower_val
upper_dist = upper_val - mid_val
return mid_val - lower_dist * k, mid_val + upper_dist * k
def _detect_outliers(
array: xr.DataArray,
flag_dim: str,
outlier_method: str = "quantile",
flag_crit: float = 0.2,
init_dir: str = "pos",
outliers_kwargs: Optional[Dict] = None,
) -> np.ndarray:
"""Mark items along flag_dim as flagged for artifact."""
if outliers_kwargs is None:
outliers_kwargs = {}
# Determine the dimension to operate across
dims = list(array.dims)
if flag_dim not in dims:
raise ValueError(f"flag_dim '{flag_dim}' not in array dimensions: {dims}")
dims.remove(flag_dim)
if not dims:
raise ValueError("Array must have at least two dimensions")
operate_dim = dims[0]
if outlier_method == "quantile":
l_out, u_out = _get_outliers_quantile(array, dim=flag_dim, **outliers_kwargs)
else:
raise ValueError(
f"outlier_method '{outlier_method}' not supported. Use 'quantile'"
)
outlier_mask = xr.zeros_like(array, dtype=bool)
if init_dir == "pos" or init_dir == "both":
outlier_mask = outlier_mask | (array > u_out)
if init_dir == "neg" or init_dir == "both":
outlier_mask = outlier_mask | (array < l_out)
# Calculate proportion of outliers along operate_dim
prop_outliers = outlier_mask.astype(float).mean(operate_dim)
if "quantile" in list(prop_outliers.coords.keys()):
prop_outliers = prop_outliers.drop_vars("quantile")
flagged_indices = prop_outliers[prop_outliers > flag_crit].coords[flag_dim].values
return flagged_indices
def _calculate_neighbor_correlations(
epochs: mne.Epochs,
n_nearest_neighbors: int,
corr_method: str = "max",
corr_trim_percent: float = 10.0,
) -> xr.DataArray:
"""Compute nearest neighbor correlations for channels within epochs."""
montage = epochs.get_montage()
ch_positions = montage.get_positions()["ch_pos"]
valid_chs = [ch for ch in epochs.ch_names if ch in ch_positions]
if not valid_chs:
raise ValueError(
"No channel positions found for any channels in the epochs object"
)
if len(valid_chs) <= n_nearest_neighbors:
actual_n_neighbors = max(0, len(valid_chs) - 1)
else:
actual_n_neighbors = n_nearest_neighbors
# Create distance matrix and find nearest neighbors
ch_locs_df = pd.DataFrame(ch_positions).T.loc[valid_chs]
dist_matrix_val = distance_matrix(ch_locs_df.values, ch_locs_df.values)
chan_dist_df = pd.DataFrame(
dist_matrix_val, columns=ch_locs_df.index, index=ch_locs_df.index
)
rank = chan_dist_df.rank(axis="columns", method="first", ascending=True) - 1
rank[rank == 0] = np.nan
nearest_neighbor_df = pd.DataFrame(
index=ch_locs_df.index, columns=range(actual_n_neighbors), dtype=object
)
for ch_name_iter in ch_locs_df.index:
sorted_neighbors = rank.loc[ch_name_iter].dropna().sort_values()
nearest_neighbor_df.loc[ch_name_iter] = sorted_neighbors.index[
:actual_n_neighbors
].values
# Pick only valid channels for epochs_xr
epochs_xr = _epochs_to_xr(epochs.copy().pick(valid_chs))
all_channel_corrs = []
for _, ch_name in enumerate(valid_chs):
neighbor_names_for_ch = [
n
for n in nearest_neighbor_df.loc[ch_name].values.tolist()
if pd.notna(n) and n != ch_name
]
if not neighbor_names_for_ch:
# No valid neighbors
ch_neighbor_corr_aggregated = xr.DataArray(
np.full(epochs_xr.sizes["epoch"], np.nan),
coords={"epoch": epochs_xr.coords["epoch"]},
dims=["epoch"],
)
else:
# Data for the current reference channel
this_ch_data = epochs_xr.sel(ch=ch_name)
# Data for its neighbors
neighbor_chs_data = epochs_xr.sel(ch=neighbor_names_for_ch)
# Calculate Pearson correlation along the 'time' dimension
ch_to_neighbors_corr = xr.corr(this_ch_data, neighbor_chs_data, dim="time")
# Aggregate correlations based on corr_method
if corr_method == "max":
ch_neighbor_corr_aggregated = np.abs(ch_to_neighbors_corr).max(dim="ch")
elif corr_method == "mean":
ch_neighbor_corr_aggregated = np.abs(ch_to_neighbors_corr).mean(
dim="ch"
)
elif corr_method == "trimmean":
proportion_to_cut = corr_trim_percent / 100.0
np_data = np.abs(ch_to_neighbors_corr).transpose("epoch", "ch").data
trimmed_means_per_epoch = [
(
scipy.stats.trim_mean(
epoch_data_for_trim, proportiontocut=proportion_to_cut
)
if not np.all(np.isnan(epoch_data_for_trim))
and len(epoch_data_for_trim) > 0
else np.nan
)
for epoch_data_for_trim in np_data
]
ch_neighbor_corr_aggregated = xr.DataArray(
trimmed_means_per_epoch,
coords={"epoch": ch_to_neighbors_corr.coords["epoch"]},
dims=["epoch"],
)
else:
raise ValueError(f"Unknown corr_method: {corr_method}")
# Expand to add the reference channel's name as a new 'ch' dimension
expanded_corr = ch_neighbor_corr_aggregated.expand_dims(dim={"ch": [ch_name]})
all_channel_corrs.append(expanded_corr)
if not all_channel_corrs:
return xr.DataArray(
np.empty((0, epochs_xr.sizes.get("epoch", 0))),
coords={"ch": [], "epoch": epochs_xr.coords.get("epoch", [])},
dims=("ch", "epoch"),
)
# Concatenate results for all reference channels
concatenated_corrs = xr.concat(all_channel_corrs, dim="ch")
return concatenated_corrs