diff --git a/manywheel/build_rocm.sh b/manywheel/build_rocm.sh index 20f6bca8a..73421e00c 100755 --- a/manywheel/build_rocm.sh +++ b/manywheel/build_rocm.sh @@ -74,6 +74,8 @@ fi ROCM_VERSION_WITH_PATCH=rocm${ROCM_VERSION_MAJOR}.${ROCM_VERSION_MINOR}.${ROCM_VERSION_PATCH} ROCM_INT=$(($ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH)) +PYTORCH_VERSION=$(cat $PYTORCH_ROOT/version.txt | grep -oP "[0-9]+\.[0-9]+\.[0-9]+") + # Required ROCm libraries ROCM_SO_FILES=( "libMIOpen.so" @@ -120,6 +122,10 @@ if [[ $ROCM_INT -ge 60200 ]]; then ROCM_SO_FILES+=("librocm-core.so") fi +if [[ $(ver $PYTORCH_VERSION) -ge $(ver 2.8) ]]; then + HEAVYWEIGHT_ROCM_SO_FILES+=("libhipsparselt.so") +fi + OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release` if [[ "$OS_NAME" == *"CentOS Linux"* || "$OS_NAME" == *"AlmaLinux"* ]]; then LIBGOMP_PATH="/usr/lib64/libgomp.so.1" @@ -306,7 +312,6 @@ ver() { # Add triton install dependency # No triton dependency till pytorch 2.3 on 3.12 # since torch.compile doesn't work. -PYTORCH_VERSION=$(cat $PYTORCH_ROOT/version.txt | grep -oP "[0-9]+\.[0-9]+\.[0-9]+") # Assuming PYTORCH_VERSION=x.y.z, if x >= 2 if [ ${PYTORCH_VERSION%%\.*} -ge 2 ]; then if [[ $(uname) == "Linux" ]] && [[ "$DESIRED_PYTHON" != "3.12" || $(ver $PYTORCH_VERSION) -ge $(ver 2.4) ]]; then