Source code for autoclean.utils.config

# src/autoclean/utils/config.py
"""
This module contains functions for loading and validating the autoclean configuration file.
"""
# pylint: disable=line-too-long
import base64
import hashlib
import zlib
from pathlib import Path

import yaml
from schema import Optional, Or, Schema

from autoclean.utils.logging import message
from autoclean.utils.montage import VALID_MONTAGES


[docs] def load_config(config_file: Path) -> dict: """Load and validate the autoclean configuration file. Parameters ---------- config_file : Path The path to the autoclean configuration file. Returns ------- autoclean_dict : dict The autoclean configuration dictionary. """ message("info", f"Loading config: {config_file}") config_schema = Schema( { "tasks": { str: { "mne_task": str, "description": str, "settings": { "filtering": { "enabled": bool, "value": { "l_freq": Or(int, float, None), "h_freq": Or(int, float, None), "notch_freqs": Or( float, int, list[float], list[int], None ), "notch_widths": Or( float, int, list[float], list[int], None ), }, }, "resample_step": { "enabled": bool, "value": Or(int, float, None), }, "drop_outerlayer": {"enabled": bool, "value": Or(list, None)}, "eog_step": {"enabled": bool, "value": Or(list, None)}, "trim_step": {"enabled": bool, "value": Or(int, float)}, "crop_step": { "enabled": bool, "value": { "start": Or(int, float), "end": Or(int, float, None), }, }, "reference_step": { "enabled": bool, "value": Or(str, list[str], None), }, "montage": {"enabled": bool, "value": Or(str, None)}, "ICA": { "enabled": bool, "value": { "method": str, # Required parameter Optional("n_components"): Or(int, float, None), Optional("noise_cov"): Or(dict, None), Optional("random_state"): Or(int, None), Optional("fit_params"): Or(dict, None), Optional("max_iter"): Or(int, str, None), Optional("allow_ref_meg"): Or(bool, None), Optional("decim"): Or(int, None), }, }, "ICLabel": { "enabled": bool, "value": { "ic_flags_to_reject": list, "ic_rejection_threshold": float, }, }, "epoch_settings": { "enabled": bool, "value": { "tmin": Or(int, float, None), "tmax": Or(int, float, None), }, "event_id": Or(dict, None), "remove_baseline": { "enabled": bool, "window": Or(list[float], None), }, "threshold_rejection": { "enabled": bool, "volt_threshold": Or(dict, int, float), }, }, }, } }, "stage_files": {str: {"enabled": bool, "suffix": str}}, } ) with open(config_file, encoding="utf-8") as f: config = yaml.safe_load(f) autoclean_dict = config_schema.validate(config) # Validate signal processing parameters for each task for task in autoclean_dict["tasks"]: validate_signal_processing_params(autoclean_dict, task) return autoclean_dict
[docs] def validate_signal_processing_params(autoclean_dict: dict, task: str) -> None: """Validate signal processing parameters for physical constraints. Parameters ---------- autoclean_dict : dict Configuration dictionary task : str Current processing task Raises ------ ValueError If parameters violate signal processing constraints """ # Validate filtering settings filtering_settings = autoclean_dict["tasks"][task]["settings"]["filtering"] if filtering_settings["enabled"]: l_freq = filtering_settings["value"]["l_freq"] h_freq = filtering_settings["value"]["h_freq"] if l_freq is not None and h_freq is not None: if l_freq >= h_freq: message( "error", f"Low-pass filter frequency {l_freq} must be less than high-pass filter frequency {h_freq}", ) raise ValueError( f"Invalid filtering settings: l_freq {l_freq} >= h_freq {h_freq}" ) resampling_settings = autoclean_dict["tasks"][task]["settings"]["resample_step"] if resampling_settings["enabled"]: resampling_rate = resampling_settings["value"] if resampling_rate is not None: if resampling_rate <= 0: message( "error", f"Resampling rate {resampling_rate} Hz must be greater than 0", ) raise ValueError(f"Invalid resampling rate: {resampling_rate} Hz") if l_freq is not None and h_freq is not None: if l_freq >= resampling_rate / 2 or h_freq >= resampling_rate / 2: message( "error", f"Filter frequencies {l_freq} Hz and {h_freq} Hz must be below Nyquist frequency {resampling_rate / 2} Hz", ) raise ValueError( f"Filter frequencies {l_freq} Hz and {h_freq} Hz must be below Nyquist frequency {resampling_rate / 2} Hz" ) # Validate epoch settings if enabled epoch_settings = autoclean_dict["tasks"][task]["settings"]["epoch_settings"] if epoch_settings["enabled"]: tmin = epoch_settings["value"]["tmin"] tmax = epoch_settings["value"]["tmax"] if tmin is not None and tmax is not None: if tmax <= tmin: message( "error", f"Epoch tmax ({tmax}s) must be greater than tmin ({tmin}s)" ) raise ValueError(f"Invalid epoch times: tmax {tmax}s <= tmin {tmin}s") message("debug", f"Signal processing parameters validated for task {task}")
[docs] def validate_eeg_system(autoclean_dict: dict, task: str) -> str: # pylint: disable=line-too-long """Validate the EEG system for a given task. Checks if the EEG system is in the VALID_MONTAGES dictionary. Parameters ---------- autoclean_dict : dict The autoclean configuration dictionary. task : str The task to validate the EEG system for. Returns ------- eeg_system : str The validated EEG system. """ # Handle both YAML-based and Python-based task configurations if task in autoclean_dict.get("tasks", {}): # YAML-based task configuration eeg_system = autoclean_dict["tasks"][task]["settings"]["montage"]["value"] else: # Python-based task - extract from task_config if available task_config = autoclean_dict.get("task_config", {}) if "montage" in task_config and "value" in task_config["montage"]: eeg_system = task_config["montage"]["value"] else: # Default or skip validation for Python tasks without explicit montage message( "warning", f"No montage specified for Python task '{task}', skipping EEG system validation", ) return None if eeg_system in VALID_MONTAGES: message("success", f"✓ EEG system validated: {eeg_system}") return eeg_system else: error_msg = ( f"Invalid EEG system: {eeg_system}. Supported: {', '.join(VALID_MONTAGES.keys())}. " "To add a new montage, please edit configs/montage.yaml or request it on GitHub issues." ) message("error", error_msg) raise ValueError(error_msg)
[docs] def hash_and_encode_yaml(content: str | dict, is_file: bool = True) -> tuple[str, str]: """Hash and encode a YAML file or dictionary. Parameters ---------- content : str or dict The content to hash and encode. is_file : bool Whether the content is a file path. Returns ------- file_hash : str The hash of the content. compressed_encoded : str The compressed and encoded content. """ if is_file: with open(content, "r", encoding="utf-8") as f: yaml_str = f.read() else: yaml_str = yaml.safe_dump(content, sort_keys=True) data = yaml.safe_load(yaml_str) canonical_yaml = yaml.safe_dump(data, sort_keys=True) # Compute a secure hash of the canonical YAML. file_hash = hashlib.sha256(canonical_yaml.encode("utf-8")).hexdigest() # Compress and then base64 encode the canonical YAML. compressed = zlib.compress(canonical_yaml.encode("utf-8")) compressed_encoded = base64.b64encode(compressed).decode("utf-8") return file_hash, compressed_encoded
[docs] def decode_compressed_yaml(encoded_str: str) -> dict: """Decode a compressed and encoded YAML string. Parameters ---------- encoded_str : str The compressed and encoded YAML string. Returns ------- yaml_dict : dict The decoded YAML dictionary. """ compressed_data = base64.b64decode(encoded_str) yaml_str = zlib.decompress(compressed_data).decode("utf-8") return yaml.safe_load(yaml_str)