DSA_SS24/scripts/data_helper.py

235 lines
8.1 KiB
Python

import pickle
import json
import copy
from matplotlib import pyplot as plt
import numpy as np
import wfdb.processing
import scipy.signal
from scipy.signal import butter, lfilter
from statsmodels.nonparametric.smoothers_lowess import lowess
import cv2 as cv
""""
TODO create overall description
"""
def load_data(only_demographic:bool=False, path_settings:str="../settings.json"):
"""
Loads data from pickle files based on the specified settings.
Args:
only_demographic (bool, optional): If True, only loads demographic data (age, diagnosis, gender). Defaults to False.
path_settings (str, optional): Path to the settings file. Defaults to "./settings.json".
Returns:
dict: A dictionary containing the loaded data.
"""
settings = json.load(open(path_settings))
path_data = settings["data_path"]
labels = settings["labels"]
data = {}
if only_demographic:
data = {'age': [], 'diag': [], 'gender': []}
for cat_name in labels.keys():
print(f"Reading {cat_name}")
with open(f'{path_data}/{cat_name}.pkl', 'rb') as f:
records = pickle.load(f)
if only_demographic:
for record in records:
age = record.comments[0].split(' ')[1]
gender = record.comments[1].split(' ')[1]
if age == 'NaN' or gender == 'NaN':
continue
data['age'].append(int(age))
data['diag'].append(cat_name)
data['gender'].append(gender)
else:
data[cat_name] = records
return data
def format_data_input(data):
"""
Formats the input data into a standardized format.
Parameters:
data (np.ndarray or wfdb.Record or list or dict): The input data to be formatted.
Returns:
dict: The formatted data.
"""
if isinstance(data, np.ndarray):
data = wfdb.Record(p_signal=data.copy())
if isinstance(data, wfdb.Record):
data = [data]
if isinstance(data, list):
temp_dict = {}
temp_dict['temp_key'] = data
data = temp_dict.copy()
return data
def format_data_output(data):
"""
Formats the output data into a less redundant format.
Args:
data (dict, list, wfdb.Record, or ndarray): The input data to be formatted.
Returns:
The formatted data.
"""
if len(data.keys()) == 1 and 'temp_key' in data.keys():
data = data['temp_key']
if isinstance(data, list) and len(data) == 1:
data = data[0]
if isinstance(data, wfdb.Record) and len(data.p_signal.shape) == 1:
data = data.p_signal[0]
return data
def butterlowpass_filter(data, cutoff:int, fs:int, order:int=5):
"""
Apply a Butterworth lowpass filter to the input data.
Parameters:
- data: (dict, list, wfdb.Record, or ndarray)
The input data to be filtered.
- cutoff: float
The cutoff frequency of the filter.
- fs: float
The sampling frequency of the input data.
- order: int, optional
The order of the filter (default is 5).
Returns:
- data: (dict, list, wfdb.Record, or ndarray)
The filtered output data.
"""
data = copy.deepcopy(data)
data = format_data_input(data)
for label, wfdb_objs in data.items():
for wfdb_obj in wfdb_objs:
for idx in range(wfdb_obj.p_signal.shape[1]):
signal = wfdb_obj.p_signal[:, idx]
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = butter(order, normal_cutoff, btype='low', analog=False)
wfdb_obj.p_signal[:, idx] = lfilter(b, a, signal)
return format_data_output(data)
def lowess_filter(data, frac:float=0.03, it:int=1):
"""
Applies the lowess filter to the given data.
Parameters:
- data: (dict, list, wfdb.Record, or ndarray)
A dictionary containing the data to be filtered.
- frac (float):
The fraction of the data used to compute each fitted value. Default is 0.03.
- it (int):
The number of iterations for the smoothing process. Default is 1.
Returns:
(dict, list, wfdb.Record, or ndarray): The filtered data.
"""
data = copy.deepcopy(data)
data = format_data_input(data)
for label, wfdb_objs in data.items():
for wfdb_obj in wfdb_objs:
for idx in range(wfdb_obj.p_signal.shape[1]):
signal = wfdb_obj.p_signal[:, idx]
d_range = np.arange(len(signal))
# [:, 1] needed to get only the smoothed values
wfdb_obj.p_signal[:, idx] = lowess(signal, d_range, is_sorted=True, frac=frac, it=it)[:, 1]
return format_data_output(data)
def non_local_means_filter(data, filter_strength:int = 50, template_window_size:int = 7, search_window_size:int = 21):
"""
Applies the Non-Local Means filter to the given data.
Parameters:
- data: (dict, list, wfdb.Record, or ndarray)
A dictionary containing the data to be filtered.
- filter_strength (int):
Parameter controlling the strength of the filtering process. Default is 50.
- template_window_size (int):
Size in pixels of the template patch that is used to compute weights. Default is 7.
- search_window_size (int):
Size in pixels of the window that is used to compute weighted average for given pixel. Default is 21.
Returns:
(dict, list, wfdb.Record, or ndarray): The filtered data.
"""
data = copy.deepcopy(data)
data = format_data_input(data)
for label, wfdb_objs in data.items():
for wfdb_obj in wfdb_objs:
for idx in range(wfdb_obj.p_signal.shape[1]):
signal = wfdb_obj.p_signal[:, idx]
# reshape data to 2d for image like processing
d_2d = np.reshape(signal, (-1, 1))
# max min scaling
d_2d_scaled = np.uint8((d_2d - np.min(d_2d)) / (np.max(d_2d) - np.min(d_2d)) * 255)
# apply non local means filter
d_2d_filtered = cv.fastNlMeansDenoising(d_2d_scaled, None, filter_strength, template_window_size, search_window_size)
# Rescale the denoised signal back to the original range
d_filtered = np.reshape(d_2d_filtered, -1) * (np.max(signal) - np.min(signal)) / 255 + np.min(signal)
wfdb_obj.p_signal[:, idx] = d_filtered
return format_data_output(data)
def filter_data(data, filter_params:dict):
"""
Apply a filter to the input data.
Parameters:
- data: (dict, list, wfdb.Record, or ndarray)
The input data to be filtered.
- filter_params: dict
The parameters of the filter to be applied.
Returns:
- data: (dict, list, wfdb.Record, or ndarray)
The filtered output data.
"""
data = copy.deepcopy(data)
#data = format_data_input(data)
if 'butterlowpass' in filter_params['names']:
data = butterlowpass_filter(data, filter_params['cutoff'], filter_params['fs'], filter_params['order'])
if 'loess' in filter_params['names']:
data = lowess_filter(data, filter_params['frac'], filter_params['it'])
if 'non_local_means' in filter_params['names']:
data = non_local_means_filter(data, filter_params['filter_strength'], filter_params['template_window_size'], filter_params['search_window_size'])
if not any(name in filter_params['names'] for name in ['butterlowpass', 'loess', 'non_local_means']):
print("Warning: No valid filter names found in filter_params['names']. Data will be returned as is.")
return data #format_data_output(data)
if __name__ == "__main__":
data = load_data(only_demographic=False, path_settings="./settings.json")
# print shape of data for each category
for cat_name in data.keys():
print(f"{cat_name}: {len(data[cat_name])}")
order = 1
fs = 500.0
cutoff = 25#25
# Apply filter to the signal
data_test = butterlowpass_filter(data, cutoff, fs, order)
data_test = butterlowpass_filter(data['SB'], cutoff, fs, order)
data_test = butterlowpass_filter(data['SB'][0], cutoff, fs, order)