Merge branch 'main' of https://gitty.informatik.hs-mannheim.de/1826514/DSA_SS24
commit
f924d962b8
|
@ -0,0 +1,178 @@
|
|||
|
||||
from matplotlib import pyplot as plt
|
||||
import wfdb.processing
|
||||
import sys
|
||||
import json
|
||||
import scipy
|
||||
import numpy as np
|
||||
import neurokit2 as nk
|
||||
import math
|
||||
import time
|
||||
|
||||
def get_y_value(ecg_cleaned, indecies):
|
||||
"""
|
||||
Get the y value of the ECG signal at the given indices.
|
||||
Args:
|
||||
ecg_cleaned (list): The cleaned ECG signal.
|
||||
indecies (list): The list of indices.
|
||||
Returns:
|
||||
list: The list of y values at the given indices.
|
||||
"""
|
||||
return [ecg_cleaned[int(i)] for i in indecies if not math.isnan(i)]
|
||||
|
||||
|
||||
def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0):
|
||||
"""
|
||||
Calculate the R and T axis of the ECG signal.
|
||||
Args:
|
||||
record (object): The record object containing the ECG signal.
|
||||
wave_peak (dict): The dictionary containing the wave peaks.
|
||||
r_peak_idx (list): The list containing the R peak indices.
|
||||
sampling_rate (int): The sampling rate of the ECG signal.
|
||||
aVF (int): The index of the aVF lead.
|
||||
I (int): The index of the I lead.
|
||||
Returns:
|
||||
tuple: The R and T axis of the ECG signal.
|
||||
"""
|
||||
# Calculate the net QRS in each lead
|
||||
ecg_signal_avf = record.p_signal[:, aVF]
|
||||
ecg_signal_avf_cleaned = nk.ecg_clean(ecg_signal_avf, sampling_rate=sampling_rate)
|
||||
|
||||
ecg_signal_i = record.p_signal[:, I]
|
||||
ecg_signal_i_cleaned = nk.ecg_clean(ecg_signal_i, sampling_rate=sampling_rate)
|
||||
|
||||
# r axis
|
||||
# get amplitude of peaks
|
||||
q_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_Q_Peaks'])
|
||||
s_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_S_Peaks'])
|
||||
r_peaks_avf = get_y_value(ecg_signal_avf_cleaned, r_peak_idx)
|
||||
|
||||
q_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_Q_Peaks'])
|
||||
s_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_S_Peaks'])
|
||||
r_peaks_i = get_y_value(ecg_signal_i_cleaned, r_peak_idx)
|
||||
|
||||
# calculate avg peal amplitude
|
||||
q_peaks_i_avg = np.mean(q_peaks_i)
|
||||
s_peaks_i_avg = np.mean(s_peaks_i)
|
||||
r_peaks_i_avg = np.mean(r_peaks_i)
|
||||
|
||||
q_peaks_avf_avg = np.mean(q_peaks_avf)
|
||||
s_peaks_avf_avg = np.mean(s_peaks_avf)
|
||||
r_peaks_avf_avg = np.mean(r_peaks_avf)
|
||||
|
||||
# Calculate net QRS in lead
|
||||
net_qrs_i = r_peaks_i_avg - (q_peaks_i_avg + s_peaks_i_avg)
|
||||
net_qrs_avf = r_peaks_avf_avg - (q_peaks_avf_avg + s_peaks_avf_avg)
|
||||
|
||||
# t axis
|
||||
t_peaks_i = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_T_Peaks'])
|
||||
t_peaks_avf = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_T_Peaks'])
|
||||
|
||||
net_t_i = np.mean(t_peaks_i)
|
||||
net_t_avf = np.mean(t_peaks_avf)
|
||||
|
||||
#print("amplitude I", net_qrs.get(I, 0))
|
||||
#print("amplitude aVF", net_qrs.get(aVF, 0))
|
||||
|
||||
# Calculate the R axis (Convert to degrees)
|
||||
r_axis = np.arctan2(net_qrs_avf, net_qrs_i) * (180 / np.pi)
|
||||
|
||||
# Calculate the T axis (Convert to degrees)
|
||||
t_axis = np.arctan2(net_t_avf, net_t_i) * (180 / np.pi)
|
||||
|
||||
return r_axis, t_axis
|
||||
|
||||
|
||||
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
|
||||
"""
|
||||
Extracts the features from the data.
|
||||
Args:
|
||||
data_dict (dict): The dictionary containing the data.
|
||||
Returns:
|
||||
dict: The dictionary containing the extracted features.
|
||||
"""
|
||||
start_time = time.time()
|
||||
feature_data = {}
|
||||
failed_records = []
|
||||
|
||||
for label, data in data_dict.items():
|
||||
print(f"Extracting features for {label} with {len(data)} data entries.")
|
||||
for data_idx, record in enumerate(data):
|
||||
if data_idx % 100 == 0:
|
||||
stop_time = time.time()
|
||||
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
|
||||
start_time = time.time()
|
||||
|
||||
age = record.comments[0].split(' ')[1]
|
||||
gender = record.comments[1].split(' ')[1]
|
||||
if age == 'NaN' or gender == 'NaN':
|
||||
continue
|
||||
features = {}
|
||||
# Extract the features
|
||||
features['y'] = label
|
||||
# Demographic features
|
||||
features['age'] = int(age)
|
||||
features['gender'] = True if gender == 'Male' else False
|
||||
# Signal features
|
||||
|
||||
ecg_signal = record.p_signal[:, 0]
|
||||
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
|
||||
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
|
||||
r_peaks = rpeaks['ECG_R_Peaks']
|
||||
# Delineate the ECG signal
|
||||
try:
|
||||
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
|
||||
except:
|
||||
failed_records.append(record.record_name)
|
||||
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
|
||||
continue
|
||||
|
||||
# TODO Other features and check features
|
||||
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
|
||||
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
|
||||
|
||||
features['artial_rate'] = atrial_rate
|
||||
features['ventricular_rate'] = ventricular_rate
|
||||
|
||||
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
|
||||
features['qrs_duration'] = qrs_duration
|
||||
|
||||
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
|
||||
features['qt_length'] = qt_interval
|
||||
|
||||
q_peak = waves_peak['ECG_Q_Peaks']
|
||||
s_peak = waves_peak['ECG_S_Peaks']
|
||||
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
|
||||
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
|
||||
features['qrs_count'] = qrs_count
|
||||
|
||||
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
|
||||
|
||||
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
|
||||
features['r_axis'] = r_axis
|
||||
features['t_axis'] = t_axis
|
||||
|
||||
# print the features
|
||||
#print(json.dumps(features, indent=4))
|
||||
feature_data[record.record_name] = features
|
||||
|
||||
return feature_data
|
||||
|
||||
|
||||
def split_data(feature_data, split_ratio):
|
||||
print(f"Splitting data with ratio {split_ratio}")
|
||||
#flatten dictionary
|
||||
feature_data = {k: v for k, v in feature_data.items()}
|
||||
# print keys
|
||||
print("Keys:")
|
||||
print(len(feature_data.keys()))
|
||||
# shuffle the data
|
||||
keys = list(feature_data.keys())
|
||||
np.random.shuffle(keys)
|
||||
# split the data
|
||||
split_data = {}
|
||||
split_data['train'] = {k: feature_data[k] for k in keys[:int(len(keys) * split_ratio[0])]}
|
||||
split_data['test'] = {k: feature_data[k] for k in keys[int(len(keys) * split_ratio[0]):int(len(keys) * (split_ratio[0] + split_ratio[1]))]}
|
||||
split_data['validation'] = {k: feature_data[k] for k in keys[int(len(keys) * (split_ratio[0] + split_ratio[1])):]}
|
||||
|
||||
return split_data
|
|
@ -1,78 +1,180 @@
|
|||
"""
|
||||
This script reads the WFDB records and extracts the diagnosis information from the comments.
|
||||
The diagnosis information is then used to classify the records into categories.
|
||||
The categories are defined by the diagnosis codes in the comments.
|
||||
The records are then saved to pickle files based on the categories.
|
||||
"""
|
||||
|
||||
import wfdb
|
||||
import os
|
||||
import pickle
|
||||
import bz2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import json
|
||||
|
||||
# Funktionen zum Bearbeiten der Daten
|
||||
import feature_extraction
|
||||
|
||||
# Functions
|
||||
def get_diagnosis_ids(record):
|
||||
"""
|
||||
Extracts diagnosis IDs from a record and returns them as a list.
|
||||
Args:
|
||||
record (object): The record object containing the diagnosis information.
|
||||
Returns:
|
||||
list: A list of diagnosis IDs extracted from the record.
|
||||
"""
|
||||
# Get the diagnosis
|
||||
diagnosis = record.comments[2]
|
||||
# clean the diagnosis
|
||||
diagnosis = diagnosis.replace('Dx: ', '')
|
||||
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
|
||||
return list_diagnosis
|
||||
|
||||
def get_diagnosis_name(diagnosis):
|
||||
name = [diagnosis_lookup[diagnosis_lookup['Snomed_CT'] == x]['Full Name'].to_string(index=False) for x in diagnosis]
|
||||
return name
|
||||
def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
||||
"""
|
||||
Generates the raw data from the WFDB records.
|
||||
Args:
|
||||
path_to_data (str): The path to the directory containing the WFDB records.
|
||||
max_counter (int): The maximum number of records to read.
|
||||
Returns:
|
||||
dict: A dictionary containing the raw data.
|
||||
"""
|
||||
counter = 0
|
||||
failed_records = []
|
||||
categories = settings["labels"]
|
||||
|
||||
def filter_signal_df_on_diag(df_dict, diagnosis_dict, filter_codes_df):
|
||||
filter_cod_li = list(filter_codes_df['Snomed_CT']) + [0]
|
||||
filter_dict_diag = {k: v for k, v in diagnosis_dict.items() if all(i in filter_cod_li for i in v)}
|
||||
filtered_df_dict = {i: df.loc[df.index.isin(filter_dict_diag.keys())] for i, df in df_dict.items()}
|
||||
return filtered_df_dict
|
||||
|
||||
# Verzeichnisse und Dateipfade
|
||||
project_dir = 'C:/Users/arman/PycharmProjects/pythonProject/DSA/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0'
|
||||
data_dir = project_dir + '/WFDBRecords'
|
||||
path_diag_lookup = "C:/Users/arman/PycharmProjects/pythonProject/DSA/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0/ConditionNames_SNOMED-CT.csv"
|
||||
|
||||
# Daten erkunden
|
||||
diagnosis_lookup = pd.read_csv(path_diag_lookup)
|
||||
|
||||
categories = {
|
||||
'SB': [426177001],
|
||||
'AFIB': [164889003, 164890007],
|
||||
'GSVT': [426761007, 713422000, 233896004, 233897008, 713422000],
|
||||
'SR': [426783006, 427393009]
|
||||
}
|
||||
|
||||
diag_dict = {k: [] for k in categories.keys()}
|
||||
counter = 0
|
||||
max_counter = 100_000
|
||||
|
||||
for dir_th in os.listdir(data_dir):
|
||||
path_to_1000_records = data_dir + '/' + dir_th
|
||||
diag_dict = {k: [] for k in categories.keys()}
|
||||
# Loop through the records
|
||||
for dir_th in os.listdir(path_to_data):
|
||||
path_to_1000_records = path_to_data + '/' + dir_th
|
||||
for dir_hd in os.listdir(path_to_1000_records):
|
||||
path_to_100_records = path_to_1000_records + '/' + dir_hd
|
||||
for record_name in os.listdir(path_to_100_records):
|
||||
# check if .hea is in the record_name
|
||||
if '.hea' not in record_name:
|
||||
continue
|
||||
# Remove the .hea extension from record_name
|
||||
record_name = record_name.replace('.hea', '')
|
||||
try:
|
||||
# Read the record
|
||||
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
|
||||
# Get the diagnosis
|
||||
diagnosis = np.array(get_diagnosis_ids(record))
|
||||
# check if diagnosis is a subset of one of the categories
|
||||
for category_name, category_codes in categories.items():
|
||||
# if any of the diagnosis codes is in the category_codes
|
||||
if any(i in category_codes for i in diagnosis):
|
||||
diag_dict[category_name].append(record)
|
||||
break
|
||||
# Increment the counter of how many records we have read
|
||||
counter += 1
|
||||
counter_bool = counter >= max_counter
|
||||
# Break the loop if we have read max_counter records
|
||||
if counter % 100 == 0:
|
||||
print(f"Gelesen {counter} Datensätze")
|
||||
print(f"Read {counter} records")
|
||||
if counter_bool:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Fehler beim Lesen des Datensatzes {record_name}: {e}")
|
||||
failed_records.append(record_name)
|
||||
print(f"Failed to read record {record_name} due to ValueError. Sum of failed records: {len(failed_records)}")
|
||||
if counter_bool:
|
||||
break
|
||||
if counter_bool:
|
||||
break
|
||||
return diag_dict
|
||||
|
||||
for cat_name, records in diag_dict.items():
|
||||
print(f"Schreibe {cat_name} in eine komprimierte Datei mit {len(records)} Datensätzen")
|
||||
if not os.path.exists('./data'):
|
||||
os.makedirs('./data')
|
||||
compressed_filename = f'./data/{cat_name}.pkl.bz2'
|
||||
with bz2.open(compressed_filename, 'wb') as f:
|
||||
pickle.dump(records, f)
|
||||
def write_data(data_dict, path='./data', file_prefix=''):
|
||||
"""
|
||||
Writes the data to a pickle file.
|
||||
Args:
|
||||
data_dict (dict): The data to be written to the file.
|
||||
dir_name (str): The directory where the file will be saved.
|
||||
"""
|
||||
# if path not exists create it
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
# write to pickle
|
||||
for cat_name, data in data_dict.items():
|
||||
print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
|
||||
with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
def generate_feature_data(input_data_path, output_data_path, settings, prefix='feature_', split_ratio=None):
|
||||
"""
|
||||
Generates the feature data from the raw data.
|
||||
Args:
|
||||
input_data_path (str): The path to the directory containing the raw data.
|
||||
output_data_path (str): The path to the directory where the feature data will be saved.
|
||||
settings (dict): The settings dictionary.
|
||||
prefix (str): The prefix to be added to the feature files.
|
||||
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
|
||||
|
||||
"""
|
||||
if split_ratio is None:
|
||||
split_ratio = settings['split_ratio']
|
||||
data_dict = {}
|
||||
for file in os.listdir(input_data_path):
|
||||
if file.endswith(".pkl"):
|
||||
print(f"Reading {file}")
|
||||
with open(f'{input_data_path}/{file}', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
data_dict[file.replace('.pkl', '')] = data
|
||||
# Extract the features
|
||||
feature_data = feature_extraction.extract_features(data_dict)
|
||||
# Split the data
|
||||
splited_data = feature_extraction.split_data(feature_data, split_ratio)
|
||||
if not os.path.exists(f'{output_data_path}/ml_dataset/'):
|
||||
os.makedirs(f'{output_data_path}/ml_dataset/')
|
||||
for file_name, data in splited_data.items():
|
||||
print(f"Writing {file_name} to pickle with {len(data)} data entries.")
|
||||
with open(f'{output_data_path}/ml_dataset/{prefix}{file_name}', 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
return splited_data
|
||||
|
||||
def main(gen_data=True, gen_features=True, split_ratio=None, settings_path='./settings.json', num_process_files=-1):
|
||||
"""
|
||||
Main function to generate the data.
|
||||
Args:
|
||||
gen_data (bool): If True, generates the raw data.
|
||||
gen_features (bool): If True, generates the feature data.
|
||||
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
|
||||
settings_path (str): The path to the settings file.
|
||||
num_process_files (int): The maximum number of records to process.
|
||||
Returns:
|
||||
dict: The generated data.
|
||||
"""
|
||||
ret_data = None
|
||||
settings = json.load(open(settings_path))
|
||||
if num_process_files < 0:
|
||||
num_process_files = 100_000
|
||||
if split_ratio is None:
|
||||
split_ratio = settings['split_ratio']
|
||||
if gen_data:
|
||||
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
|
||||
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files)
|
||||
write_data(data_dict, path=settings["data_path"])
|
||||
ret_data = data_dict
|
||||
if gen_features:
|
||||
feature_data_dict = generate_feature_data(settings["data_path"], settings["data_path"], settings, split_ratio=split_ratio)
|
||||
ret_data = feature_data_dict
|
||||
|
||||
return ret_data
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Generate the data
|
||||
# --------------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
The following categories are used to classify the records:
|
||||
|
||||
SB, Sinusbradykardie
|
||||
AFIB, Vorhofflimmern und Vorhofflattern (AFL)
|
||||
GSVT, supraventrikulärer Tachykardie, Vorhoftachykardie, AV-Knoten-Reentry-Tachykardie, AV-Reentry-Tachykardie, Vorhofschrittmacher
|
||||
SR Sinusrhythmus und Sinusunregelmäßigkeiten
|
||||
"""
|
||||
# SB, AFIB, GSVT, SR
|
||||
# new GSVT, AFIB, SR, SB
|
||||
# Generate the data
|
||||
main(gen_data=True, gen_features=False, num_process_files=100_000)
|
||||
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], num_process_files=100_000)
|
||||
print("Data generation completed.")
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
{
|
||||
"data_path_comment": "Path to the data folder. This is the folder where the data is stored.",
|
||||
"wfdb_path_comment": "Path to the WFDB data. This is the folder where the WFDB data is stored.",
|
||||
"wfdb_path": "C:/Studium/dsa/large_12_ecg_data/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0",
|
||||
"data_path_comment": "Path to the data folder. This is the folder where the genereated data is stored.",
|
||||
"data_path": "C:/Studium/dsa/data",
|
||||
"labels_comment": "Labels for the different classes. The labels are the SNOMED CT codes.",
|
||||
"labels": {
|
||||
"SB": [426177001],
|
||||
"AFIB": [164889003, 164890007],
|
||||
"GSVT": [426761007, 713422000, 233896004, 233897008, 713422000],
|
||||
"SR": [426783006, 427393009]
|
||||
}
|
||||
"AFIB": [164889003, 164890007],
|
||||
"SR": [426783006, 427393009],
|
||||
"SB": [426177001]
|
||||
},
|
||||
"split_ratio_comment": "Ratio for the train-test-validation split. The first value is the ratio for the training data, the second value is the ratio for the test data.",
|
||||
"split_ratio": [0.8, 0.1, 0.1]
|
||||
}
|
Loading…
Reference in New Issue