@@ -250,7 +250,8 @@ void BroadcastToNode::propagate(State& state) const {
250250 assert (([&]() {
251251 std::vector<ssize_t > multi_index =
252252 unravel_index (update.index , array_ptr_->shape ());
253- multi_index.insert (multi_index.begin (), this ->ndim () - array_ptr_->ndim (), 0 );
253+ multi_index.insert (multi_index.begin (), this ->ndim () - array_ptr_->ndim (),
254+ 0 );
254255 const ssize_t assert_index = ravel_multi_index (multi_index, this ->shape ());
255256 return assert_index == index;
256257 })() &&
@@ -1464,42 +1465,57 @@ ArrayNode* TransposeNode::predeccesor_check_(ArrayNode* array_ptr) const {
14641465 return array_ptr;
14651466}
14661467
1467- // a TransposeNodes shape and strides are the reverse of its predecessor
1468+ // A TransposeNodes shape and strides are the reverse of its predecessor.
14681469std::unique_ptr<ssize_t []> reverse_span_helper (const std::span<const ssize_t > span,
14691470 const ssize_t size) {
14701471 std::unique_ptr<ssize_t []> reverse_span = std::make_unique<ssize_t []>(size);
14711472 std::reverse_copy (span.begin (), span.end (), reverse_span.get ());
14721473 return reverse_span;
14731474}
14741475
1476+ std::vector<ssize_t > array_indices_per_stride_helper (const std::span<const ssize_t > array_shape,
1477+ const ssize_t ndim) {
1478+ std::vector<ssize_t > axis_index_strides;
1479+ axis_index_strides.reserve (ndim);
1480+
1481+ ssize_t indices_per_stride = 1 ;
1482+ // Traverse the array axes in forward order.
1483+ for (ssize_t i = 0 ; i < ndim; ++i) {
1484+ // Record the number of indices traversed when moving along the ith axis.
1485+ axis_index_strides.push_back (indices_per_stride);
1486+ // Account for indices in ith axis.
1487+ indices_per_stride *= array_shape[i];
1488+ }
1489+ return axis_index_strides;
1490+ }
1491+
14751492TransposeNode::TransposeNode (ArrayNode* array_ptr)
14761493 : array_ptr_(predeccesor_check_(array_ptr)),
14771494 ndim_(array_ptr->ndim ()),
14781495 shape_(reverse_span_helper(array_ptr->shape (), ndim_)),
14791496 strides_(reverse_span_helper(array_ptr->strides (), ndim_)),
1497+ array_indices_per_stride_(array_indices_per_stride_helper(array_ptr->shape (), ndim_)),
14801498 contiguous_(is_contiguous(ndim_, shape_.get(), strides_.get())),
14811499 values_info_(array_ptr) {
14821500 add_predecessor (array_ptr);
14831501}
14841502
1485- // this node simply points to the predecessor buff
1503+ // This node simply points to the predecessor buff.
14861504double const * TransposeNode::buff (const State& state) const { return array_ptr_->buff (state); }
14871505
14881506ssize_t TransposeNode::ndim () const { return ndim_; }
14891507
14901508std::span<const ssize_t > TransposeNode::shape (const State& state) const {
1491- if (ndim_ <= 1 ) { // predecessor is vector and may be dynamic
1492- return array_ptr_->shape (state);
1493- }
1494- // predecessor is (>=2)-D array and shape is static
1509+ // Predecessor is vector and may be dynamic.
1510+ if (ndim_ <= 1 ) return array_ptr_->shape (state);
1511+ // Predecessor is (>=2)-D array and shape is static.
14951512 return std::span<const ssize_t >(shape_.get (), ndim_);
14961513}
14971514
14981515std::span<const ssize_t > TransposeNode::shape () const {
1499- if (ndim_ <= 1 ) { // predecessor is vector and may be dynamic
1500- return array_ptr_->shape ();
1501- }
1502- // predecessor is (>=2)-D array and shape is fixed
1516+ // Predecessor is vector and may be dynamic.
1517+ if (ndim_ <= 1 ) return array_ptr_->shape ();
1518+ // Predecessor is (>=2)-D array and shape is fixed.
15031519 return std::span<const ssize_t >(shape_.get (), ndim_);
15041520}
15051521
@@ -1533,73 +1549,53 @@ class TransposeNodeDiffData : public NodeStateData {
15331549std::span<const Update> TransposeNode::diff (const State& state) const {
15341550 // If the predecessor is a vector, the transpose does nothing and the diff
15351551 // of this node is simply the diff of the predecessor node.
1536- if (ndim_ <= 1 ) { // predecessor is vector
1537- return array_ptr_->diff (state);
1538- }
1552+ if (ndim_ <= 1 ) return array_ptr_->diff (state);
15391553 // Otherwise, we use the stored diff data.
15401554 return data_ptr<TransposeNodeDiffData>(state)->diff ;
15411555}
15421556
15431557ssize_t TransposeNode::size_diff (const State& state) const { return array_ptr_->size_diff (state); }
15441558
15451559void TransposeNode::initialize_state (State& state) const {
1546- if (ndim_ <= 1 ) {
1547- return Node::initialize_state (state); // stateless
1548- }
1560+ if (ndim_ <= 1 ) return Node::initialize_state (state); // stateless
15491561 // Construct diff data if predecessor is (>=2)-D array
15501562 emplace_data_ptr<TransposeNodeDiffData>(state);
15511563}
15521564
15531565Update TransposeNode::convert_predecessor_update_ (Update update) const {
1554- if (ndim_ <= 1 ) { // predecessor is vector
1555- return update;
1556- }
1557-
15581566 const std::span<const ssize_t > array_shape = array_ptr_->shape ();
15591567 ssize_t transpose_flat_index = 0 ;
1560- // when constructing a flat index of the transpose, it is helpful to know
1561- // the # of indices contributed when you move along a fixed axes.
1562- // `transpose_axis_index_stride` is initialized by the # of indices
1563- // contributed when moving along the 0th axis of the transpose.
1564- ssize_t transpose_axis_index_stride = std::accumulate (
1565- array_shape.begin (), array_shape.end () - 1 , 1 , std::multiplies<ssize_t >());
1566-
1567- // traverse the predecessor axes in backward (reverse) order and the
1568- // transpose axes in forward order
1568+ assert (ndim_ > 1 );
1569+ assert (array_indices_per_stride_.size () == array_shape.size ());
1570+
1571+ // Traverse the predecessor axes in backward (reverse) order and the
1572+ // transpose axes in forward order.
15691573 for (ssize_t i = ndim_ - 1 ; i >= 0 ; --i) {
1570- // grab predecessor shape along the ith axis
1574+ // Grab predecessor shape along the ith axis.
15711575 const ssize_t axis_shape = array_shape[i];
15721576 assert (0 <= axis_shape &&
15731577 " all dimensions of (>=2)-D array must be non-negative for transpose operation" );
1574- // determine the multidimensional index of `flat_index` along the ith
1575- // axis of predecessor. Note: this is the multidimensional index along
1576- // the (ndim_ - 1 - i)th axis of the transpose
1578+ // Determine the multidimensional index along the ith axis of
1579+ // predecessor. Note: this is the multidimensional index along the
1580+ // (ndim_ - 1 - i)th axis of the transpose.
15771581 const ssize_t multidimensional_index = update.index % axis_shape;
1578- // reassign flat_index to the correct index along the (i - 1)th axes of predecessor
1582+ // Weight the multidimensional index along the (ndim_ - 1 - i)th axis
1583+ // of the transpose by # of indices contributed by moving along the ith
1584+ // axis of the predecessor.
1585+ transpose_flat_index += multidimensional_index * array_indices_per_stride_[i];
1586+ // Reassign the index to the correct index along the (i - 1)th axes of
1587+ // predecessor. Note we are using integer division here.
15791588 update.index /= axis_shape;
1580-
1581- // weight the multidimensional index along the (ndim_ - 1 - i)th axis
1582- // of the transpose by # of indices contributed by moving along axis
1583- transpose_flat_index += multidimensional_index * transpose_axis_index_stride;
1584-
1585- // recall we are traversing the tranpose axes in forward order.
1586- // the # of indices contributed by moving along the (ndim - 2 - i)th
1587- // axis is the same as (the # of indices contributed by moving along the
1588- // (ndim_ - 1 - i)th axis) / shape(ndim_ - i - 1)
1589- transpose_axis_index_stride /= array_shape[ndim_ - i - 1 ];
15901589 }
15911590
15921591 update.index = transpose_flat_index;
1593-
15941592 return update;
15951593}
15961594
15971595void TransposeNode::propagate (State& state) const {
15981596 const std::span<const Update> array_diff = array_ptr_->diff (state);
1599-
1600- if (array_diff.empty () || ndim_ <= 1 ) {
1601- return ; // Nothing to do or predecessor is vector (transpose of vector is vector)
1602- }
1597+ // Nothing to do or predecessor is vector (transpose of vector is vector).
1598+ if (array_diff.empty () || ndim_ <= 1 ) return ;
16031599
16041600 // Predecessor is a non-dynamic (>=2)-D array.
16051601 std::vector<Update>& transpose_diff = data_ptr<TransposeNodeDiffData>(state)->diff ;
@@ -1608,34 +1604,26 @@ void TransposeNode::propagate(State& state) const {
16081604
16091605 for (const Update& u : array_diff) {
16101606 assert (([&]() {
1611- // make a copy of the update
1612- Update u_copy = u;
1613- // convert flat index of predecessor update to multidimensional indices
1614- std::vector<ssize_t > multi_index = unravel_index (u_copy.index , array_ptr_->shape ());
1615- // reverse multidimensional indices to obtain the multidimensional
1616- // transpose indices
1607+ // Convert flat index of predecessor update to multidimensional indices.
1608+ std::vector<ssize_t > multi_index = unravel_index (u.index , array_ptr_->shape ());
1609+ // Reverse indices to obtain the transpose indices.
16171610 std::reverse (multi_index.begin (), multi_index.end ());
1618- // convert multidimensional transpose indices to transpose flat index
1619- // and check conversion
1620- return ravel_multi_index (multi_index, this ->shape ()) ==
1621- convert_predecessor_update_ (u_copy).index ;
1611+ // Convert to transpose flat index and check conversion.
1612+ return ravel_multi_index (multi_index, shape ()) == convert_predecessor_update_ (u).index ;
16221613 })());
1623- // Make a copy of the update and convert the index to the respective
1624- // transpose index
1614+ // Copy update and convert predecessor index to transpose index.
16251615 transpose_diff.emplace_back (convert_predecessor_update_ (u));
16261616 }
16271617}
16281618
16291619void TransposeNode::commit (State& state) const {
1630- if (ndim_ > 1 ) {
1631- data_ptr<TransposeNodeDiffData>(state)->commit ();
1632- } // otherwise, stateless
1620+ if (ndim_ > 1 ) data_ptr<TransposeNodeDiffData>(state)->commit ();
1621+ // otherwise, stateless
16331622};
16341623
16351624void TransposeNode::revert (State& state) const {
1636- if (ndim_ > 1 ) {
1637- data_ptr<TransposeNodeDiffData>(state)->revert ();
1638- } // otherwise, stateless
1625+ if (ndim_ > 1 ) data_ptr<TransposeNodeDiffData>(state)->revert ();
1626+ // otherwise, stateless
16391627}
16401628
16411629} // namespace dwave::optimization
0 commit comments