Skip to content

Commit 9571832

Browse files
committed
dropout_schedule: Add set-dropout-proportion in nnet3 utils
1 parent ca5bdf9 commit 9571832

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

src/nnet3/nnet-utils.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,29 @@ void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet) {
688688
if (outputs_remaining == 0)
689689
KALDI_ERR << "All outputs were removed.";
690690
nnet->RemoveSomeNodes(nodes_to_remove);
691+
} else if (directive == "set-dropout-proportion") {
692+
std::string name_pattern = "*";
693+
// name_pattern defaults to '*' if none is given. This pattern
694+
// matches names of components, not nodes.
695+
config_line.GetValue("name", &name_pattern);
696+
BaseFloat proportion = -1;
697+
if (!config_line.GetValue("proportion", &proportion)) {
698+
KALDI_ERR << "In edits-config, expected proportion to be set in line: "
699+
<< config_line.WholeLine();
700+
}
701+
UpdatableComponent *component = NULL;
702+
int32 num_dropout_proportions_set = 0;
703+
for (int32 c = 0; c < nnet->NumComponents(); c++) {
704+
if (NameMatchesPattern(nnet->GetComponentName(c).c_str(),
705+
name_pattern.c_str()) &&
706+
(component =
707+
dynamic_cast<DropoutComponent*>(nnet->GetComponent(c)))) {
708+
component->SetDropoutProportion(proportion);
709+
num_dropout_proportions_set++;
710+
}
711+
}
712+
KALDI_LOG << "Set dropout proportions for "
713+
<< num_dropout_proportions_set << " nodes.";
691714
} else {
692715
KALDI_ERR << "Directive '" << directive << "' is not currently "
693716
"supported (reading edit-config).";

src/nnet3/nnet-utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ void FindOrphanNodes(const Nnet &nnet, std::vector<int32> *nodes);
233233
remove internal nodes directly; instead you should use the command
234234
'remove-orphans'.
235235
236+
set-dropout-proportion [name=<name-pattern>] proportion=<dropout-proportion>
237+
Sets the dropout rates for any components of type DropoutComponent whose
238+
names match the given <name-pattern> (e.g. lstm*). <name-pattern> defaults to "*".
236239
\endverbatim
237240
*/
238241
void ReadEditConfig(std::istream &config_file, Nnet *nnet);

src/nnet3bin/nnet3-copy.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {
4242

4343
bool binary_write = true;
4444
BaseFloat learning_rate = -1,
45-
dropout = 0.0;
45+
dropout = -1;
4646
std::string nnet_config, edits_config, edits_str;
4747
BaseFloat scale = 1.0;
4848

@@ -64,7 +64,10 @@ int main(int argc, char *argv[]) {
6464
"will be converted to newlines before parsing. E.g. "
6565
"'--edits=remove-orphans'.");
6666
po.Register("set-dropout-proportion", &dropout, "Set dropout proportion "
67-
"in all DropoutComponent to this value.");
67+
"in all DropoutComponent to this value. "
68+
"This option is deprecated. Use set-dropout-proportion "
69+
"option in edits-config. See comments in ReadEditConfig() "
70+
"in nnet3/nnet-utils.h.");
6871
po.Register("scale", &scale, "The parameter matrices are scaled"
6972
" by the specified value.");
7073
po.Read(argc, argv);
@@ -92,7 +95,10 @@ int main(int argc, char *argv[]) {
9295
ScaleNnet(scale, &nnet);
9396

9497
if (dropout > 0)
95-
SetDropoutProportion(dropout, &nnet);
98+
KALDI_ERR << "--dropout option is deprecated. "
99+
<< "Use set-dropout-proportion "
100+
<< "option in edits-config. See comments in ReadEditConfig() "
101+
<< "in nnet3/nnet-utils.h.";
96102

97103
if (!edits_config.empty()) {
98104
Input ki(edits_config);

0 commit comments

Comments
 (0)