Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Spark 3.5.x support + support Spark 3.4.3, deprecate python 3.7+support python 3.10 #411

Merged
merged 12 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
- name: Set up Python 3.7
- name: Set up Python 3.9
uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4
with:
python-version: 3.7
python-version: 3.9
- name: Set up JDK 1.8
uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
with:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ray_nightly_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
python-version: [3.8, 3.9]
spark-version: [3.1.3, 3.2.4, 3.3.2, 3.4.0]
python-version: [3.8, 3.9, 3.10.14]
spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0]

runs-on: ${{ matrix.os }}

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/raydp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
python-version: [3.8, 3.9]
spark-version: [3.1.3, 3.2.4, 3.3.2, 3.4.0]
python-version: [3.8, 3.9, 3.10.14]
spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0]

runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
else
pip install torch
fi
pip install pyarrow==6.0.1 ray[train] pytest koalas tensorflow==2.13.1 tabulate grpcio-tools wget
pip install pyarrow==6.0.1 ray[train] pytest tensorflow==2.13.1 tabulate grpcio-tools wget
pip install "xgboost_ray[default]<=0.1.13"
pip install torchmetrics
- name: Cache Maven
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/raydp_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
- name: Set up Python 3.7
- name: Set up Python 3.9
uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4
with:
python-version: 3.7
python-version: 3.9
- name: Set up JDK 1.8
uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,19 @@

package org.apache.spark.sql.raydp


import com.intel.raydp.shims.SparkShimLoader
import io.ray.api.{ActorHandle, ObjectRef, PyActorHandle, Ray}
import io.ray.runtime.AbstractRayRuntime
import java.io.ByteArrayOutputStream
import java.util.{List, UUID}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
import java.util.function.{Function => JFunction}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import io.ray.api.{ActorHandle, ObjectRef, PyActorHandle, Ray}
import io.ray.runtime.AbstractRayRuntime
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types.pojo.Schema
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{RayDPException, SparkContext}
import org.apache.spark.deploy.raydp._
Expand Down Expand Up @@ -105,7 +103,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
Iterator(iter)
}

val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val arrowSchema = SparkShimLoader.getSparkShims.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"ray object store writer", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
Expand Down Expand Up @@ -217,7 +215,7 @@ object ObjectStoreWriter {
def toArrowSchema(df: DataFrame): Schema = {
val conf = df.queryExecution.sparkSession.sessionState.conf
val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
ArrowUtils.toArrowSchema(df.schema, timeZoneId)
SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId)
}

def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package com.intel.raydp.shims

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.executor.RayDPExecutorBackendFactory
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SparkSession}

sealed abstract class ShimDescriptor
Expand All @@ -36,4 +38,6 @@ trait SparkShims {
def getExecutorBackendFactory(): RayDPExecutorBackendFactory

def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark322._
import org.apache.spark.spark322.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark322.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark322Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark322Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@

package org.apache.spark.sql.spark322

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session))
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark330._
import org.apache.spark.spark330.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark330.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark330Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark330Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@

package org.apache.spark.sql.spark330

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
ArrowConverters.toDataFrame(rdd, schema, session)
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ object SparkShimProvider {
val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0)
val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1)
val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR")
val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR",
s"$SPARK343_DESCRIPTOR")
val DESCRIPTOR = SPARK341_DESCRIPTOR
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark340._
import org.apache.spark.spark340.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark340.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark340Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark340Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.sql.spark340

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(
Expand All @@ -36,4 +38,8 @@ object SparkSqlUtils {
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}

object SparkShimProvider {
val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR")
val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR")
val DESCRIPTOR = SPARK350_DESCRIPTOR
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark350._
import org.apache.spark.spark350.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark350.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark350Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark350Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.sql.spark350

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(
Expand All @@ -36,4 +38,8 @@ object SparkSqlUtils {
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false)
}
}
4 changes: 2 additions & 2 deletions python/raydp/spark/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from raydp.utils import convert_to_spark

DF = Union["pyspark.sql.DataFrame", "koalas.DataFrame"]
OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["koalas.DataFrame"]]
DF = Union["pyspark.sql.DataFrame", "pyspark.pandas.DataFrame"]
OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["pyspark.pandas.DataFrame"]]


class SparkEstimatorInterface:
Expand Down
20 changes: 11 additions & 9 deletions python/raydp/tests/test_spark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import math
import sys

import databricks.koalas as ks
# https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html
# import databricks.koalas as ks
import pyspark.pandas as ps
import pyspark
import pytest

Expand All @@ -27,13 +29,13 @@

def test_df_type_check(spark_session):
spark_df = spark_session.range(0, 10)
koalas_df = ks.range(0, 10)
koalas_df = ps.range(0, 10)
assert utils.df_type_check(spark_df)
assert utils.df_type_check(koalas_df)

other_df = "df"
error_msg = (f"The type: {type(other_df)} is not supported, only support " +
"pyspark.sql.DataFrame and databricks.koalas.DataFrame")
"pyspark.sql.DataFrame and pyspark.pandas.DataFrame")
with pytest.raises(Exception) as exinfo:
utils.df_type_check(other_df)
assert str(exinfo.value) == error_msg
Expand All @@ -45,15 +47,15 @@ def test_convert_to_spark(spark_session):
assert is_spark_df
assert spark_df is converted

koalas_df = ks.range(0, 10)
converted, is_spark_df = utils.convert_to_spark(koalas_df)
pandas_on_spark_df = ps.range(0, 10)
converted, is_spark_df = utils.convert_to_spark(pandas_on_spark_df)
assert not is_spark_df
assert isinstance(converted, pyspark.sql.DataFrame)
assert converted.count() == 10

other_df = "df"
error_msg = (f"The type: {type(other_df)} is not supported, only support " +
"pyspark.sql.DataFrame and databricks.koalas.DataFrame")
"pyspark.sql.DataFrame and pyspark.pandas.DataFrame")
with pytest.raises(Exception) as exinfo:
utils.df_type_check(other_df)
assert str(exinfo.value) == error_msg
Expand All @@ -64,10 +66,10 @@ def test_random_split(spark_session):
splits = utils.random_split(spark_df, [0.7, 0.3])
assert len(splits) == 2

koalas_df = ks.range(0, 10)
koalas_df = ps.range(0, 10)
splits = utils.random_split(koalas_df, [0.7, 0.3])
assert isinstance(splits[0], ks.DataFrame)
assert isinstance(splits[1], ks.DataFrame)
assert isinstance(splits[0], ps.DataFrame)
assert isinstance(splits[1], ps.DataFrame)
assert len(splits) == 2


Expand Down
Loading
Loading