diff --git a/dwave/optimization/include/dwave-optimization/nodes/reduce.hpp b/dwave/optimization/include/dwave-optimization/nodes/reduce.hpp index d5eb4537..b770dca2 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/reduce.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/reduce.hpp @@ -86,6 +86,9 @@ class ReduceNode : public ArrayOutputMixin { using ArrayOutputMixin::size; ssize_t size(const State& state) const override; + /// @copydoc Array::sizeinfo() + SizeInfo sizeinfo() const override { return sizeinfo_; }; + /// @copydoc Array::size_diff() ssize_t size_diff(const State& state) const override; @@ -111,6 +114,8 @@ class ReduceNode : public ArrayOutputMixin { // as whether we're integral or not. const ValuesInfo values_info_; + const SizeInfo sizeinfo_; + // During propagation, we need to take a predecessor's `update` and apply // it to the correct `index` in the ReduceNode buffer. This method // determines the correct buffer `index` given an `update.index`. diff --git a/dwave/optimization/src/nodes/reduce.cpp b/dwave/optimization/src/nodes/reduce.cpp index 6d758ca6..15060ff3 100644 --- a/dwave/optimization/src/nodes/reduce.cpp +++ b/dwave/optimization/src/nodes/reduce.cpp @@ -24,6 +24,8 @@ #include #include "../functional_.hpp" +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/state.hpp" namespace dwave::optimization { @@ -281,7 +283,7 @@ class ReduceNodeData : public NodeStateData { if (flags_[index] == ReductionFlag::unchanged) { reductions_diff_.emplace_back(index, reductions_[index]); } - flags_[index] = ReductionFlag::invalid; + flags_[index] = ReductionFlag::invalid; return; } reduction = std::move(inverse.value()); @@ -581,6 +583,19 @@ ValuesInfo values_info(const Array* array_ptr, std::span axes, return bounds; } +SizeInfo reducenode_calculate_sizeinfo(const Array* node_ptr, const Array* array_ptr, + std::span axes) { + // Node is statically sized. Note: If node_shape = {}, product(node_shape) = 1. + if (!node_ptr->dynamic()) return SizeInfo(product(node_ptr->shape())); + assert(node_ptr->shape().size() && node_ptr->shape().front() == -1); + + // Node is dynamically sized but predecessor is always empty. + if (array_ptr->size() == 0) return SizeInfo(node_ptr); + + // Node size is derived from its predecessor's size. + return array_ptr->sizeinfo() / product(keep_axes(array_ptr->shape(), axes)); +} + template ReduceNode::ReduceNode(ArrayNode* array_ptr) : ReduceNode(array_ptr, {}) {} @@ -591,7 +606,8 @@ ReduceNode::ReduceNode(ArrayNode* array_ptr, std::span initial(initial), array_ptr_(array_ptr), axes_(normalize_axes(array_ptr, axes)), - values_info_(values_info(array_ptr_, axes_, initial)) { + values_info_(values_info(array_ptr_, axes_, initial)), + sizeinfo_(reducenode_calculate_sizeinfo(this, array_ptr_, axes_)) { add_predecessor(array_ptr); } diff --git a/releasenotes/notes/reducenode_cache_sizeinfo-cba3557baf3f7901.yaml b/releasenotes/notes/reducenode_cache_sizeinfo-cba3557baf3f7901.yaml new file mode 100644 index 00000000..b837470b --- /dev/null +++ b/releasenotes/notes/reducenode_cache_sizeinfo-cba3557baf3f7901.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + The symbol: `All`, `Any`, `Max`, `Min`, `Prod`, and `Sum` can infer + their size from their predecessor's size. diff --git a/tests/cpp/nodes/test_reduce.cpp b/tests/cpp/nodes/test_reduce.cpp index a60eeaee..0135a514 100644 --- a/tests/cpp/nodes/test_reduce.cpp +++ b/tests/cpp/nodes/test_reduce.cpp @@ -45,6 +45,7 @@ TEMPLATE_TEST_CASE("ReduceNode", "", // THEN("The output shape is scalar") { CHECK(r_ptr->ndim() == 0); CHECK(r_ptr->size() == 1); + CHECK(r_ptr->sizeinfo() == SizeInfo(1)); } THEN("The constant is the operand") { @@ -77,6 +78,7 @@ TEMPLATE_TEST_CASE("ReduceNode", "", // THEN("The output shape is scalar") { CHECK(r_ptr->ndim() == 0); CHECK(r_ptr->size() == 1); + CHECK(r_ptr->sizeinfo() == SizeInfo(1)); } THEN("The constant is the operand") { @@ -110,6 +112,7 @@ TEMPLATE_TEST_CASE("ReduceNode", "", // THEN("The output shape is scalar") { CHECK(r_ptr->ndim() == 0); CHECK(r_ptr->size() == 1); + CHECK(r_ptr->sizeinfo() == SizeInfo(1)); } THEN("The constant is the operand") { @@ -137,6 +140,7 @@ TEMPLATE_TEST_CASE("ReduceNode", "", // THEN("The output shape is scalar") { CHECK(r_ptr->ndim() == 0); CHECK(r_ptr->size() == 1); + CHECK(r_ptr->sizeinfo() == SizeInfo(1)); } THEN("The set is the operand") { @@ -202,7 +206,11 @@ TEMPLATE_TEST_CASE("ReduceNode", "", // graph.emplace_node(r_ptr); // this is equivalent to a reduction, so the output is a scalar - CHECK(r_ptr->ndim() == 0); + THEN("The output shape is scalar") { + CHECK(r_ptr->ndim() == 0); + CHECK(r_ptr->size() == 1); + CHECK(r_ptr->sizeinfo() == SizeInfo(1)); + } auto values = std::vector{1, 2, 3, 4, 5}; @@ -262,8 +270,10 @@ TEST_CASE("AllNode/AnyNode") { THEN("y,z are logical and scalar") { CHECK(y_ptr->logical()); CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); CHECK(z_ptr->logical()); CHECK(z_ptr->ndim() == 0); + CHECK(z_ptr->sizeinfo() == SizeInfo(1)); } WHEN("x == [0, 0, 0, 0, 0]") { @@ -365,6 +375,15 @@ TEST_CASE("MaxNode/MinNode") { CHECK(min_ptr->min() == 0); CHECK(min_ptr->max() == 6); // beause the list can be empty + THEN("min and max shape's are scalar") { + CHECK(min_ptr->ndim() == 0); + CHECK(min_ptr->size() == 1); + CHECK(min_ptr->sizeinfo() == SizeInfo(1)); + CHECK(max_ptr->ndim() == 0); + CHECK(max_ptr->size() == 1); + CHECK(max_ptr->sizeinfo() == SizeInfo(1)); + } + AND_GIVEN("An initial state of [ 1 2 3 | 0 4 ]") { auto state = graph.empty_state(); list_ptr->initialize_state(state, {1, 2, 3, 0, 4}); @@ -488,6 +507,12 @@ TEST_CASE("MaxNode") { CHECK(y_ptr->max() == 2); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("x = IntegerNode(3, -5, 2), y = x.max(init=-.5)") { @@ -499,12 +524,24 @@ TEST_CASE("MaxNode") { CHECK(y_ptr->max() == 2); CHECK(!y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("x = IntegerNode(5, 0, 10), y x.max()") { auto x_ptr = graph.emplace_node(3, 0, 20); auto y_ptr = graph.emplace_node(x_ptr); + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } + auto state = graph.empty_state(); x_ptr->initialize_state(state, {0, 5, 10}); graph.initialize_state(state); @@ -557,6 +594,12 @@ TEST_CASE("MinNode") { CHECK(y_ptr->max() == 2); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("x = IntegerNode(3, -5, 2), y = x.min(init=-.5)") { @@ -568,6 +611,12 @@ TEST_CASE("MinNode") { CHECK(y_ptr->max() == -.5); CHECK(!y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } } @@ -583,6 +632,12 @@ TEST_CASE("ProdNode") { CHECK(y_ptr->max() == -5 * -5 * 2); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("x = IntegerNode(3, -5, 2), y = x.prod(init=-.5)") { @@ -594,12 +649,24 @@ TEST_CASE("ProdNode") { CHECK(y_ptr->max() == -5 * -5 * -5 * -.5); CHECK(!y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("Given a list node with a prod over it") { auto list_ptr = graph.emplace_node(5, 0, 5); auto prod_ptr = graph.emplace_node(list_ptr, std::vector{}, 1); + THEN("prod's shape is scalar") { + CHECK(prod_ptr->ndim() == 0); + CHECK(prod_ptr->size() == 1); + CHECK(prod_ptr->sizeinfo() == SizeInfo(1)); + } + AND_GIVEN("An initial state of [ 1 2 3 | 0 4 ]") { auto state = graph.empty_state(); list_ptr->initialize_state(state, {1, 2, 3, 0, 4}); @@ -707,6 +774,12 @@ TEST_CASE("ProdNode") { auto list_ptr = graph.emplace_node(5, 0, 5); auto prod_ptr = graph.emplace_node(list_ptr, std::vector{}, 0); + THEN("prod's shape is scalar") { + CHECK(prod_ptr->ndim() == 0); + CHECK(prod_ptr->size() == 1); + CHECK(prod_ptr->sizeinfo() == SizeInfo(1)); + } + AND_GIVEN("An initial state of [ 1 2 3 | 0 4 ]") { auto state = graph.empty_state(); list_ptr->initialize_state(state, {1, 2, 3, 0, 4}); @@ -748,10 +821,6 @@ TEST_CASE("ProdNode") { graph.emplace_node(r_ptr_1); graph.emplace_node(r_ptr_2); - CHECK(r_ptr_0->ndim() == 2); - CHECK(r_ptr_1->ndim() == 2); - CHECK(r_ptr_2->ndim() == 2); - WHEN("We make a state") { auto state = graph.initialize_state(); @@ -759,14 +828,17 @@ TEST_CASE("ProdNode") { CHECK(r_ptr_0->ndim() == 2); CHECK(r_ptr_0->size(state) == 4); CHECK(r_ptr_0->shape(state).size() == 2); + CHECK(r_ptr_0->sizeinfo() == ptr->sizeinfo() / 2); CHECK(r_ptr_1->ndim() == 2); CHECK(r_ptr_1->size(state) == 4); CHECK(r_ptr_1->shape(state).size() == 2); + CHECK(r_ptr_1->sizeinfo() == ptr->sizeinfo() / 2); CHECK(r_ptr_2->ndim() == 2); CHECK(r_ptr_2->size(state) == 4); CHECK(r_ptr_2->shape(state).size() == 2); + CHECK(r_ptr_2->sizeinfo() == ptr->sizeinfo() / 2); /// Check with /// A = np.arange(8).reshape((2, 2, 2)) @@ -793,6 +865,12 @@ TEST_CASE("SumNode") { CHECK(y_ptr->max() == 2 + 2 + 2); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("x = IntegerNode(3, -5, 2), y = x.sum(init=-.5)") { @@ -804,6 +882,12 @@ TEST_CASE("SumNode") { CHECK(y_ptr->max() == 2 + 2 + 2 + -.5); CHECK(!y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("a = [0, 3, 2], x = SetNode(3), y = a[x].sum()") { @@ -817,6 +901,12 @@ TEST_CASE("SumNode") { CHECK(y_ptr->max() == 9); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("a = [1, 3, 2], x = SetNode(3), y = a[x].sum()") { @@ -830,6 +920,12 @@ TEST_CASE("SumNode") { CHECK(y_ptr->max() == 9); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("a = [2, 3, 2], x = SetNode(3, 2, 3), y = a[x].sum()") { @@ -843,6 +939,12 @@ TEST_CASE("SumNode") { CHECK(y_ptr->max() == 9); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("a = [1, -3, 2], x = SetNode(3), y = a[x].sum()") { @@ -856,6 +958,12 @@ TEST_CASE("SumNode") { CHECK(y_ptr->max() == 6); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("a = [-1, -3, -2], x = SetNode(3), y = a[x].sum()") { @@ -869,6 +977,12 @@ TEST_CASE("SumNode") { CHECK(y_ptr->max() == 0); CHECK(y_ptr->integral()); } + + THEN("y's shape is scalar") { + CHECK(y_ptr->ndim() == 0); + CHECK(y_ptr->size() == 1); + CHECK(y_ptr->sizeinfo() == SizeInfo(1)); + } } GIVEN("A set reduced") { @@ -878,6 +992,7 @@ TEST_CASE("SumNode") { THEN("The output shape is scalar") { CHECK(r_ptr->ndim() == 0); CHECK(r_ptr->size() == 1); + CHECK(r_ptr->sizeinfo() == SizeInfo(1)); } WHEN("We make a state - defaulting the set to populated") { @@ -996,14 +1111,17 @@ TEST_CASE("SumNode") { CHECK(r_ptr_0->ndim() == 2); CHECK(r_ptr_0->size(state) == 4); CHECK(r_ptr_0->shape(state).size() == 2); + CHECK(r_ptr_0->sizeinfo() == ptr->sizeinfo() / 2); CHECK(r_ptr_1->ndim() == 2); CHECK(r_ptr_1->size(state) == 4); CHECK(r_ptr_1->shape(state).size() == 2); + CHECK(r_ptr_1->sizeinfo() == ptr->sizeinfo() / 2); CHECK(r_ptr_2->ndim() == 2); CHECK(r_ptr_2->size(state) == 4); CHECK(r_ptr_2->shape(state).size() == 2); + CHECK(r_ptr_2->sizeinfo() == ptr->sizeinfo() / 2); /// Check with /// A = np.arange(8).reshape((2, 2, 2)) @@ -1043,14 +1161,17 @@ TEST_CASE("SumNode") { CHECK(r_ptr_01->ndim() == 1); CHECK(r_ptr_01->size(state) == 4); CHECK(r_ptr_01->shape(state).size() == 1); + CHECK(r_ptr_01->sizeinfo() == array_ptr->sizeinfo() / (2 * 3)); CHECK(r_ptr_02->ndim() == 1); CHECK(r_ptr_02->size(state) == 3); CHECK(r_ptr_02->shape(state).size() == 1); + CHECK(r_ptr_02->sizeinfo() == array_ptr->sizeinfo() / (2 * 4)); CHECK(r_ptr_12->ndim() == 1); CHECK(r_ptr_12->size(state) == 2); CHECK(r_ptr_12->shape(state).size() == 1); + CHECK(r_ptr_12->sizeinfo() == array_ptr->sizeinfo() / (3 * 4)); /// Check with /// A = np.arange(24).reshape((2, 3, 4)) @@ -1086,14 +1207,17 @@ TEST_CASE("SumNode") { CHECK(r_ptr_0->ndim() == 2); CHECK(r_ptr_0->size(state) == 6); CHECK(r_ptr_0->shape(state).size() == 2); + CHECK(r_ptr_0->sizeinfo() == ptr->sizeinfo() / 2); CHECK(r_ptr_1->ndim() == 2); CHECK(r_ptr_1->size(state) == 4); CHECK(r_ptr_1->shape(state).size() == 2); + CHECK(r_ptr_1->sizeinfo() == ptr->sizeinfo() / 3); CHECK(r_ptr_2->ndim() == 2); CHECK(r_ptr_2->size(state) == 6); CHECK(r_ptr_2->shape(state).size() == 2); + CHECK(r_ptr_2->sizeinfo() == ptr->sizeinfo() / 2); CHECK(std::ranges::equal(r_ptr_0->view(state), std::vector(6, 0))); CHECK(std::ranges::equal(r_ptr_1->view(state), std::vector(4, 0))); @@ -1171,10 +1295,12 @@ TEST_CASE("SumNode") { CHECK(r_ptr_0->ndim() == 1); CHECK(r_ptr_0->size(state) == 2); CHECK(r_ptr_0->shape(state).size() == 1); + CHECK(r_ptr_0->sizeinfo() == ptr->sizeinfo() / 2); CHECK(r_ptr_1->ndim() == 1); CHECK(r_ptr_1->size(state) == 2); CHECK(r_ptr_1->shape(state).size() == 1); + CHECK(r_ptr_1->sizeinfo() == ptr->sizeinfo() / 2); /// Check with /// A = np.arange(8).reshape((2, 2, 2)) @@ -1206,6 +1332,7 @@ TEST_CASE("SumNode") { AND_GIVEN("x = sum(arr, initial=2)") { auto x_ptr = graph.emplace_node(arr_ptr, std::vector{}, 2); + CHECK(x_ptr->sizeinfo() == SizeInfo(1)); graph.emplace_node(x_ptr); auto state = graph.empty_state(); @@ -1239,6 +1366,7 @@ TEST_CASE("SumNode") { AND_GIVEN("x = sum(arr, axes=(0,), initial=2)") { auto x_ptr = graph.emplace_node(arr_ptr, std::vector{0}, 2); + CHECK(x_ptr->sizeinfo() == SizeInfo(12)); graph.emplace_node(x_ptr); CHECK_THAT(x_ptr->shape(), RangeEquals({3, 4})); @@ -1342,6 +1470,7 @@ TEST_CASE("SumNode") { // shape is as expected CHECK(sum_ptr->size() == 1); + CHECK(sum_ptr->sizeinfo() == SizeInfo(1)); CHECK_THAT(sum_ptr->shape(), RangeEquals(std::vector{})); // as are the array values @@ -1381,6 +1510,7 @@ TEST_CASE("SumNode") { // shape is as expected CHECK(sum_ptr->size() == 0); + CHECK(sum_ptr->sizeinfo() == SizeInfo(0)); CHECK_THAT(sum_ptr->shape(), RangeEquals({0})); // as are the array values @@ -1410,6 +1540,7 @@ TEST_CASE("SumNode") { // shape is as expected CHECK(sum_ptr->size() == -1); + CHECK(sum_ptr->sizeinfo() == SizeInfo(sum_ptr)); CHECK_THAT(sum_ptr->shape(), RangeEquals({-1})); // as are the array values @@ -1432,13 +1563,16 @@ TEST_CASE("SumNode") { CHECK_THAT(sum_ptr->view(state), RangeEquals({3, 3, 3, 3})); } - GIVEN("Dynamic array of shape (-1, 2) with min/max of 1/3 on first dim, and a sum across the second dim") { - auto x_ptr = graph.emplace_node(std::initializer_list{-1, 2}, 0.0, 10.0, true, 2, 6); + GIVEN("Dynamic array of shape (-1, 2) with min/max of 1/3 on first dim, and a sum across the" + "second dim") { + auto x_ptr = graph.emplace_node( + std::initializer_list{-1, 2}, 0.0, 10.0, true, 2, 6); auto sum_ptr = graph.emplace_node(x_ptr, std::vector{1}); THEN("The shape of the sum is correct") { CHECK(sum_ptr->dynamic()); CHECK_THAT(sum_ptr->shape(), RangeEquals({-1})); + CHECK(sum_ptr->sizeinfo() == x_ptr->sizeinfo() / 2); } AND_GIVEN("An initialized state") { @@ -1453,7 +1587,6 @@ TEST_CASE("SumNode") { } AND_WHEN("We grow the dynamic node") { - x_ptr->grow(state, {3, 4}); graph.propagate(state); REQUIRE_THAT(x_ptr->shape(state), RangeEquals({2, 2})); diff --git a/tests/test_symbols.py b/tests/test_symbols.py index 9220e288..6a412255 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -361,6 +361,15 @@ def generate_symbols(self): model.lock() yield from nodes + def test_sizeinfo_awareness(self): + model = Model() + s = model.set(10) + r = s.reshape([-1, 1]) + x = r.all(axis=1) + y = r.all(axis=1) + # only possible if both `x` and `y` know their size is derived from `r`. + z = y + x + def test_empty(self): model = Model() empty = model.constant([]).all() @@ -462,6 +471,15 @@ def generate_symbols(self): model.lock() yield from nodes + def test_sizeinfo_awareness(self): + model = Model() + s = model.set(10) + r = s.reshape([-1, 1]) + x = r.any(axis=1) + y = r.any(axis=1) + # only possible if both `x` and `y` know their size is derived from `r`. + z = y + x + def test_empty(self): model = Model() empty = model.constant([]).any() @@ -2589,6 +2607,15 @@ class TestMax(utils.ReduceTests): def op(self, x, *args, **kwargs): return x.max(*args, **kwargs) + def test_sizeinfo_awareness(self): + model = Model() + s = model.set(10) + r = s.reshape([-1, 1]) + x = r.max(axis=1) + y = r.max(axis=1) + # only possible if both `x` and `y` know their size is derived from `r`. + z = y + x + def test_empty(self): model = Model() with self.assertRaisesRegex(ValueError, "no identity"): @@ -2677,6 +2704,15 @@ class TestMin(utils.ReduceTests): def op(self, x, *args, **kwargs): return x.min(*args, **kwargs) + def test_sizeinfo_awareness(self): + model = Model() + s = model.set(10) + r = s.reshape([-1, 1]) + x = r.min(axis=1) + y = r.min(axis=1) + # only possible if both `x` and `y` know their size is derived from `r`. + z = y + x + def test_empty(self): model = Model() with self.assertRaisesRegex(ValueError, "no identity"): @@ -3214,6 +3250,15 @@ class TestProd(utils.ReduceTests): def op(self, x, *args, **kwargs): return x.prod(*args, **kwargs) + def test_sizeinfo_awareness(self): + model = Model() + s = model.set(10) + r = s.reshape([-1, 1]) + x = r.prod(axis=1) + y = r.prod(axis=1) + # only possible if both `x` and `y` know their size is derived from `r`. + z = y + x + def test_empty(self): model = Model() empty = model.constant([]).prod() @@ -3871,6 +3916,15 @@ class TestSum(utils.ReduceTests): def op(self, x, *args, **kwargs): return x.sum(*args, **kwargs) + def test_sizeinfo_awareness(self): + model = Model() + s = model.set(10) + r = s.reshape([-1, 1]) + x = r.sum(axis=1) + y = r.sum(axis=1) + # only possible if both `x` and `y` know their size is derived from `r`. + z = y + x + def test_axis(self): model = Model() model.states.resize(1)