Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/targets/gpu/kernels/include/migraphx/kernels/nonzero.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,22 @@ __device__ void nonzero(Input input, Output output)
constexpr auto block_size = decltype(idx.max_nlocal()){};
static_assert(block_size % MIGRAPHX_WAVEFRONTSIZE == 0,
"Block size must be a multiple of wavefront size");
// input (elem_num) uint32_t covers any input we realistically see;
// a narrower type uint8_t wraps once the prefix sum exceeds the
// type's range, producing negative out_loc values and OOB stores
block_scan(
idx,
op::sum{},
0,
index_int{0},
elem_num,
[&](auto j) -> uint8_t { return float_equal(input[j], 0) ? 0 : 1; },
[&](auto j) -> index_int { return float_equal(input[j], 0) ? 0 : 1; },
[&](auto j, auto value) {
MIGRAPHX_ASSERT(j < elem_num);
if(float_equal(input[j], 0))
return;
const auto out_loc = value - 1;
const auto multi_idx = in_shape.multi(j);
MIGRAPHX_ASSERT(value > 0);
const index_int out_loc = value - 1;
const auto multi_idx = in_shape.multi(j);
for(auto k = 0; k < multi_idx.size(); ++k)
{
output[make_array<index_int>(k, out_loc)] = multi_idx[k];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -47,10 +47,11 @@ __device__ void scatternd(const T& indices_t, const U& updates_t, const V& outpu
auto indices_idx = indices_shape.multi(0);
copy(updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());

auto index_start = indices_t.begin() + indices_shape.index(indices_idx);
auto index_end = index_start + k;
// begin_at is stride-aware; raw begin()+offset would only be correct
// for packed indices.
auto index_start = indices_t.begin_at(indices_idx);
auto out_idx = output_shape.multi(0);
copy(index_start, index_end, out_idx.begin());
copy(index_start, index_start + k, out_idx.begin());
copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);

f(output_t[out_idx], updates_t[i]);
Expand Down
Loading