Compare commits

..

2 Commits

6 changed files with 126 additions and 93 deletions

File diff suppressed because one or more lines are too long

View File

@ -13,7 +13,7 @@ import cv2 as cv
TODO create overall description
"""
def load_data(only_demographic:bool=False, path_settings:str="../settings.json"):
def load_data(only_demographic:bool=False, only_diagnosis_ids=False, path_settings:str="../settings.json"):
"""
Loads data from pickle files based on the specified settings.
@ -28,6 +28,10 @@ def load_data(only_demographic:bool=False, path_settings:str="../settings.json")
path_data = settings["data_path"]
labels = settings["labels"]
if only_diagnosis_ids:
with open(f'{path_data}/diagnosis.pkl', 'rb') as f:
return pickle.load(f)
data = {}
if only_demographic:
data = {'age': [], 'diag': [], 'gender': []}

View File

@ -5,6 +5,7 @@ import math
import time
from multiprocessing import Pool
import sqlite3
import random
def get_y_value(ecg_cleaned, indecies):
"""
@ -213,7 +214,6 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
c = conn.cursor()
# get unique data
data_dict = exclude_already_extracted(data_dict, conn)
for label, data in data_dict.items():
print(f"Extracting features for {label} with {len(data)} data entries.")
with Pool(processes=num_processes) as pool:
@ -239,7 +239,7 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], limit=1000):
"""
Extracts the features from the data.
Args:
@ -266,6 +266,8 @@ def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4,
print("No last file in DB")
for label, data in data_dict.items():
# get limit amount of radom samples out of data
data = random.sample(data, min(len(data), limit))
print(f"Extracting features for {label} with {len(data)} data entries.")
for data_idx, record in enumerate(data):
# Skip the records that are already in the database

View File

@ -30,7 +30,7 @@ def get_diagnosis_ids(record):
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
return list_diagnosis
def generate_raw_data(path_to_data, settings, max_counter=100_000):
def generate_raw_data(path_to_data, settings, max_counter=100_000, only_ids=False):
"""
Generates the raw data from the WFDB records.
Args:
@ -43,7 +43,10 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
failed_records = []
categories = settings["labels"]
diag_dict = {k: [] for k in categories.keys()}
if only_ids:
diag_dict = {}
else:
diag_dict = {k: [] for k in categories.keys()}
# Loop through the records
for dir_th in os.listdir(path_to_data):
path_to_1000_records = path_to_data + '/' + dir_th
@ -60,12 +63,15 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
# Get the diagnosis
diagnosis = np.array(get_diagnosis_ids(record))
# check if diagnosis is a subset of one of the categories
for category_name, category_codes in categories.items():
# if any of the diagnosis codes is in the category_codes
if any(i in category_codes for i in diagnosis):
diag_dict[category_name].append(record)
break
if only_ids:
diag_dict[record_name] = diagnosis
else:
# check if diagnosis is a subset of one of the categories
for category_name, category_codes in categories.items():
# if any of the diagnosis codes is in the category_codes
if any(i in category_codes for i in diagnosis):
diag_dict[category_name].append(record)
break
# Increment the counter of how many records we have read
counter += 1
counter_bool = counter >= max_counter
@ -83,7 +89,7 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
break
return diag_dict
def write_data(data_dict, path='./data', file_prefix=''):
def write_data(data_dict, path='./data', file_prefix='', only_ids=False):
"""
Writes the data to a pickle file.
Args:
@ -93,6 +99,13 @@ def write_data(data_dict, path='./data', file_prefix=''):
# if path not exists create it
if not os.path.exists(path):
os.makedirs(path)
if only_ids:
# write to pickle
print(f"Writing diagnosis IDs to pickle with {len(data_dict)} data entries.")
with open(f'{path}/{file_prefix}.pkl', 'wb') as f:
pickle.dump(data_dict, f)
return
# write to pickle
for cat_name, data in data_dict.items():
print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
@ -114,7 +127,7 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
split_ratio = settings['split_ratio']
print(list(os.listdir(input_data_path)))
for file in os.listdir(input_data_path):
if file.endswith(".pkl"):
if file.endswith(".pkl") and not file.startswith("diagnosis"):
print(f"Reading {file}")
with open(f'{input_data_path}/{file}', 'rb') as f:
data = pickle.load(f)
@ -127,13 +140,14 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
print(f"Using {max_processes} processes to extract features.")
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
else:
feature_extraction.extract_features(data_dict)
print(f"For even distribution of data, the limit is set to the smallest size: 1000.")
feature_extraction.extract_features(data_dict, limit=1000)
# Split the data
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
def main(gen_data=True, gen_features=True, gen_diag_ids=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
"""
Main function to generate the data.
Args:
@ -159,6 +173,11 @@ def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, set
if gen_features:
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
ret_data = feature_data_dict
if gen_diag_ids:
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files, only_ids=True)
write_data(data_dict, path=settings["data_path"], file_prefix='diagnosis', only_ids=True)
ret_data = data_dict
return ret_data
@ -178,6 +197,7 @@ if __name__ == '__main__':
# SB, AFIB, GSVT, SR
# new GSVT, AFIB, SR, SB
# Generate the data
main(gen_data=True, gen_features=False, num_process_files=100_000)
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], parallel=False, num_process_files=100_000)
#main(gen_data=True, gen_features=False, gen_diag_ids=False, num_process_files=100_000)
#main(gen_data=False, gen_features=True, gen_diag_ids=False, split_ratio=[0.8, 0.1, 0.1])
main(gen_data=False, gen_features=False, gen_diag_ids=True)
print("Data generation completed.")