Skip to content

Commit d298ed0

Browse files
committed
[ALL] Loss function brainlag, cleaning for MCTS, started dihedral rotation sampler
1 parent 70a02f8 commit d298ed0

File tree

8 files changed

+88
-56
lines changed

8 files changed

+88
-56
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ Ongoing project.
66

77
# TODO (in order of priority)
88

9+
* Dihedral group of board for more training samples
10+
* Sample random rotation or reflection in dihedral group during MCTS
911
* File of constants that match the paper constants
1012
* OGS / KGS API
1113
* Better Komi ?
1214
* Use logging instead of prints ?
1315

1416
# CURRENTLY DOING
1517

18+
* Brainlag on loss : cross entropy or KLDiv (crossentropy - entropy) ??
1619
* MCTS
1720
* Tree search
18-
* Rotation of board for more training samples
1921
* Adaptative temperature (close to 1 during the first 30 moves of self-play, close to 0 after and during evaluation)
2022
* Dirichlet noise to prior probabilities in the rootnode
21-
* Multiprocessing of search
23+
* Multiprocessing of search ?
2224

2325
# DONE
2426

const.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
## Dtype of the tensors depending on CUDA
99
DTYPE_FLOAT = torch.cuda.FloatTensor if CUDA else torch.FloatTensor
1010
DTYPE_LONG = torch.cuda.LongTensor if CUDA else torch.LongTensor
11-
## Number of process, used for parallel matching atm
1211
## Number of self-play parallel games
13-
# PARRALEL_SELF_PLAY = multiprocessing.cpu_count() - 2
1412
PARRALEL_SELF_PLAY = 1
1513
## Number of evaluation parralel games
1614
PARRALEL_EVAL = 2
@@ -40,7 +38,9 @@
4038
## Momentum
4139
MOMENTUM = 0.92
4240
## Activate MCTS
43-
MCTS_FLAG = False
41+
MCTS_FLAG = True
42+
## Alpha for Dirichlet noise
43+
EPS = 0.25
4444

4545
#####
4646

@@ -49,8 +49,6 @@
4949

5050
## Number of self-play before training
5151
SELF_PLAY_MATCH = 40
52-
## Number of matches to run per process
53-
NUM_MATCHES = SELF_PLAY_MATCH // PARRALEL_SELF_PLAY
5452

5553
#####
5654

@@ -60,7 +58,7 @@
6058
## Number of moves to consider when creating the batch
6159
MOVES = 10000
6260
## Number of mini-batch before evaluation during training
63-
BATCH_SIZE = 128
61+
BATCH_SIZE = 64
6462
## Number of channels of the output feature maps
6563
OUTPLANES_MAP = 10
6664
## Shape of the input state
@@ -72,10 +70,10 @@
7270
## Number of training step before evaluating
7371
TRAIN_STEPS = 200
7472
## Optimizer
75-
ADAM = True
73+
ADAM = False
7674
## Learning rate annealing factor
7775
LR_DECAY = 0.1
78-
## Learning rate anmnealing interval
76+
## Learning rate annnealing interval
7977
LR_DECAY_ITE = 50 * TRAIN_STEPS
8078

8179
#####

lib/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, mcts_flag=MCTS_FLAG):
1414

1515
self.states = np.zeros((MOVES, (HISTORY + 1) * 2 + 1, GOBAN_SIZE, GOBAN_SIZE))
1616
self.mcts_flag = mcts_flag
17-
if not mcts_flag:
17+
if mcts_flag:
1818
self.plays = np.zeros((MOVES, GOBAN_SIZE ** 2 + 1))
1919
else:
2020
self.plays = np.zeros(MOVES)
@@ -42,7 +42,7 @@ def update(self, game):
4242
self.states[:number_moves] = np.vstack(dataset[:,0])
4343

4444
self.plays = np.roll(self.plays, number_moves, axis=0)
45-
if not self.mcts_flag:
45+
if self.mcts_flag:
4646
self.plays[np.arange(number_moves),np.hstack(dataset[:,1])] = 1
4747
else:
4848
self.plays[:number_moves] = np.hstack(dataset[:,1])
@@ -54,6 +54,7 @@ def update(self, game):
5454
self.winners[:number_moves] = winners
5555
return number_moves
5656

57+
5758
def update_batch(self, raw_dataset):
5859
for game in raw_dataset:
5960
self.update(game)

lib/play.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -184,24 +184,6 @@ def _swap_color(self):
184184
else:
185185
self.player_color = 1
186186

187-
188-
def _draw_move(self, action_scores, competitive=False):
189-
"""
190-
Find the best move, either deterministically for competitive play
191-
or stochiasticly according to some temperature constant
192-
"""
193-
194-
if competitive:
195-
move = np.argmax(action_scores)
196-
197-
else:
198-
action_scores = np.power(action_scores, (1. / TEMP))
199-
total = np.sum(action_scores)
200-
probas = action_scores / total
201-
move = np.random.choice(action_scores.shape[0], p=probas)
202-
203-
return move
204-
205187

206188
def _get_move(self, board, probas):
207189
""" Select a move without MCTS """
@@ -235,22 +217,22 @@ def _get_move(self, board, probas):
235217
def _play(self, state, player):
236218
""" Choose a move depending on MCTS or not """
237219

238-
if not self.mcts_flag:
239-
action_scores = player.mcts.search()
220+
# if self.mcts_flag:
221+
# action_scores = player.mcts.search()
222+
# else:
223+
feature_maps = player.extractor(state)
224+
probas = player.policy_net(feature_maps)[0] \
225+
.cpu().data.numpy()
226+
if player.passed is True:
227+
player_move = self.goban_size ** 2
240228
else:
241-
feature_maps = player.extractor(state)
242-
probas = player.policy_net(feature_maps)[0] \
243-
.cpu().data.numpy()
244-
if player.passed is True:
245-
player_move = self.goban_size ** 2
246-
else:
247-
player_move = self._get_move(self.board, probas)
229+
player_move = self._get_move(self.board, probas)
248230

249-
if player_move == self.goban_size ** 2:
250-
player.passed = True
231+
if player_move == self.goban_size ** 2:
232+
player.passed = True
251233

252-
state, reward, done = self.board.step(player_move)
253-
return state, reward, done, player_move
234+
state, reward, done = self.board.step(player_move)
235+
return state, reward, done, player_move
254236

255237

256238
def __call__(self):

lib/train.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.nn as nn
23
import numpy as np
34
import pickle
45
import time
@@ -20,16 +21,18 @@
2021
class AlphaLoss(torch.nn.Module):
2122
""" Custom loss as defined in the paper """
2223

23-
def __init__(self, mcts_flag=MCTS_FLAG):
24-
self.mcts_flag = mcts_flag
24+
def __init__(self):
2525
super(AlphaLoss, self).__init__()
26+
self.log_softmax = nn.LogSoftmax()
2627

28+
# def forward(self, winner, self_play_winner, probas, self_play_probas):
29+
# value_error = F.mse_loss(winner, self_play_winner)
30+
# policy_error = torch.mean(torch.sum(-self_play_probas * self.log_softmax(probas), 1))
31+
# return value_error + policy_error
32+
2733
def forward(self, winner, self_play_winner, probas, self_play_probas):
2834
value_error = F.mse_loss(winner, self_play_winner)
29-
if not self.mcts_flag:
30-
policy_error = F.binary_cross_entropy(probas, self_play_probas)
31-
else:
32-
policy_error = F.cross_entropy(probas, self_play_probas)
35+
policy_error = F.kl_div(probas, self_play_probas)
3336
return value_error + policy_error
3437

3538

@@ -167,11 +170,11 @@ def new_agent(result):
167170
except KeyboardInterrupt:
168171
client.close()
169172
pool.terminate()
170-
173+
171174
example = {
172175
'state': Variable(state).type(DTYPE_FLOAT),
173176
'winner': Variable(winner).type(DTYPE_FLOAT),
174-
'move' : Variable(move).type(DTYPE_FLOAT if not MCTS_FLAG else DTYPE_LONG)
177+
'move' : Variable(move).type(DTYPE_FLOAT if MCTS_FLAG else DTYPE_LONG)
175178
}
176179
loss = train_epoch(new_player, optimizer, example, criterion)
177180

lib/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
22
from models.agent import Player
3+
import numpy as np
4+
from const import *
5+
import random
36

47

58
def get_ite(folder_path, ite):
@@ -62,3 +65,25 @@ def get_player(current_time, improvements):
6265
player = Player()
6366
player.load_models(path, models)
6467
return player, improvements + 1
68+
69+
70+
def sample_rotation(state, num=8):
71+
dh_group = [(0, 0) (np.rot90, 1), (np.rot90, 2), (np.rot90, 3),
72+
(np.fliplr, 0), (np.flipud, 0), (np.flipud, (np.rot90, 1)), (np.fliplr, (np.rot90, 1))]
73+
74+
dh_group = random.shuffle(dh_group)
75+
states = []
76+
for i in num:
77+
print(i)
78+
assert 0
79+
80+
return state
81+
82+
83+
if __name__ == "__main__":
84+
pass
85+
86+
87+
88+
89+

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def main(folder, ite):
2323
try:
2424
x = pool.apply_async(self_play, args=(current_time, ite,))
2525
y = pool.apply_async(train, args=(current_time, ite,))
26-
x.get()
27-
# y.get()
26+
# x.get()
27+
y.get()
2828
except KeyboardInterrupt:
2929
pool.terminate()
3030
else:

models/mcts.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ def __init__(self, move, probas):
88
self.n = 0
99
self.w = 0
1010
self.q = 0
11+
12+
def expand(self):
13+
pass
1114

1215

1316
class MCTS():
@@ -19,6 +22,24 @@ def __init__(self, c_puct, extractor, value_net, policy_net):
1922
self.c_puct = c_puct
2023

2124

25+
def _draw_move(self, action_scores, competitive=False):
26+
"""
27+
Find the best move, either deterministically for competitive play
28+
or stochiasticly according to some temperature constant
29+
"""
30+
31+
if competitive:
32+
move = np.argmax(action_scores)
33+
34+
else:
35+
action_scores = np.power(action_scores, (1. / TEMP))
36+
total = np.sum(action_scores)
37+
probas = action_scores / total
38+
move = np.random.choice(action_scores.shape[0], p=probas)
39+
40+
return move
41+
42+
2243
def _puct(self, proba, total_count, count):
2344
"""
2445
Function of P and N that increases if an action hasn't been explored
@@ -30,7 +51,7 @@ def _puct(self, proba, total_count, count):
3051
return action_score
3152

3253

33-
def select(self, nodes):
54+
def _select(self, nodes):
3455
"""
3556
Select the move that maximises the mean value of the next state +
3657
the result of the PUCT function
@@ -45,8 +66,8 @@ def select(self, nodes):
4566

4667
return max(action_scores)
4768

48-
def search(self, game):
49-
x = random.choice(actions)
69+
70+
def search(self, game, competitive=False):
5071
return x
5172

5273

0 commit comments

Comments
 (0)