xgboost model
parent
f924d962b8
commit
52a8e8eeca
|
@ -1,2 +1,2 @@
|
||||||
/data/
|
/data/
|
||||||
/settings.json
|
settings.json
|
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,37 @@
|
||||||
|
import sqlite3
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
conn = sqlite3.connect('features.db')
|
||||||
|
c = conn.cursor()
|
||||||
|
# print names of available tables
|
||||||
|
c.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||||
|
print("Table names: ", c.fetchall())
|
||||||
|
|
||||||
|
# for each table in the database, print the number of rows
|
||||||
|
for table in ['train', 'test', 'validation']:
|
||||||
|
c.execute(f'SELECT COUNT(*) FROM {table}')
|
||||||
|
print(f"Number of rows in the {table} table: ", c.fetchall()[0][0])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# print the number of rows in features table
|
||||||
|
c.execute('SELECT COUNT(*) FROM features')
|
||||||
|
print("Number of rows in the features table: ", c.fetchall()[0][0])
|
||||||
|
# print column names
|
||||||
|
c.execute('PRAGMA table_info(features)')
|
||||||
|
print("Column names in the features table: ", c.fetchall())
|
||||||
|
|
||||||
|
# count for each label how many rows there are
|
||||||
|
c.execute('SELECT y, COUNT(*) FROM features GROUP BY y')
|
||||||
|
print("Number of rows for each label: ", c.fetchall())
|
||||||
|
|
||||||
|
|
||||||
|
# Load data from the features table into a DataFrame
|
||||||
|
df = pd.read_sql_query("SELECT * FROM features", conn)
|
||||||
|
# Now you can work with the data in the df DataFrame
|
||||||
|
print(df.head(15))
|
||||||
|
|
||||||
|
# close the connection
|
||||||
|
|
||||||
|
conn.close()
|
|
@ -1,13 +1,10 @@
|
||||||
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
import wfdb.processing
|
import wfdb.processing
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import scipy
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import neurokit2 as nk
|
import neurokit2 as nk
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
from multiprocessing import Pool
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
def get_y_value(ecg_cleaned, indecies):
|
def get_y_value(ecg_cleaned, indecies):
|
||||||
"""
|
"""
|
||||||
|
@ -82,97 +79,261 @@ def calculate_axis(record, wave_peak, r_peak_idx, sampling_rate=500, aVF=5, I=0)
|
||||||
|
|
||||||
return r_axis, t_axis
|
return r_axis, t_axis
|
||||||
|
|
||||||
|
def get_features(record, label, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None):
|
||||||
|
"""
|
||||||
|
Extracts the features from the record.
|
||||||
|
Args:
|
||||||
|
record (object): The record object containing the ECG signal.
|
||||||
|
label (str): The label of the record.
|
||||||
|
sampling_rate (int): The sampling rate of the ECG signal.
|
||||||
|
Returns:
|
||||||
|
dict: The dictionary containing the extracted features.
|
||||||
|
"""
|
||||||
|
age = record.comments[0].split(' ')[1]
|
||||||
|
gender = record.comments[1].split(' ')[1]
|
||||||
|
if age == 'NaN' or gender == 'NaN':
|
||||||
|
return None
|
||||||
|
features = {}
|
||||||
|
# Extract the features
|
||||||
|
features['y'] = label
|
||||||
|
# Demographic features
|
||||||
|
features['age'] = int(age)
|
||||||
|
features['gender'] = True if gender == 'Male' else False
|
||||||
|
# Signal features
|
||||||
|
|
||||||
|
|
||||||
|
# Delineate the ECG signal
|
||||||
|
try:
|
||||||
|
ecg_signal = record.p_signal[:, 0]
|
||||||
|
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
|
||||||
|
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
|
||||||
|
r_peaks = rpeaks['ECG_R_Peaks']
|
||||||
|
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
conn.close()
|
||||||
|
raise
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO Other features and check features
|
||||||
|
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
|
||||||
|
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
|
||||||
|
|
||||||
|
features['artial_rate'] = atrial_rate
|
||||||
|
features['ventricular_rate'] = ventricular_rate
|
||||||
|
|
||||||
|
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
|
||||||
|
features['qrs_duration'] = qrs_duration
|
||||||
|
|
||||||
|
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
|
||||||
|
features['qt_length'] = qt_interval
|
||||||
|
|
||||||
|
q_peak = waves_peak['ECG_Q_Peaks']
|
||||||
|
s_peak = waves_peak['ECG_S_Peaks']
|
||||||
|
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
|
||||||
|
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
|
||||||
|
features['qrs_count'] = qrs_count
|
||||||
|
|
||||||
|
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
|
||||||
|
|
||||||
|
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
|
||||||
|
features['r_axis'] = r_axis
|
||||||
|
features['t_axis'] = t_axis
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def exclude_already_extracted(data_dict, conn):
|
||||||
|
"""
|
||||||
|
Exclude the records that are already in the database.
|
||||||
|
Args:
|
||||||
|
data_dict (dict): The dictionary containing the data.
|
||||||
|
Returns:
|
||||||
|
dict: The dictionary containing the unique data.
|
||||||
|
"""
|
||||||
|
unique_data_dict = {}
|
||||||
|
record_ids = []
|
||||||
|
for label, data in data_dict.items():
|
||||||
|
for record in data:
|
||||||
|
record_ids.append(record.record_name)
|
||||||
|
# get id column from database and subtract it from record_ids
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('SELECT id FROM features')
|
||||||
|
db_ids = c.fetchall()
|
||||||
|
db_ids = [x[0] for x in db_ids]
|
||||||
|
unique_ids = list(set(record_ids) - set(db_ids))
|
||||||
|
for label, data in data_dict.items():
|
||||||
|
unique_data_dict[label] = [record for record in data if record.record_name in unique_ids]
|
||||||
|
return unique_data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def process_record(data_dict_item, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5], conn=None, c=None):
|
||||||
|
"""
|
||||||
|
Process a record to extract the features.
|
||||||
|
Args:
|
||||||
|
data_dict_item (tuple): The tuple containing the data dictionary item.
|
||||||
|
sampling_rate (int): The sampling rate of the ECG signal.
|
||||||
|
used_channels (list): The list of used channels.
|
||||||
|
conn (object): The connection object to the database.
|
||||||
|
c (object): The cursor object.
|
||||||
|
Returns:
|
||||||
|
tuple: The tuple containing the record name and the extracted features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
label, record, c, conn = data_dict_item
|
||||||
|
features = get_features(record,label, sampling_rate=sampling_rate, used_channels=used_channels)
|
||||||
|
if features is None:
|
||||||
|
print(f"Failed to extract features for record {record.record_name}")
|
||||||
|
return record.record_name, None
|
||||||
|
# TODO: Insert the record into the database
|
||||||
|
#feature_data[record_name] = features
|
||||||
|
# Define the feature names
|
||||||
|
feature_names = list(features.keys())
|
||||||
|
# Create the table if it doesn't exist
|
||||||
|
c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})")
|
||||||
|
# Insert the record into the table
|
||||||
|
c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})",
|
||||||
|
[record.record_name] + list(features.values()))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
return record.record_name, features
|
||||||
|
|
||||||
|
def extract_features_parallel(data_dict, num_processes, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
|
||||||
|
"""
|
||||||
|
Extracts the features from the data.
|
||||||
|
Args:
|
||||||
|
data_dict (dict): The dictionary containing the data.
|
||||||
|
num_processes (int): The number of processes to use.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
failed_records = []
|
||||||
|
processed_records = 0
|
||||||
|
|
||||||
|
conn = sqlite3.connect('features.db')
|
||||||
|
c = conn.cursor()
|
||||||
|
# get unique data
|
||||||
|
data_dict = exclude_already_extracted(data_dict, conn)
|
||||||
|
|
||||||
|
for label, data in data_dict.items():
|
||||||
|
print(f"Extracting features for {label} with {len(data)} data entries.")
|
||||||
|
with Pool(processes=num_processes) as pool:
|
||||||
|
results = pool.map(process_record, [(label, record, c, conn) for record in data])
|
||||||
|
for result in results:
|
||||||
|
processed_records += 1
|
||||||
|
if processed_records % 100 == 0:
|
||||||
|
stop_time = time.time()
|
||||||
|
print(f"Extracted features for {processed_records} records. Time taken: {stop_time - start_time:.2f}s")
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
record_name, features = result
|
||||||
|
except Exception as exc:
|
||||||
|
print(f'{record_name} generated an exception: {exc}')
|
||||||
|
else:
|
||||||
|
if features is not None:
|
||||||
|
print(f"Extracted features for record {record_name}")
|
||||||
|
else:
|
||||||
|
failed_records.append(record_name)
|
||||||
|
print(f"Sum of failed records: {len(failed_records)}")
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
|
def extract_features(data_dict, sampling_rate=500, used_channels=[0, 1, 2, 3, 4, 5]):
|
||||||
"""
|
"""
|
||||||
Extracts the features from the data.
|
Extracts the features from the data.
|
||||||
Args:
|
Args:
|
||||||
data_dict (dict): The dictionary containing the data.
|
data_dict (dict): The dictionary containing the data.
|
||||||
Returns:
|
|
||||||
dict: The dictionary containing the extracted features.
|
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
feature_data = {}
|
|
||||||
failed_records = []
|
failed_records = []
|
||||||
|
|
||||||
|
conn = sqlite3.connect('features.db')
|
||||||
|
c = conn.cursor()
|
||||||
|
# get last file in the database
|
||||||
|
|
||||||
|
try:
|
||||||
|
# print how many records are already in the database
|
||||||
|
c.execute('SELECT COUNT(*) FROM features')
|
||||||
|
print("Records in DB:", c.fetchall())
|
||||||
|
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('SELECT id FROM features')
|
||||||
|
db_ids = c.fetchall()
|
||||||
|
db_ids = [x[0] for x in db_ids]
|
||||||
|
except:
|
||||||
|
db_ids = []
|
||||||
|
print("No last file in DB")
|
||||||
|
|
||||||
for label, data in data_dict.items():
|
for label, data in data_dict.items():
|
||||||
print(f"Extracting features for {label} with {len(data)} data entries.")
|
print(f"Extracting features for {label} with {len(data)} data entries.")
|
||||||
for data_idx, record in enumerate(data):
|
for data_idx, record in enumerate(data):
|
||||||
|
# Skip the records that are already in the database
|
||||||
|
if record.record_name in db_ids:
|
||||||
|
continue
|
||||||
|
# print current status
|
||||||
if data_idx % 100 == 0:
|
if data_idx % 100 == 0:
|
||||||
stop_time = time.time()
|
stop_time = time.time()
|
||||||
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
|
print(f"Extracted features for {data_idx} records. Time taken: {stop_time - start_time:.2f}s")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
# get the features
|
||||||
age = record.comments[0].split(' ')[1]
|
features = get_features(record, label, sampling_rate=sampling_rate, used_channels=used_channels, conn=conn)
|
||||||
gender = record.comments[1].split(' ')[1]
|
if features is None:
|
||||||
if age == 'NaN' or gender == 'NaN':
|
|
||||||
continue
|
|
||||||
features = {}
|
|
||||||
# Extract the features
|
|
||||||
features['y'] = label
|
|
||||||
# Demographic features
|
|
||||||
features['age'] = int(age)
|
|
||||||
features['gender'] = True if gender == 'Male' else False
|
|
||||||
# Signal features
|
|
||||||
|
|
||||||
ecg_signal = record.p_signal[:, 0]
|
|
||||||
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
|
|
||||||
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)
|
|
||||||
r_peaks = rpeaks['ECG_R_Peaks']
|
|
||||||
# Delineate the ECG signal
|
|
||||||
try:
|
|
||||||
_, waves_peak = nk.ecg_delineate(ecg_signal, r_peaks, sampling_rate=sampling_rate, method="peak")
|
|
||||||
except:
|
|
||||||
failed_records.append(record.record_name)
|
failed_records.append(record.record_name)
|
||||||
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
|
print(f"Failed to extract features for record {record.record_name} Sum of failed records: {len(failed_records)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO Other features and check features
|
|
||||||
atrial_rate = len(waves_peak['ECG_P_Peaks']) * 6
|
|
||||||
ventricular_rate = np.mean(nk.ecg_rate(r_peaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)))
|
|
||||||
|
|
||||||
features['artial_rate'] = atrial_rate
|
# Define the feature names
|
||||||
features['ventricular_rate'] = ventricular_rate
|
feature_names = list(features.keys())
|
||||||
|
# Create the table if it doesn't exist
|
||||||
qrs_duration = np.nanmean(np.array(waves_peak['ECG_S_Peaks']) - np.array(waves_peak['ECG_Q_Peaks']))
|
c.execute(f"CREATE TABLE IF NOT EXISTS features (id TEXT PRIMARY KEY, {', '.join(feature_names)})")
|
||||||
features['qrs_duration'] = qrs_duration
|
# Insert the record into the table
|
||||||
|
c.execute(f"INSERT INTO features (id, {', '.join(feature_names)}) VALUES (?, {', '.join('?' for _ in feature_names)})",
|
||||||
qt_interval = np.nanmean(np.array(waves_peak['ECG_T_Offsets']) - np.array(waves_peak['ECG_Q_Peaks']))
|
[record.record_name] + list(features.values()))
|
||||||
features['qt_length'] = qt_interval
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
q_peak = waves_peak['ECG_Q_Peaks']
|
|
||||||
s_peak = waves_peak['ECG_S_Peaks']
|
|
||||||
# check if q_peak, r_peak, s_peak are not nan and therefore a solid qrs complex exists
|
|
||||||
qrs_count = [any([math.isnan(q_peak[i]), math.isnan(r_peaks[i]), math.isnan(s_peak[i])]) for i in range(len(q_peak))].count(False)
|
|
||||||
features['qrs_count'] = qrs_count
|
|
||||||
|
|
||||||
features['q_peak'] = np.mean(get_y_value(ecg_cleaned, q_peak))
|
|
||||||
|
|
||||||
r_axis, t_axis = calculate_axis(record, waves_peak, r_peaks, sampling_rate=500, aVF=5, I=0)
|
|
||||||
features['r_axis'] = r_axis
|
|
||||||
features['t_axis'] = t_axis
|
|
||||||
|
|
||||||
# print the features
|
|
||||||
#print(json.dumps(features, indent=4))
|
|
||||||
feature_data[record.record_name] = features
|
|
||||||
|
|
||||||
return feature_data
|
|
||||||
|
|
||||||
|
|
||||||
def split_data(feature_data, split_ratio):
|
|
||||||
|
def split_and_shuffle_data(split_ratio, db=None):
|
||||||
|
"""
|
||||||
|
Splits the data into training, test, and validation sets.
|
||||||
|
Args:
|
||||||
|
split_ratio (list): The ratio in which the data will be split into training, test, and validation sets.
|
||||||
|
db (str): The name of the database.
|
||||||
|
"""
|
||||||
print(f"Splitting data with ratio {split_ratio}")
|
print(f"Splitting data with ratio {split_ratio}")
|
||||||
#flatten dictionary
|
|
||||||
feature_data = {k: v for k, v in feature_data.items()}
|
|
||||||
# print keys
|
|
||||||
print("Keys:")
|
|
||||||
print(len(feature_data.keys()))
|
|
||||||
# shuffle the data
|
|
||||||
keys = list(feature_data.keys())
|
|
||||||
np.random.shuffle(keys)
|
|
||||||
# split the data
|
|
||||||
split_data = {}
|
|
||||||
split_data['train'] = {k: feature_data[k] for k in keys[:int(len(keys) * split_ratio[0])]}
|
|
||||||
split_data['test'] = {k: feature_data[k] for k in keys[int(len(keys) * split_ratio[0]):int(len(keys) * (split_ratio[0] + split_ratio[1]))]}
|
|
||||||
split_data['validation'] = {k: feature_data[k] for k in keys[int(len(keys) * (split_ratio[0] + split_ratio[1])):]}
|
|
||||||
|
|
||||||
return split_data
|
if db is None:
|
||||||
|
db = 'features.db'
|
||||||
|
conn = sqlite3.connect(db)
|
||||||
|
c = conn.cursor()
|
||||||
|
# shuffle the data (rows)
|
||||||
|
# Randomize the rows and create a new table with a new ROWID
|
||||||
|
c.execute("CREATE TABLE IF NOT EXISTS features_random AS SELECT * FROM features ORDER BY RANDOM()")
|
||||||
|
# Drop the old features table
|
||||||
|
c.execute("DROP TABLE features")
|
||||||
|
# Rename the new table to features
|
||||||
|
c.execute("ALTER TABLE features_random RENAME TO features")
|
||||||
|
|
||||||
|
# get length of the data
|
||||||
|
c.execute('SELECT COUNT(*) FROM features')
|
||||||
|
length = c.fetchone()[0]
|
||||||
|
# split the data into training, test, and validation sets
|
||||||
|
train_idx = int(length * split_ratio[0])
|
||||||
|
test_idx = int(length * (split_ratio[0] + split_ratio[1]))
|
||||||
|
# create 3 tables for training, test, and validation sets
|
||||||
|
# Create the train, test, and validation tables if they don't exist
|
||||||
|
c.execute("CREATE TABLE IF NOT EXISTS train AS SELECT * FROM features WHERE 0")
|
||||||
|
c.execute("CREATE TABLE IF NOT EXISTS test AS SELECT * FROM features WHERE 0")
|
||||||
|
c.execute("CREATE TABLE IF NOT EXISTS validation AS SELECT * FROM features WHERE 0")
|
||||||
|
# Insert rows into the train table
|
||||||
|
c.execute("INSERT INTO train SELECT * FROM features WHERE ROWID <= ?", (train_idx,))
|
||||||
|
c.execute("INSERT INTO test SELECT * FROM features WHERE ROWID > ? AND ROWID <= ?", (train_idx, test_idx,))
|
||||||
|
c.execute("INSERT INTO validation SELECT * FROM features WHERE ROWID > ?", (test_idx,))
|
||||||
|
|
||||||
|
# drop the features table
|
||||||
|
# c.execute('DROP TABLE features')
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
|
@ -10,6 +10,7 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pickle
|
import pickle
|
||||||
import json
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
import feature_extraction
|
import feature_extraction
|
||||||
|
|
||||||
|
@ -98,7 +99,7 @@ def write_data(data_dict, path='./data', file_prefix=''):
|
||||||
with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f:
|
with open(f'{path}/{file_prefix}{cat_name}.pkl', 'wb') as f:
|
||||||
pickle.dump(data, f)
|
pickle.dump(data, f)
|
||||||
|
|
||||||
def generate_feature_data(input_data_path, output_data_path, settings, prefix='feature_', split_ratio=None):
|
def generate_feature_data(input_data_path, settings, parallel=False, split_ratio=None):
|
||||||
"""
|
"""
|
||||||
Generates the feature data from the raw data.
|
Generates the feature data from the raw data.
|
||||||
Args:
|
Args:
|
||||||
|
@ -111,26 +112,28 @@ def generate_feature_data(input_data_path, output_data_path, settings, prefix='f
|
||||||
"""
|
"""
|
||||||
if split_ratio is None:
|
if split_ratio is None:
|
||||||
split_ratio = settings['split_ratio']
|
split_ratio = settings['split_ratio']
|
||||||
data_dict = {}
|
print(list(os.listdir(input_data_path)))
|
||||||
for file in os.listdir(input_data_path):
|
for file in os.listdir(input_data_path):
|
||||||
if file.endswith(".pkl"):
|
if file.endswith(".pkl"):
|
||||||
print(f"Reading {file}")
|
print(f"Reading {file}")
|
||||||
with open(f'{input_data_path}/{file}', 'rb') as f:
|
with open(f'{input_data_path}/{file}', 'rb') as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
data_dict[file.replace('.pkl', '')] = data
|
#data_dict[file.replace('.pkl', '')] = data
|
||||||
# Extract the features
|
data_dict = {file.replace('.pkl', ''): data}
|
||||||
feature_data = feature_extraction.extract_features(data_dict)
|
# Extract the features
|
||||||
|
if parallel:
|
||||||
|
# get max number of processes
|
||||||
|
max_processes = multiprocessing.cpu_count() - 1
|
||||||
|
print(f"Using {max_processes} processes to extract features.")
|
||||||
|
feature_extraction.extract_features_parallel(data_dict, num_processes=max_processes)
|
||||||
|
else:
|
||||||
|
feature_extraction.extract_features(data_dict)
|
||||||
# Split the data
|
# Split the data
|
||||||
splited_data = feature_extraction.split_data(feature_data, split_ratio)
|
feature_extraction.split_and_shuffle_data(split_ratio=split_ratio)
|
||||||
if not os.path.exists(f'{output_data_path}/ml_dataset/'):
|
|
||||||
os.makedirs(f'{output_data_path}/ml_dataset/')
|
|
||||||
for file_name, data in splited_data.items():
|
|
||||||
print(f"Writing {file_name} to pickle with {len(data)} data entries.")
|
|
||||||
with open(f'{output_data_path}/ml_dataset/{prefix}{file_name}', 'wb') as f:
|
|
||||||
pickle.dump(data, f)
|
|
||||||
return splited_data
|
|
||||||
|
|
||||||
def main(gen_data=True, gen_features=True, split_ratio=None, settings_path='./settings.json', num_process_files=-1):
|
|
||||||
|
|
||||||
|
def main(gen_data=True, gen_features=True, split_ratio=None, parallel=False, settings_path='./settings.json', num_process_files=-1):
|
||||||
"""
|
"""
|
||||||
Main function to generate the data.
|
Main function to generate the data.
|
||||||
Args:
|
Args:
|
||||||
|
@ -154,7 +157,7 @@ def main(gen_data=True, gen_features=True, split_ratio=None, settings_path='./se
|
||||||
write_data(data_dict, path=settings["data_path"])
|
write_data(data_dict, path=settings["data_path"])
|
||||||
ret_data = data_dict
|
ret_data = data_dict
|
||||||
if gen_features:
|
if gen_features:
|
||||||
feature_data_dict = generate_feature_data(settings["data_path"], settings["data_path"], settings, split_ratio=split_ratio)
|
feature_data_dict = generate_feature_data(settings["data_path"], settings, split_ratio=split_ratio, parallel=parallel)
|
||||||
ret_data = feature_data_dict
|
ret_data = feature_data_dict
|
||||||
|
|
||||||
return ret_data
|
return ret_data
|
||||||
|
@ -176,5 +179,5 @@ if __name__ == '__main__':
|
||||||
# new GSVT, AFIB, SR, SB
|
# new GSVT, AFIB, SR, SB
|
||||||
# Generate the data
|
# Generate the data
|
||||||
main(gen_data=True, gen_features=False, num_process_files=100_000)
|
main(gen_data=True, gen_features=False, num_process_files=100_000)
|
||||||
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], num_process_files=100_000)
|
#main(gen_data=False, gen_features=True, split_ratio=[0.8, 0.1, 0.1], parallel=False, num_process_files=100_000)
|
||||||
print("Data generation completed.")
|
print("Data generation completed.")
|
||||||
|
|
Loading…
Reference in New Issue