import pickle
import json
import copy
from matplotlib import pyplot as plt
import numpy as np
import wfdb.processing
import scipy.signal
from scipy.signal import butter, lfilter
from statsmodels.nonparametric.smoothers_lowess import lowess
import cv2 as cv
TODO create overall description
def load_data(only_demographic:bool=False, path_settings:str="../settings.json"):
Loads data from pickle files based on the specified settings.
only_demographic (bool, optional): If True, only loads demographic data (age, diagnosis, gender). Defaults to False.
path_settings (str, optional): Path to the settings file. Defaults to "./settings.json".
dict: A dictionary containing the loaded data.
settings = json.load(open(path_settings))
path_data = settings["data_path"]
labels = settings["labels"]
data = {}
if only_demographic:
data = {'age': [], 'diag': [], 'gender': []}
for cat_name in labels.keys():
print(f"Reading {cat_name}")
with open(f'{path_data}/{cat_name}.pkl', 'rb') as f:
records = pickle.load(f)
if only_demographic:
for record in records:
age = record.comments[0].split(' ')[1]
gender = record.comments[1].split(' ')[1]
if age == 'NaN' or gender == 'NaN':
data[cat_name] = records
return data
def format_data_input(data):
Formats the input data into a standardized format.
data (np.ndarray or wfdb.Record or list or dict): The input data to be formatted.
dict: The formatted data.
if isinstance(data, np.ndarray):
data = wfdb.Record(p_signal=data.copy())
if isinstance(data, wfdb.Record):
data = [data]
if isinstance(data, list):
temp_dict = {}
temp_dict['temp_key'] = data
data = temp_dict.copy()
return data
def format_data_output(data):
Formats the output data into a less redundant format.
data (dict, list, wfdb.Record, or ndarray): The input data to be formatted.
The formatted data.
if len(data.keys()) == 1 and 'temp_key' in data.keys():
data = data['temp_key']
if isinstance(data, list) and len(data) == 1:
data = data[0]
if isinstance(data, wfdb.Record) and len(data.p_signal.shape) == 1:
data = data.p_signal[0]
return data
def butterlowpass_filter(data, cutoff:int, fs:int, order:int=5):
Apply a Butterworth lowpass filter to the input data.
- data: (dict, list, wfdb.Record, or ndarray)
The input data to be filtered.
- cutoff: float
The cutoff frequency of the filter.
- fs: float
The sampling frequency of the input data.
- order: int, optional
The order of the filter (default is 5).
- data: (dict, list, wfdb.Record, or ndarray)
The filtered output data.
data = copy.deepcopy(data)
data = format_data_input(data)
for label, wfdb_objs in data.items():
for wfdb_obj in wfdb_objs:
for idx in range(wfdb_obj.p_signal.shape[1]):
signal = wfdb_obj.p_signal[:, idx]
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = butter(order, normal_cutoff, btype='low', analog=False)
wfdb_obj.p_signal[:, idx] = lfilter(b, a, signal)
return format_data_output(data)
def lowess_filter(data, frac:float=0.03, it:int=1):
Applies the lowess filter to the given data.
- data: (dict, list, wfdb.Record, or ndarray)
A dictionary containing the data to be filtered.
- frac (float):
The fraction of the data used to compute each fitted value. Default is 0.03.
- it (int):
The number of iterations for the smoothing process. Default is 1.
(dict, list, wfdb.Record, or ndarray): The filtered data.
data = copy.deepcopy(data)
data = format_data_input(data)
for label, wfdb_objs in data.items():
for wfdb_obj in wfdb_objs:
for idx in range(wfdb_obj.p_signal.shape[1]):
signal = wfdb_obj.p_signal[:, idx]
d_range = np.arange(len(signal))
# [:, 1] needed to get only the smoothed values
wfdb_obj.p_signal[:, idx] = lowess(signal, d_range, is_sorted=True, frac=frac, it=it)[:, 1]
return format_data_output(data)
def non_local_means_filter(data, filter_strength:int = 50, template_window_size:int = 7, search_window_size:int = 21):
Applies the Non-Local Means filter to the given data.
- data: (dict, list, wfdb.Record, or ndarray)
A dictionary containing the data to be filtered.
- filter_strength (int):
Parameter controlling the strength of the filtering process. Default is 50.
- template_window_size (int):
Size in pixels of the template patch that is used to compute weights. Default is 7.
- search_window_size (int):
Size in pixels of the window that is used to compute weighted average for given pixel. Default is 21.
(dict, list, wfdb.Record, or ndarray): The filtered data.
data = copy.deepcopy(data)
data = format_data_input(data)
for label, wfdb_objs in data.items():
for wfdb_obj in wfdb_objs:
for idx in range(wfdb_obj.p_signal.shape[1]):
signal = wfdb_obj.p_signal[:, idx]
# reshape data to 2d for image like processing
d_2d = np.reshape(signal, (-1, 1))
# max min scaling
d_2d_scaled = np.uint8((d_2d - np.min(d_2d)) / (np.max(d_2d) - np.min(d_2d)) * 255)
# apply non local means filter
d_2d_filtered = cv.fastNlMeansDenoising(d_2d_scaled, None, filter_strength, template_window_size, search_window_size)
# Rescale the denoised signal back to the original range
d_filtered = np.reshape(d_2d_filtered, -1) * (np.max(signal) - np.min(signal)) / 255 + np.min(signal)
wfdb_obj.p_signal[:, idx] = d_filtered
return format_data_output(data)
def filter_data(data, filter_params:dict):
Apply a filter to the input data.
- data: (dict, list, wfdb.Record, or ndarray)
The input data to be filtered.
- filter_params: dict
The parameters of the filter to be applied.
- data: (dict, list, wfdb.Record, or ndarray)
The filtered output data.
data = copy.deepcopy(data)
#data = format_data_input(data)
if 'butterlowpass' in filter_params['names']:
data = butterlowpass_filter(data, filter_params['cutoff'], filter_params['fs'], filter_params['order'])
if 'loess' in filter_params['names']:
data = lowess_filter(data, filter_params['frac'], filter_params['it'])
if 'non_local_means' in filter_params['names']:
data = non_local_means_filter(data, filter_params['filter_strength'], filter_params['template_window_size'], filter_params['search_window_size'])
if not any(name in filter_params['names'] for name in ['butterlowpass', 'loess', 'non_local_means']):
print("Warning: No valid filter names found in filter_params['names']. Data will be returned as is.")
return data #format_data_output(data)
if __name__ == "__main__":
data = load_data(only_demographic=False, path_settings="./settings.json")
# print shape of data for each category
for cat_name in data.keys():
print(f"{cat_name}: {len(data[cat_name])}")
order = 1
fs = 500.0
cutoff = 25#25
# Apply filter to the signal
data_test = butterlowpass_filter(data, cutoff, fs, order)
data_test = butterlowpass_filter(data['SB'], cutoff, fs, order)
data_test = butterlowpass_filter(data['SB'][0], cutoff, fs, order)