fixed splitting bug
parent
98b2d2e3c0
commit
75766ad784
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -56,15 +56,20 @@ def pad_sequences(sequences, max_len, pad_index):
|
|||
|
||||
|
||||
def split_data(X, y, test_size=0.1, val_size=0.1):
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size + val_size, random_state=42)
|
||||
val_split_ratio = val_size / (val_size + test_size)
|
||||
X_test, X_val, y_test, y_val = train_test_split(X_train, y_train, test_size=val_split_ratio, random_state=42)
|
||||
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=test_size + val_size, random_state=42)
|
||||
val_split_ratio = val_size / (test_size + val_size)
|
||||
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1 - val_split_ratio, random_state=42)
|
||||
|
||||
ret_dict = {
|
||||
'train': {'X': X_train, 'y': y_train},
|
||||
'test': {'X': X_test, 'y': y_test},
|
||||
'val': {'X': X_val, 'y': y_val}
|
||||
}
|
||||
|
||||
# for each print len
|
||||
for key in ret_dict.keys():
|
||||
print(key, len(ret_dict[key]['X']), len(ret_dict[key]['y']))
|
||||
|
||||
return ret_dict
|
||||
|
||||
def save_data(data_dict, path, prefix, vocab_size=0, emb_dim=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue