MLE-Pacman/ReinforcmentLearning/util.py

169 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from enum import Enum
import random
import pygame
import numpy as np
import data.classes_consts as consts
class Direction(Enum):
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
def initial_q_fill():
q_values = {}
for x in range(-7, 8):
for y in range(-2, 3):
for cookie_direction in Direction:
for action in Direction:
state = (x, y, cookie_direction)
q_values[(state, action)] = random.random() * 0.2 - 0.1
# q_values[state][action] = random.random() * 0.2 - 0.1
return q_values
def initial_q_fill2():
indexer = consts.indexer
"""Initialize Q-table using linear indexing"""
# Create 2D array: [state_index, action]
# 300 states × 4 actions = 1200 entries
q_table = np.random.uniform(
low=-0.1,
high=0.1,
size=(indexer.total_states, 4) # 300 × 4
)
return q_table, indexer
def calc_current_state(labyrinth, pac_x, pac_y, ghost_x, ghost_y):
x_ghost_dist = pac_x - ghost_x
y_ghost_dist = pac_y - ghost_y
cookie_direction = get_closest_cookie_direction(labyrinth, pac_x, pac_y)
return x_ghost_dist, y_ghost_dist, cookie_direction
def get_closest_cookie_direction(labyrinth, pac_x, pac_y):
cookie_distances = []
for y, row in enumerate(labyrinth):
for x, cell in enumerate(row):
if cell == ".":
x_dist = abs(pac_x - x)
y_dist = abs(pac_y - y)
dist = x_dist + y_dist
cookie_distances.append((dist, (x, y)))
closest_cookie = min(cookie_distances, key=lambda item: item[0])
closest_cookie_cords = closest_cookie[1]
cookie_x = closest_cookie_cords[0]
cookie_y = closest_cookie_cords[1]
dx = cookie_x - pac_x
dy = cookie_y - pac_y
if abs(dx) >= abs(dy):
#? X distance bigger
if dy > 0:
return Direction.DOWN
elif dy < 0:
return Direction.UP
else:
#? Cookie on the same Y level
if dx > 0:
return Direction.RIGHT
else:
return Direction.LEFT
else:
#? Y distance bigger
if dx > 0:
return Direction.RIGHT
elif dx < 0:
return Direction.LEFT
else:
#? Cookie on the same X level
if dy > 0:
return Direction.DOWN
else:
return Direction.UP
def epsilon_greedy(q_values, state, epsilon):
best_action, actions_for_epsilon = get_best_q_action(q_values, state)
if random.random() < epsilon:
if not actions_for_epsilon:
best_action = get_random_direction()
return best_action
random_action = random.choice(actions_for_epsilon)
return random_action
return best_action
def get_best_q_action(q_values, state):
best_action = None
best_value = None
actions_for_epsilon = []
for (q_state, q_action), value in q_values.items():
if q_state == state:
actions_for_epsilon.append(q_action)
if best_value is None:
best_value = value
best_action = q_action
continue
if value > best_value:
best_value = value
best_action = q_action
if not best_action:
best_action = get_random_direction()
return best_action, actions_for_epsilon
def get_random_direction():
return random.choice(list(Direction))
def calc_time_reward(amount_iterations):
if amount_iterations < 1000:
return 10
if amount_iterations > 10000:
return 1
return - (1 / 1000) * amount_iterations + 11
def draw_labyrinth(screen, labyrinth):
CELL_SIZE = consts.CELL_SIZE
BLUE = consts.BLUE
WHITE = consts.WHITE
for y, row in enumerate(labyrinth):
for x, cell in enumerate(row):
if cell == "#":
pygame.draw.rect(screen, BLUE, (x * CELL_SIZE, y * CELL_SIZE, CELL_SIZE, CELL_SIZE))
elif cell == ".":
pygame.draw.circle(screen, WHITE, (x * CELL_SIZE + CELL_SIZE // 2, y * CELL_SIZE + CELL_SIZE // 2), 5)