341 lines
13 KiB
Python
341 lines
13 KiB
Python
import wfdb.processing
|
|
import numpy as np
|
|
import neurokit2 as nk
|
|
import math
|
|
import time
|
|
from multiprocessing import Pool
|
|
import sqlite3
|
|
import random
|
|
|
|
def get_y_value(ecg_cleaned, indecies):
|
|
"""
|
|
Get the y value of the ECG signal at the given indices.
|
|
Args:
|
|
ecg_cleaned (list): The cleaned ECG signal.
|
|
indecies (list): The list of indices.
|
|
Returns:
|
|
list: The list of y values at the given indices.
|
|
"""
|
|
return [ecg_cleaned[int(i)] for i in indecies if not math.isnan(i)]
|
|
|
|
|
|
def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0):
|
|
"""
|
|
Calculate the R and T axis of the ECG signal.
|
|
Args:
|
|
record (object): The record object containing the ECG signal.
|
|
wave_peak (dict): The dictionary containing the wave peaks.
|
|
r_peak_idx (list): The list containing the R peak indices.
|
|
sampling_rate (int): The sampling rate of the ECG signal.
|
|
aVF (int): The index of the aVF lead.
|
|
I (int): The index of the I lead.
|
|
Returns:
|
|
tuple: The R and T axis of the ECG signal.
|
|
"""
|
|
# Calculate the net QRS in each lead
|
|
ecg_signal_avf = record.p_signal[:, aVF]
|
|
ecg_signal_avf_cleaned = nk.ecg_clean(ecg_signal_avf, sampling_rate=sampling_rate)
|
|
|
|
ecg_signal_i = record.p_signal[:, I]
|
|
ecg_signal_i_cleaned = nk.ecg_clean(ecg_signal_i, sampling_rate=sampling_rate)
|
|
|
|
# r axis
|
|
# get amplitude of peaks
|
|
q_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_Q_Peaks'])
|
|
s_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_S_Peaks'])
|
|
r_peaks_avf = get_y_value(ecg_signal_avf_cleaned, r_peak_idx)
|
|
|
|
q_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_Q_Peaks'])
|
|
s_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_S_Peaks'])
|
|
r_peaks_i = get_y_value(ecg_signal_i_cleaned, r_peak_idx)
|
|
|
|
# calculate avg peal amplitude
|
|
q_peaks_i_avg = np.mean(q_peaks_i)
|
|
s_peaks_i_avg = np.mean(s_peaks_i)
|
|
r_peaks_i_avg = np.mean(r_peaks_i)
|
|
|
|
q_peaks_avf_avg = np.mean(q_peaks_avf)
|
|
s_peaks_avf_avg = np.mean(s_peaks_avf)
|
|
r_peaks_avf_avg = np.mean(r_peaks_avf)
|
|
|
|
# Calculate net QRS in lead
|
|
net_qrs_i = r_peaks_i_avg - (q_peaks_i_avg + s_peaks_i_avg)
|
|
net_qrs_avf = r_peaks_avf_avg - (q_peaks_avf_avg + s_peaks_avf_avg)
|
|
|
|
# t axis
|
|
t_peaks_i = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_T_Peaks'])
|
|
t_peaks_avf = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_T_Peaks'])
|
|
|
|
net_t_i = np.mean(t_peaks_i)
|
|
net_t_avf = np.mean(t_peaks_avf)
|
|
|
|
#print("amplitude I", net_qrs.get(I, 0))
|
|
#print("amplitude aVF", net_qrs.get(aVF, 0))
|
|
|
|
# Calculate the R axis (Convert to degrees)
|
|
r_axis = np.arctan2(net_qrs_avf, net_qrs_i) * (180 / np.pi)
|
|
|
|
# Calculate the T axis (Convert to degrees)
|
|
t_axis = np.arctan2(net_t_avf, net_t_i) * (180 / np.pi)
|
|
|
|
return r_axis, t_axis
|
|
|
|
def get_features(record, label, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None):
|
|
"""
|
|
Extracts the features from the record.
|
|
Args:
|
|
record (object): The record object containing the ECG signal.
|
|
label (str): The label of the record.
|
|
sampling_rate (int): The sampling rate of the ECG signal.
|
|
Returns:
|
|
dict: The dictionary containing the extracted features.
|
|
"""
|
|
age = record.comments[0].split(' ')[1]
|
|
gender = record.comments[1].split(' ')[1]
|
|
if age == 'NaN' or gender == 'NaN':
|
|
return None
|
|
features = {}
|
|
# Extract the features
|
|
features['y'] = label
|
|
# Demographic features
|
|
features['age'] = int(age)
|
|
features['gender'] = True if gender == 'Male' else False
|
|
# Signal features
|
|
|
|
|
|
# Delineate the ECG signal
|
|
try:
|
|
ecg_signal = record.p_signal[:, 0]
|
|
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
|
|
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
|
|
r_peaks = rpeaks['ECG_R_Peaks']
|
|
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
|
|
except KeyboardInterrupt:
|
|
conn.close()
|
|
raise
|
|
except:
|
|
return None
|
|
|
|
# TODO Other features and check features
|
|
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
|
|
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
|
|
|
|
features['artial_rate'] = atrial_rate
|
|
features['ventricular_rate'] = ventricular_rate
|
|
|
|
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
|
|
features['qrs_duration'] = qrs_duration
|
|
|
|
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
|
|
features['qt_length'] = qt_interval
|
|
|
|
q_peak = waves_peak['ECG_Q_Peaks']
|
|
s_peak = waves_peak['ECG_S_Peaks']
|
|
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
|
|
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
|
|
features['qrs_count'] = qrs_count
|
|
|
|
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
|
|
|
|
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
|
|
features['r_axis'] = r_axis
|
|
features['t_axis'] = t_axis
|
|
|
|
return features
|
|
|
|
def exclude_already_extracted(data_dict, conn):
|
|
"""
|
|
Exclude the records that are already in the database.
|
|
Args:
|
|
data_dict (dict): The dictionary containing the data.
|
|
Returns:
|
|
dict: The dictionary containing the unique data.
|
|
"""
|
|
unique_data_dict = {}
|
|
record_ids = []
|
|
for label, data in data_dict.items():
|
|
for record in data:
|
|
record_ids.append(record.record_name)
|
|
# get id column from database and subtract it from record_ids
|
|
c = conn.cursor()
|
|
c.execute('SELECT id FROM features')
|
|
db_ids = c.fetchall()
|
|
db_ids = [x[0] for x in db_ids]
|
|
unique_ids = list(set(record_ids) - set(db_ids))
|
|
for label, data in data_dict.items():
|
|
unique_data_dict[label] = [record for record in data if record.record_name in unique_ids]
|
|
return unique_data_dict
|
|
|
|
|
|
def process_record(data_dict_item, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None, c=None):
|
|
"""
|
|
Process a record to extract the features.
|
|
Args:
|
|
data_dict_item (tuple): The tuple containing the data dictionary item.
|
|
sampling_rate (int): The sampling rate of the ECG signal.
|
|
used_channels (list): The list of used channels.
|
|
conn (object): The connection object to the database.
|
|
c (object): The cursor object.
|
|
Returns:
|
|
tuple: The tuple containing the record name and the extracted features.
|
|
"""
|
|
|
|
label, record, c, conn = data_dict_item
|
|
features = get_features(record,label, sampling_rate=sampling_rate, used_channels=used_channels)
|
|
if features is None:
|
|
print(f"Failed to extract features for record {record.record_name}")
|
|
return record.record_name, None
|
|
# TODO: Insert the record into the database
|
|
#feature_data[record_name] = features
|
|
# Define the feature names
|
|
feature_names = list(features.keys())
|
|
# Create the table if it doesn't exist
|
|
c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})")
|
|
# Insert the record into the table
|
|
c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})",
|
|
[record.record_name] + list(features.values()))
|
|
conn.commit()
|
|
|
|
|
|
return record.record_name, features
|
|
|
|
def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
|
|
"""
|
|
Extracts the features from the data.
|
|
Args:
|
|
data_dict (dict): The dictionary containing the data.
|
|
num_processes (int): The number of processes to use.
|
|
"""
|
|
start_time = time.time()
|
|
failed_records = []
|
|
processed_records = 0
|
|
|
|
conn = sqlite3.connect('features.db')
|
|
c = conn.cursor()
|
|
# get unique data
|
|
data_dict = exclude_already_extracted(data_dict, conn)
|
|
for label, data in data_dict.items():
|
|
print(f"Extracting features for {label} with {len(data)} data entries.")
|
|
with Pool(processes=num_processes) as pool:
|
|
results = pool.map(process_record, [(label, record, c, conn) for record in data])
|
|
for result in results:
|
|
processed_records += 1
|
|
if processed_records % 100 == 0:
|
|
stop_time = time.time()
|
|
print(f"Extracted features for {processed_records} records. Time taken: {stop_time - start_time:.2f}s")
|
|
start_time = time.time()
|
|
try:
|
|
record_name, features = result
|
|
except Exception as exc:
|
|
print(f'{record_name} generated an exception: {exc}')
|
|
else:
|
|
if features is not None:
|
|
print(f"Extracted features for record {record_name}")
|
|
else:
|
|
failed_records.append(record_name)
|
|
print(f"Sum of failed records: {len(failed_records)}")
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], limit=1000):
|
|
"""
|
|
Extracts the features from the data.
|
|
Args:
|
|
data_dict (dict): The dictionary containing the data.
|
|
"""
|
|
start_time = time.time()
|
|
failed_records = []
|
|
|
|
conn = sqlite3.connect('features.db')
|
|
c = conn.cursor()
|
|
# get last file in the database
|
|
|
|
try:
|
|
# print how many records are already in the database
|
|
c.execute('SELECT COUNT(*) FROM features')
|
|
print("Records in DB:", c.fetchall())
|
|
|
|
c = conn.cursor()
|
|
c.execute('SELECT id FROM features')
|
|
db_ids = c.fetchall()
|
|
db_ids = [x[0] for x in db_ids]
|
|
except:
|
|
db_ids = []
|
|
print("No last file in DB")
|
|
|
|
for label, data in data_dict.items():
|
|
# get limit amount of radom samples out of data
|
|
data = random.sample(data, min(len(data), limit))
|
|
print(f"Extracting features for {label} with {len(data)} data entries.")
|
|
for data_idx, record in enumerate(data):
|
|
# Skip the records that are already in the database
|
|
if record.record_name in db_ids:
|
|
continue
|
|
# print current status
|
|
if data_idx % 100 == 0:
|
|
stop_time = time.time()
|
|
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
|
|
start_time = time.time()
|
|
# get the features
|
|
features = get_features(record, label, sampling_rate=sampling_rate, used_channels=used_channels, conn=conn)
|
|
if features is None:
|
|
failed_records.append(record.record_name)
|
|
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
|
|
continue
|
|
|
|
|
|
# Define the feature names
|
|
feature_names = list(features.keys())
|
|
# Create the table if it doesn't exist
|
|
c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})")
|
|
# Insert the record into the table
|
|
c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})",
|
|
[record.record_name] + list(features.values()))
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
|
|
def split_and_shuffle_data(split_ratio, db=None):
|
|
"""
|
|
Splits the data into training, test, and validation sets.
|
|
Args:
|
|
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
|
|
db (str): The name of the database.
|
|
"""
|
|
print(f"Splitting data with ratio {split_ratio}")
|
|
|
|
if db is None:
|
|
db = 'features.db'
|
|
conn = sqlite3.connect(db)
|
|
c = conn.cursor()
|
|
# shuffle the data (rows)
|
|
# Randomize the rows and create a new table with a new ROWID
|
|
c.execute("CREATE TABLE IF NOT EXISTS features_random AS SELECT * FROM features ORDER BY RANDOM()")
|
|
# Drop the old features table
|
|
c.execute("DROP TABLE features")
|
|
# Rename the new table to features
|
|
c.execute("ALTER TABLE features_random RENAME TO features")
|
|
|
|
# get length of the data
|
|
c.execute('SELECT COUNT(*) FROM features')
|
|
length = c.fetchone()[0]
|
|
# split the data into training, test, and validation sets
|
|
train_idx = int(length * split_ratio[0])
|
|
test_idx = int(length * (split_ratio[0] + split_ratio[1]))
|
|
# create 3 tables for training, test, and validation sets
|
|
# Create the train, test, and validation tables if they don't exist
|
|
c.execute("CREATE TABLE IF NOT EXISTS train AS SELECT * FROM features WHERE 0")
|
|
c.execute("CREATE TABLE IF NOT EXISTS test AS SELECT * FROM features WHERE 0")
|
|
c.execute("CREATE TABLE IF NOT EXISTS validation AS SELECT * FROM features WHERE 0")
|
|
# Insert rows into the train table
|
|
c.execute("INSERT INTO train SELECT * FROM features WHERE ROWID <= ?", (train_idx,))
|
|
c.execute("INSERT INTO test SELECT * FROM features WHERE ROWID > ? AND ROWID <= ?", (train_idx, test_idx,))
|
|
c.execute("INSERT INTO validation SELECT * FROM features WHERE ROWID > ?", (test_idx,))
|
|
|
|
# drop the features table
|
|
# c.execute('DROP TABLE features')
|
|
conn.commit()
|
|
conn.close() |