Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
DistributeMapperBase)
from pymbolic.mapper.stringifier import (StringifyMapper as
StringifyMapperBase)
from pymbolic.mapper.equality import (EqualityMapper as
EqualityMapperBase)
from pymbolic.mapper import CombineMapper as CombineMapperBase
from pymbolic.mapper.collector import TermCollector as TermCollectorBase
from immutables import Map
Expand Down Expand Up @@ -184,6 +186,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
bounds_expr = "{" + bounds_expr + "}"
return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")


class EqualityMapper(EqualityMapperBase):
def map_reduce(self, expr: Reduce, other: Reduce) -> bool:
return (
len(expr.bounds) == len(other.bounds)
and all(k == other_k
and self.rec(lb, other_lb) and self.rec(ub, other_ub)
for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip(
sorted(expr.bounds.items()),
sorted(other.bounds.items())))
and expr.op == other.op
and self.rec(expr.inner_expr, other.inner_expr)
)

# }}}


Expand Down Expand Up @@ -240,6 +256,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
# {{{ custom scalar expression nodes

class ExpressionBase(prim.Expression):
def make_equality_mapper(self) -> EqualityMapper:
return EqualityMapper()

def make_stringifier(self, originating_stringifier: Any = None) -> str:
return StringifyMapper()

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
git+https://github.com/inducer/genpy.git#egg=genpy
git+https://github.com/inducer/loopy.git#egg=loopy
git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy

asciidag
2 changes: 2 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def test_userscollector():


def test_asciidag():
pytest.importorskip("asciidag")

n = pt.make_size_param("n")
array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
stack = pt.stack([array, 2*array, array + 6])
Expand Down