-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild.rs
More file actions
136 lines (115 loc) · 4.85 KB
/
build.rs
File metadata and controls
136 lines (115 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use std::env;
use std::path::PathBuf;
fn main() {
#[cfg(feature = "cuda")]
configure_cuda();
#[cfg(feature = "rocm")]
configure_rocm();
}
#[cfg(feature = "cuda")]
fn configure_cuda() {
// Find CUDA installation
let cuda_path = env::var("CUDA_PATH").unwrap_or_else(|_| {
if cfg!(target_os = "windows") {
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.0".to_string()
} else if cfg!(target_os = "macos") {
"/Developer/NVIDIA/CUDA-12.0".to_string()
} else {
"/usr/local/cuda-12.0".to_string()
}
});
// Configure linker to find CUDA libraries
println!("cargo:rustc-link-search=native={}/lib64", cuda_path);
println!("cargo:rustc-link-search=native={}/lib", cuda_path);
println!("cargo:rustc-link-lib=cudart");
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=wrapper_cuda.h");
// Configure bindgen to generate CUDA bindings
let bindings = bindgen::Builder::default()
.header("wrapper_cuda.h")
.clang_arg(format!("-I{}/include", cuda_path))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.allowlist_function("cuda.*")
.allowlist_type("cuda.*")
.generate()
.expect("Unable to generate CUDA bindings");
// Write the bindings to the $OUT_DIR/cuda_bindings.rs file
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("cuda_bindings.rs"))
.expect("Couldn't write CUDA bindings!");
}
#[cfg(feature = "rocm")]
fn configure_rocm() {
// Find ROCm installation
let rocm_path = env::var("ROCM_PATH").unwrap_or_else(|_| {
if cfg!(target_os = "windows") {
"C:/Program Files/AMD/ROCm".to_string()
} else if cfg!(target_os = "macos") {
"/opt/rocm-5.4.0".to_string() // ROCm typically uses versioned directories on macOS
} else {
"/opt/rocm".to_string()
}
});
println!("Using ROCm path: {}", rocm_path);
// Define typical ROCm include and library paths
let hip_include_path = format!("{}/include", rocm_path);
let hip_lib_path = format!("{}/lib", rocm_path);
// Configure linker to find ROCm libraries
println!("cargo:rustc-link-search=native={}", hip_lib_path);
// Link against the appropriate libraries based on platform
if cfg!(target_os = "windows") {
println!("cargo:rustc-link-lib=amdhip64");
} else if cfg!(target_os = "macos") {
println!("cargo:rustc-link-lib=amdhip64");
} else {
println!("cargo:rustc-link-lib=amdhip64");
println!("cargo:rustc-link-lib=hip_hcc"); // May be needed on older ROCm versions
}
println!("cargo:rustc-link-lib=hiprtc"); // HIP runtime compilation library
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=wrapper_hip.h");
println!("cargo:rerun-if-env-changed=ROCM_PATH");
// Configure bindgen to generate ROCm bindings
let mut builder = bindgen::Builder::default()
.header("wrapper_hip.h")
.clang_arg(format!("-I{}", hip_include_path))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.allowlist_function("hip.*")
.allowlist_type("hip.*")
.allowlist_var("HIP_.*")
.allowlist_var("hip.*");
// Add additional HIP-specific configuration
builder = builder
.blocklist_type("size_t") // Avoid conflicts with libc
.blocklist_type("hipError") // Use hipError_t instead
// Ensure to include critical typedefs
.allowlist_type("hipDevice_t")
.allowlist_type("hipCtx_t")
.allowlist_type("hipModule_t")
.allowlist_type("hipFunction_t")
.allowlist_type("hipEvent_t")
.allowlist_type("hipStream_t")
.allowlist_type("hipJitOption")
.allowlist_type("hipFuncCache")
.allowlist_type("hipSharedMemConfig");
// Include ROCm-specific flags
if cfg!(target_os = "linux") {
builder = builder.clang_arg("-D__HIP_PLATFORM_AMD__");
}
// Generate the bindings
let bindings = builder
.generate()
.expect("Unable to generate ROCm/HIP bindings");
// Write the bindings to the $OUT_DIR/rocm_bindings.rs file
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("rocm_bindings.rs"))
.expect("Couldn't write ROCm/HIP bindings!");
// Create directories for backend module if they don't exist
let src_path = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()).join("src");
let backend_path = src_path.join("backend");
let hip_path = backend_path.join("hip");
std::fs::create_dir_all(&hip_path).expect("Failed to create hip directory");
}
}