epsilon greedy > 85%
parent
476b67fa71
commit
fa2acdcecd
|
|
@ -149,7 +149,6 @@ def train(q, num_iterations=10000):
|
||||||
"""Train the agent for num_iterations without pygame visualization."""
|
"""Train the agent for num_iterations without pygame visualization."""
|
||||||
global labyrinth
|
global labyrinth
|
||||||
|
|
||||||
outer_iter = 0
|
|
||||||
total_iterations = 0
|
total_iterations = 0
|
||||||
|
|
||||||
while total_iterations < num_iterations:
|
while total_iterations < num_iterations:
|
||||||
|
|
@ -168,12 +167,25 @@ def train(q, num_iterations=10000):
|
||||||
ghost_x, ghost_y = COLS - 2, ROWS - 2
|
ghost_x, ghost_y = COLS - 2, ROWS - 2
|
||||||
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
||||||
|
|
||||||
while running and total_iterations < num_iterations:
|
while running:
|
||||||
iter = iter + 1
|
iter = iter + 1
|
||||||
total_iterations += 1
|
|
||||||
|
# Check for collisions
|
||||||
|
if pacman_x == ghost_x and pacman_y == ghost_y:
|
||||||
|
running = False
|
||||||
|
# total_iterations += 1
|
||||||
|
|
||||||
|
# Eat cookies
|
||||||
|
if labyrinth[pacman_y][pacman_x] == ".":
|
||||||
|
labyrinth[pacman_y] = labyrinth[pacman_y][:pacman_x] + " " + labyrinth[pacman_y][pacman_x+1:]
|
||||||
|
|
||||||
|
# Check if all cookies are eaten
|
||||||
|
if all("." not in row for row in labyrinth):
|
||||||
|
running = False
|
||||||
|
total_iterations += 1
|
||||||
|
|
||||||
# Q-Learning
|
# Q-Learning
|
||||||
a = rl.epsilon_greedy(q, s)
|
a = rl.epsilon_greedy(q, s, 0.025)
|
||||||
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
||||||
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
||||||
s = s_new
|
s = s_new
|
||||||
|
|
@ -201,23 +213,8 @@ def train(q, num_iterations=10000):
|
||||||
|
|
||||||
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
||||||
|
|
||||||
# Check for collisions
|
if total_iterations % 500 == 0:
|
||||||
if pacman_x == ghost_x and pacman_y == ghost_y:
|
print(f"Training iteration {total_iterations}")
|
||||||
running = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Eat cookies
|
|
||||||
if labyrinth[pacman_y][pacman_x] == ".":
|
|
||||||
labyrinth[pacman_y] = labyrinth[pacman_y][:pacman_x] + " " + labyrinth[pacman_y][pacman_x+1:]
|
|
||||||
|
|
||||||
# Check if all cookies are eaten
|
|
||||||
if all("." not in row for row in labyrinth):
|
|
||||||
running = False
|
|
||||||
break
|
|
||||||
|
|
||||||
outer_iter += 1
|
|
||||||
if outer_iter % 100 == 0:
|
|
||||||
print(f"Training iteration {outer_iter}, Total steps: {total_iterations}")
|
|
||||||
|
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
@ -226,6 +223,9 @@ def visualize(q, num_games=10):
|
||||||
"""Visualize the trained agent playing the game."""
|
"""Visualize the trained agent playing the game."""
|
||||||
global labyrinth
|
global labyrinth
|
||||||
|
|
||||||
|
games_won = 0
|
||||||
|
games_lost = 0
|
||||||
|
|
||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
for game_num in range(num_games):
|
for game_num in range(num_games):
|
||||||
|
|
@ -246,12 +246,30 @@ def visualize(q, num_games=10):
|
||||||
|
|
||||||
print(f"Game {game_num + 1}/{num_games}")
|
print(f"Game {game_num + 1}/{num_games}")
|
||||||
|
|
||||||
while running or iter < 100:
|
while running or iter < 300:
|
||||||
screen.fill(BLACK)
|
screen.fill(BLACK)
|
||||||
iter = iter + 1
|
iter = iter + 1
|
||||||
|
|
||||||
|
# Check for collisions
|
||||||
|
if pacman.x == ghost.x and pacman.y == ghost.y:
|
||||||
|
print("Game Over! The ghost caught Pacman.")
|
||||||
|
running = False
|
||||||
|
games_lost += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# Eat cookies
|
||||||
|
if labyrinth[pacman.y][pacman.x] == ".":
|
||||||
|
labyrinth[pacman.y] = labyrinth[pacman.y][:pacman.x] + " " + labyrinth[pacman.y][pacman.x+1:]
|
||||||
|
|
||||||
|
# Check if all cookies are eaten
|
||||||
|
if all("." not in row for row in labyrinth):
|
||||||
|
print("You Win! Pacman ate all the cookies.")
|
||||||
|
running = False
|
||||||
|
games_won += 1
|
||||||
|
break
|
||||||
|
|
||||||
# Q-Learning
|
# Q-Learning
|
||||||
a = rl.epsilon_greedy(q, s, epsilon=0.025)
|
a = rl.epsilon_greedy(q, s, 0.025)
|
||||||
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
||||||
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
||||||
s = s_new
|
s = s_new
|
||||||
|
|
@ -263,30 +281,15 @@ def visualize(q, num_games=10):
|
||||||
|
|
||||||
s = (pacman.x, pacman.y, ghost.x, ghost.y)
|
s = (pacman.x, pacman.y, ghost.x, ghost.y)
|
||||||
|
|
||||||
# Check for collisions
|
|
||||||
if pacman.x == ghost.x and pacman.y == ghost.y:
|
|
||||||
print("Game Over! The ghost caught Pacman.")
|
|
||||||
running = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Eat cookies
|
|
||||||
if labyrinth[pacman.y][pacman.x] == ".":
|
|
||||||
labyrinth[pacman.y] = labyrinth[pacman.y][:pacman.x] + " " + labyrinth[pacman.y][pacman.x+1:]
|
|
||||||
|
|
||||||
# Check if all cookies are eaten
|
|
||||||
if all("." not in row for row in labyrinth):
|
|
||||||
print("You Win! Pacman ate all the cookies.")
|
|
||||||
running = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Draw
|
# Draw
|
||||||
draw_labyrinth()
|
draw_labyrinth()
|
||||||
pacman.draw()
|
pacman.draw()
|
||||||
ghost.draw()
|
ghost.draw()
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
|
|
||||||
tick_speed = 10 # if game_num % 20 == 0 else 100
|
tick_speed = 200 # if game_num % 20 == 0 else 100
|
||||||
clock.tick(tick_speed)
|
clock.tick(tick_speed)
|
||||||
|
print("winrate: " + str(games_won / num_games))
|
||||||
|
|
||||||
# Main function
|
# Main function
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -298,10 +301,10 @@ def main():
|
||||||
q = rl.q_init()
|
q = rl.q_init()
|
||||||
|
|
||||||
print("Training for 10000 iterations...")
|
print("Training for 10000 iterations...")
|
||||||
q = train(q, num_iterations=5000)
|
q = train(q, num_iterations=10000)
|
||||||
|
|
||||||
print("\nTraining complete! Starting visualization...")
|
print("\nTraining complete! Starting visualization...")
|
||||||
visualize(q, num_games=10)
|
visualize(q, num_games=100)
|
||||||
|
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,17 @@ def take_action(s, a, labyrinth):
|
||||||
else:
|
else:
|
||||||
r = calc_reward(tuple(s_new), labyrinth)
|
r = calc_reward(tuple(s_new), labyrinth)
|
||||||
|
|
||||||
|
# Boost reward if moving closer to nearest cookie
|
||||||
|
cookie_dx, cookie_dy = get_nearest_cookie(s[0], s[1], labyrinth)
|
||||||
|
old_distance = abs(cookie_dx) + abs(cookie_dy)
|
||||||
|
|
||||||
|
new_cookie_dx = int(np.sign(get_nearest_cookie(s_new[0], s_new[1], labyrinth)[0] - s_new[0]))
|
||||||
|
new_cookie_dy = int(np.sign(get_nearest_cookie(s_new[0], s_new[1], labyrinth)[1] - s_new[1]))
|
||||||
|
new_distance = abs(new_cookie_dx) + abs(new_cookie_dy)
|
||||||
|
|
||||||
|
if new_distance < old_distance:
|
||||||
|
r += 2 # Bonus for moving closer to cookie
|
||||||
|
|
||||||
# Mark new Pacman position as eaten (if it's a cookie)
|
# Mark new Pacman position as eaten (if it's a cookie)
|
||||||
if labyrinth[s_new[1]][s_new[0]] == ".":
|
if labyrinth[s_new[1]][s_new[0]] == ".":
|
||||||
# Convert string row to list, modify it, then convert back to string
|
# Convert string row to list, modify it, then convert back to string
|
||||||
|
|
@ -135,11 +146,6 @@ def take_action(s, a, labyrinth):
|
||||||
row_list[s_new[0]] = " "
|
row_list[s_new[0]] = " "
|
||||||
labyrinth[s_new[1]] = "".join(row_list)
|
labyrinth[s_new[1]] = "".join(row_list)
|
||||||
|
|
||||||
# Check if all cookies are eaten
|
|
||||||
if all("." not in row for row in labyrinth):
|
|
||||||
r = 100.0
|
|
||||||
#print("All cookies eaten")
|
|
||||||
|
|
||||||
return tuple(s_new), r, labyrinth
|
return tuple(s_new), r, labyrinth
|
||||||
|
|
||||||
def max_q(q, s_new, labyrinth):
|
def max_q(q, s_new, labyrinth):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue