renamed vector quantization variables for readability

master
Ruben-FreddyLoafers 2025-12-08 14:57:49 +01:00
parent 0798236e26
commit c88d8d003e
1 changed files with 17 additions and 17 deletions

View File

@ -4,26 +4,26 @@ import matplotlib.pyplot as plt
# Load MNIST # Load MNIST
print("Loading MNIST...") # debugging print("Loading MNIST...") # debugging
(X_train_raw, y_train), (X_test_raw, y_test) = mnist.load_data() (prototype_data_raw, prototype_labels_set), (test_data_raw, test_labels_set) = mnist.load_data()
# print("X_train_raw[0]:") # print("prototype_data_raw[0]:")
# print(X_train_raw[0]) # output: 28x28 vector spelling 5 # print(prototype_data_raw[0]) # output: 28x28 vector spelling 5
# print("y_train[0]:") # print("prototype_labels_set[0]:")
# print(y_train[0]) # output: 5 # print(prototype_labels_set[0]) # output: 5
# print("X_test_raw[0]:") # print("test_data_raw[0]:")
# print(X_test_raw[0]) # output: 28x28 vector spelling 7 # print(test_data_raw[0]) # output: 28x28 vector spelling 7
# print("y_test[0]:") # print("test_labels_set[0]:")
# print(y_test[0]) # output: 7 # print(test_labels_set[0]) # output: 7
# Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional # Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional
X_train = X_train_raw.reshape(X_train_raw.shape[0], -1).astype(np.float32) prototype_data = prototype_data_raw.reshape(prototype_data_raw.shape[0], -1).astype(np.float32)
X_test = X_test_raw.reshape(X_test_raw.shape[0], -1).astype(np.float32) test_data = test_data_raw.reshape(test_data_raw.shape[0], -1).astype(np.float32)
print("Train:", X_train.shape, "Test:", X_test.shape) print("Train:", prototype_data.shape, "Test:", test_data.shape)
# Select first 1000 prototype vectors # Select first 1000 prototype vectors
prototypes = X_train[:1000] prototypes = prototype_data[:1000]
prototype_labels = y_train[:1000] prototype_labels = prototype_labels_set[:1000]
print("Using", len(prototypes), "prototype vectors.") # debugging print("Using", len(prototypes), "prototype vectors.") # debugging
@ -47,7 +47,7 @@ def knn_predict_batch(X_batch, k=3):
distance = np.sqrt(np.sum(diff ** 2)) distance = np.sqrt(np.sum(diff ** 2))
distances.append(distance) distances.append(distance)
# Find indices of k nearest neighbors (smallest distances) # Find indices of k nearest neighbors
distances = np.array(distances) distances = np.array(distances)
nearest_k_indices = np.argsort(distances)[:k] # returns indices of array with sorted distances nearest_k_indices = np.argsort(distances)[:k] # returns indices of array with sorted distances
@ -65,8 +65,8 @@ def knn_predict_batch(X_batch, k=3):
N_TEST = 1000 N_TEST = 1000
print(f"Evaluating on {N_TEST} test samples...") # debugging print(f"Evaluating on {N_TEST} test samples...") # debugging
X_eval = X_test[:N_TEST] X_eval = test_data[:N_TEST]
y_eval = y_test[:N_TEST] y_eval = test_labels_set[:N_TEST]
preds = knn_predict_batch(X_eval, k=5) preds = knn_predict_batch(X_eval, k=5)