diff --git a/pytato/array.py b/pytato/array.py index 68d48098c..4ea42079f 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -736,7 +736,8 @@ def dtype(self) -> np.dtype[Any]: return self.expr.dtype -class AbstractResultWithNamedArrays(Mapping[str, NamedArray], ABC): +@attrs.define(frozen=True, eq=False) +class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): r"""An abstract array computation that results in multiple :class:`Array`\ s, each named. The way in which the values of these arrays are computed is determined by concrete subclasses of this class, e.g. @@ -751,7 +752,7 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], ABC): This container deliberately does not implement arithmetic. """ - + tags: FrozenSet[Tag] = attrs.field(kw_only=True) _mapper_method: ClassVar[str] @abstractmethod @@ -767,6 +768,7 @@ def __len__(self) -> int: pass +@attrs.define(frozen=True, eq=False, init=False) class DictOfNamedArrays(AbstractResultWithNamedArrays): """A container of named results, each of which can be computed as an array expression provided to the constructor. @@ -775,12 +777,21 @@ class DictOfNamedArrays(AbstractResultWithNamedArrays): .. automethod:: __init__ """ + _data: Mapping[str, Array] _mapper_method: ClassVar[str] = "map_dict_of_named_arrays" - def __init__(self, data: Mapping[str, Array]): - super().__init__() - self._data = data + def __init__(self, data: Mapping[str, Array], *, + tags: Optional[FrozenSet[Tag]] = None) -> None: + if tags is None: + from warnings import warn + warn("Passing `tags=None` is deprecated and will result" + " in an error from 2023. To remove this message either" + " call make_dict_of_named_arrays or pass the `tags` argument.") + tags = frozenset() + + object.__setattr__(self, "_data", data) + object.__setattr__(self, "tags", tags) def __hash__(self) -> int: return hash(frozenset(self._data.items())) @@ -1977,12 +1988,14 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], # {{{ make_dict_of_named_arrays -def make_dict_of_named_arrays(data: Dict[str, Array]) -> DictOfNamedArrays: +def make_dict_of_named_arrays(data: Dict[str, Array], *, + tags: FrozenSet[Tag] = frozenset() + ) -> DictOfNamedArrays: """Make a :class:`DictOfNamedArrays` object. :param data: member keys and arrays """ - return DictOfNamedArrays(data) + return DictOfNamedArrays(data, tags=(tags | _get_default_tags())) # }}} diff --git a/pytato/codegen.py b/pytato/codegen.py index eb42d74cd..7fa4bea1f 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -36,7 +36,7 @@ InputArgumentBase, Einsum, AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, BasicIndex, - NormalizedSlice) + NormalizedSlice, make_dict_of_named_arrays) from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, INT_CLASSES, IntegralT) @@ -183,7 +183,9 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: return LoopyCall(translation_unit=translation_unit, bindings=bindings, - entrypoint=entrypoint) + entrypoint=entrypoint, + tags=expr.tags + ) def map_data_wrapper(self, expr: DataWrapper) -> Array: name = _generate_name_for_temp(expr, self.var_name_gen, "_pt_data") @@ -616,9 +618,9 @@ def normalize_outputs(result: Union[Array, DictOfNamedArrays, "either an Array or a DictOfNamedArrays") if isinstance(result, Array): - outputs = DictOfNamedArrays({"_pt_out": result}) + outputs = make_dict_of_named_arrays({"_pt_out": result}) elif isinstance(result, dict): - outputs = DictOfNamedArrays(result) + outputs = make_dict_of_named_arrays(result) else: assert isinstance(result, DictOfNamedArrays) outputs = result diff --git a/pytato/loopy.py b/pytato/loopy.py index b33e5a11c..66a5c9aff 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -26,10 +26,11 @@ import numpy as np +import attrs import loopy as lp import pymbolic.primitives as prim from typing import (Dict, Optional, Any, Iterator, FrozenSet, Union, Sequence, - Tuple, Iterable, Mapping) + Tuple, Iterable, Mapping, ClassVar) from numbers import Number from pytato.array import (AbstractResultWithNamedArrays, Array, ShapeType, NamedArray, ArrayOrScalar, SizeParam, AxesT) @@ -61,26 +62,25 @@ """ +@attrs.define(eq=False, frozen=True) class LoopyCall(AbstractResultWithNamedArrays): """ An array expression node representing a call to an entrypoint in a :mod:`loopy` translation unit. """ - _mapper_method = "map_loopy_call" + translation_unit: "lp.TranslationUnit" + bindings: Dict[str, ArrayOrScalar] + entrypoint: str - def __init__(self, - translation_unit: "lp.TranslationUnit", - bindings: Dict[str, ArrayOrScalar], - entrypoint: str): - entry_kernel = translation_unit[entrypoint] - super().__init__() - self._result_names = frozenset({name - for name, lp_arg in entry_kernel.arg_dict.items() - if lp_arg.is_output}) - - self.translation_unit = translation_unit - self.bindings = bindings - self.entrypoint = entrypoint + _mapper_method: ClassVar[str] = "map_loopy_call" + + copy = attrs.evolve + + @property + def _result_names(self) -> FrozenSet[str]: + return frozenset({name + for name, lp_arg in self._entry_kernel.arg_dict.items() + if lp_arg.is_output}) @memoize_method def _to_pytato(self, expr: ScalarExpression) -> ScalarExpression: @@ -207,6 +207,8 @@ def call_loopy(translation_unit: "lp.TranslationUnit", to :class:`pytato.array.Array`. :arg entrypoint: the entrypoint of the ``translation_unit`` parameter. """ + from pytato.array import _get_default_tags + if entrypoint is None: if len(translation_unit.entrypoints) != 1: raise ValueError("cannot infer entrypoint") @@ -285,7 +287,8 @@ def call_loopy(translation_unit: "lp.TranslationUnit", translation_unit = translation_unit.with_entrypoints(frozenset()) - return LoopyCall(translation_unit, bindings, entrypoint) + return LoopyCall(translation_unit, bindings, entrypoint, + tags=_get_default_tags()) # {{{ shape inference diff --git a/pytato/partition.py b/pytato/partition.py index 0f7bc754a..e72fc7166 100644 --- a/pytato/partition.py +++ b/pytato/partition.py @@ -35,7 +35,7 @@ from pytato.transform import EdgeCachedMapper, CachedWalkMapper from pytato.array import ( Array, AbstractResultWithNamedArrays, Placeholder, - DictOfNamedArrays, make_placeholder) + DictOfNamedArrays, make_placeholder, make_dict_of_named_arrays) from pytato.target import BoundProgram @@ -359,7 +359,7 @@ def find_partition(outputs: DictOfNamedArrays, _check_partition_disjointness(result) from pytato.analysis import get_num_nodes - num_nodes_per_part = [get_num_nodes(DictOfNamedArrays( + num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays( {x: result.var_name_to_result[x] for x in part.output_names})) for part in result.parts.values()] @@ -423,7 +423,7 @@ def generate_code_for_partition(partition: GraphPartition) \ for part in sorted(partition.parts.values(), key=lambda part_: sorted(part_.output_names)): - d = DictOfNamedArrays( + d = make_dict_of_named_arrays( {var_name: partition.var_name_to_result[var_name] for var_name in part.output_names }) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 2553cea55..f8436b49b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -303,7 +303,9 @@ def map_named_array(self, expr: NamedArray) -> Array: def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: return DictOfNamedArrays({key: self.rec(val.expr) - for key, val in expr.items()}) + for key, val in expr.items()}, + tags=expr.tags + ) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array) @@ -312,7 +314,9 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: return LoopyCall(translation_unit=expr.translation_unit, bindings=bindings, - entrypoint=expr.entrypoint) + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: return LoopyCallResult( @@ -489,7 +493,9 @@ def map_named_array(self, expr: NamedArray, *args: Any, **kwargs: Any) -> Array: def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: Any, **kwargs: Any) -> DictOfNamedArrays: return DictOfNamedArrays({key: self.rec(val.expr, *args, **kwargs) - for key, val in expr.items()}) + for key, val in expr.items()}, + tags=expr.tags, + ) def map_loopy_call(self, expr: LoopyCall, *args: Any, **kwargs: Any) -> LoopyCall: @@ -500,7 +506,9 @@ def map_loopy_call(self, expr: LoopyCall, return LoopyCall(translation_unit=expr.translation_unit, bindings=bindings, - entrypoint=expr.entrypoint) + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult, *args: Any, **kwargs: Any) -> Array: @@ -1222,11 +1230,12 @@ def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, items in *source_dict* """ if not source_dict: - return DictOfNamedArrays({}) + data = {} + else: + data = {name: copy_mapper(val.expr) + for name, val in sorted(source_dict.items())} - data = {name: copy_mapper(val.expr) - for name, val in sorted(source_dict.items())} - return DictOfNamedArrays(data) + return DictOfNamedArrays(data, tags=source_dict.tags) def get_dependencies(expr: DictOfNamedArrays) -> Dict[str, FrozenSet[Array]]: @@ -1308,7 +1317,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: for name, ary in expr.items(): new_data[name] = materializer(ary.expr).expr - return DictOfNamedArrays(new_data) + return DictOfNamedArrays(new_data, tags=expr.tags) # }}} @@ -1693,6 +1702,7 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: name: self.handle_edge(expr, child) if isinstance(child, Array) else child for name, child in sorted(expr.bindings.items())}, + tags=expr.tags, ) def map_distributed_send_ref_holder( diff --git a/test/test_codegen.py b/test/test_codegen.py index 5bc690e15..ea1c58e0a 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -183,7 +183,7 @@ def test_codegen_with_DictOfNamedArrays(ctx_factory): # noqa x_in = np.array([1, 2, 3, 4, 5]) y_in = np.array([6, 7, 8, 9, 10]) - result = pt.DictOfNamedArrays(dict(x_out=x, y_out=y)) + result = pt.make_dict_of_named_arrays(dict(x_out=x, y_out=y)) # With return_dict. prog = pt.generate_loopy(result) @@ -525,7 +525,7 @@ def test_dict_of_named_array_codegen_avoids_recomputation(): y = 2*x z = y + 4*x - yz = pt.DictOfNamedArrays({"y": y, "z": z}) + yz = pt.make_dict_of_named_arrays({"y": y, "z": z}) knl = pt.generate_loopy(yz).kernel assert ("y" in knl.id_to_insn["z_store"].read_dependency_names()) @@ -750,7 +750,7 @@ def test_call_loopy_with_same_callee_names(ctx_factory): cuatro_u = 2*call_loopy(twice, {"x": u}, "callee")["y"] nueve_u = 3*call_loopy(thrice, {"x": u}, "callee")["y"] - out = pt.DictOfNamedArrays({"cuatro_u": cuatro_u, "nueve_u": nueve_u}) + out = pt.make_dict_of_named_arrays({"cuatro_u": cuatro_u, "nueve_u": nueve_u}) evt, out_dict = pt.generate_loopy(out, options=lp.Options(return_dict=True))(queue) @@ -1376,7 +1376,7 @@ def test_random_dag_against_numpy(ctx_factory): ref_result = make_random_dag(rdagc_np) dag = make_random_dag(rdagc_pt) from pytato.transform import materialize_with_mpms - dict_named_arys = pt.DictOfNamedArrays({"result": dag}) + dict_named_arys = pt.make_dict_of_named_arrays({"result": dag}) dict_named_arys = materialize_with_mpms(dict_named_arys) if 0: pt.show_dot_graph(dict_named_arys) @@ -1407,7 +1407,7 @@ def test_partitioner(ctx_factory): ref_result = make_random_dag(rdagc_np) from pytato.transform import materialize_with_mpms - dict_named_arys = materialize_with_mpms(pt.DictOfNamedArrays( + dict_named_arys = materialize_with_mpms(pt.make_dict_of_named_arrays( {"result": make_random_dag(rdagc_pt)})) from dataclasses import dataclass diff --git a/test/test_distributed.py b/test/test_distributed.py index c7604d6fa..4d07aee2f 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -92,7 +92,7 @@ def _do_test_distributed_execution_basic(ctx_factory): y = x+halo # Find the partition - outputs = pt.DictOfNamedArrays({"out": y}) + outputs = pt.make_dict_of_named_arrays({"out": y}) distributed_parts = find_distributed_partition(outputs) prg_per_partition = generate_code_for_partition(distributed_parts) @@ -167,7 +167,8 @@ def gen_comm(rdagc): additional_generators=[ (comm_fake_prob, gen_comm) ]) - pt_dag = pt.DictOfNamedArrays({"result": make_random_dag(rdagc_comm)}) + pt_dag = pt.make_dict_of_named_arrays( + {"result": make_random_dag(rdagc_comm)}) x_comm = pt.transform.materialize_with_mpms(pt_dag) distributed_partition = find_distributed_partition(x_comm) diff --git a/test/test_jax.py b/test/test_jax.py index cbbf8eeb5..0dd8f6eb4 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -102,7 +102,7 @@ def test_random_dag_against_numpy(jit): ref_result = make_random_dag(rdagc_np) dag = make_random_dag(rdagc_pt) from pytato.transform import materialize_with_mpms - dict_named_arys = pt.DictOfNamedArrays({"result": dag}) + dict_named_arys = pt.make_dict_of_named_arrays({"result": dag}) dict_named_arys = materialize_with_mpms(dict_named_arys) if 0: pt.show_dot_graph(dict_named_arys)