""" This script reads the WFDB records and extracts the diagnosis information from the comments. The diagnosis information is then used to classify the records into categories. The categories are defined by the diagnosis codes in the comments. The records are then saved to pickle files based on the categories. """ import wfdb import os import numpy as np import pickle import json import multiprocessing import feature_extraction # Functions def get_diagnosis_ids(record): """ Extracts diagnosis IDs from a record and returns them as a list. Args: record (object): The record object containing the diagnosis information. Returns: list: A list of diagnosis IDs extracted from the record. """ # Get the diagnosis diagnosis = record.comments[2] # clean the diagnosis diagnosis = diagnosis.replace('Dx: ', '') list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')] return list_diagnosis def generate_raw_data(path_to_data, settings, max_counter=100_000): """ Generates the raw data from the WFDB records. Args: path_to_data (str): The path to the directory containing the WFDB records. max_counter (int): The maximum number of records to read. Returns: dict: A dictionary containing the raw data. """ counter = 0 failed_records = [] categories = settings["labels"] diag_dict = {k: [] for k in categories.keys()} # Loop through the records for dir_th in os.listdir(path_to_data): path_to_1000_records = path_to_data + '/' + dir_th for dir_hd in os.listdir(path_to_1000_records): path_to_100_records = path_to_1000_records + '/' + dir_hd for record_name in os.listdir(path_to_100_records): # check if .hea is in the record_name if '.hea' not in record_name: continue # Remove the .hea extension from record_name record_name = record_name.replace('.hea', '') try: # Read the record record = wfdb.rdrecord(path_to_100_records + '/' + record_name) # Get the diagnosis diagnosis = np.array(get_diagnosis_ids(record)) # check if diagnosis is a subset of one of the categories for category_name, category_codes in categories.items(): # if any of the diagnosis codes is in the category_codes if any(i in category_codes for i in diagnosis): diag_dict[category_name].append(record) break # Increment the counter of how many records we have read counter += 1 counter_bool = counter >= max_counter # Break the loop if we have read max_counter records if counter % 100 == 0: print(f"Read {counter} records") if counter_bool: break except Exception as e: failed_records.append(record_name) print(f"Failed to read record {record_name} due to ValueError. Sum of failed records: {len(failed_records)}") if counter_bool: break if counter_bool: break return diag_dict def write_data(data_dict, path='./data', file_prefix=''): """ Writes the data to a pickle file. Args: data_dict (dict): The data to be written to the file. dir_name (str): The directory where the file will be saved. """ # if path not exists create it if not os.path.exists(path): os.makedirs(path) # write to pickle for cat_name, data in data_dict.items(): print(f"Writing {cat_name} to pickle with {len(data)} data entries.") with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f: pickle.dump(data, f) def generate_feature_data(input_data_path, settings, parallel=False, split_ratio=None): """ Generates the feature data from the raw data. Args: input_data_path (str): The path to the directory containing the raw data. output_data_path (str): The path to the directory where the feature data will be saved. settings (dict): The settings dictionary. prefix (str): The prefix to be added to the feature files. split_ratio (list): The ratio in which the data will be split into training, test, and validation sets. """ if split_ratio is None: split_ratio = settings['split_ratio'] print(list(os.listdir(input_data_path))) for file in os.listdir(input_data_path): if file.endswith(".pkl"): print(f"Reading {file}") with open(f'{input_data_path}/{file}', 'rb') as f: data = pickle.load(f) #data_dict[file.replace('.pkl', '')] = data data_dict = {file.replace('.pkl', ''): data} # Extract the features if parallel: # get max number of processes max_processes = multiprocessing.cpu_count() - 1 print(f"Using {max_processes} processes to extract features.") feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes) else: feature_extraction.extract_features(data_dict) # Split the data feature_extraction.split_and_shuffle_data(split_ratio=split_ratio) def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1): """ Main function to generate the data. Args: gen_data (bool): If True, generates the raw data. gen_features (bool): If True, generates the feature data. split_ratio (list): The ratio in which the data will be split into training, test, and validation sets. settings_path (str): The path to the settings file. num_process_files (int): The maximum number of records to process. Returns: dict: The generated data. """ ret_data = None settings = json.load(open(settings_path)) if num_process_files < 0: num_process_files = 100_000 if split_ratio is None: split_ratio = settings['split_ratio'] if gen_data: raw_data_dir = settings["wfdb_path"] + '/WFDBRecords' data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files) write_data(data_dict, path=settings["data_path"]) ret_data = data_dict if gen_features: feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel) ret_data = feature_data_dict return ret_data # -------------------------------------------------------------------------------- # Generate the data # -------------------------------------------------------------------------------- if __name__ == '__main__': """ The following categories are used to classify the records: SB, Sinusbradykardie AFIB, Vorhofflimmern und Vorhofflattern (AFL) GSVT, supraventrikulärer Tachykardie, Vorhoftachykardie, AV-Knoten-Reentry-Tachykardie, AV-Reentry-Tachykardie, Vorhofschrittmacher SR Sinusrhythmus und Sinusunregelmäßigkeiten """ # SB, AFIB, GSVT, SR # new GSVT, AFIB, SR, SB # Generate the data main(gen_data=True, gen_features=False, num_process_files=100_000) #main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], parallel=False, num_process_files=100_000) print("Data generation completed.")