import wfdb.processing import numpy as np import neurokit2 as nk import math import time from multiprocessing import Pool import sqlite3 import random def get_y_value(ecg_cleaned, indecies): """ Get the y value of the ECG signal at the given indices. Args: ecg_cleaned (list): The cleaned ECG signal. indecies (list): The list of indices. Returns: list: The list of y values at the given indices. """ return [ecg_cleaned[int(i)] for i in indecies if not math.isnan(i)] def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0): """ Calculate the R and T axis of the ECG signal. Args: record (object): The record object containing the ECG signal. wave_peak (dict): The dictionary containing the wave peaks. r_peak_idx (list): The list containing the R peak indices. sampling_rate (int): The sampling rate of the ECG signal. aVF (int): The index of the aVF lead. I (int): The index of the I lead. Returns: tuple: The R and T axis of the ECG signal. """ # Calculate the net QRS in each lead ecg_signal_avf = record.p_signal[:, aVF] ecg_signal_avf_cleaned = nk.ecg_clean(ecg_signal_avf, sampling_rate=sampling_rate) ecg_signal_i = record.p_signal[:, I] ecg_signal_i_cleaned = nk.ecg_clean(ecg_signal_i, sampling_rate=sampling_rate) # r axis # get amplitude of peaks q_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_Q_Peaks']) s_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_S_Peaks']) r_peaks_avf = get_y_value(ecg_signal_avf_cleaned, r_peak_idx) q_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_Q_Peaks']) s_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_S_Peaks']) r_peaks_i = get_y_value(ecg_signal_i_cleaned, r_peak_idx) # calculate avg peal amplitude q_peaks_i_avg = np.mean(q_peaks_i) s_peaks_i_avg = np.mean(s_peaks_i) r_peaks_i_avg = np.mean(r_peaks_i) q_peaks_avf_avg = np.mean(q_peaks_avf) s_peaks_avf_avg = np.mean(s_peaks_avf) r_peaks_avf_avg = np.mean(r_peaks_avf) # Calculate net QRS in lead net_qrs_i = r_peaks_i_avg - (q_peaks_i_avg + s_peaks_i_avg) net_qrs_avf = r_peaks_avf_avg - (q_peaks_avf_avg + s_peaks_avf_avg) # t axis t_peaks_i = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_T_Peaks']) t_peaks_avf = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_T_Peaks']) net_t_i = np.mean(t_peaks_i) net_t_avf = np.mean(t_peaks_avf) #print("amplitude I", net_qrs.get(I, 0)) #print("amplitude aVF", net_qrs.get(aVF, 0)) # Calculate the R axis (Convert to degrees) r_axis = np.arctan2(net_qrs_avf, net_qrs_i) * (180 / np.pi) # Calculate the T axis (Convert to degrees) t_axis = np.arctan2(net_t_avf, net_t_i) * (180 / np.pi) return r_axis, t_axis def get_features(record, label, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None): """ Extracts the features from the record. Args: record (object): The record object containing the ECG signal. label (str): The label of the record. sampling_rate (int): The sampling rate of the ECG signal. Returns: dict: The dictionary containing the extracted features. """ age = record.comments[0].split(' ')[1] gender = record.comments[1].split(' ')[1] if age == 'NaN' or gender == 'NaN': return None features = {} # Extract the features features['y'] = label # Demographic features features['age'] = int(age) features['gender'] = True if gender == 'Male' else False # Signal features # Delineate the ECG signal try: ecg_signal = record.p_signal[:, 0] ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate) _, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate) r_peaks = rpeaks['ECG_R_Peaks'] _, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak") except KeyboardInterrupt: conn.close() raise except: return None # TODO Other features and check features atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6 ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned))) features['artial_rate'] = atrial_rate features['ventricular_rate'] = ventricular_rate qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks'])) features['qrs_duration'] = qrs_duration qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks'])) features['qt_length'] = qt_interval q_peak = waves_peak['ECG_Q_Peaks'] s_peak = waves_peak['ECG_S_Peaks'] # check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False) features['qrs_count'] = qrs_count features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak)) r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0) features['r_axis'] = r_axis features['t_axis'] = t_axis return features def exclude_already_extracted(data_dict, conn): """ Exclude the records that are already in the database. Args: data_dict (dict): The dictionary containing the data. Returns: dict: The dictionary containing the unique data. """ unique_data_dict = {} record_ids = [] for label, data in data_dict.items(): for record in data: record_ids.append(record.record_name) # get id column from database and subtract it from record_ids c = conn.cursor() c.execute('SELECT id FROM features') db_ids = c.fetchall() db_ids = [x[0] for x in db_ids] unique_ids = list(set(record_ids) - set(db_ids)) for label, data in data_dict.items(): unique_data_dict[label] = [record for record in data if record.record_name in unique_ids] return unique_data_dict def process_record(data_dict_item, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None, c=None): """ Process a record to extract the features. Args: data_dict_item (tuple): The tuple containing the data dictionary item. sampling_rate (int): The sampling rate of the ECG signal. used_channels (list): The list of used channels. conn (object): The connection object to the database. c (object): The cursor object. Returns: tuple: The tuple containing the record name and the extracted features. """ label, record, c, conn = data_dict_item features = get_features(record,label, sampling_rate=sampling_rate, used_channels=used_channels) if features is None: print(f"Failed to extract features for record {record.record_name}") return record.record_name, None # TODO: Insert the record into the database #feature_data[record_name] = features # Define the feature names feature_names = list(features.keys()) # Create the table if it doesn't exist c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})") # Insert the record into the table c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})", [record.record_name] + list(features.values())) conn.commit() return record.record_name, features def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]): """ Extracts the features from the data. Args: data_dict (dict): The dictionary containing the data. num_processes (int): The number of processes to use. """ start_time = time.time() failed_records = [] processed_records = 0 conn = sqlite3.connect('features.db') c = conn.cursor() # get unique data data_dict = exclude_already_extracted(data_dict, conn) for label, data in data_dict.items(): print(f"Extracting features for {label} with {len(data)} data entries.") with Pool(processes=num_processes) as pool: results = pool.map(process_record, [(label, record, c, conn) for record in data]) for result in results: processed_records += 1 if processed_records % 100 == 0: stop_time = time.time() print(f"Extracted features for {processed_records} records. Time taken: {stop_time - start_time:.2f}s") start_time = time.time() try: record_name, features = result except Exception as exc: print(f'{record_name} generated an exception: {exc}') else: if features is not None: print(f"Extracted features for record {record_name}") else: failed_records.append(record_name) print(f"Sum of failed records: {len(failed_records)}") conn.close() def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], limit=1000): """ Extracts the features from the data. Args: data_dict (dict): The dictionary containing the data. """ start_time = time.time() failed_records = [] conn = sqlite3.connect('features.db') c = conn.cursor() # get last file in the database try: # print how many records are already in the database c.execute('SELECT COUNT(*) FROM features') print("Records in DB:", c.fetchall()) c = conn.cursor() c.execute('SELECT id FROM features') db_ids = c.fetchall() db_ids = [x[0] for x in db_ids] except: db_ids = [] print("No last file in DB") for label, data in data_dict.items(): # get limit amount of radom samples out of data data = random.sample(data, min(len(data), limit)) print(f"Extracting features for {label} with {len(data)} data entries.") for data_idx, record in enumerate(data): # Skip the records that are already in the database if record.record_name in db_ids: continue # print current status if data_idx % 100 == 0: stop_time = time.time() print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s") start_time = time.time() # get the features features = get_features(record, label, sampling_rate=sampling_rate, used_channels=used_channels, conn=conn) if features is None: failed_records.append(record.record_name) print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}") continue # Define the feature names feature_names = list(features.keys()) # Create the table if it doesn't exist c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})") # Insert the record into the table c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})", [record.record_name] + list(features.values())) conn.commit() conn.close() def split_and_shuffle_data(split_ratio, db=None): """ Splits the data into training, test, and validation sets. Args: split_ratio (list): The ratio in which the data will be split into training, test, and validation sets. db (str): The name of the database. """ print(f"Splitting data with ratio {split_ratio}") if db is None: db = 'features.db' conn = sqlite3.connect(db) c = conn.cursor() # shuffle the data (rows) # Randomize the rows and create a new table with a new ROWID c.execute("CREATE TABLE IF NOT EXISTS features_random AS SELECT * FROM features ORDER BY RANDOM()") # Drop the old features table c.execute("DROP TABLE features") # Rename the new table to features c.execute("ALTER TABLE features_random RENAME TO features") # get length of the data c.execute('SELECT COUNT(*) FROM features') length = c.fetchone()[0] # split the data into training, test, and validation sets train_idx = int(length * split_ratio[0]) test_idx = int(length * (split_ratio[0] + split_ratio[1])) # create 3 tables for training, test, and validation sets # Create the train, test, and validation tables if they don't exist c.execute("CREATE TABLE IF NOT EXISTS train AS SELECT * FROM features WHERE 0") c.execute("CREATE TABLE IF NOT EXISTS test AS SELECT * FROM features WHERE 0") c.execute("CREATE TABLE IF NOT EXISTS validation AS SELECT * FROM features WHERE 0") # Insert rows into the train table c.execute("INSERT INTO train SELECT * FROM features WHERE ROWID <= ?", (train_idx,)) c.execute("INSERT INTO test SELECT * FROM features WHERE ROWID > ? AND ROWID <= ?", (train_idx, test_idx,)) c.execute("INSERT INTO validation SELECT * FROM features WHERE ROWID > ?", (test_idx,)) # drop the features table # c.execute('DROP TABLE features') conn.commit() conn.close()