@@ -1012,11 +1012,6 @@ void RecurrentGradientMachine::generateSequence() {
10121012 /*  width */ 
10131013 false ,
10141014 /*  useGpu */ false );
1015-  Matrix::resizeOrCreate (generator_.outArg .value ,
1016-  /*  height */ 
1017-  /*  width */ 1 ,
1018-  false ,
1019-  /*  useGpu */ false );
10201015 }
10211016 ICpuGpuVector::resizeOrCreate (generator_.outArg .sequenceStartPositions ,
10221017 numSequences + 1 ,
@@ -1026,7 +1021,7 @@ void RecurrentGradientMachine::generateSequence() {
10261021 } else  {
10271022 oneWaySearch (numSequences);
10281023 }
1029-  if  (dataArgsSize_) createDataOutlink (batchMachineIdVec_ );
1024+  if  (dataArgsSize_) createDataOutlink ();
10301025
10311026 size_t  size = generator_.ids .size ();
10321027 generator_.outArg .ids ->resize (size);
@@ -1106,6 +1101,7 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
11061101 }
11071102
11081103 batchMachineIdVec_.clear ();
1104+  batchMachineStartPos_.clear ();
11091105 int * starts = generator_.outArg .sequenceStartPositions ->getMutableData (false );
11101106 starts[0 ] = 0 ;
11111107 generator_.ids .clear ();
@@ -1312,13 +1308,20 @@ void RecurrentGradientMachine::fillGenOutputs() {
13121308 finalPaths_[i].resize (minFinalPathsSize);
13131309 }
13141310
1315-  batchMachineIdVec_.clear ();
13161311 generator_.ids .clear ();
13171312 int * starts = generator_.outArg .sequenceStartPositions ->getMutableData (false );
13181313 starts[0 ] = 0 ;
13191314 if  (numResults > 1 ) {
1320-  real* probs = generator_.outArg .in ->getData ();
1315+  int  idsProbSaveSize = 0 ;
1316+  for  (auto  inSeq : finalPaths_) {
1317+  for  (auto  path : inSeq) idsProbSaveSize += path.ids .size ();
1318+  idsProbSaveSize += inSeq.size ();
1319+  }
1320+  Matrix::resizeOrCreate (
1321+  generator_.outArg .value , idsProbSaveSize, 1 , false , false );
13211322 real* idsProb = generator_.outArg .value ->getData ();
1323+ 
1324+  real* probs = generator_.outArg .in ->getData ();
13221325 size_t  curPos = 0 ;
13231326 for  (size_t  i = 0 ; i < finalPaths_.size (); ++i) {
13241327 for  (size_t  j = 0 ; j < finalPaths_[i].size (); ++j) {
@@ -1333,24 +1336,16 @@ void RecurrentGradientMachine::fillGenOutputs() {
13331336 curPos += genLen;
13341337 idsProb[curPos++] = -1.0 ;
13351338 probs[i * numResults + j] = path.logProb ;
1336- 
1337-  if  (!j && dataArgsSize_) {
1338-  //  in beam search, here only reserved the top 1 generated result
1339-  //  for out_links that are not the generated word indices.
1340-  batchMachineIdVec_.insert (batchMachineIdVec_.end (),
1341-  path.machineIdVec .begin (),
1342-  path.machineIdVec .end ());
1343-  }
13441339 }
13451340 starts[i + 1 ] = generator_.ids .size ();
13461341 }
13471342 } else  {
13481343 for  (size_t  i = 0 ; i < finalPaths_.size (); ++i) {
13491344 CHECK (!finalPaths_[i].empty ());
1350-  generator_. ids . insert (generator_. ids . begin (), 
1351-   finalPaths_[i][ 0 ] .ids .begin (), 
1352-   finalPaths_[i][ 0 ] .ids .end ());
1353-  starts[i + 1 ] = starts[i] + finalPaths_[i][ 0 ] .ids .size ();
1345+  Path& path = finalPaths_[i][ 0 ]; 
1346+  generator_ .ids .insert ( 
1347+  generator_. ids . begin (), path. ids . begin (), path .ids .end ());
1348+  starts[i + 1 ] = starts[i] + path .ids .size ();
13541349 }
13551350 }
13561351}
@@ -1364,25 +1359,76 @@ void RecurrentGradientMachine::copyDataOutlinkFrame(size_t machineCur) {
13641359 }
13651360}
13661361
1367- void  RecurrentGradientMachine::createDataOutlink (
1368-  std::vector<int >& machineIdVec) {
1369-  size_t  seqNum =
1370-  getBeamSize () > 1UL  ? finalPaths_.size () : finalPaths_[0 ].size ();
1371-  std::vector<int > starts (seqNum + 1 , 0 );
1372-  for  (size_t  i = 0 ; i < seqNum; ++i) {
1373-  size_t  seqLen = getBeamSize () > 1UL  ? finalPaths_[i][0 ].ids .size ()
1374-  : finalPaths_[0 ][i].ids .size ();
1375-  starts[i + 1 ] = starts[i] + seqLen;
1362+ void  RecurrentGradientMachine::createDataOutlinkSelRowsInfo (
1363+  bool  isSeq, std::vector<Argument>& outArgs) {
1364+  batchMachineIdVec_.clear ();
1365+ 
1366+  size_t  seqIdx = 0 ;
1367+  for  (size_t  i = 0 ; i < finalPaths_.size (); ++i) {
1368+  for  (size_t  j = 0 ; j < finalPaths_[i].size (); ++j) {
1369+  std::vector<int >& machineIdVec = finalPaths_[i][j].machineIdVec ;
1370+  if  (isSeq) {
1371+  for  (size_t  i = 0 ; i < machineIdVec.size (); ++i) {
1372+  size_t  rowId = machineIdVec[i];
1373+  int * seqPos =
1374+  outArgs[i].sequenceStartPositions ->getMutableData (false );
1375+  batchMachineIdVec_.push_back (seqPos[rowId]);
1376+  }
1377+  } else  {
1378+  batchMachineIdVec_.insert (
1379+  batchMachineIdVec_.end (), machineIdVec.begin (), machineIdVec.end ());
1380+  }
1381+  seqIdx++;
1382+  }
1383+  }
1384+ }
1385+ 
1386+ void  RecurrentGradientMachine::createDataOutlinkCopySizeInfo (
1387+  bool  isSeq, std::vector<Argument>& outArgs, std::vector<int >& copySize) {
1388+  size_t  totalSeqNum = std::accumulate (
1389+  finalPaths_.begin (),
1390+  finalPaths_.end (),
1391+  0UL ,
1392+  [](size_t  a, const  std::vector<Path>& b) { return  a + b.size (); });
1393+  copySize.resize (totalSeqNum, 1 );
1394+ 
1395+  batchMachineStartPos_.resize (totalSeqNum + 1 , 0 );
1396+  if  (isSeq) {
1397+  ICpuGpuVectorPtr inputSeqStartPos = outArgs[0 ].sequenceStartPositions ;
1398+  CHECK_EQ (static_cast <size_t >(inputSeqStartPos->getSize () - 1 ),
1399+  getBeamSize () > 1  ? finalPaths_.size () : finalPaths_[0 ].size ());
1400+  int * starts = inputSeqStartPos->getMutableData (false );
1401+  int  seqId = 0 ;
1402+  for  (int  i = 0 ; i < finalPaths_.size (); ++i) {
1403+  for  (int  j = 0 ; j < finalPaths_[i].size (); ++j) {
1404+  copySize[seqId] = getBeamSize () > 1  ? starts[i + 1 ] - starts[i]
1405+  : starts[j + 1 ] - starts[j];
1406+  batchMachineStartPos_[seqId + 1 ] =
1407+  batchMachineStartPos_[seqId] + finalPaths_[i][j].ids .size ();
1408+  seqId++;
1409+  }
1410+  }
1411+  } else  {
1412+  for  (size_t  i = 0 ; i < finalPaths_[0 ].size (); ++i)
1413+  batchMachineStartPos_[i + 1 ] =
1414+  batchMachineStartPos_[i] + finalPaths_[0 ][i].ids .size ();
13761415 }
1416+ }
13771417
1418+ void  RecurrentGradientMachine::createDataOutlink () {
13781419 for  (size_t  i = 0 ; i < dataArgsSize_; i++) {
1420+  bool  isSeq = dataArgsFrame_[i][0 ].hasSeq ();
1421+  std::vector<int > copySize;
1422+  createDataOutlinkCopySizeInfo (isSeq, dataArgsFrame_[i], copySize);
1423+  createDataOutlinkSelRowsInfo (isSeq, dataArgsFrame_[i]);
1424+ 
13791425 dataArgs_[i].concat (dataArgsFrame_[i],
1380-  machineIdVec,
1381-  starts,
1426+  batchMachineIdVec_,
1427+  batchMachineStartPos_,
1428+  copySize,
13821429 useGpu_,
13831430 HPPL_STREAM_1,
13841431 PASS_TEST);
1385- 
13861432 auto  dataAgent =
13871433 dynamic_cast <DataLayer*>(outFrameLines_[i + 1 ].agentLayer .get ());
13881434 CHECK_NOTNULL (dataAgent);
0 commit comments