@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
2626 MultiFileReader (const std::vector<std::string>& file_names,
2727 const std::vector<framework::DDim>& dims, size_t thread_num,
2828 size_t buffer_size)
29- : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
29+ : buffer_size_(buffer_size) {
30+ readers_.reserve (file_names.size ());
31+ for (const std::string& f_name : file_names) {
32+ readers_.emplace_back (CreateReaderByFileName (f_name, dims));
33+ }
3034 prefetchers_.resize (thread_num);
3135 StartNewScheduler ();
3236 }
@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
4044 void StartNewScheduler ();
4145 void EndScheduler ();
4246 void ScheduleThreadFunc ();
43- void PrefetchThreadFunc (std::string file_name , size_t thread_idx);
47+ void PrefetchThreadFunc (size_t reader_idx , size_t thread_idx);
4448
45- std::vector<std::string> file_names_;
46- std::vector<framework::DDim> dims_;
49+ std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
4750 std::thread scheduler_;
4851 std::vector<std::thread> prefetchers_;
4952 size_t buffer_size_;
50- reader::BlockingQueue<size_t >* waiting_file_idx_ ;
53+ reader::BlockingQueue<size_t >* waiting_reader_idx_ ;
5154 reader::BlockingQueue<size_t >* available_thread_idx_;
5255 reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
5356};
@@ -65,15 +68,15 @@ void MultiFileReader::ReInit() {
6568
6669void MultiFileReader::StartNewScheduler () {
6770 size_t thread_num = prefetchers_.size ();
68- waiting_file_idx_ = new reader::BlockingQueue<size_t >(file_names_ .size ());
71+ waiting_reader_idx_ = new reader::BlockingQueue<size_t >(readers_ .size ());
6972 available_thread_idx_ = new reader::BlockingQueue<size_t >(thread_num);
7073 buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
7174 buffer_size_);
7275
73- for (size_t i = 0 ; i < file_names_ .size (); ++i) {
74- waiting_file_idx_ ->Send (i);
76+ for (size_t i = 0 ; i < readers_ .size (); ++i) {
77+ waiting_reader_idx_ ->Send (i);
7578 }
76- waiting_file_idx_ ->Close ();
79+ waiting_reader_idx_ ->Close ();
7780 for (size_t i = 0 ; i < thread_num; ++i) {
7881 available_thread_idx_->Send (i);
7982 }
@@ -84,13 +87,13 @@ void MultiFileReader::StartNewScheduler() {
8487void MultiFileReader::EndScheduler () {
8588 available_thread_idx_->Close ();
8689 buffer_->Close ();
87- waiting_file_idx_ ->Close ();
90+ waiting_reader_idx_ ->Close ();
8891 if (scheduler_.joinable ()) {
8992 scheduler_.join ();
9093 }
9194 delete buffer_;
9295 delete available_thread_idx_;
93- delete waiting_file_idx_ ;
96+ delete waiting_reader_idx_ ;
9497}
9598
9699void MultiFileReader::ScheduleThreadFunc () {
@@ -102,12 +105,11 @@ void MultiFileReader::ScheduleThreadFunc() {
102105 if (prefetcher.joinable ()) {
103106 prefetcher.join ();
104107 }
105- size_t file_idx ;
106- if (waiting_file_idx_ ->Receive (&file_idx )) {
108+ size_t reader_idx ;
109+ if (waiting_reader_idx_ ->Receive (&reader_idx )) {
107110 // Still have files to read. Start a new prefetch thread.
108- std::string file_name = file_names_[file_idx];
109- prefetcher = std::thread ([this , file_name, thread_idx] {
110- PrefetchThreadFunc (file_name, thread_idx);
111+ prefetcher = std::thread ([this , reader_idx, thread_idx] {
112+ PrefetchThreadFunc (reader_idx, thread_idx);
111113 });
112114 } else {
113115 // No more file to read.
@@ -129,23 +131,22 @@ void MultiFileReader::ScheduleThreadFunc() {
129131 VLOG (5 ) << " MultiFileReader schedule thread terminates." ;
130132}
131133
132- void MultiFileReader::PrefetchThreadFunc (std::string file_name,
133- size_t thread_idx) {
134- VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' starts." ;
135- std::unique_ptr<framework::ReaderBase> reader =
136- CreateReaderByFileName (file_name, dims_);
134+ void MultiFileReader::PrefetchThreadFunc (size_t reader_idx, size_t thread_idx) {
135+ VLOG (5 ) << " The prefetch thread of file idx '" << reader_idx << " ' starts." ;
136+ std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
137137 while (true ) {
138138 std::vector<framework::LoDTensor> ins;
139139 reader->ReadNext (&ins);
140140 if (ins.empty ()) {
141+ reader->ReInit ();
141142 break ;
142143 }
143144 try {
144145 buffer_->Send (std::move (ins));
145146 } catch (paddle::platform::EnforceNotMet e) {
146147 VLOG (5 ) << " WARNING: The buffer channel has been closed. The prefetch "
147- " thread of file '"
148- << file_name << " ' will terminate." ;
148+ " thread of file idx '"
149+ << reader_idx << " ' will terminate." ;
149150 break ;
150151 }
151152 }
@@ -154,7 +155,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
154155 VLOG (5 ) << " WARNING: The available_thread_idx_ channel has been closed. "
155156 " Fail to send thread_idx." ;
156157 }
157- VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' terminates." ;
158+ VLOG (5 ) << " The prefetch thread of file idx '" << reader_idx
159+ << " ' terminates." ;
158160}
159161
160162class OpenFilesOp : public framework ::OperatorBase {
0 commit comments