Skip to content

Commit 4d27a54

Browse files
committed
[ALL] MCTS done, fixing dihedral rotation of the board during training and project should be close to done
1 parent ec66e85 commit 4d27a54

File tree

16 files changed

+359
-212
lines changed

16 files changed

+359
-212
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ __pycache__
22
data/*
33
Sabaki/
44
*.pyc
5-
test.py
65
saved_models/
76
*.py.lprof
87
pytorch/

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ Ongoing project.
66

77
# TODO (in order of priority)
88

9-
* Dihedral group of board for more training samples
10-
* Sample random rotation or reflection in dihedral group during MCTS
9+
* Optimization ?
10+
* MCTS
11+
* Multithreading of search (cant multiprocess because of virtual loss, but useless in Python) ?
1112
* File of constants that match the paper constants
1213
* OGS / KGS API
1314
* Better Komi ?
@@ -16,14 +17,15 @@ Ongoing project.
1617
# CURRENTLY DOING
1718

1819
* Brainlag on loss : cross entropy or KLDiv (crossentropy - entropy) ??
19-
* MCTS
20-
* Tree search
21-
* Adaptative temperature (close to 1 during the first 30 moves of self-play, close to 0 after and during evaluation)
22-
* Dirichlet noise to prior probabilities in the rootnode
23-
* Multithreading of search (cant multiprocess because of virtual loss, but useless in Python) ?
20+
* Dihedral group of board for more training samples
21+
* Sample random rotation or reflection in dihedral group during MCTS
2422

2523
# DONE
2624

25+
* MCTS
26+
* Tree search
27+
* Dirichlet noise to prior probabilities in the rootnode
28+
* Adaptative temperature (either take max or proportionally)
2729
* Learning without MCTS doesnt seem to work
2830
* Resume training
2931
* GTP on trained models (human.py, to plug with Sabaki)
@@ -41,7 +43,6 @@ Ongoing project.
4143

4244
* Compile my own version of Sabaki to watch games automatically while traning
4345
* Statistics
44-
* Optimization ?
4546
* Tromp Taylor scoring ?
4647
* Resignation ?
4748
* Training on a big computer / server once everything is ready ?

const.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
DTYPE_FLOAT = torch.cuda.FloatTensor if CUDA else torch.FloatTensor
1010
DTYPE_LONG = torch.cuda.LongTensor if CUDA else torch.LongTensor
1111
## Number of self-play parallel games
12-
PARRALEL_SELF_PLAY = 1
12+
PARRALEL_SELF_PLAY = 3
1313
## Number of evaluation parralel games
1414
PARRALEL_EVAL = 2
1515
## MCTS parallel
@@ -21,22 +21,20 @@
2121

2222
## Size of the Go board
2323
GOBAN_SIZE = 9
24+
## Number of move to end a game
25+
MOVE_LIMIT = GOBAN_SIZE ** 2 * 2.2
2426
## Number of last states to keep
2527
HISTORY = 7
2628
## Learning rate
2729
LR = 0.01
28-
## Number of epochs
29-
EPOCHS = 100
3030
## Number of MCTS simulation
31-
MCTS_SIM = 200
32-
## Temperature
33-
TEMP = 2
31+
MCTS_SIM = 5
3432
## Exploration constant
3533
C_PUCT = 0.2
3634
## L2 Regularization
37-
L2_REG = 0.0001
35+
L2_REG = 0.001
3836
## Momentum
39-
MOMENTUM = 0.92
37+
MOMENTUM = 0.9
4038
## Activate MCTS
4139
MCTS_FLAG = True
4240
## Epsilon for Dirichlet noise
@@ -50,27 +48,27 @@
5048
##### SELF-PLAY
5149

5250
## Number of self-play before training
53-
SELF_PLAY_MATCH = 40
51+
SELF_PLAY_MATCH = 2 * PARRALEL_SELF_PLAY
5452

5553
#####
5654

5755

5856
##### TRAINING
5957

6058
## Number of moves to consider when creating the batch
61-
MOVES = 10000
59+
MOVES = 2000
6260
## Number of mini-batch before evaluation during training
63-
BATCH_SIZE = 64
61+
BATCH_SIZE = 32
6462
## Number of channels of the output feature maps
6563
OUTPLANES_MAP = 10
6664
## Shape of the input state
6765
INPLANES = (HISTORY + 1) * 2 + 1
6866
## Probabilities for all moves + pass
6967
OUTPLANES = (GOBAN_SIZE ** 2) + 1
7068
## Number of residual blocks
71-
BLOCKS = 10
69+
BLOCKS = 5
7270
## Number of training step before evaluating
73-
TRAIN_STEPS = 200
71+
TRAIN_STEPS = 400
7472
## Optimizer
7573
ADAM = False
7674
## Learning rate annealing factor
@@ -85,7 +83,7 @@
8583

8684
## Number of matches against its old version to evaluate
8785
## the newly trained network
88-
EVAL_MATCHS = 50
86+
EVAL_MATCHS = 20
8987
## Threshold to keep the new neural net
9088
EVAL_THRESH = 0.53
9189

lib/dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from const import *
33
import numpy as np
44
import timeit
5+
from . import utils
56

67
class SelfPlayDataset(Dataset):
78
"""
@@ -23,8 +24,11 @@ def __len__(self):
2324

2425

2526
def __getitem__(self, idx):
26-
return self.states[idx], self.plays[idx], \
27+
28+
return utils.sample_rotation(self.states[idx]), self.plays[idx], \
2729
self.winners[idx]
30+
# return self.states[idx], self.plays[idx], \
31+
# self.winners[idx]
2832

2933

3034
def update(self, game):

lib/go.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_legal_moves(self):
9090

9191
for pachi_move in legal_moves:
9292
move = _coord_to_action(self.board, pachi_move)
93-
if self.test_move(move):
93+
if move != 81 or self.test_move(move):
9494
final_moves.append(move)
9595

9696
if len(final_moves) == 0:
@@ -172,6 +172,9 @@ def step(self, action):
172172

173173

174174
def __deepcopy__(self, memo):
175+
""" Used to overwrite the deepcopy implicit method since
176+
the board cannot be deepcopied """
177+
175178
cls = self.__class__
176179
result = cls.__new__(cls)
177180
memo[id(self)] = result

lib/gtp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def __init__(self, game, komi=7.5, board_size=19, version="0.2", name="AlphaGo")
129129
def send(self, message):
130130
message_id, command, arguments = parse_message(message)
131131
if command in self.known_commands:
132-
try:
133-
return format_success(
134-
message_id, getattr(self, "cmd_" + command)(arguments))
135-
except ValueError as exception:
136-
return format_error(message_id, exception.args[0])
132+
# try:
133+
return format_success(
134+
message_id, getattr(self, "cmd_" + command)(arguments))
135+
# except ValueError as exception:
136+
# return format_error(message_id, exception.args[0])
137137
else:
138138
return format_error(message_id, "unknown command")
139139

lib/play.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_matches(player, opponent=None, cores=1, match_number=10):
4646

4747

4848

49-
def self_play(current_time, ite):
49+
def self_play(current_time, loaded_version):
5050
"""
5151
Used to create a learning dataset for the value and policy network.
5252
Play against itself and backtrack the winner to maximize winner moves
@@ -56,21 +56,25 @@ def self_play(current_time, ite):
5656
client = MongoClient()
5757
collection = client.superGo[current_time]
5858
game_id = 0
59-
improvements = 1
59+
current_version = 1
6060
player = False
6161

6262
while True:
6363

6464
## Load the player when restarting traning
65-
if ite:
66-
new_player, improvements = load_player(current_time, ite)
65+
if loaded_version:
66+
new_player, checkpoint = load_player(current_time,
67+
loaded_version)
6768
game_id = collection.find().count()
68-
ite = False
69+
current_version = checkpoint['version'] + 1
70+
loaded_version = False
6971
else:
70-
new_player, improvements = get_player(current_time, improvements)
72+
new_player, checkpoint = get_player(current_time, current_version)
73+
if new_player:
74+
current_version = checkpoint['version'] + 1
7175

72-
print("[PLAY] Current improvement level: %d" % improvements)
73-
if improvements == 1 and not player and not new_player:
76+
print("[PLAY] Current improvement level: %d" % current_version)
77+
if current_version == 1 and not player and not new_player:
7478
print("[PLAY] Waiting for first player")
7579
time.sleep(5)
7680
continue
@@ -92,7 +96,6 @@ def self_play(current_time, ite):
9296
game_id += 1
9397
print("[PLAY] Done fetching")
9498
queue.close()
95-
time.sleep(15)
9699

97100

98101
def play(player, opponent):
@@ -152,17 +155,11 @@ def __init__(self, player, id, color="black", mcts_flag=MCTS_FLAG, goban_size=GO
152155
self.id = id + 1
153156
self.board = self._create_board(color)
154157
self.player_color = 2 if color == "black" else 1
158+
self.mcts = mcts_flag
155159
if mcts_flag:
156-
if opponent:
157-
self.player = (player, MCTS(player, competitive=True))
158-
self.opponent = (opponent, MCTS(player, competitive=True))
159-
else:
160-
self.player = (player, MCTS(player))
161-
self.opponent = False
162-
else:
163-
self.player = (player, False)
164-
self.opponent = False
165-
160+
self.mcts = MCTS()
161+
self.player = player
162+
self.opponent = opponent
166163

167164
def _create_board(self, color):
168165
"""
@@ -200,28 +197,39 @@ def _get_move(self, board, probas):
200197
return player_move
201198

202199

203-
def _play(self, state, player):
200+
def _play(self, state, player, other_pass, competitive=False):
204201
""" Choose a move depending on MCTS or not """
205202

206-
if player[1]:
207-
action_scores, action = player[1].search(self.board)
203+
if self.mcts:
204+
if player.passed is True or other_pass:
205+
action_scores = np.zeros((self.goban_size ** 2 + 1,))
206+
action_scores[-1] = 1
207+
action = self.goban_size ** 2
208+
else:
209+
action_scores, action = self.mcts.search(self.board, player,\
210+
competitive=competitive)
211+
212+
if action == self.goban_size ** 2:
213+
player.passed = True
214+
208215
else:
209216
feature_maps = player.extractor(state)
210217
probas = player.policy_net(feature_maps)[0] \
211218
.cpu().data.numpy()
212219
if player.passed is True:
213-
player_move = self.goban_size ** 2
220+
action = self.goban_size ** 2
214221
else:
215-
player_move = self._get_move(self.board, probas)
222+
action = self._get_move(self.board, probas)
216223

217-
if player_move == self.goban_size ** 2:
224+
if action == self.goban_size ** 2:
218225
player.passed = True
219226

220227
action_scores = np.zeros((self.goban_size ** 2 + 1),)
221-
action_scores[player_move] = 1
228+
action_scores[action] = 1
222229

223-
state, reward, done = self.board.step(player_move)
224-
return state, reward, done, action_scores
230+
state, reward, done = self.board.step(action)
231+
self.board.render()
232+
return state, reward, done, action_scores, action
225233

226234

227235
def __call__(self):
@@ -235,23 +243,28 @@ def __call__(self):
235243
state = self.board.reset()
236244
dataset = []
237245
moves = 0
246+
comp = False
238247

239248
while not done:
240249
## Prevent cycling in 2 atari situations
241-
if moves > 60 * self.goban_size:
250+
if moves > MOVE_LIMIT:
242251
return False
252+
253+
if moves > MOVE_LIMIT / 24:
254+
comp = True
243255

244256
## For evaluation
245257
if self.opponent:
246-
state, reward, done, _ = self._play(_prepare_state(state), \
247-
self.player)
248-
state, reward, done, _ = self._play(_prepare_state(state), \
249-
self.opponent)
258+
state, reward, done, _, action = self._play(_prepare_state(state), \
259+
self.player, self.opponent.passed, competitive=True)
260+
state, reward, done, _, action = self._play(_prepare_state(state), \
261+
self.opponent, self.player.passed, competitive=True)
250262
moves += 2
263+
251264
## For self-play
252265
else:
253266
state = _prepare_state(state)
254-
new_state, reward, done, probas = self._play(state, self.player)
267+
new_state, reward, done, probas, _ = self._play(state, self.player, False, competitive=comp)
255268
self._swap_color()
256269
dataset.append((state.cpu().data.numpy(), probas, \
257270
self.player_color))
@@ -260,7 +273,9 @@ def __call__(self):
260273

261274
## Pickle the result because multiprocessing
262275
if self.opponent:
276+
self.opponent.passed = False
263277
return pickle.dumps([reward])
278+
self.player.passed = False
264279
return pickle.dumps((dataset, reward))
265280

266281

@@ -271,9 +286,9 @@ def solo_play(self, move=None):
271286
## Agent plays the first move of the game
272287
if move is None:
273288
state = _prepare_state(self.board.state)
274-
state, reward, done, player_move = self._play(state, self.player)
289+
state, reward, done, probas, move = self._play(state, self.player)
275290
self._swap_color()
276-
return player_move
291+
return move
277292
## Otherwise just play a move and answer it
278293
else:
279294
state, reward, done = self.board.step(move)

0 commit comments

Comments
 (0)