DSA_SS24/scripts/generate_data.py

105 lines
4.6 KiB
Python
Raw Normal View History

2024-05-15 20:20:01 +02:00
"""
This script reads the WFDB records and extracts the diagnosis information from the comments.
The diagnosis information is then used to classify the records into categories.
The categories are defined by the diagnosis codes in the comments.
The records are then saved to pickle files based on the categories.
"""
2024-05-01 09:56:36 +02:00
import wfdb
import os
import numpy as np
2024-05-01 12:53:33 +02:00
import pickle
2024-05-01 09:56:36 +02:00
# Directories and file paths
# --------------------------------------------------------------------------------
2024-05-15 20:20:01 +02:00
# NOTE: Specify the directory where the WFDB records are stored
2024-05-01 09:56:36 +02:00
project_dir = 'C:/Users/felix/OneDrive/Studium/Master MDS/1 Semester/DSA/physionet/large_12_ecg_data/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0'
data_dir = project_dir + '/WFDBRecords'
2024-05-15 20:20:01 +02:00
path_diag_lookup = project_dir + "/ConditionNames_SNOMED-CT.csv"
2024-05-01 09:56:36 +02:00
# --------------------------------------------------------------------------------
2024-05-15 20:20:01 +02:00
# Functions
2024-05-01 09:56:36 +02:00
def get_diagnosis_ids(record):
2024-05-15 20:20:01 +02:00
"""
Extracts diagnosis IDs from a record and returns them as a list.
Args:
record (object): The record object containing the diagnosis information.
Returns:
list: A list of diagnosis IDs extracted from the record.
"""
2024-05-01 09:56:36 +02:00
# Get the diagnosis
diagnosis = record.comments[2]
# clean the diagnosis
diagnosis = diagnosis.replace('Dx: ', '')
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
return list_diagnosis
# --------------------------------------------------------------------------------
2024-05-15 20:20:01 +02:00
# Generate the data
2024-05-01 09:56:36 +02:00
# --------------------------------------------------------------------------------
2024-05-15 20:20:01 +02:00
if __name__ == '__main__':
"""
The following categories are used to classify the records:
SB, Sinusbradykardie
AFIB, Vorhofflimmern und Vorhofflattern (AFL)
GSVT, supraventrikulärer Tachykardie, Vorhoftachykardie, AV-Knoten-Reentry-Tachykardie, AV-Reentry-Tachykardie, Vorhofschrittmacher
SR Sinusrhythmus und Sinusunregelmäßigkeiten
"""
categories = {
'SB': [426177001],
'AFIB': [164889003, 164890007],
'GSVT': [426761007, 713422000, 233896004, 233897008, 713422000],
'SR': [426783006, 427393009]
}
diag_dict = {k: [] for k in categories.keys()}
# Create a counter for the number of records
counter = 0
max_counter = 100_000
failed_records = []
# Loop through the records
for dir_th in os.listdir(data_dir):
path_to_1000_records = data_dir + '/' + dir_th
for dir_hd in os.listdir(path_to_1000_records):
path_to_100_records = path_to_1000_records + '/' + dir_hd
for record_name in os.listdir(path_to_100_records):
# check if .hea is in the record_name
if '.hea' not in record_name:
continue
# Remove the .hea extension from record_name
record_name = record_name.replace('.hea', '')
try:
# Read the record
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
# Get the diagnosis
diagnosis = np.array(get_diagnosis_ids(record))
# check if diagnosis is a subset of one of the categories
for category_name, category_codes in categories.items():
# if any of the diagnosis codes is in the category_codes
if any(i in category_codes for i in diagnosis):
diag_dict[category_name].append(record)
break
# Increment the counter of how many records we have read
counter += 1
counter_bool = counter >= max_counter
# Break the loop if we have read max_counter records
if counter % 100 == 0:
print(f"Read {counter} records")
if counter_bool:
2024-05-01 09:56:36 +02:00
break
2024-05-15 20:20:01 +02:00
except Exception as e:
failed_records.append(record_name)
print(f"Failed to read record {record_name} due to ValueError. Sum of failed records: {len(failed_records)}")
if counter_bool:
break
2024-05-01 09:56:36 +02:00
if counter_bool:
break
2024-05-15 20:20:01 +02:00
# write to pickle
for cat_name, records in diag_dict.items():
print(f"Writing {cat_name} to pickle with {len(records)} records")
# if path not exists create it
if not os.path.exists('./data'):
os.makedirs('./data')
with open(f'./data/{cat_name}.pkl', 'wb') as f:
pickle.dump(records, f)