1313# limitations under the License.
1414
1515import os
16- from typing import Sequence
16+ from typing import Sequence , Literal , Union
1717
1818import numpy as np
1919from nose2 .tools import params , cartesian_params
@@ -50,6 +50,7 @@ def build_resize_transform(
5050 max_size : int = None ,
5151 interpolation : transforms .InterpolationMode = transforms .InterpolationMode .BILINEAR ,
5252 antialias : bool = False ,
53+ device : Literal ["cpu" , "gpu" ] = "cpu" ,
5354):
5455 t = transforms .Compose (
5556 [
@@ -61,13 +62,59 @@ def build_resize_transform(
6162 td = Compose (
6263 [
6364 Resize (
64- size = resize , max_size = max_size , interpolation = interpolation , antialias = antialias
65+ size = resize ,
66+ max_size = max_size ,
67+ interpolation = interpolation ,
68+ antialias = antialias ,
69+ device = device ,
6570 ),
6671 ]
6772 )
6873 return t , td
6974
7075
76+ def _internal_loop (
77+ input_data : Union [Image .Image , torch .Tensor ],
78+ t : transforms .Resize ,
79+ td : Resize ,
80+ resize : int | Sequence [int ],
81+ max_size : int = None ,
82+ interpolation : transforms .InterpolationMode = transforms .InterpolationMode .BILINEAR ,
83+ antialias : bool = False ,
84+ ):
85+ out_fn = fn_tv .resize (
86+ input_data ,
87+ size = resize ,
88+ max_size = max_size ,
89+ interpolation = interpolation ,
90+ antialias = antialias ,
91+ )
92+ out_dali_fn = fn_dali .resize (
93+ input_data ,
94+ size = resize ,
95+ max_size = max_size ,
96+ interpolation = interpolation ,
97+ antialias = antialias ,
98+ )
99+ out_tv = t (input_data )
100+ out_dali_tv = td (input_data )
101+
102+ if isinstance (input_data , Image .Image ):
103+ out_tv = transforms .functional .pil_to_tensor (out_tv ).unsqueeze (0 ).permute (0 , 2 , 3 , 1 )
104+ out_dali_tv = (
105+ transforms .functional .pil_to_tensor (out_dali_tv ).unsqueeze (0 ).permute (0 , 2 , 3 , 1 )
106+ )
107+ out_fn = transforms .functional .pil_to_tensor (out_fn )
108+ out_dali_fn = transforms .functional .pil_to_tensor (out_dali_fn )
109+
110+ assert torch .allclose (
111+ torch .tensor (out_tv .shape [1 :3 ]), torch .tensor (out_dali_tv .shape [1 :3 ]), rtol = 0 , atol = 1
112+ ), f"Should be:{ out_tv .shape } is:{ out_dali_tv .shape } "
113+ assert torch .allclose (
114+ torch .tensor (out_fn .shape [1 :3 ]), torch .tensor (out_dali_fn .shape [1 :3 ]), rtol = 0 , atol = 1
115+ ), f"Should be:{ out_fn .shape } is:{ out_dali_fn .shape } "
116+
117+
71118def loop_images_test_no_build (
72119 t : transforms .Resize ,
73120 td : Resize ,
@@ -78,52 +125,66 @@ def loop_images_test_no_build(
78125):
79126 for fn in test_files :
80127 img = Image .open (fn )
81- out_fn = transforms .functional .pil_to_tensor (
82- fn_tv .resize (
83- img ,
84- size = resize ,
85- max_size = max_size ,
86- interpolation = interpolation ,
87- antialias = antialias ,
88- )
89- )
90- out_dali_fn = transforms .functional .pil_to_tensor (
91- fn_dali .resize (
92- img ,
93- size = resize ,
94- max_size = max_size ,
95- interpolation = interpolation ,
96- antialias = antialias ,
97- )
98- )
128+ _internal_loop (img , t , td , resize , max_size , interpolation , antialias )
129+ # assert torch.equal(out_tv, out_dali_tv)
99130
100- out_tv = transforms .functional .pil_to_tensor (t (img )).unsqueeze (0 ).permute (0 , 2 , 3 , 1 )
101- out_dali_tv = transforms .functional .pil_to_tensor (td (img )).unsqueeze (0 ).permute (0 , 2 , 3 , 1 )
102131
103- assert torch .allclose (
104- torch .tensor (out_tv .shape [1 :3 ]), torch .tensor (out_dali_tv .shape [1 :3 ]), rtol = 0 , atol = 1
105- ), f"Should be:{ out_tv .shape } is:{ out_dali_tv .shape } "
106- assert torch .allclose (
107- torch .tensor (out_fn .shape [1 :3 ]), torch .tensor (out_dali_fn .shape [1 :3 ]), rtol = 0 , atol = 1
108- ), f"Should be:{ out_fn .shape } is:{ out_dali_fn .shape } "
132+ def build_tensors (max_size : int = 512 , channels : int = 3 ):
133+ h = torch .randint (10 , max_size , (1 ,)).item ()
134+ w = torch .randint (10 , max_size , (1 ,)).item ()
135+ tensors = [
136+ torch .ones ((channels , max_size , max_size )),
137+ torch .ones ((1 , channels , max_size , max_size )),
138+ torch .ones ((10 , channels , max_size , max_size )),
139+ torch .ones ((channels , max_size // 2 , max_size )),
140+ torch .ones ((1 , channels , max_size // 2 , max_size )),
141+ torch .ones ((10 , channels , max_size // 2 , max_size )),
142+ torch .ones ((channels , max_size , max_size // 2 )),
143+ torch .ones ((1 , channels , max_size , max_size // 2 )),
144+ torch .ones ((10 , channels , max_size , max_size // 2 )),
145+ torch .ones ((channels , h , w )),
146+ torch .ones ((1 , channels , h , w )),
147+ torch .ones ((10 , channels , h , w )),
148+ ]
149+
150+ return tensors
151+
152+
153+ def loop_tensors_test (
154+ resize : int | Sequence [int ],
155+ max_size : int = None ,
156+ interpolation : transforms .InterpolationMode = transforms .InterpolationMode .BILINEAR ,
157+ antialias : bool = False ,
158+ device : Literal ["cpu" , "gpu" ] = "cpu" ,
159+ ):
160+ t , td = build_resize_transform (resize , max_size , interpolation , antialias , device )
161+ tensors = build_tensors ()
109162
110- # assert torch.equal(out_tv, out_dali_tv)
163+ for tn in tensors :
164+ _internal_loop (tn , t , td , resize , max_size , interpolation , antialias )
111165
112166
113167def loop_images_test (
114168 resize : int | Sequence [int ],
115169 max_size : int = None ,
116170 interpolation : transforms .InterpolationMode = transforms .InterpolationMode .BILINEAR ,
117171 antialias : bool = False ,
172+ device : Literal ["cpu" , "gpu" ] = "cpu" ,
118173):
119- t , td = build_resize_transform (resize , max_size , interpolation , antialias )
174+ t , td = build_resize_transform (resize , max_size , interpolation , antialias , device )
120175 loop_images_test_no_build (t , td , resize , max_size , interpolation , antialias )
121176
122177
123- @params (512 , 2048 , ([512 , 512 ]), ([2048 , 2048 ]))
124- def test_resize_sizes (resize ):
178+ @cartesian_params ((512 , 2048 , ([512 , 512 ]), ([2048 , 2048 ])), ("cpu" , "gpu" ))
179+ def test_resize_sizes_images (resize , device ):
180+ # Resize with single int (preserve aspect ratio)
181+ loop_images_test (resize = resize , device = device )
182+
183+
184+ @cartesian_params ((512 , 2048 , ([512 , 512 ]), ([2048 , 2048 ])), ("cpu" , "gpu" ))
185+ def test_resize_sizes_tensors (resize , device ):
125186 # Resize with single int (preserve aspect ratio)
126- loop_images_test (resize = resize )
187+ loop_tensors_test (resize = resize , device = device )
127188
128189
129190@params ((480 , 512 ), (100 , 124 ), (None , 512 ), (1024 , 512 ), ([256 , 256 ], 512 ), (None , None ))
@@ -180,7 +241,7 @@ def test_resize_max_sizes(resize, max_size):
180241 ([256 , 256 ], transforms .InterpolationMode .BILINEAR ),
181242 (640 , transforms .InterpolationMode .BICUBIC ),
182243)
183- def test_resize_interploation (resize , interpolation ):
244+ def test_resize_interpoation (resize , interpolation ):
184245 loop_images_test (resize = resize , interpolation = interpolation )
185246
186247
0 commit comments