Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class ReduceNode : public ArrayOutputMixin<ArrayNode> {
using ArrayOutputMixin::size;
ssize_t size(const State& state) const override;

/// @copydoc Array::sizeinfo()
SizeInfo sizeinfo() const override { return sizeinfo_; };

/// @copydoc Array::size_diff()
ssize_t size_diff(const State& state) const override;

Expand All @@ -111,6 +114,8 @@ class ReduceNode : public ArrayOutputMixin<ArrayNode> {
// as whether we're integral or not.
const ValuesInfo values_info_;

const SizeInfo sizeinfo_;

// During propagation, we need to take a predecessor's `update` and apply
// it to the correct `index` in the ReduceNode buffer. This method
// determines the correct buffer `index` given an `update.index`.
Expand Down
20 changes: 18 additions & 2 deletions dwave/optimization/src/nodes/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <utility>

#include "../functional_.hpp"
#include "dwave-optimization/array.hpp"
#include "dwave-optimization/state.hpp"

namespace dwave::optimization {

Expand Down Expand Up @@ -281,7 +283,7 @@ class ReduceNodeData : public NodeStateData {
if (flags_[index] == ReductionFlag::unchanged) {
reductions_diff_.emplace_back(index, reductions_[index]);
}
flags_[index] = ReductionFlag::invalid;
flags_[index] = ReductionFlag::invalid;
return;
}
reduction = std::move(inverse.value());
Expand Down Expand Up @@ -581,6 +583,19 @@ ValuesInfo values_info(const Array* array_ptr, std::span<const ssize_t> axes,
return bounds;
}

SizeInfo reducenode_calculate_sizeinfo(const Array* node_ptr, const Array* array_ptr,
std::span<const ssize_t> axes) {
// Node is statically sized. Note: If node_shape = {}, product(node_shape) = 1.
if (!node_ptr->dynamic()) return SizeInfo(product(node_ptr->shape()));
assert(node_ptr->shape().size() && node_ptr->shape().front() == -1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(node_ptr->shape().size() && node_ptr->shape().front() == -1);
assert(node_ptr->shape().size() and node_ptr->shape().front() == -1);

Copy link
Contributor Author

@fastbodin fastbodin Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of #399?

Copy link
Contributor Author

@fastbodin fastbodin Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, shouldn't the ! above be a not?


// Node is dynamically sized but predecessor is always empty.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a bit more detail here, something like "this means we can't easily make any deductions about the resulting size as it likely depends on predecessors of the input array" (assuming that's correct).

if (array_ptr->size() == 0) return SizeInfo(node_ptr);

// Node size is derived from its predecessor's size.
return array_ptr->sizeinfo() / product(keep_axes(array_ptr->shape(), axes));
}

template <class BinaryOp>
ReduceNode<BinaryOp>::ReduceNode(ArrayNode* array_ptr) : ReduceNode(array_ptr, {}) {}

Expand All @@ -591,7 +606,8 @@ ReduceNode<BinaryOp>::ReduceNode(ArrayNode* array_ptr, std::span<const ssize_t>
initial(initial),
array_ptr_(array_ptr),
axes_(normalize_axes(array_ptr, axes)),
values_info_(values_info<BinaryOp>(array_ptr_, axes_, initial)) {
values_info_(values_info<BinaryOp>(array_ptr_, axes_, initial)),
sizeinfo_(reducenode_calculate_sizeinfo(this, array_ptr_, axes_)) {
add_predecessor(array_ptr);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
The symbol: `All`, `Any`, `Max`, `Min`, `Prod`, and `Sum` can infer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The symbol: `All`, `Any`, `Max`, `Min`, `Prod`, and `Sum` can infer
The symbols `All`, `Any`, `Max`, `Min`, `Prod`, and `Sum` can infer

their size from their predecessor's size.
Loading