92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
import numpy as np
|
|
from keras.datasets import mnist
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Load MNIST
|
|
print("Loading MNIST...") # debugging
|
|
(prototype_data_raw, prototype_labels_set), (test_data_raw, test_labels_set) = mnist.load_data()
|
|
|
|
# print("prototype_data_raw[0]:")
|
|
# print(prototype_data_raw[0]) # output: 28x28 vector spelling 5
|
|
# print("prototype_labels_set[0]:")
|
|
# print(prototype_labels_set[0]) # output: 5
|
|
# print("test_data_raw[0]:")
|
|
# print(test_data_raw[0]) # output: 28x28 vector spelling 7
|
|
# print("test_labels_set[0]:")
|
|
# print(test_labels_set[0]) # output: 7
|
|
|
|
# Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional
|
|
prototype_data = prototype_data_raw.reshape(prototype_data_raw.shape[0], -1).astype(np.float32)
|
|
test_data = test_data_raw.reshape(test_data_raw.shape[0], -1).astype(np.float32)
|
|
|
|
print("Train:", prototype_data.shape, "Test:", test_data.shape)
|
|
|
|
# Select first 1000 prototype vectors
|
|
prototypes = prototype_data[:1000]
|
|
prototype_labels = prototype_labels_set[:1000]
|
|
|
|
print("Using", len(prototypes), "prototype vectors.") # debugging
|
|
|
|
# kNN function with explicit loops for readability
|
|
def knn_predict_batch(X_batch, k=3):
|
|
"""
|
|
Predicts labels for a batch of test vectors using kNN.
|
|
X_batch: shape (batch_size, 784)
|
|
returns: shape (batch_size,)
|
|
"""
|
|
preds = []
|
|
|
|
# For each test image
|
|
for test_img in X_batch:
|
|
distances = []
|
|
|
|
# Euclidean distance to each prototype
|
|
for prototype in prototypes:
|
|
# distance = sqrt(sum((test_img - prototype)^2))
|
|
diff = test_img - prototype
|
|
distance = np.sqrt(np.sum(diff ** 2))
|
|
distances.append(distance)
|
|
|
|
# Find indices of k nearest neighbors
|
|
distances = np.array(distances)
|
|
nearest_k_indices = np.argsort(distances)[:k] # returns indices of array with sorted distances
|
|
|
|
# Get labels of the k nearest neighbors
|
|
nearest_k_labels = prototype_labels[nearest_k_indices]
|
|
|
|
# Majority vote
|
|
prediction = np.bincount(nearest_k_labels, minlength=10).argmax()
|
|
preds.append(prediction)
|
|
|
|
return np.array(preds)
|
|
|
|
|
|
# Evaluate on first N_TEST test samples
|
|
N_TEST = 1000
|
|
print(f"Evaluating on {N_TEST} test samples...") # debugging
|
|
|
|
X_eval = test_data[:N_TEST]
|
|
y_eval = test_labels_set[:N_TEST]
|
|
|
|
preds = knn_predict_batch(X_eval, k=5)
|
|
|
|
accuracy = np.mean(preds == y_eval)
|
|
|
|
print("Predictions:", preds[:20])
|
|
print("True labels:", y_eval[:20])
|
|
print("Accuracy:", accuracy)
|
|
|
|
# Visualize first 20 predictions
|
|
fig, axes = plt.subplots(4, 5, figsize=(12, 10))
|
|
axes = axes.flatten()
|
|
|
|
for i in range(0, 20):
|
|
# Reshape flattened image back to 28x28
|
|
img = X_eval[i].reshape(28, 28)
|
|
axes[i].imshow(img, cmap='gray')
|
|
axes[i].set_title(f"Pred: {preds[i]}, True: {y_eval[i]}")
|
|
axes[i].axis('off')
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|