Skip to content

Commit fdb69b8

Browse files
committed
updating the PR
1 parent 19fad73 commit fdb69b8

File tree

3 files changed

+17
-42
lines changed

3 files changed

+17
-42
lines changed

egs/wsj/s5/steps/nnet3/chain/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,9 @@ def process_args(args):
239239
If you have GPUs and have nvcc installed, go to src/ and do
240240
./configure; make""")
241241

242-
run_opts.train_queue_opt = "--gpu 1 --mem 10G"
242+
run_opts.train_queue_opt = "--gpu 1"
243243
run_opts.parallel_train_opts = "--use-gpu={}".format(args.use_gpu)
244-
run_opts.combine_queue_opt = "--gpu 1 --mem 10G"
244+
run_opts.combine_queue_opt = "--gpu 1"
245245
run_opts.combine_gpu_opt = "--use-gpu={}".format(args.use_gpu)
246246

247247
else:

src/nnet3/decodable-online-looped.cc

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,6 @@ void DecodableNnetLoopedOnlineBase::AdvanceChunk() {
162162
input_frame = num_feature_frames_ready - 1;
163163
input_features_->GetFrame(input_frame, &this_row);
164164
}
165-
166-
// dump 1st chunk of features (debug),
167-
if (num_chunks_computed_ == 0)
168-
KALDI_VLOG(100) << "feature-dump " << this_feats;
169-
170165
feats_chunk.Swap(&this_feats);
171166
}
172167
computer_.AcceptInput("input", &feats_chunk);
@@ -205,11 +200,6 @@ void DecodableNnetLoopedOnlineBase::AdvanceChunk() {
205200
Matrix<BaseFloat> ivectors(num_ivectors,
206201
ivector.Dim());
207202
ivectors.CopyRowsFromVec(ivector);
208-
209-
// dump 1st chunk of ivectors (debug),
210-
if (num_chunks_computed_ == 0)
211-
KALDI_VLOG(100) << "ivector-dump " << ivectors;
212-
213203
CuMatrix<BaseFloat> cu_ivectors;
214204
cu_ivectors.Swap(&ivectors);
215205
computer_.AcceptInput("ivector", &cu_ivectors);

src/online2bin/online2-tcp-nnet3-decode-faster.cc

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
#include <arpa/inet.h>
3939
#include <unistd.h>
4040
#include <string>
41-
#include <iomanip>
4241

4342
namespace kaldi {
4443

@@ -86,25 +85,20 @@ std::string LatticeToString(const Lattice &lat, const fst::SymbolTable &word_sym
8685
return msg.str();
8786
}
8887

89-
std::string GetTimeString(int32 t_beg, int32 t_end, int32 subsamp) {
88+
std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit) {
9089
char buffer[100];
91-
double t_beg2 = t_beg * subsamp / 100.;
92-
double t_end2 = t_end * subsamp / 100.;
93-
snprintf(buffer, 100, "{ %5.2f --> %5.2f }", t_beg2, t_end2);
90+
double t_beg2 = t_beg * time_unit;
91+
double t_end2 = t_end * time_unit;
92+
snprintf(buffer, 100, "%.2f %.2f", t_beg2, t_end2);
9493
return std::string(buffer);
9594
}
9695

9796
int32 GetLatticeTimeSpan(const Lattice& lat) {
98-
// convert the lattice,
99-
CompactLattice clat;
100-
ConvertLattice(lat, &clat);
101-
// get the word-alignment,
97+
LatticeWeight weight;
98+
std::vector<int32> alignment;
10299
std::vector<int32> words;
103-
std::vector<int32> begin_times;
104-
std::vector<int32> lengths;
105-
CompactLatticeToWordAlignment(clat, &words, &begin_times, &lengths);
106-
// get ending time of last word,
107-
return begin_times.back() + lengths.back();
100+
GetLinearSymbolSequence(lat, &alignment, &words, &weight);
101+
return alignment.size();
108102
}
109103

110104
std::string LatticeToString(const CompactLattice &clat, const fst::SymbolTable &word_syms) {
@@ -169,7 +163,7 @@ int main(int argc, char *argv[]) {
169163
po.Register("port-num", &port_num,
170164
"Port number the server will listen on.");
171165
po.Register("produce-time", &produce_time,
172-
"Send 'sentence' begin/end times based on end-points");
166+
"Prepend begin/end times between endpoints (e.g. '5.46 6.81 <text_output>0', in seconds)");
173167

174168
feature_opts.Register(&po);
175169
decodable_opts.Register(&po);
@@ -189,6 +183,9 @@ int main(int argc, char *argv[]) {
189183

190184
OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
191185

186+
BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
187+
int32 frame_subsampling = decodable_opts.frame_subsampling_factor;
188+
192189
KALDI_VLOG(1) << "Loading AM...";
193190

194191
TransitionModel trans_model;
@@ -256,7 +253,6 @@ int main(int argc, char *argv[]) {
256253
eos = !server.ReadChunk(chunk_len);
257254

258255
if (eos) {
259-
KALDI_VLOG(2) << "eos detected";
260256
feature_pipeline.InputFinished();
261257
decoder.AdvanceDecoding();
262258
decoder.FinalizeDecoding();
@@ -270,8 +266,7 @@ int main(int argc, char *argv[]) {
270266
if (produce_time) {
271267
int32 t_beg = frame_offset - decoder.NumFramesDecoded();
272268
int32 t_end = frame_offset;
273-
int32 s = decodable_opts.frame_subsampling_factor;
274-
msg = GetTimeString(t_beg, t_end, s) + " " + msg;
269+
msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
275270
}
276271

277272
KALDI_VLOG(1) << "EndOfAudio, sending message: " << msg;
@@ -284,13 +279,7 @@ int main(int argc, char *argv[]) {
284279

285280
Vector<BaseFloat> wave_part = server.GetChunk();
286281
feature_pipeline.AcceptWaveform(samp_freq, wave_part);
287-
288-
// dump 1st chunk of 'raw' audio (debug),
289-
if (samp_count == 0)
290-
KALDI_VLOG(100) << "raw audio " << wave_part;
291-
292282
samp_count += chunk_len;
293-
KALDI_VLOG(2) << "samp_count " << samp_count;
294283

295284
if (silence_weighting.Active() &&
296285
feature_pipeline.IvectorFeature() != NULL) {
@@ -299,10 +288,8 @@ int main(int argc, char *argv[]) {
299288
&delta_weights);
300289
feature_pipeline.UpdateFrameWeights(delta_weights,
301290
frame_offset * decodable_opts.frame_subsampling_factor);
302-
KALDI_VLOG(2) << "silence weighting";
303291
}
304292

305-
KALDI_VLOG(2) << "Advance decoding";
306293
decoder.AdvanceDecoding();
307294

308295
if (samp_count > check_count) {
@@ -315,8 +302,7 @@ int main(int argc, char *argv[]) {
315302
if (produce_time) {
316303
int32 t_beg = frame_offset;
317304
int32 t_end = frame_offset + GetLatticeTimeSpan(lat);
318-
int32 s = decodable_opts.frame_subsampling_factor;
319-
msg = GetTimeString(t_beg, t_end, s) + " " + msg;
305+
msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
320306
}
321307

322308
KALDI_VLOG(1) << "Temporary transcript: " << msg;
@@ -336,8 +322,7 @@ int main(int argc, char *argv[]) {
336322
if (produce_time) {
337323
int32 t_beg = frame_offset - decoder.NumFramesDecoded();
338324
int32 t_end = frame_offset;
339-
int32 s = decodable_opts.frame_subsampling_factor;
340-
msg = GetTimeString(t_beg, t_end, s) + " " + msg;
325+
msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
341326
}
342327

343328
KALDI_VLOG(1) << "Endpoint, sending message: " << msg;

0 commit comments

Comments
 (0)