63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
import numpy as np
|
|
from keras.datasets import mnist
|
|
|
|
# Load MNIST
|
|
print("Loading MNIST...") # debugging
|
|
(X_train_raw, y_train), (X_test_raw, y_test) = mnist.load_data()
|
|
|
|
# Flatten images from (samples, 28, 28) to (samples, 784)
|
|
X_train = X_train_raw.reshape(X_train_raw.shape[0], -1).astype(np.float32)
|
|
X_test = X_test_raw.reshape(X_test_raw.shape[0], -1).astype(np.float32)
|
|
|
|
print("Train:", X_train.shape, "Test:", X_test.shape)
|
|
|
|
# Select first 1000 prototype vectors
|
|
prototypes = X_train[:1000]
|
|
prototype_labels = y_train[:1000]
|
|
|
|
print("Using", len(prototypes), "prototype vectors.") # debugging
|
|
|
|
# Fully vectorized kNN function
|
|
def knn_predict_batch(X_batch, k=3):
|
|
"""
|
|
Predicts labels for a batch of test vectors using fully vectorized kNN.
|
|
X_batch: shape (batch_size, 784)
|
|
returns: shape (batch_size,)
|
|
"""
|
|
|
|
# distance[i, j] = || X_batch[i] - prototypes[j] ||
|
|
# Efficient: (a - b)^2 = a^2 + b^2 - 2ab
|
|
a2 = np.sum(X_batch**2, axis=1, keepdims=True) # shape (N, 1)
|
|
b2 = np.sum(prototypes**2, axis=1) # shape (1000,)
|
|
ab = X_batch @ prototypes.T # shape (N, 1000)
|
|
|
|
distances = np.sqrt(a2 - 2*ab + b2) # shape (N, 1000)
|
|
|
|
# Get k nearest neighbors for each test vector
|
|
knn_idx = np.argpartition(distances, k, axis=1)[:, :k]
|
|
|
|
# Get labels of those neighbors
|
|
knn_labels = prototype_labels[knn_idx]
|
|
|
|
# Majority vote (vectorized)
|
|
preds = np.array([np.bincount(row, minlength=10).argmax()
|
|
for row in knn_labels])
|
|
|
|
return preds
|
|
|
|
|
|
# 4. Evaluate on first N_TEST test samples
|
|
N_TEST = 1000
|
|
print(f"Evaluating on {N_TEST} test samples...") # debugging
|
|
|
|
X_eval = X_test[:N_TEST]
|
|
y_eval = y_test[:N_TEST]
|
|
|
|
preds = knn_predict_batch(X_eval, k=3)
|
|
|
|
accuracy = np.mean(preds == y_eval)
|
|
|
|
print("Predictions:", preds[:20])
|
|
print("True labels:", y_eval[:20])
|
|
print("Accuracy:", accuracy)
|