Source code for autoclean.mixins.signal_processing.channels

"""Channel operations mixin for autoclean tasks."""

from typing import Dict, List, Union

import mne

from autoclean.utils.logging import message


[docs] class ChannelsMixin: """Mixin class providing channel operations functionality for EEG data."""
[docs] def clean_bad_channels( self, data: Union[mne.io.Raw, None] = None, correlation_thresh: float = 0.35, deviation_thresh: float = 2.5, ransac_sample_prop: float = 0.35, ransac_corr_thresh: float = 0.65, ransac_frac_bad: float = 0.25, ransac_channel_wise: bool = False, random_state: int = 1337, cleaning_method: Union[str, None] = "interpolate", reset_bads: bool = True, stage_name: str = "post_bad_channels", ) -> mne.io.Raw: """Detect and mark bad channels using various methods. This method uses the MNE NoisyChannels class to detect bad channels using SNR, correlation, deviation, and RANSAC methods. Parameters ---------- data : mne.io.Raw, Optional The data object to detect bad channels from. If None, uses self.raw. correlation_thresh : float, Optional Threshold for correlation-based detection. deviation_thresh : float, Optional Threshold for deviation-based detection. ransac_sample_prop : float, Optional Proportion of samples to use for RANSAC. ransac_corr_thresh : float, Optional Threshold for RANSAC-based detection. ransac_frac_bad : float, Optional Fraction of bad channels to use for RANSAC. ransac_channel_wise : bool, Optional Whether to use channel-wise RANSAC. random_state : int, Optional Random state for reproducibility. cleaning_method : str, Optional Method to use for cleaning bad channels. Options are 'interpolate' or 'drop' or None(default). reset_bads : bool, Optional Whether to reset bad channels. stage_name : str, Optional Name for saving and metadata. Returns ------- result_raw : instance of mne.io.Raw The raw data object with bad channels marked or cleaned See Also -------- :py:class:`pyprep.find_noisy_channels.NoisyChannels` : For more information on the NoisyChannels class """ # Determine which data to use data = self._get_data_object(data) # Type checking if not isinstance(data, mne.io.Raw) and not isinstance( data, mne.io.base.BaseRaw ): raise TypeError("Data must be an MNE Raw object for bad channel detection") try: # Check if "eog" is in channel types and handle EOG channels if needed if ( hasattr(self, "config") and self.config.get("task") and "eog" in data.get_channel_types() ): task = self.config.get("task") if ( not self.config.get("tasks", {}) .get(task, {}) .get("settings", {}) .get("eog_step", {}) .get("enabled", True) ): # If EOG step is disabled, temporarily set EOG channels to EEG type eog_picks = mne.pick_types(data.info, eog=True) eog_ch_names = [data.ch_names[idx] for idx in eog_picks] data.set_channel_types({ch: "eeg" for ch in eog_ch_names}) # Create a copy of the data result_raw = data.copy() # Setup options options = { "random_state": random_state, "correlation_thresh": correlation_thresh, "deviation_thresh": deviation_thresh, "ransac_sample_prop": ransac_sample_prop, "ransac_corr_thresh": ransac_corr_thresh, "ransac_frac_bad": ransac_frac_bad, "ransac_channel_wise": ransac_channel_wise, } # Call standalone function for bad channel detection from autoclean.functions.artifacts.channels import detect_bad_channels bad_channels = detect_bad_channels( data=result_raw, correlation_thresh=options["correlation_thresh"], deviation_thresh=options["deviation_thresh"], ransac_sample_prop=options["ransac_sample_prop"], ransac_corr_thresh=options["ransac_corr_thresh"], ransac_frac_bad=options["ransac_frac_bad"], ransac_channel_wise=options["ransac_channel_wise"], random_state=options["random_state"], return_by_method=True, verbose=False, ) # Extract individual method results for compatibility uncorrelated_channels = bad_channels["correlation"] deviation_channels = bad_channels["deviation"] ransac_channels = bad_channels["ransac"] # Get the overall bad channels list for backward compatibility all_bad_channels = bad_channels.get("combined", []) # Check for reference channels to exclude from bad channels ref_channels = [] if hasattr(self, "config"): task = self.config.get("task") ref_step = ( self.config.get("tasks", {}) .get(task, {}) .get("settings", {}) .get("reference_step", {}) ) if ref_step and ref_step.get("enabled") and ref_step.get("value"): ref_channels = ref_step.get("value", []) message( "info", f"Excluding reference channel(s) from bad channels: {ref_channels}", ) # Add bad channels to info, but exclude reference channels filtered_bad_channels = [ str(ch) for ch in all_bad_channels if str(ch) not in ref_channels ] result_raw.info["bads"].extend(filtered_bad_channels) # Remove duplicates bads = list(set(result_raw.info["bads"])) result_raw.info["bads"] = bads if cleaning_method == "interpolate": result_raw.interpolate_bads(reset_bads=reset_bads) if cleaning_method == "drop": result_raw.drop_channels(result_raw.info["bads"]) result_raw.info["bads"] = [] if hasattr(self.raw, "bad_channels"): total_bads = self.raw.bad_channels total_bads.extend(bads) total_bads = list(set(total_bads)) self.raw.bad_channels = total_bads else: self.raw.bad_channels = bads if ( len(self.raw.bad_channels) / result_raw.info["nchan"] > self.BAD_CHANNEL_THRESHOLD ): self.flagged = True warning = ( f"WARNING: {len(self.raw.bad_channels) / result_raw.info['nchan']:.2%} " "bad channels detected" ) self.flagged_reasons.append(warning) message("warning", f"Flagging: {warning}") message("info", f"Detected {len(bads)} bad channels: {bads}") # Update metadata metadata = { "method": "NoisyChannels", "options": options, "channelCount": len(result_raw.ch_names), "durationSec": int(result_raw.n_times) / result_raw.info["sfreq"], "numberSamples": int(result_raw.n_times), "bads": bads, "uncorrelated_channels": uncorrelated_channels, "deviation_channels": deviation_channels, "ransac_channels": ransac_channels, } self._update_metadata("step_clean_bad_channels", metadata) # Save the result self._save_raw_result(result_raw, stage_name) # Update self.raw if we're using it self._update_instance_data(data, result_raw) return result_raw except Exception as e: message("error", f"Error during bad channel detection: {str(e)}") raise RuntimeError(f"Failed to detect bad channels: {str(e)}") from e
[docs] def drop_channels( self, data: Union[mne.io.Raw, mne.Epochs, None] = None, channels: List[str] = None, stage_name: str = "drop_channels", use_epochs: bool = False, ) -> Union[mne.io.Raw, mne.Epochs]: """Drop specified channels from the data. This method removes specified channels from the data. Parameters ---------- data : mne.io.Raw or mne.Epochs, Optional The data object to drop channels from. If None, uses self.raw or self.epochs. channels : List[str], Optional List of channel names to drop. stage_name : str, Optional Name for saving and metadata. use_epochs : bool, Optional If True and data is None, uses self.epochs instead of self.raw. Returns ------- result_data : instance of mne.io.Raw or mne.Epochs The data object with channels dropped See Also -------- :py:meth:`mne.io.Raw.drop_channels` : For MNE's raw data channel dropping functionality :py:meth:`mne.Epochs.drop_channels` : For MNE's epochs channel dropping functionality """ # Check if channels is provided if channels is None: is_enabled, config_value = self._check_step_enabled("drop_outerlayer") if not is_enabled: message("info", "Channel dropping is disabled in configuration") return data # Get channels from config channels = config_value if not channels: message("warning", "No channels specified for dropping in config") return data # Determine which data to use data = self._get_data_object(data, use_epochs) # Type checking if not isinstance( data, (mne.io.Raw, mne.Epochs) ): # pylint: disable=isinstance-second-argument-not-valid-type raise TypeError( "Data must be an MNE Raw or Epochs object for dropping channels" ) try: # Drop channels message("header", "Dropping channels...") result_data = data.copy().drop_channels(channels) message("info", f"Dropped {len(channels)} channels: {channels}") # Update metadata metadata = { "channels_dropped": channels, "channels_remaining": len(result_data.ch_names), } self._update_metadata("step_drop_channels", metadata) # Save the result if it's a Raw object if isinstance(result_data, mne.io.Raw): self._save_raw_result(result_data, stage_name) # Update self.raw or self.epochs self._update_instance_data(data, result_data, use_epochs) return result_data except Exception as e: message("error", f"Error during channel dropping: {str(e)}") raise RuntimeError(f"Failed to drop channels: {str(e)}") from e
[docs] def set_channel_types( self, data: Union[mne.io.Raw, mne.Epochs, None] = None, ch_types_dict: Dict[str, str] = None, stage_name: str = "set_channel_types", use_epochs: bool = False, ) -> Union[mne.io.Raw, mne.Epochs]: """Set channel types for specific channels. This method sets the type of specific channels (e.g., marking channels as EOG). Parameters ---------- data : mne.io.Raw or mne.Epochs, Optional The data object to set channel types for. If None, uses self.raw or self.epochs. ch_types_dict : dict, Optional Dictionary mapping channel names to types (e.g., {'E1': 'eog'}) stage_name : str, Optional Name for saving and metadata. use_epochs : bool, Optional If True and data is None, uses self.epochs instead of self.raw. Returns ------- result_data : instance of mne.io.Raw or mne.Epochs The data object with updated channel types """ # Check if ch_types_dict is provided if ch_types_dict is None or len(ch_types_dict) == 0: # Check if eog_step is enabled in configuration is_enabled, config_value = self._check_step_enabled("eog_step") if not is_enabled: message("info", "Channel type setting is disabled in configuration") return data # Get channel types from config ch_types_dict = config_value if not ch_types_dict: message("warning", "No channel types specified in config") return data # Determine which data to use data = self._get_data_object(data, use_epochs) # Type checking if not isinstance( data, (mne.io.Raw, mne.Epochs) ): # pylint: disable=isinstance-second-argument-not-valid-type raise TypeError( "Data must be an MNE Raw or Epochs object for setting channel types" ) try: # Set channel types message("header", "Setting channel types...") result_data = data.copy().set_channel_types(ch_types_dict) message("info", f"Set types for {len(ch_types_dict)} channels") # Update metadata metadata = {"channel_types": ch_types_dict} self._update_metadata("set_channel_types", metadata) # Save the result if it's a Raw object if isinstance(result_data, mne.io.Raw): self._save_raw_result(result_data, stage_name) # Update self.raw or self.epochs self._update_instance_data(data, result_data, use_epochs) return result_data except Exception as e: message("error", f"Error during setting channel types: {str(e)}") raise RuntimeError(f"Failed to set channel types: {str(e)}") from e