239 lines
8.2 KiB
Python
239 lines
8.2 KiB
Python
import pickle
|
|
import json
|
|
import copy
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
import wfdb.processing
|
|
import scipy.signal
|
|
from scipy.signal import butter, lfilter
|
|
from statsmodels.nonparametric.smoothers_lowess import lowess
|
|
import cv2 as cv
|
|
|
|
""""
|
|
TODO create overall description
|
|
"""
|
|
|
|
def load_data(only_demographic:bool=False, only_diagnosis_ids=False, path_settings:str="../settings.json"):
|
|
"""
|
|
Loads data from pickle files based on the specified settings.
|
|
|
|
Args:
|
|
only_demographic (bool, optional): If True, only loads demographic data (age, diagnosis, gender). Defaults to False.
|
|
path_settings (str, optional): Path to the settings file. Defaults to "./settings.json".
|
|
|
|
Returns:
|
|
dict: A dictionary containing the loaded data.
|
|
"""
|
|
settings = json.load(open(path_settings))
|
|
path_data = settings["data_path"]
|
|
labels = settings["labels"]
|
|
|
|
if only_diagnosis_ids:
|
|
with open(f'{path_data}/diagnosis.pkl', 'rb') as f:
|
|
return pickle.load(f)
|
|
|
|
data = {}
|
|
if only_demographic:
|
|
data = {'age': [], 'diag': [], 'gender': []}
|
|
|
|
for cat_name in labels.keys():
|
|
print(f"Reading {cat_name}")
|
|
with open(f'{path_data}/{cat_name}.pkl', 'rb') as f:
|
|
records = pickle.load(f)
|
|
if only_demographic:
|
|
for record in records:
|
|
age = record.comments[0].split(' ')[1]
|
|
gender = record.comments[1].split(' ')[1]
|
|
if age == 'NaN' or gender == 'NaN':
|
|
continue
|
|
data['age'].append(int(age))
|
|
data['diag'].append(cat_name)
|
|
data['gender'].append(gender)
|
|
else:
|
|
data[cat_name] = records
|
|
return data
|
|
|
|
|
|
def format_data_input(data):
|
|
"""
|
|
Formats the input data into a standardized format.
|
|
|
|
Parameters:
|
|
data (np.ndarray or wfdb.Record or list or dict): The input data to be formatted.
|
|
|
|
Returns:
|
|
dict: The formatted data.
|
|
|
|
"""
|
|
if isinstance(data, np.ndarray):
|
|
data = wfdb.Record(p_signal=data.copy())
|
|
if isinstance(data, wfdb.Record):
|
|
data = [data]
|
|
if isinstance(data, list):
|
|
temp_dict = {}
|
|
temp_dict['temp_key'] = data
|
|
data = temp_dict.copy()
|
|
return data
|
|
|
|
|
|
def format_data_output(data):
|
|
"""
|
|
Formats the output data into a less redundant format.
|
|
|
|
Args:
|
|
data (dict, list, wfdb.Record, or ndarray): The input data to be formatted.
|
|
|
|
Returns:
|
|
The formatted data.
|
|
|
|
"""
|
|
if len(data.keys()) == 1 and 'temp_key' in data.keys():
|
|
data = data['temp_key']
|
|
if isinstance(data, list) and len(data) == 1:
|
|
data = data[0]
|
|
if isinstance(data, wfdb.Record) and len(data.p_signal.shape) == 1:
|
|
data = data.p_signal[0]
|
|
return data
|
|
|
|
def butterlowpass_filter(data, cutoff:int, fs:int, order:int=5):
|
|
"""
|
|
Apply a Butterworth lowpass filter to the input data.
|
|
|
|
Parameters:
|
|
- data: (dict, list, wfdb.Record, or ndarray)
|
|
The input data to be filtered.
|
|
- cutoff: float
|
|
The cutoff frequency of the filter.
|
|
- fs: float
|
|
The sampling frequency of the input data.
|
|
- order: int, optional
|
|
The order of the filter (default is 5).
|
|
|
|
Returns:
|
|
- data: (dict, list, wfdb.Record, or ndarray)
|
|
The filtered output data.
|
|
|
|
"""
|
|
data = copy.deepcopy(data)
|
|
data = format_data_input(data)
|
|
for label, wfdb_objs in data.items():
|
|
for wfdb_obj in wfdb_objs:
|
|
for idx in range(wfdb_obj.p_signal.shape[1]):
|
|
signal = wfdb_obj.p_signal[:, idx]
|
|
nyq = 0.5 * fs
|
|
normal_cutoff = cutoff / nyq
|
|
b, a = butter(order, normal_cutoff, btype='low', analog=False)
|
|
wfdb_obj.p_signal[:, idx] = lfilter(b, a, signal)
|
|
|
|
return format_data_output(data)
|
|
|
|
|
|
def lowess_filter(data, frac:float=0.03, it:int=1):
|
|
"""
|
|
Applies the lowess filter to the given data.
|
|
|
|
Parameters:
|
|
- data: (dict, list, wfdb.Record, or ndarray)
|
|
A dictionary containing the data to be filtered.
|
|
- frac (float):
|
|
The fraction of the data used to compute each fitted value. Default is 0.03.
|
|
- it (int):
|
|
The number of iterations for the smoothing process. Default is 1.
|
|
|
|
Returns:
|
|
(dict, list, wfdb.Record, or ndarray): The filtered data.
|
|
|
|
"""
|
|
data = copy.deepcopy(data)
|
|
data = format_data_input(data)
|
|
for label, wfdb_objs in data.items():
|
|
for wfdb_obj in wfdb_objs:
|
|
for idx in range(wfdb_obj.p_signal.shape[1]):
|
|
signal = wfdb_obj.p_signal[:, idx]
|
|
d_range = np.arange(len(signal))
|
|
# [:, 1] needed to get only the smoothed values
|
|
wfdb_obj.p_signal[:, idx] = lowess(signal, d_range, is_sorted=True, frac=frac, it=it)[:, 1]
|
|
return format_data_output(data)
|
|
|
|
|
|
def non_local_means_filter(data, filter_strength:int = 50, template_window_size:int = 7, search_window_size:int = 21):
|
|
"""
|
|
Applies the Non-Local Means filter to the given data.
|
|
|
|
Parameters:
|
|
- data: (dict, list, wfdb.Record, or ndarray)
|
|
A dictionary containing the data to be filtered.
|
|
- filter_strength (int):
|
|
Parameter controlling the strength of the filtering process. Default is 50.
|
|
- template_window_size (int):
|
|
Size in pixels of the template patch that is used to compute weights. Default is 7.
|
|
- search_window_size (int):
|
|
Size in pixels of the window that is used to compute weighted average for given pixel. Default is 21.
|
|
|
|
Returns:
|
|
(dict, list, wfdb.Record, or ndarray): The filtered data.
|
|
|
|
"""
|
|
data = copy.deepcopy(data)
|
|
data = format_data_input(data)
|
|
for label, wfdb_objs in data.items():
|
|
for wfdb_obj in wfdb_objs:
|
|
for idx in range(wfdb_obj.p_signal.shape[1]):
|
|
signal = wfdb_obj.p_signal[:, idx]
|
|
# reshape data to 2d for image like processing
|
|
d_2d = np.reshape(signal, (-1, 1))
|
|
# max min scaling
|
|
d_2d_scaled = np.uint8((d_2d - np.min(d_2d)) / (np.max(d_2d) - np.min(d_2d)) * 255)
|
|
# apply non local means filter
|
|
d_2d_filtered = cv.fastNlMeansDenoising(d_2d_scaled, None, filter_strength, template_window_size, search_window_size)
|
|
# Rescale the denoised signal back to the original range
|
|
d_filtered = np.reshape(d_2d_filtered, -1) * (np.max(signal) - np.min(signal)) / 255 + np.min(signal)
|
|
wfdb_obj.p_signal[:, idx] = d_filtered
|
|
return format_data_output(data)
|
|
|
|
def filter_data(data, filter_params:dict):
|
|
"""
|
|
Apply a filter to the input data.
|
|
|
|
Parameters:
|
|
- data: (dict, list, wfdb.Record, or ndarray)
|
|
The input data to be filtered.
|
|
- filter_params: dict
|
|
The parameters of the filter to be applied.
|
|
|
|
Returns:
|
|
- data: (dict, list, wfdb.Record, or ndarray)
|
|
The filtered output data.
|
|
|
|
"""
|
|
data = copy.deepcopy(data)
|
|
#data = format_data_input(data)
|
|
if 'butterlowpass' in filter_params['names']:
|
|
data = butterlowpass_filter(data, filter_params['cutoff'], filter_params['fs'], filter_params['order'])
|
|
if 'loess' in filter_params['names']:
|
|
data = lowess_filter(data, filter_params['frac'], filter_params['it'])
|
|
if 'non_local_means' in filter_params['names']:
|
|
data = non_local_means_filter(data, filter_params['filter_strength'], filter_params['template_window_size'], filter_params['search_window_size'])
|
|
if not any(name in filter_params['names'] for name in ['butterlowpass', 'loess', 'non_local_means']):
|
|
print("Warning: No valid filter names found in filter_params['names']. Data will be returned as is.")
|
|
return data #format_data_output(data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
data = load_data(only_demographic=False, path_settings="./settings.json")
|
|
# print shape of data for each category
|
|
for cat_name in data.keys():
|
|
print(f"{cat_name}: {len(data[cat_name])}")
|
|
|
|
|
|
order = 1
|
|
fs = 500.0
|
|
cutoff = 25#25
|
|
|
|
# Apply filter to the signal
|
|
data_test = butterlowpass_filter(data, cutoff, fs, order)
|
|
data_test = butterlowpass_filter(data['SB'], cutoff, fs, order)
|
|
data_test = butterlowpass_filter(data['SB'][0], cutoff, fs, order)
|
|
|
|
|