diff --git a/uebungen/aufgabe3/uebung3.py b/uebungen/aufgabe3/uebung3.py index 3555e45..8535ec8 100644 --- a/uebungen/aufgabe3/uebung3.py +++ b/uebungen/aufgabe3/uebung3.py @@ -78,9 +78,9 @@ class RBM: return h0, v1 -rbm = RBM(28 ** 2, 100, 0.2, epochs=1) +rbm = RBM(28 ** 2, 100, 0.2, epochs=2) -for i in range(100): +for i in range(100, 600): # normalize mnist data and train number = mnist[i] / 255 rbm.train(number)