import pickle import json import copy from matplotlib import pyplot as plt import numpy as np import wfdb.processing import scipy.signal from scipy.signal import butter, lfilter from statsmodels.nonparametric.smoothers_lowess import lowess import cv2 as cv """" TODO create overall description """ def load_data(only_demographic:bool=False, only_diagnosis_ids=False, path_settings:str="../settings.json"): """ Loads data from pickle files based on the specified settings. Args: only_demographic (bool, optional): If True, only loads demographic data (age, diagnosis, gender). Defaults to False. path_settings (str, optional): Path to the settings file. Defaults to "./settings.json". Returns: dict: A dictionary containing the loaded data. """ settings = json.load(open(path_settings)) path_data = settings["data_path"] labels = settings["labels"] if only_diagnosis_ids: with open(f'{path_data}/diagnosis.pkl', 'rb') as f: return pickle.load(f) data = {} if only_demographic: data = {'age': [], 'diag': [], 'gender': []} for cat_name in labels.keys(): print(f"Reading {cat_name}") with open(f'{path_data}/{cat_name}.pkl', 'rb') as f: records = pickle.load(f) if only_demographic: for record in records: age = record.comments[0].split(' ')[1] gender = record.comments[1].split(' ')[1] if age == 'NaN' or gender == 'NaN': continue data['age'].append(int(age)) data['diag'].append(cat_name) data['gender'].append(gender) else: data[cat_name] = records return data def format_data_input(data): """ Formats the input data into a standardized format. Parameters: data (np.ndarray or wfdb.Record or list or dict): The input data to be formatted. Returns: dict: The formatted data. """ if isinstance(data, np.ndarray): data = wfdb.Record(p_signal=data.copy()) if isinstance(data, wfdb.Record): data = [data] if isinstance(data, list): temp_dict = {} temp_dict['temp_key'] = data data = temp_dict.copy() return data def format_data_output(data): """ Formats the output data into a less redundant format. Args: data (dict, list, wfdb.Record, or ndarray): The input data to be formatted. Returns: The formatted data. """ if len(data.keys()) == 1 and 'temp_key' in data.keys(): data = data['temp_key'] if isinstance(data, list) and len(data) == 1: data = data[0] if isinstance(data, wfdb.Record) and len(data.p_signal.shape) == 1: data = data.p_signal[0] return data def butterlowpass_filter(data, cutoff:int, fs:int, order:int=5): """ Apply a Butterworth lowpass filter to the input data. Parameters: - data: (dict, list, wfdb.Record, or ndarray) The input data to be filtered. - cutoff: float The cutoff frequency of the filter. - fs: float The sampling frequency of the input data. - order: int, optional The order of the filter (default is 5). Returns: - data: (dict, list, wfdb.Record, or ndarray) The filtered output data. """ data = copy.deepcopy(data) data = format_data_input(data) for label, wfdb_objs in data.items(): for wfdb_obj in wfdb_objs: for idx in range(wfdb_obj.p_signal.shape[1]): signal = wfdb_obj.p_signal[:, idx] nyq = 0.5 * fs normal_cutoff = cutoff / nyq b, a = butter(order, normal_cutoff, btype='low', analog=False) wfdb_obj.p_signal[:, idx] = lfilter(b, a, signal) return format_data_output(data) def lowess_filter(data, frac:float=0.03, it:int=1): """ Applies the lowess filter to the given data. Parameters: - data: (dict, list, wfdb.Record, or ndarray) A dictionary containing the data to be filtered. - frac (float): The fraction of the data used to compute each fitted value. Default is 0.03. - it (int): The number of iterations for the smoothing process. Default is 1. Returns: (dict, list, wfdb.Record, or ndarray): The filtered data. """ data = copy.deepcopy(data) data = format_data_input(data) for label, wfdb_objs in data.items(): for wfdb_obj in wfdb_objs: for idx in range(wfdb_obj.p_signal.shape[1]): signal = wfdb_obj.p_signal[:, idx] d_range = np.arange(len(signal)) # [:, 1] needed to get only the smoothed values wfdb_obj.p_signal[:, idx] = lowess(signal, d_range, is_sorted=True, frac=frac, it=it)[:, 1] return format_data_output(data) def non_local_means_filter(data, filter_strength:int = 50, template_window_size:int = 7, search_window_size:int = 21): """ Applies the Non-Local Means filter to the given data. Parameters: - data: (dict, list, wfdb.Record, or ndarray) A dictionary containing the data to be filtered. - filter_strength (int): Parameter controlling the strength of the filtering process. Default is 50. - template_window_size (int): Size in pixels of the template patch that is used to compute weights. Default is 7. - search_window_size (int): Size in pixels of the window that is used to compute weighted average for given pixel. Default is 21. Returns: (dict, list, wfdb.Record, or ndarray): The filtered data. """ data = copy.deepcopy(data) data = format_data_input(data) for label, wfdb_objs in data.items(): for wfdb_obj in wfdb_objs: for idx in range(wfdb_obj.p_signal.shape[1]): signal = wfdb_obj.p_signal[:, idx] # reshape data to 2d for image like processing d_2d = np.reshape(signal, (-1, 1)) # max min scaling d_2d_scaled = np.uint8((d_2d - np.min(d_2d)) / (np.max(d_2d) - np.min(d_2d)) * 255) # apply non local means filter d_2d_filtered = cv.fastNlMeansDenoising(d_2d_scaled, None, filter_strength, template_window_size, search_window_size) # Rescale the denoised signal back to the original range d_filtered = np.reshape(d_2d_filtered, -1) * (np.max(signal) - np.min(signal)) / 255 + np.min(signal) wfdb_obj.p_signal[:, idx] = d_filtered return format_data_output(data) def filter_data(data, filter_params:dict): """ Apply a filter to the input data. Parameters: - data: (dict, list, wfdb.Record, or ndarray) The input data to be filtered. - filter_params: dict The parameters of the filter to be applied. Returns: - data: (dict, list, wfdb.Record, or ndarray) The filtered output data. """ data = copy.deepcopy(data) #data = format_data_input(data) if 'butterlowpass' in filter_params['names']: data = butterlowpass_filter(data, filter_params['cutoff'], filter_params['fs'], filter_params['order']) if 'loess' in filter_params['names']: data = lowess_filter(data, filter_params['frac'], filter_params['it']) if 'non_local_means' in filter_params['names']: data = non_local_means_filter(data, filter_params['filter_strength'], filter_params['template_window_size'], filter_params['search_window_size']) if not any(name in filter_params['names'] for name in ['butterlowpass', 'loess', 'non_local_means']): print("Warning: No valid filter names found in filter_params['names']. Data will be returned as is.") return data #format_data_output(data) if __name__ == "__main__": data = load_data(only_demographic=False, path_settings="./settings.json") # print shape of data for each category for cat_name in data.keys(): print(f"{cat_name}: {len(data[cat_name])}") order = 1 fs = 500.0 cutoff = 25#25 # Apply filter to the signal data_test = butterlowpass_filter(data, cutoff, fs, order) data_test = butterlowpass_filter(data['SB'], cutoff, fs, order) data_test = butterlowpass_filter(data['SB'][0], cutoff, fs, order)