diff --git a/spark-rapids/spark-rapids.sh b/spark-rapids/spark-rapids.sh index 235eb51b7..82f3415b1 100644 --- a/spark-rapids/spark-rapids.sh +++ b/spark-rapids/spark-rapids.sh @@ -269,8 +269,7 @@ function execute_with_retries() { return 1 } -function install_spark_rapids() { - local -r nvidia_repo_url='https://repo1.maven.org/maven2/com/nvidia' +function install_gpu_xgboost() { local -r dmlc_repo_url='https://repo.maven.apache.org/maven2/ml/dmlc' wget -nv --timeout=30 --tries=5 --retry-connrefused \ @@ -279,6 +278,29 @@ function install_spark_rapids() { wget -nv --timeout=30 --tries=5 --retry-connrefused \ "${dmlc_repo_url}/xgboost4j-gpu_2.12/${XGBOOST_VERSION}/xgboost4j-gpu_2.12-${XGBOOST_VERSION}.jar" \ -P /usr/lib/spark/jars/ +} + +function check_spark_rapids_jar() { + local jars_found + jars_found=$(ls /usr/lib/spark/jars/rapids-4-spark_*.jar 2>/dev/null | wc -l) + if [[ $jars_found -gt 0 ]]; then + echo "RAPIDS Spark plugin JAR found" + return 0 + else + echo "RAPIDS Spark plugin JAR not found" + return 1 + fi +} + +function remove_spark_rapids_jar() { + rm -f /usr/lib/spark/jars/rapids-4-spark_*.jar + echo "Existing RAPIDS Spark plugin JAR removed successfully" +} + +function install_spark_rapids() { + + local -r nvidia_repo_url='https://repo1.maven.org/maven2/com/nvidia' + wget -nv --timeout=30 --tries=5 --retry-connrefused \ "${nvidia_repo_url}/rapids-4-spark_2.12/${SPARK_RAPIDS_VERSION}/rapids-4-spark_2.12-${SPARK_RAPIDS_VERSION}.jar" \ -P /usr/lib/spark/jars/ @@ -807,27 +829,38 @@ function remove_old_backports { function main() { - if is_debian && [[ $(echo "${DATAPROC_IMAGE_VERSION} <= 2.1" | bc -l) == 1 ]]; then - remove_old_backports - fi - check_os_and_secure_boot - setup_gpu_yarn - if [[ "${RUNTIME}" == "SPARK" ]]; then + # If the RAPIDS Spark RAPIDS JAR is already installed (common on ML images), replace it with the requested version + # ML images by default have Spark RAPIDS and GPU drivers installed + if check_spark_rapids_jar; then + # This ensures the cluster always uses the desired RAPIDS version, even if a default is present + remove_spark_rapids_jar install_spark_rapids - configure_spark - echo "RAPIDS initialized with Spark runtime" + echo "RAPIDS Spark RAPIDS JAR replaced successfully" else - echo "Unsupported RAPIDS Runtime: ${RUNTIME}" - exit 1 - fi + # Install GPU drivers and setup SPARK RAPIDS JAR for non-ML images + if is_debian && [[ $(echo "${DATAPROC_IMAGE_VERSION} <= 2.1" | bc -l) == 1 ]]; then + remove_old_backports + fi + check_os_and_secure_boot + setup_gpu_yarn + if [[ "${RUNTIME}" == "SPARK" ]]; then + install_spark_rapids + install_gpu_xgboost + configure_spark + echo "RAPIDS initialized with Spark runtime" + else + echo "Unsupported RAPIDS Runtime: ${RUNTIME}" + exit 1 + fi - for svc in resourcemanager nodemanager; do - if [[ $(systemctl show hadoop-yarn-${svc}.service -p SubState --value) == 'running' ]]; then - systemctl restart hadoop-yarn-${svc}.service + for svc in resourcemanager nodemanager; do + if [[ $(systemctl show hadoop-yarn-${svc}.service -p SubState --value) == 'running' ]]; then + systemctl restart hadoop-yarn-${svc}.service + fi + done + if is_debian || is_ubuntu ; then + apt-get clean fi - done - if is_debian || is_ubuntu ; then - apt-get clean fi }