Skip to content

Commit 34bcd56

Browse files
committed
Review fixes - 2
Modified resize operator size calculation Signed-off-by: Marek Dabek <mdabek@nvidia.com>
1 parent e4ade7a commit 34bcd56

4 files changed

Lines changed: 157 additions & 38 deletions

File tree

dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,22 @@ def resize(
4040
effective_size, mode = Resize.infer_effective_size(size, max_size)
4141
interpolation = Resize.interpolation_modes[interpolation]
4242

43+
if isinstance(img, ndd.Tensor):
44+
img_shape = img.shape
45+
elif isinstance(img, ndd.Batch):
46+
img_shape = img.shape[0] # Batches have uniform layout
47+
else:
48+
raise TypeError(f"Input must be ndd.Tensor or ndd.Batch got {type(img)}")
49+
50+
if img.layout in ["HWC", "NHWC"]:
51+
original_h = img_shape[-3]
52+
original_w = img_shape[-2]
53+
elif img.layout in ["CHW", "NCHW"]:
54+
original_h = img_shape[-2]
55+
original_w = img_shape[-1]
56+
4357
target_h, target_w = Resize.calculate_target_size(
44-
img.shape, effective_size, max_size, size is None
58+
(original_h, original_w), effective_size, max_size, size is None
4559
)
4660

4761
# Shorter edge limited by max size

dali/python/nvidia/dali/experimental/torchvision/v2/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _kernel(self, data_input):
206206

207207
def __call__(self, data_input):
208208

209-
Operator.verify_data(data_input)
209+
type(self).verify_data(data_input)
210210

211211
if self.device == "gpu":
212212
data_input = data_input.gpu()

dali/python/nvidia/dali/experimental/torchvision/v2/resize.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@
2424
import numpy as np
2525

2626

27+
def get_inputHW(data_input):
28+
"""
29+
Gets the height and width of the input data.
30+
31+
Parameters
32+
----------
33+
data_input : Tensor
34+
Input data to get the height and width of.
35+
36+
Returns
37+
-------
38+
input_height : int
39+
Height of the input data.
40+
input_width : int
41+
Width of the input data.
42+
"""
43+
layout = data_input.property("layout")[0]
44+
45+
# CWH
46+
if layout == np.frombuffer(bytes("C", "utf-8"), dtype=np.uint8)[0]:
47+
input_height = data_input.shape()[-1]
48+
input_width = data_input.shape()[-2]
49+
# HWC
50+
else:
51+
input_height = data_input.shape()[-3]
52+
input_width = data_input.shape()[-2]
53+
54+
return input_height, input_width, data_input
55+
56+
2757
class VerificationSize(ArgumentVerificationRule):
2858
@classmethod
2959
def verify(cls, *, size, max_size, interpolation, **_):
@@ -84,7 +114,9 @@ class Resize(Operator):
84114
InterpolationMode.HAMMING: DALIInterpType.INTERP_GAUSSIAN, # TODO:
85115
InterpolationMode.LANCZOS: DALIInterpType.INTERP_LANCZOS3,
86116
}
117+
87118
arg_rules = [VerificationSize]
119+
preprocess_data = get_inputHW
88120

89121
@classmethod
90122
def infer_effective_size(
@@ -120,6 +152,7 @@ def calculate_target_size(
120152
):
121153
orig_h = orig_size[0]
122154
orig_w = orig_size[1]
155+
123156
target_h = effective_size[0]
124157
target_w = effective_size[1]
125158

@@ -160,15 +193,24 @@ def _kernel(self, data_input):
160193
with ``torchvision.transforms.Resize`` documentation and applies DALI operator on the
161194
``data_input``.
162195
"""
196+
input_height, input_width, data_input = data_input
163197

164198
target_h, target_w = Resize.calculate_target_size(
165-
data_input.shape(), self.effective_size, self.max_size, self.size is None
199+
orig_size = (input_height, input_width),
200+
effective_size = self.effective_size,
201+
max_size = self.max_size,
202+
no_size = self.size is None
166203
)
167204

168205
# Shorter edge limited by max size
169206
if self.mode == "resize_shorter":
170207
return fn.resize(
171-
data_input, device=self.device, resize_shorter=target_h, max_size=self.max_size
208+
data_input,
209+
device=self.device,
210+
resize_shorter=target_h,
211+
max_size=self.max_size,
212+
antialias=self.antialias,
213+
interp_type=self.interpolation,
172214
)
173215

174216
return fn.resize(
@@ -179,4 +221,6 @@ def _kernel(self, data_input):
179221
fn.cast(target_w, dtype=dali.types.FLOAT),
180222
),
181223
mode=self.mode,
224+
antialias=self.antialias,
225+
interp_type=self.interpolation,
182226
)

dali/test/python/torchvision/test_tv_resize.py

Lines changed: 95 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16-
from typing import Sequence
16+
from typing import Sequence, Literal, Union
1717

1818
import numpy as np
1919
from nose2.tools import params, cartesian_params
@@ -50,6 +50,7 @@ def build_resize_transform(
5050
max_size: int = None,
5151
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR,
5252
antialias: bool = False,
53+
device: Literal["cpu", "gpu"] = "cpu",
5354
):
5455
t = transforms.Compose(
5556
[
@@ -61,13 +62,59 @@ def build_resize_transform(
6162
td = Compose(
6263
[
6364
Resize(
64-
size=resize, max_size=max_size, interpolation=interpolation, antialias=antialias
65+
size=resize,
66+
max_size=max_size,
67+
interpolation=interpolation,
68+
antialias=antialias,
69+
device=device,
6570
),
6671
]
6772
)
6873
return t, td
6974

7075

76+
def _internal_loop(
77+
input_data: Union[Image.Image, torch.Tensor],
78+
t: transforms.Resize,
79+
td: Resize,
80+
resize: int | Sequence[int],
81+
max_size: int = None,
82+
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR,
83+
antialias: bool = False,
84+
):
85+
out_fn = fn_tv.resize(
86+
input_data,
87+
size=resize,
88+
max_size=max_size,
89+
interpolation=interpolation,
90+
antialias=antialias,
91+
)
92+
out_dali_fn = fn_dali.resize(
93+
input_data,
94+
size=resize,
95+
max_size=max_size,
96+
interpolation=interpolation,
97+
antialias=antialias,
98+
)
99+
out_tv = t(input_data)
100+
out_dali_tv = td(input_data)
101+
102+
if isinstance(input_data, Image.Image):
103+
out_tv = transforms.functional.pil_to_tensor(out_tv).unsqueeze(0).permute(0, 2, 3, 1)
104+
out_dali_tv = (
105+
transforms.functional.pil_to_tensor(out_dali_tv).unsqueeze(0).permute(0, 2, 3, 1)
106+
)
107+
out_fn = transforms.functional.pil_to_tensor(out_fn)
108+
out_dali_fn = transforms.functional.pil_to_tensor(out_dali_fn)
109+
110+
assert torch.allclose(
111+
torch.tensor(out_tv.shape[1:3]), torch.tensor(out_dali_tv.shape[1:3]), rtol=0, atol=1
112+
), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}"
113+
assert torch.allclose(
114+
torch.tensor(out_fn.shape[1:3]), torch.tensor(out_dali_fn.shape[1:3]), rtol=0, atol=1
115+
), f"Should be:{out_fn.shape} is:{out_dali_fn.shape}"
116+
117+
71118
def loop_images_test_no_build(
72119
t: transforms.Resize,
73120
td: Resize,
@@ -78,52 +125,66 @@ def loop_images_test_no_build(
78125
):
79126
for fn in test_files:
80127
img = Image.open(fn)
81-
out_fn = transforms.functional.pil_to_tensor(
82-
fn_tv.resize(
83-
img,
84-
size=resize,
85-
max_size=max_size,
86-
interpolation=interpolation,
87-
antialias=antialias,
88-
)
89-
)
90-
out_dali_fn = transforms.functional.pil_to_tensor(
91-
fn_dali.resize(
92-
img,
93-
size=resize,
94-
max_size=max_size,
95-
interpolation=interpolation,
96-
antialias=antialias,
97-
)
98-
)
128+
_internal_loop(img, t, td, resize, max_size, interpolation, antialias)
129+
# assert torch.equal(out_tv, out_dali_tv)
99130

100-
out_tv = transforms.functional.pil_to_tensor(t(img)).unsqueeze(0).permute(0, 2, 3, 1)
101-
out_dali_tv = transforms.functional.pil_to_tensor(td(img)).unsqueeze(0).permute(0, 2, 3, 1)
102131

103-
assert torch.allclose(
104-
torch.tensor(out_tv.shape[1:3]), torch.tensor(out_dali_tv.shape[1:3]), rtol=0, atol=1
105-
), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}"
106-
assert torch.allclose(
107-
torch.tensor(out_fn.shape[1:3]), torch.tensor(out_dali_fn.shape[1:3]), rtol=0, atol=1
108-
), f"Should be:{out_fn.shape} is:{out_dali_fn.shape}"
132+
def build_tensors(max_size: int = 512, channels: int = 3):
133+
h = torch.randint(10, max_size, (1,)).item()
134+
w = torch.randint(10, max_size, (1,)).item()
135+
tensors = [
136+
torch.ones((channels, max_size, max_size)),
137+
torch.ones((1, channels, max_size, max_size)),
138+
torch.ones((10, channels, max_size, max_size)),
139+
torch.ones((channels, max_size // 2, max_size)),
140+
torch.ones((1, channels, max_size // 2, max_size)),
141+
torch.ones((10, channels, max_size // 2, max_size)),
142+
torch.ones((channels, max_size, max_size // 2)),
143+
torch.ones((1, channels, max_size, max_size // 2)),
144+
torch.ones((10, channels, max_size, max_size // 2)),
145+
torch.ones((channels, h, w)),
146+
torch.ones((1, channels, h, w)),
147+
torch.ones((10, channels, h, w)),
148+
]
149+
150+
return tensors
151+
152+
153+
def loop_tensors_test(
154+
resize: int | Sequence[int],
155+
max_size: int = None,
156+
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR,
157+
antialias: bool = False,
158+
device: Literal["cpu", "gpu"] = "cpu",
159+
):
160+
t, td = build_resize_transform(resize, max_size, interpolation, antialias, device)
161+
tensors = build_tensors()
109162

110-
# assert torch.equal(out_tv, out_dali_tv)
163+
for tn in tensors:
164+
_internal_loop(tn, t, td, resize, max_size, interpolation, antialias)
111165

112166

113167
def loop_images_test(
114168
resize: int | Sequence[int],
115169
max_size: int = None,
116170
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR,
117171
antialias: bool = False,
172+
device: Literal["cpu", "gpu"] = "cpu",
118173
):
119-
t, td = build_resize_transform(resize, max_size, interpolation, antialias)
174+
t, td = build_resize_transform(resize, max_size, interpolation, antialias, device)
120175
loop_images_test_no_build(t, td, resize, max_size, interpolation, antialias)
121176

122177

123-
@params(512, 2048, ([512, 512]), ([2048, 2048]))
124-
def test_resize_sizes(resize):
178+
@cartesian_params((512, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu"))
179+
def test_resize_sizes_images(resize, device):
180+
# Resize with single int (preserve aspect ratio)
181+
loop_images_test(resize=resize, device=device)
182+
183+
184+
@cartesian_params((512, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu"))
185+
def test_resize_sizes_tensors(resize, device):
125186
# Resize with single int (preserve aspect ratio)
126-
loop_images_test(resize=resize)
187+
loop_tensors_test(resize=resize, device=device)
127188

128189

129190
@params((480, 512), (100, 124), (None, 512), (1024, 512), ([256, 256], 512), (None, None))
@@ -180,7 +241,7 @@ def test_resize_max_sizes(resize, max_size):
180241
([256, 256], transforms.InterpolationMode.BILINEAR),
181242
(640, transforms.InterpolationMode.BICUBIC),
182243
)
183-
def test_resize_interploation(resize, interpolation):
244+
def test_resize_interpoation(resize, interpolation):
184245
loop_images_test(resize=resize, interpolation=interpolation)
185246

186247

0 commit comments

Comments
 (0)