DSA_SS24/scripts/feature_extraction.py

178 lines
6.9 KiB
Python
Raw Normal View History

2024-06-05 14:24:31 +02:00
from matplotlib import pyplot as plt
import wfdb.processing
import sys
import json
import scipy
import numpy as np
import neurokit2 as nk
import math
import time
def get_y_value(ecg_cleaned, indecies):
"""
Get the y value of the ECG signal at the given indices.
Args:
ecg_cleaned (list): The cleaned ECG signal.
indecies (list): The list of indices.
Returns:
list: The list of y values at the given indices.
"""
return [ecg_cleaned[int(i)] for i in indecies if not math.isnan(i)]
def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0):
"""
Calculate the R and T axis of the ECG signal.
Args:
record (object): The record object containing the ECG signal.
wave_peak (dict): The dictionary containing the wave peaks.
r_peak_idx (list): The list containing the R peak indices.
sampling_rate (int): The sampling rate of the ECG signal.
aVF (int): The index of the aVF lead.
I (int): The index of the I lead.
Returns:
tuple: The R and T axis of the ECG signal.
"""
# Calculate the net QRS in each lead
ecg_signal_avf = record.p_signal[:, aVF]
ecg_signal_avf_cleaned = nk.ecg_clean(ecg_signal_avf, sampling_rate=sampling_rate)
ecg_signal_i = record.p_signal[:, I]
ecg_signal_i_cleaned = nk.ecg_clean(ecg_signal_i, sampling_rate=sampling_rate)
# r axis
# get amplitude of peaks
q_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_Q_Peaks'])
s_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_S_Peaks'])
r_peaks_avf = get_y_value(ecg_signal_avf_cleaned, r_peak_idx)
q_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_Q_Peaks'])
s_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_S_Peaks'])
r_peaks_i = get_y_value(ecg_signal_i_cleaned, r_peak_idx)
# calculate avg peal amplitude
q_peaks_i_avg = np.mean(q_peaks_i)
s_peaks_i_avg = np.mean(s_peaks_i)
r_peaks_i_avg = np.mean(r_peaks_i)
q_peaks_avf_avg = np.mean(q_peaks_avf)
s_peaks_avf_avg = np.mean(s_peaks_avf)
r_peaks_avf_avg = np.mean(r_peaks_avf)
# Calculate net QRS in lead
net_qrs_i = r_peaks_i_avg - (q_peaks_i_avg + s_peaks_i_avg)
net_qrs_avf = r_peaks_avf_avg - (q_peaks_avf_avg + s_peaks_avf_avg)
# t axis
t_peaks_i = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_T_Peaks'])
t_peaks_avf = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_T_Peaks'])
net_t_i = np.mean(t_peaks_i)
net_t_avf = np.mean(t_peaks_avf)
#print("amplitude I", net_qrs.get(I, 0))
#print("amplitude aVF", net_qrs.get(aVF, 0))
# Calculate the R axis (Convert to degrees)
r_axis = np.arctan2(net_qrs_avf, net_qrs_i) * (180 / np.pi)
# Calculate the T axis (Convert to degrees)
t_axis = np.arctan2(net_t_avf, net_t_i) * (180 / np.pi)
return r_axis, t_axis
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
"""
Extracts the features from the data.
Args:
data_dict (dict): The dictionary containing the data.
Returns:
dict: The dictionary containing the extracted features.
"""
start_time = time.time()
feature_data = {}
failed_records = []
for label, data in data_dict.items():
print(f"Extracting features for {label} with {len(data)} data entries.")
for data_idx, record in enumerate(data):
if data_idx % 100 == 0:
stop_time = time.time()
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
start_time = time.time()
age = record.comments[0].split(' ')[1]
gender = record.comments[1].split(' ')[1]
if age == 'NaN' or gender == 'NaN':
continue
features = {}
# Extract the features
features['y'] = label
# Demographic features
features['age'] = int(age)
features['gender'] = True if gender == 'Male' else False
# Signal features
ecg_signal = record.p_signal[:, 0]
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
r_peaks = rpeaks['ECG_R_Peaks']
# Delineate the ECG signal
try:
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
except:
failed_records.append(record.record_name)
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
continue
# TODO Other features and check features
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
features['artial_rate'] = atrial_rate
features['ventricular_rate'] = ventricular_rate
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qrs_duration'] = qrs_duration
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qt_length'] = qt_interval
q_peak = waves_peak['ECG_Q_Peaks']
s_peak = waves_peak['ECG_S_Peaks']
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
features['qrs_count'] = qrs_count
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
features['r_axis'] = r_axis
features['t_axis'] = t_axis
# print the features
#print(json.dumps(features, indent=4))
feature_data[record.record_name] = features
return feature_data
def split_data(feature_data, split_ratio):
print(f"Splitting data with ratio {split_ratio}")
#flatten dictionary
feature_data = {k: v for k, v in feature_data.items()}
# print keys
print("Keys:")
print(len(feature_data.keys()))
# shuffle the data
keys = list(feature_data.keys())
np.random.shuffle(keys)
# split the data
split_data = {}
split_data['train'] = {k: feature_data[k] for k in keys[:int(len(keys) * split_ratio[0])]}
split_data['test'] = {k: feature_data[k] for k in keys[int(len(keys) * split_ratio[0]):int(len(keys) * (split_ratio[0] + split_ratio[1]))]}
split_data['validation'] = {k: feature_data[k] for k in keys[int(len(keys) * (split_ratio[0] + split_ratio[1])):]}
return split_data