@@ -1875,7 +1875,8 @@ def build_expression(tracer):
18751875 np .testing .assert_allclose (outputs [key ], expected [key ])
18761876
18771877
1878- def test_nested_function_calls (ctx_factory ):
1878+ @pytest .mark .parametrize ("should_concatenate_bar" , (False , True ))
1879+ def test_nested_function_calls (ctx_factory , should_concatenate_bar ):
18791880 from functools import partial
18801881
18811882 ctx = ctx_factory ()
@@ -1909,6 +1910,14 @@ def call_bar(tracer, x, y):
19091910 "out2" : call_bar (pt .trace_call , x2 , y2 )}
19101911 )
19111912 result = pt .tag_all_calls_to_be_inlined (result )
1913+ if should_concatenate_bar :
1914+ from pytato .transform .calls import CallsiteCollector
1915+ assert len (CallsiteCollector (())(result )) == 4
1916+ result = pt .concatenate_calls (
1917+ result ,
1918+ lambda x : pt .tags .FunctionIdentifier ("bar" ) in x .call .function .tags )
1919+ assert len (CallsiteCollector (())(result )) == 2
1920+
19121921 expect = pt .make_dict_of_named_arrays ({"out1" : call_bar (ref_tracer , x1 , y1 ),
19131922 "out2" : call_bar (ref_tracer , x2 , y2 )}
19141923 )
@@ -1921,6 +1930,109 @@ def call_bar(tracer, x, y):
19211930 np .testing .assert_allclose (result_out [k ], expect_out [k ])
19221931
19231932
1933+ def test_concatenate_calls_no_nested (ctx_factory ):
1934+ rng = np .random .default_rng (0 )
1935+
1936+ ctx = ctx_factory ()
1937+ cq = cl .CommandQueue (ctx )
1938+
1939+ def foo (x , y ):
1940+ return 3 * x + 4 * y + 42 * pt .sin (x ) + 1729 * pt .tan (y )* pt .maximum (x , y )
1941+
1942+ x1 = pt .make_placeholder ("x1" , (10 , 4 ), np .float64 )
1943+ x2 = pt .make_placeholder ("x2" , (10 , 4 ), np .float64 )
1944+
1945+ y1 = pt .make_placeholder ("y1" , (10 , 4 ), np .float64 )
1946+ y2 = pt .make_placeholder ("y2" , (10 , 4 ), np .float64 )
1947+
1948+ z1 = pt .make_placeholder ("z1" , (10 , 4 ), np .float64 )
1949+ z2 = pt .make_placeholder ("z2" , (10 , 4 ), np .float64 )
1950+
1951+ result = pt .make_dict_of_named_arrays ({"out1" : 2 * pt .trace_call (foo , 2 * x1 , 3 * x2 ),
1952+ "out2" : 4 * pt .trace_call (foo , 4 * y1 , 9 * y2 ),
1953+ "out3" : 6 * pt .trace_call (foo , 7 * z1 , 8 * z2 )
1954+ })
1955+
1956+ concatenated_result = pt .concatenate_calls (
1957+ result , lambda x : pt .tags .FunctionIdentifier ("foo" ) in x .call .function .tags )
1958+
1959+ result = pt .tag_all_calls_to_be_inlined (result )
1960+ concatenated_result = pt .tag_all_calls_to_be_inlined (concatenated_result )
1961+
1962+ assert (pt .analysis .get_num_nodes (pt .inline_calls (result ))
1963+ > pt .analysis .get_num_nodes (pt .inline_calls (concatenated_result )))
1964+
1965+ x1_np , x2_np , y1_np , y2_np , z1_np , z2_np = rng .random ((6 , 10 , 4 ))
1966+
1967+ _ , out_dict1 = pt .generate_loopy (result )(cq ,
1968+ x1 = x1_np , x2 = x2_np ,
1969+ y1 = y1_np , y2 = y2_np ,
1970+ z1 = z1_np , z2 = z2_np )
1971+
1972+ _ , out_dict2 = pt .generate_loopy (concatenated_result )(cq ,
1973+ x1 = x1_np , x2 = x2_np ,
1974+ y1 = y1_np , y2 = y2_np ,
1975+ z1 = z1_np , z2 = z2_np )
1976+ assert out_dict1 .keys () == out_dict2 .keys ()
1977+
1978+ for key in out_dict1 :
1979+ np .testing .assert_allclose (out_dict1 [key ], out_dict2 [key ])
1980+
1981+
1982+ def test_concatenation_via_constant_expressions (ctx_factory ):
1983+
1984+ from pytato .transform .calls import CallsiteCollector
1985+
1986+ rng = np .random .default_rng (0 )
1987+
1988+ ctx = ctx_factory ()
1989+ cq = cl .CommandQueue (ctx )
1990+
1991+ def resampling (coords , iels ):
1992+ return coords [iels ]
1993+
1994+ n_el = 1000
1995+ n_dof = 20
1996+ n_dim = 3
1997+
1998+ n_left_els = 17
1999+ n_right_els = 29
2000+
2001+ coords_dofs_np = rng .random ((n_el , n_dim , n_dof ), np .float64 )
2002+ left_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_left_els )
2003+ right_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_right_els )
2004+
2005+ coords_dofs = pt .make_data_wrapper (coords_dofs_np )
2006+ left_bnd_iels = pt .make_data_wrapper (left_bnd_iels_np )
2007+ right_bnd_iels = pt .make_data_wrapper (right_bnd_iels_np )
2008+
2009+ lcoords = pt .trace_call (resampling , coords_dofs , left_bnd_iels )
2010+ rcoords = pt .trace_call (resampling , coords_dofs , right_bnd_iels )
2011+
2012+ result = pt .make_dict_of_named_arrays ({"lcoords" : lcoords ,
2013+ "rcoords" : rcoords })
2014+ result = pt .tag_all_calls_to_be_inlined (result )
2015+
2016+ assert len (CallsiteCollector (())(result )) == 2
2017+ concated_result = pt .concatenate_calls (
2018+ result ,
2019+ lambda cs : pt .tags .FunctionIdentifier ("resampling" ) in cs .call .function .tags
2020+ )
2021+ assert len (CallsiteCollector (())(concated_result )) == 1
2022+
2023+ _ , out_result = pt .generate_loopy (result )(cq )
2024+ np .testing .assert_allclose (out_result ["lcoords" ],
2025+ coords_dofs_np [left_bnd_iels_np ])
2026+ np .testing .assert_allclose (out_result ["rcoords" ],
2027+ coords_dofs_np [right_bnd_iels_np ])
2028+
2029+ _ , out_concated_result = pt .generate_loopy (result )(cq )
2030+ np .testing .assert_allclose (out_concated_result ["lcoords" ],
2031+ coords_dofs_np [left_bnd_iels_np ])
2032+ np .testing .assert_allclose (out_concated_result ["rcoords" ],
2033+ coords_dofs_np [right_bnd_iels_np ])
2034+
2035+
19242036if __name__ == "__main__" :
19252037 if len (sys .argv ) > 1 :
19262038 exec (sys .argv [1 ])
0 commit comments