Skip to content

Commit ea85eed

Browse files
authored
[src] Add batched xvector computation (#3643)
1 parent 880df12 commit ea85eed

File tree

6 files changed

+705
-6
lines changed

6 files changed

+705
-6
lines changed

src/nnet3/nnet-general-component.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ void StatisticsPoolingComponent::InitFromConfig(ConfigLine *cfl) {
573573
StatisticsPoolingComponent::StatisticsPoolingComponent():
574574
input_dim_(-1), input_period_(1), left_context_(-1), right_context_(-1),
575575
num_log_count_features_(0), output_stddevs_(false),
576-
variance_floor_(1.0e-10) { }
576+
variance_floor_(1.0e-10), require_direct_input_(false) { }
577577

578578

579579
StatisticsPoolingComponent::StatisticsPoolingComponent(
@@ -582,7 +582,8 @@ StatisticsPoolingComponent::StatisticsPoolingComponent(
582582
left_context_(other.left_context_), right_context_(other.right_context_),
583583
num_log_count_features_(other.num_log_count_features_),
584584
output_stddevs_(other.output_stddevs_),
585-
variance_floor_(1.0e-10) {
585+
variance_floor_(other.variance_floor_),
586+
require_direct_input_(other.require_direct_input_) {
586587
Check();
587588
}
588589

@@ -614,6 +615,9 @@ void StatisticsPoolingComponent::Read(std::istream &is, bool binary) {
614615
ExpectToken(is, binary, "<VarianceFloor>");
615616
ReadBasicType(is, binary, &variance_floor_);
616617
ExpectToken(is, binary, "</StatisticsPoolingComponent>");
618+
require_direct_input_ = false; // This is not written to disk, it's only used
619+
// temporarily, in memory (see
620+
// nnet3-xvector-compute-batched.cc).
617621
Check();
618622
}
619623

src/nnet3/nnet-general-component.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ class StatisticsExtractionComponentPrecomputedIndexes:
331331
or whatever, instead of just component-name, because its output is only defined at multiples
332332
of its input-period.
333333
334-
The output of StatisticsPoolingComponent will only be defined if at least one input was defined.
334+
The output of StatisticsPoolingComponent will only be defined if at least one
335+
input was defined.
335336
*/
336337
class StatisticsPoolingComponent: public Component {
337338
public:
@@ -396,6 +397,11 @@ class StatisticsPoolingComponent: public Component {
396397
const std::vector<Index> &output_indexes,
397398
bool need_backprop) const;
398399

400+
// Used in computing the 'real' context of networks involving this component;
401+
// with the default value of false, the left/right context will always appear
402+
// to be 0.
403+
void SetRequireDirectInput(bool b) { require_direct_input_ = b; }
404+
399405
private:
400406
// Checks that the parameters are valid.
401407
void Check() const;
@@ -411,6 +417,13 @@ class StatisticsPoolingComponent: public Component {
411417
int32 num_log_count_features_;
412418
bool output_stddevs_;
413419
BaseFloat variance_floor_;
420+
// If require_direct_input_ is set to true, in order for a particular 't'
421+
// value to be available at the output of this component, it will require that
422+
// 't' value to be computable at the input. This is used in computing the
423+
// "real" left/right context of the network, but this member isn't currently
424+
// written to disk and will default to false when read.
425+
bool require_direct_input_;
426+
414427
};
415428

416429
class StatisticsPoolingComponentPrecomputedIndexes:

src/nnet3/nnet-utils.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,15 @@ void SetNnetAsGradient(Nnet *nnet) {
300300
}
301301
}
302302

303+
void SetRequireDirectInput(bool b, Nnet *nnet) {
304+
for (int32 c = 0; c < nnet->NumComponents(); c++) {
305+
Component *comp = nnet->GetComponent(c);
306+
if (dynamic_cast<StatisticsPoolingComponent*>(comp) != NULL)
307+
dynamic_cast<StatisticsPoolingComponent*>(comp)->SetRequireDirectInput(b);
308+
}
309+
}
310+
311+
303312
void ScaleNnet(BaseFloat scale, Nnet *nnet) {
304313
if (scale == 1.0) return;
305314
else {
@@ -724,7 +733,7 @@ class SvdApplier {
724733
<< " components to FixedAffineComponent.";
725734
}
726735

727-
// This function finds the minimum index of
736+
// This function finds the minimum index of
728737
// the Descending order sorted [input_vector],
729738
// over a range of indices from [lower] to [upper] index,
730739
// for which the sum of elements upto the found min. index is greater
@@ -743,7 +752,7 @@ class SvdApplier {
743752
}
744753
return (i+1);
745754
}
746-
755+
747756
// Here we perform SVD based refactorig of an input Affine component.
748757
// After applying SVD , we sort the Singularity values in descending order,
749758
// and take the subset of values which contribute to energy_threshold times
@@ -777,7 +786,7 @@ class SvdApplier {
777786
if (energy_threshold_ > 0) {
778787
BaseFloat min_singular_sum = energy_threshold_ * s2_sum_orig;
779788
bottleneck_dim_ = GetReducedDimension(s2, 0, s2.Dim()-1, min_singular_sum);
780-
}
789+
}
781790
SubVector<BaseFloat> this_part(s2, 0, bottleneck_dim_);
782791
BaseFloat s2_sum_reduced = this_part.Sum();
783792
BaseFloat shrinkage_ratio =

src/nnet3/nnet-utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ void ScaleNnet(BaseFloat scale, Nnet *nnet);
118118
/// learning_rate_ to 1 for each UpdatableComponent in nnet
119119
void SetNnetAsGradient(Nnet *nnet);
120120

121+
122+
/// Calls the corresponding function in any component of type
123+
/// StatisticsPoolingComponent; used as a way to compute the 'real' left-right
124+
/// context of networks including SatisticsPoolingComponent, which will give you
125+
/// the minimum chunk size they can consume.
126+
void SetRequireDirectInput(bool b, Nnet *nnet);
127+
128+
121129
/// Does *dest += alpha * src (affects nnet parameters and
122130
/// stored stats).
123131
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest);

src/nnet3bin/Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \
1919
nnet3-discriminative-subset-egs nnet3-get-egs-simple \
2020
nnet3-discriminative-compute-from-egs nnet3-latgen-faster-looped \
2121
nnet3-egs-augment-image nnet3-xvector-get-egs nnet3-xvector-compute \
22+
nnet3-xvector-compute-batched \
2223
nnet3-latgen-grammar nnet3-compute-batch nnet3-latgen-faster-batch \
2324
cuda-gpu-available cuda-compiled
2425

@@ -36,4 +37,5 @@ ADDLIBS = ../nnet3/kaldi-nnet3.a ../chain/kaldi-chain.a \
3637
../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \
3738
../base/kaldi-base.a
3839

40+
3941
include ../makefiles/default_rules.mk

0 commit comments

Comments
 (0)