Source code for autoclean.mixins.signal_processing.eventid_epochs

"""Event ID epochs creation mixin for autoclean tasks.

This module provides functionality for creating epochs based on event markers in
EEG data. Event-based epochs are time segments centered around specific event markers
that represent stimuli, responses, or other experimental events of interest.

The EventIDEpochsMixin class implements methods for creating these epochs and
detecting artifacts within them, particularly focusing on reference and muscle
artifacts that can contaminate the data.

Event-based epoching is particularly useful for task-based EEG analysis, where
the data needs to be segmented around specific events of interest for further
processing and analysis, such as event-related potentials (ERPs) or time-frequency
analysis.
"""

from typing import Dict, Optional, Union

import mne
import numpy as np
import pandas as pd

from autoclean.functions.epoching import create_eventid_epochs as _create_eventid_epochs
from autoclean.utils.logging import message


[docs] class EventIDEpochsMixin: """Mixin class providing event ID based epochs creation functionality for EEG data."""
[docs] def create_eventid_epochs( self, data: Union[mne.io.Raw, None] = None, event_id: Optional[Dict[str, int]] = None, tmin: float = -0.5, tmax: float = 2, baseline: Optional[tuple] = (None, 0), volt_threshold: Optional[Dict[str, float]] = None, reject_by_annotation: bool = False, keep_all_epochs: bool = False, stage_name: str = "post_epochs", ) -> Optional[mne.Epochs]: """Create epochs based on event IDs from raw data. Parameters ---------- data : mne.io.Raw, Optional The raw data to create epochs from. If None, uses self.raw. event_id : dict, Optional Dictionary mapping event names to event IDs (e.g., {"target": 1, "standard": 2}). tmin : float, Optional Start time of the epoch relative to the event in seconds, by default -0.5. tmax : float, Optional End time of the epoch relative to the event in seconds, by default 2. baseline : tuple, Optional Baseline correction (tuple of start, end), by default (None, 0). volt_threshold : dict, Optional Dictionary of channel types and thresholds for rejection, by default None. reject_by_annotation : bool, Optional Whether to reject epochs by annotation, by default False. keep_all_epochs : bool, Optional If True, no epochs will be dropped - bad epochs will only be marked in metadata, by default False. stage_name : str, Optional Name for saving and metadata, by default "post_epochs". Returns ------- epochs_clean : instance of mne.Epochs | None The created epochs or None if epoching is disabled. Notes ----- This method creates epochs centered around specific event IDs in the raw data. It is useful for event-related potential (ERP) analysis where you want to extract segments of data time-locked to specific events. """ # Check if epoch_settings is enabled in the configuration is_enabled, epoch_config = self._check_step_enabled("epoch_settings") if not is_enabled: message("info", "Epoch settings step is disabled in configuration") return None # Get epoch settings if epoch_config and isinstance(epoch_config, dict): epoch_value = epoch_config.get("value", {}) if isinstance(epoch_value, dict): tmin = epoch_value.get("tmin", tmin) tmax = epoch_value.get("tmax", tmax) event_id = epoch_config.get("event_id", {}) # Get keep_all_epochs setting if available keep_all_epochs = epoch_config.get("keep_all_epochs", keep_all_epochs) # Get baseline settings baseline_settings = epoch_config.get("remove_baseline", {}) if isinstance(baseline_settings, dict) and baseline_settings.get( "enabled", False ): baseline = baseline_settings.get("window", baseline) # Get threshold settings threshold_settings = epoch_config.get("threshold_rejection", {}) if isinstance(threshold_settings, dict) and threshold_settings.get( "enabled", False ): threshold_config = threshold_settings.get("volt_threshold", {}) if isinstance(threshold_config, (int, float)): volt_threshold = {"eeg": float(threshold_config)} elif isinstance(threshold_config, dict): volt_threshold = {k: float(v) for k, v in threshold_config.items()} # Determine which data to use data = self._get_data_object(data) # Type checking if not isinstance(data, mne.io.Raw) and not isinstance( data, mne.io.base.BaseRaw ): raise TypeError("Data must be an MNE Raw object for epoch creation") try: # Check if event_id is provided if event_id is None: message("warning", "No event_id provided for event-based epoching") return None message("header", f"Creating epochs based on event IDs: {event_id}") # Get all events from annotations events_all, event_id_all = mne.events_from_annotations(data) # Find all event types that match our event_id keys event_patterns = {} # Name and code of events to epoch by for event_key in event_id.keys(): # Could lead to undesired results if event_key is a substring of another event matching_events = [ k for k in event_id_all.keys() if event_key == str(k) ] for match in matching_events: event_patterns[match] = event_id_all[match] message( "info", f"Looking for events matching patterns: {list(event_patterns.keys())}", ) # Filter events to include only those with matching trigger codes trigger_codes = list(event_patterns.values()) events_trig = events_all[np.isin(events_all[:, 2], trigger_codes)] if len(events_trig) == 0: message("warning", "No matching events found") return None message("info", f"Found {len(events_trig)} events matching the patterns") # Create epochs with the filtered events # Use standalone function for core epoch creation epochs = _create_eventid_epochs( data=data, event_id=event_patterns, tmin=tmin, tmax=tmax, baseline=baseline, reject=(None if keep_all_epochs else volt_threshold), reject_by_annotation=(reject_by_annotation and not keep_all_epochs), preload=True, on_missing="ignore", # Don't error if no events ) # Step 5: Filter other events to keep only those that fall *within the kept epochs* sfreq = data.info["sfreq"] epoch_samples = epochs.events[:, 0] # sample indices of epoch triggers # Compute valid ranges for each epoch (in raw sample indices) start_offsets = int(tmin * sfreq) end_offsets = int(tmax * sfreq) epoch_sample_ranges = [ (s + start_offsets, s + end_offsets) for s in epoch_samples ] # Filter events_all for events that fall inside any of those ranges events_in_epochs = [] for sample, prev, code in events_all: for i, (start, end) in enumerate(epoch_sample_ranges): if start <= sample <= end: events_in_epochs.append([sample, prev, code]) break # prevent double counting elif sample < start: break events_in_epochs = np.array(events_in_epochs, dtype=int) event_descriptions = {v: k for k, v in event_id_all.items()} # Build metadata rows metadata_rows = [] for i, (start, end) in enumerate(epoch_sample_ranges): epoch_events = [] for sample, _, code in events_in_epochs: if start <= sample <= end: relative_time = (sample - epoch_samples[i]) / sfreq label = event_descriptions.get(code, f"code_{code}") epoch_events.append((label, relative_time)) metadata_rows.append({"additional_events": epoch_events}) # Add the metadata column if epochs.metadata is not None: epochs.metadata["additional_events"] = [ row["additional_events"] for row in metadata_rows ] else: epochs.metadata = pd.DataFrame(metadata_rows) # Create a copy for potential dropping epochs_clean = epochs.copy() # If we're keeping all epochs but still want to mark them, we need to apply additional logic if keep_all_epochs: # 1. Mark epochs that would have been rejected by voltage threshold if volt_threshold is not None: # Use MNE's built-in functionality to detect which epochs exceed thresholds # but don't actually drop them drop_log_thresh = mne.preprocessing.compute_thresholds( epochs, volt_threshold ) bad_epochs_thresh = [] for idx, log in enumerate(drop_log_thresh): if len(log) > 0: # If epoch would have been dropped bad_epochs_thresh.append(idx) # Add to metadata which channels exceeded threshold for ch_type in log: col_name = f"THRESHOLD_{ch_type.upper()}" if col_name not in epochs.metadata.columns: epochs.metadata[col_name] = False epochs.metadata.loc[idx, col_name] = True message( "info", f"Marked {len(bad_epochs_thresh)} epochs exceeding voltage thresholds (not dropped)", ) # If not using reject_by_annotation or keeping all epochs, manually track bad annotations if not reject_by_annotation or keep_all_epochs: # Find epochs that overlap with any "bad" or "BAD" annotations bad_epochs = [] bad_annotations = {} # To track which annotation affected each epoch for ann in data.annotations: # Check if annotation description starts with "bad" or "BAD" if ann["description"].lower().startswith("bad"): ann_start = ann["onset"] ann_end = ann["onset"] + ann["duration"] # Check each epoch for idx, event in enumerate(epochs.events): epoch_start = ( event[0] / epochs.info["sfreq"] ) # Convert to seconds epoch_end = epoch_start + (tmax - tmin) # Check for overlap if (epoch_start <= ann_end) and (epoch_end >= ann_start): bad_epochs.append(idx) # Track which annotation affected this epoch if idx not in bad_annotations: bad_annotations[idx] = [] bad_annotations[idx].append(ann["description"]) # Remove duplicates and sort bad_epochs = sorted(list(set(bad_epochs))) # Mark bad epochs in metadata epochs.metadata["BAD_ANNOTATION"] = [ idx in bad_epochs for idx in range(len(epochs)) ] # Add specific annotation types to metadata for idx, annotations in bad_annotations.items(): for annotation in annotations: col_name = annotation.upper() if col_name not in epochs.metadata.columns: epochs.metadata[col_name] = False epochs.metadata.loc[idx, col_name] = True message( "info", f"Marked {len(bad_epochs)} epochs with bad annotations (not dropped)", ) # Save epochs with bad epochs marked but not dropped self._save_epochs_result(result_data=epochs, stage_name=stage_name) # Drop bad epochs only if not keeping all epochs if not keep_all_epochs: epochs_clean.drop(bad_epochs, reason="BAD_ANNOTATION") message("debug", "reordering metadata after dropping") # After epochs_clean.drop(), epochs_clean.events contains the actual surviving events. # epochs.metadata contains the fully augmented metadata for the original set of epochs # (before this manual annotation-based drop). # We need to select rows from epochs.metadata that correspond to the events # actually remaining in epochs_clean. if ( epochs_clean.metadata is not None ): # Should always be true as it's copied # Get sample times of events that survived in epochs_clean surviving_event_samples = epochs_clean.events[:, 0] # Get sample times of the events in the original 'epochs' object # (from which epochs.metadata was derived) original_event_samples = epochs.events[:, 0] # Find the indices in 'original_event_samples' that match 'surviving_event_samples'. # This effectively maps the surviving events in epochs_clean back to their # corresponding rows in the original (and fully augmented) epochs.metadata. # np.isin creates a boolean mask, np.where converts it to indices. kept_original_indices = np.where( np.isin(original_event_samples, surviving_event_samples) )[0] if len(kept_original_indices) != len(epochs_clean.events): message( "error", f"Mismatch when aligning surviving events to original metadata. " f"Expected {len(epochs_clean.events)} matches, found {len(kept_original_indices)}. " f"Metadata might be incorrect.", ) # If there's a mismatch, it indicates a deeper issue, perhaps non-unique event samples # or an unexpected state. For now, we proceed with potentially incorrect metadata # or let MNE raise an error if lengths still don't match later. # A more robust solution might involve raising an error here. # Slice the augmented epochs.metadata using these derived indices. # The resulting DataFrame will have the same number of rows as len(epochs_clean.events). epochs_clean.metadata = epochs.metadata.iloc[ kept_original_indices ].reset_index(drop=True) else: message( "warning", "epochs_clean.metadata was None before assignment, which is unexpected.", ) # If keeping all epochs, use the original epochs for subsequent processing if keep_all_epochs: epochs_clean = epochs message( "info", "Keeping all epochs as requested (keep_all_epochs=True)" ) # Analyze drop log to tally different annotation types drop_log = epochs_clean.drop_log total_epochs = len(drop_log) good_epochs = sum(1 for log in drop_log if len(log) == 0) # Dynamically collect all unique annotation types annotation_types = {} for log in drop_log: if len(log) > 0: # If epoch was dropped for annotation in log: # Convert numpy string to regular string if needed annotation = str(annotation) annotation_types[annotation] = ( annotation_types.get(annotation, 0) + 1 ) message("info", "\nEpoch Drop Log Summary:") message("info", f"Total epochs: {total_epochs}") message("info", f"Good epochs: {good_epochs}") for annotation, count in annotation_types.items(): message("info", f"Epochs with {annotation}: {count}") # Add flags if needed (only if not keeping all epochs) if ( not keep_all_epochs and (good_epochs / total_epochs) < self.EPOCH_RETENTION_THRESHOLD ): flagged_reason = ( f"WARNING: Only {good_epochs / total_epochs * 100}% " "of epochs were kept" ) self._update_flagged_status(flagged=True, reason=flagged_reason) # Add good and total to the annotation_types dictionary annotation_types["KEEP"] = good_epochs annotation_types["TOTAL"] = total_epochs # Update metadata metadata = { "duration": tmax - tmin, "reject_by_annotation": reject_by_annotation, "keep_all_epochs": keep_all_epochs, "initial_epoch_count": len(events_trig), "final_epoch_count": len(epochs_clean), "single_epoch_duration": epochs.times[-1] - epochs.times[0], "single_epoch_samples": epochs.times.shape[0], "initial_duration": (epochs.times[-1] - epochs.times[0]) * len(epochs_clean), "numberSamples": epochs.times.shape[0] * len(epochs_clean), "channelCount": len(epochs.ch_names), "annotation_types": annotation_types, "marked_epochs_file": "post_epochs", "cleaned_epochs_file": ( "post_drop_bads" if not keep_all_epochs else "post_epochs" ), "tmin": tmin, "tmax": tmax, "event_id": event_id, } self._update_metadata("step_create_eventid_epochs", metadata) # Store epochs if hasattr(self, "config") and self.config.get("run_id"): self.epochs = epochs_clean # Save epochs if not keep_all_epochs: self._save_epochs_result( result_data=epochs_clean, stage_name="post_drop_bad_epochs" ) return epochs_clean except Exception as e: message("error", f"Error during event ID epoch creation: {str(e)}") raise RuntimeError(f"Failed to create event ID epochs: {str(e)}") from e