3838
3939import attrs
4040
41- from typing import (Callable , Dict , FrozenSet , Tuple , Union , TypeVar , Optional , Hashable ,
42- Sequence , ClassVar )
41+ from typing import (Callable , Dict , FrozenSet , Tuple , Union , TypeVar , Optional ,
42+ Hashable , Sequence , ClassVar )
4343from immutables import Map
4444from pytato .array import (Array , AbstractResultWithNamedArrays , AxesT ,
4545 Placeholder , NamedArray , ShapeType , _dtype_any )
5656
5757# {{{ Call/NamedCallResult
5858
59- @attrs .define (frozen = True , eq = False , repr = False )
59+ @attrs .define (frozen = True , repr = False , eq = False )
6060class FunctionDefinition (AbstractResultWithNamedArrays ):
6161 r"""
6262 A function definition that represents its outputs as instances of
@@ -67,9 +67,7 @@ class FunctionDefinition(AbstractResultWithNamedArrays):
6767
6868 .. attribute:: parameters
6969
70- The inputs to the function node. A mapping from the names of the
71- function's parameters to its corresponding placeholder in the
72- outputs' expression graph.
70+ Names of the inputs to the function node.
7371
7472 .. attribute:: returns
7573
@@ -85,16 +83,11 @@ class FunctionDefinition(AbstractResultWithNamedArrays):
8583 :class:`pytato.Array`\ s, then *returns* uses the same mapping.
8684 """
8785 identifier : Hashable
88- parameters : Map [str , Placeholder ]
86+ parameters : FrozenSet [str ]
8987 returns : Map [str , Array ]
9088
9189 _mapper_method : ClassVar [str ] = "map_function_definition"
9290
93- def __post_init__ (self ):
94- if __debug__ :
95- assert all (name == param .name
96- for name , param in self .parameters .items ())
97-
9891 def __contains__ (self , name : object ) -> bool :
9992 return name in self .returns
10093
@@ -189,15 +182,17 @@ class Call(AbstractResultWithNamedArrays):
189182 result_tags : Map [str , FrozenSet [Tag ]] = attrs .field (kw_only = True )
190183 result_axes : Map [str , AxesT ] = attrs .field (kw_only = True )
191184
185+ _mapper_method : ClassVar [str ] = "map_call"
186+
187+ copy = attrs .evolve
188+
192189 def __post_init__ (self ):
193190 if __debug__ :
194191 # check that the invocation parameters and the function definition
195192 # parameters agree with each other.
196- assert set (self .bindings ) == set (self .function .parameters .keys ())
197- assert set (self .tags .keys ()) == set (self .function .returns .keys ())
198- assert set (self .axes .keys ()) == set (self .function .returns .keys ())
199-
200- _mapper_method : ClassVar [str ] = "map_call"
193+ assert frozenset (self .bindings ) == self .function .parameters
194+ assert set (self .result_tags .keys ()) == set (self .function .returns .keys ())
195+ assert set (self .result_axes .keys ()) == set (self .function .returns .keys ())
201196
202197 def __contains__ (self , name : object ) -> bool :
203198 return name in self .function .returns
@@ -211,8 +206,8 @@ def __getitem__(self, name: str) -> NamedCallResult:
211206 def __len__ (self ) -> int :
212207 return len (self .function .returns )
213208
214- def with_tagged_axis (self , name : str , iaxis : int ,
215- tags : Union [Sequence [Tag ], Tag ]) -> Self :
209+ def with_result_axis_tagged (self , name : str , iaxis : int ,
210+ tags : Union [Sequence [Tag ], Tag ]) -> Self :
216211 """
217212 Returns a copy of *self* with the result corresponding to *name*\' s
218213 *iaxis*-th axis tagged with *tags*. Also, see
@@ -230,14 +225,15 @@ def with_tagged_axis(self, name: str, iaxis: int,
230225 return replace (self ,
231226 result_axes = self .result_axes .set (name , new_axes ))
232227
233- def tagged (self , name : str , tags : Union [Sequence [Tag ], Tag ]) -> Self :
228+ def with_result_tagged (self , name : str , tags : Union [Sequence [Tag ], Tag ]) -> Self :
234229 """
235230 Returns a copy of *self* with the result corresponding to *name*\' s
236231 tagged with *tags*. Also, see :meth:`pytato.Array.tagged`.
237232 """
238233 from attrs import evolve as replace
239234 from pytools .tag import check_tag_uniqueness , normalize_tags
240- new_tags = check_tag_uniqueness (normalize_tags (tags ) | self .result_tags [name ])
235+ new_tags = check_tag_uniqueness (normalize_tags (tags )
236+ | self .result_tags [name ])
241237
242238 return replace (self , result_tags = self .result_tags .set (name , new_tags ))
243239
@@ -299,8 +295,7 @@ def trace_call(f: Callable[..., ReturnT],
299295 # construct the function
300296 function = FunctionDefinition (
301297 identifier ,
302- (Map ({pl .name : pl for pl in pl_args })
303- .update (Map ({pl .name : pl for pl in pl_kwargs .values ()}))),
298+ frozenset (pl_arg .name for pl_arg in pl_args ) | frozenset (pl_kwargs ),
304299 Map (returns ),
305300 tags = _get_default_tags ()
306301 )
0 commit comments