Skip to content

Commit 18c2f5f

Browse files
committed
avoid duplication in untag_loopy_call_results
1 parent 89be59f commit 18c2f5f

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

meshmode/array_context.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,13 @@ def transform_dag(self, dag: pt_typ.DictOfNamedArrays) -> pt_typ.DictOfNamedArra
262262
def untag_loopy_call_results(
263263
expr: pt.Array | pt.AbstractResultWithNamedArrays
264264
) -> pt.Array | pt.AbstractResultWithNamedArrays:
265-
if isinstance(expr, pt.NamedArray):
266-
return expr.copy(tags=frozenset(),
267-
axes=(pt.Axis(frozenset()),)*expr.ndim)
265+
if isinstance(expr, pt.loopy.LoopyCallResult):
266+
new_tags = frozenset()
267+
if any(axis.tags for axis in expr.axes):
268+
new_axes = (pt.Axis(frozenset()),)*expr.ndim
269+
else:
270+
new_axes = expr.axes
271+
return expr.replace_if_different(tags=new_tags, axes=new_axes)
268272
else:
269273
return expr
270274

0 commit comments

Comments
 (0)