@@ -1029,11 +1029,20 @@ def get_default_op_name_to_num_flops() -> dict[str, int]:
10291029 "max" : 1 }
10301030
10311031
1032+ # FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10321033def get_num_flops (
10331034 expr : ArrayOrNames ,
10341035 op_name_to_num_flops : Mapping [str , int ] | None = None ,
10351036 ) -> ArrayOrScalar :
1036- """Count the total number of floating point operations in the DAG *expr*."""
1037+ """
1038+ Count the total number of floating point operations in the DAG *expr*.
1039+
1040+ .. note::
1041+
1042+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1043+ this function assumes a SIMT-like model of computation in which the per-entry
1044+ cost is the maximum(??? FIXME) of the costs of the two branches.
1045+ """
10371046 from pytato .codegen import normalize_outputs
10381047 expr = normalize_outputs (expr )
10391048 expr = _normalize_materialization (expr )
@@ -1049,13 +1058,20 @@ def get_num_flops(
10491058 + sum (fc .call_to_nflops .values ()))
10501059
10511060
1061+ # FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10521062def get_materialized_node_flop_counts (
10531063 expr : ArrayOrNames ,
10541064 op_name_to_num_flops : Mapping [str , int ] | None = None ,
10551065 ) -> dict [Array , ArrayOrScalar ]:
10561066 """
10571067 Returns a dictionary mapping materialized nodes in DAG *expr* to their floating
10581068 point operation count.
1069+
1070+ .. note::
1071+
1072+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1073+ this function assumes a SIMT-like model of computation in which the per-entry
1074+ cost is the maximum(??? FIXME) of the costs of the two branches.
10591075 """
10601076 from pytato .codegen import normalize_outputs
10611077 expr = normalize_outputs (expr )
@@ -1070,6 +1086,7 @@ def get_materialized_node_flop_counts(
10701086 return fc .materialized_node_to_nflops
10711087
10721088
1089+ # FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10731090def get_unmaterialized_node_flop_counts (
10741091 expr : ArrayOrNames ,
10751092 op_name_to_num_flops : Mapping [str , int ] | None = None ,
@@ -1078,6 +1095,12 @@ def get_unmaterialized_node_flop_counts(
10781095 Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a
10791096 :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count
10801097 information.
1098+
1099+ .. note::
1100+
1101+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1102+ this function assumes a SIMT-like model of computation in which the per-entry
1103+ cost is the maximum(??? FIXME) of the costs of the two branches.
10811104 """
10821105 from pytato .codegen import normalize_outputs
10831106 expr = normalize_outputs (expr )
0 commit comments