renamed vector quantization variables for readability
parent
0798236e26
commit
c88d8d003e
|
|
@ -4,26 +4,26 @@ import matplotlib.pyplot as plt
|
|||
|
||||
# Load MNIST
|
||||
print("Loading MNIST...") # debugging
|
||||
(X_train_raw, y_train), (X_test_raw, y_test) = mnist.load_data()
|
||||
(prototype_data_raw, prototype_labels_set), (test_data_raw, test_labels_set) = mnist.load_data()
|
||||
|
||||
# print("X_train_raw[0]:")
|
||||
# print(X_train_raw[0]) # output: 28x28 vector spelling 5
|
||||
# print("y_train[0]:")
|
||||
# print(y_train[0]) # output: 5
|
||||
# print("X_test_raw[0]:")
|
||||
# print(X_test_raw[0]) # output: 28x28 vector spelling 7
|
||||
# print("y_test[0]:")
|
||||
# print(y_test[0]) # output: 7
|
||||
# print("prototype_data_raw[0]:")
|
||||
# print(prototype_data_raw[0]) # output: 28x28 vector spelling 5
|
||||
# print("prototype_labels_set[0]:")
|
||||
# print(prototype_labels_set[0]) # output: 5
|
||||
# print("test_data_raw[0]:")
|
||||
# print(test_data_raw[0]) # output: 28x28 vector spelling 7
|
||||
# print("test_labels_set[0]:")
|
||||
# print(test_labels_set[0]) # output: 7
|
||||
|
||||
# Flatten images from (samples, 28, 28) to (samples, 784) -> one dimensional
|
||||
X_train = X_train_raw.reshape(X_train_raw.shape[0], -1).astype(np.float32)
|
||||
X_test = X_test_raw.reshape(X_test_raw.shape[0], -1).astype(np.float32)
|
||||
prototype_data = prototype_data_raw.reshape(prototype_data_raw.shape[0], -1).astype(np.float32)
|
||||
test_data = test_data_raw.reshape(test_data_raw.shape[0], -1).astype(np.float32)
|
||||
|
||||
print("Train:", X_train.shape, "Test:", X_test.shape)
|
||||
print("Train:", prototype_data.shape, "Test:", test_data.shape)
|
||||
|
||||
# Select first 1000 prototype vectors
|
||||
prototypes = X_train[:1000]
|
||||
prototype_labels = y_train[:1000]
|
||||
prototypes = prototype_data[:1000]
|
||||
prototype_labels = prototype_labels_set[:1000]
|
||||
|
||||
print("Using", len(prototypes), "prototype vectors.") # debugging
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ def knn_predict_batch(X_batch, k=3):
|
|||
distance = np.sqrt(np.sum(diff ** 2))
|
||||
distances.append(distance)
|
||||
|
||||
# Find indices of k nearest neighbors (smallest distances)
|
||||
# Find indices of k nearest neighbors
|
||||
distances = np.array(distances)
|
||||
nearest_k_indices = np.argsort(distances)[:k] # returns indices of array with sorted distances
|
||||
|
||||
|
|
@ -65,8 +65,8 @@ def knn_predict_batch(X_batch, k=3):
|
|||
N_TEST = 1000
|
||||
print(f"Evaluating on {N_TEST} test samples...") # debugging
|
||||
|
||||
X_eval = X_test[:N_TEST]
|
||||
y_eval = y_test[:N_TEST]
|
||||
X_eval = test_data[:N_TEST]
|
||||
y_eval = test_labels_set[:N_TEST]
|
||||
|
||||
preds = knn_predict_batch(X_eval, k=5)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue