@@ -11,7 +11,6 @@ using namespace exif_private;
1111torch::Tensor decode_png (
1212 const torch::Tensor& data,
1313 ImageReadMode mode,
14- bool allow_16_bits,
1514 bool apply_exif_orientation) {
1615 TORCH_CHECK (
1716 false , " decode_png: torchvision not compiled with libPNG support" );
@@ -26,7 +25,6 @@ bool is_little_endian() {
2625torch::Tensor decode_png (
2726 const torch::Tensor& data,
2827 ImageReadMode mode,
29- bool allow_16_bits,
3028 bool apply_exif_orientation) {
3129 C10_LOG_API_USAGE_ONCE (" torchvision.csrc.io.image.cpu.decode_png.decode_png" );
3230 // Check that the input tensor dtype is uint8
@@ -99,12 +97,12 @@ torch::Tensor decode_png(
9997 TORCH_CHECK (retval == 1 , " Could read image metadata from content." )
10098 }
10199
102- auto max_bit_depth = allow_16_bits ? 16 : 8 ;
103- auto err_msg = " At most " + std::to_string (max_bit_depth) +
104- " -bit PNG images are supported currently." ;
105- if (bit_depth > max_bit_depth) {
100+ if (bit_depth > 8 && bit_depth != 16 ) {
106101 png_destroy_read_struct (&png_ptr, &info_ptr, nullptr );
107- TORCH_CHECK (false , err_msg)
102+ TORCH_CHECK (
103+ false ,
104+ " bit depth of png image is " + std::to_string (bit_depth) +
105+ " . Only <=8 and 16 are supported." )
108106 }
109107
110108 int channels = png_get_channels (png_ptr, info_ptr);
@@ -199,45 +197,20 @@ torch::Tensor decode_png(
199197 }
200198
201199 auto num_pixels_per_row = width * channels;
200+ auto is_16_bits = bit_depth == 16 ;
202201 auto tensor = torch::empty (
203202 {int64_t (height), int64_t (width), channels},
204- bit_depth <= 8 ? torch::kU8 : torch::kI32 );
205-
206- if (bit_depth <= 8 ) {
207- auto t_ptr = tensor.accessor <uint8_t , 3 >().data ();
208- for (int pass = 0 ; pass < number_of_passes; pass++) {
209- for (png_uint_32 i = 0 ; i < height; ++i) {
210- png_read_row (png_ptr, t_ptr, nullptr );
211- t_ptr += num_pixels_per_row;
212- }
213- t_ptr = tensor.accessor <uint8_t , 3 >().data ();
214- }
215- } else {
216- // We're reading a 16bits png, but pytorch doesn't support uint16.
217- // So we read each row in a 16bits tmp_buffer which we then cast into
218- // a int32 tensor instead.
219- if (is_little_endian ()) {
220- png_set_swap (png_ptr);
221- }
222- int32_t * t_ptr = tensor.accessor <int32_t , 3 >().data ();
223-
224- // We create a tensor instead of malloc-ing for automatic memory management
225- auto tmp_buffer_tensor = torch::empty (
226- {int64_t (num_pixels_per_row * sizeof (uint16_t ))}, torch::kU8 );
227- uint16_t * tmp_buffer =
228- (uint16_t *)tmp_buffer_tensor.accessor <uint8_t , 1 >().data ();
229-
230- for (int pass = 0 ; pass < number_of_passes; pass++) {
231- for (png_uint_32 i = 0 ; i < height; ++i) {
232- png_read_row (png_ptr, (uint8_t *)tmp_buffer, nullptr );
233- // Now we copy the uint16 values into the int32 tensor.
234- for (size_t j = 0 ; j < num_pixels_per_row; ++j) {
235- t_ptr[j] = (int32_t )tmp_buffer[j];
236- }
237- t_ptr += num_pixels_per_row;
238- }
239- t_ptr = tensor.accessor <int32_t , 3 >().data ();
203+ is_16_bits ? at::kUInt16 : torch::kU8 );
204+ if (is_little_endian ()) {
205+ png_set_swap (png_ptr);
206+ }
207+ auto t_ptr = (uint8_t *)tensor.data_ptr ();
208+ for (int pass = 0 ; pass < number_of_passes; pass++) {
209+ for (png_uint_32 i = 0 ; i < height; ++i) {
210+ png_read_row (png_ptr, t_ptr, nullptr );
211+ t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1 );
240212 }
213+ t_ptr = (uint8_t *)tensor.data_ptr ();
241214 }
242215
243216 int exif_orientation = -1 ;
0 commit comments