Skip to content

Commit 4b13a6c

Browse files
committed
addressing the case when output region for repeat operation is too big
1 parent 472d6d0 commit 4b13a6c

4 files changed

Lines changed: 64 additions & 10 deletions

File tree

cunumeric/module.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,7 +2046,6 @@ def repeat(a, repeats, axis=None):
20462046
--------
20472047
Multiple GPUs, Multiple CPUs
20482048
"""
2049-
20502049
# when array is a scalar
20512050
if np.ndim(a) == 0:
20522051
if np.ndim(repeats) == 0:
@@ -2100,11 +2099,36 @@ def repeat(a, repeats, axis=None):
21002099
category=UserWarning,
21012100
)
21022101
repeats = np.int64(repeats)
2103-
result = array._thunk.repeat(
2104-
repeats=repeats,
2105-
axis=axis,
2106-
scalar_repeats=True,
2107-
)
2102+
if repeats < 0:
2103+
return ValueError(
2104+
"'repeats' should not be negative: {}".format(repeats)
2105+
)
2106+
2107+
# check output shape (if it will fit to GPU or not)
2108+
out_shape = list(array.shape)
2109+
out_shape[axis] *= repeats
2110+
out_shape = tuple(out_shape)
2111+
size = sum(out_shape) * array.itemsize
2112+
# check if size of the output array is less 8GB. In this case we can
2113+
# use output regions, otherwise we will use statcally allocated
2114+
# array
2115+
if size < 8589934592 / 2:
2116+
2117+
result = array._thunk.repeat(
2118+
repeats=repeats, axis=axis, scalar_repeats=True
2119+
)
2120+
else:
2121+
# this implementation is taken from CuPy
2122+
result = ndarray(shape=out_shape, dtype=array.dtype)
2123+
a_index = [slice(None)] * len(out_shape)
2124+
res_index = list(a_index)
2125+
offset = 0
2126+
for i in range(a._shape[axis]):
2127+
a_index[axis] = slice(i, i + 1)
2128+
res_index[axis] = slice(offset, offset + repeats)
2129+
result[res_index] = array[a_index]
2130+
offset += repeats
2131+
return result
21082132
# repeats is an array
21092133
else:
21102134
# repeats should be integer type
@@ -2116,9 +2140,31 @@ def repeat(a, repeats, axis=None):
21162140
repeats = repeats.astype(np.int64)
21172141
if repeats.shape[0] != array.shape[axis]:
21182142
return ValueError("incorrect shape of repeats array")
2119-
result = array._thunk.repeat(
2120-
repeats=repeats._thunk, axis=axis, scalar_repeats=False
2121-
)
2143+
2144+
# check output shape (if it will fit to GPU or not)
2145+
out_shape = list(array.shape)
2146+
n_repeats = sum(repeats)
2147+
out_shape[axis] = n_repeats
2148+
out_shape = tuple(out_shape)
2149+
size = sum(out_shape) * array.itemsize
2150+
# check if size of the output array is less 8GB. In this case we can
2151+
# use output regions, otherwise we will use statcally allocated
2152+
# array
2153+
if size < 8589934592 / 2:
2154+
result = array._thunk.repeat(
2155+
repeats=repeats._thunk, axis=axis, scalar_repeats=False
2156+
)
2157+
else: # this implementation is taken from CuPy
2158+
result = ndarray(shape=out_shape, dtype=array.dtype)
2159+
a_index = [slice(None)] * len(out_shape)
2160+
res_index = list(a_index)
2161+
offset = 0
2162+
for i in range(a._shape[axis]):
2163+
a_index[axis] = slice(i, i + 1)
2164+
res_index[axis] = slice(offset, offset + repeats[i])
2165+
result[res_index] = array[a_index]
2166+
offset += repeats[i]
2167+
return result
21222168
return ndarray(shape=result.shape, thunk=result)
21232169

21242170

src/cunumeric/index/repeat.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ struct RepeatImplBody<VariantKind::CPU, CODE, DIM> {
6969
int64_t out_idx = 0;
7070
for (size_t in_idx = 0; in_idx < volume; ++in_idx) {
7171
auto p = in_pitches.unflatten(in_idx, in_rect.lo);
72+
// TODO replace assert with Legate exception handeling interface when available
73+
assert(repeats[p] >= 0);
7274
for (size_t r = 0; r < repeats[p]; r++) out[out_idx++] = in[p];
7375
}
7476
}
@@ -88,6 +90,8 @@ struct RepeatImplBody<VariantKind::CPU, CODE, DIM> {
8890
for (int64_t idx = in_rect.lo[axis]; idx <= in_rect.hi[axis]; ++idx) {
8991
p[axis] = idx;
9092
offsets[off_idx++] = sum;
93+
// TODO replace assert with Legate exception handeling interface when available
94+
assert(repeats[p] >= 0);
9195
sum += repeats[p];
9296
}
9397

src/cunumeric/index/repeat.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
4141
if (offset < extent) {
4242
auto p = origin;
4343
p[axis] += offset;
44+
// TODO replace assert with Legate exception handeling interface when available
45+
assert(repeats[p] >= 0);
4446
auto val = repeats[p];
4547
offsets[offset] = val;
4648
SumReduction<int64_t>::fold<true>(value, val);

src/cunumeric/index/repeat_omp.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ struct RepeatImplBody<VariantKind::OMP, CODE, DIM> {
7777
int64_t axis_lo = p[axis];
7878
#pragma omp for schedule(static) private(p)
7979
for (int64_t idx = 0; idx < axis_extent; ++idx) {
80-
p[axis] = axis_lo + idx;
80+
p[axis] = axis_lo + idx;
81+
// TODO replace assert with Legate exception handeling interface when available
82+
assert(repeats[p] >= 0);
8183
auto val = repeats[p];
8284
offsets[idx] = val;
8385
local_sums[tid] += val;

0 commit comments

Comments
 (0)