diff --git a/05_mnist_vectorquant/vector_quantization.py b/05_mnist_vectorquant/vector_quantization.py index fc1e91c..8c30fb0 100644 --- a/05_mnist_vectorquant/vector_quantization.py +++ b/05_mnist_vectorquant/vector_quantization.py @@ -4,26 +4,26 @@ import matplotlib.pyplot as plt # Load MNIST print("Loading MNIST...") # debugging -(X_train_raw, y_train), (X_test_raw, y_test) = mnist.load_data() +(prototype_data_raw, prototype_labels_set), (test_data_raw, test_labels_set) = mnist.load_data() -# print("X_train_raw[0]:") -# print(X_train_raw[0]) # output: 28x28 vector spelling 5 -# print("y_train[0]:") -# print(y_train[0]) # output: 5 -# print("X_test_raw[0]:") -# print(X_test_raw[0]) # output: 28x28 vector spelling 7 -# print("y_test[0]:") -# print(y_test[0]) # output: 7 +# print("prototype_data_raw[0]:") +# print(prototype_data_raw[0]) # output: 28x28 vector spelling 5 +# print("prototype_labels_set[0]:") +# print(prototype_labels_set[0]) # output: 5 +# print("test_data_raw[0]:") +# print(test_data_raw[0]) # output: 28x28 vector spelling 7 +# print("test_labels_set[0]:") +# print(test_labels_set[0]) # output: 7 # Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional -X_train = X_train_raw.reshape(X_train_raw.shape[0], -1).astype(np.float32) -X_test = X_test_raw.reshape(X_test_raw.shape[0], -1).astype(np.float32) +prototype_data = prototype_data_raw.reshape(prototype_data_raw.shape[0], -1).astype(np.float32) +test_data = test_data_raw.reshape(test_data_raw.shape[0], -1).astype(np.float32) -print("Train:", X_train.shape, "Test:", X_test.shape) +print("Train:", prototype_data.shape, "Test:", test_data.shape) # Select first 1000 prototype vectors -prototypes = X_train[:1000] -prototype_labels = y_train[:1000] +prototypes = prototype_data[:1000] +prototype_labels = prototype_labels_set[:1000] print("Using", len(prototypes), "prototype vectors.") # debugging @@ -47,7 +47,7 @@ def knn_predict_batch(X_batch, k=3): distance = np.sqrt(np.sum(diff ** 2)) distances.append(distance) - # Find indices of k nearest neighbors (smallest distances) + # Find indices of k nearest neighbors distances = np.array(distances) nearest_k_indices = np.argsort(distances)[:k] # returns indices of array with sorted distances @@ -65,8 +65,8 @@ def knn_predict_batch(X_batch, k=3): N_TEST = 1000 print(f"Evaluating on {N_TEST} test samples...") # debugging -X_eval = X_test[:N_TEST] -y_eval = y_test[:N_TEST] +X_eval = test_data[:N_TEST] +y_eval = test_labels_set[:N_TEST] preds = knn_predict_batch(X_eval, k=5)