Skip to content

Commit e060c67

Browse files
committed
implements inlined; codegen works now
1 parent da04a11 commit e060c67

8 files changed

Lines changed: 86 additions & 84 deletions

File tree

pytato/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def set_debug_enabled(flag: bool) -> None:
9090
from pytato.visualization import (get_dot_graph, show_dot_graph,
9191
get_ascii_graph, show_ascii_graph,
9292
get_dot_graph_from_partition)
93+
from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls
9394
import pytato.analysis as analysis
9495
import pytato.tags as tags
9596
import pytato.tracing as tracing
@@ -154,6 +155,8 @@ def set_debug_enabled(flag: bool) -> None:
154155
"make_distributed_recv", "make_distributed_send", "DistributedRecv",
155156
"DistributedSend", "staple_distributed_send", "DistributedSendRefHolder",
156157

158+
"tag_all_calls_to_be_inlined", "inline_calls",
159+
157160
"find_distributed_partition",
158161
"number_distributed_tags",
159162
"execute_distributed_partition",

pytato/array.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,17 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC):
767767
tags: FrozenSet[Tag] = attrs.field(kw_only=True)
768768
_mapper_method: ClassVar[str]
769769

770+
def _is_eq_valid(self) -> bool:
771+
return self.__class__.__eq__ is AbstractResultWithNamedArrays.__eq__
772+
773+
def __post_init__(self) -> None:
774+
# ensure that a developer does not uses dataclass' "__eq__"
775+
# or "__hash__" implementation as they have exponential complexity.
776+
assert self._is_eq_valid()
777+
778+
def __attrs_post_init__(self) -> None:
779+
return self.__post_init__()
780+
770781
@abstractmethod
771782
def __contains__(self, name: object) -> bool:
772783
pass
@@ -779,6 +790,13 @@ def __getitem__(self, name: str) -> NamedArray:
779790
def __len__(self) -> int:
780791
pass
781792

793+
def __eq__(self, other: Any) -> bool:
794+
if self is other:
795+
return True
796+
797+
from pytato.equality import EqualityComparer
798+
return EqualityComparer()(self, other)
799+
782800

783801
@attrs.define(frozen=True, eq=False, init=False)
784802
class DictOfNamedArrays(AbstractResultWithNamedArrays):
@@ -825,13 +843,6 @@ def __len__(self) -> int:
825843
def __iter__(self) -> Iterator[str]:
826844
return iter(self._data)
827845

828-
def __eq__(self, other: Any) -> bool:
829-
if self is other:
830-
return True
831-
832-
from pytato.equality import EqualityComparer
833-
return EqualityComparer()(self, other)
834-
835846
def __repr__(self) -> str:
836847
return "DictOfNamedArrays(" + str(self._data) + ")"
837848

pytato/codegen.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ def __init__(self, target: Target) -> None:
118118
self.target = target
119119
self.kernels_seen: Dict[str, lp.LoopKernel] = {}
120120

121+
def clone(self) -> CodeGenPreprocessor:
122+
new_mapper = CodeGenPreprocessor(self.target)
123+
new_mapper.kernels_seen = self.kernels_seen
124+
return new_mapper
125+
121126
def map_size_param(self, expr: SizeParam) -> Array:
122127
name = expr.name
123128
assert name is not None
@@ -602,6 +607,7 @@ def map_non_contiguous_advanced_index(self,
602607
var_to_reduction_descr=Map(),
603608
tags=expr.tags,
604609
)
610+
605611
# }}}
606612

607613

@@ -697,11 +703,11 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
697703

698704
# }}}
699705

700-
mapper = CodeGenPreprocessor(target)
701-
702-
new_outputs = copy_dict_of_named_arrays(outputs, mapper)
703706
new_outputs = inline_calls(outputs)
704707

708+
mapper = CodeGenPreprocessor(target)
709+
new_outputs = copy_dict_of_named_arrays(new_outputs, mapper)
710+
705711
return PreprocessResult(outputs=new_outputs,
706712
compute_order=tuple(output_order),
707713
bound_arguments=mapper.bound_arguments)

pytato/equality.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
IndexBase, IndexLambda, NamedArray,
3232
Reshape, Roll, Stack, AbstractResultWithNamedArrays,
3333
Array, DictOfNamedArrays, Placeholder, SizeParam)
34-
from pytato.tracing import (TracePlaceholder, Call, NamedCallResult,
35-
FunctionDefinition)
34+
from pytato.tracing import Call, NamedCallResult, FunctionDefinition
3635

3736
if TYPE_CHECKING:
3837
from pytato.loopy import LoopyCall, LoopyCallResult
@@ -276,10 +275,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
276275
) -> bool:
277276
return (expr1.__class__ is expr2.__class__
278277
and expr1.identifier == expr2.identifier
279-
and len(expr1.parameters.keys()) == len(expr2.parameters.keys())
280-
and all(self.rec(expr1_pl, expr2_pl)
281-
for expr1_pl, expr2_pl in zip(expr1.parameters,
282-
expr2.parameters))
278+
and expr1.parameters == expr2.parameters
283279
and (set(expr1.returns.keys()) == set(expr2.returns.keys()))
284280
and all(self.rec(expr1.returns[k], expr2.returns[k])
285281
for k in expr1.returns)
@@ -299,16 +295,6 @@ def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool:
299295
and expr1.name == expr2.name
300296
and self.rec(expr1._container, expr2._container))
301297

302-
def map_trace_placeholder(self, expr1: TracePlaceholder, expr2: Any) -> bool:
303-
return (expr1.__class__ is expr2.__class__
304-
and expr1.name == expr2.name
305-
and expr1.identifier == expr2.identifier
306-
and expr1.shape == expr2.shape
307-
and expr1.dtype == expr2.dtype
308-
and expr1.tags == expr2.tags
309-
and expr1.axes == expr2.axes
310-
)
311-
312298
# }}}
313299

314300
# vim: fdm=marker

pytato/loopy.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,6 @@ def __len__(self) -> int:
118118
def __iter__(self) -> Iterator[str]:
119119
return iter(self._result_names)
120120

121-
def __eq__(self, other: Any) -> bool:
122-
if self is other:
123-
return True
124-
125-
if not isinstance(other, LoopyCall):
126-
return False
127-
128-
if ((self.entrypoint == other.entrypoint)
129-
and (self.bindings == other.bindings)
130-
and (self.translation_unit == other.translation_unit)):
131-
return True
132-
return False
133-
134121

135122
class LoopyCallResult(NamedArray):
136123
"""

pytato/tracing.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838

3939
import attrs
4040

41-
from typing import (Callable, Dict, FrozenSet, Tuple, Union, TypeVar, Optional, Hashable,
42-
Sequence, ClassVar)
41+
from typing import (Callable, Dict, FrozenSet, Tuple, Union, TypeVar, Optional,
42+
Hashable, Sequence, ClassVar)
4343
from immutables import Map
4444
from pytato.array import (Array, AbstractResultWithNamedArrays, AxesT,
4545
Placeholder, NamedArray, ShapeType, _dtype_any)
@@ -56,7 +56,7 @@
5656

5757
# {{{ Call/NamedCallResult
5858

59-
@attrs.define(frozen=True, eq=False, repr=False)
59+
@attrs.define(frozen=True, repr=False, eq=False)
6060
class FunctionDefinition(AbstractResultWithNamedArrays):
6161
r"""
6262
A function definition that represents its outputs as instances of
@@ -67,9 +67,7 @@ class FunctionDefinition(AbstractResultWithNamedArrays):
6767
6868
.. attribute:: parameters
6969
70-
The inputs to the function node. A mapping from the names of the
71-
function's parameters to its corresponding placeholder in the
72-
outputs' expression graph.
70+
Names of the inputs to the function node.
7371
7472
.. attribute:: returns
7573
@@ -85,16 +83,11 @@ class FunctionDefinition(AbstractResultWithNamedArrays):
8583
:class:`pytato.Array`\ s, then *returns* uses the same mapping.
8684
"""
8785
identifier: Hashable
88-
parameters: Map[str, Placeholder]
86+
parameters: FrozenSet[str]
8987
returns: Map[str, Array]
9088

9189
_mapper_method: ClassVar[str] = "map_function_definition"
9290

93-
def __post_init__(self):
94-
if __debug__:
95-
assert all(name == param.name
96-
for name, param in self.parameters.items())
97-
9891
def __contains__(self, name: object) -> bool:
9992
return name in self.returns
10093

@@ -189,15 +182,17 @@ class Call(AbstractResultWithNamedArrays):
189182
result_tags: Map[str, FrozenSet[Tag]] = attrs.field(kw_only=True)
190183
result_axes: Map[str, AxesT] = attrs.field(kw_only=True)
191184

185+
_mapper_method: ClassVar[str] = "map_call"
186+
187+
copy = attrs.evolve
188+
192189
def __post_init__(self):
193190
if __debug__:
194191
# check that the invocation parameters and the function definition
195192
# parameters agree with each other.
196-
assert set(self.bindings) == set(self.function.parameters.keys())
197-
assert set(self.tags.keys()) == set(self.function.returns.keys())
198-
assert set(self.axes.keys()) == set(self.function.returns.keys())
199-
200-
_mapper_method: ClassVar[str] = "map_call"
193+
assert frozenset(self.bindings) == self.function.parameters
194+
assert set(self.result_tags.keys()) == set(self.function.returns.keys())
195+
assert set(self.result_axes.keys()) == set(self.function.returns.keys())
201196

202197
def __contains__(self, name: object) -> bool:
203198
return name in self.function.returns
@@ -211,8 +206,8 @@ def __getitem__(self, name: str) -> NamedCallResult:
211206
def __len__(self) -> int:
212207
return len(self.function.returns)
213208

214-
def with_tagged_axis(self, name: str, iaxis: int,
215-
tags: Union[Sequence[Tag], Tag]) -> Self:
209+
def with_result_axis_tagged(self, name: str, iaxis: int,
210+
tags: Union[Sequence[Tag], Tag]) -> Self:
216211
"""
217212
Returns a copy of *self* with the result corresponding to *name*\'s
218213
*iaxis*-th axis tagged with *tags*. Also, see
@@ -230,14 +225,15 @@ def with_tagged_axis(self, name: str, iaxis: int,
230225
return replace(self,
231226
result_axes=self.result_axes.set(name, new_axes))
232227

233-
def tagged(self, name: str, tags: Union[Sequence[Tag], Tag]) -> Self:
228+
def with_result_tagged(self, name: str, tags: Union[Sequence[Tag], Tag]) -> Self:
234229
"""
235230
Returns a copy of *self* with the result corresponding to *name*\'s
236231
tagged with *tags*. Also, see :meth:`pytato.Array.tagged`.
237232
"""
238233
from attrs import evolve as replace
239234
from pytools.tag import check_tag_uniqueness, normalize_tags
240-
new_tags = check_tag_uniqueness(normalize_tags(tags) | self.result_tags[name])
235+
new_tags = check_tag_uniqueness(normalize_tags(tags)
236+
| self.result_tags[name])
241237

242238
return replace(self, result_tags=self.result_tags.set(name, new_tags))
243239

@@ -299,8 +295,7 @@ def trace_call(f: Callable[..., ReturnT],
299295
# construct the function
300296
function = FunctionDefinition(
301297
identifier,
302-
(Map({pl.name: pl for pl in pl_args})
303-
.update(Map({pl.name: pl for pl in pl_kwargs.values()}))),
298+
frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs),
304299
Map(returns),
305300
tags=_get_default_tags()
306301
)

0 commit comments

Comments
 (0)