204 lines
8.6 KiB
Python
204 lines
8.6 KiB
Python
"""
|
|
This script reads the WFDB records and extracts the diagnosis information from the comments.
|
|
The diagnosis information is then used to classify the records into categories.
|
|
The categories are defined by the diagnosis codes in the comments.
|
|
The records are then saved to pickle files based on the categories.
|
|
"""
|
|
|
|
import wfdb
|
|
import os
|
|
import numpy as np
|
|
import pickle
|
|
import json
|
|
import multiprocessing
|
|
|
|
import feature_extraction
|
|
|
|
# Functions
|
|
def get_diagnosis_ids(record):
|
|
"""
|
|
Extracts diagnosis IDs from a record and returns them as a list.
|
|
Args:
|
|
record (object): The record object containing the diagnosis information.
|
|
Returns:
|
|
list: A list of diagnosis IDs extracted from the record.
|
|
"""
|
|
# Get the diagnosis
|
|
diagnosis = record.comments[2]
|
|
# clean the diagnosis
|
|
diagnosis = diagnosis.replace('Dx: ', '')
|
|
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
|
|
return list_diagnosis
|
|
|
|
def generate_raw_data(path_to_data, settings, max_counter=100_000, only_ids=False):
|
|
"""
|
|
Generates the raw data from the WFDB records.
|
|
Args:
|
|
path_to_data (str): The path to the directory containing the WFDB records.
|
|
max_counter (int): The maximum number of records to read.
|
|
Returns:
|
|
dict: A dictionary containing the raw data.
|
|
"""
|
|
counter = 0
|
|
failed_records = []
|
|
categories = settings["labels"]
|
|
|
|
if only_ids:
|
|
diag_dict = {}
|
|
else:
|
|
diag_dict = {k: [] for k in categories.keys()}
|
|
# Loop through the records
|
|
for dir_th in os.listdir(path_to_data):
|
|
path_to_1000_records = path_to_data + '/' + dir_th
|
|
for dir_hd in os.listdir(path_to_1000_records):
|
|
path_to_100_records = path_to_1000_records + '/' + dir_hd
|
|
for record_name in os.listdir(path_to_100_records):
|
|
# check if .hea is in the record_name
|
|
if '.hea' not in record_name:
|
|
continue
|
|
# Remove the .hea extension from record_name
|
|
record_name = record_name.replace('.hea', '')
|
|
try:
|
|
# Read the record
|
|
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
|
|
# Get the diagnosis
|
|
diagnosis = np.array(get_diagnosis_ids(record))
|
|
if only_ids:
|
|
diag_dict[record_name] = diagnosis
|
|
else:
|
|
# check if diagnosis is a subset of one of the categories
|
|
for category_name, category_codes in categories.items():
|
|
# if any of the diagnosis codes is in the category_codes
|
|
if any(i in category_codes for i in diagnosis):
|
|
diag_dict[category_name].append(record)
|
|
break
|
|
# Increment the counter of how many records we have read
|
|
counter += 1
|
|
counter_bool = counter >= max_counter
|
|
# Break the loop if we have read max_counter records
|
|
if counter % 100 == 0:
|
|
print(f"Read {counter} records")
|
|
if counter_bool:
|
|
break
|
|
except Exception as e:
|
|
failed_records.append(record_name)
|
|
print(f"Failed to read record {record_name} due to ValueError. Sum of failed records: {len(failed_records)}")
|
|
if counter_bool:
|
|
break
|
|
if counter_bool:
|
|
break
|
|
return diag_dict
|
|
|
|
def write_data(data_dict, path='./data', file_prefix='', only_ids=False):
|
|
"""
|
|
Writes the data to a pickle file.
|
|
Args:
|
|
data_dict (dict): The data to be written to the file.
|
|
dir_name (str): The directory where the file will be saved.
|
|
"""
|
|
# if path not exists create it
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
if only_ids:
|
|
# write to pickle
|
|
print(f"Writing diagnosis IDs to pickle with {len(data_dict)} data entries.")
|
|
with open(f'{path}/{file_prefix}.pkl', 'wb') as f:
|
|
pickle.dump(data_dict, f)
|
|
return
|
|
# write to pickle
|
|
for cat_name, data in data_dict.items():
|
|
print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
|
|
with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f:
|
|
pickle.dump(data, f)
|
|
|
|
def generate_feature_data(input_data_path, settings, parallel=False, split_ratio=None):
|
|
"""
|
|
Generates the feature data from the raw data.
|
|
Args:
|
|
input_data_path (str): The path to the directory containing the raw data.
|
|
output_data_path (str): The path to the directory where the feature data will be saved.
|
|
settings (dict): The settings dictionary.
|
|
prefix (str): The prefix to be added to the feature files.
|
|
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
|
|
|
|
"""
|
|
if split_ratio is None:
|
|
split_ratio = settings['split_ratio']
|
|
print(list(os.listdir(input_data_path)))
|
|
for file in os.listdir(input_data_path):
|
|
if file.endswith(".pkl") and not file.startswith("diagnosis"):
|
|
print(f"Reading {file}")
|
|
with open(f'{input_data_path}/{file}', 'rb') as f:
|
|
data = pickle.load(f)
|
|
#data_dict[file.replace('.pkl', '')] = data
|
|
data_dict = {file.replace('.pkl', ''): data}
|
|
# Extract the features
|
|
if parallel:
|
|
# get max number of processes
|
|
max_processes = multiprocessing.cpu_count() - 1
|
|
print(f"Using {max_processes} processes to extract features.")
|
|
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
|
|
else:
|
|
print(f"For even distribution of data, the limit is set to the smallest size: 1000.")
|
|
feature_extraction.extract_features(data_dict, limit=1000)
|
|
# Split the data
|
|
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
|
|
|
|
|
|
|
|
def main(gen_data=True, gen_features=True, gen_diag_ids=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
|
|
"""
|
|
Main function to generate the data.
|
|
Args:
|
|
gen_data (bool): If True, generates the raw data.
|
|
gen_features (bool): If True, generates the feature data.
|
|
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
|
|
settings_path (str): The path to the settings file.
|
|
num_process_files (int): The maximum number of records to process.
|
|
Returns:
|
|
dict: The generated data.
|
|
"""
|
|
ret_data = None
|
|
settings = json.load(open(settings_path))
|
|
if num_process_files < 0:
|
|
num_process_files = 100_000
|
|
if split_ratio is None:
|
|
split_ratio = settings['split_ratio']
|
|
if gen_data:
|
|
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
|
|
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files)
|
|
write_data(data_dict, path=settings["data_path"])
|
|
ret_data = data_dict
|
|
if gen_features:
|
|
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
|
|
ret_data = feature_data_dict
|
|
if gen_diag_ids:
|
|
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
|
|
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files, only_ids=True)
|
|
write_data(data_dict, path=settings["data_path"], file_prefix='diagnosis', only_ids=True)
|
|
ret_data = data_dict
|
|
|
|
return ret_data
|
|
|
|
|
|
# --------------------------------------------------------------------------------
|
|
# Generate the data
|
|
# --------------------------------------------------------------------------------
|
|
if __name__ == '__main__':
|
|
"""
|
|
The following categories are used to classify the records:
|
|
|
|
SB, Sinusbradykardie
|
|
AFIB, Vorhofflimmern und Vorhofflattern (AFL)
|
|
GSVT, supraventrikulärer Tachykardie, Vorhoftachykardie, AV-Knoten-Reentry-Tachykardie, AV-Reentry-Tachykardie, Vorhofschrittmacher
|
|
SR Sinusrhythmus und Sinusunregelmäßigkeiten
|
|
"""
|
|
# SB, AFIB, GSVT, SR
|
|
# new GSVT, AFIB, SR, SB
|
|
# Generate the data
|
|
#main(gen_data=True, gen_features=False, gen_diag_ids=False, num_process_files=100_000)
|
|
#main(gen_data=False, gen_features=True, gen_diag_ids=False, split_ratio=[0.8, 0.1, 0.1])
|
|
main(gen_data=False, gen_features=False, gen_diag_ids=True)
|
|
print("Data generation completed.")
|