@@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr(
6767 src->pub .next_input_byte = src->data ;
6868}
6969
70+ inline unsigned char clamped_cmyk_rgb_convert (
71+ unsigned char k,
72+ unsigned char cmy) {
73+ // Inspired from Pillow:
74+ // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
75+ int v = k * cmy + 128 ;
76+ v = ((v >> 8 ) + v) >> 8 ;
77+ return std::clamp (k - v, 0 , 255 );
78+ }
79+
80+ void convert_line_cmyk_to_rgb (
81+ j_decompress_ptr cinfo,
82+ const unsigned char * cmyk_line,
83+ unsigned char * rgb_line) {
84+ int width = cinfo->output_width ;
85+ for (int i = 0 ; i < width; ++i) {
86+ int c = cmyk_line[i * 4 + 0 ];
87+ int m = cmyk_line[i * 4 + 1 ];
88+ int y = cmyk_line[i * 4 + 2 ];
89+ int k = cmyk_line[i * 4 + 3 ];
90+
91+ rgb_line[i * 3 + 0 ] = clamped_cmyk_rgb_convert (k, 255 - c);
92+ rgb_line[i * 3 + 1 ] = clamped_cmyk_rgb_convert (k, 255 - m);
93+ rgb_line[i * 3 + 2 ] = clamped_cmyk_rgb_convert (k, 255 - y);
94+ }
95+ }
96+
97+ inline unsigned char rgb_to_gray (int r, int g, int b) {
98+ // Inspired from Pillow:
99+ // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
100+ return (r * 19595 + g * 38470 + b * 7471 + 0x8000 ) >> 16 ;
101+ }
102+
103+ void convert_line_cmyk_to_gray (
104+ j_decompress_ptr cinfo,
105+ const unsigned char * cmyk_line,
106+ unsigned char * gray_line) {
107+ int width = cinfo->output_width ;
108+ for (int i = 0 ; i < width; ++i) {
109+ int c = cmyk_line[i * 4 + 0 ];
110+ int m = cmyk_line[i * 4 + 1 ];
111+ int y = cmyk_line[i * 4 + 2 ];
112+ int k = cmyk_line[i * 4 + 3 ];
113+
114+ int r = clamped_cmyk_rgb_convert (k, 255 - c);
115+ int g = clamped_cmyk_rgb_convert (k, 255 - m);
116+ int b = clamped_cmyk_rgb_convert (k, 255 - y);
117+
118+ gray_line[i] = rgb_to_gray (r, g, b);
119+ }
120+ }
121+
70122} // namespace
71123
72124torch::Tensor decode_jpeg (const torch::Tensor& data, ImageReadMode mode) {
@@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
102154 jpeg_read_header (&cinfo, TRUE );
103155
104156 int channels = cinfo.num_components ;
157+ bool cmyk_to_rgb_or_gray = false ;
105158
106159 if (mode != IMAGE_READ_MODE_UNCHANGED) {
107160 switch (mode) {
108161 case IMAGE_READ_MODE_GRAY:
109- if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
162+ if (cinfo.jpeg_color_space == JCS_CMYK ||
163+ cinfo.jpeg_color_space == JCS_YCCK) {
164+ cinfo.out_color_space = JCS_CMYK;
165+ cmyk_to_rgb_or_gray = true ;
166+ } else {
110167 cinfo.out_color_space = JCS_GRAYSCALE;
111- channels = 1 ;
112168 }
169+ channels = 1 ;
113170 break ;
114171 case IMAGE_READ_MODE_RGB:
115- if (cinfo.jpeg_color_space != JCS_RGB) {
172+ if (cinfo.jpeg_color_space == JCS_CMYK ||
173+ cinfo.jpeg_color_space == JCS_YCCK) {
174+ cinfo.out_color_space = JCS_CMYK;
175+ cmyk_to_rgb_or_gray = true ;
176+ } else {
116177 cinfo.out_color_space = JCS_RGB;
117- channels = 3 ;
118178 }
179+ channels = 3 ;
119180 break ;
120181 /*
121182 * Libjpeg does not support converting from CMYK to grayscale etc. There
@@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
139200 auto tensor =
140201 torch::empty ({int64_t (height), int64_t (width), channels}, torch::kU8 );
141202 auto ptr = tensor.data_ptr <uint8_t >();
203+ torch::Tensor cmyk_line_tensor;
204+ if (cmyk_to_rgb_or_gray) {
205+ cmyk_line_tensor = torch::empty ({int64_t (width), 4 }, torch::kU8 );
206+ }
207+
142208 while (cinfo.output_scanline < cinfo.output_height ) {
143209 /* jpeg_read_scanlines expects an array of pointers to scanlines.
144210 * Here the array is only one element long, but you could ask for
145211 * more than one scanline at a time if that's more convenient.
146212 */
147- jpeg_read_scanlines (&cinfo, &ptr, 1 );
213+ if (cmyk_to_rgb_or_gray) {
214+ auto cmyk_line_ptr = cmyk_line_tensor.data_ptr <uint8_t >();
215+ jpeg_read_scanlines (&cinfo, &cmyk_line_ptr, 1 );
216+
217+ if (channels == 3 ) {
218+ convert_line_cmyk_to_rgb (&cinfo, cmyk_line_ptr, ptr);
219+ } else if (channels == 1 ) {
220+ convert_line_cmyk_to_gray (&cinfo, cmyk_line_ptr, ptr);
221+ }
222+ } else {
223+ jpeg_read_scanlines (&cinfo, &ptr, 1 );
224+ }
148225 ptr += stride;
149226 }
150227
0 commit comments