finally
parent
823ef138a8
commit
ef7be8ac54
|
@ -29,8 +29,8 @@ class RBM:
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.weights = np.random.randn(self.visible_size, self.hidden_size)
|
self.weights = np.random.randn(self.visible_size, self.hidden_size)
|
||||||
self.visible_bias = np.ones(self.visible_size) * 0.1
|
self.visible_bias = np.zeros(self.visible_size) * 0.1
|
||||||
self.hidden_bias = np.ones(self.hidden_size) * 0.1
|
self.hidden_bias = np.zeros(self.hidden_size) * 0.1
|
||||||
|
|
||||||
def activate(self, v0):
|
def activate(self, v0):
|
||||||
return sigmoid(np.matmul(v0.T, self.weights) + self.hidden_bias)
|
return sigmoid(np.matmul(v0.T, self.weights) + self.hidden_bias)
|
||||||
|
@ -45,14 +45,6 @@ class RBM:
|
||||||
self.weights += self.learnrate * (postive_gradient - negative_gradient)
|
self.weights += self.learnrate * (postive_gradient - negative_gradient)
|
||||||
|
|
||||||
return self.weights
|
return self.weights
|
||||||
|
|
||||||
def gibbs(self, v0):
|
|
||||||
for _ in range(self.k):
|
|
||||||
hidden_probs = self.activate(v0)
|
|
||||||
h0 = np.random.rand(len(v0), self.hidden_size) < hidden_probs
|
|
||||||
visible_probs = self.sigmoid(np.dot(h0, self.weights.T) + self.visible_bias)
|
|
||||||
v0 = np.random.rand(len(v0), self.visible_size) < visible_probs
|
|
||||||
return v0, hidden_probs
|
|
||||||
|
|
||||||
def train(self, v0):
|
def train(self, v0):
|
||||||
for _ in range(self.epochs):
|
for _ in range(self.epochs):
|
||||||
|
@ -85,15 +77,14 @@ class RBM:
|
||||||
|
|
||||||
def validate(idx):
|
def validate(idx):
|
||||||
|
|
||||||
test = mnist[idx].flatten()
|
test = mnist[idx].flatten()/255
|
||||||
#rbm.reset()
|
|
||||||
rbm.train(test)
|
rbm.train(test)
|
||||||
(hid, out) = rbm.run(test)
|
(hid, out) = rbm.run(test)
|
||||||
|
|
||||||
return (hid.reshape((28, 28)), out.reshape((28,28)))
|
return (hid.reshape((5, 5)), out.reshape((28,28)))
|
||||||
|
|
||||||
|
|
||||||
rbm = RBM(28**2, 28**2, 0.1)
|
rbm = RBM(28**2, 25, 0.1)
|
||||||
|
|
||||||
rows, columns = (4,4)
|
rows, columns = (4,4)
|
||||||
fig = plt.figure(figsize=(10, 7))
|
fig = plt.figure(figsize=(10, 7))
|
||||||
|
|
Loading…
Reference in New Issue