MLE_Assignments/Aufgabe_5/Aufgabe_5.py

96 lines
2.5 KiB
Python

import numpy as np
trainingDataFile = 'data/t10k-images.idx3-ubyte'
trainingLabelFile = 'data/t10k-labels.idx1-ubyte'
# def euclidean_distance(img1, img2):
# return np.sqrt(np.sum((img1 - img2) ** 2))
def getData():
with open(trainingDataFile, mode='rb') as file:
magicNumber = file.read(4)
numOfImages = int.from_bytes(file.read(4), 'big')
height, width = int.from_bytes(file.read(4), 'big'), int.from_bytes(file.read(4), 'big')
print(height, width)
data = []
for i in range(numOfImages):
image = []
for j in range(width):
for k in range(height):
pixel = int.from_bytes(file.read(1), 'big')
image.append(pixel)
data.append(np.array(image))
return data
def getLabels():
with open(trainingLabelFile, mode='rb') as file:
magicNumber = file.read(4)
numOfImages = int.from_bytes(file.read(4), 'big')
data = []
for i in range(numOfImages):
data.append(int.from_bytes(file.read(1), 'big'))
return data
#np.lin.alg
def euclidean_distance2(img1, img2):
distance = 0
for i in range(img1.__len__()):
distance += (img1[i] - img2[i])**2
return distance
def nearestNeighbor(image, reference, referenceTags, k=11):
nearestDistance = []
for i in range(reference.__len__()):
distance = euclidean_distance2(image, reference[i])
nearestDistance.append((distance, referenceTags[i]))
nearestDistance.sort(key=lambda x: x[0])
if nearestDistance.__len__() > k:
nearestDistance.pop()
labels = {}
for i in nearestDistance:
if labels.get(i[1]) is None:
labels[i[1]] = 1
else:
labels[i[1]] += 1
sorted(labels.items(), key=lambda item: item[1])
return list(labels.keys())[0]
labels = getLabels()
data = getData()
baseVectorsCount = 1000
realLabels = labels[:baseVectorsCount]
realData = data[:baseVectorsCount]
trainingLabels = labels[7000:]
trainingData = data[7000:]
wrong = 0
for i in range(baseVectorsCount):
nearestDistance = nearestNeighbor(realData[i], trainingData, trainingLabels)
if nearestDistance != realLabels[i]:
print("Identified as " + str(nearestDistance) + ", in reality: " + str(realLabels[i]))
wrong += 1
if i % 100 == 0:
if i == 0:
continue
print("Step: " + str(i) + " Error rate: " + str((wrong/i)*100) + "%")
print("Final error rate: " + str((wrong/baseVectorsCount)*100) + "%")