diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 0c93e872b..5ad456172 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -41,6 +41,8 @@ DistributeMapperBase) from pymbolic.mapper.stringifier import (StringifyMapper as StringifyMapperBase) +from pymbolic.mapper.equality import (EqualityMapper as + EqualityMapperBase) from pymbolic.mapper import CombineMapper as CombineMapperBase from pymbolic.mapper.collector import TermCollector as TermCollectorBase from immutables import Map @@ -184,6 +186,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str: bounds_expr = "{" + bounds_expr + "}" return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})") + +class EqualityMapper(EqualityMapperBase): + def map_reduce(self, expr: Reduce, other: Reduce) -> bool: + return ( + len(expr.bounds) == len(other.bounds) + and all(k == other_k + and self.rec(lb, other_lb) and self.rec(ub, other_ub) + for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip( + sorted(expr.bounds.items()), + sorted(other.bounds.items()))) + and expr.op == other.op + and self.rec(expr.inner_expr, other.inner_expr) + ) + # }}} @@ -240,6 +256,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(), # {{{ custom scalar expression nodes class ExpressionBase(prim.Expression): + def make_equality_mapper(self) -> EqualityMapper: + return EqualityMapper() + def make_stringifier(self, originating_stringifier: Any = None) -> str: return StringifyMapper() diff --git a/requirements.txt b/requirements.txt index a9cf2c76e..85d451ec3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1 -git+https://github.com/inducer/pymbolic.git#egg=pymbolic +git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy asciidag diff --git a/test/test_pytato.py b/test/test_pytato.py index c6d3b1e11..777a5b3d5 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -355,6 +355,8 @@ def test_userscollector(): def test_asciidag(): + pytest.importorskip("asciidag") + n = pt.make_size_param("n") array = pt.make_placeholder(name="array", shape=n, dtype=np.float64) stack = pt.stack([array, 2*array, array + 6])