From ef7be8ac54ebe0c5826255f03f390d2086d7574a Mon Sep 17 00:00:00 2001 From: romanamo Date: Thu, 9 May 2024 00:41:02 +0200 Subject: [PATCH] finally --- uebungen/uebung3.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/uebungen/uebung3.py b/uebungen/uebung3.py index a0b7b8a..5aa1627 100644 --- a/uebungen/uebung3.py +++ b/uebungen/uebung3.py @@ -29,8 +29,8 @@ class RBM: def reset(self) -> None: self.weights = np.random.randn(self.visible_size, self.hidden_size) - self.visible_bias = np.ones(self.visible_size) * 0.1 - self.hidden_bias = np.ones(self.hidden_size) * 0.1 + self.visible_bias = np.zeros(self.visible_size) * 0.1 + self.hidden_bias = np.zeros(self.hidden_size) * 0.1 def activate(self, v0): return sigmoid(np.matmul(v0.T, self.weights) + self.hidden_bias) @@ -45,14 +45,6 @@ class RBM: self.weights += self.learnrate * (postive_gradient - negative_gradient) return self.weights - - def gibbs(self, v0): - for _ in range(self.k): - hidden_probs = self.activate(v0) - h0 = np.random.rand(len(v0), self.hidden_size) < hidden_probs - visible_probs = self.sigmoid(np.dot(h0, self.weights.T) + self.visible_bias) - v0 = np.random.rand(len(v0), self.visible_size) < visible_probs - return v0, hidden_probs def train(self, v0): for _ in range(self.epochs): @@ -85,15 +77,14 @@ class RBM: def validate(idx): - test = mnist[idx].flatten() - #rbm.reset() + test = mnist[idx].flatten()/255 rbm.train(test) (hid, out) = rbm.run(test) - return (hid.reshape((28, 28)), out.reshape((28,28))) + return (hid.reshape((5, 5)), out.reshape((28,28))) -rbm = RBM(28**2, 28**2, 0.1) +rbm = RBM(28**2, 25, 0.1) rows, columns = (4,4) fig = plt.figure(figsize=(10, 7))