main
romanamo 2024-05-09 00:41:02 +02:00
parent 823ef138a8
commit ef7be8ac54
1 changed files with 5 additions and 14 deletions

View File

@ -29,8 +29,8 @@ class RBM:
def reset(self) -> None:
self.weights = np.random.randn(self.visible_size, self.hidden_size)
self.visible_bias = np.ones(self.visible_size) * 0.1
self.hidden_bias = np.ones(self.hidden_size) * 0.1
self.visible_bias = np.zeros(self.visible_size) * 0.1
self.hidden_bias = np.zeros(self.hidden_size) * 0.1
def activate(self, v0):
return sigmoid(np.matmul(v0.T, self.weights) + self.hidden_bias)
@ -46,14 +46,6 @@ class RBM:
return self.weights
def gibbs(self, v0):
for _ in range(self.k):
hidden_probs = self.activate(v0)
h0 = np.random.rand(len(v0), self.hidden_size) < hidden_probs
visible_probs = self.sigmoid(np.dot(h0, self.weights.T) + self.visible_bias)
v0 = np.random.rand(len(v0), self.visible_size) < visible_probs
return v0, hidden_probs
def train(self, v0):
for _ in range(self.epochs):
h0 = self.activate(v0) # Aktivieren versteckter Schicht
@ -85,15 +77,14 @@ class RBM:
def validate(idx):
test = mnist[idx].flatten()
#rbm.reset()
test = mnist[idx].flatten()/255
rbm.train(test)
(hid, out) = rbm.run(test)
return (hid.reshape((28, 28)), out.reshape((28,28)))
return (hid.reshape((5, 5)), out.reshape((28,28)))
rbm = RBM(28**2, 28**2, 0.1)
rbm = RBM(28**2, 25, 0.1)
rows, columns = (4,4)
fig = plt.figure(figsize=(10, 7))