stuff persistent
parent
c88d8d003e
commit
3fb0afd80e
|
|
@ -1,6 +1,8 @@
|
|||
import pygame
|
||||
import math
|
||||
import reinforcement_learning as rl
|
||||
import json
|
||||
import os
|
||||
|
||||
# Initialize pygame
|
||||
pygame.init()
|
||||
|
|
@ -120,19 +122,113 @@ def move_pacman(pacman, a):
|
|||
if a == 3: # down
|
||||
pacman.move(0, 1)
|
||||
|
||||
# Main game function
|
||||
def main():
|
||||
def save_q_table(q, filename="q_table.json"):
|
||||
"""Save Q-table to JSON file."""
|
||||
# Convert tuple keys to strings for JSON serialization
|
||||
q_json = {str(k): v for k, v in q.items()}
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(q_json, f)
|
||||
print(f"Q-table saved to {filename}")
|
||||
|
||||
def load_q_table(filename="q_table.json"):
|
||||
"""Load Q-table from JSON file, or return None if file doesn't exist."""
|
||||
if not os.path.exists(filename):
|
||||
print(f"No saved Q-table found at {filename}. Starting fresh.")
|
||||
return None
|
||||
|
||||
with open(filename, 'r') as f:
|
||||
q_json = json.load(f)
|
||||
|
||||
# Convert string keys back to tuples
|
||||
q = {eval(k): v for k, v in q_json.items()}
|
||||
print(f"Q-table loaded from {filename}")
|
||||
return q
|
||||
|
||||
# Training function (without visualization)
|
||||
def train(q, num_iterations=10000):
|
||||
"""Train the agent for num_iterations without pygame visualization."""
|
||||
global labyrinth
|
||||
|
||||
outer_iter = 0
|
||||
total_iterations = 0
|
||||
|
||||
while total_iterations < num_iterations:
|
||||
labyrinth = [
|
||||
"##########",
|
||||
"#........#",
|
||||
"#.##..##.#",
|
||||
"#........#",
|
||||
"##########"
|
||||
]
|
||||
running = True
|
||||
iter = 0
|
||||
|
||||
# Initialize Pacman and Ghost positions (no visual objects needed)
|
||||
pacman_x, pacman_y = 1, 1
|
||||
ghost_x, ghost_y = COLS - 2, ROWS - 2
|
||||
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
||||
|
||||
while running and total_iterations < num_iterations:
|
||||
iter = iter + 1
|
||||
total_iterations += 1
|
||||
|
||||
# Check for collisions
|
||||
if pacman_x == ghost_x and pacman_y == ghost_y:
|
||||
running = False
|
||||
break
|
||||
|
||||
# Eat cookies
|
||||
if labyrinth[pacman_y][pacman_x] == ".":
|
||||
labyrinth[pacman_y] = labyrinth[pacman_y][:pacman_x] + " " + labyrinth[pacman_y][pacman_x+1:]
|
||||
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
running = False
|
||||
break
|
||||
|
||||
# Q-Learning
|
||||
a = rl.epsilon_greedy(q, s)
|
||||
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
||||
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
||||
s = s_new
|
||||
|
||||
# Update Pacman position
|
||||
if a == 0: # left
|
||||
pacman_x = max(1, pacman_x - 1) if labyrinth[pacman_y][pacman_x - 1] != "#" else pacman_x
|
||||
elif a == 1: # right
|
||||
pacman_x = min(COLS - 2, pacman_x + 1) if labyrinth[pacman_y][pacman_x + 1] != "#" else pacman_x
|
||||
elif a == 2: # up
|
||||
pacman_y = max(1, pacman_y - 1) if labyrinth[pacman_y - 1][pacman_x] != "#" else pacman_y
|
||||
elif a == 3: # down
|
||||
pacman_y = min(ROWS - 2, pacman_y + 1) if labyrinth[pacman_y + 1][pacman_x] != "#" else pacman_y
|
||||
|
||||
# Ghost movement
|
||||
if iter % 3 == 0:
|
||||
if ghost_x < pacman_x and labyrinth[ghost_y][ghost_x + 1] != "#":
|
||||
ghost_x += 1
|
||||
elif ghost_x > pacman_x and labyrinth[ghost_y][ghost_x - 1] != "#":
|
||||
ghost_x -= 1
|
||||
elif ghost_y < pacman_y and labyrinth[ghost_y + 1][ghost_x] != "#":
|
||||
ghost_y += 1
|
||||
elif ghost_y > pacman_y and labyrinth[ghost_y - 1][ghost_x] != "#":
|
||||
ghost_y -= 1
|
||||
|
||||
s = (pacman_x, pacman_y, ghost_x, ghost_y)
|
||||
|
||||
outer_iter += 1
|
||||
if outer_iter % 100 == 0:
|
||||
print(f"Training iteration {outer_iter}, Total steps: {total_iterations}")
|
||||
|
||||
return q
|
||||
|
||||
# Visualization function (with pygame)
|
||||
def visualize(q, num_games=10):
|
||||
"""Visualize the trained agent playing the game."""
|
||||
global labyrinth
|
||||
q = rl.q_init()
|
||||
|
||||
clock = pygame.time.Clock()
|
||||
|
||||
# Game loop
|
||||
not_won = True
|
||||
outer_iter = 0
|
||||
|
||||
while not_won:
|
||||
|
||||
for game_num in range(num_games):
|
||||
labyrinth = [
|
||||
"##########",
|
||||
"#........#",
|
||||
|
|
@ -146,22 +242,17 @@ def main():
|
|||
# Initialize Pacman and Ghost positions
|
||||
pacman = Pacman(1, 1)
|
||||
ghost = Ghost(COLS - 2, ROWS - 2)
|
||||
s = (pacman.x, pacman.y, ghost.x, ghost.y) # as a tuple so the state becomes hashable
|
||||
s = (pacman.x, pacman.y, ghost.x, ghost.y)
|
||||
|
||||
print(f"Game {game_num + 1}/{num_games}")
|
||||
|
||||
# Handle events
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
not_won = False
|
||||
|
||||
print(outer_iter)
|
||||
while running or iter < 100:
|
||||
screen.fill(BLACK)
|
||||
iter = iter + 1
|
||||
|
||||
# Check for collisions (game over if ghost catches pacman)
|
||||
# Check for collisions
|
||||
if pacman.x == ghost.x and pacman.y == ghost.y:
|
||||
print("Game Over! The ghost caught Pacman.")
|
||||
outer_iter = outer_iter + 1
|
||||
running = False
|
||||
break
|
||||
|
||||
|
|
@ -169,49 +260,52 @@ def main():
|
|||
if labyrinth[pacman.y][pacman.x] == ".":
|
||||
labyrinth[pacman.y] = labyrinth[pacman.y][:pacman.x] + " " + labyrinth[pacman.y][pacman.x+1:]
|
||||
|
||||
# Check if all cookies are eaten (game over)
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
print("You Win! Pacman ate all the cookies.")
|
||||
running = False
|
||||
not_won = False
|
||||
break
|
||||
|
||||
# Q-Learning part ############################################################################
|
||||
|
||||
a = rl.epsilon_greedy(q, s) # 0 = Left; 1 = Right ; 2 = Up ; 3 = Down
|
||||
# Q-Learning
|
||||
a = rl.epsilon_greedy(q, s)
|
||||
s_new, r, labyrinth = rl.take_action(s, a, labyrinth)
|
||||
# print(s) # debugging
|
||||
# print(q[s]) # debugging
|
||||
|
||||
q[s][a] += ALPHA * (r + GAMMA * rl.max_q(q, s_new, labyrinth) - q[s][a])
|
||||
|
||||
s = s_new
|
||||
|
||||
|
||||
move_pacman(pacman, a)
|
||||
|
||||
if iter % 3 == 0:
|
||||
# Ghost moves towards Pacman
|
||||
ghost.move_towards_pacman(pacman)
|
||||
# Update state
|
||||
s = (pacman.x, pacman.y, ghost.x, ghost.y)
|
||||
|
||||
# End of Q-Learning part ######################################################################
|
||||
|
||||
# Draw the labyrinth, pacman, and ghost
|
||||
# Draw
|
||||
draw_labyrinth()
|
||||
pacman.draw()
|
||||
ghost.draw()
|
||||
|
||||
# Update display
|
||||
pygame.display.flip()
|
||||
|
||||
# Cap the frame rate
|
||||
# tick_speed = 100
|
||||
tick_speed = 5 if outer_iter % 20 == 0 else 50
|
||||
tick_speed = 20 # if game_num % 20 == 0 else 100
|
||||
clock.tick(tick_speed)
|
||||
|
||||
# Main function
|
||||
def main():
|
||||
global labyrinth
|
||||
|
||||
# Load existing Q-table or create new one
|
||||
q = load_q_table("q_table.json")
|
||||
if q is None:
|
||||
q = rl.q_init()
|
||||
|
||||
print("Training for 10000 iterations...")
|
||||
q = train(q, num_iterations=20000)
|
||||
|
||||
print("\nTraining complete! Starting visualization...")
|
||||
visualize(q, num_games=10)
|
||||
|
||||
pygame.quit()
|
||||
|
||||
# Save Q-table when exiting
|
||||
save_q_table(q, "q_table.json")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -136,6 +136,11 @@ def take_action(s, a, labyrinth):
|
|||
row_list[s_new[0]] = " "
|
||||
labyrinth[s_new[1]] = "".join(row_list)
|
||||
|
||||
# Check if all cookies are eaten
|
||||
if all("." not in row for row in labyrinth):
|
||||
r = 100.0
|
||||
#print("All cookies eaten")
|
||||
|
||||
return tuple(s_new), r, labyrinth
|
||||
|
||||
def max_q(q, s_new, labyrinth):
|
||||
|
|
@ -148,4 +153,22 @@ def max_q(q, s_new, labyrinth):
|
|||
if q[s_new][a] is not None: # Only consider valid (non-blocked) actions
|
||||
q_max = max(q_max, q[s_new][a])
|
||||
|
||||
return q_max
|
||||
return q_max
|
||||
|
||||
def get_nearest_cookie(pacman_x, pacman_y, labyrinth):
|
||||
cookies = [
|
||||
(x, y)
|
||||
for y, row in enumerate(labyrinth)
|
||||
for x, cell in enumerate(row)
|
||||
if cell == "."
|
||||
]
|
||||
if cookies:
|
||||
nearest = min(
|
||||
cookies, key=lambda c: abs(c[0] - pacman_x) + abs(c[1] - pacman_y)
|
||||
)
|
||||
cookie_dx = int(np.sign(nearest[0] - pacman_x))
|
||||
cookie_dy = int(np.sign(nearest[1] - pacman_y))
|
||||
else:
|
||||
cookie_dx, cookie_dy = 0, 0
|
||||
|
||||
return cookie_dx, cookie_dy
|
||||
Loading…
Reference in New Issue