-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathLeakyReLU.mlu
More file actions
122 lines (95 loc) · 2.76 KB
/
Copy pathLeakyReLU.mlu
File metadata and controls
122 lines (95 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <bang.h>
#include <torch/extension.h>
#include <cnrt.h>
#define CHUNK_SIZE 4096
__mlu_entry__ void leakyrelu_kernel(
float *input,
float *output,
int total,
float negative_slope) {
// 多核拆分参数
uint32_t core_id = taskId;
uint32_t core_num = taskDim;
uint32_t per_core = total / core_num;
uint32_t remainder = total % core_num; // 修正笔误
uint32_t start = core_id * per_core +
(core_id < remainder ? core_id : remainder);
uint32_t count = per_core +
(core_id < remainder ? 1 : 0);
// NRAM
__nram__ float nram_input[CHUNK_SIZE];
__nram__ float nram_relu[CHUNK_SIZE];
__nram__ float nram_temp[CHUNK_SIZE];
for (uint32_t offset = 0; offset < count; offset += CHUNK_SIZE) {
uint32_t len =
(offset + CHUNK_SIZE <= count)
? CHUNK_SIZE
: (count - offset);
uint32_t aligned_len = (len + 63) & ~63;
__memcpy(
nram_input,
input + start + offset,
len * sizeof(float),
GDRAM2NRAM);
// relu(x)
__bang_active_relu(
nram_relu,
nram_input,
aligned_len);
// min(0,x)
__bang_sub(
nram_temp,
nram_input,
nram_relu,
aligned_len);
// negative_slope * min(0,x)
__bang_mul_scalar(
nram_temp,
nram_temp,
negative_slope,
aligned_len);
// relu + scaled negative
__bang_add(
nram_temp,
nram_relu,
nram_temp,
aligned_len);
__memcpy(
output + start + offset,
nram_temp,
len * sizeof(float),
NRAM2GDRAM);
}
}
torch::Tensor bang_func(
torch::Tensor input,
double negative_slope) {
TORCH_CHECK(
input.is_contiguous(),
"Input must be contiguous");
// 保留原始 dtype
auto original_dtype = input.scalar_type();
// -------- 只处理数据类型 --------
torch::Tensor input_fp32 = input;
if (original_dtype != torch::kFloat) {
input_fp32 = input.to(torch::kFloat);
}
auto output_fp32 = torch::empty_like(input_fp32);
int total = input_fp32.numel();
cnrtQueue_t queue =
torch_mlu::getCurMLUStream();
cnrtDim3_t dim = {4,1,1};
cnrtFunctionType_t ktype =
cnrtFuncTypeUnion1;
leakyrelu_kernel<<<dim, ktype, queue>>>(
input_fp32.data_ptr<float>(),
output_fp32.data_ptr<float>(),
total,
(float)negative_slope
);
// 转回原 dtype
if (original_dtype != torch::kFloat) {
return output_fp32.to(original_dtype);
}
return output_fp32;
}