diff --git a/build.sbt b/build.sbt index af02be9f9..5c8044d30 100644 --- a/build.sbt +++ b/build.sbt @@ -4,6 +4,10 @@ val scalatest = "3.0.1" val shapeless = "2.3.2" val scalacheck = "1.13.4" +// spark has scalatest and scalactic as a runtime dependency +// which can mess things up if you use a different version in your project +val exclusions = Seq(ExclusionRule("org.scalatest"), ExclusionRule("org.scalactic")) + lazy val root = Project("frameless", file("." + "frameless")).in(file(".")) .aggregate(core, cats, dataset, docs) .settings(framelessSettings: _*) @@ -22,7 +26,7 @@ lazy val cats = project .settings(publishSettings: _*) .settings(libraryDependencies ++= Seq( "org.typelevel" %% "cats" % catsv, - "org.apache.spark" %% "spark-core" % sparkVersion % "provided")) + "org.apache.spark" %% "spark-core" % sparkVersion % "provided" excludeAll(exclusions: _*))) lazy val dataset = project .settings(name := "frameless-dataset") @@ -31,8 +35,8 @@ lazy val dataset = project .settings(framelessTypedDatasetREPL: _*) .settings(publishSettings: _*) .settings(libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-core" % sparkVersion % "provided", - "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" + "org.apache.spark" %% "spark-core" % sparkVersion % "provided" excludeAll(exclusions: _*), + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" excludeAll(exclusions: _*) )) .dependsOn(core % "test->test;compile->compile") diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 92719d482..b17d2eafe 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -271,6 +271,7 @@ object TypedColumn { lgen: LabelledGeneric.Aux[T, H], selector: Selector.Aux[H, K, V] ): Exists[T, K, V] = new Exists[T, K, V] {} + } implicit class OrderedTypedColumnSyntax[T, U: CatalystOrdered](col: TypedColumn[T, U]) { @@ -279,4 +280,4 @@ object TypedColumn { def >(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped > other.untyped).typed def >=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped >= other.untyped).typed } -} +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index ad9913131..175d4f684 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -62,11 +62,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val * * It is statically checked that column with such name exists and has type `A`. */ - def apply[A](column: Witness.Lt[Symbol])( - implicit - exists: TypedColumn.Exists[T, column.T, A], + def apply[A](selector: T => A)(implicit encoder: TypedEncoder[A] - ): TypedColumn[T, A] = col(column) + ): TypedColumn[T, A] = macro frameless.macros.ColumnMacros.fromFunction[T, A] /** Returns `TypedColumn` of type `A` given it's name. * @@ -319,6 +317,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } + def selectExpr[B](expr: T => B)(implicit encoder: TypedEncoder[B]): TypedDataset[B] = + macro frameless.macros.ColumnMacros.fromExpr[T, B] + /** Type-safe projection from type T to Tuple2[A,B] * {{{ * d.select( d('a), d('a)+d('b), ... ) diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index a434f3d87..89512a2e2 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -4,10 +4,12 @@ import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import shapeless._ + +import scala.collection.Map import scala.reflect.ClassTag abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Serializable { @@ -34,9 +36,11 @@ object TypedEncoder { def nullable: Boolean = true def sourceDataType: DataType = ScalaReflection.dataTypeFor[Unit] + def targetDataType: DataType = NullType def constructorFor(path: Expression): Expression = Literal.create((), sourceDataType) + def extractorFor(path: Expression): Expression = Literal.create(null, targetDataType) } @@ -44,6 +48,7 @@ object TypedEncoder { def nullable: Boolean = true def sourceDataType: DataType = FramelessInternals.objectTypeFor[String] + def targetDataType: DataType = StringType def extractorFor(path: Expression): Expression = @@ -57,9 +62,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = BooleanType + def targetDataType: DataType = BooleanType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -67,9 +74,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = IntegerType + def targetDataType: DataType = IntegerType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -77,9 +86,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = LongType + def targetDataType: DataType = LongType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -87,9 +98,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ShortType + def targetDataType: DataType = ShortType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -98,6 +111,7 @@ object TypedEncoder { implicit val charAsString: Injection[java.lang.Character, String] = new Injection[java.lang.Character, String] { def apply(a: java.lang.Character): String = String.valueOf(a) + def invert(b: String): java.lang.Character = { require(b.length == 1) b.charAt(0) @@ -110,9 +124,11 @@ object TypedEncoder { // this line fixes underlying encoder def sourceDataType: DataType = FramelessInternals.objectTypeFor[java.lang.Character] + def targetDataType: DataType = StringType def extractorFor(path: Expression): Expression = underlying.extractorFor(path) + def constructorFor(path: Expression): Expression = underlying.constructorFor(path) } @@ -120,9 +136,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ByteType + def targetDataType: DataType = ByteType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -130,9 +148,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = FloatType + def targetDataType: DataType = FloatType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -140,9 +160,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = DoubleType + def targetDataType: DataType = DoubleType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -150,6 +172,7 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ScalaReflection.dataTypeFor[BigDecimal] + def targetDataType: DataType = DecimalType.SYSTEM_DEFAULT def extractorFor(path: Expression): Expression = @@ -163,6 +186,7 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ScalaReflection.dataTypeFor[SQLDate] + def targetDataType: DataType = DateType def extractorFor(path: Expression): Expression = @@ -182,6 +206,7 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ScalaReflection.dataTypeFor[SQLTimestamp] + def targetDataType: DataType = TimestampType def extractorFor(path: Expression): Expression = @@ -204,6 +229,7 @@ object TypedEncoder { def nullable: Boolean = true def sourceDataType: DataType = FramelessInternals.objectTypeFor[Option[A]](classTag) + def targetDataType: DataType = underlying.targetDataType def extractorFor(path: Expression): Expression = { @@ -253,26 +279,41 @@ object TypedEncoder { WrapOption(underlying.constructorFor(path), underlying.sourceDataType) } - implicit def vectorEncoder[A]( + implicit def vectorEncoder[A : ClassTag]( implicit underlying: TypedEncoder[A] ): TypedEncoder[Vector[A]] = new TypedEncoder[Vector[A]]() { - def nullable: Boolean = false + val nullable: Boolean = false - def sourceDataType: DataType = FramelessInternals.objectTypeFor[Vector[A]](classTag) + val sourceDataType: DataType = FramelessInternals.objectTypeFor[Vector[A]](classTag) - def targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) + val targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) def constructorFor(path: Expression): Expression = { - val arrayData = Invoke( - MapObjects( - underlying.constructorFor, - path, - underlying.targetDataType - ), - "array", - ScalaReflection.dataTypeFor[Array[AnyRef]] - ) + val arrayData = Option(underlying.sourceDataType) + .filter(ScalaReflection.isNativeType) + .filter(_ == underlying.targetDataType) + .collect { + case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]] + case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] + case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] + case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] + case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] + case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] + case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] + }.map { + case (method, typ) => Invoke(path, method, typ) + }.getOrElse { + Invoke( + MapObjects( + underlying.constructorFor, + path, + underlying.targetDataType + ), + "array", + FramelessInternals.objectTypeFor[Array[A]] + ) + } StaticInvoke( TypedEncoderUtils.getClass, @@ -296,24 +337,138 @@ object TypedEncoder { } } + implicit def arrayEncoder[A : ClassTag]( + implicit + underlying: TypedEncoder[A] + ): TypedEncoder[Array[A]] = new TypedEncoder[Array[A]]() { + val nullable: Boolean = false + + val sourceDataType: DataType = FramelessInternals.objectTypeFor[Array[A]](classTag) + + val targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) + + def constructorFor(path: Expression): Expression = { + Option(underlying.sourceDataType) + .filter(ScalaReflection.isNativeType) + .filter(_ == underlying.targetDataType) + .collect { + case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]] + case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] + case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] + case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] + case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] + case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] + case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] + }.map { + case (method, typ) => Invoke(path, method, typ) + }.getOrElse { + Invoke( + MapObjects( + underlying.constructorFor, + path, + underlying.targetDataType + ), + "array", + sourceDataType + ) + } + } + + def extractorFor(path: Expression): Expression = { + // if source `path` is already native for Spark, no need to `map` + if (ScalaReflection.isNativeType(underlying.sourceDataType)) { + NewInstance( + classOf[GenericArrayData], + path :: Nil, + dataType = ArrayType(underlying.targetDataType, underlying.nullable) + ) + } else { + MapObjects(underlying.extractorFor, path, underlying.sourceDataType) + } + } + } + + implicit def mapEncoder[A : ClassTag, B : ClassTag]( + implicit + encodeA: TypedEncoder[A], + encodeB: TypedEncoder[B] + ): TypedEncoder[scala.collection.immutable.Map[A, B]] = new TypedEncoder[scala.collection.immutable.Map[A, B]] { + val nullable: Boolean = false + val sourceDataType = FramelessInternals.objectTypeFor[Map[A, B]] + val targetDataType = MapType(encodeA.targetDataType, encodeB.targetDataType, encodeB.nullable) + + implicit val classTagArrayA = implicitly[ClassTag[A]].wrap + implicit val classTagArrayB = implicitly[ClassTag[B]].wrap + + private val arrayA = arrayEncoder[A] + private val arrayB = arrayEncoder[B] + private val vectorA = vectorEncoder[A] + private val vectorB = vectorEncoder[B] + + + private def wrap(arrayData: Expression) = { + StaticInvoke( + scala.collection.mutable.WrappedArray.getClass, + FramelessInternals.objectTypeFor[Seq[_]], + "make", + arrayData :: Nil) + } + + def constructorFor(path: Expression): Expression = { + val keyArrayType = ArrayType(encodeA.targetDataType, false) + val keyData = wrap(arrayA.constructorFor(Invoke(path, "keyArray", keyArrayType))) + + val valueArrayType = ArrayType(encodeB.targetDataType, encodeB.nullable) + val valueData = wrap(arrayB.constructorFor(Invoke(path, "valueArray", valueArrayType))) + + StaticInvoke( + ArrayBasedMapData.getClass, + sourceDataType, + "toScalaMap", + keyData :: valueData :: Nil) + } + + def extractorFor(path: Expression): Expression = { + val keys = + Invoke( + Invoke(path, "keysIterator", FramelessInternals.objectTypeFor[scala.collection.Iterator[A]]), + "toVector", + vectorA.sourceDataType) + val convertedKeys = arrayA.extractorFor(keys) + + val values = + Invoke( + Invoke(path, "valuesIterator", FramelessInternals.objectTypeFor[scala.collection.Iterator[B]]), + "toVector", + vectorB.sourceDataType) + val convertedValues = arrayB.extractorFor(values) + + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = targetDataType) + } + + } + /** Encodes things using injection if there is one defined */ implicit def usingInjection[A: ClassTag, B] - (implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] = - new TypedEncoder[A] { - def nullable: Boolean = trb.nullable - def sourceDataType: DataType = FramelessInternals.objectTypeFor[A](classTag) - def targetDataType: DataType = trb.targetDataType - - def constructorFor(path: Expression): Expression = { - val bexpr = trb.constructorFor(path) - Invoke(Literal.fromObject(inj), "invert", sourceDataType, Seq(bexpr)) - } + (implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] = + new TypedEncoder[A] { + def nullable: Boolean = trb.nullable + def sourceDataType: DataType = FramelessInternals.objectTypeFor[A](classTag) + def targetDataType: DataType = trb.targetDataType + + def constructorFor(path: Expression): Expression = { + val bexpr = trb.constructorFor(path) + Invoke(Literal.fromObject(inj), "invert", sourceDataType, Seq(bexpr)) + } - def extractorFor(path: Expression): Expression = { - val invoke = Invoke(Literal.fromObject(inj), "apply", trb.sourceDataType, Seq(path)) - trb.extractorFor(invoke) - } + def extractorFor(path: Expression): Expression = { + val invoke = Invoke(Literal.fromObject(inj), "apply", trb.sourceDataType, Seq(path)) + trb.extractorFor(invoke) } + } /** Encodes things as records if there is not Injection defined */ implicit def usingDerivation[F, G <: HList]( @@ -322,4 +477,4 @@ object TypedEncoder { recordEncoder: Lazy[RecordEncoderFields[G]], classTag: ClassTag[F] ): TypedEncoder[F] = new RecordEncoder[F, G] -} +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 06fe5bc09..ba5529c87 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -2,4 +2,6 @@ package frameless package object functions extends Udf { object aggregate extends AggregateFunctions + + } diff --git a/dataset/src/main/scala/frameless/functions/quoted.scala b/dataset/src/main/scala/frameless/functions/quoted.scala new file mode 100644 index 000000000..c29d955d3 --- /dev/null +++ b/dataset/src/main/scala/frameless/functions/quoted.scala @@ -0,0 +1,441 @@ +package frameless.functions + +import scala.annotation.{StaticAnnotation, compileTimeOnly} +import scala.math.BigDecimal.RoundingMode +import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.expressions.{Factorial, Hex} +import org.apache.spark.sql.catalyst.util.NumberConverter +import org.apache.spark.sql.{Column, functions => sf} + +/** + * Quoted functions, for being used inside an expression + * Functions annotated with @sparkFunction are rewritten accordingly, whereas other functions are converted to UDFs and + * actually applied. + * + * Note that these functions have simple implementations in order to provide explanation of what they do, and so that + * it won't be catastrophic to call them in the wrong context. But the implementations ought to never be invoked + * at run time. + */ +object quoted { + + // $COVERAGE-OFF$ - these functions will never actually be invoked at runtime + + sealed trait QuotedFunc + class sparkFunction0[B](referrent: () => B) extends StaticAnnotation with QuotedFunc + class sparkFunction[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc + class sparkColumnOp[A, B](referrent: Column => A => B) extends StaticAnnotation with QuotedFunc + class sparkAggregate[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc + class sparkWindowFunction[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc + class sparkWindowFunction0[B](referent: () => B) extends StaticAnnotation with QuotedFunc + + @sparkFunction(sf.abs) + def abs[T : Numeric](t: T): T = implicitly[Numeric[T]].abs(t) + + @sparkFunction((cols: Seq[Column]) => sf.array(cols:_*)) + def array[T : ClassTag](items: T*): Array[T] = items.toArray + + @sparkFunction((cols: Seq[(Column,Column)]) => sf.map(cols.flatMap(a => Seq(a._1, a._2)):_*)) + def map[T, U](items: (T, U)*): Map[T, U] = items.toMap + + @sparkFunction((cols: Seq[Column]) => sf.coalesce(cols:_*)) + def coalesce[T](items: T*): T = items.find(_ != null).getOrElse(null.asInstanceOf[T]) + + @sparkFunction0(sf.input_file_name) + def input_file_name(): String = "invoked outside of quoted expression" + + @sparkFunction(sf.isnan) + def isnan[T: Numeric](t: T): Boolean = implicitly[Numeric[T]].toDouble(t).isNaN + + @sparkFunction(sf.isnull) + def isnull[T](t: T): Boolean = t == null + + private lazy val longs = Stream.range(0L, Long.MaxValue).iterator + @sparkFunction0(sf.monotonically_increasing_id) + def monotonically_increasing_id(): Long = longs.next() + + @sparkFunction((sf.nanvl _).tupled) + def nanvl(first: Double, second: Double): Double = if(java.lang.Double.isNaN(first)) second else first + + @sparkFunction((sf.nanvl _).tupled) + def nanvl(first: Float, second: Float): Float = if(java.lang.Float.isNaN(first)) second else first + + @sparkFunction(sf.negate) + def negate[T : Numeric](t: T): T = implicitly[Numeric[T]].negate(t) + + @sparkFunction(sf.not) + def not(b: Boolean): Boolean = !b + + @sparkFunction(sf.rand) + def rand(seed: Long): Double = { + scala.util.Random.setSeed(seed) + scala.util.Random.nextDouble() + } + + @sparkFunction0(sf.rand) + def rand(): Double = scala.util.Random.nextDouble() + + @sparkFunction(sf.randn) + def randn(seed: Long): Double = { + scala.util.Random.setSeed(seed) + scala.util.Random.nextGaussian() + } + + @sparkFunction0(sf.randn) + def randn(): Double = scala.util.Random.nextGaussian() + + @sparkFunction0(sf.spark_partition_id) + def spark_partition_id(): Int = -1 + + @sparkFunction(sf.sqrt(_: Column)) + def sqrt[T](t: T)(implicit frac: Fractional[T]): Double = math.sqrt(frac.toDouble(t)) + + + case class CaseBuilder[T]() { + @sparkColumnOp(col => (col.when _).tupled) + def when(condition: Boolean, value: T): CaseBuilder[T] = this + + @sparkColumnOp(_.otherwise) + def otherwise(value: T): T = value + } + + @sparkFunction((sf.when _).tupled) + def when[T](condition: Boolean, value: T): CaseBuilder[T] = CaseBuilder[T]() + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(b: Byte): Byte = (~b).toByte + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(s: Short): Short = (~s).toShort + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(i: Int): Int = ~i + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(l: Long): Long = ~l + + @sparkFunction(sf.acos(_:Column)) + def acos(d: Double): Double = scala.math.acos(d) + + @sparkFunction(sf.asin(_:Column)) + def asin(d: Double): Double = scala.math.asin(d) + + @sparkFunction(sf.atan(_:Column)) + def atan(d: Double): Double = scala.math.atan(d) + + @sparkFunction((sf.atan2(_: Column, _: Column)).tupled) + def atan2(l: Double, r: Double): Double = scala.math.atan2(l, r) + + @sparkFunction(sf.bin(_:Column)) + def bin(l: Long): String = java.lang.Long.toBinaryString(l) + + @sparkFunction(sf.cbrt(_:Column)) + def cbrt(d: Double): Double = math.cbrt(d) + + @sparkFunction(sf.ceil(_:Column)) + def ceil(d: Double): Double = math.ceil(d) + + @sparkFunction((sf.conv _).tupled) + def conv(num: String, fromBase: Int, toBase: Int): String = + NumberConverter.convert(num.getBytes(), fromBase, toBase).toString + + @sparkFunction(sf.cos(_:Column)) + def cos(n: Double): Double = math.cos(n) + + @sparkFunction(sf.cosh(_:Column)) + def cosh(n: Double): Double = math.cosh(n) + + @sparkFunction(sf.exp(_:Column)) + def exp(n: Double): Double = math.exp(n) + + @sparkFunction(sf.expm1(_:Column)) + def expm1(n: Double): Double = math.expm1(n) + + @sparkFunction(sf.factorial) + def factorial(n: Int): Long = Factorial.factorial(n) + + @sparkFunction(sf.floor(_:Column)) + def floor(d: Double): Double = math.floor(d) + + @sparkFunction((cols: Seq[Column]) => sf.greatest(cols:_*)) + def greatest[T : Ordering](first: T, rest: T*): T = rest.foldLeft(first)(implicitly[Ordering[T]].max) + + @sparkFunction(sf.hex) + def hex(arr: Array[Byte]): String = Hex.hex(arr).toString + + @sparkFunction(sf.hex) + def hex(l: Long): String = Hex.hex(l).toString + + @sparkFunction(sf.unhex) + def unhex(str: String): Array[Byte] = Hex.unhex(str.getBytes()) + + @sparkFunction((sf.hypot(_:Column,_:Column)).tupled) + def hypot(a: Double, b: Double): Double = math.hypot(a, b) + + @sparkFunction((cols: Seq[Column]) => sf.least(cols:_*)) + def least[T : Ordering](first: T, rest: T*): T = rest.foldLeft(first)(implicitly[Ordering[T]].min) + + @sparkFunction(sf.log(_:Column)) + def log(n: Double): Double = math.log(n) + + @sparkFunction((sf.log(_:Double,_:Column)).tupled) + def log(base: Double, n: Double): Double = math.log(n) / math.log(base) + + @sparkFunction(sf.log10(_:Column)) + def log10(n: Double): Double = math.log10(n) + + @sparkFunction(sf.log1p(_:Column)) + def log1p(n: Double): Double = math.log1p(n) + + @sparkFunction(sf.log2(_:Column)) + def log2(n: Double): Double = log(n, 2.0) + + @sparkFunction((sf.pow(_:Column,_:Column)).tupled) + def pow(base: Double, exp: Double): Double = math.pow(base, exp) + + @sparkFunction((sf.pmod _).tupled) + def pmod[T](dividend: T, divisor: T)(implicit int: Integral[T]): T = { + val rem = int.rem(dividend, divisor) + if(int.lt(rem, int.zero)) + int.plus(rem, divisor) + else + rem + } + + @sparkFunction(sf.rint(_:Column)) + def rint(n: Double): Double = math.rint(n) + + @sparkFunction(sf.round(_:Column)) + def round(n: Double): Long = math.round(n) + + @sparkFunction(sf.round(_:Column)) + def round(n: Float): Int = math.round(n) + + @sparkFunction((sf.round(_:Column, _:Int)).tupled) + def round(n: Double, scale: Int): Double = BigDecimal(n).setScale(scale, RoundingMode.HALF_UP).toDouble + + @sparkFunction(sf.bround(_:Column)) + def bround(n: Double): Long = BigDecimal(n).setScale(0, RoundingMode.HALF_EVEN).toLong + + @sparkFunction((sf.bround(_:Column, _:Int)).tupled) + def bround(n: Double, scale: Int): Double = BigDecimal(n).setScale(scale, RoundingMode.HALF_EVEN).toDouble + + @sparkFunction((sf.shiftLeft _).tupled) + def shiftLeft(l: Long, bits: Int): Long = l << bits + + @sparkFunction((sf.shiftLeft _).tupled) + def shiftLeft(i: Int, bits: Int): Int = i << bits + + @sparkFunction((sf.shiftRight _).tupled) + def shiftRight(l: Long, bits: Int): Long = l >> bits + + @sparkFunction((sf.shiftRight _).tupled) + def shiftRight(i: Int, bits: Int): Int = i >> bits + + @sparkFunction((sf.shiftRightUnsigned _).tupled) + def shiftRightUnsigned(l: Long, bits: Int): Long = l >>> bits + + @sparkFunction((sf.shiftRightUnsigned _).tupled) + def shiftRightUnsigned(i: Int, bits: Int): Int = i >>> bits + + @sparkFunction(sf.signum(_:Column)) + def signum[T : Numeric](t: T): Int = implicitly[Numeric[T]].signum(t) + + @sparkFunction(sf.sin(_:Column)) + def sin(d: Double): Double = math.sin(d) + + @sparkFunction(sf.sinh(_:Column)) + def sinh(d: Double): Double = math.sinh(d) + + @sparkFunction(sf.tan(_:Column)) + def tan(d: Double): Double = math.tan(d) + + @sparkFunction(sf.tanh(_:Column)) + def tanh(d: Double): Double = math.tanh(d) + + @sparkFunction(sf.toDegrees(_:Column)) + def toDegrees(d: Double): Double = math.toDegrees(d) + + @sparkFunction(sf.toRadians(_:Column)) + def toRadians(d: Double): Double = math.toRadians(d) + + ///////////////////////// + // Aggregate functions // + ///////////////////////// + private val aggMsg = "Aggregate function can only be used inside a TypedDataset expression" + + @sparkAggregate(sf.approxCountDistinct(_:Column)) + @compileTimeOnly(aggMsg) + def approxCountDistinct[T](t: T): Long = ??? + + @sparkAggregate((sf.approxCountDistinct(_:Column,_:Double)).tupled) + @compileTimeOnly(aggMsg) + def approxCountDistinct[T](a: T, b: Double): Long = ??? + + @sparkAggregate(sf.avg(_:Column)) + @compileTimeOnly(aggMsg) + def avg[T : Numeric](t: T): Double = ??? + + @sparkAggregate(sf.collect_list(_:Column)) + @compileTimeOnly(aggMsg) + def collect_list[T](t: T): Seq[T] = ??? + + @sparkAggregate(sf.collect_set(_:Column)) + @compileTimeOnly(aggMsg) + def collect_set[T](t: T): Seq[T] = ??? + + @sparkAggregate((sf.corr(_:Column,_:Column)).tupled) + @compileTimeOnly(aggMsg) + def corr(a: Double, b: Double): Double = ??? + + @sparkAggregate(sf.count(_:Column)) + @compileTimeOnly(aggMsg) + def count[T](t: T): Long = ??? + + @sparkAggregate(((col: Column, cols: Seq[Column]) => sf.countDistinct(col, cols:_*)).tupled) + @compileTimeOnly(aggMsg) + def countDistinct(columns: Any*): Long = ??? + + @sparkAggregate((sf.covar_pop(_:Column,_:Column)).tupled) + @compileTimeOnly(aggMsg) + def covar_pop(a: Double, b: Double): Double = ??? + + @sparkAggregate((sf.covar_samp(_:Column,_:Column)).tupled) + @compileTimeOnly(aggMsg) + def covar_samp(a: Double, b: Double): Double = ??? + + @sparkAggregate((sf.first(_:Column,_:Boolean)).tupled) + @compileTimeOnly(aggMsg) + def first[T](col: T, ignoreNulls: Boolean) = ??? + + @sparkAggregate(sf.first(_:Column)) + @compileTimeOnly(aggMsg) + def first[T](col: T): T = first(col, ignoreNulls = false) + + @sparkAggregate(sf.grouping(_:Column)) + @compileTimeOnly(aggMsg) + def grouping(col: Any): Int = ??? + + @sparkAggregate((cols: Seq[Column]) => sf.grouping_id(cols:_*)) + @compileTimeOnly(aggMsg) + def grouping_id(cols: Any*): Int = ??? + + @sparkAggregate(sf.kurtosis(_:Column)) + @compileTimeOnly(aggMsg) + def kurtosis(col: Double): Double = ??? + + @sparkAggregate((sf.last(_:Column,_:Boolean)).tupled) + @compileTimeOnly(aggMsg) + def last[T](col: T, ignoreNulls: Boolean) = ??? + + @sparkAggregate(sf.last(_:Column)) + @compileTimeOnly(aggMsg) + def last[T](col: T): T = last(col, ignoreNulls = false) + + @sparkAggregate(sf.max(_:Column)) + @compileTimeOnly(aggMsg) + def max[T : Ordering](col: T): T = ??? + + @sparkAggregate(sf.mean(_:Column)) + @compileTimeOnly(aggMsg) + def mean[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.min(_:Column)) + @compileTimeOnly(aggMsg) + def min[T : Ordering](col: T): T = ??? + + @sparkAggregate(sf.skewness(_:Column)) + @compileTimeOnly(aggMsg) + def skewness(col: Double): Double = ??? + + @sparkAggregate(sf.stddev(_:Column)) + @compileTimeOnly(aggMsg) + def stddev(col: Double): Double = ??? + + @sparkAggregate(sf.stddev_samp(_:Column)) + @compileTimeOnly(aggMsg) + def stddev_samp(col: Double): Double = ??? + + @sparkAggregate(sf.stddev_pop(_:Column)) + @compileTimeOnly(aggMsg) + def stddev_pop(col: Double): Double = ??? + + @sparkAggregate(sf.sum(_:Column)) + @compileTimeOnly(aggMsg) + def sum[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.sumDistinct(_:Column)) + @compileTimeOnly(aggMsg) + def sumDistinct[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.variance(_:Column)) + @compileTimeOnly(aggMsg) + def variance[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.var_samp(_:Column)) + @compileTimeOnly(aggMsg) + def var_samp[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.var_pop(_:Column)) + @compileTimeOnly(aggMsg) + def var_pop[T : Numeric](col: T): T = ??? + + ////////////////////// + // Window functions // + ////////////////////// + + @sparkWindowFunction0(sf.cume_dist) + @compileTimeOnly(aggMsg) + def cume_dist(): Double = ??? + + @sparkWindowFunction0(sf.dense_rank) + @compileTimeOnly(aggMsg) + def dense_rank(): Double = ??? + + @sparkWindowFunction((sf.lag(_:Column,_:Int)).tupled) + @compileTimeOnly(aggMsg) + def lag[T](col: T, offset: Int): T = ??? + + @sparkWindowFunction((sf.lead(_:Column,_:Int)).tupled) + @compileTimeOnly(aggMsg) + def lead[T](col: T, offset: Int): T = ??? + + @sparkWindowFunction(sf.ntile) + @compileTimeOnly(aggMsg) + def ntile(n: Int): Int = ??? + + @sparkWindowFunction0(sf.percent_rank) + @compileTimeOnly(aggMsg) + def percent_rank(): Double = ??? + + @sparkWindowFunction0(sf.rank) + @compileTimeOnly(aggMsg) + def rank(): Int = ??? + + @sparkWindowFunction0(sf.row_number) + @compileTimeOnly(aggMsg) + def row_number(): Long = ??? + + //////////////////// + // Misc functions // + //////////////////// + + @sparkFunction(sf.md5) + def md5(col: Array[Byte]): String = org.apache.commons.codec.digest.DigestUtils.md5Hex(col) + + @sparkFunction(sf.sha1) + def sha1(col: Array[Byte]): String = org.apache.commons.codec.digest.DigestUtils.sha1Hex(col) + + @sparkFunction((sf.sha2(_, _)).tupled) + @compileTimeOnly("Implemented only in frameless expressions") + def sha2(col: Array[Byte], numBits: Int): String = ??? + + @sparkFunction(sf.crc32) + @compileTimeOnly("Implemented only in frameless expressions") + def sha2(col: Array[Byte]): Long = ??? + + @sparkFunction(sf.hash) + @compileTimeOnly("Implemented only in frameless expressions") + def hash(cols: Any*): Int = ??? + + +} diff --git a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala new file mode 100644 index 000000000..4558fda06 --- /dev/null +++ b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala @@ -0,0 +1,445 @@ +package frameless.macros + +import frameless.functions.quoted.QuotedFunc +import frameless.{TypedColumn, TypedEncoder} +import org.apache.spark.sql, sql.ColumnName + +import scala.collection.immutable.Queue +import scala.reflect.macros.{TypecheckException, whitebox} + +class ColumnMacros(val c: whitebox.Context) { + import c.universe._ + + private val TypedExpressionEncoder = reify(frameless.TypedExpressionEncoder) + private val TypedDataset = reify(frameless.TypedDataset) + private val ColumnName = weakTypeOf[ColumnName] + + private def toColumn[A : WeakTypeTag, B : WeakTypeTag]( + selectorStr: String, + encoder: c.Expr[TypedEncoder[B]] + ): Tree = { + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val typedCol = appliedType( + weakTypeOf[TypedColumn[_, _]].typeConstructor, A, B + ) + + + val datasetCol = c.typecheck( + q"new $ColumnName($selectorStr).as[$B]($TypedExpressionEncoder.apply[$B]($encoder))" + ) + + c.typecheck(q"new $typedCol($datasetCol)") + } + + // could be used to reintroduce apply('foo) + // $COVERAGE-OFF$ Currently unused + def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + val B = weakTypeOf[B].dealias + val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})") + c.typecheck(q"${c.prefix}.col[$B]($witness)") + } + // $COVERAGE-ON$ + + def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + def fail(tree: Tree) = { + val err = + s"Could not create a column identifier from $tree - try using _.a.b syntax" + c.abort(tree.pos, err) + } + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val selectorStr = selector.tree match { + case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => body match { + case `argName`(strs) => strs.mkString(".") + case other => fail(other) + } + // $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked + case other => fail(other) + // $COVERAGE-ON$ + } + + toColumn[A, B](selectorStr, encoder) + } + + def fromExpr[A : WeakTypeTag, B : WeakTypeTag](expr: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + def fail(tree: Tree)(reason: String, pos: Position = tree.pos) = c.abort( + tree.pos, + s"Could not expand expression $tree$reason" + ) + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val result = expr.tree match { + case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => + val Extractor = ExprExtractor(argName) + val selectCols = body match { + case Extractor(extracted) => extracted match { + // top-level construct; have to unwrap the args + case Right(Construct(tpe, args)) => + Expr.toNamedColumns(args) + case Right(colExpr) => + List(Expr.refold(colExpr)) + case Left(errs) => + errs match { + case (tree, err) :: Nil => fail(tree)(s": $err") + case multi => + multi foreach { + case (tree, err) => c.error(tree.pos, err) + } + fail(body)(": multiple errors composing projection") + } + } + + case other => + val o = other + fail(o)(": could not compose projection") + } + val df = q"${c.prefix}.dataset.toDF().select(..$selectCols)" + val ds = q"$df.as[$B]($TypedExpressionEncoder.apply[$B])" + val typedDs = q"$TypedDataset.create[$B]($ds)($encoder)" + + typedDs + } + + + c.typecheck(result) + } + + sealed trait Expr { + def tpe: Type + } + + object Expr { + // zip down to the leaves and construct an expression that gives us a Column + // this is some kind of hylo-mylo-bananafanafofylo-morphism. + def refold(expr: Expr): Tree = expr match { + case LiteralExpr(tpe, tree) => q"org.apache.spark.sql.functions.lit($tree)" + case Column(tpe, selectorStr) => + q"new $ColumnName($selectorStr)" + case FunctionApplication(tpe, src, fn, args) => + val argTrees = args.map { + case Left(tree) => tree + case Right(argExpr) => Expr.refold(argExpr) + } + q"$src.$fn(..$argTrees)" + case BinaryOperator(tpe, lhs, op, rhs) => + val lhsTree = refold(lhs) + val rhsTree = refold(rhs) + Apply(Select(lhsTree, op), List(rhsTree)) + case UnaryRAOperator(tpe, rhs, op) => + val rhsTree = refold(rhs) + Select(rhsTree, op) + case Construct(tpe, args) => + val argTrees = toNamedColumns(args) + q"org.apache.spark.sql.functions.struct(..$argTrees)" + case TrustMeBro(tpe, tree) => tree + } + + def toNamedColumns(exprMap: Map[String, Expr]): List[Tree] = exprMap.toList.map { + case (colName, expr) => + val colTree = refold(expr) + q"$colTree.as($colName)" + } + } + + case class Construct(tpe: Type, args: Map[String, Expr]) extends Expr // Construct tuple or case class + case class UnaryRAOperator(tpe: Type, rhs: Expr, op: TermName) extends Expr // Unary right-associative op + case class BinaryOperator(tpe: Type, lhs: Expr, op: TermName, rhs: Expr) extends Expr // Binary op + case class FunctionApplication(tpe: Type, src: Tree, fn: TermName, args: List[Either[Tree, Expr]]) extends Expr // SQL function + case class Column(tpe: Type, selectorStr: String) extends Expr // Column of this dataset + case class LiteralExpr(tpe: Type, tree: Tree) extends Expr // Literal expression + + // manually added tree (temporary until functions can be implemented; needed for operator substitution) + case class TrustMeBro(tpe: Type, tree: Tree) extends Expr + + case class ExprExtractor(Rooted: NameExtractor) { + private val This = this + + // substitute these operators with functions (or other trees) + // i.e. two string columns can't be + together; they must be concat(a, b) instead. + val subBinaryOperators: Map[MethodSymbol, (Tree, Tree) => Tree] = Map( + weakTypeOf[String].member(TermName("+").encodedName).asMethod -> + ((lhs, rhs) => q"org.apache.spark.sql.functions.concat($lhs, $rhs)") + ) + + val subUnaryOperators: Map[MethodSymbol, Tree => Tree] = Map( + weakTypeOf[String].member(TermName("length")).asMethod -> + (lhs => q"org.apache.spark.sql.functions.length($lhs)"), + weakTypeOf[Array[_]].typeConstructor.member(TermName("length")).asMethod -> + (lhs => q"org.apache.spark.sql.functions.size($lhs)"), + weakTypeOf[Vector[_]].typeConstructor.member(TermName("length")).asMethod -> + (lhs => q"org.apache.spark.sql.functions.size($lhs)") + ) + + def unapply(tree: Tree): Option[Either[List[(Tree, String)], Expr]] = tree match { + + // A single column with no expression around it + case Rooted(strs) if tree.symbol.isMethod && tree.symbol.asMethod.isCaseAccessor => + Some(Right(Column(tree.tpe, strs.mkString(".")))) + + // A unary operator - the operator must exist on org.apache.spark.sql.Column + case sel @ Select(This(rhsE), op: TermName) => Some { + subUnaryOperators.get(sel.symbol.asMethod) + .map { + sub => rhsE.right.map(rhs => TrustMeBro(tree.tpe, sub(Expr.refold(rhs)))) + }.getOrElse { + if(isColumnOp(op, 0)) { + rhsE.right.map { + rhs => UnaryRAOperator(tree.tpe, rhs, op) + } + } else addError(rhsE)(tree, s"${op.decodedName} is not a valid column operator") + } + } + + // A unary operator (which is a no-arg method) - the operator must exist on org.apache.spark.sql.Column + case Apply(sel @ Select(This(rhsE), op: TermName), List()) => Some { + subUnaryOperators.get(sel.symbol.asMethod) + .map { + sub => rhsE.right.map(rhs => TrustMeBro(tree.tpe, sub(Expr.refold(rhs)))) + }.getOrElse { + + if(isColumnOp(op, 0)) { + // $COVERAGE-OFF$ - can't find a unary column op that isn't substituted and is a no-arg apply tree + rhsE.right.map { + rhs => UnaryRAOperator(tree.tpe, rhs, op) + } + // $COVERAGE-ON$ + } else addError(rhsE)(tree, s"${op.decodedName} is not a valid column operator") + } + } + + // A literal constant (would it be useful to distinguish this from non-constant literal? + case Literal(_) => Some(Right(LiteralExpr(tree.tpe, tree))) + + // Constructing a case class + case Apply(sel @ Select(qualifier, TermName("apply")), AllExprs(argsE)) => Some { + if (isConstructor(sel, tree.tpe)) { + argsE.right.map { + args => + val params = sel.symbol.asMethod.paramLists.head + val names = sel.symbol.asMethod.paramLists.head.zip(args).map { + case (param, paramExpr) => param.name.toString -> paramExpr + }.toMap + Construct(tree.tpe, names) + } + } else addError(argsE)(tree, "Only constructor can be used here") + } + // Constructing a tuple or parameterized case class + case Apply(ta @ TypeApply(sel @ Select(qualifier, TermName("apply")), typArgs), AllExprs(argsE)) => Some { + if(isConstructor(ta, tree.tpe)) { + argsE.right.map { + args => + val params = sel.symbol.asMethod.paramLists.head + val names = sel.symbol.asMethod.paramLists.head.zip(args).map { + case (param, paramExpr) => param.name.toString -> paramExpr + }.toMap + Construct(tree.tpe, names) + } + } else addError(argsE)(tree, "Only constructor can be used here") + } + + // A binary operator - the operator must exist on org.apache.spark.sql.Column + case Apply(sel @ Select(This(lhsE), op: TermName), List(This(rhsE))) => Some { + subBinaryOperators.get(sel.symbol.asMethod).map { + sub => requireBoth(lhsE, rhsE)( + (lhs, rhs) => failFrom(lhs, rhs), + (lhs, rhs) => Right(TrustMeBro(tree.tpe, sub(Expr.refold(lhs), Expr.refold(rhs)))) + ) + }.getOrElse { + if (isColumnOp(op, 1)) { + // special case - substitute === for == + val realOp = if(op.decodedName.toString == "==") TermName("===").encodedName.toTermName else op + requireBoth(lhsE, rhsE)( + (lhs, rhs) => failFrom(lhs, rhs), + (lhs, rhs) => Right(BinaryOperator(tree.tpe, lhs, realOp, rhs)) + ) + } else addError(lhsE, rhsE)(tree, s"${op.decodedName} is not a valid column operator") + } + } + + // A function application + case Apply(fn @ (Select(_, _) | TypeApply(_, _)), argTrees) => Some { + // if the function is annotated with a QuotedFunc, it gets rewritten to the native spark function + fn.symbol.annotations.find(_.tree.tpe <:< weakTypeOf[QuotedFunc]).map { + annot => annot.tree match { + case Apply(Select(New(tpt), _), List(sparkFuncTree)) => + val Function(sparkFnArgs, body) = sparkFuncTree match { + case SingleExpression(f@Function(_, _)) => f + case Select(SingleExpression(f @ Function(_, _)), TermName("tupled")) => f + case other => + val o = other + println(o) + EmptyTree + } + + + val sparkArgs = if(sparkFnArgs.nonEmpty) { + sparkFnArgs.zipAll(argTrees.map(Some(_)), sparkFnArgs.last, None).flatMap { + case (_, None) => Nil + case (ValDef(_, _, typ, _), Some(This(Right(expr)))) if typ.tpe <:< weakTypeOf[sql.Column] => + List(Right(expr)) + case (ValDef(_, _, typ, _), Some(This(Right(expr)))) if typ.tpe <:< weakTypeOf[Seq[sql.Column]] => + List(Right(expr)) + case (ValDef(_, _, typ, _), Some(Tuple2Tree(This(Right(a)), This(Right(b))))) + if typ.tpe <:< weakTypeOf[Seq[(sql.Column, sql.Column)]] => + List(Right(a), Right(b)) + case (arg, Some(argTree)) => + List(Left(argTree)) + } + } else { + Nil + } + + val (src, term) = body match { + case Apply(Select(q, t), _) => (q, t) + } + + Right(FunctionApplication(tree.tpe, src, term.toTermName, sparkArgs)) + case other => + println(other) + Left(List((tree, "Function application not currently supported"))) + } + }.getOrElse { + Left(List((tree, "Function application not currently supported"))) + } + + } + //Some(argsE.right.map(args => FunctionApplication(tree.tpe, src, fn, args))) + + case Apply(This(fn), implicitArgs) => Some(fn) + + case Apply(_, _) => Some(Left(List((tree, "Function application not currently supported")))) + + case Typed(This(expr), _) => Some(expr) + + case _ => None + } + + private def isColumnOp(name: TermName, numArgs: Int): Boolean = { + val sym = weakTypeOf[org.apache.spark.sql.Column].member(name) + sym.isMethod && { + val bothEmptyArgs = numArgs == 0 && sym.asMethod.paramLists.isEmpty + val matchingArgCounts = sym.asMethod.paramLists.map(_.length) == List(numArgs) + bothEmptyArgs || matchingArgCounts + } + } + + def isConstructor(tree: Tree, result: Type, appliedTypes: Option[List[Type]] = None): Boolean = { + val (qualifier, typeArgs) = tree match { + case TypeApply(Select(q, _), args) => (q, args) + case Select(q, _) => (q, Nil) + } + + val MethodType(params, _) = tree.tpe + + val companion = Option(result.companion).filterNot(_ == NoType).getOrElse { + try { + c.typecheck(Ident(result.typeSymbol.name.toTermName)).tpe + } catch { + case TypecheckException(_, _) => NoType + } + } + + if(tree.symbol.isMethod && companion =:= qualifier.tpe) { + // must have same args as primary constructor + val meth = tree.symbol.asMethod + result.members.find(s => s.isConstructor && s.asMethod.isPrimaryConstructor) match { + case Some(defaultConstructor) => + defaultConstructor.asMethod.paramLists.head.zip(params).forall { + case (arg1, arg2) => + val sameArg = arg1.name == arg2.name + val arg1Typ = result.member(arg1.asTerm.name).typeSignatureIn(result).finalResultType + val sameType = arg1Typ =:= arg2.typeSignature + sameArg && sameType + } + case _ => false + } + } else false + } + + object AllExprs { + def unapply(trees: List[Tree]): Option[Either[List[(Tree, String)], List[Expr]]] = { + val eithersOpt = trees.map(This.unapply).foldRight[Option[List[Either[(Tree, String), Expr]]]](Some(Nil)) { + (nextOpt, accumOpt) => for { + next <- nextOpt + accum <- accumOpt + } yield next match { + case Left(errs) => errs.map(Left(_)) ::: accum + case Right(expr) => Right(expr) :: accum + } + } + + // wish cats was here + eithersOpt.map { + eithers => + eithers.foldRight[Either[List[(Tree, String)], List[Expr]]](Right(Nil)) { + (next, accum) => next.fold( + err => Left(err :: accum.left.toOption.getOrElse(Nil)), + expr => accum.right.map(expr :: _) + ) + } + } + } + } + } + + def addError[A]( + exprs: Either[List[(Tree, String)], A]*)( + tree: Tree, err: String + ): Either[List[(Tree, String)], Expr] = failFrom(exprs: _*).left.map(_ :+ (tree -> err)) + + def failFrom[A](exprs: Either[List[(Tree, String)], A]*): Left[List[(Tree, String)], Expr] = Left { + exprs.foldLeft[List[(Tree, String)]](Nil) { + (accum, next) => next.fold( + errs => accum ::: errs, + _ => accum + ) + } + } + + def requireBoth[A, B, C](first: Either[A, B], second: Either[A, B])( + oneLeft: (Either[A, B], Either[A, B]) => C, + bothRight: (B, B) => C + ): C = (first, second) match { + case (Right(f), Right(s)) => bothRight(f, s) + case _ => oneLeft(first, second) + } + + case class NameExtractor(name: TermName) { + private val This = this + def unapply(tree: Tree): Option[Queue[String]] = { + tree match { + case Ident(`name`) => Some(Queue.empty) + case Select(This(strs), nested) => Some(strs enqueue nested.toString) + case _ => None + } + } + } + + object ArgName { + def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name)) + } + + object Tuple2Tree { + def unapply(tree: Tree): Option[(Tree, Tree)] = tree match { + case q"($a, $b)" => Some((a, b)) + case q"$predef.ArrowAssoc[..$tptA]($a).->[..$tptB]($b)" => Some((a, b)) + case other => + None + } + } + + object SingleExpression { + def unapply(tree: Tree): Option[Tree] = tree match { + case Block(Nil, expr) => Some(expr) + case Block(_, _) => None + case other => Some(other) + } + } +} diff --git a/dataset/src/test/scala/frameless/ColTests.scala b/dataset/src/test/scala/frameless/ColTests.scala index ad62aa068..d6aa27191 100644 --- a/dataset/src/test/scala/frameless/ColTests.scala +++ b/dataset/src/test/scala/frameless/ColTests.scala @@ -29,6 +29,42 @@ class ColTests extends TypedDatasetSuite { () } + test("colApply") { + val x4 = TypedDataset.create[X4[Int, String, Long, Boolean]](Nil) + val t4 = TypedDataset.create[(Int, String, Long, Boolean)](Nil) + val x4x4 = TypedDataset.create[X4X4[Int, String, Long, Boolean]](Nil) + + x4(_.a) + t4(_._1) + x4[Int](_.a) + t4[Int](_._1) + + illTyped("x4[String](_.a)", "type mismatch;\n found : Int\n required: String") + + x4(_.b) + t4(_._2) + + x4[String](_.b) + t4[String](_._2) + + illTyped("x4[Int](_.b)", "type mismatch;\n found : String\n required: Int") + + x4x4(_.xa.a) + x4x4[Int](_.xa.a) + x4x4(_.xa.b) + x4x4[String](_.xa.b) + + x4x4(_.xb.a) + x4x4[Int](_.xb.a) + x4x4(_.xb.b) + x4x4[String](_.xb.b) + + illTyped("x4x4[String](_.xa.a)", "type mismatch;\n found : Int\n required: String") + illTyped("x4x4(item => item.xa.a * 20)", "Could not create a column identifier from item\\.xa\\.a\\.\\*\\(20\\) - try using _\\.a\\.b syntax") + + () + } + test("colMany") { type X2X2 = X2[X2[Int, String], X2[Long, Boolean]] val x2x2 = TypedDataset.create[X2X2](Nil) diff --git a/dataset/src/test/scala/frameless/CreateTests.scala b/dataset/src/test/scala/frameless/CreateTests.scala index c1ce719b0..d9088c913 100644 --- a/dataset/src/test/scala/frameless/CreateTests.scala +++ b/dataset/src/test/scala/frameless/CreateTests.scala @@ -1,6 +1,6 @@ package frameless -import org.scalacheck.Prop +import org.scalacheck.{Arbitrary, Prop} import org.scalacheck.Prop._ import org.scalatest.Matchers @@ -28,6 +28,66 @@ class CreateTests extends TypedDatasetSuite with Matchers { Vector[(Food, Country)]] _)) } + test("array fields") { + + def prop[T](implicit arb: Arbitrary[Array[T]], encoder: TypedEncoder[X1[Array[T]]]) = forAll { + data: Array[T] => + val Seq(X1(arr)) = TypedDataset.create(Seq(X1(data))).collect().run() + Prop(arr.sameElements(data)) + } + + check(prop[Boolean]) + check(prop[Byte]) + check(prop[Short]) + check(prop[Int]) + check(prop[Long]) + check(prop[Float]) + check(prop[Double]) + check(prop[X1[String]]) + check(prop[String]) + + } + + test("vector fields") { + + def prop[T](implicit arb: Arbitrary[Vector[T]], encoder: TypedEncoder[X1[Vector[T]]]) = forAll { + data: Vector[T] => + val Seq(X1(vec)) = TypedDataset.create(Seq(X1(data))).collect().run() + Prop(vec == data) + } + + check(prop[Boolean]) + check(prop[Byte]) + check(prop[Short]) + check(prop[Int]) + check(prop[Long]) + check(prop[Float]) + check(prop[Double]) + check(prop[X1[String]]) + check(prop[String]) + + } + + test("map fields") { + + def prop[A, B](implicit arb: Arbitrary[Map[A, B]], encoder: TypedEncoder[X1[Map[A, B]]]) = forAll { + data: Map[A, B] => + val Seq(X1(map)) = TypedDataset.create(Seq(X1(data))).collect().run() + Prop(map == data) + } + + check(prop[String, Boolean]) + check(prop[String, Byte]) + check(prop[String, Short]) + check(prop[String, Int]) + check(prop[String, Long]) + check(prop[String, Float]) + check(prop[String, Double]) + check(prop[String, X1[String]]) + check(prop[String, String]) + + } + test("not alligned columns should throw an exception") { val v = Vector(X2(1,2)) val df = TypedDataset.create(v).dataset.toDF() diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala index 837afed2d..81710c69a 100644 --- a/dataset/src/test/scala/frameless/FilterTests.scala +++ b/dataset/src/test/scala/frameless/FilterTests.scala @@ -22,7 +22,7 @@ class FilterTests extends TypedDatasetSuite { test("filter with arithmetic expressions: addition") { check(forAll { (data: Vector[X1[Int]]) => val ds = TypedDataset.create(data) - val res = ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector + val res = ds.filter((ds(_.a) + 1) === (ds(_.a) + 1)).collect().run().toVector res ?= data }) } @@ -31,7 +31,7 @@ class FilterTests extends TypedDatasetSuite { val t = X1(1) :: X1(2) :: X1(3) :: Nil val tds: TypedDataset[X1[Int]] = TypedDataset.create(t) - assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1))) - assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1))) + assert(tds.filter(tds(_.a) * 2 === 2).collect().run().toVector === Vector(X1(1))) + assert(tds.filter(tds(_.a) * 3 === 3).collect().run().toVector === Vector(X1(1))) } } diff --git a/dataset/src/test/scala/frameless/SelectExprTests.scala b/dataset/src/test/scala/frameless/SelectExprTests.scala new file mode 100644 index 000000000..af368c81b --- /dev/null +++ b/dataset/src/test/scala/frameless/SelectExprTests.scala @@ -0,0 +1,252 @@ +package frameless + +import org.scalatest.Matchers +import shapeless.test.illTyped + +case class Foo(a: Int, b: String) { + // tests failure of unsupported unary operator + def unary_~ : Foo = copy(a = a + 1) + + // tests failure of binary operator with different arg count (Column#when takes two args) + def when(a: Int): Int = this.a + a +} + +class SelectExprTests extends TypedDatasetSuite with Matchers { + + test("selectExpr with a single column") { + val ds = TypedDataset.create(Seq(X2(20, "twenty"), X2(30, "thirty"))) + val ds2 = ds.selectExpr(_.a) + val ds3 = ds.selectExpr(_.b) + ds2.collect().run() should contain theSameElementsAs Seq(20, 30) + ds3.collect().run() should contain theSameElementsAs Seq("twenty", "thirty") + } + + test("selectExpr with a single column and unary operation") { + val ds = TypedDataset.create(Seq(X1(20), X1(30))) + val ds2 = ds.selectExpr(-_.a) + ds2.collect().run() should contain theSameElementsAs Seq(-20, -30) + } + + test("selectExpr with a binary operation between two columns") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => x.a * x.b) + ds2.collect().run() should contain theSameElementsAs Seq(200, 600) + } + + test("selectExpr with a binary operation between a column and a literal") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => x.a * 10) + ds2.collect().run() should contain theSameElementsAs Seq(100, 200) + } + + test("selectExpr constructing a tuple") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => (x.a, x.b)) + ds2.collect().run() should contain theSameElementsAs Seq((10, 20), (20, 30)) + } + + test("selectExpr constructing a tuple with an operation") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => (x.a * 10, x.b + 1)) + ds2.collect().run() should contain theSameElementsAs Seq((100, 21), (200, 31)) + } + + test("selectExpr constructing a nested tuple with substitute function") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + val ds2 = ds.selectExpr(x => ((x.a * 1.1, x.c + "foo"), x.b / 2)) + ds2.collect().run() should contain theSameElementsAs Seq( + ((11.0, "foofoo"), 10.0), + ((22.0, "barfoo"), 15.0) + ) + + val ds3 = ds.selectExpr(x => (x.a, x.c.length)) + } + + test("substitute length() on Array and Vector") { + val ds = TypedDataset.create(Seq( + X2(Array(1, 2, 3), Vector(1, 2)) + )) + val ds2 = ds.selectExpr(x => (x.a.length, x.b.length)) + ds2.collect().run() should contain theSameElementsAs Seq((3, 2)) + } + + test("selectExpr constructing a case class") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + val ds2 = ds.selectExpr(x => X2(x.a, x.b)) + ds2.collect().run() should contain theSameElementsAs Seq( + X2(10, 20.0), + X2(20, 30.0) + ) + } + + test("selectExpr constructing an unparameterized case class") { + + val ds = TypedDataset.create(Seq(Foo(10, "ten"), Foo(20, "twenty"))) + val ds2 = ds.selectExpr(f => Foo(f.a * 5, f.b + "foo")) + ds2.collect().run() should contain theSameElementsAs Seq( + Foo(50, "tenfoo"), + Foo(100, "twentyfoo") + ) + + } + + test("selectExpr constructing a nested case class") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + val ds2 = ds.selectExpr(x => X2(X2(x.a, x.b), x.c)) + ds2.collect().run() should contain theSameElementsAs Seq( + X2(X2(10, 20.0), "foo"), + X2(X2(20, 30.0), "bar") + ) + } + + test("selectExpr accessing a nested case class field") { + val ds = TypedDataset.create(Seq( + X2(X2(11.0, "foofoo"), 10.0), + X2(X2(22.0, "barfoo"), 15.0) + )) + + val ds2 = ds.selectExpr(x => (x.a.a * x.b, x.a.b)) + + ds2.collect().run() should contain theSameElementsAs Seq( + (110.0, "foofoo"), + (330.0, "barfoo") + ) + } + + test("can't use a non-constructor apply method") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + object Foo { + def apply(str: String, i: Int): String = str + i.toString + } + + object FooP { + def apply[A, B](a: A, b: B): X2[A, B] = X2(a, b) + } + + illTyped( + "val ds2 = ds.selectExpr(x => Foo(x.c, x.a))", + ".*Only constructor can be used here" + ) + + illTyped( + "val ds2 = ds.selectExpr(x => FooP(x.c, x.a))", + ".*Only constructor can be used here" + ) + + } + + test("quoted functions expanded to spark native column functions") { + import frameless.functions.quoted._ + val ds = TypedDataset.create(Seq(X4(10, 20.0, "foo", 10L), X4(-20, 30.0, "bar", 20L))) + + def test[U : TypedEncoder](expected: U*)(fn: TypedDataset[X4[Int, Double, String, Long]] => TypedDataset[U]) = + fn(ds).collect().run() should contain theSameElementsAs expected + + def testConv[T, U : TypedEncoder](expected: T*)(fn: TypedDataset[X4[Int, Double, String, Long]] => TypedDataset[U])(conv: U => T) = + fn(ds).collect().run().map(conv) should contain theSameElementsAs expected + + test(10, 20)(_.selectExpr(x => abs(x.a))) + testConv(List(10, 10, 10), List(-20, -20, -20))(_.selectExpr(x => array(x.a, x.a, x.a)))(_.toList) + test(Map("foo" -> 10), Map("bar" -> -20))(_.selectExpr(x => map(x.c -> x.a))) + test("foo", "bar")(_.selectExpr(x => coalesce(null, x.c))) + test(("foo", ""), ("bar", ""))(_.selectExpr(x => (x.c, input_file_name()))) + test((false, true), (false, true))(_.selectExpr(x => (isnan(x.b), isnan(Double.NaN)))) + test((false, true), (false, true))(_.selectExpr(x => (isnull(x.c), isnull(null: String)))) + test(0L, 1L)(_.selectExpr(x => monotonically_increasing_id())) + test(20.0, 30.0)(_.selectExpr(x => nanvl(Double.NaN, x.b))) + test(-10, 20)(_.selectExpr(x => negate(x.a))) + + // no way to compare output, but test to make sure the functions can be executed + ds.selectExpr(x => (x.a, rand(10L), rand(), randn(10L), randn())).collect().run() + + test(math.sqrt(20.0), math.sqrt(30.0))(_.selectExpr(x => sqrt(x.b))) + test("ten", "not ten")(_.selectExpr(x => when(x.a == 10, "ten") otherwise "not ten")) + test(~10, ~(-20))(_.selectExpr(x => bitwiseNOT(x.a))) + test( + (math.acos(20.0 / 32.0), math.asin(20.0 / 32.0), math.atan(20.0), math.atan2(20.0, 20.0)), + (math.acos(30.0 / 32.0), math.asin(30.0 / 32.0), math.atan(30.0), math.atan2(30.0, 30.0)) + )(_.selectExpr(x => (acos(x.b / 32), asin(x.b / 32), atan(x.b), atan2(x.b, x.b)))) + test(java.lang.Long.toBinaryString(10L), java.lang.Long.toBinaryString(20L))(_.selectExpr(x => bin(x.d))) + + // aggregate functions + test(2L)(_.selectExpr(x => count(x.a))) + test(2L)(_.selectExpr(x => countDistinct(x.a))) + test(2L)(_.selectExpr(x => countDistinct(x.a, x.b))) + + illTyped("val a: Int = 22.22") + } + + //TODO: could we just UDF the function in this case? + test("arbitrary functions not yet supported") { + def strfun(s: String) = s"fun${s}fun" + + object Fun { + def strfun(s: String) = s"Funfun${s}Funfun" + } + + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.a, strfun(x.c)))", + ".*Function application not currently supported" + ) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.a, Fun.strfun(x.c)))", + ".*Function application not currently supported" + ) + } + + test("fails if operator is not available on Column") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => x.c substring x.a)", + ".*substring is not a valid column operator" + ) + + val ds3 = TypedDataset.create(Seq(X2(Foo(10, "ten"), "foo"), X2(Foo(20, "twenty"), "bar"))) + + illTyped( + "val ds4 = ds3.selectExpr(x => ~x.a)", + ".*~ is not a valid column operator" + ) + + illTyped( + """val ds5 = ds3.selectExpr(x => x.a when 5)""", + ".*when is not a valid column operator" + // would be nice if this error was clearer about the arity mismatch being the cause + ) + } + + test("multiple errors are all shown before aborting") { + def strfun(s: String) = s"fun${s}fun" + + object Fun { + def strfun(s: String) = s"Funfun${s}Funfun" + } + + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.a, (strfun(x.c), Fun.strfun(x.c))))", + ".*Function application not currently supported" + // this is a limitation of illTyped - it only looks at the first emitted error, rather than the error that aborted + ) + } + + test("errors emitted from nested expressions in allowed binary operator") { + // tests failFrom case in binary operator case + + def intfun(i: Int): Int = i + 1 + + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.c, x.a + intfun(x.a)))", + ".*Function application not currently supported" + ) + + } +} diff --git a/dataset/src/test/scala/frameless/SelectTests.scala b/dataset/src/test/scala/frameless/SelectTests.scala index cd3d6c411..99cf86d01 100644 --- a/dataset/src/test/scala/frameless/SelectTests.scala +++ b/dataset/src/test/scala/frameless/SelectTests.scala @@ -297,7 +297,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) + const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) + const).collect().run().toVector val data2 = data.map { case X1(a) => num.plus(a, const) } dataset2 ?= data2 @@ -319,7 +319,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) * const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) * const).collect().run().toVector val data2 = data.map { case X1(a) => num.times(a, const) } dataset2 ?= data2 @@ -341,7 +341,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) - const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) - const).collect().run().toVector val data2 = data.map { case X1(a) => num.minus(a, const) } dataset2 ?= data2 @@ -363,7 +363,7 @@ class SelectTests extends TypedDatasetSuite { val ds = TypedDataset.create(data) if (const != 0) { - val dataset2 = ds.select(ds('a) / const).collect().run().toVector.asInstanceOf[Vector[A]] + val dataset2 = ds.select(ds(_.a) / const).collect().run().toVector.asInstanceOf[Vector[A]] val data2 = data.map { case X1(a) => frac.div(a, const) } dataset2 ?= data2 } else 0 ?= 0 @@ -379,17 +379,17 @@ class SelectTests extends TypedDatasetSuite { assert(t.select(t.col('_1)).collect().run().toList === List(2)) // Issue #54 val fooT = t.select(t.col('_1)).map(x => Tuple1.apply(x)).as[Foo] - assert(fooT.select(fooT('i)).collect().run().toList === List(2)) + assert(fooT.select(fooT(_.i)).collect().run().toList === List(2)) } test("unary - on arithmetic") { val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) - assert(e.select(-e('_1)).collect().run().toVector === Vector(-1, -2, -2)) - assert(e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(-3L, -6L, -3L)) + assert(e.select(-e(_._1)).collect().run().toVector === Vector(-1, -2, -2)) + assert(e.select(-(e(_._1) + e(_._3))).collect().run().toVector === Vector(-3L, -6L, -3L)) } test("unary - on strings should not type check") { val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) - illTyped("""e.select( -e('_2) )""") + illTyped("""e.select( -e(_._2) )""") } } \ No newline at end of file diff --git a/dataset/src/test/scala/frameless/XN.scala b/dataset/src/test/scala/frameless/XN.scala index 4fdab552e..0b19771e2 100644 --- a/dataset/src/test/scala/frameless/XN.scala +++ b/dataset/src/test/scala/frameless/XN.scala @@ -64,3 +64,5 @@ object X5 { implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] = Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e)) } + +case class X4X4[A, B, C, D](xa: X4[A, B, C, D], xb: X4[A, B, C, D]) diff --git a/docs/src/main/tut/GettingStarted.md b/docs/src/main/tut/GettingStarted.md index dd29ae587..03d037196 100644 --- a/docs/src/main/tut/GettingStarted.md +++ b/docs/src/main/tut/GettingStarted.md @@ -63,13 +63,13 @@ val apartmentsTypedDS2 = spark.createDataset(apartments).typed This is how we select a particular column from a `TypedDataset`: ```tut:book -val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS('city)) +val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS(_.city)) ``` This is completely safe, for instance suppose we misspell `city`: ```tut:book:fail -apartmentsTypedDS.select(apartmentsTypedDS('citi)) +apartmentsTypedDS.select(apartmentsTypedDS(_.citi)) ``` This gets caught at compile-time, whereas with traditional Spark `Dataset` the error appears at run-time. @@ -81,7 +81,7 @@ apartmentsDS.select('citi) `select()` supports arbitrary column operations: ```tut:book -apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('surface) + 2).show().run() +apartmentsTypedDS.select(apartmentsTypedDS(_.surface) * 10, apartmentsTypedDS(_.surface) + 2).show().run() ``` *Note that unlike the standard Spark api, here `show()` is lazy. It requires to apply `run()` for the @@ -91,14 +91,14 @@ apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('su Let us now try to compute the price by surface unit: ```tut:book:fail -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface)) ^ +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface)) ^ ``` Argh! Looks like we can't divide a `TypedColumn` of `Double` by `Int`. Well, we can cast our `Int`s to `Double`s explicitly to proceed with the computation. ```tut:book -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface).cast[Double]) +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface).cast[Double]) priceBySurfaceUnit.collect().run() ``` @@ -107,7 +107,7 @@ Alternatively, we can perform the cast implicitly: ```tut:book import frameless.implicits.widen._ -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface)) +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface)) priceBySurfaceUnit.collect.run() ``` @@ -115,7 +115,7 @@ Looks like it worked, but that `cast` looks unsafe right? Actually it is safe. Let's try to cast a `TypedColumn` of `String` to `Double`: ```tut:book:fail -apartmentsTypedDS('city).cast[Double] +apartmentsTypedDS(_.city).cast[Double] ``` The compile-time error tells us that to perform the cast, an evidence (in the form of `CatalystCast[String, Double]`) must be available. @@ -136,7 +136,7 @@ The cast is valid and the expression compiles: ```tut:book case class UpdatedSurface(city: String, surface: Int) -val updated = apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('surface) + 2).as[UpdatedSurface] +val updated = apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.surface) + 2).as[UpdatedSurface] updated.show(2).run() ``` @@ -144,7 +144,7 @@ Next we try to cast a `(String, String)` to an `UpdatedSurface` (which has types The cast is not valid and the expression does not compile: ```tut:book:fail -apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('city)).as[UpdatedSurface] +apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.city)).as[UpdatedSurface] ``` ### Projections @@ -161,7 +161,7 @@ import frameless.implicits.widen._ val aptds = apartmentsTypedDS // For shorter expressions case class ApartmentDetails(city: String, price: Double, surface: Int, ratio: Double) -val aptWithRatio = aptds.select(aptds('city), aptds('price), aptds('surface), aptds('price) / aptds('surface)).as[ApartmentDetails] +val aptWithRatio = aptds.select(aptds(_.city), aptds(_.price), aptds(_.surface), aptds(_.price) / aptds(_.surface)).as[ApartmentDetails] ``` Suppose we only want to work with `city` and `ratio`: @@ -222,7 +222,7 @@ val udf = apartmentsTypedDS.makeUDF(priceModifier) val aptds = apartmentsTypedDS // For shorter expressions -val adjustedPrice = aptds.select(aptds('city), udf(aptds('city), aptds('price))) +val adjustedPrice = aptds.select(aptds(_.city), udf(aptds(_.city), aptds(_.price))) adjustedPrice.show().run() ``` @@ -230,12 +230,12 @@ adjustedPrice.show().run() ## GroupBy and Aggregations Let's suppose we wanted to retrieve the average apartment price in each city ```tut:book -val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('price))) +val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.price))) priceByCity.collect().run() ``` Again if we try to aggregate a column that can't be aggregated, we get a compilation error ```tut:book:fail -apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('city))) ^ +apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.city))) ^ ``` Next, we combine `select` and `groupBy` to calculate the average price/surface ratio per city: @@ -243,9 +243,9 @@ Next, we combine `select` and `groupBy` to calculate the average price/surface r ```tut:book val aptds = apartmentsTypedDS // For shorter expressions -val cityPriceRatio = aptds.select(aptds('city), aptds('price) / aptds('surface)) +val cityPriceRatio = aptds.select(aptds(_.city), aptds(_.price) / aptds(_.surface)) -cityPriceRatio.groupBy(cityPriceRatio('_1)).agg(avg(cityPriceRatio('_2))).show().run() +cityPriceRatio.groupBy(cityPriceRatio(_._1)).agg(avg(cityPriceRatio(_._2))).show().run() ``` ## Joins @@ -265,7 +265,7 @@ val citiInfoTypedDS = TypedDataset.create(cityInfo) Here is how to join the population information to the apartment's dataset. ```tut:book -val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS('city), citiInfoTypedDS('name)) +val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS(_.city), citiInfoTypedDS(_.name)) withCityInfo.show().run() ``` @@ -278,7 +278,7 @@ We can then select which information we want to continue to work with: case class AptPriceCity(city: String, aptPrice: Double, cityPopulation: Int) withCityInfo.select( - withCityInfo.colMany('_2, 'name), withCityInfo.colMany('_1, 'price), withCityInfo.colMany('_2, 'population) + withCityInfo(_._2.name), withCityInfo(_._1.price), withCityInfo(_._2.population) ).as[AptPriceCity].show().run ``` diff --git a/docs/src/main/tut/TypedDatasetVsSparkDataset.md b/docs/src/main/tut/TypedDatasetVsSparkDataset.md index 5b6775d39..5a85731ff 100644 --- a/docs/src/main/tut/TypedDatasetVsSparkDataset.md +++ b/docs/src/main/tut/TypedDatasetVsSparkDataset.md @@ -116,19 +116,19 @@ with a fully optimized query plan. import frameless.TypedDataset val fds = TypedDataset.create(ds) -fds.filter( fds('i) === 10 ).select( fds('i) ).show().run() +fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).show().run() ``` And the optimized Physical Plan: ```tut:book -fds.filter( fds('i) === 10 ).select( fds('i) ).explain() +fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).explain() ``` And the compiler is our friend. ```tut:fail -fds.filter( fds('i) === 10 ).select( fds('x) ) +fds.filter( fds(_.i) === 10 ).select( fds(_.x) ) ``` ```tut:invisible