new data generator process

main
Felix Jan Michael Mucha 2024-06-05 14:24:31 +02:00
parent 5eb1c44403
commit ea37e07108
3 changed files with 352 additions and 68 deletions

View File

@ -0,0 +1,178 @@
from matplotlib import pyplot as plt
import wfdb.processing
import sys
import json
import scipy
import numpy as np
import neurokit2 as nk
import math
import time
def get_y_value(ecg_cleaned, indecies):
"""
Get the y value of the ECG signal at the given indices.
Args:
ecg_cleaned (list): The cleaned ECG signal.
indecies (list): The list of indices.
Returns:
list: The list of y values at the given indices.
"""
return [ecg_cleaned[int(i)] for i in indecies if not math.isnan(i)]
def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0):
"""
Calculate the R and T axis of the ECG signal.
Args:
record (object): The record object containing the ECG signal.
wave_peak (dict): The dictionary containing the wave peaks.
r_peak_idx (list): The list containing the R peak indices.
sampling_rate (int): The sampling rate of the ECG signal.
aVF (int): The index of the aVF lead.
I (int): The index of the I lead.
Returns:
tuple: The R and T axis of the ECG signal.
"""
# Calculate the net QRS in each lead
ecg_signal_avf = record.p_signal[:, aVF]
ecg_signal_avf_cleaned = nk.ecg_clean(ecg_signal_avf, sampling_rate=sampling_rate)
ecg_signal_i = record.p_signal[:, I]
ecg_signal_i_cleaned = nk.ecg_clean(ecg_signal_i, sampling_rate=sampling_rate)
# r axis
# get amplitude of peaks
q_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_Q_Peaks'])
s_peaks_avf = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_S_Peaks'])
r_peaks_avf = get_y_value(ecg_signal_avf_cleaned, r_peak_idx)
q_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_Q_Peaks'])
s_peaks_i = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_S_Peaks'])
r_peaks_i = get_y_value(ecg_signal_i_cleaned, r_peak_idx)
# calculate avg peal amplitude
q_peaks_i_avg = np.mean(q_peaks_i)
s_peaks_i_avg = np.mean(s_peaks_i)
r_peaks_i_avg = np.mean(r_peaks_i)
q_peaks_avf_avg = np.mean(q_peaks_avf)
s_peaks_avf_avg = np.mean(s_peaks_avf)
r_peaks_avf_avg = np.mean(r_peaks_avf)
# Calculate net QRS in lead
net_qrs_i = r_peaks_i_avg - (q_peaks_i_avg + s_peaks_i_avg)
net_qrs_avf = r_peaks_avf_avg - (q_peaks_avf_avg + s_peaks_avf_avg)
# t axis
t_peaks_i = get_y_value(ecg_signal_avf_cleaned, wave_peak['ECG_T_Peaks'])
t_peaks_avf = get_y_value(ecg_signal_i_cleaned, wave_peak['ECG_T_Peaks'])
net_t_i = np.mean(t_peaks_i)
net_t_avf = np.mean(t_peaks_avf)
#print("amplitude I", net_qrs.get(I, 0))
#print("amplitude aVF", net_qrs.get(aVF, 0))
# Calculate the R axis (Convert to degrees)
r_axis = np.arctan2(net_qrs_avf, net_qrs_i) * (180 / np.pi)
# Calculate the T axis (Convert to degrees)
t_axis = np.arctan2(net_t_avf, net_t_i) * (180 / np.pi)
return r_axis, t_axis
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
"""
Extracts the features from the data.
Args:
data_dict (dict): The dictionary containing the data.
Returns:
dict: The dictionary containing the extracted features.
"""
start_time = time.time()
feature_data = {}
failed_records = []
for label, data in data_dict.items():
print(f"Extracting features for {label} with {len(data)} data entries.")
for data_idx, record in enumerate(data):
if data_idx % 100 == 0:
stop_time = time.time()
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
start_time = time.time()
age = record.comments[0].split(' ')[1]
gender = record.comments[1].split(' ')[1]
if age == 'NaN' or gender == 'NaN':
continue
features = {}
# Extract the features
features['y'] = label
# Demographic features
features['age'] = int(age)
features['gender'] = True if gender == 'Male' else False
# Signal features
ecg_signal = record.p_signal[:, 0]
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
r_peaks = rpeaks['ECG_R_Peaks']
# Delineate the ECG signal
try:
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
except:
failed_records.append(record.record_name)
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
continue
# TODO Other features and check features
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
features['artial_rate'] = atrial_rate
features['ventricular_rate'] = ventricular_rate
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qrs_duration'] = qrs_duration
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qt_length'] = qt_interval
q_peak = waves_peak['ECG_Q_Peaks']
s_peak = waves_peak['ECG_S_Peaks']
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
features['qrs_count'] = qrs_count
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
features['r_axis'] = r_axis
features['t_axis'] = t_axis
# print the features
#print(json.dumps(features, indent=4))
feature_data[record.record_name] = features
return feature_data
def split_data(feature_data, split_ratio):
print(f"Splitting data with ratio {split_ratio}")
#flatten dictionary
feature_data = {k: v for k, v in feature_data.items()}
# print keys
print("Keys:")
print(len(feature_data.keys()))
# shuffle the data
keys = list(feature_data.keys())
np.random.shuffle(keys)
# split the data
split_data = {}
split_data['train'] = {k: feature_data[k] for k in keys[:int(len(keys) * split_ratio[0])]}
split_data['test'] = {k: feature_data[k] for k in keys[int(len(keys) * split_ratio[0]):int(len(keys) * (split_ratio[0] + split_ratio[1]))]}
split_data['validation'] = {k: feature_data[k] for k in keys[int(len(keys) * (split_ratio[0] + split_ratio[1])):]}
return split_data

View File

@ -1,78 +1,180 @@
"""
This script reads the WFDB records and extracts the diagnosis information from the comments.
The diagnosis information is then used to classify the records into categories.
The categories are defined by the diagnosis codes in the comments.
The records are then saved to pickle files based on the categories.
"""
import wfdb
import os
import pickle
import bz2
import numpy as np
import pandas as pd
import pickle
import json
# Funktionen zum Bearbeiten der Daten
import feature_extraction
# Functions
def get_diagnosis_ids(record):
"""
Extracts diagnosis IDs from a record and returns them as a list.
Args:
record (object): The record object containing the diagnosis information.
Returns:
list: A list of diagnosis IDs extracted from the record.
"""
# Get the diagnosis
diagnosis = record.comments[2]
# clean the diagnosis
diagnosis = diagnosis.replace('Dx: ', '')
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
return list_diagnosis
def get_diagnosis_name(diagnosis):
name = [diagnosis_lookup[diagnosis_lookup['Snomed_CT'] == x]['Full Name'].to_string(index=False) for x in diagnosis]
return name
def generate_raw_data(path_to_data, settings, max_counter=100_000):
"""
Generates the raw data from the WFDB records.
Args:
path_to_data (str): The path to the directory containing the WFDB records.
max_counter (int): The maximum number of records to read.
Returns:
dict: A dictionary containing the raw data.
"""
counter = 0
failed_records = []
categories = settings["labels"]
def filter_signal_df_on_diag(df_dict, diagnosis_dict, filter_codes_df):
filter_cod_li = list(filter_codes_df['Snomed_CT']) + [0]
filter_dict_diag = {k: v for k, v in diagnosis_dict.items() if all(i in filter_cod_li for i in v)}
filtered_df_dict = {i: df.loc[df.index.isin(filter_dict_diag.keys())] for i, df in df_dict.items()}
return filtered_df_dict
# Verzeichnisse und Dateipfade
project_dir = 'C:/Users/arman/PycharmProjects/pythonProject/DSA/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0'
data_dir = project_dir + '/WFDBRecords'
path_diag_lookup = "C:/Users/arman/PycharmProjects/pythonProject/DSA/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0/ConditionNames_SNOMED-CT.csv"
# Daten erkunden
diagnosis_lookup = pd.read_csv(path_diag_lookup)
categories = {
'SB': [426177001],
'AFIB': [164889003, 164890007],
'GSVT': [426761007, 713422000, 233896004, 233897008, 713422000],
'SR': [426783006, 427393009]
}
diag_dict = {k: [] for k in categories.keys()}
counter = 0
max_counter = 100_000
for dir_th in os.listdir(data_dir):
path_to_1000_records = data_dir + '/' + dir_th
diag_dict = {k: [] for k in categories.keys()}
# Loop through the records
for dir_th in os.listdir(path_to_data):
path_to_1000_records = path_to_data + '/' + dir_th
for dir_hd in os.listdir(path_to_1000_records):
path_to_100_records = path_to_1000_records + '/' + dir_hd
for record_name in os.listdir(path_to_100_records):
# check if .hea is in the record_name
if '.hea' not in record_name:
continue
# Remove the .hea extension from record_name
record_name = record_name.replace('.hea', '')
try:
# Read the record
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
# Get the diagnosis
diagnosis = np.array(get_diagnosis_ids(record))
# check if diagnosis is a subset of one of the categories
for category_name, category_codes in categories.items():
# if any of the diagnosis codes is in the category_codes
if any(i in category_codes for i in diagnosis):
diag_dict[category_name].append(record)
break
# Increment the counter of how many records we have read
counter += 1
counter_bool = counter >= max_counter
# Break the loop if we have read max_counter records
if counter % 100 == 0:
print(f"Gelesen {counter} Datensätze")
print(f"Read {counter} records")
if counter_bool:
break
except Exception as e:
print(f"Fehler beim Lesen des Datensatzes {record_name}: {e}")
failed_records.append(record_name)
print(f"Failed to read record {record_name} due to ValueError. Sum of failed records: {len(failed_records)}")
if counter_bool:
break
if counter_bool:
break
return diag_dict
for cat_name, records in diag_dict.items():
print(f"Schreibe {cat_name} in eine komprimierte Datei mit {len(records)} Datensätzen")
if not os.path.exists('./data'):
os.makedirs('./data')
compressed_filename = f'./data/{cat_name}.pkl.bz2'
with bz2.open(compressed_filename, 'wb') as f:
pickle.dump(records, f)
def write_data(data_dict, path='./data', file_prefix=''):
"""
Writes the data to a pickle file.
Args:
data_dict (dict): The data to be written to the file.
dir_name (str): The directory where the file will be saved.
"""
# if path not exists create it
if not os.path.exists(path):
os.makedirs(path)
# write to pickle
for cat_name, data in data_dict.items():
print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f:
pickle.dump(data, f)
def generate_feature_data(input_data_path, output_data_path, settings, prefix='feature_', split_ratio=None):
"""
Generates the feature data from the raw data.
Args:
input_data_path (str): The path to the directory containing the raw data.
output_data_path (str): The path to the directory where the feature data will be saved.
settings (dict): The settings dictionary.
prefix (str): The prefix to be added to the feature files.
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
"""
if split_ratio is None:
split_ratio = settings['split_ratio']
data_dict = {}
for file in os.listdir(input_data_path):
if file.endswith(".pkl"):
print(f"Reading {file}")
with open(f'{input_data_path}/{file}', 'rb') as f:
data = pickle.load(f)
data_dict[file.replace('.pkl', '')] = data
# Extract the features
feature_data = feature_extraction.extract_features(data_dict)
# Split the data
splited_data = feature_extraction.split_data(feature_data, split_ratio)
if not os.path.exists(f'{output_data_path}/ml_dataset/'):
os.makedirs(f'{output_data_path}/ml_dataset/')
for file_name, data in splited_data.items():
print(f"Writing {file_name} to pickle with {len(data)} data entries.")
with open(f'{output_data_path}/ml_dataset/{prefix}{file_name}', 'wb') as f:
pickle.dump(data, f)
return splited_data
def main(gen_data=True, gen_features=True, split_ratio=None, settings_path='./settings.json', num_process_files=-1):
"""
Main function to generate the data.
Args:
gen_data (bool): If True, generates the raw data.
gen_features (bool): If True, generates the feature data.
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
settings_path (str): The path to the settings file.
num_process_files (int): The maximum number of records to process.
Returns:
dict: The generated data.
"""
ret_data = None
settings = json.load(open(settings_path))
if num_process_files < 0:
num_process_files = 100_000
if split_ratio is None:
split_ratio = settings['split_ratio']
if gen_data:
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files)
write_data(data_dict, path=settings["data_path"])
ret_data = data_dict
if gen_features:
feature_data_dict = generate_feature_data(settings["data_path"], settings["data_path"], settings, split_ratio=split_ratio)
ret_data = feature_data_dict
return ret_data
# --------------------------------------------------------------------------------
# Generate the data
# --------------------------------------------------------------------------------
if __name__ == '__main__':
"""
The following categories are used to classify the records:
SB, Sinusbradykardie
AFIB, Vorhofflimmern und Vorhofflattern (AFL)
GSVT, supraventrikulärer Tachykardie, Vorhoftachykardie, AV-Knoten-Reentry-Tachykardie, AV-Reentry-Tachykardie, Vorhofschrittmacher
SR Sinusrhythmus und Sinusunregelmäßigkeiten
"""
# SB, AFIB, GSVT, SR
# new GSVT, AFIB, SR, SB
# Generate the data
main(gen_data=True, gen_features=False, num_process_files=100_000)
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], num_process_files=100_000)
print("Data generation completed.")

View File

@ -1,11 +1,15 @@
{
"data_path_comment": "Path to the data folder. This is the folder where the data is stored.",
"wfdb_path_comment": "Path to the WFDB data. This is the folder where the WFDB data is stored.",
"wfdb_path": "C:/Studium/dsa/large_12_ecg_data/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0",
"data_path_comment": "Path to the data folder. This is the folder where the genereated data is stored.",
"data_path": "C:/Studium/dsa/data",
"labels_comment": "Labels for the different classes. The labels are the SNOMED CT codes.",
"labels": {
"SB": [426177001],
"AFIB": [164889003, 164890007],
"GSVT": [426761007, 713422000, 233896004, 233897008, 713422000],
"SR": [426783006, 427393009]
}
"AFIB": [164889003, 164890007],
"SR": [426783006, 427393009],
"SB": [426177001]
},
"split_ratio_comment": "Ratio for the train-test-validation split. The first value is the ratio for the training data, the second value is the ratio for the test data.",
"split_ratio": [0.8, 0.1, 0.1]
}