|
13 | 13 | limitations under the License. */ |
14 | 14 |
|
15 | 15 | #include "paddle/framework/lod_tensor.h" |
| 16 | +#include "paddle/framework/saver.pb.h" |
| 17 | + |
| 18 | +#include "paddle/memory/memcpy.h" |
| 19 | +#include "paddle/memory/memory.h" |
| 20 | + |
| 21 | +#include <stdint.h> |
| 22 | +#include <string.h> |
| 23 | +#include <algorithm> |
| 24 | +#include <iterator> |
16 | 25 |
|
17 | 26 | #include <glog/logging.h> |
18 | 27 |
|
@@ -112,5 +121,140 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, |
112 | 121 | lod_ = new_lod; |
113 | 122 | } |
114 | 123 |
|
| 124 | +std::string LoDTensor::SerializeToString() const { |
| 125 | + LoDTensorProto desc; |
| 126 | + |
| 127 | + // set data_type |
| 128 | + if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL); |
| 129 | + if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16); |
| 130 | + if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32); |
| 131 | + if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64); |
| 132 | + // FIXME(dzh): there is no fp16 in standard c++ |
| 133 | + |
| 134 | + if (this->type() == typeid(float)) // NOLINT |
| 135 | + desc.set_data_type(DataType::FP32); |
| 136 | + if (this->type() == typeid(double)) // NOLINT |
| 137 | + desc.set_data_type(DataType::FP64); |
| 138 | + |
| 139 | + for (int i = 0; i < dims().size(); ++i) { |
| 140 | + desc.add_dims(dims()[i]); |
| 141 | + } |
| 142 | + |
| 143 | + // set lod information |
| 144 | + desc.set_lod_level(this->NumLevels()); |
| 145 | + for (size_t i = 0; i < this->NumLevels(); ++i) { |
| 146 | + LoDInfo* lod = desc.add_levels(); |
| 147 | + for (size_t j = 0; j < lod_[i].size(); ++j) { |
| 148 | + lod->add_level(lod_[i][j]); |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + desc.set_version(0); |
| 153 | + |
| 154 | + std::string desc_bytes = desc.SerializeAsString(); |
| 155 | + |
| 156 | + // FIXME(dzh) : implement fix chunk size buffer. |
| 157 | + size_t DESC_SIZE = desc_bytes.size(); |
| 158 | + size_t DATA_SIZE = holder_->size() - offset_; |
| 159 | + |
| 160 | + const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t); |
| 161 | + char* buffer = |
| 162 | + static_cast<char*>(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE)); |
| 163 | + |
| 164 | + // format: desc_size data_size, desc_bytes, data_bytes. |
| 165 | + platform::CPUPlace src_place; |
| 166 | + platform::CPUPlace dst_place; |
| 167 | + |
| 168 | + memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t)); |
| 169 | + memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE, |
| 170 | + sizeof(size_t)); |
| 171 | + memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place, |
| 172 | + desc_bytes.c_str(), desc_bytes.size()); |
| 173 | + |
| 174 | + PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!"); |
| 175 | + |
| 176 | + platform::Place place = holder_->place(); |
| 177 | + int element_width = holder_->size() / this->numel(); |
| 178 | + |
| 179 | + if (platform::is_cpu_place(place)) { |
| 180 | + memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(), |
| 181 | + boost::get<platform::CPUPlace>(place), |
| 182 | + static_cast<char*>(holder_->ptr()) + offset_ / element_width, |
| 183 | + DATA_SIZE); |
| 184 | + } |
| 185 | +#ifdef PADDLE_WITH_GPU |
| 186 | + if (platform::is_gpu_place(place)) { |
| 187 | + memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(), |
| 188 | + boost::get<platform::GPUPlace>(place), |
| 189 | + static_cast<char*>(holder_->ptr()) + offset_ / element_width, |
| 190 | + DATA_SIZE); |
| 191 | + } |
| 192 | +#endif |
| 193 | + |
| 194 | + std::string ret(buffer, BUFFER_SIZE); |
| 195 | + memory::Free(platform::CPUPlace(), buffer); |
| 196 | + return ret; |
| 197 | +} |
| 198 | + |
| 199 | +void LoDTensor::DeserializeFromString(const std::string& s, |
| 200 | + const platform::Place& dst_place) { |
| 201 | + size_t DESC_SIZE, BUFFER_SIZE; |
| 202 | + platform::CPUPlace src_place; |
| 203 | + |
| 204 | + memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t)); |
| 205 | + memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t), |
| 206 | + sizeof(size_t)); |
| 207 | + |
| 208 | + const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2; |
| 209 | + |
| 210 | + // parse LoDTensorDesc |
| 211 | + LoDTensorProto desc; |
| 212 | + desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE); |
| 213 | + |
| 214 | + std::vector<int64_t> dims; |
| 215 | + std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); |
| 216 | + this->Resize(make_ddim(dims)); |
| 217 | + |
| 218 | + // parse data type |
| 219 | + void* ptr = nullptr; |
| 220 | + if (desc.data_type() == DataType::BOOL) |
| 221 | + ptr = this->mutable_data<bool>(dst_place); |
| 222 | + if (desc.data_type() == DataType::INT16) |
| 223 | + ptr = this->mutable_data<int16_t>(dst_place); |
| 224 | + if (desc.data_type() == DataType::INT32) |
| 225 | + ptr = this->mutable_data<int32_t>(dst_place); |
| 226 | + if (desc.data_type() == DataType::INT64) |
| 227 | + ptr = this->mutable_data<int64_t>(dst_place); |
| 228 | + // FIXME(dzh): there is no fp16 in standard c++ |
| 229 | + |
| 230 | + if (desc.data_type() == DataType::FP32) |
| 231 | + ptr = this->mutable_data<float>(dst_place); |
| 232 | + if (desc.data_type() == DataType::FP64) |
| 233 | + ptr = this->mutable_data<double>(dst_place); |
| 234 | + |
| 235 | + LoD lod; |
| 236 | + std::vector<size_t> levels; |
| 237 | + for (int i = 0; i < desc.levels().size(); ++i) { |
| 238 | + auto current_level = desc.levels()[i].level(); |
| 239 | + std::copy(current_level.begin(), current_level.end(), |
| 240 | + std::back_inserter(levels)); |
| 241 | + lod.emplace_back(levels); |
| 242 | + levels.clear(); |
| 243 | + } |
| 244 | + |
| 245 | + this->set_lod(lod); |
| 246 | + |
| 247 | + if (platform::is_cpu_place(dst_place)) { |
| 248 | + memory::Copy(boost::get<platform::CPUPlace>(dst_place), ptr, src_place, |
| 249 | + s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE); |
| 250 | + } |
| 251 | +#ifdef PADDLE_WITH_GPU |
| 252 | + if (platform::is_gpu_place(dst_place)) { |
| 253 | + memory::Copy(boost::get<platform::GPUPlace>(dst_place), ptr, src_place, |
| 254 | + s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE); |
| 255 | + } |
| 256 | +#endif |
| 257 | +} |
| 258 | + |
115 | 259 | } // namespace framework |
116 | 260 | } // namespace paddle |
0 commit comments