Skip to content

Commit 47c864e

Browse files
committed
[ALL] Code cleaning, fixed 0.4 PyTorch old code
1 parent 3860bbe commit 47c864e

File tree

9 files changed

+293
-290
lines changed

9 files changed

+293
-290
lines changed

const.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
## Dtype of the tensors depending on CUDA
99
DEVICE = torch.device("cuda") if CUDA else torch.device("cpu")
1010
## Number of self-play parallel games
11-
PARALLEL_SELF_PLAY = 3
11+
PARALLEL_SELF_PLAY = 2
1212
## Number of evaluation parallel games
1313
PARALLEL_EVAL = 2
1414
## MCTS parallel
@@ -43,6 +43,9 @@
4343
BATCH_SIZE_EVAL = 4
4444
## Number of self-play before training
4545
SELF_PLAY_MATCH = 2 * PARALLEL_SELF_PLAY
46+
## Number of moves before changing temperature to stop
47+
## exploration
48+
TEMPERATURE_MOVE = 5
4649

4750

4851
##### TRAINING

lib/dataset.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from torch.utils.data import Dataset, DataLoader
2-
from const import *
31
import numpy as np
42
import timeit
3+
from torch.utils.data import Dataset, DataLoader
4+
from const import *
55
from . import utils
66

7+
78
class SelfPlayDataset(Dataset):
89
"""
910
Self-play dataset containing state, probabilities
@@ -47,8 +48,3 @@ def update(self, game):
4748
winners[np.where(winners != -1)] = 1
4849
self.winners[:number_moves] = winners
4950
return number_moves
50-
51-
52-
def update_batch(self, raw_dataset):
53-
for game in raw_dataset:
54-
self.update(game)

lib/evaluate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import timeit
12
from .play import play
23
from const import *
34

@@ -7,7 +8,14 @@ def evaluate(player, new_player):
78
the newly trained model """
89

910
print("[EVALUATION] Starting to evaluate trained model !")
11+
start_time = timeit.default_timer()
12+
## Play the matches and get the results
1013
results = play(player, opponent=new_player)
14+
final_time = timeit.default_timer() - start_time
15+
print("[EVALUATION] Total duration: %.3f seconds, average duration:"
16+
" %0.3f seconds" % ((final_time, final_time / EVAL_MATCHS)))
17+
18+
## Count the number of wins for each players
1119
black_wins = 0
1220
white_wins = 0
1321
for result in results:
@@ -18,6 +26,9 @@ def evaluate(player, new_player):
1826

1927
print("[EVALUATION] black wins: %d vs %d for white"\
2028
% (black_wins, white_wins))
29+
30+
## Check if the trained player (black) is better than
31+
## the current best player depending on the threshold
2132
if black_wins >= EVAL_THRESH * len(results):
2233
return True
2334
return False

lib/game.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import numpy as np
2+
import pickle
3+
from const import *
4+
from models.mcts import MCTS
5+
from .go import GoEnv as Board
6+
from .utils import _prepare_state
7+
8+
9+
class Game:
10+
""" A single process that is used to play a game between 2 agents """
11+
12+
def __init__(self, player, id, color="black", mcts_flag=MCTS_FLAG, goban_size=GOBAN_SIZE, opponent=False):
13+
self.goban_size = goban_size
14+
self.id = id + 1
15+
self.human_pass = False
16+
self.board = self._create_board(color)
17+
self.player_color = 2 if color == "black" else 1
18+
self.mcts = mcts_flag
19+
if mcts_flag:
20+
self.mcts = MCTS()
21+
self.player = player
22+
self.opponent = opponent
23+
24+
25+
def _create_board(self, color):
26+
"""
27+
Create a board with a goban_size and the color is
28+
for the starting player
29+
"""
30+
31+
board = Board(color, self.goban_size)
32+
board.reset()
33+
return board
34+
35+
36+
def _swap_color(self):
37+
if self.player_color == 1:
38+
self.player_color = 2
39+
else:
40+
self.player_color = 1
41+
42+
43+
def _get_move(self, board, probas):
44+
""" Select a move without MCTS """
45+
46+
player_move = None
47+
legal_moves = board.get_legal_moves()
48+
49+
while player_move not in legal_moves and len(legal_moves) > 0:
50+
player_move = np.random.choice(probas.shape[0], p=probas)
51+
if player_move not in legal_moves:
52+
old_proba = probas[player_move]
53+
probas = probas + (old_proba / (probas.shape[0] - 1))
54+
probas[player_move] = 0
55+
56+
return player_move
57+
58+
59+
def _play(self, state, player, other_pass, competitive=False):
60+
""" Choose a move depending on MCTS or not """
61+
62+
if self.mcts:
63+
if player.passed is True or other_pass:
64+
action_scores = np.zeros((self.goban_size ** 2 + 1,))
65+
action_scores[-1] = 1
66+
action = self.goban_size ** 2
67+
else:
68+
action_scores, action = self.mcts.search(self.board, player,\
69+
competitive=competitive)
70+
71+
if action == self.goban_size ** 2:
72+
player.passed = True
73+
74+
else:
75+
feature_maps = player.extractor(state)
76+
probas = player.policy_net(feature_maps)[0] \
77+
.cpu().data.numpy()
78+
if player.passed is True:
79+
action = self.goban_size ** 2
80+
else:
81+
action = self._get_move(self.board, probas)
82+
83+
if action == self.goban_size ** 2:
84+
player.passed = True
85+
86+
action_scores = np.zeros((self.goban_size ** 2 + 1),)
87+
action_scores[action] = 1
88+
89+
state, reward, done = self.board.step(action)
90+
return state, reward, done, action_scores, action
91+
92+
93+
def __call__(self):
94+
"""
95+
Make a game between the player and the opponent and return all the states
96+
and the associated move. Also returns the winner in order to create the
97+
training dataset
98+
"""
99+
100+
done = False
101+
state = self.board.reset()
102+
dataset = []
103+
moves = 0
104+
comp = False
105+
106+
while not done:
107+
## Prevent cycling in 2 atari situations
108+
if moves > MOVE_LIMIT:
109+
return pickle.dumps((dataset, self.board.get_winner()))
110+
111+
## Magic ratio for adaptative temperature
112+
if moves > TEMPERATURE_MOVE:
113+
comp = True
114+
115+
## For evaluation
116+
if self.opponent:
117+
state, reward, done, _, action = self._play(_prepare_state(state), \
118+
self.player, self.opponent.passed, competitive=True)
119+
state, reward, done, _, action = self._play(_prepare_state(state), \
120+
self.opponent, self.player.passed, competitive=True)
121+
moves += 2
122+
123+
## For self-play
124+
else:
125+
state = _prepare_state(state)
126+
new_state, reward, done, probas, action = self._play(state, self.player, \
127+
False, competitive=comp)
128+
self._swap_color()
129+
dataset.append((state.cpu().data.numpy(), probas, \
130+
self.player_color, action))
131+
state = new_state
132+
moves += 1
133+
134+
## Pickle the result because multiprocessing
135+
if self.opponent:
136+
print("[EVALUATION] Match %d done in eval, winner %s" % (self.id, "black" if reward == 0 else "white"))
137+
self.opponent.passed = False
138+
return pickle.dumps([reward])
139+
140+
self.player.passed = False
141+
return pickle.dumps((dataset, reward))
142+
143+
144+
def solo_play(self, move=None):
145+
""" Used to play against a human or for GTP, cant be called
146+
in a multiprocess scenario """
147+
148+
## Agent plays the first move of the game
149+
if move is None:
150+
state = _prepare_state(self.board.state)
151+
state, reward, done, probas, move = self._play(state, self.player, self.human_pass, competitive=True)
152+
self._swap_color()
153+
return move
154+
## Otherwise just play a move and answer it
155+
else:
156+
state, reward, done = self.board.step(move)
157+
if move != self.board.board_size ** 2:
158+
self.mcts.advance(move)
159+
else:
160+
self.human_pass = True
161+
self._swap_color()
162+
return True
163+
164+
165+
def reset(self):
166+
state = self.board.reset()

0 commit comments

Comments
 (0)