diff --git a/data/embedded_padded/test.pt b/data/embedded_padded/test.pt index eebf938..5456946 100644 Binary files a/data/embedded_padded/test.pt and b/data/embedded_padded/test.pt differ diff --git a/data/embedded_padded/val.pt b/data/embedded_padded/val.pt index 8a46d1e..a63f08e 100644 Binary files a/data/embedded_padded/val.pt and b/data/embedded_padded/val.pt differ diff --git a/data/idx_based_padded/test.pt b/data/idx_based_padded/test.pt index 91a44cd..7542d35 100644 Binary files a/data/idx_based_padded/test.pt and b/data/idx_based_padded/test.pt differ diff --git a/data/idx_based_padded/val.pt b/data/idx_based_padded/val.pt index c59c227..4756a64 100644 Binary files a/data/idx_based_padded/val.pt and b/data/idx_based_padded/val.pt differ diff --git a/dataset_generator.py b/dataset_generator.py index 35ae04e..aa14657 100644 --- a/dataset_generator.py +++ b/dataset_generator.py @@ -56,15 +56,20 @@ def pad_sequences(sequences, max_len, pad_index): def split_data(X, y, test_size=0.1, val_size=0.1): - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size + val_size, random_state=42) - val_split_ratio = val_size / (val_size + test_size) - X_test, X_val, y_test, y_val = train_test_split(X_train, y_train, test_size=val_split_ratio, random_state=42) + X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=test_size + val_size, random_state=42) + val_split_ratio = val_size / (test_size + val_size) + X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1 - val_split_ratio, random_state=42) ret_dict = { 'train': {'X': X_train, 'y': y_train}, 'test': {'X': X_test, 'y': y_test}, 'val': {'X': X_val, 'y': y_val} } + + # for each print len + for key in ret_dict.keys(): + print(key, len(ret_dict[key]['X']), len(ret_dict[key]['y'])) + return ret_dict def save_data(data_dict, path, prefix, vocab_size=0, emb_dim=None):