Source code for autoclean.mixins.viz.ica

"""ICA reporting mixin for autoclean tasks.

This module provides specialized ICA visualization and reporting functionality for
the AutoClean pipeline. It defines methods for generating comprehensive visualizations
and reports of Independent Component Analysis (ICA) results, including:

- Full-duration component activations
- Component properties and classifications
- Rejected components with their properties
- Interactive and static reports

These reports help users understand the ICA decomposition and validate component rejection
decisions to ensure appropriate artifact removal.

"""

import os
from datetime import datetime
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.gridspec import GridSpec

from autoclean.utils.logging import message

# Force matplotlib to use non-interactive backend for async operations
matplotlib.use("Agg")


[docs] class ICAReportingMixin: """Mixin providing ICA reporting functionality for EEG data. This mixin extends the base ReportingMixin with specialized methods for generating visualizations and reports of ICA results. It provides tools for assessing component properties, visualizing component activations, and documenting component rejection decisions. All reporting methods respect configuration toggles from `autoclean_config.yaml`, checking if their corresponding step is enabled before execution. Each method can be individually enabled or disabled via configuration. Available ICA reporting methods include: - `plot_ica_full`: Plot all ICA components over the full time series - `generate_ica_reports`: Create a comprehensive report of ICA decomposition results - `verify_topography_plot`: Use a basicica topograph to verify MEA channel placement. """
[docs] def plot_ica_full(self) -> plt.Figure: """Plot ICA components over the full time series with their labels and probabilities. This method creates a figure showing each ICA component's time course over the full time series. Components are color-coded by their classification/rejection status, and probability scores are indicated for each component. Returns ------- matplotlib.figure.Figure The generated figure with ICA components. Raises ------ ValueError If no ICA object is found in the pipeline. Examples -------- >>> # After performing ICA >>> fig = task.plot_ica_full() >>> plt.show() Notes: - Components classified as artifacts are highlighted in red - Classification probabilities are shown for each component - The method respects configuration settings via the `ica_full_plot_step` config """ # Get raw and ICA from pipeline raw = self.raw.copy() ica = self.final_ica ic_labels = self.ica_flags # Get ICA activations and create time vector ica_sources = ica.get_sources(raw) ica_data = ica_sources.get_data() times = raw.times n_components, _ = ica_data.shape # Normalize each component individually for better visibility for idx in range(n_components): component = ica_data[idx] # Scale to have a consistent peak-to-peak amplitude ptp = np.ptp(component) if ptp == 0: scaling_factor = 2.5 # Avoid division by zero else: scaling_factor = 2.5 / ptp ica_data[idx] = component * scaling_factor # Determine appropriate spacing spacing = 2 # Fixed spacing between components # Calculate figure size proportional to duration total_duration = times[-1] - times[0] width_per_second = 0.1 # Increased from 0.02 to 0.1 for wider view fig_width = total_duration * width_per_second max_fig_width = 200 # Doubled from 100 to allow wider figures fig_width = min(fig_width, max_fig_width) fig_height = max(6, n_components * 0.5) # Ensure a minimum height # Create plot with wider figure fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Create a colormap for the components cmap = plt.cm.get_cmap("tab20", n_components) line_colors = [cmap(i) for i in range(n_components)] # Plot components in original order for idx in range(n_components): offset = idx * spacing ax.plot( times, ica_data[idx] + offset, color=line_colors[idx], linewidth=0.5 ) # Set y-ticks and labels yticks = [idx * spacing for idx in range(n_components)] yticklabels = [] for idx in range(n_components): label_text = ( f"IC{idx + 1}: {ic_labels['ic_type'][idx]} " f"({ic_labels['confidence'][idx]:.2f})" ) yticklabels.append(label_text) ax.set_yticks(yticks) ax.set_yticklabels(yticklabels, fontsize=8) # Customize axes ax.set_xlabel("Time (seconds)", fontsize=12) ax.set_title("ICA Component Activations (Full Duration)", fontsize=14) ax.set_xlim(times[0], times[-1]) # Adjust y-axis limits ax.set_ylim(-spacing, (n_components - 1) * spacing + spacing) # Remove y-axis label as we have custom labels ax.set_ylabel("") # Invert y-axis to have the first component at the top ax.invert_yaxis() # Color the labels red or black based on component type artifact_types = ["eog", "muscle", "ecg", "other"] for ticklabel, idx in zip(ax.get_yticklabels(), range(n_components)): ic_type = ic_labels["ic_type"][idx] if ic_type in artifact_types: ticklabel.set_color("red") else: ticklabel.set_color("black") # Adjust layout plt.tight_layout() derivatives_dir = Path(self.config["derivatives_dir"]) basename = self.config["bids_path"].basename basename = basename.replace("_eeg", "_ica_components_full_duration") target_figure = derivatives_dir / basename # Save figure with higher DPI for better resolution of wider plot fig.savefig(target_figure, dpi=300, bbox_inches="tight") metadata = { "artifact_reports": { "creationDateTime": datetime.now().isoformat(), "ica_components_full_duration": Path(target_figure).name, } } self._update_metadata("plot_ica_full", metadata) return fig
[docs] def generate_ica_reports( self, duration: int = 10, ) -> None: """Generate comprehensive ICA reports using the _plot_ica_components method. Parameters ---------- duration : Optional[int] Duration in seconds for plotting time series data """ # Generate report for all components report_filename = self._plot_ica_components( duration=duration, components="all", ) metadata = { "artifact_reports": { "creationDateTime": datetime.now().isoformat(), "ica_all_components": report_filename, } } self._update_metadata("generate_ica_reports", metadata) # Generate report for rejected components report_filename = self._plot_ica_components( duration=duration, components="rejected", ) metadata = { "artifact_reports": { "creationDateTime": datetime.now().isoformat(), "ica_rejected_components": report_filename, } } self._update_metadata("generate_ica_reports", metadata)
def _plot_ica_components( self, duration: int = 10, components: str = "all", ): """ Plots ICA components with labels and saves reports. Parameters: ----------- duration : int Duration in seconds to plot. components : str 'all' to plot all components, 'rejected' to plot only rejected components. """ # Get raw and ICA from pipeline raw = self.raw ica = self.final_ica ic_labels = self.ica_flags # Determine components to plot if components == "all": component_indices = range(ica.n_components_) report_name = "ica_components_all" elif components == "rejected": component_indices = ica.exclude report_name = "ica_components_rejected" if not component_indices: print( "No components were rejected. Skipping rejected components report." ) return else: raise ValueError("components parameter must be 'all' or 'rejected'.") # Get ICA activations ica_sources = ica.get_sources(raw) ica_data = ica_sources.get_data() # Limit data to specified duration sfreq = raw.info["sfreq"] n_samples = int(duration * sfreq) times = raw.times[:n_samples] # Create output path for the PDF report derivatives_dir = Path(self.config["derivatives_dir"]) basename = self.config["bids_path"].basename basename = basename.replace("_eeg", report_name) pdf_path = derivatives_dir / basename pdf_path = pdf_path.with_suffix(".pdf") # Remove existing file if os.path.exists(pdf_path): os.remove(pdf_path) with PdfPages(pdf_path) as pdf: # Calculate how many components to show per page components_per_page = 20 num_pages = int(np.ceil(len(component_indices) / components_per_page)) # Create summary tables split across pages for page in range(num_pages): start_idx = page * components_per_page end_idx = min((page + 1) * components_per_page, len(component_indices)) page_components = component_indices[start_idx:end_idx] fig_table = plt.figure(figsize=(11, 8.5)) ax_table = fig_table.add_subplot(111) ax_table.axis("off") # Prepare table data for this page table_data = [] colors = [] for idx in page_components: comp_info = ic_labels.iloc[idx] table_data.append( [ f"IC{idx + 1}", comp_info["ic_type"], f"{comp_info['confidence']:.2f}", "Yes" if idx in ica.exclude else "No", ] ) # Define colors for different IC types color_map = { "brain": "#d4edda", # Light green "eog": "#f9e79f", # Light yellow "muscle": "#f5b7b1", # Light red "ecg": "#d7bde2", # Light purple, "ch_noise": "#ffd700", # Light orange "line_noise": "#add8e6", # Light blue "other": "#f0f0f0", # Light grey } colors.append( [color_map.get(comp_info["ic_type"].lower(), "white")] * 4 ) # Create and customize table table = ax_table.table( cellText=table_data, colLabels=["Component", "Type", "Confidence", "Rejected"], loc="center", cellLoc="center", cellColours=colors, colWidths=[0.2, 0.3, 0.25, 0.25], ) # Customize table appearance table.auto_set_font_size(False) table.set_fontsize(9) table.scale(1.2, 1.5) # Reduced vertical scaling # Add title with page information, filename and timestamp timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fig_table.suptitle( f"ICA Components Summary - {self.config['bids_path'].basename}\n" f"(Page {page + 1} of {num_pages})\n" f"Generated: {timestamp}", fontsize=12, y=0.95, ) # Add legend for colors legend_elements = [ plt.Rectangle((0, 0), 1, 1, facecolor=color, edgecolor="none") for color in color_map.values() ] ax_table.legend( legend_elements, color_map.keys(), loc="upper right", title="Component Types", ) # Add margins plt.subplots_adjust(top=0.85, bottom=0.15) pdf.savefig(fig_table) plt.close(fig_table) # First page: Component topographies overview fig_topo = ica.plot_components(picks=component_indices, show=False) if isinstance(fig_topo, list): for f in fig_topo: pdf.savefig(f) plt.close(f) else: pdf.savefig(fig_topo) plt.close(fig_topo) # If rejected components, add overlay plot if components == "rejected": fig_overlay = plt.figure() end_time = min(30.0, self.raw.times[-1]) # Create a copy of raw data with only the channels used in ICA training # to avoid shape mismatch during pre-whitening raw_copy = self.raw.copy() # Get the channel names that were used for ICA training ica_ch_names = self.final_ica.ch_names # Pick only those channels from the raw data if len(ica_ch_names) != len(raw_copy.ch_names): message( "warning", f"Channel count mismatch: ICA has {len(ica_ch_names)} channels, " f"raw has {len(raw_copy.ch_names)}. Using only ICA channels for plotting.", ) # Keep only the channels that were used in ICA raw_copy.pick_channels(ica_ch_names) fig_overlay = self.final_ica.plot_overlay( raw_copy, start=0, stop=end_time, exclude=component_indices, show=False, ) fig_overlay.set_size_inches(15, 10) # Set size after creating figure pdf.savefig(fig_overlay) plt.close(fig_overlay) # For each component, create detailed plots for idx in component_indices: fig = plt.figure(constrained_layout=True, figsize=(12, 8)) gs = GridSpec(nrows=3, ncols=3, figure=fig) # Axes for ica.plot_properties ax1 = fig.add_subplot(gs[0, 0]) # Data ax2 = fig.add_subplot(gs[0, 1]) # Epochs image ax3 = fig.add_subplot(gs[0, 2]) # ERP/ERF ax4 = fig.add_subplot(gs[1, 0]) # Spectrum ax5 = fig.add_subplot(gs[1, 1]) # Topomap ax_props = [ax1, ax2, ax3, ax4, ax5] # Plot properties ica.plot_properties( raw, picks=[idx], axes=ax_props, dB=True, plot_std=True, log_scale=False, reject="auto", show=False, ) # Add time series plot ax_timeseries = fig.add_subplot(gs[2, :]) # Last row, all columns ax_timeseries.plot(times, ica_data[idx, :n_samples], linewidth=0.5) ax_timeseries.set_xlabel("Time (seconds)") ax_timeseries.set_ylabel("Amplitude") ax_timeseries.set_title( f"Component {idx + 1} Time Course ({duration}s)" ) # Add labels comp_info = ic_labels.iloc[idx] label_text = ( f"Component {comp_info['component']}\n" f"Type: {comp_info['ic_type']}\n" f"Confidence: {comp_info['confidence']:.2f}" ) fig.suptitle( label_text, fontsize=14, fontweight="bold", color=( "red" if comp_info["ic_type"] in ["eog", "muscle", "ch_noise", "line_noise", "ecg"] else "black" ), ) # Save the figure pdf.savefig(fig) plt.close(fig) print(f"Report saved to {pdf_path}") return Path(pdf_path).name
[docs] def verify_topography_plot(self) -> bool: """Use ica topograph to verify MEA channel placement. This function simply runs fast ICA then plots the topography. It is used on mouse files to verify channel placement. """ # pylint: disable=import-outside-toplevel from mne.preprocessing import ICA derivatives_dir = Path(self.config["derivatives_dir"]) ica = ICA( # pylint: disable=not-callable n_components=len(self.raw.ch_names) - len(self.raw.info["bads"]), method="fastica", random_state=42, ) ica.fit(self.raw) fig = ica.plot_components( picks=range(len(self.raw.ch_names) - len(self.raw.info["bads"])), show=False ) fig.savefig(derivatives_dir / "ica_topography.png")
[docs] def compare_vision_iclabel_classifications(self): """Compare ICLabel and Vision API classifications for ICA components. This method creates a comparison report between ICLabel and OpenAI Vision classifications of ICA components, highlighting agreements and disagreements. It requires both classify_ica_components_vision and run_ICLabel to have been run. Returns ------- matplotlib.figure.Figure Figure showing the comparison of classifications. """ # Check if both ICLabel and Vision classifications exist if not hasattr(self, "ica_flags") or self.ica_flags is None: message("error", "ICLabel results not found. Please run run_ICLabel first.") return None if not hasattr(self, "ica_vision_flags") or self.ica_vision_flags is None: message( "error", "Vision classification results not found. Please run classify_ica_components_vision first.", ) return None # Get the classification results iclabel_results = self.ica_flags vision_results = self.ica_vision_flags # Prepare data for comparison n_components = len(iclabel_results) # Create mapping for ICLabel categories to binary brain/artifact iclabel_mapping = { "brain": "brain", "eog": "artifact", "muscle": "artifact", "ecg": "artifact", "ch_noise": "artifact", "line_noise": "artifact", "other": "artifact", } # Create a figure for the comparison fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10)) # First subplot: Bar chart comparison indices = np.arange(n_components) bar_width = 0.4 # Create binary coding (1 for brain, 0 for artifact) iclabel_binary = np.array( [ ( 1 if iclabel_mapping.get( iclabel_results.iloc[i]["ic_type"].lower(), "artifact" ) == "brain" else 0 ) for i in range(n_components) ] ) vision_binary = np.array( [ 1 if vision_results.iloc[i]["label"] == "brain" else 0 for i in range(n_components) ] ) # Plot bars ax1.bar( indices - bar_width / 2, iclabel_binary, bar_width, label="ICLabel", color="blue", alpha=0.6, ) ax1.bar( indices + bar_width / 2, vision_binary, bar_width, label="Vision API", color="orange", alpha=0.6, ) # Highlight disagreements disagreements = np.where(iclabel_binary != vision_binary)[0] if len(disagreements) > 0: for idx in disagreements: ax1.annotate( "*", xy=(idx, 1.1), xytext=(idx, 1.1), ha="center", va="bottom", fontsize=12, color="red", ) # Customize plot ax1.set_title("Classification Comparison: ICLabel vs. Vision API", fontsize=14) ax1.set_xlabel("Component Number", fontsize=12) ax1.set_xticks(indices) ax1.set_xticklabels([f"IC{i+1}" for i in range(n_components)]) ax1.set_yticks([0, 1]) ax1.set_yticklabels(["Artifact", "Brain"]) ax1.legend() # Second subplot: Agreement table ax2.axis("tight") ax2.axis("off") # Prepare table data table_data = [] cell_colors = [] agreement_count = 0 for i in range(n_components): iclabel_category = iclabel_results.iloc[i]["ic_type"] iclabel_type = iclabel_mapping.get(iclabel_category.lower(), "artifact") iclabel_conf = iclabel_results.iloc[i]["confidence"] vision_type = vision_results.iloc[i]["label"] vision_conf = vision_results.iloc[i]["confidence"] agreement = "✓" if iclabel_type == vision_type else "✗" if iclabel_type == vision_type: agreement_count += 1 bg_color = "#d4edda" # Light green else: bg_color = "#f8d7da" # Light red table_data.append( [ f"IC{i+1}", iclabel_category, f"{iclabel_conf:.2f}", vision_type.title(), f"{vision_conf:.2f}", agreement, ] ) cell_colors.append([bg_color] * 6) # Add agreement percentage to the end agreement_pct = (agreement_count / n_components) * 100 # Create and customize table table = ax2.table( cellText=table_data, colLabels=[ "Component", "ICLabel Category", "ICLabel Conf.", "Vision Type", "Vision Conf.", "Agreement", ], loc="center", cellLoc="center", cellColours=cell_colors, ) # Customize table appearance table.auto_set_font_size(False) table.set_fontsize(9) table.scale(1.2, 1.5) # Add agreement percentage as text ax2.text( 0.5, -0.1, f"Overall Agreement: {agreement_pct:.1f}% ({agreement_count}/{n_components} components)", ha="center", va="center", transform=ax2.transAxes, fontsize=12, fontweight="bold", ) # Adjust layout plt.tight_layout() fig.subplots_adjust(hspace=0.3) # Save the figure derivatives_dir = Path(self.config["derivatives_dir"]) basename = self.config["bids_path"].basename basename = basename.replace("_eeg", "_ica_classification_comparison") target_figure = derivatives_dir / basename # Save figure with higher DPI fig.savefig(target_figure, dpi=300, bbox_inches="tight") metadata = { "artifact_reports": { "creationDateTime": datetime.now().isoformat(), "ica_classification_comparison": Path(target_figure).name, } } self._update_metadata("compare_vision_iclabel_classifications", metadata) return fig