diff --git a/bin/raydp-submit b/bin/raydp-submit index 55b561a4..164733b1 100755 --- a/bin/raydp-submit +++ b/bin/raydp-submit @@ -44,6 +44,10 @@ fi if [ -z "${RAY_HOME}" ]; then RAY_HOME=$(python3 -c "import os, ray; print(os.path.dirname(ray.__file__))") fi +# set RAY_DRIVER_NODE_IP +if [ -z "${RAY_DRIVER_NODE_IP}" ]; then + RAY_DRIVER_NODE_IP=$(python3 -c "import ray; print(ray.util.get_node_ip_address())") +fi # set log4j versions for spark driver and executors inside ray worker @@ -115,6 +119,10 @@ added_confs+=("-D$SPARK_LOG4J_CONFIG_FILE_NAME_KEY=$SPARK_LOG4J_CONFIG_FILE_NAME added_confs+=("-Dspark.javaagent=$raydp_agent_jar") added_args+=("--conf") added_args+=("spark.ray.log4j.config.file.name=$RAY_LOG4J_CONFIG_FILE_NAME") +added_args+=("--conf") +added_args+=("spark.driver.host=$RAY_DRIVER_NODE_IP") +added_args+=("--conf") +added_args+=("spark.driver.bindAddress=$RAY_DRIVER_NODE_IP") # Find the java binary if [ -n "${JAVA_HOME}" ]; then