MLE/05_mnist_vectorquant/vector_quantization.py

94 lines
3.2 KiB
Python

import numpy as np
from keras.datasets import mnist
import matplotlib.pyplot as plt
# Load MNIST
print("Loading MNIST...") # debugging
(prototype_data_raw, prototype_labels_set), (test_data_raw, test_labels_set) = mnist.load_data()
# print("prototype_data_raw[0]:")
# print(prototype_data_raw[0]) # output: 28x28 vector spelling 5
# print("prototype_labels_set[0]:")
# print(prototype_labels_set[0]) # output: 5
# print("test_data_raw[0]:")
# print(test_data_raw[0]) # output: 28x28 vector spelling 7
# print("test_labels_set[0]:")
# print(test_labels_set[0]) # output: 7
# Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional
prototype_data = prototype_data_raw.reshape(prototype_data_raw.shape[0], -1).astype(np.float32)
test_data = test_data_raw.reshape(test_data_raw.shape[0], -1).astype(np.float32)
# print("Train:", prototype_data.shape, "Test:", test_data.shape) # debugging
# Select first 1000 prototype vectors
prototypes = prototype_data[:1000]
prototype_labels = prototype_labels_set[:1000]
# print("Using", len(prototypes), "prototype vectors.") # debugging
# kNN function with explicit loops for readability
def knn_predict_batch(X_batch, k=3):
"""
Predicts labels for a batch of test vectors using kNN.
X_batch: shape (batch_size, 784)
returns: shape (batch_size,)
"""
preds = []
# For each test image
for test_img in X_batch:
distances = []
# Euclidean distance
for prototype in prototypes: # For each prototype
diff = test_img - prototype
distance = np.sqrt(np.sum(diff ** 2))
distances.append(distance)
# Find indices of k nearest neighbors
distances = np.array(distances)
# sort distances and put the indices of the sorted array in an array
nearest_k_indices = np.argsort(distances)[:k] # put indices of the values with k-lowest distances in nearest_k_indices
# Get array with labels of the k nearest neighbors
nearest_k_labels = prototype_labels[nearest_k_indices]
# Majority vote
# count occurences and put that count in an array at the index of that value
# print(np.bincount(nearest_k_labels, minlength=10)) # debugging
prediction = np.bincount(nearest_k_labels, minlength=10).argmax() # argmax returns index of highest value (which is the actual value/number!!)
preds.append(prediction) # prediction for this test image
return np.array(preds) # prediction for every test image
# Evaluate on first N_TEST test samples
N_TEST = 1000
print(f"Evaluating on {N_TEST} test samples...") # debugging
data_eval = test_data[:N_TEST]
label_eval = test_labels_set[:N_TEST]
preds = knn_predict_batch(data_eval, k=3)
accuracy = np.mean(preds == label_eval) # calc accuracy
print("Predictions:", preds[:20])
print("True labels:", label_eval[:20])
print("Accuracy:", accuracy)
# Visualize first 20 predictions
fig, axes = plt.subplots(10, 5, figsize=(12, 10))
axes = axes.flatten()
for i in range(0, 50):
# Reshape flattened image back to 28x28
img = data_eval[i].reshape(28, 28)
axes[i].imshow(img, cmap='gray')
axes[i].set_title(f"Pred: {preds[i]}, True: {label_eval[i]}")
axes[i].axis('off')
plt.tight_layout()
plt.show()