Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ import org.apache.spark.sql.{DataFrameReader, QueryTest}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.Utils

abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession with KafkaTest {
abstract class KafkaRelationSuiteBase
extends QueryTest
with SharedClassicSparkSession
with KafkaTest {

import testImplicits._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ import org.apache.spark.annotation.Stable
* @since 1.6.0
*/
@Stable
abstract class DatasetHolder[T] {
class DatasetHolder[T](ds: Dataset[T]) {

// This is declared with parentheses to prevent the Scala compiler from treating
// `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset.
def toDS(): Dataset[T]
def toDS(): Dataset[T] = ds

// This is declared with parentheses to prevent the Scala compiler from treating
// `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
def toDF(): DataFrame
def toDF(): DataFrame = ds.toDF()

def toDF(colNames: String*): DataFrame
def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*)
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ abstract class SQLImplicits extends EncoderImplicits with Serializable {
* Creates a [[Dataset]] from a local Seq.
* @since 1.6.0
*/
implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T]
implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] =
new DatasetHolder[T](session.createDataset(s))

/**
* Creates a [[Dataset]] from an RDD.
*
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T]
implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T] =
new DatasetHolder[T](session.createDataset(rdd))

/**
* An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ abstract class SQLImplicits private[sql] (override val session: SparkSession)
new DatasetHolder[T](session.createDataset(rdd))
}

class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U] {
class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U](ds) {
override def toDS(): Dataset[U] = ds
override def toDF(): DataFrame = ds.toDF()
override def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connector.catalog.{CatalogManager, Column, Identifier, InMemoryChangelogCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -72,7 +72,7 @@ import org.apache.spark.util.Utils
// scalastyle:on
class ProtoToParsedPlanTestSuite
extends SparkFunSuite
with SharedSparkSession
with SharedClassicSparkSession
with ResourceHelper {

private val cleanOrphanedGoldenFiles: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey, SparkConnectService}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.util.CloseableIterator

/**
* Base class and utilities for a test suite that starts and tests the real SparkConnectService
* with a real SparkConnectClient, communicating over RPC, but both in-process.
*/
trait SparkConnectServerTest extends SharedSparkSession {
trait SparkConnectServerTest extends SharedClassicSparkSession {

// Server port
val serverPort: Int =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimeType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Testing trait for SparkConnect tests with some helper methods to make it easier to create new
* test cases.
*/
trait SparkConnectPlanTest extends SharedSparkSession {
trait SparkConnectPlanTest extends SharedClassicSparkSession {
def transform(rel: proto.Relation): logical.LogicalPlan = {
SparkConnectPlannerTestUtils.transform(spark, rel)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteKey, ExecuteStatus, SessionStatus, SparkConnectAnalyzeHandler, SparkConnectService, SparkListenerConnectOperationStarted}
import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand All @@ -61,7 +61,7 @@ import org.apache.spark.util.Utils
* Testing Connect Service implementation.
*/
class SparkConnectServiceSuite
extends SharedSparkSession
extends SharedClassicSparkSession
with MockitoSugar
with Logging
with SparkConnectPlanTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class StreamingForeachBatchHelperSuite extends SharedSparkSession with MockitoSugar {
class StreamingForeachBatchHelperSuite extends SharedClassicSparkSession with MockitoSugar {

private def mockQuery(): StreamingQuery = {
val query = mock[StreamingQuery]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.{SparkConnectPlanner, SparkConnectPlanTest}
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class DummyPlugin extends RelationPlugin {
override def transform(
Expand Down Expand Up @@ -119,7 +119,7 @@ class ExampleCommandPlugin extends CommandPlugin {
}
}

class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConnectPlanTest {
class SparkConnectPluginRegistrySuite extends SharedClassicSparkSession with SparkConnectPlanTest {

override def beforeEach(): Unit = {
if (SparkEnv.get.conf.contains(Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ArtifactStatusesResponse
import org.apache.spark.network.util.JavaUtils.sha256Hex
import org.apache.spark.sql.connect.ResourceHelper
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.ThreadUtils

private class DummyStreamObserver(p: Promise[ArtifactStatusesResponse])
Expand All @@ -38,7 +38,7 @@ private class DummyStreamObserver(p: Promise[ArtifactStatusesResponse])
override def onCompleted(): Unit = {}
}

class ArtifactStatusesHandlerSuite extends SharedSparkSession with ResourceHelper {
class ArtifactStatusesHandlerSuite extends SharedClassicSparkSession with ResourceHelper {

val sessionId = UUID.randomUUID().toString

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.GetStatusResponse
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.plugin.{GetStatusPlugin, SparkConnectPluginRegistry}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -104,7 +104,7 @@ class FailingGetStatusPlugin extends GetStatusPlugin {
throw new RuntimeException("operation plugin failure")
}

class GetStatusHandlerSuite extends SharedSparkSession {
class GetStatusHandlerSuite extends SharedClassicSparkSession {

// Default userId matching SparkConnectTestUtils.createDummySessionHolder default
private val defaultUserId = "testUser"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.sql.connect.service
import java.util.UUID

import org.apache.spark.SparkSQLException
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class SparkConnectCloneSessionSuite extends SharedSparkSession {
class SparkConnectCloneSessionSuite extends SharedClassicSparkSession {

override def beforeEach(): Unit = {
super.beforeEach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.apache.spark.sql.connect.service

import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

/**
* Test suite for SparkConnectExecutionManager.
*/
class SparkConnectExecutionManagerSuite extends SharedSparkSession {
class SparkConnectExecutionManagerSuite extends SharedClassicSparkSession {

protected override def afterEach(): Unit = {
super.afterEach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ import org.apache.spark.sql.connect.execution.ExecuteResponseObserver
import org.apache.spark.sql.connect.planner.SparkConnectStreamingQueryListenerHandler
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener}
import org.apache.spark.sql.streaming.Trigger.ProcessingTime
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class SparkConnectListenerBusListenerSuite
extends SparkFunSuite
with SharedSparkSession
with SharedClassicSparkSession
with MockitoSugar {

override def afterEach(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, Spark
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.ArrayImplicits._

class SparkConnectSessionHolderSuite extends SharedSparkSession {
class SparkConnectSessionHolderSuite extends SharedClassicSparkSession {

test("DataFrame cache: Successful put and get") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import org.apache.spark.SparkSQLException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class SparkConnectSessionManagerSuite extends SharedSparkSession {
class SparkConnectSessionManagerSuite extends SharedClassicSparkSession {

override def beforeEach(): Unit = {
super.beforeEach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ abstract class SQLImplicits extends sql.SQLImplicits {
new DatasetHolder[T](session.createDataset(rdd))
}

class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U] {
class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U](ds) {
override def toDS(): Dataset[U] = ds
override def toDF(): DataFrame = ds.toDF()
override def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.storage.StorageLevel.{MEMORY_AND_DISK_2, MEMORY_ONLY}
Expand All @@ -65,9 +65,9 @@ private case class BigData(s: String)

@SlowSQLTest
class CachedTableSuite extends QueryTest
with SharedSparkSession
with SharedClassicSparkSession
with AdaptiveSparkPlanHelper {
import testImplicits._
import classicTestImplicits._

override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.plans.AsOfJoinDirection
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.tags.SlowSQLTest

@SlowSQLTest
class DataFrameAsOfJoinSuite extends QueryTest
with SharedSparkSession
with SharedClassicSparkSession
with AdaptiveSparkPlanHelper {

def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.tags.ExtendedSQLTest

/**
* Test suite for functions in [[org.apache.spark.sql.functions]].
*/
@ExtendedSQLTest
class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
import testImplicits._
class DataFrameFunctionsSuite extends QueryTest with SharedClassicSparkSession {
import classicTestImplicits._

test("DataFrame function and SQL function parity") {
// This test compares the available list of DataFrame functions in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import org.apache.spark.sql.classic.{Dataset => DatasetImpl}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, count, explode, sum, year}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}

class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
import testImplicits._
class DataFrameSelfJoinSuite extends QueryTest with SharedClassicSparkSession {
import classicTestImplicits._

test("join - join using self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.functions.{col, lit, struct, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}

class DataFrameStatSuite extends QueryTest with SharedSparkSession {
class DataFrameStatSuite extends QueryTest with SharedClassicSparkSession {
import testImplicits._

private def toLetter(i: Int): String = (i + 97).toChar.toString
Expand Down Expand Up @@ -608,7 +608,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
}


class DataFrameStatPerfSuite extends QueryTest with SharedSparkSession with Logging {
class DataFrameStatPerfSuite extends QueryTest with SharedClassicSparkSession with Logging {

// Turn on this test if you want to test the performance of approximate quantiles.
ignore("computing quantiles should not take much longer than describe()") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.{SharedClassicSparkSession, SharedSparkSession}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
Expand All @@ -72,9 +72,9 @@ object TestForTypeAlias {
}

class DatasetSuite extends QueryTest
with SharedSparkSession
with SharedClassicSparkSession
with AdaptiveSparkPlanHelper {
import testImplicits._
import classicTestImplicits._

private implicit val ordering: Ordering[ClassData] = Ordering.by((c: ClassData) => c.a -> c.b)

Expand Down
Loading