Training an AI to catch fruit using Deep Reinforcement Learning
This article builds on tutorials on Reinforcement Learning (DQN, or Deep Q Network), such as this one. I recommend checking that out for the basics of DQNs. That tutorial uses a pre-made “gymnasium” environment in which a DQN agent can learn to act to maximise reward / minimise punishment. However, I wanted to try and build my own environment and DQN from scratch. This article details my efforts.
The Results#
The video above compares the agent’s performance at episode (i.e. epoch) -1, 50, and 130. Before training, the agent responds to the environment randomly. After some training, the agent catches some fruit but still achieves a negative score. After much training, the agent is thinking ahead to catch fruit, seems to understand when pursuing fruit will prove fruitless, and exhibits signs of making sensible decisions between complex paths. More compute, and a more complex model, is likely required to perfect play.
These results might seem unimpressive, but, as we’ll see, getting a neural network to “see” the game and make good choices is more difficult than it seems.
NB: I provide complete code chunks below for ease of replication. The downside of this is that they are quite long. Feel free to scroll past without guilt.
You can also find the complete code here.
The Environment#
I coded a basic fruit catching game in Python. The game is human-playable with the arrow keys when launching, but is also playable by an AI agent when imported as a module.
Fruit randomly spawn at the top of the game field and fall. If the player catches a fruit, the score is incremented by one. Missed fruit cost the player one point each.
import pygame
import numpy as np
class Field:
# class for compiling the array that the DQN will interpret
def __init__(self, height=10, width=5):
self.width = width
self.height = height
def clear_field(self):
self.body = np.zeros(shape=(self.height, self.width))
def update_field(self, fruits, player):
# draw fruits
for fruit in fruits:
if not fruit.out_of_field:
for y in range(fruit.y, min(fruit.y + fruit.height, self.height)):
for x in range(fruit.x, min(fruit.x + fruit.width, self.width-1)):
self.body[y][x] = 1
# draw player
for i in range(player.width):
self.body[player.y][player.x + i] = 2
class Fruit:
# class for the fruit
def __init__(self, height=1, width=1, x=None, y=0, speed=1, field=None):
self.field = field
self.height = height
self.width = width
self.x = self.generate_x() if x == None else x
self.y = y
self.speed = speed
self.out_of_field = False
self.is_caught = 0
def generate_x(self):
return np.random.randint(0, self.field.width - self.width)
def set_out_of_field(self):
self.out_of_field = True if (self.y > self.field.height - 1) else False
def move(self):
self.y += self.speed
def set_is_caught(self, player):
if self.y != player.y:
self.is_caught = 0
if self.x + self.width > player.x and (self.x < player.x + player.width):
self.is_caught = 1
self.is_caught = -1
class Player:
# class for the player
def __init__(self, height=1, width=1, field=None):
self.field = field
self.height = height
self.width = width
self.x = int(self.field.width / 2 - width / 2)
self.last_x = self.x
self.y = self.field.height - 1
self.dir = 0
self.colour = "blue"
def move(self):
self.last_x = self.x
self.x += self.dir
self.dir = 0
def action(self, action):
if action == 1:
self.dir = -1
elif action == 2:
self.dir = 1
self.dir = 0
def constrain(self):
if self.x < 0:
self.x = self.field.width - self.width
elif (self.x + self.width) > self.field.width:
self.x = 0
class Environment:
# class for the environment
F_WIDTH = 12
ACTION_SPACE = [0, 1, 2]
score = 0
game_tick = 0
FPS = 20
next_spawn_tick = 0
FRUIT_COLOURS = {-1: "red", 0: "black", 1: "green"}
def __init__(self):
def get_state(self):
return self.field.body / self.MAX_VAL
def reset(self):
self.game_tick = 0
self.game_over = False
self.game_won = False
self.field = Field(height=self.F_HEIGHT, width=self.F_WIDTH)
self.player = Player(field=self.field, width=self.PLAYER_WIDTH)
self.score = 0
self.fruits = []
self.field.update_field(self.fruits, self.player)
return self.get_state()
def spawn_fruit(self):
if len(self.fruits) < self.MAX_FRUIT:
self.fruits.append(Fruit(field=self.field, height=self.FRUIT_WIDTH, width=self.FRUIT_WIDTH))
def set_next_spawn_tick(self):
self.next_spawn_tick = self.game_tick + np.random.randint(self.SPAWN_FRUIT_EVERY_MIN, self.SPAWN_FRUIT_EVERY_MAX)
def step(self, action=None):
# this runs every step of the game
# the QDN can pass an action to the game, and in return gets next game state, reward, etc.
self.game_tick += 1
if self.game_tick % self.INCREASE_MAX_FRUIT_EVERY == 0:
self.MAX_FRUIT += 1
if self.game_tick >= self.next_spawn_tick or len(self.fruits) == 0:
if action != None:
reward = 0
if self.game_tick % self.MOVE_FRUIT_EVERY == 0:
in_field_fruits = []
for fruit in self.fruits:
if fruit.is_caught == 1:
reward = self.REWARD
elif fruit.is_caught == -1:
reward = self.PUNISHMENT
if not fruit.out_of_field:
self.fruits = in_field_fruits
self.field.update_field(fruits=self.fruits, player=self.player)
if self.score <= self.LOSS_SCORE:
self.game_over = True
if self.score >= self.WIN_SCORE:
self.game_won = True
return self.get_state(), reward, self.game_over or self.game_won, self.score
def update_score(self, delta):
self.score += delta
def render(self, screen, solo=True, x_offset=0, y_offset=0):
# for rendering the game
if solo:
pygame.display.set_caption(f"Score: {self.score}")
# draw player
((self.player.x * self.DRAW_MUL + x_offset, self.player.y * self.DRAW_MUL + y_offset), (self.player.width * self.DRAW_MUL, self.player.height * self.DRAW_MUL))
# draw fruit
for fruit in self.fruits:
((fruit.x * self.DRAW_MUL + x_offset, fruit.y * self.DRAW_MUL + y_offset), (fruit.width * self.DRAW_MUL, fruit.height * self.DRAW_MUL))
def main():
# if run as a script, the game is human playable at 15fps
env = Environment()
screen = pygame.display.set_mode((env.WINDOW_WIDTH, env.WINDOW_HEIGHT))
clock = pygame.time.Clock()
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
keys = pygame.key.get_pressed()
if keys[pygame.K_LEFT]:
env.player.dir = -1
if keys[pygame.K_RIGHT]:
env.player.dir = 1
return 0
if __name__ == "__main__":
For each tick of the game, the AI passes one of three actions to the game: {stay still, move left, move right}. In return it gets some information: the game state, the reward/punishment, whether the game has ended, and the score. With this information, the AI chooses its next move and learns from past experience.
This game is deceptively difficult for an AI. First, the AI has to “see” what is being represented in two dimensions, then it has to predict the rewards associated with different actions. Rewards are usually the delayed consequence of previous actions. Sometimes there is much fruit visible, other times very little. Fruit also fall frequently enough that catching them all is impossible, and so tough decisions need to be made. Perfect strategy will necessitate ignoring some fruit in order to catch more fruit in the future.
![The agent receives the game state as a 12x12 grid. 1 represents fruit, and 2 represents the player.](/falling_fruit_dqn/fig_1.png)
![The same state as above, displayed on a graph](/falling_fruit_dqn/fig_2.png)
The DQN#
The agent is initiated with the following parameters:
- The environment state shape — i.e. (12, 12, 1).
- The number of actions possible — i.e. 3
- Starting learning rate, and learning rate decay
- Gamma — i.e. how much to discount future rewards (1=no discount)
- Memory size — i.e. how many examples to keep in memory before discarding the oldest
- Exploration max, min, and decay — exploration rate is the proportion of actions made by random chance
The DQN has functions for remembering past states with their rewards, and for replaying and training from its memory.
The DQN also has two versions of its own model. The base model is updated every time the model trains, but the the target model is only updated after set intervals and is used for long term reward estimation. This is to ensure that long term reward predictions are less affected by short-term changes in the model, which could otherwise introduce instability.
class DQN:
def __init__(self, state_shape, action_size, learning_rate_max=0.001, learning_rate_decay=0.995, gamma=0.75,
memory_size=2000, batch_size=32, exploration_max=1.0, exploration_min=0.01, exploration_decay=0.995):
self.state_shape = state_shape
self.state_tensor_shape = (-1,) + state_shape
self.action_size = action_size
self.learning_rate_max = learning_rate_max
self.learning_rate = learning_rate_max
self.learning_rate_decay = learning_rate_decay
self.gamma = gamma
self.memory_size = memory_size
self.memory = PrioritizedReplayBuffer(capacity=2000)
self.batch_size = batch_size
self.exploration_rate = exploration_max
self.exploration_max = exploration_max
self.exploration_min = exploration_min
self.exploration_decay = exploration_decay
self.model = self._build_model()
self.target_model = self._build_model()
def _build_model(self):
# the actual neural network structure
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_uniform', input_shape=self.state_shape))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(self.action_size, activation='linear', name='action_values', kernel_initializer='he_uniform'))
model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate))
return model
def update_target_model(self):
def remember(self, state, action, reward, next_state, done):
self.memory.push((state, action, reward, next_state, done))
def act(self, state, epsilon=None):
if epsilon == None:
epsilon = self.exploration_rate
if np.random.rand() < epsilon:
return random.randrange(self.action_size)
return np.argmax(self.target_model.predict(state, verbose=0)[0])
def replay(self, episode=0):
if self.memory.length() < self.batch_size:
return None
experiences, indices, weights = self.memory.sample(self.batch_size)
unpacked_experiences = list(zip(*experiences))
states, actions, rewards, next_states, dones = [list(arr) for arr in unpacked_experiences]
# Convert to tensors
states = tf.convert_to_tensor(states)
states = tf.reshape(states, self.state_tensor_shape)
actions = tf.convert_to_tensor(actions, dtype=tf.int32)
rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
next_states = tf.convert_to_tensor(next_states)
next_states = tf.reshape(next_states, self.state_tensor_shape)
dones = tf.convert_to_tensor(dones, dtype=tf.float32)
# Compute Q values and next Q values
target_q_values = self.target_model.predict(next_states, verbose=0)
q_values = self.model.predict(states, verbose=0)
# Compute target values using the Bellman equation
max_target_q_values = np.max(target_q_values, axis=1)
targets = rewards + (1 - dones) * self.gamma * max_target_q_values
# Compute TD errors
batch_indices = np.arange(self.batch_size)
q_values_current_action = q_values[batch_indices, actions]
td_errors = targets - q_values_current_action
self.memory.update_priorities(indices, np.abs(td_errors))
# For learning: Adjust Q values of taken actions to match the computed targets
q_values[batch_indices, actions] = targets
loss = self.model.train_on_batch(states, q_values, sample_weight=weights)
self.exploration_rate = self.exploration_max*self.exploration_decay**episode
self.exploration_rate = max(self.exploration_min, self.exploration_rate)
self.learning_rate = self.learning_rate_max*self.learning_rate_decay**episode
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.learning_rate)
return loss
def load(self, name):
self.model = tf.keras.models.load_model(name)
self.target_model = tf.keras.models.load_model(name)
def save(self, name):
Basic DQNs replay scenarios from a simple buffer. Sparse environments, however, can make it very challenging for agents to learn. My environment is pretty sparse — rewards are not available every tick. I found it necessary to implement a Prioritised Replay Buffer that prioritises learning from memories with high error associated with them. This means that the DQN will spend more time learning from events it doesn’t understand well. With alpha set to 0.8, this Buffer quite aggressively favours memories the DQN currently predicts badly.
class PrioritizedReplayBuffer:
def __init__(self, capacity, epsilon=1e-6, alpha=0.8, beta=0.4, beta_increment=0.001):
self.capacity = capacity
self.epsilon = epsilon
self.alpha = alpha # how much prioritisation is used
self.beta = beta # for importance sampling weights
self.beta_increment = beta_increment
self.priority_buffer = np.zeros(self.capacity) = []
self.position = 0
def length(self):
return len(
def push(self, experience):
max_priority = np.max(self.priority_buffer) if else 1.0
if len( < self.capacity:
else:[self.position] = experience
self.priority_buffer[self.position] = max_priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
priorities = self.priority_buffer[:len(]
probabilities = priorities ** self.alpha
probabilities /= probabilities.sum()
indices = np.random.choice(len(, batch_size, p=probabilities)
experiences = [[i] for i in indices]
total = len(
weights = (total * probabilities[indices]) ** (-self.beta)
weights /= weights.max()
self.beta = np.min([1., self.beta + self.beta_increment])
return experiences, indices, weights
def update_priorities(self, indices, errors):
for idx, error in zip(indices, errors):
self.priority_buffer[idx] = error + self.epsilon
Training is relatively straightforward. The agent plays a number of games until they are lost, committing each game state to memory, and retraining from a sample of memories every tick.
Each episode, the learning rate falls, as does the exploration rate. Lowering the learning rate helps the model converge, and starting off with a high exploration rate helps the agent discover strategies and also prevents the model from locking itself into a bad strategy. Models are also saved for future reference.
import numpy as np
from collections import deque
import game
import dqn
GAMMA = 0.975
env = game.Environment()
agent = dqn.DQN(
state = env.reset()
state = np.expand_dims(state, axis=0)
most_recent_losses = deque(maxlen=BATCH_SIZE)
log = []
# fill up memory before training starts
while agent.memory.length() < BATCH_SIZE:
action = agent.act(state)
next_state, reward, done, score = env.step(action)
next_state = np.expand_dims(next_state, axis=0)
agent.remember(state, action, reward, next_state, done)
state = next_state
for e in range(0, EPISODES):
state = env.reset()
state = np.expand_dims(state, axis=0)
done = False
step = 0
ma_loss = None
while not done:
action = agent.act(state)
next_state, reward, done, score = env.step(action)
next_state = np.expand_dims(next_state, axis=0)
agent.remember(state, action, reward, next_state, done)
state = next_state
step += 1
loss = agent.replay(episode=e)
ma_loss = np.array(most_recent_losses).mean()
if loss != None:
print(f"Step: {step}. Score: {score}. -- Loss: {loss}", end=" \r")
if done:
print(f"Episode {e}/{EPISODES-1} completed with {step} steps. Score: {score:.0f}. LR: {agent.learning_rate:.6f}. EP: {agent.exploration_rate:.2f}. MA loss: {ma_loss:.6f}")
log.append([e, step, score, agent.learning_rate, agent.exploration_rate, ma_loss])'models/{e}.h5')
Replaying memories#
If you want to see how your agent performs at any particular point, you can load it into a fresh environment like this:
import numpy as np
import game
import dqn
model_path = "models\5.h5"
env = game.Environment()
agent = dqn.DQN(
state = env.reset()
state = np.expand_dims(state, axis=0)
import pygame
screen = pygame.display.set_mode((env.WINDOW_WIDTH, env.WINDOW_HEIGHT))
clock = pygame.time.Clock()
running = True
score = 0
while running:
pygame.display.set_caption(f"Score: {score}")
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
action = agent.act(state, 0)
state, reward, done, score = env.step(action)
state = np.expand_dims(state, axis=0)
Learnings & Strategies#
Getting things working well is often a frustrating case of trial and error. With enough compute, one could grid search for ideal parameters, but I proceeded mainly by feel.
Much training required#
Models in well-designed spaces converged on decent parameters after around 100 episodes. In each episode, the agent might train around 50–100 times on batches of 64 or more examples — that’s the better part of a million examples processed. While still far less than the theoretical number of possible states, that is still a lot of training to master very simple environments.
Model complexity#
This is a computer vision + game-playing task, so requires a fairly complex neural network. With some trial and error I was able to find a model spec that trained quickly enough on an old laptop CPU, but better performance undoubtedly requires more complex neural networks.
Sparse rewards#
DQNs find it difficult to learn in sparse reward environments. In most ticks of my game, no rewards are available. In addition, actions are often rewarded only in several moves’ time. Some of this is addressed below, but I also made things as easy as possible for the agent in early prototypes, while getting to grips with the code and parameters.
In the first proof-of-concept, the agent was told how far away it was from the fruit in the x-axis — so it had very little work to do to predict the best direction of movement. In the next iteration, the agent received the x and y coordinates of the lowest fruit, and its own x coordinate. Only when I was satisfied that these were working did I move onto passing the agent the full state of the board.
I also tweaked the environment to assist with learning, for example by ensuring there is always at least one fruit available. I found that higher batch sizes led to more stable learning, as they were more likely to contain a representative sample of rewards/punishments.
Auxiliary outputs#
An auxiliary task can be given to the agent’s model, so that it does not just predict the reward for each action — in this case, it also predicts the next state.
I experimented with auxiliary outputs and found that this helped the agent learn in some cases, but proved a distraction in more complex cases.
Future rewards#
A feature of DQNs is that they can predict future rewards. However, this is only possible if the network is predicting immediate rewards properly. The agent uses its own predictions to account for future rewards, which obviously fails when rewards are not being predicted accurately.
One solution to this is to use “curriculum learning” — i.e. initially training the model with a simpler version of the problem. In simpler versions of the game, I found that this was key to success. An agent initially learns only from instances in which the fruit is on the row above it. This vastly simplifies the problem, and greatly increases the frequency of rewards/punishments in memory. Once the agent learns how to catch fruit directly above it, it can start to figure out how to move towards fruit that is further away. At first, the agent’s memory contains only occasions where fruit is very low. Gradually other examples are introduced and slowly the agent’s memory becomes representative of the whole game.
The Prioritised Replay Buffer is another solution to this problem, with the advantage that it requires no manipulation of the agent’s experience of the environment. If found this to be the much more performant solution in the final case.
Sensitivity to reward structure#
Agents exhibit sensitivity to reward structure in ways that can be unexpected, and that can hinder learning. Earlier versions of my game would prevent the player from off the screen to the left or right. The inadvertent side-effect of this is that holding down one of the keys appears to generate rewards, because you are actually staying still while fruit falls on you. Agents would therefore often get stuck on the sides, and get into a learning hole. The solution to this was to let the player wrap around the x axis.
Interestingly, the agent learns how to take full advantage of the wrap-around mechanic and successfully chases fruit off each side. The downside is that the agent will have to choose more frequently between moving left and right if fruit can be caught off the sides of the field.
Learning rate annealing#
I found it useful to gradually adjust the learning rate downwards using a decay function. Some trial and error is required to get a feel for when the model starts to get stuck.
Exploration rate decay#
It is common to start training with a very high exploration rate (e.g. 1) so that the agent is acting at random. The rate is gradually lowered as the agent trains. Remember to set this to 0 when observing agent behaviour, otherwise the agent will act randomly a proportion of the time.