Skip to content

Commit 3031e26

Browse files
chenzhehuaidanpovey
authored andcommitted
[src] Optimization to decoders for speed (#2168)
1 parent f861b00 commit 3031e26

File tree

4 files changed

+126
-32
lines changed

4 files changed

+126
-32
lines changed

src/decoder/lattice-faster-decoder.cc

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
44
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
55
// 2014 Guoguo Chen
6+
// 2018 Zhehuai Chen
67

78
// See ../../COPYING for clarification regarding multiple authors
89
//
@@ -68,7 +69,7 @@ void LatticeFasterDecoder::InitDecoding() {
6869
active_toks_[0].toks = start_tok;
6970
toks_.Insert(start_state, start_tok);
7071
num_toks_++;
71-
ProcessNonemitting(config_.beam);
72+
ProcessNonemittingWrapper(config_.beam);
7273
}
7374

7475
// Returns true if any kind of traceback is available (not necessarily from
@@ -84,8 +85,8 @@ bool LatticeFasterDecoder::Decode(DecodableInterface *decodable) {
8485
while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {
8586
if (NumFramesDecoded() % config_.prune_interval == 0)
8687
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
87-
BaseFloat cost_cutoff = ProcessEmitting(decodable);
88-
ProcessNonemitting(cost_cutoff);
88+
BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable);
89+
ProcessNonemittingWrapper(cost_cutoff);
8990
}
9091
FinalizeDecoding();
9192

@@ -588,8 +589,8 @@ void LatticeFasterDecoder::AdvanceDecoding(DecodableInterface *decodable,
588589
if (NumFramesDecoded() % config_.prune_interval == 0) {
589590
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
590591
}
591-
BaseFloat cost_cutoff = ProcessEmitting(decodable);
592-
ProcessNonemitting(cost_cutoff);
592+
BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable);
593+
ProcessNonemittingWrapper(cost_cutoff);
593594
}
594595
}
595596

@@ -683,6 +684,7 @@ BaseFloat LatticeFasterDecoder::GetCutoff(Elem *list_head, size_t *tok_count,
683684
}
684685
}
685686

687+
template <typename FstType>
686688
BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
687689
KALDI_ASSERT(active_toks_.size() > 0);
688690
int32 frame = active_toks_.size() - 1; // frame is the frame-index
@@ -707,6 +709,7 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
707709

708710
BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
709711
// dynamic range.
712+
const FstType &fst = dynamic_cast<const FstType&>(fst_);
710713

711714
// First process the best token to get a hopefully
712715
// reasonably tight bound on the next cutoff. The only
@@ -715,15 +718,13 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
715718
StateId state = best_elem->key;
716719
Token *tok = best_elem->val;
717720
cost_offset = - tok->tot_cost;
718-
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
721+
for (fst::ArcIterator<FstType> aiter(fst, state);
719722
!aiter.Done();
720723
aiter.Next()) {
721-
Arc arc = aiter.Value();
724+
const Arc &arc = aiter.Value();
722725
if (arc.ilabel != 0) { // propagate..
723-
arc.weight = Times(arc.weight,
724-
Weight(cost_offset -
725-
decodable->LogLikelihood(frame, arc.ilabel)));
726-
BaseFloat new_weight = arc.weight.Value() + tok->tot_cost;
726+
BaseFloat new_weight = arc.weight.Value() + cost_offset -
727+
decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost;
727728
if (new_weight + adaptive_beam < next_cutoff)
728729
next_cutoff = new_weight + adaptive_beam;
729730
}
@@ -744,7 +745,7 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
744745
StateId state = e->key;
745746
Token *tok = e->val;
746747
if (tok->tot_cost <= cur_cutoff) {
747-
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
748+
for (fst::ArcIterator<FstType> aiter(fst, state);
748749
!aiter.Done();
749750
aiter.Next()) {
750751
const Arc &arc = aiter.Value();
@@ -775,12 +776,31 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
775776
return next_cutoff;
776777
}
777778

779+
template BaseFloat LatticeFasterDecoder::ProcessEmitting<fst::ConstFst<fst::StdArc>>(
780+
DecodableInterface *decodable);
781+
template BaseFloat LatticeFasterDecoder::ProcessEmitting<fst::VectorFst<fst::StdArc>>(
782+
DecodableInterface *decodable);
783+
template BaseFloat LatticeFasterDecoder::ProcessEmitting<fst::Fst<fst::StdArc>>(
784+
DecodableInterface *decodable);
785+
786+
BaseFloat LatticeFasterDecoder::ProcessEmittingWrapper(DecodableInterface *decodable) {
787+
if (fst_.Type() == "const") {
788+
return LatticeFasterDecoder::ProcessEmitting<fst::ConstFst<Arc>>(decodable);
789+
} else if (fst_.Type() == "vector") {
790+
return LatticeFasterDecoder::ProcessEmitting<fst::VectorFst<Arc>>(decodable);
791+
} else {
792+
return LatticeFasterDecoder::ProcessEmitting<fst::Fst<Arc>>(decodable);
793+
}
794+
}
795+
796+
template <typename FstType>
778797
void LatticeFasterDecoder::ProcessNonemitting(BaseFloat cutoff) {
779798
KALDI_ASSERT(!active_toks_.empty());
780799
int32 frame = static_cast<int32>(active_toks_.size()) - 2;
781800
// Note: "frame" is the time-index we just processed, or -1 if
782801
// we are processing the nonemitting transitions before the
783802
// first frame (called from InitDecoding()).
803+
const FstType &fst = dynamic_cast<const FstType&>(fst_);
784804

785805
// Processes nonemitting arcs for one frame. Propagates within toks_.
786806
// Note-- this queue structure is is not very optimal as
@@ -812,7 +832,7 @@ void LatticeFasterDecoder::ProcessNonemitting(BaseFloat cutoff) {
812832
// but since most states are emitting it's not a huge issue.
813833
tok->DeleteForwardLinks(); // necessary when re-visiting
814834
tok->links = NULL;
815-
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
835+
for (fst::ArcIterator<FstType> aiter(fst, state);
816836
!aiter.Done();
817837
aiter.Next()) {
818838
const Arc &arc = aiter.Value();
@@ -837,6 +857,22 @@ void LatticeFasterDecoder::ProcessNonemitting(BaseFloat cutoff) {
837857
} // while queue not empty
838858
}
839859

860+
template void LatticeFasterDecoder::ProcessNonemitting<fst::ConstFst<fst::StdArc>>(
861+
BaseFloat cutoff);
862+
template void LatticeFasterDecoder::ProcessNonemitting<fst::VectorFst<fst::StdArc>>(
863+
BaseFloat cutoff);
864+
template void LatticeFasterDecoder::ProcessNonemitting<fst::Fst<fst::StdArc>>(
865+
BaseFloat cutoff);
866+
867+
void LatticeFasterDecoder::ProcessNonemittingWrapper(BaseFloat cost_cutoff) {
868+
if (fst_.Type() == "const") {
869+
return LatticeFasterDecoder::ProcessNonemitting<fst::ConstFst<Arc>>(cost_cutoff);
870+
} else if (fst_.Type() == "vector") {
871+
return LatticeFasterDecoder::ProcessNonemitting<fst::VectorFst<Arc>>(cost_cutoff);
872+
} else {
873+
return LatticeFasterDecoder::ProcessNonemitting<fst::ConstFst<Arc>>(cost_cutoff);
874+
}
875+
}
840876

841877
void LatticeFasterDecoder::DeleteElems(Elem *list) {
842878
for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {

src/decoder/lattice-faster-decoder.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
44
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
55
// 2014 Guoguo Chen
6+
// 2018 Zhehuai Chen
67

78
// See ../../COPYING for clarification regarding multiple authors
89
//
@@ -339,12 +340,18 @@ class LatticeFasterDecoder {
339340

340341
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_.
341342
/// Returns the cost cutoff for subsequent ProcessNonemitting() to use.
342-
BaseFloat ProcessEmitting(DecodableInterface *decodable);
343+
/// Templated on FST type for speed; called via ProcessEmittingWrapper().
344+
template <typename FstType> BaseFloat ProcessEmitting(DecodableInterface *decodable);
345+
346+
BaseFloat ProcessEmittingWrapper(DecodableInterface *decodable);
343347

344348
/// Processes nonemitting (epsilon) arcs for one frame. Called after
345349
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
346350
/// preceding ProcessEmitting().
347-
void ProcessNonemitting(BaseFloat cost_cutoff);
351+
/// the templated design is similar to ProcessEmitting()
352+
template <typename FstType> void ProcessNonemitting(BaseFloat cost_cutoff);
353+
354+
void ProcessNonemittingWrapper(BaseFloat cost_cutoff);
348355

349356
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
350357
// more than one list (e.g. for current and previous frames), but only one of

src/decoder/lattice-faster-online-decoder.cc

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
55
// 2014 Guoguo Chen
66
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
7+
// 2018 Zhehuai Chen
78

89
// See ../../COPYING for clarification regarding multiple authors
910
//
@@ -68,7 +69,7 @@ void LatticeFasterOnlineDecoder::InitDecoding() {
6869
active_toks_[0].toks = start_tok;
6970
toks_.Insert(start_state, start_tok);
7071
num_toks_++;
71-
ProcessNonemitting(config_.beam);
72+
ProcessNonemittingWrapper(config_.beam);
7273
}
7374

7475
// Returns true if any kind of traceback is available (not necessarily from
@@ -84,8 +85,8 @@ bool LatticeFasterOnlineDecoder::Decode(DecodableInterface *decodable) {
8485
while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {
8586
if (NumFramesDecoded() % config_.prune_interval == 0)
8687
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
87-
BaseFloat cost_cutoff = ProcessEmitting(decodable); // Note: the value returned by
88-
ProcessNonemitting(cost_cutoff);
88+
BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable); // Note: the value returned by
89+
ProcessNonemittingWrapper(cost_cutoff);
8990
}
9091
FinalizeDecoding();
9192

@@ -763,8 +764,8 @@ void LatticeFasterOnlineDecoder::AdvanceDecoding(DecodableInterface *decodable,
763764
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
764765
}
765766
// note: ProcessEmitting() increments NumFramesDecoded().
766-
BaseFloat cost_cutoff = ProcessEmitting(decodable);
767-
ProcessNonemitting(cost_cutoff);
767+
BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable);
768+
ProcessNonemittingWrapper(cost_cutoff);
768769
}
769770
}
770771

@@ -861,6 +862,7 @@ BaseFloat LatticeFasterOnlineDecoder::GetCutoff(Elem *list_head, size_t *tok_cou
861862
}
862863

863864

865+
template <typename FstType>
864866
BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
865867
DecodableInterface *decodable) {
866868
KALDI_ASSERT(active_toks_.size() > 0);
@@ -883,6 +885,7 @@ BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
883885

884886
BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
885887
// dynamic range.
888+
const FstType &fst = dynamic_cast<const FstType&>(fst_);
886889

887890
// First process the best token to get a hopefully
888891
// reasonably tight bound on the next cutoff. The only
@@ -891,15 +894,13 @@ BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
891894
StateId state = best_elem->key;
892895
Token *tok = best_elem->val;
893896
cost_offset = - tok->tot_cost;
894-
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
897+
for (fst::ArcIterator<FstType> aiter(fst, state);
895898
!aiter.Done();
896899
aiter.Next()) {
897-
Arc arc = aiter.Value();
900+
const Arc &arc = aiter.Value();
898901
if (arc.ilabel != 0) { // propagate..
899-
arc.weight = Times(arc.weight,
900-
Weight(cost_offset -
901-
decodable->LogLikelihood(frame, arc.ilabel)));
902-
BaseFloat new_weight = arc.weight.Value() + tok->tot_cost;
902+
BaseFloat new_weight = arc.weight.Value() + cost_offset -
903+
decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost;
903904
if (new_weight + adaptive_beam < next_cutoff)
904905
next_cutoff = new_weight + adaptive_beam;
905906
}
@@ -919,8 +920,8 @@ BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
919920
// loop this way because we delete "e" as we go.
920921
StateId state = e->key;
921922
Token *tok = e->val;
922-
if (tok->tot_cost <= cur_cutoff) {
923-
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
923+
if (tok->tot_cost <= cur_cutoff) {
924+
for (fst::ArcIterator<FstType> aiter(fst, state);
924925
!aiter.Done();
925926
aiter.Next()) {
926927
const Arc &arc = aiter.Value();
@@ -951,12 +952,35 @@ BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
951952
return next_cutoff;
952953
}
953954

955+
template BaseFloat LatticeFasterOnlineDecoder::
956+
ProcessEmitting<fst::ConstFst<fst::StdArc>>(DecodableInterface *decodable);
957+
template BaseFloat LatticeFasterOnlineDecoder::
958+
ProcessEmitting<fst::VectorFst<fst::StdArc>>(DecodableInterface *decodable);
959+
template BaseFloat LatticeFasterOnlineDecoder::
960+
ProcessEmitting<fst::Fst<fst::StdArc>>(DecodableInterface *decodable);
961+
962+
BaseFloat LatticeFasterOnlineDecoder::ProcessEmittingWrapper(
963+
DecodableInterface *decodable) {
964+
if (fst_.Type() == "const") {
965+
return LatticeFasterOnlineDecoder::
966+
ProcessEmitting<fst::ConstFst<Arc>>(decodable);
967+
} else if (fst_.Type() == "vector") {
968+
return LatticeFasterOnlineDecoder::
969+
ProcessEmitting<fst::VectorFst<Arc>>(decodable);
970+
} else {
971+
return LatticeFasterOnlineDecoder::
972+
ProcessEmitting<fst::Fst<Arc>>(decodable);
973+
}
974+
}
975+
976+
template <typename FstType>
954977
void LatticeFasterOnlineDecoder::ProcessNonemitting(BaseFloat cutoff) {
955978
KALDI_ASSERT(!active_toks_.empty());
956979
int32 frame = static_cast<int32>(active_toks_.size()) - 2;
957980
// Note: "frame" is the time-index we just processed, or -1 if
958981
// we are processing the nonemitting transitions before the
959982
// first frame (called from InitDecoding()).
983+
const FstType &fst = dynamic_cast<const FstType&>(fst_);
960984

961985
// Processes nonemitting arcs for one frame. Propagates within toks_.
962986
// Note-- this queue structure is is not very optimal as
@@ -988,7 +1012,7 @@ void LatticeFasterOnlineDecoder::ProcessNonemitting(BaseFloat cutoff) {
9881012
// but since most states are emitting it's not a huge issue.
9891013
tok->DeleteForwardLinks(); // necessary when re-visiting
9901014
tok->links = NULL;
991-
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
1015+
for (fst::ArcIterator<FstType> aiter(fst, state);
9921016
!aiter.Done();
9931017
aiter.Next()) {
9941018
const Arc &arc = aiter.Value();
@@ -1013,6 +1037,26 @@ void LatticeFasterOnlineDecoder::ProcessNonemitting(BaseFloat cutoff) {
10131037
} // while queue not empty
10141038
}
10151039

1040+
template void LatticeFasterOnlineDecoder::
1041+
ProcessNonemitting<fst::ConstFst<fst::StdArc>>(BaseFloat cutoff);
1042+
template void LatticeFasterOnlineDecoder::
1043+
ProcessNonemitting<fst::VectorFst<fst::StdArc>>(BaseFloat cutoff);
1044+
template void LatticeFasterOnlineDecoder::
1045+
ProcessNonemitting<fst::Fst<fst::StdArc>>(BaseFloat cutoff);
1046+
1047+
void LatticeFasterOnlineDecoder::ProcessNonemittingWrapper(
1048+
BaseFloat cost_cutoff) {
1049+
if (fst_.Type() == "const") {
1050+
return LatticeFasterOnlineDecoder::
1051+
ProcessNonemitting<fst::ConstFst<Arc>>(cost_cutoff);
1052+
} else if (fst_.Type() == "vector") {
1053+
return LatticeFasterOnlineDecoder::
1054+
ProcessNonemitting<fst::VectorFst<Arc>>(cost_cutoff);
1055+
} else {
1056+
return LatticeFasterOnlineDecoder::
1057+
ProcessNonemitting<fst::ConstFst<Arc>>(cost_cutoff);
1058+
}
1059+
}
10161060

10171061
void LatticeFasterOnlineDecoder::DeleteElems(Elem *list) {
10181062
for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {

src/decoder/lattice-faster-online-decoder.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
44
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
55
// 2014 Guoguo Chen
6+
// 2018 Zhehuai Chen
67

78
// See ../../COPYING for clarification regarding multiple authors
89
//
@@ -337,12 +338,18 @@ class LatticeFasterOnlineDecoder {
337338

338339
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_.
339340
/// Returns the cost cutoff for subsequent ProcessNonemitting() to use.
340-
BaseFloat ProcessEmitting(DecodableInterface *decodable);
341+
/// Templated on FST type for speed; called via ProcessEmittingWrapper().
342+
template <typename FstType> BaseFloat ProcessEmitting(DecodableInterface *decodable);
343+
344+
BaseFloat ProcessEmittingWrapper(DecodableInterface *decodable);
341345

342346
/// Processes nonemitting (epsilon) arcs for one frame. Called after
343347
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
344348
/// preceding ProcessEmitting().
345-
void ProcessNonemitting(BaseFloat cost_cutoff);
349+
/// the templated design is similar to ProcessEmitting()
350+
template <typename FstType> void ProcessNonemitting(BaseFloat cost_cutoff);
351+
352+
void ProcessNonemittingWrapper(BaseFloat cost_cutoff);
346353

347354
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
348355
// more than one list (e.g. for current and previous frames), but only one of
@@ -361,7 +368,7 @@ class LatticeFasterOnlineDecoder {
361368
// make it class member to avoid internal new/delete.
362369
const fst::Fst<fst::StdArc> &fst_;
363370
bool delete_fst_;
364-
std::vector<BaseFloat> cost_offsets_; // This contains, for each
371+
std::vector<BaseFloat> cost_offsets_; // This contains, for each
365372
// frame, an offset that was added to the acoustic log-likelihoods on that
366373
// frame in order to keep everything in a nice dynamic range i.e. close to
367374
// zero, to reduce roundoff errors.

0 commit comments

Comments
 (0)