@@ -2046,7 +2046,6 @@ def repeat(a, repeats, axis=None):
20462046 --------
20472047 Multiple GPUs, Multiple CPUs
20482048 """
2049-
20502049 # when array is a scalar
20512050 if np .ndim (a ) == 0 :
20522051 if np .ndim (repeats ) == 0 :
@@ -2100,11 +2099,36 @@ def repeat(a, repeats, axis=None):
21002099 category = UserWarning ,
21012100 )
21022101 repeats = np .int64 (repeats )
2103- result = array ._thunk .repeat (
2104- repeats = repeats ,
2105- axis = axis ,
2106- scalar_repeats = True ,
2107- )
2102+ if repeats < 0 :
2103+ return ValueError (
2104+ "'repeats' should not be negative: {}" .format (repeats )
2105+ )
2106+
2107+ # check output shape (if it will fit to GPU or not)
2108+ out_shape = list (array .shape )
2109+ out_shape [axis ] *= repeats
2110+ out_shape = tuple (out_shape )
2111+ size = sum (out_shape ) * array .itemsize
2112+ # check if size of the output array is less 8GB. In this case we can
2113+ # use output regions, otherwise we will use statcally allocated
2114+ # array
2115+ if size < 8589934592 / 2 :
2116+
2117+ result = array ._thunk .repeat (
2118+ repeats = repeats , axis = axis , scalar_repeats = True
2119+ )
2120+ else :
2121+ # this implementation is taken from CuPy
2122+ result = ndarray (shape = out_shape , dtype = array .dtype )
2123+ a_index = [slice (None )] * len (out_shape )
2124+ res_index = list (a_index )
2125+ offset = 0
2126+ for i in range (a ._shape [axis ]):
2127+ a_index [axis ] = slice (i , i + 1 )
2128+ res_index [axis ] = slice (offset , offset + repeats )
2129+ result [res_index ] = array [a_index ]
2130+ offset += repeats
2131+ return result
21082132 # repeats is an array
21092133 else :
21102134 # repeats should be integer type
@@ -2116,9 +2140,31 @@ def repeat(a, repeats, axis=None):
21162140 repeats = repeats .astype (np .int64 )
21172141 if repeats .shape [0 ] != array .shape [axis ]:
21182142 return ValueError ("incorrect shape of repeats array" )
2119- result = array ._thunk .repeat (
2120- repeats = repeats ._thunk , axis = axis , scalar_repeats = False
2121- )
2143+
2144+ # check output shape (if it will fit to GPU or not)
2145+ out_shape = list (array .shape )
2146+ n_repeats = sum (repeats )
2147+ out_shape [axis ] = n_repeats
2148+ out_shape = tuple (out_shape )
2149+ size = sum (out_shape ) * array .itemsize
2150+ # check if size of the output array is less 8GB. In this case we can
2151+ # use output regions, otherwise we will use statcally allocated
2152+ # array
2153+ if size < 8589934592 / 2 :
2154+ result = array ._thunk .repeat (
2155+ repeats = repeats ._thunk , axis = axis , scalar_repeats = False
2156+ )
2157+ else : # this implementation is taken from CuPy
2158+ result = ndarray (shape = out_shape , dtype = array .dtype )
2159+ a_index = [slice (None )] * len (out_shape )
2160+ res_index = list (a_index )
2161+ offset = 0
2162+ for i in range (a ._shape [axis ]):
2163+ a_index [axis ] = slice (i , i + 1 )
2164+ res_index [axis ] = slice (offset , offset + repeats [i ])
2165+ result [res_index ] = array [a_index ]
2166+ offset += repeats [i ]
2167+ return result
21222168 return ndarray (shape = result .shape , thunk = result )
21232169
21242170
0 commit comments