diff --git a/scripts/feature_extraction.py b/scripts/feature_extraction.py new file mode 100644 index 0000000..c015142 --- /dev/null +++ b/scripts/feature_extraction.py @@ -0,0 +1,178 @@ + +from matplotlib import pyplot as plt +import wfdb.processing +import sys +import json +import scipy +import numpy as np +import neurokit2 as nk +import math +import time + +def get_y_value(ecg_cleaned, indecies): + """ + Get the y value of the ECG signal at the given indices. + Args: + ecg_cleaned (list): The cleaned ECG signal. + indecies (list): The list of indices. + Returns: + list: The list of y values at the given indices. + """ + return [ecg_cleaned[int(i)] for i in indecies if not math.isnan(i)] + + +def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0): + """ + Calculate the R and T axis of the ECG signal. + Args: + record (object): The record object containing the ECG signal. + wave_peak (dict): The dictionary containing the wave peaks. + r_peak_idx (list): The list containing the R peak indices. + sampling_rate (int): The sampling rate of the ECG signal. + aVF (int): The index of the aVF lead. + I (int): The index of the I lead. + Returns: + tuple: The R and T axis of the ECG signal. + """ + # Calculate the net QRS in each lead + ecg_signal_avf = record.p_signal[:, aVF] + ecg_signal_avf_cleaned = nk.ecg_clean(ecg_signal_avf, sampling_rate=sampling_rate) + + ecg_signal_i = record.p_signal[:, I] + ecg_signal_i_cleaned = nk.ecg_clean(ecg_signal_i, sampling_rate=sampling_rate) + + # r axis + # get amplitude of peaks + q_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_Q_Peaks']) + s_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_S_Peaks']) + r_peaks_avf = get_y_value(ecg_signal_avf_cleaned, r_peak_idx) + + q_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_Q_Peaks']) + s_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_S_Peaks']) + r_peaks_i = get_y_value(ecg_signal_i_cleaned, r_peak_idx) + + # calculate avg peal amplitude + q_peaks_i_avg = np.mean(q_peaks_i) + s_peaks_i_avg = np.mean(s_peaks_i) + r_peaks_i_avg = np.mean(r_peaks_i) + + q_peaks_avf_avg = np.mean(q_peaks_avf) + s_peaks_avf_avg = np.mean(s_peaks_avf) + r_peaks_avf_avg = np.mean(r_peaks_avf) + + # Calculate net QRS in lead + net_qrs_i = r_peaks_i_avg - (q_peaks_i_avg + s_peaks_i_avg) + net_qrs_avf = r_peaks_avf_avg - (q_peaks_avf_avg + s_peaks_avf_avg) + + # t axis + t_peaks_i = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_T_Peaks']) + t_peaks_avf = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_T_Peaks']) + + net_t_i = np.mean(t_peaks_i) + net_t_avf = np.mean(t_peaks_avf) + + #print("amplitude I", net_qrs.get(I, 0)) + #print("amplitude aVF", net_qrs.get(aVF, 0)) + + # Calculate the R axis (Convert to degrees) + r_axis = np.arctan2(net_qrs_avf, net_qrs_i) * (180 / np.pi) + + # Calculate the T axis (Convert to degrees) + t_axis = np.arctan2(net_t_avf, net_t_i) * (180 / np.pi) + + return r_axis, t_axis + + +def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]): + """ + Extracts the features from the data. + Args: + data_dict (dict): The dictionary containing the data. + Returns: + dict: The dictionary containing the extracted features. + """ + start_time = time.time() + feature_data = {} + failed_records = [] + + for label, data in data_dict.items(): + print(f"Extracting features for {label} with {len(data)} data entries.") + for data_idx, record in enumerate(data): + if data_idx % 100 == 0: + stop_time = time.time() + print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s") + start_time = time.time() + + age = record.comments[0].split(' ')[1] + gender = record.comments[1].split(' ')[1] + if age == 'NaN' or gender == 'NaN': + continue + features = {} + # Extract the features + features['y'] = label + # Demographic features + features['age'] = int(age) + features['gender'] = True if gender == 'Male' else False + # Signal features + + ecg_signal = record.p_signal[:, 0] + ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate) + _, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate) + r_peaks = rpeaks['ECG_R_Peaks'] + # Delineate the ECG signal + try: + _, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak") + except: + failed_records.append(record.record_name) + print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}") + continue + + # TODO Other features and check features + atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6 + ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned))) + + features['artial_rate'] = atrial_rate + features['ventricular_rate'] = ventricular_rate + + qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks'])) + features['qrs_duration'] = qrs_duration + + qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks'])) + features['qt_length'] = qt_interval + + q_peak = waves_peak['ECG_Q_Peaks'] + s_peak = waves_peak['ECG_S_Peaks'] + # check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists + qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False) + features['qrs_count'] = qrs_count + + features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak)) + + r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0) + features['r_axis'] = r_axis + features['t_axis'] = t_axis + + # print the features + #print(json.dumps(features, indent=4)) + feature_data[record.record_name] = features + + return feature_data + + +def split_data(feature_data, split_ratio): + print(f"Splitting data with ratio {split_ratio}") + #flatten dictionary + feature_data = {k: v for k, v in feature_data.items()} + # print keys + print("Keys:") + print(len(feature_data.keys())) + # shuffle the data + keys = list(feature_data.keys()) + np.random.shuffle(keys) + # split the data + split_data = {} + split_data['train'] = {k: feature_data[k] for k in keys[:int(len(keys) * split_ratio[0])]} + split_data['test'] = {k: feature_data[k] for k in keys[int(len(keys) * split_ratio[0]):int(len(keys) * (split_ratio[0] + split_ratio[1]))]} + split_data['validation'] = {k: feature_data[k] for k in keys[int(len(keys) * (split_ratio[0] + split_ratio[1])):]} + + return split_data \ No newline at end of file diff --git a/scripts/generate_data.py b/scripts/generate_data.py index df772d0..f676b37 100644 --- a/scripts/generate_data.py +++ b/scripts/generate_data.py @@ -1,78 +1,180 @@ +""" +This script reads the WFDB records and extracts the diagnosis information from the comments. +The diagnosis information is then used to classify the records into categories. +The categories are defined by the diagnosis codes in the comments. +The records are then saved to pickle files based on the categories. +""" + import wfdb import os -import pickle -import bz2 import numpy as np -import pandas as pd +import pickle +import json -# Funktionen zum Bearbeiten der Daten +import feature_extraction + +# Functions def get_diagnosis_ids(record): + """ + Extracts diagnosis IDs from a record and returns them as a list. + Args: + record (object): The record object containing the diagnosis information. + Returns: + list: A list of diagnosis IDs extracted from the record. + """ + # Get the diagnosis diagnosis = record.comments[2] + # clean the diagnosis diagnosis = diagnosis.replace('Dx: ', '') list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')] return list_diagnosis -def get_diagnosis_name(diagnosis): - name = [diagnosis_lookup[diagnosis_lookup['Snomed_CT'] == x]['Full Name'].to_string(index=False) for x in diagnosis] - return name +def generate_raw_data(path_to_data, settings, max_counter=100_000): + """ + Generates the raw data from the WFDB records. + Args: + path_to_data (str): The path to the directory containing the WFDB records. + max_counter (int): The maximum number of records to read. + Returns: + dict: A dictionary containing the raw data. + """ + counter = 0 + failed_records = [] + categories = settings["labels"] -def filter_signal_df_on_diag(df_dict, diagnosis_dict, filter_codes_df): - filter_cod_li = list(filter_codes_df['Snomed_CT']) + [0] - filter_dict_diag = {k: v for k, v in diagnosis_dict.items() if all(i in filter_cod_li for i in v)} - filtered_df_dict = {i: df.loc[df.index.isin(filter_dict_diag.keys())] for i, df in df_dict.items()} - return filtered_df_dict - -# Verzeichnisse und Dateipfade -project_dir = 'C:/Users/arman/PycharmProjects/pythonProject/DSA/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0' -data_dir = project_dir + '/WFDBRecords' -path_diag_lookup = "C:/Users/arman/PycharmProjects/pythonProject/DSA/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0/ConditionNames_SNOMED-CT.csv" - -# Daten erkunden -diagnosis_lookup = pd.read_csv(path_diag_lookup) - -categories = { - 'SB': [426177001], - 'AFIB': [164889003, 164890007], - 'GSVT': [426761007, 713422000, 233896004, 233897008, 713422000], - 'SR': [426783006, 427393009] -} - -diag_dict = {k: [] for k in categories.keys()} -counter = 0 -max_counter = 100_000 - -for dir_th in os.listdir(data_dir): - path_to_1000_records = data_dir + '/' + dir_th - for dir_hd in os.listdir(path_to_1000_records): - path_to_100_records = path_to_1000_records + '/' + dir_hd - for record_name in os.listdir(path_to_100_records): - if '.hea' not in record_name: - continue - record_name = record_name.replace('.hea', '') - try: - record = wfdb.rdrecord(path_to_100_records + '/' + record_name) - diagnosis = np.array(get_diagnosis_ids(record)) - for category_name, category_codes in categories.items(): - if any(i in category_codes for i in diagnosis): - diag_dict[category_name].append(record) + diag_dict = {k: [] for k in categories.keys()} + # Loop through the records + for dir_th in os.listdir(path_to_data): + path_to_1000_records = path_to_data + '/' + dir_th + for dir_hd in os.listdir(path_to_1000_records): + path_to_100_records = path_to_1000_records + '/' + dir_hd + for record_name in os.listdir(path_to_100_records): + # check if .hea is in the record_name + if '.hea' not in record_name: + continue + # Remove the .hea extension from record_name + record_name = record_name.replace('.hea', '') + try: + # Read the record + record = wfdb.rdrecord(path_to_100_records + '/' + record_name) + # Get the diagnosis + diagnosis = np.array(get_diagnosis_ids(record)) + # check if diagnosis is a subset of one of the categories + for category_name, category_codes in categories.items(): + # if any of the diagnosis codes is in the category_codes + if any(i in category_codes for i in diagnosis): + diag_dict[category_name].append(record) + break + # Increment the counter of how many records we have read + counter += 1 + counter_bool = counter >= max_counter + # Break the loop if we have read max_counter records + if counter % 100 == 0: + print(f"Read {counter} records") + if counter_bool: break - counter += 1 - counter_bool = counter >= max_counter - if counter % 100 == 0: - print(f"Gelesen {counter} Datensätze") - if counter_bool: - break - except Exception as e: - print(f"Fehler beim Lesen des Datensatzes {record_name}: {e}") + except Exception as e: + failed_records.append(record_name) + print(f"Failed to read record {record_name} due to ValueError. Sum of failed records: {len(failed_records)}") + if counter_bool: + break if counter_bool: break - if counter_bool: - break + return diag_dict -for cat_name, records in diag_dict.items(): - print(f"Schreibe {cat_name} in eine komprimierte Datei mit {len(records)} Datensätzen") - if not os.path.exists('./data'): - os.makedirs('./data') - compressed_filename = f'./data/{cat_name}.pkl.bz2' - with bz2.open(compressed_filename, 'wb') as f: - pickle.dump(records, f) +def write_data(data_dict, path='./data', file_prefix=''): + """ + Writes the data to a pickle file. + Args: + data_dict (dict): The data to be written to the file. + dir_name (str): The directory where the file will be saved. + """ + # if path not exists create it + if not os.path.exists(path): + os.makedirs(path) + # write to pickle + for cat_name, data in data_dict.items(): + print(f"Writing {cat_name} to pickle with {len(data)} data entries.") + with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f: + pickle.dump(data, f) + +def generate_feature_data(input_data_path, output_data_path, settings, prefix='feature_', split_ratio=None): + """ + Generates the feature data from the raw data. + Args: + input_data_path (str): The path to the directory containing the raw data. + output_data_path (str): The path to the directory where the feature data will be saved. + settings (dict): The settings dictionary. + prefix (str): The prefix to be added to the feature files. + split_ratio (list): The ratio in which the data will be split into training, test, and validation sets. + + """ + if split_ratio is None: + split_ratio = settings['split_ratio'] + data_dict = {} + for file in os.listdir(input_data_path): + if file.endswith(".pkl"): + print(f"Reading {file}") + with open(f'{input_data_path}/{file}', 'rb') as f: + data = pickle.load(f) + data_dict[file.replace('.pkl', '')] = data + # Extract the features + feature_data = feature_extraction.extract_features(data_dict) + # Split the data + splited_data = feature_extraction.split_data(feature_data, split_ratio) + if not os.path.exists(f'{output_data_path}/ml_dataset/'): + os.makedirs(f'{output_data_path}/ml_dataset/') + for file_name, data in splited_data.items(): + print(f"Writing {file_name} to pickle with {len(data)} data entries.") + with open(f'{output_data_path}/ml_dataset/{prefix}{file_name}', 'wb') as f: + pickle.dump(data, f) + return splited_data + +def main(gen_data=True, gen_features=True, split_ratio=None, settings_path='./settings.json', num_process_files=-1): + """ + Main function to generate the data. + Args: + gen_data (bool): If True, generates the raw data. + gen_features (bool): If True, generates the feature data. + split_ratio (list): The ratio in which the data will be split into training, test, and validation sets. + settings_path (str): The path to the settings file. + num_process_files (int): The maximum number of records to process. + Returns: + dict: The generated data. + """ + ret_data = None + settings = json.load(open(settings_path)) + if num_process_files < 0: + num_process_files = 100_000 + if split_ratio is None: + split_ratio = settings['split_ratio'] + if gen_data: + raw_data_dir = settings["wfdb_path"] + '/WFDBRecords' + data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files) + write_data(data_dict, path=settings["data_path"]) + ret_data = data_dict + if gen_features: + feature_data_dict = generate_feature_data(settings["data_path"], settings["data_path"], settings, split_ratio=split_ratio) + ret_data = feature_data_dict + + return ret_data + + +# -------------------------------------------------------------------------------- +# Generate the data +# -------------------------------------------------------------------------------- +if __name__ == '__main__': + """ + The following categories are used to classify the records: + + SB, Sinusbradykardie + AFIB, Vorhofflimmern und Vorhofflattern (AFL) + GSVT, supraventrikulärer Tachykardie, Vorhoftachykardie, AV-Knoten-Reentry-Tachykardie, AV-Reentry-Tachykardie, Vorhofschrittmacher + SR Sinusrhythmus und Sinusunregelmäßigkeiten + """ + # SB, AFIB, GSVT, SR + # new GSVT, AFIB, SR, SB + # Generate the data + main(gen_data=True, gen_features=False, num_process_files=100_000) + #main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], num_process_files=100_000) + print("Data generation completed.") diff --git a/settings.json b/settings.json index 3446038..e401605 100644 --- a/settings.json +++ b/settings.json @@ -1,11 +1,15 @@ { - "data_path_comment": "Path to the data folder. This is the folder where the data is stored.", + "wfdb_path_comment": "Path to the WFDB data. This is the folder where the WFDB data is stored.", + "wfdb_path": "C:/Studium/dsa/large_12_ecg_data/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0", + "data_path_comment": "Path to the data folder. This is the folder where the genereated data is stored.", "data_path": "C:/Studium/dsa/data", "labels_comment": "Labels for the different classes. The labels are the SNOMED CT codes.", "labels": { - "SB": [426177001], - "AFIB": [164889003, 164890007], "GSVT": [426761007, 713422000, 233896004, 233897008, 713422000], - "SR": [426783006, 427393009] - } + "AFIB": [164889003, 164890007], + "SR": [426783006, 427393009], + "SB": [426177001] + }, + "split_ratio_comment": "Ratio for the train-test-validation split. The first value is the ratio for the training data, the second value is the ratio for the test data.", + "split_ratio": [0.8, 0.1, 0.1] } \ No newline at end of file