Source code for autoclean.mixins.signal_processing.regular_epochs

"""Regular epochs creation mixin for autoclean tasks.

This module provides functionality for creating regular fixed-length epochs from
continuous EEG data. Regular epochs are time segments of equal duration that are
created at fixed intervals throughout the recording, regardless of event markers.

The RegularEpochsMixin class implements methods for creating these epochs and
handling annotations, allowing users to either automatically reject epochs that
overlap with bad annotations or just mark them in the metadata for later processing.

Regular epoching is particularly useful for resting-state data analysis, where
there are no specific events of interest, but the data needs to be segmented
into manageable chunks for further processing and analysis.

"""

from typing import Dict, Optional, Union

import mne

from autoclean.functions.epoching import create_regular_epochs as _create_regular_epochs
from autoclean.utils.logging import message


[docs] class RegularEpochsMixin: """Mixin class providing regular (fixed-length) epochs creation functionality for EEG data."""
[docs] def create_regular_epochs( self, data: Union[mne.io.Raw, None] = None, tmin: float = -1, tmax: float = 1, baseline: Optional[tuple] = None, volt_threshold: Optional[Dict[str, float]] = None, stage_name: str = "post_epochs", reject_by_annotation: bool = False, export: bool = False, ) -> mne.Epochs: """Create regular fixed-length epochs from raw data. Parameters ---------- data : mne.io.Raw, Optional The raw data to create epochs from. If None, uses self.raw. tmin : float, Optional The start time of the epoch in seconds. Default is -1. tmax : float, Optional The end time of the epoch in seconds. Default is 1. baseline : tuple of float, Optional The time interval to apply baseline correction. Default is None. volt_threshold : dict, Optional Dictionary of channel types and thresholds for rejection, by default None. stage_name : str, Optional Name for saving and metadata tracking. Default is "post_epochs". reject_by_annotation : bool, Optional Whether to automatically reject epochs that overlap with bad annotations, or just mark them in the metadata for later processing. Default is False. export : bool, Optional If True, exports the processed epochs to the stage directory. Default is False. Returns ------- epochs_clean: mne.Epochs The created epochs object with bad epochs marked (and dropped if reject_by_annotation=True) Notes ----- If reject_by_annotation is False, an intermediate file with bad epochs marked but not dropped is saved. The epoching parameters can be customized through the configuration file (autoclean_config.yaml) under the "epoch_settings" section. If enabled, the configuration values will override the default parameters. See Also -------- create_eventid_epochs : For creating epochs based on specific event markers. """ # Check if this step is enabled in the configuration is_enabled, config_value = self._check_step_enabled("epoch_settings") if not is_enabled: message("info", "Epoch creation step is disabled in configuration") return None # Get parameters from config if available if config_value and isinstance(config_value, dict): # Get epoch settings epoch_value = config_value.get("value", {}) if isinstance(epoch_value, dict): tmin = epoch_value.get("tmin", tmin) tmax = epoch_value.get("tmax", tmax) # Get baseline settings baseline_settings = config_value.get("remove_baseline", {}) if isinstance(baseline_settings, dict) and baseline_settings.get( "enabled", False ): baseline = baseline_settings.get("window", baseline) # Get threshold settings threshold_settings = config_value.get("threshold_rejection", {}) if isinstance(threshold_settings, dict) and threshold_settings.get( "enabled", False ): threshold_config = threshold_settings.get("volt_threshold", {}) if isinstance(threshold_config, (int, float)): volt_threshold = {"eeg": float(threshold_config)} elif isinstance(threshold_config, dict): volt_threshold = {k: float(v) for k, v in threshold_config.items()} # Determine which data to use data = self._get_data_object(data) # Type checking if not isinstance(data, mne.io.Raw) and not isinstance( data, mne.io.base.BaseRaw ): raise TypeError("Data must be an MNE Raw object for epoch creation") try: # Use standalone function for core epoch creation message("header", f"Creating regular epochs from {tmin}s to {tmax}s...") epochs = _create_regular_epochs( data=data, tmin=tmin, tmax=tmax, baseline=baseline, reject=volt_threshold, reject_by_annotation=reject_by_annotation, include_metadata=True, # Always include metadata for pipeline preload=True, ) # Note: metadata is now handled by the standalone function # No additional metadata processing needed here since the standalone function handles it # Create a copy for dropping if using amplitude thresholds epochs_clean = epochs.copy() # If not using reject_by_annotation, manually track bad annotations if not reject_by_annotation: # Find epochs that overlap with any "bad" or "BAD" annotations bad_epochs = [] bad_annotations = {} # To track which annotation affected each epoch for ann in data.annotations: # Check if annotation description starts with "bad" or "BAD" if ann["description"].lower().startswith("bad"): ann_start = ann["onset"] ann_end = ann["onset"] + ann["duration"] # Check each epoch for idx, event in enumerate(epochs.events): epoch_start = ( event[0] / epochs.info["sfreq"] ) # Convert to seconds epoch_end = epoch_start + (tmax - tmin) # Check for overlap if (epoch_start <= ann_end) and (epoch_end >= ann_start): bad_epochs.append(idx) # Track which annotation affected this epoch if idx not in bad_annotations: bad_annotations[idx] = [] bad_annotations[idx].append(ann["description"]) # Remove duplicates and sort bad_epochs = sorted(list(set(bad_epochs))) # Mark bad epochs in metadata epochs.metadata["BAD_ANNOTATION"] = [ idx in bad_epochs for idx in range(len(epochs)) ] # Add specific annotation types to metadata for idx, annotations in bad_annotations.items(): for annotation in annotations: col_name = annotation.upper() if col_name not in epochs.metadata.columns: epochs.metadata[col_name] = False epochs.metadata.loc[idx, col_name] = True message( "info", f"Marked {len(bad_epochs)} epochs with bad annotations (not dropped)", ) # Save epochs with bad epochs marked but not dropped self._save_epochs_result( result_data=epochs_clean, stage_name=stage_name ) epochs_clean.drop(bad_epochs, reason="BAD_ANNOTATION") # Reorder metadata after dropping bad epochs if metadata exists if epochs_clean.metadata is not None: message("debug", "reordering metadata after dropping") kept_indices = epochs_clean.selection max_index = epochs.metadata.shape[0] - 1 if kept_indices.max() > max_index: print("Metadata shape:", epochs.metadata.shape) print("Regular indices:", kept_indices) kept_indices = kept_indices - 1 print("Adjusted indices:", kept_indices) epochs_clean.metadata = epochs.metadata.iloc[ kept_indices ].reset_index(drop=True) # Analyze drop log to tally different annotation types drop_log = epochs_clean.drop_log total_epochs = len(drop_log) good_epochs = sum(1 for log in drop_log if len(log) == 0) # Dynamically collect all unique annotation types annotation_types = {} for log in drop_log: if len(log) > 0: # If epoch was dropped for annotation in log: # Convert numpy string to regular string if needed annotation = str(annotation) annotation_types[annotation] = ( annotation_types.get(annotation, 0) + 1 ) message("info", "\nEpoch Drop Log Summary:") message("info", f"Total epochs: {total_epochs}") message("info", f"Good epochs: {good_epochs}") for annotation, count in annotation_types.items(): message("info", f"Epochs with {annotation}: {count}") # Add flags if needed if (good_epochs / total_epochs) < self.EPOCH_RETENTION_THRESHOLD: flagged_reason = ( f"WARNING: Only {good_epochs / total_epochs * 100}% " "of epochs were kept" ) self._update_flagged_status(flagged=True, reason=flagged_reason) # Add good and total to the annotation_types dictionary annotation_types["KEEP"] = good_epochs annotation_types["TOTAL"] = total_epochs # Update metadata metadata = { "duration": tmax - tmin, "reject_by_annotation": reject_by_annotation, "initial_epoch_count": len(epochs), "final_epoch_count": len(epochs_clean), "single_epoch_duration": epochs.times[-1] - epochs.times[0], "single_epoch_samples": epochs.times.shape[0], "initial_duration": (epochs.times[-1] - epochs.times[0]) * len(epochs_clean), "numberSamples": epochs.times.shape[0] * len(epochs_clean), "channelCount": len(epochs.ch_names), "annotation_types": annotation_types, "marked_epochs_file": "post_epochs", "cleaned_epochs_file": "post_drop_bads", "tmin": tmin, "tmax": tmax, } self._update_metadata("step_create_regular_epochs", metadata) # Store epochs if hasattr(self, "config") and self.config.get("run_id"): self.epochs = epochs_clean # Save epochs with default naming self._save_epochs_result( result_data=epochs_clean, stage_name="post_drop_bad_epochs" ) # Export if requested self._auto_export_if_enabled(epochs_clean, stage_name, export) return epochs_clean except Exception as e: message("error", f"Error during regular epoch creation: {str(e)}") raise RuntimeError(f"Failed to create regular epochs: {str(e)}") from e