import numpy as np from keras.datasets import mnist import matplotlib.pyplot as plt # Load MNIST print("Loading MNIST...") # debugging (X_train_raw, y_train), (X_test_raw, y_test) = mnist.load_data() # print("X_train_raw[0]:") # print(X_train_raw[0]) # output: 28x28 vector spelling 5 # print("y_train[0]:") # print(y_train[0]) # output: 5 # print("X_test_raw[0]:") # print(X_test_raw[0]) # output: 28x28 vector spelling 7 # print("y_test[0]:") # print(y_test[0]) # output: 7 # Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional X_train = X_train_raw.reshape(X_train_raw.shape[0], -1).astype(np.float32) X_test = X_test_raw.reshape(X_test_raw.shape[0], -1).astype(np.float32) print("Train:", X_train.shape, "Test:", X_test.shape) # Select first 1000 prototype vectors prototypes = X_train[:1000] prototype_labels = y_train[:1000] print("Using", len(prototypes), "prototype vectors.") # debugging # kNN function with explicit loops for readability def knn_predict_batch(X_batch, k=3): """ Predicts labels for a batch of test vectors using kNN. X_batch: shape (batch_size, 784) returns: shape (batch_size,) """ preds = [] # For each test image for test_img in X_batch: distances = [] # Euclidean distance to each prototype for prototype in prototypes: # distance = sqrt(sum((test_img - prototype)^2)) diff = test_img - prototype distance = np.sqrt(np.sum(diff ** 2)) distances.append(distance) # Find indices of k nearest neighbors (smallest distances) distances = np.array(distances) nearest_k_indices = np.argsort(distances)[:k] # returns indices of array with sorted distances # Get labels of the k nearest neighbors nearest_k_labels = prototype_labels[nearest_k_indices] # Majority vote prediction = np.bincount(nearest_k_labels, minlength=10).argmax() preds.append(prediction) return np.array(preds) # Evaluate on first N_TEST test samples N_TEST = 1000 print(f"Evaluating on {N_TEST} test samples...") # debugging X_eval = X_test[:N_TEST] y_eval = y_test[:N_TEST] preds = knn_predict_batch(X_eval, k=5) accuracy = np.mean(preds == y_eval) print("Predictions:", preds[:20]) print("True labels:", y_eval[:20]) print("Accuracy:", accuracy) # Visualize first 20 predictions fig, axes = plt.subplots(4, 5, figsize=(12, 10)) axes = axes.flatten() for i in range(0, 20): # Reshape flattened image back to 28x28 img = X_eval[i].reshape(28, 28) axes[i].imshow(img, cmap='gray') axes[i].set_title(f"Pred: {preds[i]}, True: {y_eval[i]}") axes[i].axis('off') plt.tight_layout() plt.show()