xgboost model

main
Felix Jan Michael Mucha 2024-06-05 23:18:20 +02:00
parent f924d962b8
commit 52a8e8eeca
10 changed files with 990 additions and 92 deletions

2
.gitignore vendored
View File

@ -1,2 +1,2 @@
/data/
/settings.json
settings.json

BIN
features.db 100644

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,37 @@
import sqlite3
import pandas as pd
conn = sqlite3.connect('features.db')
c = conn.cursor()
# print names of available tables
c.execute("SELECT name FROM sqlite_master WHERE type='table';")
print("Table names: ", c.fetchall())
# for each table in the database, print the number of rows
for table in ['train', 'test', 'validation']:
c.execute(f'SELECT COUNT(*) FROM {table}')
print(f"Number of rows in the {table} table: ", c.fetchall()[0][0])
# print the number of rows in features table
c.execute('SELECT COUNT(*) FROM features')
print("Number of rows in the features table: ", c.fetchall()[0][0])
# print column names
c.execute('PRAGMA table_info(features)')
print("Column names in the features table: ", c.fetchall())
# count for each label how many rows there are
c.execute('SELECT y, COUNT(*) FROM features GROUP BY y')
print("Number of rows for each label: ", c.fetchall())
# Load data from the features table into a DataFrame
df = pd.read_sql_query("SELECT * FROM features", conn)
# Now you can work with the data in the df DataFrame
print(df.head(15))
# close the connection
conn.close()

View File

@ -1,13 +1,10 @@
from matplotlib import pyplot as plt
import wfdb.processing
import sys
import json
import scipy
import numpy as np
import neurokit2 as nk
import math
import time
from multiprocessing import Pool
import sqlite3
def get_y_value(ecg_cleaned, indecies):
"""
@ -82,97 +79,261 @@ def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0)
return r_axis, t_axis
def get_features(record, label, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None):
"""
Extracts the features from the record.
Args:
record (object): The record object containing the ECG signal.
label (str): The label of the record.
sampling_rate (int): The sampling rate of the ECG signal.
Returns:
dict: The dictionary containing the extracted features.
"""
age = record.comments[0].split(' ')[1]
gender = record.comments[1].split(' ')[1]
if age == 'NaN' or gender == 'NaN':
return None
features = {}
# Extract the features
features['y'] = label
# Demographic features
features['age'] = int(age)
features['gender'] = True if gender == 'Male' else False
# Signal features
# Delineate the ECG signal
try:
ecg_signal = record.p_signal[:, 0]
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
r_peaks = rpeaks['ECG_R_Peaks']
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
except KeyboardInterrupt:
conn.close()
raise
except:
return None
# TODO Other features and check features
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
features['artial_rate'] = atrial_rate
features['ventricular_rate'] = ventricular_rate
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qrs_duration'] = qrs_duration
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qt_length'] = qt_interval
q_peak = waves_peak['ECG_Q_Peaks']
s_peak = waves_peak['ECG_S_Peaks']
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
features['qrs_count'] = qrs_count
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
features['r_axis'] = r_axis
features['t_axis'] = t_axis
return features
def exclude_already_extracted(data_dict, conn):
"""
Exclude the records that are already in the database.
Args:
data_dict (dict): The dictionary containing the data.
Returns:
dict: The dictionary containing the unique data.
"""
unique_data_dict = {}
record_ids = []
for label, data in data_dict.items():
for record in data:
record_ids.append(record.record_name)
# get id column from database and subtract it from record_ids
c = conn.cursor()
c.execute('SELECT id FROM features')
db_ids = c.fetchall()
db_ids = [x[0] for x in db_ids]
unique_ids = list(set(record_ids) - set(db_ids))
for label, data in data_dict.items():
unique_data_dict[label] = [record for record in data if record.record_name in unique_ids]
return unique_data_dict
def process_record(data_dict_item, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None, c=None):
"""
Process a record to extract the features.
Args:
data_dict_item (tuple): The tuple containing the data dictionary item.
sampling_rate (int): The sampling rate of the ECG signal.
used_channels (list): The list of used channels.
conn (object): The connection object to the database.
c (object): The cursor object.
Returns:
tuple: The tuple containing the record name and the extracted features.
"""
label, record, c, conn = data_dict_item
features = get_features(record,label, sampling_rate=sampling_rate, used_channels=used_channels)
if features is None:
print(f"Failed to extract features for record {record.record_name}")
return record.record_name, None
# TODO: Insert the record into the database
#feature_data[record_name] = features
# Define the feature names
feature_names = list(features.keys())
# Create the table if it doesn't exist
c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})")
# Insert the record into the table
c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})",
[record.record_name] + list(features.values()))
conn.commit()
return record.record_name, features
def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
"""
Extracts the features from the data.
Args:
data_dict (dict): The dictionary containing the data.
num_processes (int): The number of processes to use.
"""
start_time = time.time()
failed_records = []
processed_records = 0
conn = sqlite3.connect('features.db')
c = conn.cursor()
# get unique data
data_dict = exclude_already_extracted(data_dict, conn)
for label, data in data_dict.items():
print(f"Extracting features for {label} with {len(data)} data entries.")
with Pool(processes=num_processes) as pool:
results = pool.map(process_record, [(label, record, c, conn) for record in data])
for result in results:
processed_records += 1
if processed_records % 100 == 0:
stop_time = time.time()
print(f"Extracted features for {processed_records} records. Time taken: {stop_time - start_time:.2f}s")
start_time = time.time()
try:
record_name, features = result
except Exception as exc:
print(f'{record_name} generated an exception: {exc}')
else:
if features is not None:
print(f"Extracted features for record {record_name}")
else:
failed_records.append(record_name)
print(f"Sum of failed records: {len(failed_records)}")
conn.close()
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
"""
Extracts the features from the data.
Args:
data_dict (dict): The dictionary containing the data.
Returns:
dict: The dictionary containing the extracted features.
"""
start_time = time.time()
feature_data = {}
failed_records = []
conn = sqlite3.connect('features.db')
c = conn.cursor()
# get last file in the database
try:
# print how many records are already in the database
c.execute('SELECT COUNT(*) FROM features')
print("Records in DB:", c.fetchall())
c = conn.cursor()
c.execute('SELECT id FROM features')
db_ids = c.fetchall()
db_ids = [x[0] for x in db_ids]
except:
db_ids = []
print("No last file in DB")
for label, data in data_dict.items():
print(f"Extracting features for {label} with {len(data)} data entries.")
for data_idx, record in enumerate(data):
# Skip the records that are already in the database
if record.record_name in db_ids:
continue
# print current status
if data_idx % 100 == 0:
stop_time = time.time()
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
start_time = time.time()
age = record.comments[0].split(' ')[1]
gender = record.comments[1].split(' ')[1]
if age == 'NaN' or gender == 'NaN':
continue
features = {}
# Extract the features
features['y'] = label
# Demographic features
features['age'] = int(age)
features['gender'] = True if gender == 'Male' else False
# Signal features
ecg_signal = record.p_signal[:, 0]
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
r_peaks = rpeaks['ECG_R_Peaks']
# Delineate the ECG signal
try:
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
except:
# get the features
features = get_features(record, label, sampling_rate=sampling_rate, used_channels=used_channels, conn=conn)
if features is None:
failed_records.append(record.record_name)
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
continue
# TODO Other features and check features
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
features['artial_rate'] = atrial_rate
features['ventricular_rate'] = ventricular_rate
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qrs_duration'] = qrs_duration
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
features['qt_length'] = qt_interval
q_peak = waves_peak['ECG_Q_Peaks']
s_peak = waves_peak['ECG_S_Peaks']
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
features['qrs_count'] = qrs_count
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
features['r_axis'] = r_axis
features['t_axis'] = t_axis
# print the features
#print(json.dumps(features, indent=4))
feature_data[record.record_name] = features
return feature_data
# Define the feature names
feature_names = list(features.keys())
# Create the table if it doesn't exist
c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})")
# Insert the record into the table
c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})",
[record.record_name] + list(features.values()))
conn.commit()
conn.close()
def split_data(feature_data, split_ratio):
def split_and_shuffle_data(split_ratio, db=None):
"""
Splits the data into training, test, and validation sets.
Args:
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
db (str): The name of the database.
"""
print(f"Splitting data with ratio {split_ratio}")
#flatten dictionary
feature_data = {k: v for k, v in feature_data.items()}
# print keys
print("Keys:")
print(len(feature_data.keys()))
# shuffle the data
keys = list(feature_data.keys())
np.random.shuffle(keys)
# split the data
split_data = {}
split_data['train'] = {k: feature_data[k] for k in keys[:int(len(keys) * split_ratio[0])]}
split_data['test'] = {k: feature_data[k] for k in keys[int(len(keys) * split_ratio[0]):int(len(keys) * (split_ratio[0] + split_ratio[1]))]}
split_data['validation'] = {k: feature_data[k] for k in keys[int(len(keys) * (split_ratio[0] + split_ratio[1])):]}
return split_data
if db is None:
db = 'features.db'
conn = sqlite3.connect(db)
c = conn.cursor()
# shuffle the data (rows)
# Randomize the rows and create a new table with a new ROWID
c.execute("CREATE TABLE IF NOT EXISTS features_random AS SELECT * FROM features ORDER BY RANDOM()")
# Drop the old features table
c.execute("DROP TABLE features")
# Rename the new table to features
c.execute("ALTER TABLE features_random RENAME TO features")
# get length of the data
c.execute('SELECT COUNT(*) FROM features')
length = c.fetchone()[0]
# split the data into training, test, and validation sets
train_idx = int(length * split_ratio[0])
test_idx = int(length * (split_ratio[0] + split_ratio[1]))
# create 3 tables for training, test, and validation sets
# Create the train, test, and validation tables if they don't exist
c.execute("CREATE TABLE IF NOT EXISTS train AS SELECT * FROM features WHERE 0")
c.execute("CREATE TABLE IF NOT EXISTS test AS SELECT * FROM features WHERE 0")
c.execute("CREATE TABLE IF NOT EXISTS validation AS SELECT * FROM features WHERE 0")
# Insert rows into the train table
c.execute("INSERT INTO train SELECT * FROM features WHERE ROWID <= ?", (train_idx,))
c.execute("INSERT INTO test SELECT * FROM features WHERE ROWID > ? AND ROWID <= ?", (train_idx, test_idx,))
c.execute("INSERT INTO validation SELECT * FROM features WHERE ROWID > ?", (test_idx,))
# drop the features table
# c.execute('DROP TABLE features')
conn.commit()
conn.close()

View File

@ -10,6 +10,7 @@ import os
import numpy as np
import pickle
import json
import multiprocessing
import feature_extraction
@ -98,7 +99,7 @@ def write_data(data_dict, path='./data', file_prefix=''):
with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f:
pickle.dump(data, f)
def generate_feature_data(input_data_path, output_data_path, settings, prefix='feature_', split_ratio=None):
def generate_feature_data(input_data_path, settings, parallel=False, split_ratio=None):
"""
Generates the feature data from the raw data.
Args:
@ -111,26 +112,28 @@ def generate_feature_data(input_data_path, output_data_path, settings, prefix='f
"""
if split_ratio is None:
split_ratio = settings['split_ratio']
data_dict = {}
print(list(os.listdir(input_data_path)))
for file in os.listdir(input_data_path):
if file.endswith(".pkl"):
print(f"Reading {file}")
with open(f'{input_data_path}/{file}', 'rb') as f:
data = pickle.load(f)
data_dict[file.replace('.pkl', '')] = data
# Extract the features
feature_data = feature_extraction.extract_features(data_dict)
#data_dict[file.replace('.pkl', '')] = data
data_dict = {file.replace('.pkl', ''): data}
# Extract the features
if parallel:
# get max number of processes
max_processes = multiprocessing.cpu_count() - 1
print(f"Using {max_processes} processes to extract features.")
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
else:
feature_extraction.extract_features(data_dict)
# Split the data
splited_data = feature_extraction.split_data(feature_data, split_ratio)
if not os.path.exists(f'{output_data_path}/ml_dataset/'):
os.makedirs(f'{output_data_path}/ml_dataset/')
for file_name, data in splited_data.items():
print(f"Writing {file_name} to pickle with {len(data)} data entries.")
with open(f'{output_data_path}/ml_dataset/{prefix}{file_name}', 'wb') as f:
pickle.dump(data, f)
return splited_data
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
def main(gen_data=True, gen_features=True, split_ratio=None, settings_path='./settings.json', num_process_files=-1):
def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
"""
Main function to generate the data.
Args:
@ -154,7 +157,7 @@ def main(gen_data=True, gen_features=True, split_ratio=None, settings_path='./se
write_data(data_dict, path=settings["data_path"])
ret_data = data_dict
if gen_features:
feature_data_dict = generate_feature_data(settings["data_path"], settings["data_path"], settings, split_ratio=split_ratio)
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
ret_data = feature_data_dict
return ret_data
@ -176,5 +179,5 @@ if __name__ == '__main__':
# new GSVT, AFIB, SR, SB
# Generate the data
main(gen_data=True, gen_features=False, num_process_files=100_000)
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], num_process_files=100_000)
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], parallel=False, num_process_files=100_000)
print("Data generation completed.")