Skip to content

Commit 5cc7ce0

Browse files
LvHangdanpovey
authored andcommitted
[src] Update Insert function of hashlist and decoders (kaldi-asr#3402)
makes interface of HashList more standard; slight speed improvement.
1 parent 14cc156 commit 5cc7ce0

File tree

8 files changed

+87
-84
lines changed

8 files changed

+87
-84
lines changed

src/decoder/biglm-faster-decoder.h

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -397,18 +397,14 @@ class BiglmFasterDecoder {
397397
if (new_weight < next_weight_cutoff) { // not pruned..
398398
PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
399399
Token *new_tok = new Token(arc, ac_weight, tok);
400-
Elem *e_found = toks_.Find(next_pair);
400+
Elem *e_found = toks_.Insert(next_pair, new_tok);
401401
if (new_weight + adaptive_beam < next_weight_cutoff)
402402
next_weight_cutoff = new_weight + adaptive_beam;
403-
if (e_found == NULL) {
404-
toks_.Insert(next_pair, new_tok);
403+
if ( *(e_found->val) < *new_tok ) {
404+
Token::TokenDelete(e_found->val);
405+
e_found->val = new_tok;
405406
} else {
406-
if ( *(e_found->val) < *new_tok ) {
407-
Token::TokenDelete(e_found->val);
408-
e_found->val = new_tok;
409-
} else {
410-
Token::TokenDelete(new_tok);
411-
}
407+
Token::TokenDelete(new_tok);
412408
}
413409
}
414410
}
@@ -426,11 +422,12 @@ class BiglmFasterDecoder {
426422
// Processes nonemitting arcs for one frame.
427423
KALDI_ASSERT(queue_.empty());
428424
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
429-
queue_.push_back(e->key);
425+
queue_.push_back(e);
430426
while (!queue_.empty()) {
431-
PairId state_pair = queue_.back();
427+
const Elem *e = queue_.back();
432428
queue_.pop_back();
433-
Token *tok = toks_.Find(state_pair)->val; // would segfault if state not
429+
PairId state_pair = e->key;
430+
Token *tok = e->val; // would segfault if state not
434431
// in toks_ but this can't happen.
435432
if (tok->weight_.Value() > cutoff) { // Don't bother processing successors.
436433
continue;
@@ -450,15 +447,14 @@ class BiglmFasterDecoder {
450447
if (new_tok->weight_.Value() > cutoff) { // prune
451448
Token::TokenDelete(new_tok);
452449
} else {
453-
Elem *e_found = toks_.Find(next_pair);
454-
if (e_found == NULL) {
455-
toks_.Insert(next_pair, new_tok);
456-
queue_.push_back(next_pair);
450+
Elem *e_found = toks_.Insert(next_pair, new_tok);
451+
if (e_found->val == new_tok) {
452+
queue_.push_back(e_found);
457453
} else {
458454
if ( *(e_found->val) < *new_tok ) {
459455
Token::TokenDelete(e_found->val);
460456
e_found->val = new_tok;
461-
queue_.push_back(next_pair);
457+
queue_.push_back(e_found);
462458
} else {
463459
Token::TokenDelete(new_tok);
464460
}
@@ -477,7 +473,7 @@ class BiglmFasterDecoder {
477473
fst::DeterministicOnDemandFst<fst::StdArc> *lm_diff_fst_;
478474
BiglmFasterDecoderOptions opts_;
479475
bool warned_noarc_;
480-
std::vector<PairId> queue_; // temp variable used in ProcessNonemitting,
476+
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
481477
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
482478
// make it class member to avoid internal new/delete.
483479

src/decoder/faster-decoder.cc

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,14 @@ double FasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
277277
double new_weight = arc.weight.Value() + tok->cost_ + ac_cost;
278278
if (new_weight < next_weight_cutoff) { // not pruned..
279279
Token *new_tok = new Token(arc, ac_cost, tok);
280-
Elem *e_found = toks_.Find(arc.nextstate);
280+
Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
281281
if (new_weight + adaptive_beam < next_weight_cutoff)
282282
next_weight_cutoff = new_weight + adaptive_beam;
283-
if (e_found == NULL) {
284-
toks_.Insert(arc.nextstate, new_tok);
283+
if ( *(e_found->val) < *new_tok ) {
284+
Token::TokenDelete(e_found->val);
285+
e_found->val = new_tok;
285286
} else {
286-
if ( *(e_found->val) < *new_tok ) {
287-
Token::TokenDelete(e_found->val);
288-
e_found->val = new_tok;
289-
} else {
290-
Token::TokenDelete(new_tok);
291-
}
287+
Token::TokenDelete(new_tok);
292288
}
293289
}
294290
}
@@ -307,11 +303,12 @@ void FasterDecoder::ProcessNonemitting(double cutoff) {
307303
// Processes nonemitting arcs for one frame.
308304
KALDI_ASSERT(queue_.empty());
309305
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
310-
queue_.push_back(e->key);
306+
queue_.push_back(e);
311307
while (!queue_.empty()) {
312-
StateId state = queue_.back();
308+
const Elem* e = queue_.back();
313309
queue_.pop_back();
314-
Token *tok = toks_.Find(state)->val; // would segfault if state not
310+
StateId state = e->key;
311+
Token *tok = e->val; // would segfault if state not
315312
// in toks_ but this can't happen.
316313
if (tok->cost_ > cutoff) { // Don't bother processing successors.
317314
continue;
@@ -326,15 +323,14 @@ void FasterDecoder::ProcessNonemitting(double cutoff) {
326323
if (new_tok->cost_ > cutoff) { // prune
327324
Token::TokenDelete(new_tok);
328325
} else {
329-
Elem *e_found = toks_.Find(arc.nextstate);
330-
if (e_found == NULL) {
331-
toks_.Insert(arc.nextstate, new_tok);
332-
queue_.push_back(arc.nextstate);
326+
Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
327+
if (e_found->val == new_tok) {
328+
queue_.push_back(e_found);
333329
} else {
334330
if ( *(e_found->val) < *new_tok ) {
335331
Token::TokenDelete(e_found->val);
336332
e_found->val = new_tok;
337-
queue_.push_back(arc.nextstate);
333+
queue_.push_back(e_found);
338334
} else {
339335
Token::TokenDelete(new_tok);
340336
}

src/decoder/faster-decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class FasterDecoder {
170170
HashList<StateId, Token*> toks_;
171171
const fst::Fst<fst::StdArc> &fst_;
172172
FasterDecoderOptions config_;
173-
std::vector<StateId> queue_; // temp variable used in ProcessNonemitting,
173+
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
174174
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
175175
// make it class member to avoid internal new/delete.
176176

src/decoder/lattice-biglm-faster-decoder.h

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -312,14 +312,14 @@ class LatticeBiglmFasterDecoder {
312312
// for the current frame. [note: it's inserted if necessary into hash toks_
313313
// and also into the singly linked list of tokens active on this frame
314314
// (whose head is at active_toks_[frame]).
315-
inline Token *FindOrAddToken(PairId state_pair, int32 frame, BaseFloat tot_cost,
316-
bool emitting, bool *changed) {
315+
inline Elem *FindOrAddToken(PairId state_pair, int32 frame,
316+
BaseFloat tot_cost, bool emitting, bool *changed) {
317317
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
318318
// if the token was newly created or the cost changed.
319319
KALDI_ASSERT(frame < active_toks_.size());
320320
Token *&toks = active_toks_[frame].toks;
321-
Elem *e_found = toks_.Find(state_pair);
322-
if (e_found == NULL) { // no such token presently.
321+
Elem *e_found = toks_.Insert(state_pair, NULL);
322+
if (e_found->val == NULL) { // no such token presently.
323323
const BaseFloat extra_cost = 0.0;
324324
// tokens on the currently final frame have zero extra_cost
325325
// as any of them could end up
@@ -328,9 +328,9 @@ class LatticeBiglmFasterDecoder {
328328
// NULL: no forward links yet
329329
toks = new_tok;
330330
num_toks_++;
331-
toks_.Insert(state_pair, new_tok);
331+
e_found->val = new_tok;
332332
if (changed) *changed = true;
333-
return new_tok;
333+
return e_found;
334334
} else {
335335
Token *tok = e_found->val; // There is an existing Token for this state.
336336
if (tok->tot_cost > tot_cost) { // replace old token
@@ -346,7 +346,7 @@ class LatticeBiglmFasterDecoder {
346346
} else {
347347
if (changed) *changed = false;
348348
}
349-
return tok;
349+
return e_found;
350350
}
351351
}
352352

@@ -744,11 +744,11 @@ class LatticeBiglmFasterDecoder {
744744
else if (tot_cost + config_.beam < next_cutoff)
745745
next_cutoff = tot_cost + config_.beam; // prune by best current token
746746
PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
747-
Token *next_tok = FindOrAddToken(next_pair, frame, tot_cost, true, NULL);
747+
Elem *e_next = FindOrAddToken(next_pair, frame, tot_cost, true, NULL);
748748
// true: emitting, NULL: no change indicator needed
749749

750750
// Add ForwardLink from tok to next_tok (put on head of list tok->links)
751-
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
751+
tok->links = new ForwardLink(e_next->val, arc.ilabel, arc.olabel,
752752
graph_cost, ac_cost, tok->links);
753753
}
754754
} // for all arcs
@@ -770,7 +770,7 @@ class LatticeBiglmFasterDecoder {
770770
KALDI_ASSERT(queue_.empty());
771771
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
772772
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
773-
queue_.push_back(e->key);
773+
queue_.push_back(e);
774774
// for pruning with current best token
775775
best_cost = std::min(best_cost, static_cast<BaseFloat>(e->val->tot_cost));
776776
}
@@ -784,11 +784,12 @@ class LatticeBiglmFasterDecoder {
784784
BaseFloat cutoff = best_cost + config_.beam;
785785

786786
while (!queue_.empty()) {
787-
PairId state_pair = queue_.back();
787+
const Elem *e = queue_.back();
788788
queue_.pop_back();
789789

790-
Token *tok = toks_.Find(state_pair)->val; // would segfault if state not in
791-
// toks_ but this can't happen.
790+
PairId state_pair = e->key;
791+
Token *tok = e->val; // would segfault if state not in
792+
// toks_ but this can't happen.
792793
BaseFloat cur_cost = tok->tot_cost;
793794
if (cur_cost > cutoff) // Don't bother processing successors.
794795
continue;
@@ -812,15 +813,15 @@ class LatticeBiglmFasterDecoder {
812813
if (tot_cost < cutoff) {
813814
bool changed;
814815
PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
815-
Token *new_tok = FindOrAddToken(next_pair, frame, tot_cost,
816-
false, &changed); // false: non-emit
816+
Elem *e_new = FindOrAddToken(next_pair, frame, tot_cost,
817+
false, &changed); // false: non-emit
817818

818-
tok->links = new ForwardLink(new_tok, 0, arc.olabel,
819+
tok->links = new ForwardLink(e_new->val, 0, arc.olabel,
819820
graph_cost, 0, tok->links);
820821

821822
// "changed" tells us whether the new token has a different
822823
// cost from before, or is new [if so, add into queue].
823-
if (changed) queue_.push_back(next_pair);
824+
if (changed) queue_.push_back(e_new);
824825
}
825826
}
826827
} // for all arcs
@@ -835,7 +836,7 @@ class LatticeBiglmFasterDecoder {
835836
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
836837
// frame (members of TokenList are toks, must_prune_forward_links,
837838
// must_prune_tokens).
838-
std::vector<PairId> queue_; // temp variable used in ProcessNonemitting,
839+
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
839840
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
840841
// make it class member to avoid internal new/delete.
841842
const fst::Fst<fst::StdArc> &fst_;

src/decoder/lattice-faster-decoder.cc

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,16 @@ void LatticeFasterDecoderTpl<FST, Token>::PossiblyResizeHash(size_t num_toks) {
263263
// and also into the singly linked list of tokens active on this frame
264264
// (whose head is at active_toks_[frame]).
265265
template <typename FST, typename Token>
266-
inline Token* LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
266+
inline typename LatticeFasterDecoderTpl<FST, Token>::Elem*
267+
LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
267268
StateId state, int32 frame_plus_one, BaseFloat tot_cost,
268269
Token *backpointer, bool *changed) {
269270
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
270271
// if the token was newly created or the cost changed.
271272
KALDI_ASSERT(frame_plus_one < active_toks_.size());
272273
Token *&toks = active_toks_[frame_plus_one].toks;
273-
Elem *e_found = toks_.Find(state);
274-
if (e_found == NULL) { // no such token presently.
274+
Elem *e_found = toks_.Insert(state, NULL);
275+
if (e_found->val == NULL) { // no such token presently.
275276
const BaseFloat extra_cost = 0.0;
276277
// tokens on the currently final frame have zero extra_cost
277278
// as any of them could end up
@@ -280,9 +281,9 @@ inline Token* LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
280281
// NULL: no forward links yet
281282
toks = new_tok;
282283
num_toks_++;
283-
toks_.Insert(state, new_tok);
284+
e_found->val = new_tok;
284285
if (changed) *changed = true;
285-
return new_tok;
286+
return e_found;
286287
} else {
287288
Token *tok = e_found->val; // There is an existing Token for this state.
288289
if (tok->tot_cost > tot_cost) { // replace old token
@@ -301,7 +302,7 @@ inline Token* LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
301302
} else {
302303
if (changed) *changed = false;
303304
}
304-
return tok;
305+
return e_found;
305306
}
306307
}
307308

@@ -800,12 +801,12 @@ BaseFloat LatticeFasterDecoderTpl<FST, Token>::ProcessEmitting(
800801
next_cutoff = tot_cost + adaptive_beam; // prune by best current token
801802
// Note: the frame indexes into active_toks_ are one-based,
802803
// hence the + 1.
803-
Token *next_tok = FindOrAddToken(arc.nextstate,
804-
frame + 1, tot_cost, tok, NULL);
804+
Elem *e_next = FindOrAddToken(arc.nextstate,
805+
frame + 1, tot_cost, tok, NULL);
805806
// NULL: no change indicator needed
806807

807808
// Add ForwardLink from tok to next_tok (put on head of list tok->links)
808-
tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel,
809+
tok->links = new ForwardLinkT(e_next->val, arc.ilabel, arc.olabel,
809810
graph_cost, ac_cost, tok->links);
810811
}
811812
} // for all arcs
@@ -855,14 +856,15 @@ void LatticeFasterDecoderTpl<FST, Token>::ProcessNonemitting(BaseFloat cutoff) {
855856
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
856857
StateId state = e->key;
857858
if (fst_->NumInputEpsilons(state) != 0)
858-
queue_.push_back(state);
859+
queue_.push_back(e);
859860
}
860861

861862
while (!queue_.empty()) {
862-
StateId state = queue_.back();
863+
const Elem *e = queue_.back();
863864
queue_.pop_back();
864865

865-
Token *tok = toks_.Find(state)->val; // would segfault if state not in toks_ but this can't happen.
866+
StateId state = e->key;
867+
Token *tok = e->val; // would segfault if e is a NULL pointer but this can't happen.
866868
BaseFloat cur_cost = tok->tot_cost;
867869
if (cur_cost > cutoff) // Don't bother processing successors.
868870
continue;
@@ -882,16 +884,16 @@ void LatticeFasterDecoderTpl<FST, Token>::ProcessNonemitting(BaseFloat cutoff) {
882884
if (tot_cost < cutoff) {
883885
bool changed;
884886

885-
Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
887+
Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
886888
tok, &changed);
887889

888-
tok->links = new ForwardLinkT(new_tok, 0, arc.olabel,
890+
tok->links = new ForwardLinkT(e_new->val, 0, arc.olabel,
889891
graph_cost, 0, tok->links);
890892

891893
// "changed" tells us whether the new token has a different
892894
// cost from before, or is new [if so, add into queue].
893895
if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
894-
queue_.push_back(arc.nextstate);
896+
queue_.push_back(e_new);
895897
}
896898
}
897899
} // for all arcs

src/decoder/lattice-faster-decoder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,9 @@ class LatticeFasterDecoderTpl {
380380
// token was newly created or the cost changed.
381381
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
382382
// hopefully be optimized out).
383-
inline Token *FindOrAddToken(StateId state, int32 frame_plus_one,
384-
BaseFloat tot_cost, Token *backpointer,
385-
bool *changed);
383+
inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
384+
BaseFloat tot_cost, Token *backpointer,
385+
bool *changed);
386386

387387
// prunes outgoing links for all tokens in active_toks_[frame]
388388
// it's called by PruneActiveTokens
@@ -464,7 +464,7 @@ class LatticeFasterDecoderTpl {
464464
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
465465
// frame (members of TokenList are toks, must_prune_forward_links,
466466
// must_prune_tokens).
467-
std::vector<StateId> queue_; // temp variable used in ProcessNonemitting,
467+
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
468468
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
469469

470470
// fst_ is a pointer to the FST we are decoding from.

src/util/hash-list-inl.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,24 @@ HashList<I, T>::~HashList() {
121121
}
122122
}
123123

124-
125124
template<class I, class T>
126-
void HashList<I, T>::Insert(I key, T val) {
125+
inline typename HashList<I, T>::Elem* HashList<I, T>::Insert(I key, T val) {
127126
size_t index = (static_cast<size_t>(key) % hash_size_);
128127
HashBucket &bucket = buckets_[index];
128+
// Check the element is existing or not.
129+
if (bucket.last_elem != NULL) {
130+
Elem *head = (bucket.prev_bucket == static_cast<size_t>(-1) ?
131+
list_head_ :
132+
buckets_[bucket.prev_bucket].last_elem->tail),
133+
*tail = bucket.last_elem->tail;
134+
for (Elem *e = head; e != tail; e = e->tail)
135+
if (e->key == key) return e;
136+
}
137+
138+
// This is a new element. Insert it.
129139
Elem *elem = New();
130140
elem->key = key;
131141
elem->val = val;
132-
133142
if (bucket.last_elem == NULL) { // Unoccupied bucket. Insert at
134143
// head of bucket list (which is tail of regular list, they go in
135144
// opposite directions).
@@ -152,6 +161,7 @@ void HashList<I, T>::Insert(I key, T val) {
152161
bucket.last_elem->tail = elem;
153162
bucket.last_elem = elem;
154163
}
164+
return elem;
155165
}
156166

157167
template<class I, class T>

0 commit comments

Comments
 (0)