"""GFP-based epoch cleaning mixin for autoclean tasks.
This module provides functionality for cleaning epochs based on Global Field Power (GFP),
a measure of the spatial standard deviation across all electrodes at each time point.
GFP is a useful metric for identifying epochs with abnormal activity patterns that
may represent artifacts.
The GFPCleanEpochsMixin class implements methods for calculating GFP values for each
epoch, identifying outliers based on statistical thresholds, and optionally creating
visualization plots to help understand the distribution of GFP values across epochs.
GFP-based cleaning is particularly useful for removing epochs with widespread artifacts
that affect multiple channels simultaneously, such as movement artifacts, muscle activity,
or other global disturbances in the EEG signal.
"""
import random
from pathlib import Path
from typing import Optional, Union
import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
from autoclean.utils.logging import message
[docs]
class GFPCleanEpochsMixin:
"""Mixin class providing functionality to clean epochs based on Global Field Power (GFP)."""
[docs]
def gfp_clean_epochs(
self,
epochs: Union[mne.Epochs, None] = None,
gfp_threshold: float = 3.0,
number_of_epochs: Optional[int] = None,
random_seed: Optional[int] = None,
stage_name: str = "post_gfp_clean",
export: bool = False,
) -> mne.Epochs:
"""Clean an MNE Epochs object by removing outlier epochs based on Global Field Power.
Parameters
----------
epochs: mne.Epochs, Optional
The epochs object to clean. If None, uses self.epochs.
gfp_threshold: float
The z-score threshold for GFP-based outlier detection (default: 3.0).
number_of_epochs: int
If specified, randomly selects this number of epochs from the cleaned data.
random_seed: int
Seed for random number generator when selecting epochs.
stage_name: str
Name for saving and metadata tracking. By default "post_gfp_clean".
export : bool, optional
If True, exports the cleaned epochs to the stage directory. Default is False.
Returns
-------
epochs_final : instance of mne.Epochs
The cleaned epochs object with outlier epochs removed
See Also
--------
mne.Epochs
mne.Epochs.get_data
mne.Epochs.copy
Example:
```python
# Clean epochs with default parameters
clean_epochs = task.gfp_clean_epochs()
# Clean epochs with custom parameters and select a specific number of epochs
clean_epochs = task.gfp_clean_epochs(
gfp_threshold=2.5,
number_of_epochs=40,
random_seed=42
)
```
Notes
-----
This method calculates the Global Field Power (GFP) for each epoch, identifies
epochs with abnormal GFP values (outliers), and removes them from the dataset.
GFP is calculated as the standard deviation across all scalp electrodes at each
time point, providing a measure of the spatial variability of the electric field.
The method focuses on scalp electrodes for GFP calculation, excluding non-EEG
channels like EOG or reference electrodes. It can also optionally select a random
subset of the cleaned epochs, which is useful for ensuring a consistent number
of epochs across subjects or conditions.
The method generates visualization plots showing the GFP distribution and outliers,
which are saved to the output directory if a run_id is specified in the configuration.
"""
# Determine which data to use
epochs = self._get_data_object(epochs, use_epochs=True)
# Type checking
if not isinstance(
epochs, mne.Epochs
): # pylint: disable=isinstance-second-argument-not-valid-type
raise TypeError("Data must be an MNE Epochs object for GFP cleaning")
try:
message("header", "Cleaning epochs based on Global Field Power (GFP)")
# Force preload to avoid RuntimeError
if not epochs.preload:
epochs.load_data()
# Create a copy to work with
epochs_clean = epochs.copy()
# Define non-scalp electrodes to exclude
channel_region_map = {
"E17": "OTHER",
"E38": "OTHER",
"E43": "OTHER",
"E44": "OTHER",
"E48": "OTHER",
"E49": "OTHER",
"E56": "OTHER",
"E73": "OTHER",
"E81": "OTHER",
"E88": "OTHER",
"E94": "OTHER",
"E107": "OTHER",
"E113": "OTHER",
"E114": "OTHER",
"E119": "OTHER",
"E120": "OTHER",
"E121": "OTHER",
"E125": "OTHER",
"E126": "OTHER",
"E127": "OTHER",
"E128": "OTHER",
}
# Get scalp electrode indices (all channels except those in channel_region_map)
non_scalp_channels = list(channel_region_map.keys())
all_channels = epochs_clean.ch_names
scalp_channels = [ch for ch in all_channels if ch not in non_scalp_channels]
scalp_indices = [
epochs_clean.ch_names.index(ch)
for ch in scalp_channels
if ch in epochs_clean.ch_names
]
# Calculate Global Field Power (GFP) only for scalp electrodes
message(
"info",
"Calculating Global Field Power (GFP) for each epoch using only scalp electrodes",
)
gfp = np.sqrt(
np.mean(epochs_clean.get_data()[:, scalp_indices, :] ** 2, axis=(1, 2))
) # Shape: (n_epochs,)
# Epoch Statistics
epoch_stats = pd.DataFrame(
{
"epoch": np.arange(len(gfp)),
"gfp": gfp,
"mean_amplitude": epochs_clean.get_data()[:, scalp_indices, :].mean(
axis=(1, 2)
),
"max_amplitude": epochs_clean.get_data()[:, scalp_indices, :].max(
axis=(1, 2)
),
"min_amplitude": epochs_clean.get_data()[:, scalp_indices, :].min(
axis=(1, 2)
),
"std_amplitude": epochs_clean.get_data()[:, scalp_indices, :].std(
axis=(1, 2)
),
}
)
# Remove Outlier Epochs based on GFP
message("info", "Removing outlier epochs based on GFP z-scores")
gfp_mean = epoch_stats["gfp"].mean()
gfp_std = epoch_stats["gfp"].std()
z_scores = np.abs((epoch_stats["gfp"] - gfp_mean) / gfp_std)
good_epochs_mask = z_scores < gfp_threshold
removed_by_gfp = np.sum(~good_epochs_mask)
epochs_final = epochs_clean[good_epochs_mask]
epoch_stats_final = epoch_stats[good_epochs_mask]
message("info", f"Outlier epochs removed based on GFP: {removed_by_gfp}")
# Handle epoch selection with warning if needed
requested_epochs_exceeded = False
if number_of_epochs is not None:
if len(epochs_final) < number_of_epochs:
warning_msg = (
f"Requested number_of_epochs={number_of_epochs} "
f"exceeds the available cleaned epochs={len(epochs_final)}. "
"Using all available epochs."
)
message("warning", warning_msg)
requested_epochs_exceeded = True
number_of_epochs = len(epochs_final)
if random_seed is not None:
random.seed(random_seed)
selected_indices = random.sample(
range(len(epochs_final)), number_of_epochs
)
epochs_final = epochs_final[selected_indices]
epoch_stats_final = epoch_stats_final.iloc[selected_indices]
message(
"info", f"Selected {number_of_epochs} epochs from the cleaned data"
)
# Analyze drop log to tally different annotation types
drop_log = epochs.drop_log
total_epochs = len(drop_log)
good_epochs = sum(1 for log in drop_log if len(log) == 0)
# Dynamically collect all unique annotation types
annotation_types = {}
for log in drop_log:
if len(log) > 0: # If epoch was dropped
for annotation in log:
# Convert numpy string to regular string if needed
annotation = str(annotation)
annotation_types[annotation] = (
annotation_types.get(annotation, 0) + 1
)
# Add good and total to the annotation_types dictionary
annotation_types["KEEP"] = good_epochs
annotation_types["TOTAL"] = total_epochs
# Create GFP barplot if we have a pipeline with a derivative path
if hasattr(self, "pipeline") and hasattr(self, "config"):
self._create_gfp_plots(epoch_stats, epoch_stats_final)
# Update metadata
metadata = {
"initial_epochs": len(epochs),
"final_epochs": len(epochs_final),
"removed_by_gfp": int(removed_by_gfp),
"mean_amplitude": float(epoch_stats_final["mean_amplitude"].mean()),
"max_amplitude": float(epoch_stats_final["max_amplitude"].max()),
"min_amplitude": float(epoch_stats_final["min_amplitude"].min()),
"std_amplitude": float(epoch_stats_final["std_amplitude"].mean()),
"mean_gfp": float(epoch_stats_final["gfp"].mean()),
"gfp_threshold": float(gfp_threshold),
"removed_total": int(removed_by_gfp),
"annotation_types": annotation_types,
"epoch_duration": epochs.times[-1] - epochs.times[0],
"samples_per_epoch": epochs.times.shape[0],
"total_duration_sec": (epochs.times[-1] - epochs.times[0])
* len(epochs_final),
"total_samples": epochs.times.shape[0] * len(epochs_final),
"channel_count": len(epochs.ch_names),
"scalp_channels_used": scalp_channels,
"requested_epochs_exceeded": requested_epochs_exceeded,
}
self._update_metadata("step_gfp_clean_epochs", metadata)
self._update_instance_data(epochs, epochs_final, use_epochs=True)
# Save epochs with default naming
self._save_epochs_result(epochs_final, stage_name)
# Export if requested
self._auto_export_if_enabled(epochs_final, stage_name, export)
message("info", "Epoch GFP cleaning process completed")
return epochs_final
except Exception as e:
message("error", f"Error during GFP epoch cleaning: {str(e)}")
raise RuntimeError(f"Failed to clean epochs using GFP: {str(e)}") from e
def _create_gfp_plots(self, epoch_stats, epoch_stats_final):
"""Create GFP plots and save them to the derivatives path.
Args:
epochs: Original epochs object
epoch_stats: DataFrame with statistics for all epochs
epoch_stats_final: DataFrame with statistics for kept epochs
"""
try:
# Get derivative path
derivatives_path = self.config["derivatives_dir"]
# Create GFP barplot
plt.figure(figsize=(12, 4))
# Plot all epochs in red first (marking removed epochs)
plt.bar(
epoch_stats.index, epoch_stats["gfp"], width=0.8, color="red", alpha=0.3
)
# Then overlay kept epochs in blue
plt.bar(
epoch_stats_final.index,
epoch_stats_final["gfp"],
width=0.8,
color="blue",
)
plt.xlabel("Epoch Number")
plt.ylabel("Global Field Power (GFP)")
plt.title("GFP Values by Epoch (Red = Removed, Blue = Kept)")
# Save plot
plot_fname = derivatives_path.copy().update(
suffix="gfp", extension="png", check=False
)
plt.savefig(Path(plot_fname), dpi=150, bbox_inches="tight")
plt.close()
# Create GFP heatmap with larger figure size and improved readability
plt.figure(figsize=(30, 18))
# Calculate number of rows and columns for grid layout
n_epochs = len(epoch_stats)
n_cols = 8 # Reduced number of columns for larger cells
n_rows = int(np.ceil(n_epochs / n_cols))
# Create a grid of values
grid = np.full((n_rows, n_cols), np.nan)
for i, (idx, gfp) in enumerate(epoch_stats["gfp"].items()):
row = i // n_cols
col = i % n_cols
grid[row, col] = gfp
# Create heatmap with larger spacing between cells
im = plt.imshow(grid, cmap="RdYlBu_r", aspect="auto")
plt.colorbar(im, label="GFP Value (×10⁻⁶)", fraction=0.02, pad=0.04)
# Add text annotations with increased font size and spacing
for i, (idx, gfp) in enumerate(epoch_stats["gfp"].items()):
row = i // n_cols
col = i % n_cols
kept = idx in epoch_stats_final.index
color = "black" if kept else "red"
plt.text(
col,
row,
f"ID: {idx}\nGFP: {gfp:.1e}",
ha="center",
va="center",
color=color,
fontsize=10,
bbox=dict(facecolor="white", alpha=0.8, pad=0.8),
)
# Improve title and labels with larger font sizes
plt.title(
"GFP Heatmap by Epoch (Red = Removed, Black = Kept)",
fontsize=14,
pad=20,
)
plt.xlabel("Column", fontsize=12, labelpad=10)
plt.ylabel("Row", fontsize=12, labelpad=10)
# Adjust layout to prevent text overlap
plt.tight_layout()
# Save heatmap plot with higher DPI for better quality
plot_fname = derivatives_path.copy().update(
suffix="gfp-heatmap", extension="png", check=False
)
plt.savefig(Path(plot_fname), dpi=300, bbox_inches="tight")
plt.close()
except Exception as e: # pylint: disable=broad-exception-caught
message("warning", f"Could not create GFP plots: {str(e)}")
# Continue without plots