From 8257062dc35d74dd46c84a7061ad5f9ff35768e4 Mon Sep 17 00:00:00 2001 From: Ruben Seitz Date: Sun, 7 Dec 2025 14:49:03 +0100 Subject: [PATCH] First implementation of VQ --- 05_mnist_vectorquant/vector_quantization.py | 62 +++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 05_mnist_vectorquant/vector_quantization.py diff --git a/05_mnist_vectorquant/vector_quantization.py b/05_mnist_vectorquant/vector_quantization.py new file mode 100644 index 0000000..8d37ab1 --- /dev/null +++ b/05_mnist_vectorquant/vector_quantization.py @@ -0,0 +1,62 @@ +import numpy as np +from keras.datasets import mnist + +# Load MNIST +print("Loading MNIST...") # debugging +(X_train_raw, y_train), (X_test_raw, y_test) = mnist.load_data() + +# Flatten images from (samples, 28, 28) to (samples, 784) +X_train = X_train_raw.reshape(X_train_raw.shape[0], -1).astype(np.float32) +X_test = X_test_raw.reshape(X_test_raw.shape[0], -1).astype(np.float32) + +print("Train:", X_train.shape, "Test:", X_test.shape) + +# Select first 1000 prototype vectors +prototypes = X_train[:1000] +prototype_labels = y_train[:1000] + +print("Using", len(prototypes), "prototype vectors.") # debugging + +# Fully vectorized kNN function +def knn_predict_batch(X_batch, k=3): + """ + Predicts labels for a batch of test vectors using fully vectorized kNN. + X_batch: shape (batch_size, 784) + returns: shape (batch_size,) + """ + + # distance[i, j] = || X_batch[i] - prototypes[j] || + # Efficient: (a - b)^2 = a^2 + b^2 - 2ab + a2 = np.sum(X_batch**2, axis=1, keepdims=True) # shape (N, 1) + b2 = np.sum(prototypes**2, axis=1) # shape (1000,) + ab = X_batch @ prototypes.T # shape (N, 1000) + + distances = np.sqrt(a2 - 2*ab + b2) # shape (N, 1000) + + # Get k nearest neighbors for each test vector + knn_idx = np.argpartition(distances, k, axis=1)[:, :k] + + # Get labels of those neighbors + knn_labels = prototype_labels[knn_idx] + + # Majority vote (vectorized) + preds = np.array([np.bincount(row, minlength=10).argmax() + for row in knn_labels]) + + return preds + + +# 4. Evaluate on first N_TEST test samples +N_TEST = 1000 +print(f"Evaluating on {N_TEST} test samples...") # debugging + +X_eval = X_test[:N_TEST] +y_eval = y_test[:N_TEST] + +preds = knn_predict_batch(X_eval, k=3) + +accuracy = np.mean(preds == y_eval) + +print("Predictions:", preds[:20]) +print("True labels:", y_eval[:20]) +print("Accuracy:", accuracy)