Source code for autoclean.mixins.signal_processing.artifacts

"""Artifacts detection and rejection mixin for autoclean tasks."""

from typing import Optional, Union

import mne
import numpy as np

from autoclean.utils.logging import message


[docs] class ArtifactsMixin: """Mixin class providing artifact detection and rejection functionality for EEG data."""
[docs] def detect_dense_oscillatory_artifacts( self, data: Union[mne.io.Raw, None] = None, window_size_ms: int = 100, channel_threshold_uv: float = 45, min_channels: int = 75, padding_ms: float = 500, annotation_label: str = "BAD_REF_AF", ) -> mne.io.Raw: """Detect smaller, dense oscillatory multichannel artifacts. This method identifies oscillatory artifacts that affect multiple channels simultaneously, while excluding large single deflections. Parameters ---------- data : mne.io.Raw, Optional The raw data to detect artifacts from. If None, uses self.raw. window_size_ms : int, Optional Window size in milliseconds for artifact detection, by default 100. channel_threshold_uv : float, Optional Threshold for peak-to-peak amplitude in microvolts, by default 45. min_channels : int, Optional Minimum number of channels that must exhibit oscillations, by default 75. padding_ms : float, Optional Amount of padding in milliseconds to add before and after each detected artifact, by default 500. annotation_label : str, Optional Label to use for the annotations, by default "BAD_REF_AF". stage_name : str, Optional Name for saving and metadata, by default "detect_dense_oscillatory_artifacts". Returns ------- result_raw : instance of mne.io.Raw The raw data object with updated artifact annotations. *Note the self.raw is updated in place. So the return value is optional.* Examples -------- >>> #Inside a task class that uses the autoclean framework >>> self.detect_dense_oscillatory_artifacts() >>> #Or with custom parameters >>> self.detect_dense_oscillatory_artifacts(window_size_ms=200, channel_threshold_uv=50, min_channels=100, padding_ms=1000, annotation_label="BAD_CUSTOM_ARTIFACT") Notes ----- This method is intended to find reference artifacts but may also be triggered by other artifacts. """ # Determine which data to use data = self._get_data_object(data) # Type checking if not isinstance(data, mne.io.Raw) and not isinstance( data, mne.io.base.BaseRaw ): raise TypeError("Data must be an MNE Raw object for artifact detection") try: # Convert parameters to samples and volts sfreq = data.info["sfreq"] window_size = int(window_size_ms * sfreq / 1000) channel_threshold = channel_threshold_uv * 1e-6 # Convert µV to V padding_sec = padding_ms / 1000.0 # Convert padding to seconds # Get data and times raw_data, times = data.get_data(return_times=True) _, n_samples = raw_data.shape artifact_annotations = [] # Sliding window detection for start_idx in range(0, n_samples - window_size, window_size): window = raw_data[:, start_idx : start_idx + window_size] # Compute peak-to-peak amplitude for each channel in the window ptp_amplitudes = np.ptp( window, axis=1 ) # Peak-to-peak amplitude per channel # Count channels exceeding the threshold num_channels_exceeding = np.sum(ptp_amplitudes > channel_threshold) # Check if artifact spans multiple channels with oscillatory behavior if num_channels_exceeding >= min_channels: start_time = times[start_idx] - padding_sec # Add padding before end_time = ( times[start_idx + window_size] + padding_sec ) # Add padding after # Ensure we don't go beyond recording bounds start_time = max(start_time, times[0]) end_time = min(end_time, times[-1]) artifact_annotations.append( [start_time, end_time - start_time, annotation_label] ) # Create a copy of the raw data result_raw = data.copy() # Add annotations to the raw data if artifact_annotations: for annotation in artifact_annotations: result_raw.annotations.append( onset=annotation[0], duration=annotation[1], description=annotation[2], ) message( "info", f"Added {len(artifact_annotations)} potential reference artifact annotations", ) else: message("info", "No reference artifacts detected") # Add flags if needed if len(artifact_annotations) > self.REFERENCE_ARTIFACT_THRESHOLD: flagged_reason = f"WARNING: {len(artifact_annotations)} potential reference artifacts detected" # pylint: disable=line-too-long self._update_flagged_status(flagged=True, reason=flagged_reason) # Update metadata metadata = { "window_size_ms": window_size_ms, "channel_threshold_uv": channel_threshold_uv, "min_channels": min_channels, "padding_ms": padding_ms, "annotation_label": annotation_label, "artifacts_detected": len(artifact_annotations), } self._update_metadata("step_detect_dense_oscillatory_artifacts", metadata) # Save the result self._save_raw_result(result_raw, "post_artifact_detection") # Update self.raw if we're using it self._update_instance_data(data, result_raw) return result_raw except Exception as e: message("error", f"Error during artifact detection: {str(e)}") raise RuntimeError(f"Failed to detect artifacts: {str(e)}") from e
[docs] def detect_muscle_beta_focus( self, data: Union[mne.io.Raw, None] = None, freq_band: tuple = (20, 30), scale_factor: float = 3.0, window_length: float = 1.0, window_overlap: float = 0.5, annotation_description: str = "BAD_MOVEMENT", ) -> mne.io.Raw: """Detect muscle artifacts in continuous Raw data and add annotations. This method detects muscle artifacts in continuous EEG data by analyzing high-frequency activity in peripheral electrodes. It automatically adds annotations to the Raw object marking segments with detected artifacts. Parameters ---------- data : mne.io.Raw, Optional The raw data to detect artifacts from. If None, uses self.raw. freq_band : tuple, Optional Frequency band for filtering (min, max), by default (20, 30). scale_factor : float, Optional Scale factor for threshold calculation, by default 3.0. window_length : float, Optional Length of sliding window in seconds, by default 1.0. window_overlap : float, Optional Overlap between windows as a fraction (0-1), by default 0.5. annotation_description : str, Optional Description for the annotations, by default "BAD_MOVEMENT". Returns ------- results_raw : instance of mne.io.Raw The raw data object with updated artifact annotations. *Note the self.raw is updated in place. So the return value is optional.* Examples -------- >>> #Inside a task class that uses the autoclean framework >>> self.detect_muscle_beta_focus() >>> #Or with custom parameters >>> self.detect_muscle_beta_focus(freq_band=(20, 30), scale_factor=4.0, window_length=2.0, window_overlap=0.7, annotation_description="BAD_CUSTOM_ARTIFACT") """ # Determine which data to use data = self._get_data_object(data) # Type checking if not isinstance(data, mne.io.Raw) and not isinstance( data, mne.io.base.BaseRaw ): raise TypeError( "Data must be an MNE Raw object for muscle artifact detection" ) # Ensure data is loaded data.load_data() # Create a copy to work with results_raw = data.copy() # Filter in beta/gamma band raw_beta = data.copy().filter( l_freq=freq_band[0], h_freq=freq_band[1], verbose=False ) # Build channel_region_map from the provided channel data # Make sure all "OTHER" electrodes are listed here channel_region_map = { "E17": "OTHER", "E38": "OTHER", "E43": "OTHER", "E44": "OTHER", "E48": "OTHER", "E49": "OTHER", "E56": "OTHER", "E73": "OTHER", "E81": "OTHER", "E88": "OTHER", "E94": "OTHER", "E107": "OTHER", "E113": "OTHER", "E114": "OTHER", "E119": "OTHER", "E120": "OTHER", "E121": "OTHER", "E125": "OTHER", "E126": "OTHER", "E127": "OTHER", "E128": "OTHER", } # Get channel names ch_names = raw_beta.ch_names # Select only OTHER channels selected_ch_indices = [ i for i, ch in enumerate(ch_names) if channel_region_map.get(ch, "") == "OTHER" ] # If no OTHER channels are found, return if not selected_ch_indices: message("info", "No 'OTHER' channels found for muscle artifact detection") return None # Calculate window parameters sfreq = raw_beta.info["sfreq"] n_samples = len(raw_beta.times) window_samples = int(window_length * sfreq) step_samples = int(window_samples * (1 - window_overlap)) # Create sliding windows n_windows = max(1, int((n_samples - window_samples) / step_samples) + 1) # Store peak-to-peak values for each window max_p2p_values = [] window_times = [] # Process each window for i in range(n_windows): start_sample = i * step_samples end_sample = min(start_sample + window_samples, n_samples) # Skip if window is too small if end_sample - start_sample < window_samples / 2: continue # Extract data for this window (only selected channels) window_data = raw_beta.get_data( picks=selected_ch_indices, start=start_sample, stop=end_sample ) # Compute peak-to-peak amplitude per channel p2p = window_data.max(axis=1) - window_data.min(axis=1) # Compute maximum peak-to-peak amplitude across channels max_p2p = np.max(p2p) max_p2p_values.append(max_p2p) # Store window time boundaries start_time = start_sample / sfreq end_time = end_sample / sfreq window_times.append((start_time, end_time)) # Compute median and MAD max_p2p_values = np.array(max_p2p_values) med = np.median(max_p2p_values) mad = np.median(np.abs(max_p2p_values - med)) # Robust threshold threshold = med + scale_factor * mad # Identify bad windows bad_window_indices = np.where(max_p2p_values > threshold)[0].tolist() bad_windows = [window_times[i] for i in bad_window_indices] # Add annotations if bad_windows: # Merge overlapping windows merged_windows = self._merge_overlapping_windows(bad_windows) # Add annotations for start, end in merged_windows: results_raw.annotations.append( onset=start, duration=end - start, description=annotation_description, ) message( "info", f"Added {len(merged_windows)} {annotation_description} annotations to Raw data", ) # Update the original data with the new annotations self._update_instance_data(data, results_raw) else: message("info", "No muscle artifacts detected") # Update metadata metadata = { "freq_band": freq_band, "scale_factor": scale_factor, "window_length": window_length, "window_overlap": window_overlap, "annotation_description": annotation_description, } self._update_metadata("step_detect_muscle_artifacts", metadata) return results_raw
def _merge_overlapping_windows(self, windows): """Merge overlapping time windows. Args: windows : List of tuples (start_time, end_time) in seconds Returns ------- List of merged tuples (start_time, end_time) with no overlaps """ if not windows: return [] # Sort windows by start time sorted_windows = sorted(windows, key=lambda x: x[0]) # Initialize with the first window merged = [sorted_windows[0]] # Iterate through remaining windows for current in sorted_windows[1:]: previous = merged[-1] # If current window overlaps with previous, merge them if current[0] <= previous[1]: merged[-1] = (previous[0], max(previous[1], current[1])) else: merged.append(current) return merged
[docs] def reject_bad_segments( self, data: Union[mne.io.Raw, None] = None, bad_label: Optional[str] = None, stage_name: str = "bad_segment_rejection", ) -> mne.io.Raw: """Remove all time spans annotated with a specific label or all 'BAD' segments. This method removes segments marked as bad and concatenates the remaining good segments. Parameters ---------- data : mne.io.Raw, Optional The raw data to detect artifacts from. If None, uses self.raw. bad_label : str, Optional Specific label of annotations to reject. If None, rejects all segments where description starts with 'BAD' stage_name : str, Optional Name for saving and metadata, by default "bad_segment_rejection". Returns ------- raw_cleaned : instance of mne.io.Raw The raw data object with updated artifact annotations. *Note the self.raw is updated in place. So the return value is optional.*. Examples -------- >>> #Inside a task class that uses the autoclean framework >>> self.reject_bad_segments() >>> #Or with custom label >>> self.reject_bad_segments(bad_label="BAD_CUSTOM_ARTIFACT") """ # Determine which data to use data = self._get_data_object(data) # Type checking if not isinstance(data, mne.io.base.BaseRaw): raise TypeError("Data must be an MNE Raw object for segment rejection") try: # Get annotations annotations = data.annotations # Identify bad intervals based on label matching strategy bad_intervals = [ (onset, onset + duration) for onset, duration, desc in zip( annotations.onset, annotations.duration, annotations.description ) if (bad_label is None and desc.startswith("BAD")) or (bad_label is not None and desc == bad_label) ] # Define good intervals (non-bad spans) good_intervals = [] prev_end = 0 # Start of the first good interval for start, end in sorted(bad_intervals): if prev_end < start: good_intervals.append((prev_end, start)) # Add non-bad span prev_end = end if prev_end < data.times[-1]: # Add final good interval if it exists good_intervals.append((prev_end, data.times[-1])) # Crop and concatenate good intervals if not good_intervals: message("warning", "No good segments found after rejection") return data.copy() raw_segments = [ data.copy().crop(tmin=start, tmax=end) for start, end in good_intervals ] raw_cleaned = mne.concatenate_raws(raw_segments) # Update metadata metadata = { "bad_label": bad_label if bad_label else "All BAD*", "segments_removed": len(bad_intervals), "segments_kept": len(good_intervals), "original_duration": data.times[-1], "cleaned_duration": raw_cleaned.times[-1], } self._update_metadata("step_reject_bad_segments", metadata) # Save the result self._save_raw_result(raw_cleaned, stage_name) # Update self.raw if we're using it self._update_instance_data(data, raw_cleaned) return raw_cleaned except Exception as e: message("error", f"Error during segment rejection: {str(e)}") raise RuntimeError(f"Failed to reject bad segments: {str(e)}") from e