# ./src/autoclean/mixins/viz/visualization.py
# pylint: disable=too-many-lines
# pylint: disable=line-too-long
# pylint: disable=invalid-name
"""Visualization mixin for EEG data in autoclean tasks.
This module provides specialized visualization functionality for EEG data in the AutoClean
pipeline. It defines methods for generating plots that visualize different aspects of
EEG processing results, such as:
- Raw data overlays comparing original and cleaned data
- Bad channel visualizations with topographies
- PSD and topographical maps
- MMN ERP analyses
These visualizations help users understand the effects of preprocessing steps and
validate the quality of processed data.
"""
import os
from datetime import datetime
from pathlib import Path
from typing import Any, List, Optional, Tuple
import matplotlib
import matplotlib.pyplot as plt
import mne
import numpy as np
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from autoclean.utils.logging import message
# Force matplotlib to use non-interactive backend for async operations
matplotlib.use("Agg")
[docs]
class VisualizationMixin:
"""Mixin providing visualization methods for EEG data.
This mixin extends the base ReportingMixin with specialized methods for
generating plots and visualizations of EEG data. It provides a comprehensive
set of visualization tools for assessing data quality and understanding the
effects of preprocessing steps.
All visualization methods respect configuration toggles from `autoclean_config.yaml`,
checking if their corresponding step is enabled before execution. Each method
can be individually enabled or disabled via configuration.
Available visualization methods include:
- `plot_raw_vs_cleaned_overlay`: Overlay of original and cleaned EEG data
- `plot_bad_channels_with_topography`: Bad channel visualization with topographies
- `psd_topo_figure`: Combined PSD and topographical maps for frequency bands
- `step_psd_topo_figure`: Wrapper for backwards compatibility
"""
[docs]
def plot_raw_vs_cleaned_overlay(
self,
raw_original: mne.io.Raw,
raw_cleaned: mne.io.Raw,
) -> None:
"""Plot raw data channels over the full duration, overlaying the original and cleaned data.
Parameters
----------
raw_original : mne.io.Raw
Original raw EEG data before cleaning.
raw_cleaned : mne.io.Raw
Cleaned raw EEG data after preprocessing.
Returns
-------
None
Notes
-----
- The method downsamples the data for plotting to reduce file size
- The resulting plot is saved to the derivatives directory with appropriate naming
- Metadata about the plot is stored in the processing database
- Original data is plotted in red, cleaned data in black.
"""
# Ensure that the original and cleaned data have the same channels and times
if raw_original.ch_names != raw_cleaned.ch_names:
raise ValueError(
"Channel names in raw_original and raw_cleaned do not match."
)
if raw_original.times.shape != raw_cleaned.times.shape:
raise ValueError(
"Time vectors in raw_original and raw_cleaned do not match."
)
# Get raw data
channel_labels = raw_original.ch_names
n_channels = len(channel_labels)
sfreq = raw_original.info["sfreq"]
times = raw_original.times
data_original = raw_original.get_data()
data_cleaned = raw_cleaned.get_data()
# Increase downsample factor to reduce file size
desired_sfreq = 100 # Reduced sampling rate to 100 Hz
downsample_factor = int(sfreq // desired_sfreq)
if downsample_factor > 1:
data_original = data_original[:, ::downsample_factor]
data_cleaned = data_cleaned[:, ::downsample_factor]
times = times[::downsample_factor]
# Normalize each channel individually for better visibility
data_original_normalized = np.zeros_like(data_original)
data_cleaned_normalized = np.zeros_like(data_cleaned)
spacing = 10 # Fixed spacing between channels
for idx in range(n_channels):
# Original data
channel_data_original = data_original[idx]
channel_data_original = channel_data_original - np.mean(
channel_data_original
) # Remove DC offset
std = np.std(channel_data_original)
if std == 0:
std = 1 # Avoid division by zero
data_original_normalized[idx] = (
channel_data_original / std
) # Normalize to unit variance
# Cleaned data
channel_data_cleaned = data_cleaned[idx]
channel_data_cleaned = channel_data_cleaned - np.mean(
channel_data_cleaned
) # Remove DC offset
# Use same std for normalization to ensure both signals are on the same scale
data_cleaned_normalized[idx] = channel_data_cleaned / std
# Multiply by a scaling factor to control amplitude
scaling_factor = 2 # Adjust this factor as needed for visibility
data_original_scaled = data_original_normalized * scaling_factor
data_cleaned_scaled = data_cleaned_normalized * scaling_factor
# Calculate offsets for plotting
offsets = np.arange(n_channels) * spacing
# Create plot
total_duration = times[-1] - times[0]
width_per_second = 0.1 # Adjust this factor as needed
fig_width = min(total_duration * width_per_second, 50)
fig_height = max(6, n_channels * 0.25) # Adjusted for better spacing
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
# Plot channels
for idx in range(n_channels):
offset = offsets[idx]
# Plot original data in red
ax.plot(
times,
data_original_scaled[idx] + offset,
color="red",
linewidth=0.5,
linestyle="-",
)
# Plot cleaned data in black
ax.plot(
times,
data_cleaned_scaled[idx] + offset,
color="black",
linewidth=0.5,
linestyle="-",
)
# Set y-ticks and labels
ax.set_yticks(offsets)
ax.set_yticklabels(channel_labels, fontsize=8)
# Customize axes
ax.set_xlabel("Time (seconds)", fontsize=12)
ax.set_title(
"Raw Data Channels: Original vs Cleaned (Full Duration)", fontsize=14
)
ax.set_xlim(times[0], times[-1])
ax.set_ylim(-spacing, offsets[-1] + spacing)
ax.set_ylabel("")
ax.invert_yaxis()
# Add legend
legend_elements = [
Line2D([0], [0], color="red", lw=0.5, linestyle="-", label="Original Data"),
Line2D(
[0], [0], color="black", lw=0.5, linestyle="-", label="Cleaned Data"
),
]
ax.legend(handles=legend_elements, loc="upper right", fontsize=8)
plt.tight_layout()
# Create Artifact Report
derivatives_dir = self.config["derivatives_dir"]
basename = self.config["bids_path"].basename
basename = basename.replace("_eeg", "_raw_vs_cleaned_overlay")
target_figure = derivatives_dir / f"{basename}.png"
# Save as PNG with high DPI for quality
fig.savefig(target_figure, dpi=150, bbox_inches="tight")
plt.close(fig)
message(
"info", f"Raw channels overlay full duration plot saved to {target_figure}"
)
metadata = {
"artifact_reports": {
"creationDateTime": datetime.now().isoformat(),
"plot_raw_vs_cleaned_overlay": Path(target_figure).name,
}
}
self._update_metadata("plot_raw_vs_cleaned_overlay", metadata)
[docs]
def plot_bad_channels_with_topography(
self,
raw_original: mne.io.Raw,
raw_cleaned: mne.io.Raw,
pipeline: Any,
zoom_duration: float = 30,
zoom_start: float = 0,
) -> None:
"""Plot bad channels with a topographical map and time series overlays.
This method creates a comprehensive visualization of bad channels including:
1. A topographical map showing the locations of bad channels
2. Time series plots comparing original and cleaned data for bad channels
3. Both full duration and zoomed-in views of the data
The visualization helps to validate the detection and interpolation of bad channels
during preprocessing.
Parameters
----------
raw_original: mne.io.Raw
Original raw EEG data before cleaning.
raw_cleaned: mne.io.Raw
Cleaned raw EEG data after interpolation of bad channels.
pipeline: Any
Pipeline object containing pipeline metadata and utility functions.
zoom_duration: float, Optional
Duration in seconds for the zoomed-in time series plot.
zoom_start: float, Optional
Start time in seconds for the zoomed-in window.
Returns
-------
None
Notes
-----
- If no bad channels are found, the method returns without creating a plot
- The resulting plot is saved to the derivatives directory
"""
# Check if this step is enabled in the configuration
if not self._check_step_enabled("bad_channel_report_step"):
message("info", "✗ Bad channel visualization step disabled in config")
return
message("info", "✓ Generating bad channel visualization")
# Collect Bad Channels
bad_channels_info = {}
# Mapping from channel to reason(s)
for reason, channels in pipeline.flags.get("ch", {}).items():
for ch in channels:
if ch in bad_channels_info:
if reason not in bad_channels_info[ch]:
bad_channels_info[ch].append(reason)
else:
bad_channels_info[ch] = [reason]
bad_channels = list(bad_channels_info.keys())
if not bad_channels:
message("info", "No bad channels were identified.")
return
# Debugging: Print bad channels
message("info", f"Identified Bad Channels: {bad_channels}")
# Identify Good Channels
all_channels = raw_original.ch_names
good_channels = [ch for ch in all_channels if ch not in bad_channels]
# Debugging: Print good channels count
message("info", f"Number of Good Channels: {len(good_channels)}")
# Extract Data for Bad Channels
picks_bad_original = mne.pick_channels(raw_original.ch_names, bad_channels)
picks_bad_cleaned = mne.pick_channels(raw_cleaned.ch_names, bad_channels)
if len(picks_bad_original) == 0:
message("info", "No bad channels found in original data.")
return
if len(picks_bad_cleaned) == 0:
message("info", "No bad channels found in cleaned data.")
return
data_original, times = raw_original.get_data(
picks=picks_bad_original, return_times=True
)
data_cleaned = raw_cleaned.get_data(picks=picks_bad_cleaned)
channel_labels = [raw_original.ch_names[i] for i in picks_bad_original]
n_channels = len(channel_labels)
# Debugging: Print number of bad channels being plotted
message("info", f"Number of Bad Channels to Plot: {n_channels}")
# Downsample Data if Necessary
sfreq = raw_original.info["sfreq"]
desired_sfreq = 100 # Target sampling rate
downsample_factor = int(sfreq // desired_sfreq)
if downsample_factor > 1:
data_original = data_original[:, ::downsample_factor]
data_cleaned = data_cleaned[:, ::downsample_factor]
times = times[::downsample_factor]
message(
"info",
f"Data downsampled by a factor of {downsample_factor} to {desired_sfreq} Hz.",
)
# Normalize and Scale Data
data_original_normalized = np.zeros_like(data_original)
data_cleaned_normalized = np.zeros_like(data_cleaned)
spacing = 10 # Fixed spacing between channels
for idx in range(n_channels):
# Original data
channel_data_original = data_original[idx]
channel_data_original = channel_data_original - np.mean(
channel_data_original
) # Remove DC offset
std = np.std(channel_data_original)
if std == 0:
std = 1 # Avoid division by zero
data_original_normalized[idx] = (
channel_data_original / std
) # Normalize to unit variance
# Cleaned data
channel_data_cleaned = data_cleaned[idx]
channel_data_cleaned = channel_data_cleaned - np.mean(
channel_data_cleaned
) # Remove DC offset
# Use same std for normalization to ensure both signals are on the same scale
data_cleaned_normalized[idx] = channel_data_cleaned / std
# Multiply by a scaling factor to control amplitude
scaling_factor = 2 # Adjust this factor as needed for visibility
data_original_scaled = data_original_normalized * scaling_factor
data_cleaned_scaled = data_cleaned_normalized * scaling_factor
# Calculate offsets for plotting
offsets = np.arange(n_channels) * spacing
# Create the figure
fig = plt.figure(figsize=(20, 15))
gs = GridSpec(2, 2, height_ratios=[1, 2], width_ratios=[1, 2])
# Topographical Subplot
ax_topo = fig.add_subplot(gs[0, 0])
self._plot_bad_channels_topo(raw_original, bad_channels, ax_topo)
# Full Duration Time Series Subplot
ax_full = fig.add_subplot(gs[0, 1])
for idx in range(n_channels):
# Plot original data
ax_full.plot(
times,
data_original_scaled[idx] + offsets[idx],
color="red",
linewidth=1,
linestyle="-",
)
# Plot cleaned data
ax_full.plot(
times,
data_cleaned_scaled[idx] + offsets[idx],
color="black",
linewidth=1,
linestyle="-",
)
ax_full.set_xlabel("Time (seconds)", fontsize=14)
ax_full.set_ylabel("Bad Channels", fontsize=14)
ax_full.set_title(
"Bad Channels: Original vs Interpolated (Full Duration)", fontsize=16
)
ax_full.set_xlim(times[0], times[-1])
ax_full.set_ylim(-spacing, offsets[-1] + spacing)
ax_full.set_yticks([]) # Hide y-ticks
ax_full.invert_yaxis()
# Add legend
legend_elements = [
Line2D([0], [0], color="red", lw=2, linestyle="-", label="Original Data"),
Line2D(
[0], [0], color="black", lw=2, linestyle="-", label="Interpolated Data"
),
]
ax_full.legend(handles=legend_elements, loc="upper right", fontsize=12)
# Zoomed-In Time Series Subplot
ax_zoom = fig.add_subplot(gs[1, 1])
for idx in range(n_channels):
# Plot original data
ax_zoom.plot(
times,
data_original_scaled[idx] + offsets[idx],
color="red",
linewidth=1,
linestyle="-",
)
# Plot cleaned data
ax_zoom.plot(
times,
data_cleaned_scaled[idx] + offsets[idx],
color="black",
linewidth=1,
linestyle="-",
)
# Calculate zoom window
zoom_end = min(zoom_start + zoom_duration, times[-1])
ax_zoom.set_xlim(zoom_start, zoom_end)
ax_zoom.set_ylim(-spacing, offsets[-1] + spacing)
ax_zoom.set_xlabel("Time (seconds)", fontsize=14)
ax_zoom.set_ylabel("Bad Channels", fontsize=14)
ax_zoom.set_title(
f"Bad Channels: Original vs Interpolated (Zoom: {zoom_start}-{zoom_end} s)",
fontsize=16,
)
ax_zoom.set_yticks([]) # Hide y-ticks
ax_zoom.invert_yaxis()
# Add legend
ax_zoom.legend(handles=legend_elements, loc="upper right", fontsize=12)
# Detailed Bad Channels Info Subplot
ax_info = fig.add_subplot(gs[1, 0])
ax_info.axis("off") # Turn off axes
info_text = "Bad Channels Information:\n\n"
for ch, reasons in bad_channels_info.items():
info_text += f"{ch}: {', '.join(reasons)}\n"
ax_info.text(
0.05,
0.95,
info_text,
transform=ax_info.transAxes,
fontsize=12,
verticalalignment="top",
bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.8),
)
fig.tight_layout()
# Save the figure
output_dir = Path(pipeline.run_path) / "derivatives" / "figures"
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
output_file = output_dir / f"bad_channels_{timestamp}.png"
fig.savefig(output_file, dpi=150, bbox_inches="tight")
plt.close(fig)
message("success", f"Bad channels plot saved to {output_file}")
# Save metadata to pipeline
pipeline.add_metadata(
{
"bad_channels_plot": {
"file": str(output_file),
"n_bad_channels": len(bad_channels),
"bad_channels": bad_channels,
"reasons": {
ch: reasons for ch, reasons in bad_channels_info.items()
},
}
}
)
# Check if configuration checking is enabled and the step is enabled
if hasattr(self, "_check_step_enabled"):
is_enabled, _ = self._check_step_enabled("bad_channel_report_step")
if not is_enabled:
message(
"info",
"✗ Bad channel report generation is disabled in configuration",
)
return
# Get bad channels from the cleaned data
bad_channels = raw_cleaned.info["bads"]
if not bad_channels:
message(
"info", "No bad channels found. Skipping bad channel visualization."
)
return
message(
"info",
f"Generating bad channel visualization for {len(bad_channels)} channels: {', '.join(bad_channels)}",
)
# Create Artifact Report
derivatives_path = pipeline.get_derivative_path(self.config["bids_path"])
# Target file path
target_figure = str(
derivatives_path.copy().update(
suffix="step_bad_channels_topo", extension=".png", datatype="eeg"
)
)
# Create a figure to visualize bad channels
n_bad_channels = len(bad_channels)
fig = plt.figure(figsize=(12, 8 + n_bad_channels * 1.5))
gs = GridSpec(
2 + n_bad_channels, 2, height_ratios=[2] + [1.5] * n_bad_channels + [0.5]
)
# 1. Plot topographical map of all channels with bad channels highlighted
ax_topo = fig.add_subplot(gs[0, :])
self._plot_bad_channels_topo(raw_original, bad_channels, ax_topo)
# 2. Plot time series for each bad channel - full duration
message("info", "Generating full duration time series plots for bad channels")
sfreq = raw_original.info["sfreq"]
for i, bad_ch in enumerate(bad_channels):
ax_full = fig.add_subplot(gs[i + 1, 0])
self._plot_bad_channel_timeseries(
raw_original,
raw_cleaned,
bad_ch,
ax_full,
full_duration=True,
title_suffix="Full Duration",
)
# 3. Plot time series for each bad channel - zoomed in
message("info", "Generating zoomed time series plots for bad channels")
for i, bad_ch in enumerate(bad_channels):
ax_zoom = fig.add_subplot(gs[i + 1, 1])
self._plot_bad_channel_timeseries(
raw_original,
raw_cleaned,
bad_ch,
ax_zoom,
full_duration=False,
zoom_start=zoom_start,
zoom_duration=zoom_duration,
title_suffix=f"Zoom {zoom_start}-{zoom_start + zoom_duration}s",
)
# Add legend at the bottom
ax_legend = fig.add_subplot(gs[-1, :])
ax_legend.axis("off")
ax_legend.legend(
[
plt.Line2D([0], [0], color="red", lw=1),
plt.Line2D([0], [0], color="blue", lw=1),
],
["Original Data", "Cleaned/Interpolated Data"],
loc="center",
ncol=2,
frameon=False,
)
plt.tight_layout()
# Save figure
fig.savefig(target_figure, dpi=150, bbox_inches="tight")
plt.close(fig)
message("info", f"Bad channels visualization saved to {target_figure}")
metadata = {
"artifact_reports": {
"creationDateTime": datetime.now().isoformat(),
"bad_channels_visualization": Path(target_figure).name,
"bad_channels": bad_channels,
}
}
self._update_metadata("bad_channels_visualization", metadata)
def _plot_bad_channels_topo(
self, raw: mne.io.Raw, bad_channels: List[str], ax
) -> None:
"""Plot topographical map highlighting bad channels.
Parameters
----------
raw : mne.io.Raw
Raw EEG data with channel information.
bad_channels : list of str
List of bad channel names.
ax : matplotlib.axes.Axes
Axes to plot on.
"""
# Create data array for topography where bad channels have value 1, good channels 0
ch_names = raw.ch_names
data = np.zeros(len(ch_names))
for bad_ch in bad_channels:
if bad_ch in ch_names:
data[ch_names.index(bad_ch)] = 1
# Plot topography
mne.viz.plot_topomap(
data,
raw.info,
axes=ax,
show=False,
cmap="RdBu_r",
vmin=0,
vmax=1,
outlines="head",
contours=0,
sensors=True,
)
# Add title
ax.set_title(f"Bad Channels Topography (n={len(bad_channels)})")
def _plot_bad_channel_timeseries(
self,
raw_original: mne.io.Raw,
raw_cleaned: mne.io.Raw,
channel: str,
ax,
full_duration: bool = True,
zoom_start: float = 0,
zoom_duration: float = 30,
title_suffix: str = "",
) -> None:
"""Plot time series for a specific channel.
Parameters:
-----------
raw_original : mne.io.Raw
Original raw EEG data before cleaning.
raw_cleaned : mne.io.Raw
Cleaned raw EEG data after preprocessing.
channel : str
Channel name to plot.
ax : matplotlib.axes.Axes
Axes to plot on.
full_duration : bool, optional
Whether to plot the full duration or a zoomed-in section. Default is True.
zoom_start : float, optional
Start time in seconds for the zoomed-in window. Default is 0 seconds.
zoom_duration : float, optional
Duration in seconds for the zoomed-in time series plot. Default is 30 seconds.
title_suffix : str, optional
Suffix to add to the plot title.
"""
# Get channel data
sfreq = raw_original.info["sfreq"]
ch_idx = raw_original.ch_names.index(channel)
if full_duration:
data_orig = raw_original.get_data(picks=[ch_idx])[0]
data_clean = raw_cleaned.get_data(picks=[ch_idx])[0]
times = raw_original.times
# Downsample for faster plotting if too many data points
if len(times) > 10000: # Arbitrary threshold
downsample_factor = max(1, int(len(times) / 10000))
data_orig = data_orig[::downsample_factor]
data_clean = data_clean[::downsample_factor]
times = times[::downsample_factor]
else:
# Extract zoomed-in data
start_idx = int(zoom_start * sfreq)
end_idx = int((zoom_start + zoom_duration) * sfreq)
data_orig = raw_original.get_data(
picks=[ch_idx], start=start_idx, stop=end_idx
)[0]
data_clean = raw_cleaned.get_data(
picks=[ch_idx], start=start_idx, stop=end_idx
)[0]
times = np.arange(len(data_orig)) / sfreq + zoom_start
# Plot data
ax.plot(times, data_orig, color="red", lw=0.8)
ax.plot(times, data_clean, color="blue", lw=0.8)
# Add labels and title
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude (µV)")
ax.set_title(f"Channel: {channel} - {title_suffix}")
# Add grid
ax.grid(True, linestyle="--", alpha=0.6)
def _plot_psd(
self,
fig,
gs,
freqs,
psd_original_mean_mV2,
psd_cleaned_mean_mV2,
psd_original_rel,
psd_cleaned_rel,
num_bands,
) -> None:
"""Helper function to create PSD plots.
Parameters:
-----------
fig : matplotlib.figure.Figure
Figure object to plot on
gs : matplotlib.gridspec.GridSpec
GridSpec for organizing subplots
freqs : numpy.ndarray
Frequency vector for PSD
psd_original_mean_mV2 : numpy.ndarray
Original mean PSD in mV²/Hz
psd_cleaned_mean_mV2 : numpy.ndarray
Cleaned mean PSD in mV²/Hz
psd_original_rel : numpy.ndarray
Original relative PSD in %
psd_cleaned_rel : numpy.ndarray
Cleaned relative PSD in %
num_bands : int
Number of frequency bands for subplot width
"""
# Plot Absolute PSD
ax1 = fig.add_subplot(gs[0, : num_bands // 2])
ax1.semilogy(freqs, psd_original_mean_mV2, color="red", label="Original")
ax1.semilogy(freqs, psd_cleaned_mean_mV2, color="black", label="Cleaned")
ax1.set_xlabel("Frequency (Hz)")
ax1.set_ylabel("PSD (mV²/Hz)")
ax1.set_title("Absolute PSD")
ax1.legend(loc="upper right")
ax1.grid(True, which="both", ls="--")
# Plot Relative PSD
ax2 = fig.add_subplot(gs[0, num_bands // 2 :])
ax2.plot(freqs, psd_original_rel, color="red", label="Original")
ax2.plot(freqs, psd_cleaned_rel, color="black", label="Cleaned")
ax2.set_xlabel("Frequency (Hz)")
ax2.set_ylabel("Relative Power (%)")
ax2.set_title("Relative PSD")
ax2.legend(loc="upper right")
ax2.grid(True)
def _plot_topomaps(
self,
fig,
gs,
bands,
band_powers_orig,
band_powers_clean,
raw_original,
raw_cleaned,
outlier_channels_orig,
outlier_channels_clean,
):
"""Helper function to create topographical maps"""
# Second row: Topomaps for original data
for i, (band, power) in enumerate(zip(bands, band_powers_orig)):
band_name, l_freq, h_freq = band
ax = fig.add_subplot(gs[1, i])
mne.viz.plot_topomap(
power, raw_original.info, axes=ax, show=False, contours=0, cmap="jet"
)
mean_power = np.mean(power)
ax.set_title(
f"Original: {band_name}\n({l_freq}-{h_freq} Hz)\nMean Power: {mean_power:.2e} mV²",
fontsize=10,
)
# Annotate outlier channels
outliers = outlier_channels_orig[band_name]
if outliers:
ax.annotate(
f"Outliers:\n{', '.join(outliers)}",
xy=(0.5, -0.15),
xycoords="axes fraction",
ha="center",
va="top",
fontsize=8,
color="red",
)
# Third row: Topomaps for cleaned data
for i, (band, power) in enumerate(zip(bands, band_powers_clean)):
band_name, l_freq, h_freq = band
ax = fig.add_subplot(gs[2, i])
mne.viz.plot_topomap(
power, raw_cleaned.info, axes=ax, show=False, contours=0, cmap="jet"
)
mean_power = np.mean(power)
ax.set_title(
f"Cleaned: {band_name}\n({l_freq}-{h_freq} Hz)\nMean Power: {mean_power:.2e} mV²",
fontsize=10,
)
# Annotate outlier channels
outliers = outlier_channels_clean[band_name]
if outliers:
ax.annotate(
f"Outliers:\n{', '.join(outliers)}",
xy=(0.5, -0.15),
xycoords="axes fraction",
ha="center",
va="top",
fontsize=8,
color="red",
)