98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
import pygame as pyg
|
|
import numpy as np
|
|
import random
|
|
|
|
# die Q-Updates machen
|
|
def updateQ(reward, state, action, nextState):
|
|
global er_re, er_s, er_a, er_ns, tick, Q, alpha, gamma
|
|
# Replay-Buffer füllen
|
|
er_re[tick%400]= reward # experience replay Belohnung
|
|
er_s[tick%400] = state # experience replay Zustand
|
|
er_a[tick%400] = action # experience replay Aktion
|
|
er_ns[tick%400]= nextState# experience replay nächster Zustand
|
|
|
|
for i in range(batch_size):
|
|
r = random.randint(0,399)
|
|
# Q[s][a]+=r+alpha*(gamma * max_a' Q(s',a')-Q(s,a))
|
|
Q[int(er_s[r])][int(er_a[r])] += er_re[r] + alpha*(gamma * np.max(Q[int(er_ns[r])]) - Q[int(er_s[r])][int(er_a[r])])
|
|
|
|
# Nächste Aktion anfragen
|
|
def getAction(state): # gibt -1 für Schläger links oder +1 für rechts zurück
|
|
global epsilon, Q
|
|
if np.random.rand() <= epsilon:
|
|
return np.random.choice([-1, 1])
|
|
return (np.argmax(Q[int(state)]) * 2) - 1
|
|
|
|
# Macht aus 5 Koordinaten -> 1 Koordinate
|
|
# damit bekommt jeder Zustand eine eindeutige Nummer
|
|
def getState(x_ball, y_ball, vx_ball, vy_ball, x_racket):
|
|
return (((x_ball*13 +y_ball)*2 +(vx_ball+1)/2)*2 +(vy_ball+1)/2)*12 +x_racket
|
|
|
|
# Parameter für Q-Learning und Experience Replay
|
|
epsilon = 1
|
|
alpha = 0.1
|
|
gamma = 0.95
|
|
batch_size = 32
|
|
er_re = np.zeros(400)
|
|
er_s = np.zeros(400)
|
|
er_a = np.zeros(400)
|
|
er_ns = np.zeros(400)
|
|
|
|
tick = 0 # zählt bei jeder Zustandsändrung hoch
|
|
episode = 0 # zählt die Episoden
|
|
|
|
num_of_states = 13*12*2*2*12
|
|
num_of_actions = 2
|
|
Q = np.random.rand(num_of_states, num_of_actions)/1000.0
|
|
|
|
pyg.init()
|
|
screen = pyg.display.set_mode((240, 260))
|
|
pyg.display.set_caption("Q-Learning Experience-Replay")
|
|
file = open('reward_experience_replay.txt','w')
|
|
x_racket, x_ball, y_ball, vx_ball, vy_ball, score = 5, 1, 1, 1, 1, 0
|
|
|
|
cont = True
|
|
#clock = pyg.time.Clock()
|
|
while cont:
|
|
for event in pyg.event.get():
|
|
if event.type == pyg.QUIT:
|
|
cont = False
|
|
|
|
epsilon -= 1/400
|
|
if (epsilon<0):
|
|
epsilon=0
|
|
screen.fill((0,0,0))
|
|
font = pyg.font.SysFont("arial", 15)
|
|
t = font.render("Score:"+str(score)+" Episode:"+str(episode), True, (255,255,255))
|
|
screen.blit(t, t.get_rect(centerx = screen.get_rect().centerx))
|
|
pyg.draw.rect(screen, (0, 128, 255), pyg.Rect(x_racket*20, 250, 80, 10))
|
|
pyg.draw.rect(screen, (255, 100, 0), pyg.Rect(x_ball*20, y_ball*20, 20, 20))
|
|
|
|
state = getState(x_ball, y_ball, vx_ball, vy_ball, x_racket)
|
|
action = getAction(state)
|
|
|
|
x_racket = x_racket + action # Aktion ausführen
|
|
if x_racket>11: x_racket = 11
|
|
if x_racket<0: x_racket = 0
|
|
|
|
x_ball, y_ball = x_ball + vx_ball, y_ball + vy_ball
|
|
if x_ball > 10 or x_ball < 1: vx_ball *= -1
|
|
if y_ball > 11 or y_ball < 1: vy_ball *= -1
|
|
|
|
reward = 0
|
|
if y_ball == 12:
|
|
reward = -1 # Annahme: Ball daneben
|
|
if x_ball >= x_racket and x_ball <= x_racket + 4:
|
|
reward = +1 # Ball doch getroffen
|
|
episode += 1
|
|
score = score + reward
|
|
|
|
nextState = getState(x_ball, y_ball, vx_ball, vy_ball, x_racket)
|
|
updateQ(reward, state, (action+1)//2, nextState)
|
|
|
|
tick += 1
|
|
if reward!=0:
|
|
file.write(str(reward)+",")
|
|
file.flush()
|
|
#clock.tick(60) # Refresh-Zeiten festlegen 60 FPS
|
|
pyg.display.flip() |