fixed splitting bug

main
Felix Jan Michael Mucha 2025-02-07 23:08:24 +01:00
parent 98b2d2e3c0
commit 75766ad784
5 changed files with 8 additions and 3 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -56,15 +56,20 @@ def pad_sequences(sequences, max_len, pad_index):
def split_data(X, y, test_size=0.1, val_size=0.1):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size + val_size, random_state=42)
val_split_ratio = val_size / (val_size + test_size)
X_test, X_val, y_test, y_val = train_test_split(X_train, y_train, test_size=val_split_ratio, random_state=42)
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=test_size + val_size, random_state=42)
val_split_ratio = val_size / (test_size + val_size)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1 - val_split_ratio, random_state=42)
ret_dict = {
'train': {'X': X_train, 'y': y_train},
'test': {'X': X_test, 'y': y_test},
'val': {'X': X_val, 'y': y_val}
}
# for each print len
for key in ret_dict.keys():
print(key, len(ret_dict[key]['X']), len(ret_dict[key]['y']))
return ret_dict
def save_data(data_dict, path, prefix, vocab_size=0, emb_dim=None):