Skip to content

Commit 789e7ac

Browse files
committed
add note to docs about assumptions when handling conditional expressions
1 parent 22167db commit 789e7ac

1 file changed

Lines changed: 24 additions & 1 deletion

File tree

pytato/analysis/__init__.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1029,11 +1029,20 @@ def get_default_op_name_to_num_flops() -> dict[str, int]:
10291029
"max": 1}
10301030

10311031

1032+
# FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10321033
def get_num_flops(
10331034
expr: ArrayOrNames,
10341035
op_name_to_num_flops: Mapping[str, int] | None = None,
10351036
) -> ArrayOrScalar:
1036-
"""Count the total number of floating point operations in the DAG *expr*."""
1037+
"""
1038+
Count the total number of floating point operations in the DAG *expr*.
1039+
1040+
.. note::
1041+
1042+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1043+
this function assumes a SIMT-like model of computation in which the per-entry
1044+
cost is the maximum(??? FIXME) of the costs of the two branches.
1045+
"""
10371046
from pytato.codegen import normalize_outputs
10381047
expr = normalize_outputs(expr)
10391048
expr = _normalize_materialization(expr)
@@ -1049,13 +1058,20 @@ def get_num_flops(
10491058
+ sum(fc.call_to_nflops.values()))
10501059

10511060

1061+
# FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10521062
def get_materialized_node_flop_counts(
10531063
expr: ArrayOrNames,
10541064
op_name_to_num_flops: Mapping[str, int] | None = None,
10551065
) -> dict[Array, ArrayOrScalar]:
10561066
"""
10571067
Returns a dictionary mapping materialized nodes in DAG *expr* to their floating
10581068
point operation count.
1069+
1070+
.. note::
1071+
1072+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1073+
this function assumes a SIMT-like model of computation in which the per-entry
1074+
cost is the maximum(??? FIXME) of the costs of the two branches.
10591075
"""
10601076
from pytato.codegen import normalize_outputs
10611077
expr = normalize_outputs(expr)
@@ -1070,6 +1086,7 @@ def get_materialized_node_flop_counts(
10701086
return fc.materialized_node_to_nflops
10711087

10721088

1089+
# FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10731090
def get_unmaterialized_node_flop_counts(
10741091
expr: ArrayOrNames,
10751092
op_name_to_num_flops: Mapping[str, int] | None = None,
@@ -1078,6 +1095,12 @@ def get_unmaterialized_node_flop_counts(
10781095
Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a
10791096
:class:`UnmaterializedNodeFlopCounts` containing floating-point operation count
10801097
information.
1098+
1099+
.. note::
1100+
1101+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1102+
this function assumes a SIMT-like model of computation in which the per-entry
1103+
cost is the maximum(??? FIXME) of the costs of the two branches.
10811104
"""
10821105
from pytato.codegen import normalize_outputs
10831106
expr = normalize_outputs(expr)

0 commit comments

Comments
 (0)