diff --git a/.gitignore b/.gitignore index 8082a5b48..3c818fcb7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,5 +15,6 @@ docs/rst docs/sphinx experiments/ dist/ +build/ rlcard/games/doudizhu/jsondata/ rlcard/agents/gin_rummy_human_agent/gui_cards/cards_png diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8fe2f47af --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 9366e8e22..92587f656 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def _get_version(): 'termcolor' ], extras_require=extras, - requires_python='>=3.7', + python_requires='>=3.7', classifiers=[ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.10", diff --git a/train.py b/train.py new file mode 100644 index 000000000..3f0515980 --- /dev/null +++ b/train.py @@ -0,0 +1,157 @@ +import os +import argparse + +import torch + +import rlcard +from rlcard.agents import RandomAgent +from rlcard.utils import ( + get_device, + set_seed, + tournament, + reorganize, + Logger, + plot_curve, +) + +def train(args): + + # Check whether gpu is available + device = get_device() + + # Seed numpy, torch, random + set_seed(args.seed) + + # Make the environment with seed + env = rlcard.make( + args.env, + config={ + 'seed': args.seed, + } + ) + + # Initialize the agent and use random agents as opponents + if args.algorithm == 'dqn': + from rlcard.agents import DQNAgent + agent = DQNAgent( + num_actions=env.num_actions, + state_shape=env.state_shape[0], + mlp_layers=[64,64], + device=device, + ) + elif args.algorithm == 'nfsp': + from rlcard.agents import NFSPAgent + agent = NFSPAgent( + num_actions=env.num_actions, + state_shape=env.state_shape[0], + hidden_layers_sizes=[64,64], + q_mlp_layers=[64,64], + device=device, + ) + agents = [agent] + for _ in range(1, env.num_players): + agents.append(RandomAgent(num_actions=env.num_actions)) + env.set_agents(agents) + + # Start training + with Logger(args.log_dir) as logger: + for episode in range(args.num_episodes): + + if args.algorithm == 'nfsp': + agents[0].sample_episode_policy() + + # Generate data from the environment + trajectories, payoffs = env.run(is_training=True) + + # Reorganaize the data to be state, action, reward, next_state, done + trajectories = reorganize(trajectories, payoffs) + + # Feed transitions into agent memory, and train the agent + # Here, we assume that DQN always plays the first position + # and the other players play randomly (if any) + for ts in trajectories[0]: + agent.feed(ts) + + # Evaluate the performance. Play with random agents. + if episode % args.evaluate_every == 0: + logger.log_performance( + episode, + tournament( + env, + args.num_eval_games, + )[0] + ) + + # Get the paths + csv_path, fig_path = logger.csv_path, logger.fig_path + + # Plot the learning curve + plot_curve(csv_path, fig_path, args.algorithm) + + # Save model + save_path = os.path.join(args.log_dir, 'model.pth') + torch.save(agent, save_path) + print('Model saved in', save_path) + +if __name__ == '__main__': + parser = argparse.ArgumentParser("DQN/NFSP example in RLCard") + parser.add_argument( + '--env', + type=str, + default='leduc-holdem', + choices=[ + 'blackjack', + 'leduc-holdem', + 'limit-holdem', + 'doudizhu', + 'mahjong', + 'no-limit-holdem', + 'uno', + 'gin-rummy', + 'bridge', + ], + ) + parser.add_argument( + '--algorithm', + type=str, + default='dqn', + choices=[ + 'dqn', + 'nfsp', + ], + ) + parser.add_argument( + '--cuda', + type=str, + default='', + ) + parser.add_argument( + '--seed', + type=int, + default=42, + ) + parser.add_argument( + '--num_episodes', + type=int, + default=5000, + ) + parser.add_argument( + '--num_eval_games', + type=int, + default=2000, + ) + parser.add_argument( + '--evaluate_every', + type=int, + default=100, + ) + parser.add_argument( + '--log_dir', + type=str, + default='experiments/leduc_holdem_dqn_result/', + ) + + args = parser.parse_args() + + os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda + train(args)