Compare commits
2 Commits
ba692a6d8b
...
0d6e85013f
Author | SHA1 | Date |
---|---|---|
Felix Jan Michael Mucha | 0d6e85013f | |
Felix Jan Michael Mucha | 52c66982e0 |
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
|
@ -13,7 +13,7 @@ import cv2 as cv
|
|||
TODO create overall description
|
||||
"""
|
||||
|
||||
def load_data(only_demographic:bool=False, path_settings:str="../settings.json"):
|
||||
def load_data(only_demographic:bool=False, only_diagnosis_ids=False, path_settings:str="../settings.json"):
|
||||
"""
|
||||
Loads data from pickle files based on the specified settings.
|
||||
|
||||
|
@ -28,6 +28,10 @@ def load_data(only_demographic:bool=False, path_settings:str="../settings.json")
|
|||
path_data = settings["data_path"]
|
||||
labels = settings["labels"]
|
||||
|
||||
if only_diagnosis_ids:
|
||||
with open(f'{path_data}/diagnosis.pkl', 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
data = {}
|
||||
if only_demographic:
|
||||
data = {'age': [], 'diag': [], 'gender': []}
|
||||
|
|
|
@ -5,6 +5,7 @@ import math
|
|||
import time
|
||||
from multiprocessing import Pool
|
||||
import sqlite3
|
||||
import random
|
||||
|
||||
def get_y_value(ecg_cleaned, indecies):
|
||||
"""
|
||||
|
@ -213,7 +214,6 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
|
|||
c = conn.cursor()
|
||||
# get unique data
|
||||
data_dict = exclude_already_extracted(data_dict, conn)
|
||||
|
||||
for label, data in data_dict.items():
|
||||
print(f"Extracting features for {label} with {len(data)} data entries.")
|
||||
with Pool(processes=num_processes) as pool:
|
||||
|
@ -239,7 +239,7 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
|
|||
|
||||
|
||||
|
||||
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
|
||||
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], limit=1000):
|
||||
"""
|
||||
Extracts the features from the data.
|
||||
Args:
|
||||
|
@ -266,6 +266,8 @@ def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4,
|
|||
print("No last file in DB")
|
||||
|
||||
for label, data in data_dict.items():
|
||||
# get limit amount of radom samples out of data
|
||||
data = random.sample(data, min(len(data), limit))
|
||||
print(f"Extracting features for {label} with {len(data)} data entries.")
|
||||
for data_idx, record in enumerate(data):
|
||||
# Skip the records that are already in the database
|
||||
|
|
|
@ -30,7 +30,7 @@ def get_diagnosis_ids(record):
|
|||
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
|
||||
return list_diagnosis
|
||||
|
||||
def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
||||
def generate_raw_data(path_to_data, settings, max_counter=100_000, only_ids=False):
|
||||
"""
|
||||
Generates the raw data from the WFDB records.
|
||||
Args:
|
||||
|
@ -43,6 +43,9 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
|||
failed_records = []
|
||||
categories = settings["labels"]
|
||||
|
||||
if only_ids:
|
||||
diag_dict = {}
|
||||
else:
|
||||
diag_dict = {k: [] for k in categories.keys()}
|
||||
# Loop through the records
|
||||
for dir_th in os.listdir(path_to_data):
|
||||
|
@ -60,6 +63,9 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
|||
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
|
||||
# Get the diagnosis
|
||||
diagnosis = np.array(get_diagnosis_ids(record))
|
||||
if only_ids:
|
||||
diag_dict[record_name] = diagnosis
|
||||
else:
|
||||
# check if diagnosis is a subset of one of the categories
|
||||
for category_name, category_codes in categories.items():
|
||||
# if any of the diagnosis codes is in the category_codes
|
||||
|
@ -83,7 +89,7 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
|||
break
|
||||
return diag_dict
|
||||
|
||||
def write_data(data_dict, path='./data', file_prefix=''):
|
||||
def write_data(data_dict, path='./data', file_prefix='', only_ids=False):
|
||||
"""
|
||||
Writes the data to a pickle file.
|
||||
Args:
|
||||
|
@ -93,6 +99,13 @@ def write_data(data_dict, path='./data', file_prefix=''):
|
|||
# if path not exists create it
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
if only_ids:
|
||||
# write to pickle
|
||||
print(f"Writing diagnosis IDs to pickle with {len(data_dict)} data entries.")
|
||||
with open(f'{path}/{file_prefix}.pkl', 'wb') as f:
|
||||
pickle.dump(data_dict, f)
|
||||
return
|
||||
# write to pickle
|
||||
for cat_name, data in data_dict.items():
|
||||
print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
|
||||
|
@ -114,7 +127,7 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
|
|||
split_ratio = settings['split_ratio']
|
||||
print(list(os.listdir(input_data_path)))
|
||||
for file in os.listdir(input_data_path):
|
||||
if file.endswith(".pkl"):
|
||||
if file.endswith(".pkl") and not file.startswith("diagnosis"):
|
||||
print(f"Reading {file}")
|
||||
with open(f'{input_data_path}/{file}', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
@ -127,13 +140,14 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
|
|||
print(f"Using {max_processes} processes to extract features.")
|
||||
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
|
||||
else:
|
||||
feature_extraction.extract_features(data_dict)
|
||||
print(f"For even distribution of data, the limit is set to the smallest size: 1000.")
|
||||
feature_extraction.extract_features(data_dict, limit=1000)
|
||||
# Split the data
|
||||
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
|
||||
|
||||
|
||||
|
||||
def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
|
||||
def main(gen_data=True, gen_features=True, gen_diag_ids=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
|
||||
"""
|
||||
Main function to generate the data.
|
||||
Args:
|
||||
|
@ -159,6 +173,11 @@ def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, set
|
|||
if gen_features:
|
||||
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
|
||||
ret_data = feature_data_dict
|
||||
if gen_diag_ids:
|
||||
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
|
||||
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files, only_ids=True)
|
||||
write_data(data_dict, path=settings["data_path"], file_prefix='diagnosis', only_ids=True)
|
||||
ret_data = data_dict
|
||||
|
||||
return ret_data
|
||||
|
||||
|
@ -178,6 +197,7 @@ if __name__ == '__main__':
|
|||
# SB, AFIB, GSVT, SR
|
||||
# new GSVT, AFIB, SR, SB
|
||||
# Generate the data
|
||||
main(gen_data=True, gen_features=False, num_process_files=100_000)
|
||||
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], parallel=False, num_process_files=100_000)
|
||||
#main(gen_data=True, gen_features=False, gen_diag_ids=False, num_process_files=100_000)
|
||||
#main(gen_data=False, gen_features=True, gen_diag_ids=False, split_ratio=[0.8, 0.1, 0.1])
|
||||
main(gen_data=False, gen_features=False, gen_diag_ids=True)
|
||||
print("Data generation completed.")
|
||||
|
|
Loading…
Reference in New Issue