Skip to content

Commit 5768e69

Browse files
committed
TransposeNode::propagate() improvements
`TransposeNode::convert_predecessor_index()` recomputed the same information (the number of indices of the predecessor node traversed by taking a stride along an axis) for each update. Since this node is already extremely lightweight (its buffer is simply its predecessor's buffer), we can easily cache this data.
1 parent 9008655 commit 5768e69

2 files changed

Lines changed: 58 additions & 68 deletions

File tree

dwave/optimization/include/dwave-optimization/nodes/manipulation.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,8 @@ class TransposeNode : public ArrayNode {
537537
const ssize_t ndim_;
538538
const std::unique_ptr<ssize_t[]> shape_;
539539
const std::unique_ptr<ssize_t[]> strides_;
540+
/// The number of indices per stride in the predecessor array.
541+
const std::vector<ssize_t> array_indices_per_stride_;
540542
const bool contiguous_;
541543
const ValuesInfo values_info_;
542544

dwave/optimization/src/nodes/manipulation.cpp

Lines changed: 56 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ void BroadcastToNode::propagate(State& state) const {
250250
assert(([&]() {
251251
std::vector<ssize_t> multi_index =
252252
unravel_index(update.index, array_ptr_->shape());
253-
multi_index.insert(multi_index.begin(), this->ndim() - array_ptr_->ndim(), 0);
253+
multi_index.insert(multi_index.begin(), this->ndim() - array_ptr_->ndim(),
254+
0);
254255
const ssize_t assert_index = ravel_multi_index(multi_index, this->shape());
255256
return assert_index == index;
256257
})() &&
@@ -1464,42 +1465,57 @@ ArrayNode* TransposeNode::predeccesor_check_(ArrayNode* array_ptr) const {
14641465
return array_ptr;
14651466
}
14661467

1467-
// a TransposeNodes shape and strides are the reverse of its predecessor
1468+
// A TransposeNodes shape and strides are the reverse of its predecessor.
14681469
std::unique_ptr<ssize_t[]> reverse_span_helper(const std::span<const ssize_t> span,
14691470
const ssize_t size) {
14701471
std::unique_ptr<ssize_t[]> reverse_span = std::make_unique<ssize_t[]>(size);
14711472
std::reverse_copy(span.begin(), span.end(), reverse_span.get());
14721473
return reverse_span;
14731474
}
14741475

1476+
std::vector<ssize_t> array_indices_per_stride_helper(const std::span<const ssize_t> array_shape,
1477+
const ssize_t ndim) {
1478+
std::vector<ssize_t> axis_index_strides;
1479+
axis_index_strides.reserve(ndim);
1480+
1481+
ssize_t indices_per_stride = 1;
1482+
// Traverse the array axes in forward order.
1483+
for (ssize_t i = 0; i < ndim; ++i) {
1484+
// Record the number of indices traversed when moving along the ith axis.
1485+
axis_index_strides.push_back(indices_per_stride);
1486+
// Account for indices in ith axis.
1487+
indices_per_stride *= array_shape[i];
1488+
}
1489+
return axis_index_strides;
1490+
}
1491+
14751492
TransposeNode::TransposeNode(ArrayNode* array_ptr)
14761493
: array_ptr_(predeccesor_check_(array_ptr)),
14771494
ndim_(array_ptr->ndim()),
14781495
shape_(reverse_span_helper(array_ptr->shape(), ndim_)),
14791496
strides_(reverse_span_helper(array_ptr->strides(), ndim_)),
1497+
array_indices_per_stride_(array_indices_per_stride_helper(array_ptr->shape(), ndim_)),
14801498
contiguous_(is_contiguous(ndim_, shape_.get(), strides_.get())),
14811499
values_info_(array_ptr) {
14821500
add_predecessor(array_ptr);
14831501
}
14841502

1485-
// this node simply points to the predecessor buff
1503+
// This node simply points to the predecessor buff.
14861504
double const* TransposeNode::buff(const State& state) const { return array_ptr_->buff(state); }
14871505

14881506
ssize_t TransposeNode::ndim() const { return ndim_; }
14891507

14901508
std::span<const ssize_t> TransposeNode::shape(const State& state) const {
1491-
if (ndim_ <= 1) { // predecessor is vector and may be dynamic
1492-
return array_ptr_->shape(state);
1493-
}
1494-
// predecessor is (>=2)-D array and shape is static
1509+
// Predecessor is vector and may be dynamic.
1510+
if (ndim_ <= 1) return array_ptr_->shape(state);
1511+
// Predecessor is (>=2)-D array and shape is static.
14951512
return std::span<const ssize_t>(shape_.get(), ndim_);
14961513
}
14971514

14981515
std::span<const ssize_t> TransposeNode::shape() const {
1499-
if (ndim_ <= 1) { // predecessor is vector and may be dynamic
1500-
return array_ptr_->shape();
1501-
}
1502-
// predecessor is (>=2)-D array and shape is fixed
1516+
// Predecessor is vector and may be dynamic.
1517+
if (ndim_ <= 1) return array_ptr_->shape();
1518+
// Predecessor is (>=2)-D array and shape is fixed.
15031519
return std::span<const ssize_t>(shape_.get(), ndim_);
15041520
}
15051521

@@ -1533,73 +1549,53 @@ class TransposeNodeDiffData : public NodeStateData {
15331549
std::span<const Update> TransposeNode::diff(const State& state) const {
15341550
// If the predecessor is a vector, the transpose does nothing and the diff
15351551
// of this node is simply the diff of the predecessor node.
1536-
if (ndim_ <= 1) { // predecessor is vector
1537-
return array_ptr_->diff(state);
1538-
}
1552+
if (ndim_ <= 1) return array_ptr_->diff(state);
15391553
// Otherwise, we use the stored diff data.
15401554
return data_ptr<TransposeNodeDiffData>(state)->diff;
15411555
}
15421556

15431557
ssize_t TransposeNode::size_diff(const State& state) const { return array_ptr_->size_diff(state); }
15441558

15451559
void TransposeNode::initialize_state(State& state) const {
1546-
if (ndim_ <= 1) {
1547-
return Node::initialize_state(state); // stateless
1548-
}
1560+
if (ndim_ <= 1) return Node::initialize_state(state); // stateless
15491561
// Construct diff data if predecessor is (>=2)-D array
15501562
emplace_data_ptr<TransposeNodeDiffData>(state);
15511563
}
15521564

15531565
Update TransposeNode::convert_predecessor_update_(Update update) const {
1554-
if (ndim_ <= 1) { // predecessor is vector
1555-
return update;
1556-
}
1557-
15581566
const std::span<const ssize_t> array_shape = array_ptr_->shape();
15591567
ssize_t transpose_flat_index = 0;
1560-
// when constructing a flat index of the transpose, it is helpful to know
1561-
// the # of indices contributed when you move along a fixed axes.
1562-
// `transpose_axis_index_stride` is initialized by the # of indices
1563-
// contributed when moving along the 0th axis of the transpose.
1564-
ssize_t transpose_axis_index_stride = std::accumulate(
1565-
array_shape.begin(), array_shape.end() - 1, 1, std::multiplies<ssize_t>());
1566-
1567-
// traverse the predecessor axes in backward (reverse) order and the
1568-
// transpose axes in forward order
1568+
assert(ndim_ > 1);
1569+
assert(array_indices_per_stride_.size() == array_shape.size());
1570+
1571+
// Traverse the predecessor axes in backward (reverse) order and the
1572+
// transpose axes in forward order.
15691573
for (ssize_t i = ndim_ - 1; i >= 0; --i) {
1570-
// grab predecessor shape along the ith axis
1574+
// Grab predecessor shape along the ith axis.
15711575
const ssize_t axis_shape = array_shape[i];
15721576
assert(0 <= axis_shape &&
15731577
"all dimensions of (>=2)-D array must be non-negative for transpose operation");
1574-
// determine the multidimensional index of `flat_index` along the ith
1575-
// axis of predecessor. Note: this is the multidimensional index along
1576-
// the (ndim_ - 1 - i)th axis of the transpose
1578+
// Determine the multidimensional index along the ith axis of
1579+
// predecessor. Note: this is the multidimensional index along the
1580+
// (ndim_ - 1 - i)th axis of the transpose.
15771581
const ssize_t multidimensional_index = update.index % axis_shape;
1578-
// reassign flat_index to the correct index along the (i - 1)th axes of predecessor
1582+
// Weight the multidimensional index along the (ndim_ - 1 - i)th axis
1583+
// of the transpose by # of indices contributed by moving along the ith
1584+
// axis of the predecessor.
1585+
transpose_flat_index += multidimensional_index * array_indices_per_stride_[i];
1586+
// Reassign the index to the correct index along the (i - 1)th axes of
1587+
// predecessor. Note we are using integer division here.
15791588
update.index /= axis_shape;
1580-
1581-
// weight the multidimensional index along the (ndim_ - 1 - i)th axis
1582-
// of the transpose by # of indices contributed by moving along axis
1583-
transpose_flat_index += multidimensional_index * transpose_axis_index_stride;
1584-
1585-
// recall we are traversing the tranpose axes in forward order.
1586-
// the # of indices contributed by moving along the (ndim - 2 - i)th
1587-
// axis is the same as (the # of indices contributed by moving along the
1588-
// (ndim_ - 1 - i)th axis) / shape(ndim_ - i - 1)
1589-
transpose_axis_index_stride /= array_shape[ndim_ - i - 1];
15901589
}
15911590

15921591
update.index = transpose_flat_index;
1593-
15941592
return update;
15951593
}
15961594

15971595
void TransposeNode::propagate(State& state) const {
15981596
const std::span<const Update> array_diff = array_ptr_->diff(state);
1599-
1600-
if (array_diff.empty() || ndim_ <= 1) {
1601-
return; // Nothing to do or predecessor is vector (transpose of vector is vector)
1602-
}
1597+
// Nothing to do or predecessor is vector (transpose of vector is vector).
1598+
if (array_diff.empty() || ndim_ <= 1) return;
16031599

16041600
// Predecessor is a non-dynamic (>=2)-D array.
16051601
std::vector<Update>& transpose_diff = data_ptr<TransposeNodeDiffData>(state)->diff;
@@ -1608,34 +1604,26 @@ void TransposeNode::propagate(State& state) const {
16081604

16091605
for (const Update& u : array_diff) {
16101606
assert(([&]() {
1611-
// make a copy of the update
1612-
Update u_copy = u;
1613-
// convert flat index of predecessor update to multidimensional indices
1614-
std::vector<ssize_t> multi_index = unravel_index(u_copy.index, array_ptr_->shape());
1615-
// reverse multidimensional indices to obtain the multidimensional
1616-
// transpose indices
1607+
// Convert flat index of predecessor update to multidimensional indices.
1608+
std::vector<ssize_t> multi_index = unravel_index(u.index, array_ptr_->shape());
1609+
// Reverse indices to obtain the transpose indices.
16171610
std::reverse(multi_index.begin(), multi_index.end());
1618-
// convert multidimensional transpose indices to transpose flat index
1619-
// and check conversion
1620-
return ravel_multi_index(multi_index, this->shape()) ==
1621-
convert_predecessor_update_(u_copy).index;
1611+
// Convert to transpose flat index and check conversion.
1612+
return ravel_multi_index(multi_index, shape()) == convert_predecessor_update_(u).index;
16221613
})());
1623-
// Make a copy of the update and convert the index to the respective
1624-
// transpose index
1614+
// Copy update and convert predecessor index to transpose index.
16251615
transpose_diff.emplace_back(convert_predecessor_update_(u));
16261616
}
16271617
}
16281618

16291619
void TransposeNode::commit(State& state) const {
1630-
if (ndim_ > 1) {
1631-
data_ptr<TransposeNodeDiffData>(state)->commit();
1632-
} // otherwise, stateless
1620+
if (ndim_ > 1) data_ptr<TransposeNodeDiffData>(state)->commit();
1621+
// otherwise, stateless
16331622
};
16341623

16351624
void TransposeNode::revert(State& state) const {
1636-
if (ndim_ > 1) {
1637-
data_ptr<TransposeNodeDiffData>(state)->revert();
1638-
} // otherwise, stateless
1625+
if (ndim_ > 1) data_ptr<TransposeNodeDiffData>(state)->revert();
1626+
// otherwise, stateless
16391627
}
16401628

16411629
} // namespace dwave::optimization

0 commit comments

Comments
 (0)