import numpy as np trainingDataFile = 'data/t10k-images.idx3-ubyte' trainingLabelFile = 'data/t10k-labels.idx1-ubyte' # def euclidean_distance(img1, img2): # return np.sqrt(np.sum((img1 - img2) ** 2)) def getData(): with open(trainingDataFile, mode='rb') as file: magicNumber = file.read(4) numOfImages = int.from_bytes(file.read(4), 'big') height, width = int.from_bytes(file.read(4), 'big'), int.from_bytes(file.read(4), 'big') print(height, width) data = [] for i in range(numOfImages): image = [] for j in range(width): for k in range(height): pixel = int.from_bytes(file.read(1), 'big') image.append(pixel) data.append(np.array(image)) return data def getLabels(): with open(trainingLabelFile, mode='rb') as file: magicNumber = file.read(4) numOfImages = int.from_bytes(file.read(4), 'big') data = [] for i in range(numOfImages): data.append(int.from_bytes(file.read(1), 'big')) return data #np.lin.alg def euclidean_distance2(img1, img2): distance = 0 for i in range(img1.__len__()): distance += (img1[i] - img2[i])**2 return distance def nearestNeighbor(image, reference, referenceTags, k=11): nearestDistance = [] for i in range(reference.__len__()): distance = euclidean_distance2(image, reference[i]) nearestDistance.append((distance, referenceTags[i])) nearestDistance.sort(key=lambda x: x[0]) if nearestDistance.__len__() > k: nearestDistance.pop() labels = {} for i in nearestDistance: if labels.get(i[1]) is None: labels[i[1]] = 1 else: labels[i[1]] += 1 sorted(labels.items(), key=lambda item: item[1]) return list(labels.keys())[0] labels = getLabels() data = getData() baseVectorsCount = 1000 realLabels = labels[:baseVectorsCount] realData = data[:baseVectorsCount] trainingLabels = labels[7000:] trainingData = data[7000:] wrong = 0 for i in range(baseVectorsCount): nearestDistance = nearestNeighbor(realData[i], trainingData, trainingLabels) if nearestDistance != realLabels[i]: print("Identified as " + str(nearestDistance) + ", in reality: " + str(realLabels[i])) wrong += 1 if i % 100 == 0: if i == 0: continue print("Step: " + str(i) + " Error rate: " + str((wrong/i)*100) + "%") print("Final error rate: " + str((wrong/baseVectorsCount)*100) + "%")