@@ -413,23 +413,32 @@ def test_read_interlaced_png():
413413
414414
415415@needs_cuda
416- @pytest .mark .parametrize (
417- "img_path" ,
418- [pytest .param (jpeg_path , id = _get_safe_image_name (jpeg_path )) for jpeg_path in get_images (IMAGE_ROOT , ".jpg" )],
419- )
420416@pytest .mark .parametrize ("mode" , [ImageReadMode .UNCHANGED , ImageReadMode .GRAY , ImageReadMode .RGB ])
421417@pytest .mark .parametrize ("scripted" , (False , True ))
422- def test_decode_jpeg_cuda (mode , img_path , scripted ):
423- if "cmyk" in img_path :
424- pytest .xfail ("Decoding a CMYK jpeg isn't supported" )
418+ def test_decode_jpegs_cuda (mode , scripted ):
419+ encoded_images = []
420+ for jpeg_path in get_images (IMAGE_ROOT , ".jpg" ):
421+ if "cmyk" in jpeg_path :
422+ continue
423+ encoded_image = read_file (jpeg_path )
424+ encoded_images .append (encoded_image )
425+ decoded_images_cpu = decode_jpeg (encoded_images , mode = mode )
426+ decode_fn = torch .jit .script (decode_jpeg ) if scripted else decode_jpeg
425427
426- data = read_file (img_path )
427- img = decode_image (data , mode = mode )
428- f = torch .jit .script (decode_jpeg ) if scripted else decode_jpeg
429- img_nvjpeg = f (data , mode = mode , device = "cuda" )
428+ # test multithreaded decoding
429+ # in the current version we prevent this by using a lock but we still want to test it
430+ num_workers = 10
430431
431- # Some difference expected between jpeg implementations
432- assert (img .float () - img_nvjpeg .cpu ().float ()).abs ().mean () < 2
432+ with concurrent .futures .ThreadPoolExecutor (max_workers = num_workers ) as executor :
433+ futures = [executor .submit (decode_fn , encoded_images , mode , "cuda" ) for _ in range (num_workers )]
434+ decoded_images_threaded = [future .result () for future in futures ]
435+ assert len (decoded_images_threaded ) == num_workers
436+ for decoded_images in decoded_images_threaded :
437+ assert len (decoded_images ) == len (encoded_images )
438+ for decoded_image_cuda , decoded_image_cpu in zip (decoded_images , decoded_images_cpu ):
439+ assert decoded_image_cuda .shape == decoded_image_cpu .shape
440+ assert decoded_image_cuda .dtype == decoded_image_cpu .dtype == torch .uint8
441+ assert (decoded_image_cuda .cpu ().float () - decoded_image_cpu .cpu ().float ()).abs ().mean () < 2
433442
434443
435444@needs_cuda
@@ -440,25 +449,95 @@ def test_decode_image_cuda_raises():
440449
441450
442451@needs_cuda
443- @pytest .mark .parametrize ("cuda_device" , ("cuda" , "cuda:0" , torch .device ("cuda" )))
444- def test_decode_jpeg_cuda_device_param (cuda_device ):
445- """Make sure we can pass a string or a torch.device as device param"""
452+ def test_decode_jpeg_cuda_device_param ():
446453 path = next (path for path in get_images (IMAGE_ROOT , ".jpg" ) if "cmyk" not in path )
447454 data = read_file (path )
448- decode_jpeg (data , device = cuda_device )
455+ current_device = torch .cuda .current_device ()
456+ current_stream = torch .cuda .current_stream ()
457+ num_devices = torch .cuda .device_count ()
458+ devices = ["cuda" , torch .device ("cuda" )] + [torch .device (f"cuda:{ i } " ) for i in range (num_devices )]
459+ results = []
460+ for device in devices :
461+ results .append (decode_jpeg (data , device = device ))
462+ assert len (results ) == len (devices )
463+ for result in results :
464+ assert torch .all (result .cpu () == results [0 ].cpu ())
465+ assert current_device == torch .cuda .current_device ()
466+ assert current_stream == torch .cuda .current_stream ()
449467
450468
451469@needs_cuda
452470def test_decode_jpeg_cuda_errors ():
453471 data = read_file (next (get_images (IMAGE_ROOT , ".jpg" )))
454472 with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
455473 decode_jpeg (data .reshape (- 1 , 1 ), device = "cuda" )
456- with pytest .raises (RuntimeError , match = "input tensor must be on CPU" ):
474+ with pytest .raises (ValueError , match = "must be tensors" ):
475+ decode_jpeg ([1 , 2 , 3 ])
476+ with pytest .raises (ValueError , match = "Input tensor must be a CPU tensor" ):
457477 decode_jpeg (data .to ("cuda" ), device = "cuda" )
458478 with pytest .raises (RuntimeError , match = "Expected a torch.uint8 tensor" ):
459479 decode_jpeg (data .to (torch .float ), device = "cuda" )
460- with pytest .raises (RuntimeError , match = "Expected a cuda device" ):
461- torch .ops .image .decode_jpeg_cuda (data , ImageReadMode .UNCHANGED .value , "cpu" )
480+ with pytest .raises (RuntimeError , match = "Expected the device parameter to be a cuda device" ):
481+ torch .ops .image .decode_jpegs_cuda ([data ], ImageReadMode .UNCHANGED .value , "cpu" )
482+ with pytest .raises (ValueError , match = "Input tensor must be a CPU tensor" ):
483+ decode_jpeg (
484+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
485+ )
486+ with pytest .raises (ValueError , match = "Input list must contain tensors on CPU" ):
487+ decode_jpeg (
488+ [
489+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
490+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
491+ ]
492+ )
493+
494+ with pytest .raises (ValueError , match = "Input list must contain tensors on CPU" ):
495+ decode_jpeg (
496+ [
497+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
498+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
499+ ],
500+ device = "cuda" ,
501+ )
502+
503+ with pytest .raises (ValueError , match = "Input list must contain tensors on CPU" ):
504+ decode_jpeg (
505+ [
506+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cpu" ),
507+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
508+ ],
509+ device = "cuda" ,
510+ )
511+
512+ with pytest .raises (RuntimeError , match = "Expected a torch.uint8 tensor" ):
513+ decode_jpeg (
514+ [
515+ torch .empty ((100 ,), dtype = torch .uint8 ),
516+ torch .empty ((100 ,), dtype = torch .float32 ),
517+ ],
518+ device = "cuda" ,
519+ )
520+
521+ with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
522+ decode_jpeg (
523+ [
524+ torch .empty ((100 ,), dtype = torch .uint8 ),
525+ torch .empty ((1 , 100 ), dtype = torch .uint8 ),
526+ ],
527+ device = "cuda" ,
528+ )
529+
530+ with pytest .raises (RuntimeError , match = "Error while decoding JPEG images" ):
531+ decode_jpeg (
532+ [
533+ torch .empty ((100 ,), dtype = torch .uint8 ),
534+ torch .empty ((100 ,), dtype = torch .uint8 ),
535+ ],
536+ device = "cuda" ,
537+ )
538+
539+ with pytest .raises (ValueError , match = "Input list must contain at least one element" ):
540+ decode_jpeg ([], device = "cuda" )
462541
463542
464543def test_encode_jpeg_errors ():
@@ -515,12 +594,10 @@ def test_encode_jpeg_cuda_device_param():
515594 devices = ["cuda" , torch .device ("cuda" )] + [torch .device (f"cuda:{ i } " ) for i in range (num_devices )]
516595 results = []
517596 for device in devices :
518- print (f"python: device: { device } " )
519597 results .append (encode_jpeg (data .to (device = device )))
520598 assert len (results ) == len (devices )
521599 for result in results :
522600 assert torch .all (result .cpu () == results [0 ].cpu ())
523-
524601 assert current_device == torch .cuda .current_device ()
525602 assert current_stream == torch .cuda .current_stream ()
526603
0 commit comments