Compare commits

..

No commits in common. "d6268c2cace9dd3653df080af07c7782ac5723b1" and "458a3bade64566faf48b386ea95bb72b9c4e6129" have entirely different histories.

14 changed files with 762 additions and 19 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -167,6 +167,20 @@
"test_accuracy = accuracy_score(test_y, test_pred)\n", "test_accuracy = accuracy_score(test_y, test_pred)\n",
"print(f'Testgenauigkeit: {test_accuracy}')\n" "print(f'Testgenauigkeit: {test_accuracy}')\n"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Die Validierungsgenauigkeit des Modells liegt bei 75,5%, was darauf hinweist, dass das Modell in etwa drei Vierteln der Fälle korrekte Vorhersagen auf den Validierungsdaten macht. Dies zeigt eine recht solide Leistung, deutet jedoch auch darauf hin, dass es noch Verbesserungspotenzial gibt, insbesondere bei der Verfeinerung des Modells, um die Fehlerquote zu senken"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mit einer Testgenauigkeit von 79% klassifiziert das Modell die Testdaten überwiegend korrekt. Dieses Ergebnis ist ein Indikator dafür, dass das Modell eine gute Generalisierungsfähigkeit aufweist und zuverlässig auf neuen, unbekannten Daten agieren kann. "
]
} }
], ],
"metadata": { "metadata": {

View File

@ -13,7 +13,7 @@ import cv2 as cv
TODO create overall description TODO create overall description
""" """
def load_data(only_demographic:bool=False, path_settings:str="../settings.json"): def load_data(only_demographic:bool=False, only_diagnosis_ids=False, path_settings:str="../settings.json"):
""" """
Loads data from pickle files based on the specified settings. Loads data from pickle files based on the specified settings.
@ -28,6 +28,10 @@ def load_data(only_demographic:bool=False, path_settings:str="../settings.json")
path_data = settings["data_path"] path_data = settings["data_path"]
labels = settings["labels"] labels = settings["labels"]
if only_diagnosis_ids:
with open(f'{path_data}/diagnosis.pkl', 'rb') as f:
return pickle.load(f)
data = {} data = {}
if only_demographic: if only_demographic:
data = {'age': [], 'diag': [], 'gender': []} data = {'age': [], 'diag': [], 'gender': []}

View File

@ -5,6 +5,7 @@ import math
import time import time
from multiprocessing import Pool from multiprocessing import Pool
import sqlite3 import sqlite3
import random
def get_y_value(ecg_cleaned, indecies): def get_y_value(ecg_cleaned, indecies):
""" """
@ -213,7 +214,6 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
c = conn.cursor() c = conn.cursor()
# get unique data # get unique data
data_dict = exclude_already_extracted(data_dict, conn) data_dict = exclude_already_extracted(data_dict, conn)
for label, data in data_dict.items(): for label, data in data_dict.items():
print(f"Extracting features for {label} with {len(data)} data entries.") print(f"Extracting features for {label} with {len(data)} data entries.")
with Pool(processes=num_processes) as pool: with Pool(processes=num_processes) as pool:
@ -239,7 +239,7 @@ def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]): def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], limit=1000):
""" """
Extracts the features from the data. Extracts the features from the data.
Args: Args:
@ -266,6 +266,8 @@ def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4,
print("No last file in DB") print("No last file in DB")
for label, data in data_dict.items(): for label, data in data_dict.items():
# get limit amount of radom samples out of data
data = random.sample(data, min(len(data), limit))
print(f"Extracting features for {label} with {len(data)} data entries.") print(f"Extracting features for {label} with {len(data)} data entries.")
for data_idx, record in enumerate(data): for data_idx, record in enumerate(data):
# Skip the records that are already in the database # Skip the records that are already in the database

View File

@ -30,7 +30,7 @@ def get_diagnosis_ids(record):
list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')] list_diagnosis = [int(x.strip()) for x in diagnosis.split(',')]
return list_diagnosis return list_diagnosis
def generate_raw_data(path_to_data, settings, max_counter=100_000): def generate_raw_data(path_to_data, settings, max_counter=100_000, only_ids=False):
""" """
Generates the raw data from the WFDB records. Generates the raw data from the WFDB records.
Args: Args:
@ -43,7 +43,10 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
failed_records = [] failed_records = []
categories = settings["labels"] categories = settings["labels"]
diag_dict = {k: [] for k in categories.keys()} if only_ids:
diag_dict = {}
else:
diag_dict = {k: [] for k in categories.keys()}
# Loop through the records # Loop through the records
for dir_th in os.listdir(path_to_data): for dir_th in os.listdir(path_to_data):
path_to_1000_records = path_to_data + '/' + dir_th path_to_1000_records = path_to_data + '/' + dir_th
@ -60,12 +63,15 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
record = wfdb.rdrecord(path_to_100_records + '/' + record_name) record = wfdb.rdrecord(path_to_100_records + '/' + record_name)
# Get the diagnosis # Get the diagnosis
diagnosis = np.array(get_diagnosis_ids(record)) diagnosis = np.array(get_diagnosis_ids(record))
# check if diagnosis is a subset of one of the categories if only_ids:
for category_name, category_codes in categories.items(): diag_dict[record_name] = diagnosis
# if any of the diagnosis codes is in the category_codes else:
if any(i in category_codes for i in diagnosis): # check if diagnosis is a subset of one of the categories
diag_dict[category_name].append(record) for category_name, category_codes in categories.items():
break # if any of the diagnosis codes is in the category_codes
if any(i in category_codes for i in diagnosis):
diag_dict[category_name].append(record)
break
# Increment the counter of how many records we have read # Increment the counter of how many records we have read
counter += 1 counter += 1
counter_bool = counter >= max_counter counter_bool = counter >= max_counter
@ -83,7 +89,7 @@ def generate_raw_data(path_to_data, settings, max_counter=100_000):
break break
return diag_dict return diag_dict
def write_data(data_dict, path='./data', file_prefix=''): def write_data(data_dict, path='./data', file_prefix='', only_ids=False):
""" """
Writes the data to a pickle file. Writes the data to a pickle file.
Args: Args:
@ -93,6 +99,13 @@ def write_data(data_dict, path='./data', file_prefix=''):
# if path not exists create it # if path not exists create it
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
if only_ids:
# write to pickle
print(f"Writing diagnosis IDs to pickle with {len(data_dict)} data entries.")
with open(f'{path}/{file_prefix}.pkl', 'wb') as f:
pickle.dump(data_dict, f)
return
# write to pickle # write to pickle
for cat_name, data in data_dict.items(): for cat_name, data in data_dict.items():
print(f"Writing {cat_name} to pickle with {len(data)} data entries.") print(f"Writing {cat_name} to pickle with {len(data)} data entries.")
@ -114,7 +127,7 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
split_ratio = settings['split_ratio'] split_ratio = settings['split_ratio']
print(list(os.listdir(input_data_path))) print(list(os.listdir(input_data_path)))
for file in os.listdir(input_data_path): for file in os.listdir(input_data_path):
if file.endswith(".pkl"): if file.endswith(".pkl") and not file.startswith("diagnosis"):
print(f"Reading {file}") print(f"Reading {file}")
with open(f'{input_data_path}/{file}', 'rb') as f: with open(f'{input_data_path}/{file}', 'rb') as f:
data = pickle.load(f) data = pickle.load(f)
@ -127,13 +140,14 @@ def generate_feature_data(input_data_path, settings, parallel=False, split_ratio
print(f"Using {max_processes} processes to extract features.") print(f"Using {max_processes} processes to extract features.")
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes) feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
else: else:
feature_extraction.extract_features(data_dict) print(f"For even distribution of data, the limit is set to the smallest size: 1000.")
feature_extraction.extract_features(data_dict, limit=1000)
# Split the data # Split the data
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio) feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1): def main(gen_data=True, gen_features=True, gen_diag_ids=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
""" """
Main function to generate the data. Main function to generate the data.
Args: Args:
@ -159,6 +173,11 @@ def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, set
if gen_features: if gen_features:
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel) feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
ret_data = feature_data_dict ret_data = feature_data_dict
if gen_diag_ids:
raw_data_dir = settings["wfdb_path"] + '/WFDBRecords'
data_dict = generate_raw_data(raw_data_dir, settings, max_counter=num_process_files, only_ids=True)
write_data(data_dict, path=settings["data_path"], file_prefix='diagnosis', only_ids=True)
ret_data = data_dict
return ret_data return ret_data
@ -178,6 +197,7 @@ if __name__ == '__main__':
# SB, AFIB, GSVT, SR # SB, AFIB, GSVT, SR
# new GSVT, AFIB, SR, SB # new GSVT, AFIB, SR, SB
# Generate the data # Generate the data
main(gen_data=True, gen_features=False, num_process_files=100_000) #main(gen_data=True, gen_features=False, gen_diag_ids=False, num_process_files=100_000)
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], parallel=False, num_process_files=100_000) #main(gen_data=False, gen_features=True, gen_diag_ids=False, split_ratio=[0.8, 0.1, 0.1])
main(gen_data=False, gen_features=False, gen_diag_ids=True)
print("Data generation completed.") print("Data generation completed.")

View File

@ -1,8 +1,8 @@
{ {
"wfdb_path_comment": "Path to the WFDB data. This is the folder where the WFDB data is stored.", "wfdb_path_comment": "Path to the WFDB data. This is the folder where the WFDB data is stored.",
"wfdb_path": "C:/Users/arman/PycharmProjects/pythonProject/DSA/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0", "wfdb_path": "C:/Studium/dsa/large_12_ecg_data/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0",
"data_path_comment": "Path to the data folder. This is the folder where the genereated data is stored.", "data_path_comment": "Path to the data folder. This is the folder where the genereated data is stored.",
"data_path": "C:/Users/arman/PycharmProjects/pythonProject/DSA/DSA_SS24/data", "data_path": "C:/Studium/dsa/data",
"labels_comment": "Labels for the different classes. The labels are the SNOMED CT codes.", "labels_comment": "Labels for the different classes. The labels are the SNOMED CT codes.",
"labels": { "labels": {
"GSVT": [426761007, 713422000, 233896004, 233897008, 713422000], "GSVT": [426761007, 713422000, 233896004, 233897008, 713422000],