Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,8 @@ def dtype(self) -> np.dtype[Any]:
return self.expr.dtype


class AbstractResultWithNamedArrays(Mapping[str, NamedArray], ABC):
@attrs.define(frozen=True, eq=False)
class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC):
r"""An abstract array computation that results in multiple :class:`Array`\ s,
each named. The way in which the values of these arrays are computed
is determined by concrete subclasses of this class, e.g.
Expand All @@ -751,7 +752,7 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], ABC):

This container deliberately does not implement arithmetic.
"""

tags: FrozenSet[Tag] = attrs.field(kw_only=True)
_mapper_method: ClassVar[str]

@abstractmethod
Expand All @@ -767,6 +768,7 @@ def __len__(self) -> int:
pass


@attrs.define(frozen=True, eq=False, init=False)
class DictOfNamedArrays(AbstractResultWithNamedArrays):
"""A container of named results, each of which can be computed as an
array expression provided to the constructor.
Expand All @@ -775,12 +777,21 @@ class DictOfNamedArrays(AbstractResultWithNamedArrays):

.. automethod:: __init__
"""
_data: Mapping[str, Array]

_mapper_method: ClassVar[str] = "map_dict_of_named_arrays"

def __init__(self, data: Mapping[str, Array]):
super().__init__()
self._data = data
def __init__(self, data: Mapping[str, Array], *,
tags: Optional[FrozenSet[Tag]] = None) -> None:
if tags is None:
from warnings import warn
warn("Passing `tags=None` is deprecated and will result"
" in an error from 2023. To remove this message either"
" call make_dict_of_named_arrays or pass the `tags` argument.")
tags = frozenset()

object.__setattr__(self, "_data", data)
object.__setattr__(self, "tags", tags)

def __hash__(self) -> int:
return hash(frozenset(self._data.items()))
Expand Down Expand Up @@ -1977,12 +1988,14 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]],

# {{{ make_dict_of_named_arrays

def make_dict_of_named_arrays(data: Dict[str, Array]) -> DictOfNamedArrays:
def make_dict_of_named_arrays(data: Dict[str, Array], *,
tags: FrozenSet[Tag] = frozenset()
) -> DictOfNamedArrays:
"""Make a :class:`DictOfNamedArrays` object.

:param data: member keys and arrays
"""
return DictOfNamedArrays(data)
return DictOfNamedArrays(data, tags=(tags | _get_default_tags()))

# }}}

Expand Down
10 changes: 6 additions & 4 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
InputArgumentBase, Einsum,
AdvancedIndexInContiguousAxes,
AdvancedIndexInNoncontiguousAxes, BasicIndex,
NormalizedSlice)
NormalizedSlice, make_dict_of_named_arrays)

from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression,
INT_CLASSES, IntegralT)
Expand Down Expand Up @@ -183,7 +183,9 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:

return LoopyCall(translation_unit=translation_unit,
bindings=bindings,
entrypoint=entrypoint)
entrypoint=entrypoint,
tags=expr.tags
)

def map_data_wrapper(self, expr: DataWrapper) -> Array:
name = _generate_name_for_temp(expr, self.var_name_gen, "_pt_data")
Expand Down Expand Up @@ -616,9 +618,9 @@ def normalize_outputs(result: Union[Array, DictOfNamedArrays,
"either an Array or a DictOfNamedArrays")

if isinstance(result, Array):
outputs = DictOfNamedArrays({"_pt_out": result})
outputs = make_dict_of_named_arrays({"_pt_out": result})
elif isinstance(result, dict):
outputs = DictOfNamedArrays(result)
outputs = make_dict_of_named_arrays(result)
else:
assert isinstance(result, DictOfNamedArrays)
outputs = result
Expand Down
35 changes: 19 additions & 16 deletions pytato/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@


import numpy as np
import attrs
import loopy as lp
import pymbolic.primitives as prim
from typing import (Dict, Optional, Any, Iterator, FrozenSet, Union, Sequence,
Tuple, Iterable, Mapping)
Tuple, Iterable, Mapping, ClassVar)
from numbers import Number
from pytato.array import (AbstractResultWithNamedArrays, Array, ShapeType,
NamedArray, ArrayOrScalar, SizeParam, AxesT)
Expand Down Expand Up @@ -61,26 +62,25 @@
"""


@attrs.define(eq=False, frozen=True)
class LoopyCall(AbstractResultWithNamedArrays):
"""
An array expression node representing a call to an entrypoint in a
:mod:`loopy` translation unit.
"""
_mapper_method = "map_loopy_call"
translation_unit: "lp.TranslationUnit"
bindings: Dict[str, ArrayOrScalar]
entrypoint: str

def __init__(self,
translation_unit: "lp.TranslationUnit",
bindings: Dict[str, ArrayOrScalar],
entrypoint: str):
entry_kernel = translation_unit[entrypoint]
super().__init__()
self._result_names = frozenset({name
for name, lp_arg in entry_kernel.arg_dict.items()
if lp_arg.is_output})

self.translation_unit = translation_unit
self.bindings = bindings
self.entrypoint = entrypoint
_mapper_method: ClassVar[str] = "map_loopy_call"

copy = attrs.evolve

@property
def _result_names(self) -> FrozenSet[str]:
return frozenset({name
for name, lp_arg in self._entry_kernel.arg_dict.items()
if lp_arg.is_output})

@memoize_method
def _to_pytato(self, expr: ScalarExpression) -> ScalarExpression:
Expand Down Expand Up @@ -207,6 +207,8 @@ def call_loopy(translation_unit: "lp.TranslationUnit",
to :class:`pytato.array.Array`.
:arg entrypoint: the entrypoint of the ``translation_unit`` parameter.
"""
from pytato.array import _get_default_tags

if entrypoint is None:
if len(translation_unit.entrypoints) != 1:
raise ValueError("cannot infer entrypoint")
Expand Down Expand Up @@ -285,7 +287,8 @@ def call_loopy(translation_unit: "lp.TranslationUnit",

translation_unit = translation_unit.with_entrypoints(frozenset())

return LoopyCall(translation_unit, bindings, entrypoint)
return LoopyCall(translation_unit, bindings, entrypoint,
tags=_get_default_tags())


# {{{ shape inference
Expand Down
6 changes: 3 additions & 3 deletions pytato/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytato.transform import EdgeCachedMapper, CachedWalkMapper
from pytato.array import (
Array, AbstractResultWithNamedArrays, Placeholder,
DictOfNamedArrays, make_placeholder)
DictOfNamedArrays, make_placeholder, make_dict_of_named_arrays)

from pytato.target import BoundProgram

Expand Down Expand Up @@ -359,7 +359,7 @@ def find_partition(outputs: DictOfNamedArrays,
_check_partition_disjointness(result)

from pytato.analysis import get_num_nodes
num_nodes_per_part = [get_num_nodes(DictOfNamedArrays(
num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays(
{x: result.var_name_to_result[x] for x in part.output_names}))
for part in result.parts.values()]

Expand Down Expand Up @@ -423,7 +423,7 @@ def generate_code_for_partition(partition: GraphPartition) \

for part in sorted(partition.parts.values(),
key=lambda part_: sorted(part_.output_names)):
d = DictOfNamedArrays(
d = make_dict_of_named_arrays(
{var_name: partition.var_name_to_result[var_name]
for var_name in part.output_names
})
Expand Down
28 changes: 19 additions & 9 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def map_named_array(self, expr: NamedArray) -> Array:
def map_dict_of_named_arrays(self,
expr: DictOfNamedArrays) -> DictOfNamedArrays:
return DictOfNamedArrays({key: self.rec(val.expr)
for key, val in expr.items()})
for key, val in expr.items()},
tags=expr.tags
)

def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:
bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array)
Expand All @@ -312,7 +314,9 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:

return LoopyCall(translation_unit=expr.translation_unit,
bindings=bindings,
entrypoint=expr.entrypoint)
entrypoint=expr.entrypoint,
tags=expr.tags,
)

def map_loopy_call_result(self, expr: LoopyCallResult) -> Array:
return LoopyCallResult(
Expand Down Expand Up @@ -489,7 +493,9 @@ def map_named_array(self, expr: NamedArray, *args: Any, **kwargs: Any) -> Array:
def map_dict_of_named_arrays(self,
expr: DictOfNamedArrays, *args: Any, **kwargs: Any) -> DictOfNamedArrays:
return DictOfNamedArrays({key: self.rec(val.expr, *args, **kwargs)
for key, val in expr.items()})
for key, val in expr.items()},
tags=expr.tags,
)

def map_loopy_call(self, expr: LoopyCall,
*args: Any, **kwargs: Any) -> LoopyCall:
Expand All @@ -500,7 +506,9 @@ def map_loopy_call(self, expr: LoopyCall,

return LoopyCall(translation_unit=expr.translation_unit,
bindings=bindings,
entrypoint=expr.entrypoint)
entrypoint=expr.entrypoint,
tags=expr.tags,
)

def map_loopy_call_result(self, expr: LoopyCallResult,
*args: Any, **kwargs: Any) -> Array:
Expand Down Expand Up @@ -1222,11 +1230,12 @@ def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays,
items in *source_dict*
"""
if not source_dict:
return DictOfNamedArrays({})
data = {}
else:
data = {name: copy_mapper(val.expr)
for name, val in sorted(source_dict.items())}

data = {name: copy_mapper(val.expr)
for name, val in sorted(source_dict.items())}
return DictOfNamedArrays(data)
return DictOfNamedArrays(data, tags=source_dict.tags)


def get_dependencies(expr: DictOfNamedArrays) -> Dict[str, FrozenSet[Array]]:
Expand Down Expand Up @@ -1308,7 +1317,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays:
for name, ary in expr.items():
new_data[name] = materializer(ary.expr).expr

return DictOfNamedArrays(new_data)
return DictOfNamedArrays(new_data, tags=expr.tags)

# }}}

Expand Down Expand Up @@ -1693,6 +1702,7 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:
name: self.handle_edge(expr, child)
if isinstance(child, Array) else child
for name, child in sorted(expr.bindings.items())},
tags=expr.tags,
)

def map_distributed_send_ref_holder(
Expand Down
10 changes: 5 additions & 5 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_codegen_with_DictOfNamedArrays(ctx_factory): # noqa
x_in = np.array([1, 2, 3, 4, 5])
y_in = np.array([6, 7, 8, 9, 10])

result = pt.DictOfNamedArrays(dict(x_out=x, y_out=y))
result = pt.make_dict_of_named_arrays(dict(x_out=x, y_out=y))

# With return_dict.
prog = pt.generate_loopy(result)
Expand Down Expand Up @@ -525,7 +525,7 @@ def test_dict_of_named_array_codegen_avoids_recomputation():
y = 2*x
z = y + 4*x

yz = pt.DictOfNamedArrays({"y": y, "z": z})
yz = pt.make_dict_of_named_arrays({"y": y, "z": z})

knl = pt.generate_loopy(yz).kernel
assert ("y" in knl.id_to_insn["z_store"].read_dependency_names())
Expand Down Expand Up @@ -750,7 +750,7 @@ def test_call_loopy_with_same_callee_names(ctx_factory):
cuatro_u = 2*call_loopy(twice, {"x": u}, "callee")["y"]
nueve_u = 3*call_loopy(thrice, {"x": u}, "callee")["y"]

out = pt.DictOfNamedArrays({"cuatro_u": cuatro_u, "nueve_u": nueve_u})
out = pt.make_dict_of_named_arrays({"cuatro_u": cuatro_u, "nueve_u": nueve_u})

evt, out_dict = pt.generate_loopy(out,
options=lp.Options(return_dict=True))(queue)
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def test_random_dag_against_numpy(ctx_factory):
ref_result = make_random_dag(rdagc_np)
dag = make_random_dag(rdagc_pt)
from pytato.transform import materialize_with_mpms
dict_named_arys = pt.DictOfNamedArrays({"result": dag})
dict_named_arys = pt.make_dict_of_named_arrays({"result": dag})
dict_named_arys = materialize_with_mpms(dict_named_arys)
if 0:
pt.show_dot_graph(dict_named_arys)
Expand Down Expand Up @@ -1407,7 +1407,7 @@ def test_partitioner(ctx_factory):
ref_result = make_random_dag(rdagc_np)

from pytato.transform import materialize_with_mpms
dict_named_arys = materialize_with_mpms(pt.DictOfNamedArrays(
dict_named_arys = materialize_with_mpms(pt.make_dict_of_named_arrays(
{"result": make_random_dag(rdagc_pt)}))

from dataclasses import dataclass
Expand Down
5 changes: 3 additions & 2 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _do_test_distributed_execution_basic(ctx_factory):
y = x+halo

# Find the partition
outputs = pt.DictOfNamedArrays({"out": y})
outputs = pt.make_dict_of_named_arrays({"out": y})
distributed_parts = find_distributed_partition(outputs)
prg_per_partition = generate_code_for_partition(distributed_parts)

Expand Down Expand Up @@ -167,7 +167,8 @@ def gen_comm(rdagc):
additional_generators=[
(comm_fake_prob, gen_comm)
])
pt_dag = pt.DictOfNamedArrays({"result": make_random_dag(rdagc_comm)})
pt_dag = pt.make_dict_of_named_arrays(
{"result": make_random_dag(rdagc_comm)})
x_comm = pt.transform.materialize_with_mpms(pt_dag)

distributed_partition = find_distributed_partition(x_comm)
Expand Down
2 changes: 1 addition & 1 deletion test/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_random_dag_against_numpy(jit):
ref_result = make_random_dag(rdagc_np)
dag = make_random_dag(rdagc_pt)
from pytato.transform import materialize_with_mpms
dict_named_arys = pt.DictOfNamedArrays({"result": dag})
dict_named_arys = pt.make_dict_of_named_arrays({"result": dag})
dict_named_arys = materialize_with_mpms(dict_named_arys)
if 0:
pt.show_dot_graph(dict_named_arys)
Expand Down