forked from 2211275/gnn
45 lines
961 B
Python
45 lines
961 B
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
start = np.array([0, 0])
|
|
bias = np.array([-3.37, 0.125])
|
|
weights = np.array([
|
|
[-4,1.5],
|
|
[-1.5,0]
|
|
])
|
|
|
|
def activate(input):
|
|
# calculate activation as matrix
|
|
# o1 = w11 * o1 + w12 * o2 + b1
|
|
# o2 = w21 * o1 + w22 * o2 + b2
|
|
return np.matmul(weights, input) + bias
|
|
|
|
def predict(n):
|
|
current_output = start.copy()
|
|
points = np.zeros((2, n))
|
|
|
|
for i in range(n):
|
|
# calculate output with tanh(x)
|
|
current_output = np.tanh(activate(current_output))
|
|
|
|
#save datapoint
|
|
points[:, i] = current_output.copy()
|
|
|
|
return points
|
|
|
|
timespan = 50
|
|
timespan_range = range(timespan)
|
|
|
|
predictions = predict(timespan)
|
|
|
|
# fetch o1, o2 from datapoint prediction
|
|
o1 = predictions[0, :]
|
|
o2 = predictions[1, :]
|
|
|
|
plt.title("Recurrent Neural Network")
|
|
plt.plot(timespan_range, o1)
|
|
plt.plot(timespan_range, o2)
|
|
plt.legend(["o1", "o2"], loc="upper left")
|
|
plt.show()
|
|
|