Skip to content

Commit 50b0c68

Browse files
committed
Add regressions for pt.concatenate_calls
1 parent 45bb4a7 commit 50b0c68

1 file changed

Lines changed: 113 additions & 1 deletion

File tree

test/test_codegen.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1875,7 +1875,8 @@ def build_expression(tracer):
18751875
np.testing.assert_allclose(outputs[key], expected[key])
18761876

18771877

1878-
def test_nested_function_calls(ctx_factory):
1878+
@pytest.mark.parametrize("should_concatenate_bar", (False, True))
1879+
def test_nested_function_calls(ctx_factory, should_concatenate_bar):
18791880
from functools import partial
18801881

18811882
ctx = ctx_factory()
@@ -1909,6 +1910,14 @@ def call_bar(tracer, x, y):
19091910
"out2": call_bar(pt.trace_call, x2, y2)}
19101911
)
19111912
result = pt.tag_all_calls_to_be_inlined(result)
1913+
if should_concatenate_bar:
1914+
from pytato.transform.calls import CallsiteCollector
1915+
assert len(CallsiteCollector(())(result)) == 4
1916+
result = pt.concatenate_calls(
1917+
result,
1918+
lambda x: pt.tags.FunctionIdentifier("bar") in x.call.function.tags)
1919+
assert len(CallsiteCollector(())(result)) == 2
1920+
19121921
expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1),
19131922
"out2": call_bar(ref_tracer, x2, y2)}
19141923
)
@@ -1921,6 +1930,109 @@ def call_bar(tracer, x, y):
19211930
np.testing.assert_allclose(result_out[k], expect_out[k])
19221931

19231932

1933+
def test_concatenate_calls_no_nested(ctx_factory):
1934+
rng = np.random.default_rng(0)
1935+
1936+
ctx = ctx_factory()
1937+
cq = cl.CommandQueue(ctx)
1938+
1939+
def foo(x, y):
1940+
return 3*x + 4*y + 42*pt.sin(x) + 1729*pt.tan(y)*pt.maximum(x, y)
1941+
1942+
x1 = pt.make_placeholder("x1", (10, 4), np.float64)
1943+
x2 = pt.make_placeholder("x2", (10, 4), np.float64)
1944+
1945+
y1 = pt.make_placeholder("y1", (10, 4), np.float64)
1946+
y2 = pt.make_placeholder("y2", (10, 4), np.float64)
1947+
1948+
z1 = pt.make_placeholder("z1", (10, 4), np.float64)
1949+
z2 = pt.make_placeholder("z2", (10, 4), np.float64)
1950+
1951+
result = pt.make_dict_of_named_arrays({"out1": 2*pt.trace_call(foo, 2*x1, 3*x2),
1952+
"out2": 4*pt.trace_call(foo, 4*y1, 9*y2),
1953+
"out3": 6*pt.trace_call(foo, 7*z1, 8*z2)
1954+
})
1955+
1956+
concatenated_result = pt.concatenate_calls(
1957+
result, lambda x: pt.tags.FunctionIdentifier("foo") in x.call.function.tags)
1958+
1959+
result = pt.tag_all_calls_to_be_inlined(result)
1960+
concatenated_result = pt.tag_all_calls_to_be_inlined(concatenated_result)
1961+
1962+
assert (pt.analysis.get_num_nodes(pt.inline_calls(result))
1963+
> pt.analysis.get_num_nodes(pt.inline_calls(concatenated_result)))
1964+
1965+
x1_np, x2_np, y1_np, y2_np, z1_np, z2_np = rng.random((6, 10, 4))
1966+
1967+
_, out_dict1 = pt.generate_loopy(result)(cq,
1968+
x1=x1_np, x2=x2_np,
1969+
y1=y1_np, y2=y2_np,
1970+
z1=z1_np, z2=z2_np)
1971+
1972+
_, out_dict2 = pt.generate_loopy(concatenated_result)(cq,
1973+
x1=x1_np, x2=x2_np,
1974+
y1=y1_np, y2=y2_np,
1975+
z1=z1_np, z2=z2_np)
1976+
assert out_dict1.keys() == out_dict2.keys()
1977+
1978+
for key in out_dict1:
1979+
np.testing.assert_allclose(out_dict1[key], out_dict2[key])
1980+
1981+
1982+
def test_concatenation_via_constant_expressions(ctx_factory):
1983+
1984+
from pytato.transform.calls import CallsiteCollector
1985+
1986+
rng = np.random.default_rng(0)
1987+
1988+
ctx = ctx_factory()
1989+
cq = cl.CommandQueue(ctx)
1990+
1991+
def resampling(coords, iels):
1992+
return coords[iels]
1993+
1994+
n_el = 1000
1995+
n_dof = 20
1996+
n_dim = 3
1997+
1998+
n_left_els = 17
1999+
n_right_els = 29
2000+
2001+
coords_dofs_np = rng.random((n_el, n_dim, n_dof), np.float64)
2002+
left_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_left_els)
2003+
right_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_right_els)
2004+
2005+
coords_dofs = pt.make_data_wrapper(coords_dofs_np)
2006+
left_bnd_iels = pt.make_data_wrapper(left_bnd_iels_np)
2007+
right_bnd_iels = pt.make_data_wrapper(right_bnd_iels_np)
2008+
2009+
lcoords = pt.trace_call(resampling, coords_dofs, left_bnd_iels)
2010+
rcoords = pt.trace_call(resampling, coords_dofs, right_bnd_iels)
2011+
2012+
result = pt.make_dict_of_named_arrays({"lcoords": lcoords,
2013+
"rcoords": rcoords})
2014+
result = pt.tag_all_calls_to_be_inlined(result)
2015+
2016+
assert len(CallsiteCollector(())(result)) == 2
2017+
concated_result = pt.concatenate_calls(
2018+
result,
2019+
lambda cs: pt.tags.FunctionIdentifier("resampling") in cs.call.function.tags
2020+
)
2021+
assert len(CallsiteCollector(())(concated_result)) == 1
2022+
2023+
_, out_result = pt.generate_loopy(result)(cq)
2024+
np.testing.assert_allclose(out_result["lcoords"],
2025+
coords_dofs_np[left_bnd_iels_np])
2026+
np.testing.assert_allclose(out_result["rcoords"],
2027+
coords_dofs_np[right_bnd_iels_np])
2028+
2029+
_, out_concated_result = pt.generate_loopy(result)(cq)
2030+
np.testing.assert_allclose(out_concated_result["lcoords"],
2031+
coords_dofs_np[left_bnd_iels_np])
2032+
np.testing.assert_allclose(out_concated_result["rcoords"],
2033+
coords_dofs_np[right_bnd_iels_np])
2034+
2035+
19242036
if __name__ == "__main__":
19252037
if len(sys.argv) > 1:
19262038
exec(sys.argv[1])

0 commit comments

Comments
 (0)