import numpy as np from keras.datasets import mnist import matplotlib.pyplot as plt # Load MNIST print("Loading MNIST...") # debugging (prototype_data_raw, prototype_labels_set), (test_data_raw, test_labels_set) = mnist.load_data() # print("prototype_data_raw[0]:") # print(prototype_data_raw[0]) # output: 28x28 vector spelling 5 # print("prototype_labels_set[0]:") # print(prototype_labels_set[0]) # output: 5 # print("test_data_raw[0]:") # print(test_data_raw[0]) # output: 28x28 vector spelling 7 # print("test_labels_set[0]:") # print(test_labels_set[0]) # output: 7 # Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional prototype_data = prototype_data_raw.reshape(prototype_data_raw.shape[0], -1).astype(np.float32) test_data = test_data_raw.reshape(test_data_raw.shape[0], -1).astype(np.float32) print("Train:", prototype_data.shape, "Test:", test_data.shape) # Select first 1000 prototype vectors prototypes = prototype_data[:1000] prototype_labels = prototype_labels_set[:1000] print("Using", len(prototypes), "prototype vectors.") # debugging # kNN function with explicit loops for readability def knn_predict_batch(X_batch, k=3): """ Predicts labels for a batch of test vectors using kNN. X_batch: shape (batch_size, 784) returns: shape (batch_size,) """ preds = [] # For each test image for test_img in X_batch: distances = [] # Euclidean distance to each prototype for prototype in prototypes: # distance = sqrt(sum((test_img - prototype)^2)) diff = test_img - prototype distance = np.sqrt(np.sum(diff ** 2)) distances.append(distance) # Find indices of k nearest neighbors distances = np.array(distances) nearest_k_indices = np.argsort(distances)[:k] # returns indices of array with sorted distances # Get labels of the k nearest neighbors nearest_k_labels = prototype_labels[nearest_k_indices] # Majority vote prediction = np.bincount(nearest_k_labels, minlength=10).argmax() preds.append(prediction) return np.array(preds) # Evaluate on first N_TEST test samples N_TEST = 1000 print(f"Evaluating on {N_TEST} test samples...") # debugging X_eval = test_data[:N_TEST] y_eval = test_labels_set[:N_TEST] preds = knn_predict_batch(X_eval, k=5) accuracy = np.mean(preds == y_eval) print("Predictions:", preds[:20]) print("True labels:", y_eval[:20]) print("Accuracy:", accuracy) # Visualize first 20 predictions fig, axes = plt.subplots(4, 5, figsize=(12, 10)) axes = axes.flatten() for i in range(0, 20): # Reshape flattened image back to 28x28 img = X_eval[i].reshape(28, 28) axes[i].imshow(img, cmap='gray') axes[i].set_title(f"Pred: {preds[i]}, True: {y_eval[i]}") axes[i].axis('off') plt.tight_layout() plt.show()