from typing import Tuple import numpy as np import pandas as pd import matplotlib.pyplot as plt def load_data(): df_orig_train = pd.read_csv('mnist_test_final.csv') df_digits = df_orig_train.drop('label',axis=1) return df_digits.to_numpy() mnist = load_data() def sigmoid(x): return 1.0 / (1.0 + np.exp(-x)) # Sigmoidfunktion class RBM: def __init__(self, visible_size: int, hidden_size: int, learnrate: float=0.1) -> None: self.learnrate = learnrate self.visible_size = visible_size self.hidden_size = hidden_size self.k = 2 self.epochs = 10 self.reset() def reset(self) -> None: self.weights = np.random.randn(self.visible_size, self.hidden_size) self.visible_bias = np.zeros(self.visible_size) * 0.1 self.hidden_bias = np.zeros(self.hidden_size) * 0.1 def activate(self, v0): return sigmoid(np.matmul(v0.T, self.weights) + self.hidden_bias) def reactivate(self, h0): return sigmoid(np.matmul(self.weights, h0.T) + self.visible_bias) def contrastive_divergence(self, v0, h0, v1, h1): postive_gradient = np.outer(v0, h0) negative_gradient = np.outer(v1, h0) self.weights += self.learnrate * (postive_gradient - negative_gradient) return self.weights def train(self, v0): for _ in range(self.epochs): h0 = self.activate(v0) # Aktivieren versteckter Schicht v1 = self.reactivate(h0) # Reaktivieren sichtbarer Schicht h1 = self.activate(v1) self.contrastive_divergence(v0, h0, v1, h1) self.visible_bias += self.learnrate * (v0 - v1) self.hidden_bias += self.learnrate * (h0 - h1) return h0, v1 def run(self, v0 : np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """run Runs the Restricted Boltzmann machine on some input vector v0. Args: v0 (np.ndarray): 1-dimensional Input vector Returns: Tuple[np.ndarray, np.ndarray]: (hidden activation, visible reactivation) """ h0 = self.activate(v0) v1 = self.reactivate(h0) return h0, v1 def validate(idx): test = mnist[idx].flatten()/255 rbm.train(test) (hid, out) = rbm.run(test) return (hid.reshape((5, 5)), out.reshape((28,28))) rbm = RBM(28**2, 25, 0.1) rows, columns = (4,4) fig = plt.figure(figsize=(10, 7)) for i in range((rows * columns)): if i % 2 == 0: (hid, out) = validate(i) fig.add_subplot(rows, columns, i+1) plt.imshow(hid, cmap='gray') fig.add_subplot(rows, columns, i+2) plt.imshow(out, cmap='gray') plt.axis('off') plt.show()