Hello, Fruit API !

In this tutorial, we use Arcade Learning Environment to demonstrate Fruit API. In particular, we use the A3C method [paper] to train an AI agent to play Breakout. The following screenshot describes the game interface. The agent's mission is to control the paddle to catch the ball, keep it alive, and break the walls to get a high score.

figure 3


Start The Engine

  • Open PyCharm.
  • figure 4


  • Click Open.
  • figure 5


  • Browse to Fruit-API/fruit/samples/basic/a3c_test.py.
  • figure 6


  • We will use the sample code in a3c_test.py to demonstrate the framework.
  • Put Breakout rom (breakout.bin) into the folder fruit/envs/roms/.
  • Press run it. The program will train the agent to learn Breakout. It takes 1-2 days to train it.
  • If the program returns errors relating to matplotlib, go to Settings -> Tools -> Python Scientific -> Disable Show plots in toolwindow.
from fruit.agents.factory import AgentFactory
from fruit.configs.a3c import AtariA3CConfig
from fruit.envs.ale import ALEEnvironment
from fruit.learners.a3c import A3CLearner
from fruit.networks.policy import PolicyNetwork


def train_ale_environment():
    # Create an ALE for Breakout
    environment = ALEEnvironment(ALEEnvironment.BREAKOUT)

    # Create a network configuration for Atari A3C
    network_config = AtariA3CConfig(environment, initial_learning_rate=0.004, debug_mode=True)

    # Create a shared network for A3C agent
    network = PolicyNetwork(network_config, max_num_of_checkpoints=40)

    # Create an A3C agent
    agent = AgentFactory.create(A3CLearner, network, environment, 
                num_of_epochs=40, steps_per_epoch=1e6,
                checkpoint_frequency=1e6, 
                log_dir='./train/breakout/a3c_checkpoints')

    # Train it
    agent.train()


if __name__ == '__main__':
    train_ale_environment()
try:
                                                        # Use example's library to make requests...
                                                        pass
                                                        except example.error.CardError as e:
                                                        # Since it's a decline, example.error.CardError will be caught
                                                        body = e.json_body
                                                        err  = body['error']

                                                        print "Status is: %s" % e.http_status
                                                        print "Type is: %s" % err['type']
                                                        print "Code is: %s" % err['code']
                                                        # param is '' in this case
                                                        print "Param is: %s" % err['param']
                                                        print "Message is: %s" % err['message']
                                                        except example.error.RateLimitError as e:
                                                        # Too many requests made to the API too quickly
                                                        pass
                                                        except example.error.InvalidRequestError as e:
                                                        # Invalid parameters were supplied to example's API
                                                        pass
                                                        except example.error.AuthenticationError as e:
                                                        # Authentication with example's API failed
                                                        # (maybe you changed API keys recently)
                                                        pass
                                                        except example.error.APIConnectionError as e:
                                                        # Network communication with example failed
                                                        pass
                                                        except example.error.exampleError as e:
                                                        # Display a very generic error to the user, and maybe send
                                                        # yourself an email
                                                        pass
                                                        except Exception as e:
                                                        # Something else happened, completely unrelated to example
                                                        pass

Bingo !

Congratulation! You just trained an agent to play Breakout, which can beat a professional human player. The agent can obtain a high score of 860 in Breakout, as shown in the following reward distribution over the course of training (40 million steps or 160 million frames).

figure 7

Eat the pie !

Now, we will evaluate the trained models again to see the agent in action.

  • When the training finishes, the program saves the neural network's parameters (checkpoint or model) in different training steps, as shown in the below figure.
  • figure 8


  • The following program will load the model (after 39 million training steps) and evaluate it.

  • from fruit.agents.factory import AgentFactory
    from fruit.configs.a3c import AtariA3CConfig
    from fruit.envs.ale import ALEEnvironment
    from fruit.learners.a3c import A3CLearner
    from fruit.networks.policy import PolicyNetwork
    
    
    def evaluate_ale_environment():
        # Create an ALE for Breakout and enable GUI
        environment = ALEEnvironment(ALEEnvironment.BREAKOUT, is_render=True)
    
        # Create a network configuration for Atari A3C
        network_config = AtariA3CConfig(environment)
    
        # Create a shared network for A3C agent
        network = PolicyNetwork(network_config, 
                    load_model_path='./train/breakout/a3c_checkpoints_10-23-2019-02-13/model-39030506')
    
        # Create an A3C agent, use only one learner as we want to show a GUI
        agent = AgentFactory.create(A3CLearner, network, environment, 
                    num_of_epochs=1, steps_per_epoch=10000,
                    num_of_learners=1, 
                    log_dir='./test/breakout/a3c_checkpoints')
    
        # Evaluate it
        agent.evaluate()
    
    
    if __name__ == '__main__':
        evaluate_ale_environment()

  • Now we can see the agent to play the game.

  • figure 9

Code Inspection

The code is quite self-explanatory.

  • The first step is to create the game (environment). In this case, we developed a wrapper ALEEnvironment that is a sub-class of BaseEnvironment. ALEEnvironment plays as a role to communicate between ALE and Fruit API via a unique interface (declared by BaseEnvironment). Therefore, to integrate other environments into Fruit API, we need to create a subclass of BaseEnvironment and implement all functions declared by BaseEnvironment.
  • The second step is to create a configuration AtariA3CConfig, which basically includes a network architecture, training operations, an objective function, and a suitable optimizer.
  • The configuration is then plugged into a PolicyNetwork, which is used to initialize the neural network by using the provided configuration.
  • Finally, the algorithm (learner) A3CLearner is generated by the AgentFactory.
  • Therefore, to implement a different RL algorithm, we create a configuration and a learner and plug them into the framework via PolicyNetwork and AgentFactory.