diff --git a/src/cuda/tile/__init__.py b/src/cuda/tile/__init__.py index 0ce4537..2ef8436 100644 --- a/src/cuda/tile/__init__.py +++ b/src/cuda/tile/__init__.py @@ -55,6 +55,7 @@ Scalar, Tile, + abs, add, arange, argmax, @@ -183,6 +184,7 @@ "Scalar", "Tile", + "abs", "add", "arange", "argmax", diff --git a/src/cuda/tile/_stub.py b/src/cuda/tile/_stub.py index ddb2578..1832bed 100644 --- a/src/cuda/tile/_stub.py +++ b/src/cuda/tile/_stub.py @@ -127,6 +127,10 @@ def __eq__(self, other) -> "TileOrScalar": def __ne__(self, other) -> "TileOrScalar": ... + @function + def __abs__(self) -> "TileOrScalar": + ... + Scalar = int | float | ScalarProtocol @@ -370,6 +374,10 @@ def __eq__(self, other) -> "Tile": def __ne__(self, other) -> "Tile": ... + @function + def __abs__(self) -> "Tile": + ... + Shape = Union[int, tuple[int, ...]] Order = Union[tuple[int, ...], Literal['C'], Literal['F']] @@ -1755,6 +1763,12 @@ def ceil(x, /) -> TileOrScalar: pass +@_doc_unary_op +@function +def abs(x, /) -> TileOrScalar: + pass + + @function def negative(x, /) -> TileOrScalar: """Same as `-x`. diff --git a/test/test_unary_elementwise.py b/test/test_unary_elementwise.py index 8c45a20..e94d8d7 100644 --- a/test/test_unary_elementwise.py +++ b/test/test_unary_elementwise.py @@ -223,6 +223,15 @@ def test_array_abs(shape, tile, dtype, tmp_path): assert_equal(y, abs(x)) +@pytest.mark.parametrize("dtype", bool_dtypes + int_dtypes + float_dtypes, ids=dtype_id) +def test_array_ct_abs(shape, tile, dtype, tmp_path): + x = make_tensor(shape, dtype=dtype, device='cuda') + y = torch.zeros_like(x, device="cuda") + kernel = array_kernel('ct_abs', "ty = ct.abs(tx)", tmp_path) + launch_unary(kernel, x, y, tile) + assert_equal(y, abs(x)) + + @pytest.mark.parametrize("is_constant", [False, True]) @pytest.mark.parametrize("dtype", int_dtypes + float_dtypes, ids=dtype_id) def test_scalar_abs(shape, tile, is_constant, dtype, tmp_path): @@ -241,6 +250,24 @@ def test_scalar_abs(shape, tile, is_constant, dtype, tmp_path): assert_equal(y, abs(x)) +@pytest.mark.parametrize("is_constant", [False, True]) +@pytest.mark.parametrize("dtype", int_dtypes + float_dtypes, ids=dtype_id) +def test_scalar_ct_abs(shape, tile, is_constant, dtype, tmp_path): + if dtype in int_dtypes: + x = -5 + dtype_str = "int" + else: + x = -5.0 + dtype_str = "float" + y = torch.zeros(shape, dtype=dtype, device='cuda') + if not is_constant: + kernel = scalar_kernel('ct_abs', 'c = ct.abs(x)', tmp_path) + else: + kernel = const_scalar_kernel('ct_abs', dtype_str, 'c = ct.abs(x)', tmp_path) + launch_unary(kernel, x, y, tile) + assert_equal(y, abs(x)) + + @pytest.mark.parametrize("bitwise_not_func", ['~', 'ct.bitwise_not']) @pytest.mark.parametrize("dtype", int_dtypes + bool_dtypes, ids=dtype_id) def test_array_bitwise_not(shape, tile, dtype, tmp_path, bitwise_not_func):