gnn/uebungen/uebung2.py

95 lines
2.2 KiB
Python

import numpy as np
# Sigmoide Aktivierungsfunktion und ihre Ableitung
def sigmoid(x):
return 1 / (1 + np.exp(-x)) # Sigmoidfunktion
def deriv_sigmoid(x):
return x * (1 - x) # Ableitung der Sigmoiden
# Das XOR-Problem, input [bias, x, y] und Target-Daten
inp = np.array([[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1]])
target = np.array([[0], [1], [1], [0]])
# Die Architektur des neuronalen Netzes
inp_size = 3 # Eingabeneuronen
hid_size = 4 # Hidden-Neuronen
out_size = 1 # Ausgabeneuron
delta_L1 = np.ones((inp_size, hid_size)) * 0.125
delta_L2 = np.ones((hid_size, out_size)) * 0.125
delta_min = 0
delta_max = 50
# Gewichte zufällig initialisieren (Mittelwert = 0)
w0 = np.random.random((inp_size, hid_size)) - 0.5
w1 = np.random.random((hid_size, out_size)) - 0.5
def multiply_learnrate(old, new):
if old * new > 0:
return 1.2
elif old * new < 0:
return 0.5
return 1
v_multiply_learnrate = np.vectorize(multiply_learnrate)
L2_grad_old = np.zeros((4, 1))
L1_grad_old = np.zeros((3, 4))
# Netzwerk trainieren
for i in range(100):
# Vorwärtsaktivierung
L0 = inp
L1 = sigmoid(np.matmul(L0, w0))
L1[0] = 1 # Bias-Neuron in der Hiddenschicht
L2 = sigmoid(np.matmul(L1, w1))
# Fehler berechnen
L2_error = L2 - target
# Backpropagation
L2_delta = L2_error * deriv_sigmoid(L2) # Gradient eL2
L1_error = np.matmul(L2_delta, w1.T)
L1_delta = L1_error * deriv_sigmoid(L1)
# Gradienten
L2_grad_new = np.matmul(L1.T, L2_delta)
L1_grad_new = np.matmul(L0.T, L1_delta)
# Gewichte aktualisieren
learnrate = 0.1
delta_L1 = np.clip(
delta_L1 * v_multiply_learnrate(L1_grad_old, L1_grad_new), 0, 50)
delta_L2 = np.clip(
delta_L2 * v_multiply_learnrate(L2_grad_old, L2_grad_new), 0, 50)
w1 -= learnrate * np.sign(L2_grad_new) * delta_L2
w0 -= learnrate * np.sign(L1_grad_new) * delta_L1
# Gradienten aktualisieren
L1_grad_old = np.copy(L1_grad_new)
L2_grad_old = np.copy(L2_grad_new)
# Netzwerk testen
L0 = inp
L1 = sigmoid(np.matmul(inp, w0))
L1[0] = 1 # Bias-Neuron in der Hiddenschicht
L2 = sigmoid(np.matmul(L1, w1))
print(L2)