DSA_SS24/scripts/feature_extraction.py

178 lines
6.9 KiB
Python

from matplotlib import pyplot as plt
import wfdb.processing
import sys
import json
import scipy
import numpy as np
import neurokit2 as nk
import math
import time
def get_y_value(ecg_cleaned, indecies):
"""
Get the y value of the ECG signal at the given indices.
Args:
ecg_cleaned (list): The cleaned ECG signal.
indecies (list): The list of indices.
Returns:
list: The list of y values at the given indices.
"""
return [ecg_cleaned[int(i)] for i in indecies if not math.isnan(i)]
def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0):
"""
Calculate the R and T axis of the ECG signal.
Args:
record (object): The record object containing the ECG signal.
wave_peak (dict): The dictionary containing the wave peaks.
r_peak_idx (list): The list containing the R peak indices.
sampling_rate (int): The sampling rate of the ECG signal.
aVF (int): The index of the aVF lead.
I (int): The index of the I lead.
Returns:
tuple: The R and T axis of the ECG signal.
"""
# Calculate the net QRS in each lead
ecg_signal_avf = record.p_signal[:, aVF]
ecg_signal_avf_cleaned = nk.ecg_clean(ecg_signal_avf, sampling_rate=sampling_rate)
ecg_signal_i = record.p_signal[:, I]
ecg_signal_i_cleaned = nk.ecg_clean(ecg_signal_i, sampling_rate=sampling_rate)
# r axis
# get amplitude of peaks
q_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_Q_Peaks'])
s_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_S_Peaks'])
r_peaks_avf = get_y_value(ecg_signal_avf_cleaned, r_peak_idx)
q_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_Q_Peaks'])
s_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_S_Peaks'])
r_peaks_i = get_y_value(ecg_signal_i_cleaned, r_peak_idx)
# calculate avg peal amplitude
q_peaks_i_avg = np.mean(q_peaks_i)
s_peaks_i_avg = np.mean(s_peaks_i)
r_peaks_i_avg = np.mean(r_peaks_i)
q_peaks_avf_avg = np.mean(q_peaks_avf)
s_peaks_avf_avg = np.mean(s_peaks_avf)
r_peaks_avf_avg = np.mean(r_peaks_avf)
# Calculate net QRS in lead
net_qrs_i = r_peaks_i_avg - (q_peaks_i_avg + s_peaks_i_avg)
net_qrs_avf = r_peaks_avf_avg - (q_peaks_avf_avg + s_peaks_avf_avg)
# t axis
t_peaks_i = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_T_Peaks'])
t_peaks_avf = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_T_Peaks'])
net_t_i = np.mean(t_peaks_i)
net_t_avf = np.mean(t_peaks_avf)
#print("amplitude I", net_qrs.get(I, 0))
#print("amplitude aVF", net_qrs.get(aVF, 0))
# Calculate the R axis (Convert to degrees)
r_axis = np.arctan2(net_qrs_avf, net_qrs_i) * (180 / np.pi)
# Calculate the T axis (Convert to degrees)
t_axis = np.arctan2(net_t_avf, net_t_i) * (180 / np.pi)
return r_axis, t_axis
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
"""
Extracts the features from the data.
Args:
data_dict (dict): The dictionary containing the data.
Returns:
dict: The dictionary containing the extracted features.
"""
start_time = time.time()
feature_data = {}
failed_records = []
for label, data in data_dict.items():
print(f"Extracting features for {label} with {len(data)} data entries.")
for data_idx, record in enumerate(data):
if data_idx % 100 == 0:
stop_time = time.time()
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
start_time = time.time()
age = record.comments[0].split(' ')[1]
gender = record.comments[1].split(' ')[1]
if age == 'NaN' or gender == 'NaN':
continue
features = {}
# Extract the features
features['y'] = label
# Demographic features
features['age'] = int(age)
features['gender'] = True if gender == 'Male' else False
# Signal features
ecg_signal = record.p_signal[:, 0]
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
r_peaks = rpeaks['ECG_R_Peaks']
# Delineate the ECG signal
try:
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
except:
failed_records.append(record.record_name)
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
continue
# TODO Other features and check features
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
features['artial_rate'] = atrial_rate
features['ventricular_rate'] = ventricular_rate
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qrs_duration'] = qrs_duration
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qt_length'] = qt_interval
q_peak = waves_peak['ECG_Q_Peaks']
s_peak = waves_peak['ECG_S_Peaks']
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
features['qrs_count'] = qrs_count
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
features['r_axis'] = r_axis
features['t_axis'] = t_axis
# print the features
#print(json.dumps(features, indent=4))
feature_data[record.record_name] = features
return feature_data
def split_data(feature_data, split_ratio):
print(f"Splitting data with ratio {split_ratio}")
#flatten dictionary
feature_data = {k: v for k, v in feature_data.items()}
# print keys
print("Keys:")
print(len(feature_data.keys()))
# shuffle the data
keys = list(feature_data.keys())
np.random.shuffle(keys)
# split the data
split_data = {}
split_data['train'] = {k: feature_data[k] for k in keys[:int(len(keys) * split_ratio[0])]}
split_data['test'] = {k: feature_data[k] for k in keys[int(len(keys) * split_ratio[0]):int(len(keys) * (split_ratio[0] + split_ratio[1]))]}
split_data['validation'] = {k: feature_data[k] for k in keys[int(len(keys) * (split_ratio[0] + split_ratio[1])):]}
return split_data