In this tutorial, we use Arcade Learning Environment to demonstrate Fruit API. In particular, we use the A3C method [paper] to train an AI agent to play Breakout. The following screenshot describes the game interface. The agent's mission is to control the paddle to catch the ball, keep it alive, and break the walls to get a high score.
Fruit-API/fruit/samples/basic/a3c_test.py
.
a3c_test.py
to demonstrate the framework.breakout.bin
) into the folder fruit/envs/roms/
.matplotlib
, go to Settings -> Tools -> Python Scientific -> Disable Show plots in toolwindow.from fruit.agents.factory import AgentFactory from fruit.configs.a3c import AtariA3CConfig from fruit.envs.ale import ALEEnvironment from fruit.learners.a3c import A3CLearner from fruit.networks.policy import PolicyNetwork def train_ale_environment(): # Create an ALE for Breakout environment = ALEEnvironment(ALEEnvironment.BREAKOUT) # Create a network configuration for Atari A3C network_config = AtariA3CConfig(environment, initial_learning_rate=0.004, debug_mode=True) # Create a shared network for A3C agent network = PolicyNetwork(network_config, max_num_of_checkpoints=40) # Create an A3C agent agent = AgentFactory.create(A3CLearner, network, environment, num_of_epochs=40, steps_per_epoch=1e6, checkpoint_frequency=1e6, log_dir='./train/breakout/a3c_checkpoints') # Train it agent.train() if __name__ == '__main__': train_ale_environment()
try:
# Use example's library to make requests...
pass
except example.error.CardError as e:
# Since it's a decline, example.error.CardError will be caught
body = e.json_body
err = body['error']
print "Status is: %s" % e.http_status
print "Type is: %s" % err['type']
print "Code is: %s" % err['code']
# param is '' in this case
print "Param is: %s" % err['param']
print "Message is: %s" % err['message']
except example.error.RateLimitError as e:
# Too many requests made to the API too quickly
pass
except example.error.InvalidRequestError as e:
# Invalid parameters were supplied to example's API
pass
except example.error.AuthenticationError as e:
# Authentication with example's API failed
# (maybe you changed API keys recently)
pass
except example.error.APIConnectionError as e:
# Network communication with example failed
pass
except example.error.exampleError as e:
# Display a very generic error to the user, and maybe send
# yourself an email
pass
except Exception as e:
# Something else happened, completely unrelated to example
pass
Congratulation! You just trained an agent to play Breakout, which can beat a professional human player. The agent can obtain a high score of 860 in Breakout, as shown in the following reward distribution over the course of training (40 million steps or 160 million frames).
Now, we will evaluate the trained models again to see the agent in action.
from fruit.agents.factory import AgentFactory from fruit.configs.a3c import AtariA3CConfig from fruit.envs.ale import ALEEnvironment from fruit.learners.a3c import A3CLearner from fruit.networks.policy import PolicyNetwork def evaluate_ale_environment(): # Create an ALE for Breakout and enable GUI environment = ALEEnvironment(ALEEnvironment.BREAKOUT, is_render=True) # Create a network configuration for Atari A3C network_config = AtariA3CConfig(environment) # Create a shared network for A3C agent network = PolicyNetwork(network_config, load_model_path='./train/breakout/a3c_checkpoints_10-23-2019-02-13/model-39030506') # Create an A3C agent, use only one learner as we want to show a GUI agent = AgentFactory.create(A3CLearner, network, environment, num_of_epochs=1, steps_per_epoch=10000, num_of_learners=1, log_dir='./test/breakout/a3c_checkpoints') # Evaluate it agent.evaluate() if __name__ == '__main__': evaluate_ale_environment()
The code is quite self-explanatory.
ALEEnvironment
that is a sub-class of BaseEnvironment
. ALEEnvironment
plays as a role
to communicate between ALE and Fruit API via a unique interface (declared by BaseEnvironment
). Therefore, to integrate other environments into Fruit API, we need to create a subclass of BaseEnvironment
and implement
all functions declared by BaseEnvironment
.
AtariA3CConfig
, which basically includes a network architecture, training operations, an objective function, and a suitable optimizer.PolicyNetwork
, which is used to initialize the neural network by using the provided configuration.
A3CLearner
is generated by the AgentFactory
.
PolicyNetwork
and AgentFactory
.