Skip to content

Commit d055533

Browse files
committed
dropout_schedule: Add set-dropout-proportion in nnet3 utils
1 parent ca5bdf9 commit d055533

File tree

3 files changed

+35
-13
lines changed

3 files changed

+35
-13
lines changed

src/nnet3/nnet-utils.cc

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -523,16 +523,6 @@ std::string NnetInfo(const Nnet &nnet) {
523523
return ostr.str();
524524
}
525525

526-
void SetDropoutProportion(BaseFloat dropout_proportion,
527-
Nnet *nnet) {
528-
for (int32 c = 0; c < nnet->NumComponents(); c++) {
529-
Component *comp = nnet->GetComponent(c);
530-
DropoutComponent *dc = dynamic_cast<DropoutComponent*>(comp);
531-
if (dc != NULL)
532-
dc->SetDropoutProportion(dropout_proportion);
533-
}
534-
}
535-
536526
void FindOrphanComponents(const Nnet &nnet, std::vector<int32> *components) {
537527
int32 num_components = nnet.NumComponents(), num_nodes = nnet.NumNodes();
538528
std::vector<bool> is_used(num_components, false);
@@ -688,6 +678,29 @@ void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet) {
688678
if (outputs_remaining == 0)
689679
KALDI_ERR << "All outputs were removed.";
690680
nnet->RemoveSomeNodes(nodes_to_remove);
681+
} else if (directive == "set-dropout-proportion") {
682+
std::string name_pattern = "*";
683+
// name_pattern defaults to '*' if none is given. This pattern
684+
// matches names of components, not nodes.
685+
config_line.GetValue("name", &name_pattern);
686+
BaseFloat proportion = -1;
687+
if (!config_line.GetValue("proportion", &proportion)) {
688+
KALDI_ERR << "In edits-config, expected proportion to be set in line: "
689+
<< config_line.WholeLine();
690+
}
691+
DropoutComponent *component = NULL;
692+
int32 num_dropout_proportions_set = 0;
693+
for (int32 c = 0; c < nnet->NumComponents(); c++) {
694+
if (NameMatchesPattern(nnet->GetComponentName(c).c_str(),
695+
name_pattern.c_str()) &&
696+
(component =
697+
dynamic_cast<DropoutComponent*>(nnet->GetComponent(c)))) {
698+
component->SetDropoutProportion(proportion);
699+
num_dropout_proportions_set++;
700+
}
701+
}
702+
KALDI_LOG << "Set dropout proportions for "
703+
<< num_dropout_proportions_set << " nodes.";
691704
} else {
692705
KALDI_ERR << "Directive '" << directive << "' is not currently "
693706
"supported (reading edit-config).";

src/nnet3/nnet-utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ void FindOrphanNodes(const Nnet &nnet, std::vector<int32> *nodes);
233233
remove internal nodes directly; instead you should use the command
234234
'remove-orphans'.
235235
236+
set-dropout-proportion [name=<name-pattern>] proportion=<dropout-proportion>
237+
Sets the dropout rates for any components of type DropoutComponent whose
238+
names match the given <name-pattern> (e.g. lstm*). <name-pattern> defaults to "*".
236239
\endverbatim
237240
*/
238241
void ReadEditConfig(std::istream &config_file, Nnet *nnet);

src/nnet3bin/nnet3-copy.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {
4242

4343
bool binary_write = true;
4444
BaseFloat learning_rate = -1,
45-
dropout = 0.0;
45+
dropout = -1;
4646
std::string nnet_config, edits_config, edits_str;
4747
BaseFloat scale = 1.0;
4848

@@ -64,7 +64,10 @@ int main(int argc, char *argv[]) {
6464
"will be converted to newlines before parsing. E.g. "
6565
"'--edits=remove-orphans'.");
6666
po.Register("set-dropout-proportion", &dropout, "Set dropout proportion "
67-
"in all DropoutComponent to this value.");
67+
"in all DropoutComponent to this value. "
68+
"This option is deprecated. Use set-dropout-proportion "
69+
"option in edits-config. See comments in ReadEditConfig() "
70+
"in nnet3/nnet-utils.h.");
6871
po.Register("scale", &scale, "The parameter matrices are scaled"
6972
" by the specified value.");
7073
po.Read(argc, argv);
@@ -92,7 +95,10 @@ int main(int argc, char *argv[]) {
9295
ScaleNnet(scale, &nnet);
9396

9497
if (dropout > 0)
95-
SetDropoutProportion(dropout, &nnet);
98+
KALDI_ERR << "--dropout option is deprecated. "
99+
<< "Use set-dropout-proportion "
100+
<< "option in edits-config. See comments in ReadEditConfig() "
101+
<< "in nnet3/nnet-utils.h.";
96102

97103
if (!edits_config.empty()) {
98104
Input ki(edits_config);

0 commit comments

Comments
 (0)