epsilon greedy > 85%
parent
476b67fa71
commit
fa2acdcecd
|
|
@ -149,7 +149,6 @@ def train(q, num_iterations=10000):
|
|||
"""Train the agent for num_iterations without pygame visualization."""
|
||||
global labyrinth
|
||||
|
||||
outer_iter = 0
|
||||
total_iterations = 0
|
||||
|
||||
while total_iterations < num_iterations:
|
||||
|
|
@ -168,12 +167,25 @@ def train(q, num_iterations=10000):
|
|||
ghost_x, ghost_y = COLS - 2, ROWS - 2
|
||||
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
||||
|
||||
while running and total_iterations < num_iterations:
|
||||
while running:
|
||||
iter = iter + 1
|
||||
total_iterations += 1
|
||||
|
||||
# Check for collisions
|
||||
if pacman_x == ghost_x and pacman_y == ghost_y:
|
||||
running = False
|
||||
# total_iterations += 1
|
||||
|
||||
# Eat cookies
|
||||
if labyrinth[pacman_y][pacman_x] == ".":
|
||||
labyrinth[pacman_y] = labyrinth[pacman_y][:pacman_x] + " " + labyrinth[pacman_y][pacman_x+1:]
|
||||
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
running = False
|
||||
total_iterations += 1
|
||||
|
||||
# Q-Learning
|
||||
a = rl.epsilon_greedy(q, s)
|
||||
a = rl.epsilon_greedy(q, s, 0.025)
|
||||
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
||||
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
||||
s = s_new
|
||||
|
|
@ -200,24 +212,9 @@ def train(q, num_iterations=10000):
|
|||
ghost_y -= 1
|
||||
|
||||
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
||||
|
||||
# Check for collisions
|
||||
if pacman_x == ghost_x and pacman_y == ghost_y:
|
||||
running = False
|
||||
break
|
||||
|
||||
# Eat cookies
|
||||
if labyrinth[pacman_y][pacman_x] == ".":
|
||||
labyrinth[pacman_y] = labyrinth[pacman_y][:pacman_x] + " " + labyrinth[pacman_y][pacman_x+1:]
|
||||
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
running = False
|
||||
break
|
||||
|
||||
outer_iter += 1
|
||||
if outer_iter % 100 == 0:
|
||||
print(f"Training iteration {outer_iter}, Total steps: {total_iterations}")
|
||||
if total_iterations % 500 == 0:
|
||||
print(f"Training iteration {total_iterations}")
|
||||
|
||||
return q
|
||||
|
||||
|
|
@ -226,6 +223,9 @@ def visualize(q, num_games=10):
|
|||
"""Visualize the trained agent playing the game."""
|
||||
global labyrinth
|
||||
|
||||
games_won = 0
|
||||
games_lost = 0
|
||||
|
||||
clock = pygame.time.Clock()
|
||||
|
||||
for game_num in range(num_games):
|
||||
|
|
@ -246,12 +246,30 @@ def visualize(q, num_games=10):
|
|||
|
||||
print(f"Game {game_num + 1}/{num_games}")
|
||||
|
||||
while running or iter < 100:
|
||||
while running or iter < 300:
|
||||
screen.fill(BLACK)
|
||||
iter = iter + 1
|
||||
|
||||
# Check for collisions
|
||||
if pacman.x == ghost.x and pacman.y == ghost.y:
|
||||
print("Game Over! The ghost caught Pacman.")
|
||||
running = False
|
||||
games_lost += 1
|
||||
break
|
||||
|
||||
# Eat cookies
|
||||
if labyrinth[pacman.y][pacman.x] == ".":
|
||||
labyrinth[pacman.y] = labyrinth[pacman.y][:pacman.x] + " " + labyrinth[pacman.y][pacman.x+1:]
|
||||
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
print("You Win! Pacman ate all the cookies.")
|
||||
running = False
|
||||
games_won += 1
|
||||
break
|
||||
|
||||
# Q-Learning
|
||||
a = rl.epsilon_greedy(q, s, epsilon=0.025)
|
||||
a = rl.epsilon_greedy(q, s, 0.025)
|
||||
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
||||
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
||||
s = s_new
|
||||
|
|
@ -263,30 +281,15 @@ def visualize(q, num_games=10):
|
|||
|
||||
s = (pacman.x, pacman.y, ghost.x, ghost.y)
|
||||
|
||||
# Check for collisions
|
||||
if pacman.x == ghost.x and pacman.y == ghost.y:
|
||||
print("Game Over! The ghost caught Pacman.")
|
||||
running = False
|
||||
break
|
||||
|
||||
# Eat cookies
|
||||
if labyrinth[pacman.y][pacman.x] == ".":
|
||||
labyrinth[pacman.y] = labyrinth[pacman.y][:pacman.x] + " " + labyrinth[pacman.y][pacman.x+1:]
|
||||
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
print("You Win! Pacman ate all the cookies.")
|
||||
running = False
|
||||
break
|
||||
|
||||
# Draw
|
||||
draw_labyrinth()
|
||||
pacman.draw()
|
||||
ghost.draw()
|
||||
pygame.display.flip()
|
||||
|
||||
tick_speed = 10 # if game_num % 20 == 0 else 100
|
||||
tick_speed = 200 # if game_num % 20 == 0 else 100
|
||||
clock.tick(tick_speed)
|
||||
print("winrate: " + str(games_won / num_games))
|
||||
|
||||
# Main function
|
||||
def main():
|
||||
|
|
@ -298,10 +301,10 @@ def main():
|
|||
q = rl.q_init()
|
||||
|
||||
print("Training for 10000 iterations...")
|
||||
q = train(q, num_iterations=5000)
|
||||
q = train(q, num_iterations=10000)
|
||||
|
||||
print("\nTraining complete! Starting visualization...")
|
||||
visualize(q, num_games=10)
|
||||
visualize(q, num_games=100)
|
||||
|
||||
pygame.quit()
|
||||
|
||||
|
|
|
|||
|
|
@ -127,6 +127,17 @@ def take_action(s, a, labyrinth):
|
|||
# print("Invalid action")
|
||||
else:
|
||||
r = calc_reward(tuple(s_new), labyrinth)
|
||||
|
||||
# Boost reward if moving closer to nearest cookie
|
||||
cookie_dx, cookie_dy = get_nearest_cookie(s[0], s[1], labyrinth)
|
||||
old_distance = abs(cookie_dx) + abs(cookie_dy)
|
||||
|
||||
new_cookie_dx = int(np.sign(get_nearest_cookie(s_new[0], s_new[1], labyrinth)[0] - s_new[0]))
|
||||
new_cookie_dy = int(np.sign(get_nearest_cookie(s_new[0], s_new[1], labyrinth)[1] - s_new[1]))
|
||||
new_distance = abs(new_cookie_dx) + abs(new_cookie_dy)
|
||||
|
||||
if new_distance < old_distance:
|
||||
r += 2 # Bonus for moving closer to cookie
|
||||
|
||||
# Mark new Pacman position as eaten (if it's a cookie)
|
||||
if labyrinth[s_new[1]][s_new[0]] == ".":
|
||||
|
|
@ -135,11 +146,6 @@ def take_action(s, a, labyrinth):
|
|||
row_list[s_new[0]] = " "
|
||||
labyrinth[s_new[1]] = "".join(row_list)
|
||||
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
r = 100.0
|
||||
#print("All cookies eaten")
|
||||
|
||||
return tuple(s_new), r, labyrinth
|
||||
|
||||
def max_q(q, s_new, labyrinth):
|
||||
|
|
|
|||
Loading…
Reference in New Issue