Compare commits
2 Commits
ba692a6d8b
...
0d6e85013f
Author | SHA1 | Date |
---|---|---|
Felix Jan Michael Mucha | 0d6e85013f | |
Felix Jan Michael Mucha | 52c66982e0 |
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
|
@ -13,7 +13,7 @@ import cv2 as cv
|
||||||
TODO create overall description
|
TODO create overall description
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load_data(only_demographic:bool=False, path_settings:str="../settings.json"):
|
def load_data(only_demographic:bool=False, only_diagnosis_ids=False, path_settings:str="../settings.json"):
|
||||||
"""
|
"""
|
||||||
Loads data from pickle files based on the specified settings.
|
Loads data from pickle files based on the specified settings.
|
||||||
|
|
||||||
|
@ -28,6 +28,10 @@ def load_data(only_demographic:bool=False, path_settings:str="../settings.json")
|
||||||
path_data = settings["data_path"]
|
path_data = settings["data_path"]
|
||||||
labels = settings["labels"]
|
labels = settings["labels"]
|
||||||
|
|
||||||
|
if only_diagnosis_ids:
|
||||||
|
with open(f'{path_data}/diagnosis.pkl', 'rb') as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
if only_demographic:
|
if only_demographic:
|
||||||
data = {'age': [], 'diag': [], 'gender': []}
|
data = {'age': [], 'diag': [], 'gender': []}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import math
|
||||||
import time
|
import time
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import random
|
||||||
|
|
||||||
def get_y_value(ecg_cleaned, indecies):
|
def get_y_value(ecg_cleaned, indecies):
|
||||||
"""
|
"""
|
||||||
|
@ -213,7 +214,6 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
# get unique data
|
# get unique data
|
||||||
data_dict = exclude_already_extracted(data_dict, conn)
|
data_dict = exclude_already_extracted(data_dict, conn)
|
||||||
|
|
||||||
for label, data in data_dict.items():
|
for label, data in data_dict.items():
|
||||||
print(f"Extracting features for {label} with {len(data)} data entries.")
|
print(f"Extracting features for {label} with {len(data)} data entries.")
|
||||||
with Pool(processes=num_processes) as pool:
|
with Pool(processes=num_processes) as pool:
|
||||||
|
@ -239,7 +239,7 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
|
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], limit=1000):
|
||||||
"""
|
"""
|
||||||
Extracts the features from the data.
|
Extracts the features from the data.
|
||||||
Args:
|
Args:
|
||||||
|
@ -266,6 +266,8 @@ def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4,
|
||||||
print("No last file in DB")
|
print("No last file in DB")
|
||||||
|
|
||||||
for label, data in data_dict.items():
|
for label, data in data_dict.items():
|
||||||
|
# get limit amount of radom samples out of data
|
||||||
|
data = random.sample(data, min(len(data), limit))
|
||||||
print(f"Extracting features for {label} with {len(data)} data entries.")
|
print(f"Extracting features for {label} with {len(data)} data entries.")
|
||||||
for data_idx, record in enumerate(data):
|
for data_idx, record in enumerate(data):
|
||||||
# Skip the records that are already in the database
|
# Skip the records that are already in the database
|
||||||
|
|
|
@ -30,7 +30,7 @@ def get_diagnosis_ids(record):
|
||||||
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
|
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
|
||||||
return list_diagnosis
|
return list_diagnosis
|
||||||
|
|
||||||
def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
def generate_raw_data(path_to_data, settings, max_counter=100_000, only_ids=False):
|
||||||
"""
|
"""
|
||||||
Generates the raw data from the WFDB records.
|
Generates the raw data from the WFDB records.
|
||||||
Args:
|
Args:
|
||||||
|
@ -43,7 +43,10 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
||||||
failed_records = []
|
failed_records = []
|
||||||
categories = settings["labels"]
|
categories = settings["labels"]
|
||||||
|
|
||||||
diag_dict = {k: [] for k in categories.keys()}
|
if only_ids:
|
||||||
|
diag_dict = {}
|
||||||
|
else:
|
||||||
|
diag_dict = {k: [] for k in categories.keys()}
|
||||||
# Loop through the records
|
# Loop through the records
|
||||||
for dir_th in os.listdir(path_to_data):
|
for dir_th in os.listdir(path_to_data):
|
||||||
path_to_1000_records = path_to_data + '/' + dir_th
|
path_to_1000_records = path_to_data + '/' + dir_th
|
||||||
|
@ -60,12 +63,15 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
||||||
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
|
record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
|
||||||
# Get the diagnosis
|
# Get the diagnosis
|
||||||
diagnosis = np.array(get_diagnosis_ids(record))
|
diagnosis = np.array(get_diagnosis_ids(record))
|
||||||
# check if diagnosis is a subset of one of the categories
|
if only_ids:
|
||||||
for category_name, category_codes in categories.items():
|
diag_dict[record_name] = diagnosis
|
||||||
# if any of the diagnosis codes is in the category_codes
|
else:
|
||||||
if any(i in category_codes for i in diagnosis):
|
# check if diagnosis is a subset of one of the categories
|
||||||
diag_dict[category_name].append(record)
|
for category_name, category_codes in categories.items():
|
||||||
break
|
# if any of the diagnosis codes is in the category_codes
|
||||||
|
if any(i in category_codes for i in diagnosis):
|
||||||
|
diag_dict[category_name].append(record)
|
||||||
|
break
|
||||||
# Increment the counter of how many records we have read
|
# Increment the counter of how many records we have read
|
||||||
counter += 1
|
counter += 1
|
||||||
counter_bool = counter >= max_counter
|
counter_bool = counter >= max_counter
|
||||||
|
@ -83,7 +89,7 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
|
||||||
break
|
break
|
||||||
return diag_dict
|
return diag_dict
|
||||||
|
|
||||||
def write_data(data_dict, path='./data', file_prefix=''):
|
def write_data(data_dict, path='./data', file_prefix='', only_ids=False):
|
||||||
"""
|
"""
|
||||||
Writes the data to a pickle file.
|
Writes the data to a pickle file.
|
||||||
Args:
|
Args:
|
||||||
|
@ -93,6 +99,13 @@ def write_data(data_dict, path='./data', file_prefix=''):
|
||||||
# if path not exists create it
|
# if path not exists create it
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
|
if only_ids:
|
||||||
|
# write to pickle
|
||||||
|
print(f"Writing diagnosis IDs to pickle with {len(data_dict)} data entries.")
|
||||||
|
with open(f'{path}/{file_prefix}.pkl', 'wb') as f:
|
||||||
|
pickle.dump(data_dict, f)
|
||||||
|
return
|
||||||
# write to pickle
|
# write to pickle
|
||||||
for cat_name, data in data_dict.items():
|
for cat_name, data in data_dict.items():
|
||||||
print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
|
print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
|
||||||
|
@ -114,7 +127,7 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
|
||||||
split_ratio = settings['split_ratio']
|
split_ratio = settings['split_ratio']
|
||||||
print(list(os.listdir(input_data_path)))
|
print(list(os.listdir(input_data_path)))
|
||||||
for file in os.listdir(input_data_path):
|
for file in os.listdir(input_data_path):
|
||||||
if file.endswith(".pkl"):
|
if file.endswith(".pkl") and not file.startswith("diagnosis"):
|
||||||
print(f"Reading {file}")
|
print(f"Reading {file}")
|
||||||
with open(f'{input_data_path}/{file}', 'rb') as f:
|
with open(f'{input_data_path}/{file}', 'rb') as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
|
@ -127,13 +140,14 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
|
||||||
print(f"Using {max_processes} processes to extract features.")
|
print(f"Using {max_processes} processes to extract features.")
|
||||||
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
|
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
|
||||||
else:
|
else:
|
||||||
feature_extraction.extract_features(data_dict)
|
print(f"For even distribution of data, the limit is set to the smallest size: 1000.")
|
||||||
|
feature_extraction.extract_features(data_dict, limit=1000)
|
||||||
# Split the data
|
# Split the data
|
||||||
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
|
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
|
def main(gen_data=True, gen_features=True, gen_diag_ids=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
|
||||||
"""
|
"""
|
||||||
Main function to generate the data.
|
Main function to generate the data.
|
||||||
Args:
|
Args:
|
||||||
|
@ -159,6 +173,11 @@ def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, set
|
||||||
if gen_features:
|
if gen_features:
|
||||||
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
|
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
|
||||||
ret_data = feature_data_dict
|
ret_data = feature_data_dict
|
||||||
|
if gen_diag_ids:
|
||||||
|
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
|
||||||
|
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files, only_ids=True)
|
||||||
|
write_data(data_dict, path=settings["data_path"], file_prefix='diagnosis', only_ids=True)
|
||||||
|
ret_data = data_dict
|
||||||
|
|
||||||
return ret_data
|
return ret_data
|
||||||
|
|
||||||
|
@ -178,6 +197,7 @@ if __name__ == '__main__':
|
||||||
# SB, AFIB, GSVT, SR
|
# SB, AFIB, GSVT, SR
|
||||||
# new GSVT, AFIB, SR, SB
|
# new GSVT, AFIB, SR, SB
|
||||||
# Generate the data
|
# Generate the data
|
||||||
main(gen_data=True, gen_features=False, num_process_files=100_000)
|
#main(gen_data=True, gen_features=False, gen_diag_ids=False, num_process_files=100_000)
|
||||||
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], parallel=False, num_process_files=100_000)
|
#main(gen_data=False, gen_features=True, gen_diag_ids=False, split_ratio=[0.8, 0.1, 0.1])
|
||||||
|
main(gen_data=False, gen_features=False, gen_diag_ids=True)
|
||||||
print("Data generation completed.")
|
print("Data generation completed.")
|
||||||
|
|
Loading…
Reference in New Issue