MLE/04_pacman_rl/reinforcement_learning.py

163 lines
5.2 KiB
Python

"""
Entwickeln Sie einen Reinforcement Learning (RL) Agenten, der in
einem minimalistischen Pacman-Spiel (bereitgestellt auf meiner
Homepage) effektiv Punkte sammelt, während er dem Geist
ausweicht und somit vermeidet gefressen zu werden.
"""
import numpy as np
def q_init():
""" Fill every possible action in every state with a small value for initialization"""
# Configuration
NUM_ACTIONS = 4
INITIAL_Q_VALUE = 3.0 # Small value for initialization
# Labyrinth layout
labyrinth = [
"##########",
"#........#",
"#.##..##.#",
"#........#",
"##########"
]
s0_range = range(1, 9)
s1_range = range(1, 4)
s2_range = range(1, 9)
s3_range = range(1, 4)
s_constrained_values = {1, 4, 5, 8}
# The Q-Table dictionary
q_table = {}
# Iterate through all possible combinations of s0, s1, s2, s3
for s0 in s0_range:
for s1 in s1_range:
for s2 in s2_range:
for s3 in s3_range:
# Skip impossible states
if s1 == 2 and s0 not in s_constrained_values:
continue
if s3 == 2 and s2 not in s_constrained_values:
continue
if s0 == s2 and s1 == s3:
continue
# Assign all possible states a tuple of values
state_key = (s0, s1, s2, s3)
q_values = [INITIAL_Q_VALUE] * NUM_ACTIONS
# Check which actions are blocked by walls
# Action 0: move left (s0 - 1)
if labyrinth[s1][s0 - 1] == "#":
q_values[0] = 0
# Action 1: move right (s0 + 1)
if labyrinth[s1][s0 + 1] == "#":
q_values[1] = 0
# Action 2: move up (s1 - 1)
if labyrinth[s1 - 1][s0] == "#":
q_values[2] = 0
# Action 3: move down (s1 + 1)
if labyrinth[s1 + 1][s0] == "#":
q_values[3] = 0
q_table[state_key] = q_values
# print(f"Total number of valid states initialized: {len(q_table)}") # debugging
# print(list(q_table.items())[:5]) # Uncomment to see the first 5 entries
return q_table
def epsilon_greedy(q, s, epsilon=0.2):
"""
Return which direction Pacman should move to using epsilon-greedy algorithm
With probability epsilon, choose a random action. Otherwise choose the greedy action.
If multiple actions have the same max Q-value, prefer actions different from a_prev.
Never allows Pacman to move backwards (opposite direction).
"""
q_max = max(q[s])
a = q[s].index(q_max)
return a
"""
if np.random.random() < epsilon:
# Explore: choose random action (excluding blocked actions with Q=0)
valid_actions = [i for i in range(len(q[s])) if q[s][i] > 0]
if valid_actions:
return np.random.choice(valid_actions)
else:
return np.random.randint(0, len(q[s]))
else:
# Exploit: choose best action
valid_q_values = [(i, q[s][i]) for i in range(len(q[s])) if q[s][i] > 0]
if valid_q_values:
# Get max Q-value among valid actions
best_action = max(valid_q_values, key=lambda x: x[1])[0]
return best_action
else:
return 0
"""
def bfs_distance(start, end, labyrinth):
"""
Calculate shortest path distance between two points using BFS.
Returns the distance or infinity if no path exists.
"""
from collections import deque
if start == end:
return 0
queue = deque([(start, 0)]) # (position, distance)
visited = {start}
while queue:
(x, y), dist = queue.popleft()
# Check all 4 directions
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
nx, ny = x + dx, y + dy
if (nx, ny) == end:
return round(dist + 1, 2)
if 0 <= ny < len(labyrinth) and 0 <= nx < len(labyrinth[0]):
if (nx, ny) not in visited and labyrinth[ny][nx] != "#":
visited.add((nx, ny))
queue.append(((nx, ny), dist + 1))
return float('inf') # No path found
def take_action(s, a, labyrinth):
s_new = list(s)
if a == 0: # left
s_new[0] -= 1
if a == 1: # right
s_new[0] += 1
if a == 2: # up
s_new[1] -= 1
if a == 3: # down
s_new[1] += 1
# consider new distance between Pacman and Ghost using actual pathfinding
pacman_pos_new = (s_new[0], s_new[1])
ghost_pos = (s[2], s[3])
distance_new = bfs_distance(pacman_pos_new, ghost_pos, labyrinth)
# Reward inversely proportional to distance from ghost (asymptotes to 0)
r = 1.0 / (1.0 + distance_new) if distance_new != float('inf') else 0.0
# Reward for eating cookies
r += 5.0 if labyrinth[s_new[1]][s_new[0]] == "." else -2.0
# Ensure reward doesn't drop below 0.01
r = max(r, 0.01)
return tuple(s_new), r