Skip to content
This repository was archived by the owner on Dec 11, 2020. It is now read-only.

Commit d5f320d

Browse files
committed
Big refactor from FAIR team
This commit includes a *ton* of refactoring goodness and various stability fixes.
1 parent 1b6859f commit d5f320d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+2699
-1673
lines changed

.gitmodules

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
[submodule "third_party/concurrentqueue"]
22
path = third_party/concurrentqueue
33
url = https://github.com/cameron314/concurrentqueue.git
4-
[submodule "third_party/cppzmq"]
5-
path = third_party/cppzmq
6-
url = https://github.com/zeromq/cppzmq.git
7-
[submodule "third_party/googletest"]
8-
path = third_party/googletest
9-
url = https://github.com/google/googletest.git
10-
[submodule "third_party/json"]
11-
path = third_party/json
12-
url = https://github.com/nlohmann/json.git
134
[submodule "third_party/pybind11"]
145
path = third_party/pybind11
156
url = https://github.com/pybind/pybind11.git
167
[submodule "third_party/spdlog"]
178
path = third_party/spdlog
189
url = https://github.com/gabime/spdlog.git
10+
[submodule "third_party/json"]
11+
path = third_party/json
12+
url = https://github.com/nlohmann/json.git
13+
[submodule "third_party/googletest"]
14+
path = third_party/googletest
15+
url = https://github.com/google/googletest.git
1916
[submodule "third_party/tbb"]
2017
path = third_party/tbb
2118
url = https://github.com/01org/tbb.git
19+
ignore = untracked
20+
[submodule "third_party/cppzmq"]
21+
path = third_party/cppzmq
22+
url = https://github.com/zeromq/cppzmq.git

scripts/elfgames/go/analysis.sh

Lines changed: 0 additions & 22 deletions
This file was deleted.

src_cpp/elf/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ set(ELF_SOURCES
1414
options/Pybind.cc
1515
)
1616

17-
set(ELF_TEST_SOURCES
18-
options/OptionMapTest.cc
19-
options/OptionSpecTest.cc
20-
)
17+
# set(ELF_TEST_SOURCES
18+
# options/OptionMapTest.cc
19+
# options/OptionSpecTest.cc
20+
# )
2121

2222
# Main ELF library
2323

@@ -26,10 +26,10 @@ target_compile_definitions(elf PUBLIC
2626
GIT_COMMIT_HASH=${GIT_COMMIT_HASH}
2727
GIT_STAGED=${GIT_STAGED_STRING}
2828
)
29+
2930
target_link_libraries(elf PUBLIC
3031
#${Boost_LIBRARIES}
3132
concurrentqueue
32-
cppzmq
3333
nlohmann_json
3434
pybind11
3535
$<BUILD_INTERFACE:${PYTHON_LIBRARIES}>
@@ -40,7 +40,7 @@ target_link_libraries(elf PUBLIC
4040
# Tests
4141

4242
enable_testing()
43-
add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})
43+
# add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})
4444

4545
# Python bindings
4646

src_cpp/elf/ai/ai.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ class AIClientT : public AI_T<S, A> {
8888
status == comm::ReplyStatus::UNKNOWN;
8989
}
9090

91-
virtual bool act_batch(
91+
bool act_batch(
9292
const std::vector<const S*>& batch_s,
93-
const std::vector<A*>& batch_a) {
93+
const std::vector<A*>& batch_a) override {
9494
std::vector<elf::FuncsWithState> funcs_s =
9595
client_->BindStateToFunctions(targets_, batch_s);
9696
std::vector<elf::FuncsWithState> funcs_a =

src_cpp/elf/ai/tree_search/tree_search.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class TreeSearchSingleThreadT {
7070
: threadId_(thread_id), options_(options) {
7171
if (options_.verbose) {
7272
std::string log_file =
73-
"tree_search_" + std::to_string(thread_id) + ".txt";
73+
options_.log_prefix + std::to_string(thread_id) + ".txt";
7474
output_.reset(new std::ofstream(log_file));
7575
}
7676
}
@@ -374,10 +374,7 @@ class TreeSearchT {
374374
return searchTree_.printTree();
375375
}
376376

377-
MCTSResult runPolicyOnly(const State& /*root_state*/) {
378-
// TODO Policy only doesn't work.
379-
assert(false);
380-
/*
377+
MCTSResult runPolicyOnly(const State& root_state) {
381378
if (actors_.empty() || treeSearches_.empty()) {
382379
throw std::range_error(
383380
"TreeSearch::runPolicyOnly works when there is at least one thread");
@@ -386,15 +383,17 @@ class TreeSearchT {
386383

387384
// Some hack here.
388385
Node* root = searchTree_.getRootNode();
389-
treeSearches_[0]->visit(*actors_[0], root);
390386

391-
// return StrongestPrior(root->getStateActions());
392-
*/
387+
if (!root->isVisited()) {
388+
NodeResponseT<Action> resp;
389+
actors_[0]->evaluate(*root->getStatePtr(), &resp);
390+
root->setEvaluation(resp);
391+
}
393392

394393
MCTSResult result;
395-
// result.action_rank_method = MCTSResult::PRIOR;
396-
// result.addActions(root->getStateActions());
397-
394+
result.action_rank_method = MCTSResult::PRIOR;
395+
result.addActions(root->getStateActions());
396+
result.root_value = root->getValue();
398397
return result;
399398
}
400399

@@ -490,6 +489,8 @@ class TreeSearchT {
490489

491490
// Pick the best solution.
492491
MCTSResult result;
492+
result.root_value = root->getValue();
493+
493494
// MCTSResult result2;
494495
if (options_.pick_method == "strongest_prior") {
495496
result.action_rank_method = MCTSResult::PRIOR;

src_cpp/elf/ai/tree_search/tree_search_base.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ struct StateTrait {
6161
return s1 == s2;
6262
}
6363

64-
static bool
65-
moves_since(const S& s, size_t* next_move_number, std::vector<A>* moves) {
64+
static bool moves_since(
65+
const S& /*s*/,
66+
size_t* /*next_move_number*/,
67+
std::vector<A>* /*moves*/) {
6668
// By default it is not provided.
6769
return false;
6870
}
@@ -84,7 +86,7 @@ struct ActionTrait {
8486
template <typename Actor>
8587
struct ActorTrait {
8688
public:
87-
static std::string to_string(const Actor& a) {
89+
static std::string to_string(const Actor&) {
8890
return "";
8991
}
9092
};
@@ -213,6 +215,7 @@ struct MCTSResultT {
213215
enum RankCriterion { MOST_VISITED = 0, PRIOR = 1, UNIFORM_RANDOM };
214216

215217
Action best_action;
218+
float root_value;
216219
float max_score;
217220
EdgeInfo best_edge_info;
218221
MCTSPolicy<Action> mcts_policy;
@@ -224,6 +227,7 @@ struct MCTSResultT {
224227
// action_edges ssengupta@fb.com
225228
MCTSResultT()
226229
: best_action(ActionTrait<Action>::default_value()),
230+
root_value(0.0),
227231
max_score(std::numeric_limits<float>::lowest()),
228232
best_edge_info(0),
229233
total_visits(0),

src_cpp/elf/ai/tree_search/tree_search_options.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct TSOptions {
8585
bool persistent_tree = false;
8686
float root_epsilon = 0.0;
8787
float root_alpha = 0.0;
88+
std::string log_prefix = "";
8889

8990
// [TODO] Not a good design.
9091
// string pick_method = "strongest_prior";
@@ -102,6 +103,7 @@ struct TSOptions {
102103
ss << "Maximal #moves (0 = no constraint): " << max_num_moves
103104
<< std::endl;
104105
ss << "Seed: " << seed << std::endl;
106+
ss << "Log Prefix: " << log_prefix << std::endl;
105107
ss << "#Threads: " << num_threads << std::endl;
106108
ss << "#Rollout per thread: " << num_rollouts_per_thread
107109
<< ", #rollouts per batch: " << num_rollouts_per_batch << std::endl;
@@ -156,6 +158,9 @@ struct TSOptions {
156158
if (t1.pick_method != t2.pick_method) {
157159
return false;
158160
}
161+
if (t1.log_prefix != t2.log_prefix) {
162+
return false;
163+
}
159164
if (t1.root_epsilon != t2.root_epsilon) {
160165
return false;
161166
}
@@ -181,6 +186,7 @@ struct TSOptions {
181186
JSON_SAVE(j, seed);
182187
JSON_SAVE(j, persistent_tree);
183188
JSON_SAVE(j, pick_method);
189+
JSON_SAVE(j, log_prefix);
184190
JSON_SAVE(j, root_epsilon);
185191
JSON_SAVE(j, root_alpha);
186192
JSON_SAVE(j, virtual_loss);
@@ -198,6 +204,7 @@ struct TSOptions {
198204
JSON_LOAD(opt, j, seed);
199205
JSON_LOAD(opt, j, persistent_tree);
200206
JSON_LOAD(opt, j, pick_method);
207+
JSON_LOAD(opt, j, log_prefix);
201208
JSON_LOAD(opt, j, root_epsilon);
202209
JSON_LOAD(opt, j, root_alpha);
203210
JSON_LOAD(opt, j, virtual_loss);
@@ -213,6 +220,7 @@ struct TSOptions {
213220
verbose,
214221
persistent_tree,
215222
pick_method,
223+
log_prefix,
216224
virtual_loss,
217225
verbose_time,
218226
alg_opt,

src_cpp/elf/base/context.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class Context {
124124

125125
void start() {
126126
th_.reset(new std::thread([&]() {
127-
assert(nice(10) == 10);
127+
// assert(nice(10) == 10);
128128
collectAndSendBatch();
129129
}));
130130
}
@@ -183,11 +183,17 @@ class Context {
183183
}
184184
}
185185
smem_->waitBatchFillMem(server_);
186-
// LOG(INFO) << "Receiver: Batch received. #batch = "
187-
// << batch.size() << std::endl;
186+
// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
187+
// received. #batch = "
188+
// << smem_->getEffectiveBatchSize() << std::endl;
189+
188190
comm::ReplyStatus batch_status =
189191
batchClient_->sendWait(smem_.get(), {""});
190192

193+
// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
194+
// releasing. #batch = "
195+
// << smem_->getEffectiveBatchSize() << std::endl;
196+
191197
// LOG(INFO) << "Receiver: Release batch" << std::endl;
192198
smem_->waitReplyReleaseBatch(server_, batch_status);
193199
}
@@ -280,7 +286,7 @@ class Context {
280286
auto* client = getClient();
281287
for (int i = 0; i < num_games_; ++i) {
282288
game_threads_.emplace_back([i, client, this]() {
283-
assert(nice(19) == 19);
289+
// assert(nice(19) == 19);
284290
client->start();
285291
game_cb_(i, client);
286292
client->End();
@@ -329,7 +335,7 @@ class Context {
329335
std::atomic<bool> tmp_thread_done(false);
330336

331337
std::thread tmp_thread([&]() {
332-
assert(nice(10) == 10);
338+
// assert(nice(10) == 10);
333339

334340
std::cout << "Prepare to stop ..." << std::endl;
335341
client_->prepareToStop();

0 commit comments

Comments
 (0)