From 815b7a2f5c7c87e62abb638752b6a918a2f00b35 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Fri, 17 Feb 2017 18:39:00 -0800 Subject: [PATCH 01/10] Change TypeDataset#apply syntax to use a function Before/After: ```scala case class Foo(a: String, b: Int) case class Bar(foo: Foo, c: Double) val ds = TypedDataset.create(spark.createDataset(Seq.empty[Bar])) // After: val a = ds.select(ds(_.foo.a)) val b = ds.select(ds(_.foo.b)) val c = ds.select(ds(_.c)) // Before: val a = ??? // no equivalent val b = ??? // no equivalent val c = ds.select(ds('c)) ``` I am proposing this change because: 1. It typechecks before any macros are invoked 2. It makes editors like IntelliJ happy (because of #1) 3. It seems more idiomatic to me than using `Symbol`s 4. It allows specifying columns of nested structures, which is impossible with the current syntax The downside is that a macro is used. But, a macro is used (indirectly, through Witness.apply) for the existing syntax anyway. Also, that syntax uses implicit conversions, which the proposed syntax doesn't. The macro introduced for the new syntax is reasonably uncomplicated. I changed `apply` and left `col` intact, so that both syntaxes would be available. If this change is distasteful (understandable due to the BC break), it could be renamed to something other than `apply`. However, I really think this syntax is strictly more powerful than what's currently used for `apply`, and should be the default behavior. --- .../main/scala/frameless/TypedColumn.scala | 3 +- .../main/scala/frameless/TypedDataset.scala | 6 +- .../scala/frameless/column/ColumnMacros.scala | 62 +++++++++++++++++++ .../src/test/scala/frameless/ColTests.scala | 35 +++++++++++ .../test/scala/frameless/FilterTests.scala | 6 +- .../test/scala/frameless/SelectTests.scala | 16 ++--- dataset/src/test/scala/frameless/XN.scala | 2 + 7 files changed, 114 insertions(+), 16 deletions(-) create mode 100644 dataset/src/main/scala/frameless/column/ColumnMacros.scala diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 92719d482..b17d2eafe 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -271,6 +271,7 @@ object TypedColumn { lgen: LabelledGeneric.Aux[T, H], selector: Selector.Aux[H, K, V] ): Exists[T, K, V] = new Exists[T, K, V] {} + } implicit class OrderedTypedColumnSyntax[T, U: CatalystOrdered](col: TypedColumn[T, U]) { @@ -279,4 +280,4 @@ object TypedColumn { def >(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped > other.untyped).typed def >=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped >= other.untyped).typed } -} +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index ad9913131..056ae0349 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -62,11 +62,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val * * It is statically checked that column with such name exists and has type `A`. */ - def apply[A](column: Witness.Lt[Symbol])( - implicit - exists: TypedColumn.Exists[T, column.T, A], + def apply[A](selector: T => A)(implicit encoder: TypedEncoder[A] - ): TypedColumn[T, A] = col(column) + ): TypedColumn[T, A] = macro frameless.column.ColumnMacros.fromFunction[T, A] /** Returns `TypedColumn` of type `A` given it's name. * diff --git a/dataset/src/main/scala/frameless/column/ColumnMacros.scala b/dataset/src/main/scala/frameless/column/ColumnMacros.scala new file mode 100644 index 000000000..f16b3e539 --- /dev/null +++ b/dataset/src/main/scala/frameless/column/ColumnMacros.scala @@ -0,0 +1,62 @@ +package frameless.column + +import frameless.{TypedColumn, TypedEncoder, TypedExpressionEncoder} + +import scala.collection.immutable.Queue +import scala.reflect.macros.whitebox + +class ColumnMacros(val c: whitebox.Context) { + import c.universe._ + + // could be used to reintroduce apply('foo) + def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + val B = weakTypeOf[B].dealias + val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})") + c.typecheck(q"${c.prefix}.col[$B]($witness)") + } + + def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + def fail(tree: Tree) = c.abort( + tree.pos, + s"Could not create a column identifier from $tree - try using _.a.b syntax") + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val selectorStr = selector.tree match { + case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => body match { + case `argName`(strs) => strs.mkString(".") + case other => fail(other) + } + case other => fail(other) + } + + val typedCol = appliedType( + weakTypeOf[TypedColumn[_, _]].typeConstructor, A, B + ) + + val TEEObj = reify(TypedExpressionEncoder) + + val datasetCol = c.typecheck( + q"${c.prefix}.dataset.col($selectorStr).as[$B]($TEEObj.apply[$B]($encoder))" + ) + + c.typecheck(q"new $typedCol($datasetCol)") + } + + case class NameExtractor(name: TermName) { + private val This = this + def unapply(tree: Tree): Option[Queue[String]] = { + tree match { + case Ident(`name`) => Some(Queue.empty) + case Select(This(strs), nested) => Some(strs enqueue nested.toString) + case Apply(This(strs), List()) => Some(strs) + case _ => None + } + } + } + + object ArgName { + def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name)) + } +} diff --git a/dataset/src/test/scala/frameless/ColTests.scala b/dataset/src/test/scala/frameless/ColTests.scala index ad62aa068..10b2e1bbb 100644 --- a/dataset/src/test/scala/frameless/ColTests.scala +++ b/dataset/src/test/scala/frameless/ColTests.scala @@ -29,6 +29,41 @@ class ColTests extends TypedDatasetSuite { () } + test("colApply") { + val x4 = TypedDataset.create[X4[Int, String, Long, Boolean]](Nil) + val t4 = TypedDataset.create[(Int, String, Long, Boolean)](Nil) + val x4x4 = TypedDataset.create[X4X4[Int, String, Long, Boolean]](Nil) + + x4(_.a) + t4(_._1) + x4[Int](_.a) + t4[Int](_._1) + + illTyped("x4[String](_.a)", "type mismatch;\n found : Int\n required: String") + + x4(_.b) + t4(_._2) + + x4[String](_.b) + t4[String](_._2) + + illTyped("x4[Int](_.b)", "type mismatch;\n found : String\n required: Int") + + x4x4(_.xa.a) + x4x4[Int](_.xa.a) + x4x4(_.xa.b) + x4x4[String](_.xa.b) + + x4x4(_.xb.a) + x4x4[Int](_.xb.a) + x4x4(_.xb.b) + x4x4[String](_.xb.b) + + illTyped("x4x4[String](_.xa.a)", "type mismatch;\n found : Int\n required: String") + + () + } + test("colMany") { type X2X2 = X2[X2[Int, String], X2[Long, Boolean]] val x2x2 = TypedDataset.create[X2X2](Nil) diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala index 837afed2d..81710c69a 100644 --- a/dataset/src/test/scala/frameless/FilterTests.scala +++ b/dataset/src/test/scala/frameless/FilterTests.scala @@ -22,7 +22,7 @@ class FilterTests extends TypedDatasetSuite { test("filter with arithmetic expressions: addition") { check(forAll { (data: Vector[X1[Int]]) => val ds = TypedDataset.create(data) - val res = ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector + val res = ds.filter((ds(_.a) + 1) === (ds(_.a) + 1)).collect().run().toVector res ?= data }) } @@ -31,7 +31,7 @@ class FilterTests extends TypedDatasetSuite { val t = X1(1) :: X1(2) :: X1(3) :: Nil val tds: TypedDataset[X1[Int]] = TypedDataset.create(t) - assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1))) - assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1))) + assert(tds.filter(tds(_.a) * 2 === 2).collect().run().toVector === Vector(X1(1))) + assert(tds.filter(tds(_.a) * 3 === 3).collect().run().toVector === Vector(X1(1))) } } diff --git a/dataset/src/test/scala/frameless/SelectTests.scala b/dataset/src/test/scala/frameless/SelectTests.scala index cd3d6c411..99cf86d01 100644 --- a/dataset/src/test/scala/frameless/SelectTests.scala +++ b/dataset/src/test/scala/frameless/SelectTests.scala @@ -297,7 +297,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) + const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) + const).collect().run().toVector val data2 = data.map { case X1(a) => num.plus(a, const) } dataset2 ?= data2 @@ -319,7 +319,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) * const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) * const).collect().run().toVector val data2 = data.map { case X1(a) => num.times(a, const) } dataset2 ?= data2 @@ -341,7 +341,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) - const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) - const).collect().run().toVector val data2 = data.map { case X1(a) => num.minus(a, const) } dataset2 ?= data2 @@ -363,7 +363,7 @@ class SelectTests extends TypedDatasetSuite { val ds = TypedDataset.create(data) if (const != 0) { - val dataset2 = ds.select(ds('a) / const).collect().run().toVector.asInstanceOf[Vector[A]] + val dataset2 = ds.select(ds(_.a) / const).collect().run().toVector.asInstanceOf[Vector[A]] val data2 = data.map { case X1(a) => frac.div(a, const) } dataset2 ?= data2 } else 0 ?= 0 @@ -379,17 +379,17 @@ class SelectTests extends TypedDatasetSuite { assert(t.select(t.col('_1)).collect().run().toList === List(2)) // Issue #54 val fooT = t.select(t.col('_1)).map(x => Tuple1.apply(x)).as[Foo] - assert(fooT.select(fooT('i)).collect().run().toList === List(2)) + assert(fooT.select(fooT(_.i)).collect().run().toList === List(2)) } test("unary - on arithmetic") { val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) - assert(e.select(-e('_1)).collect().run().toVector === Vector(-1, -2, -2)) - assert(e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(-3L, -6L, -3L)) + assert(e.select(-e(_._1)).collect().run().toVector === Vector(-1, -2, -2)) + assert(e.select(-(e(_._1) + e(_._3))).collect().run().toVector === Vector(-3L, -6L, -3L)) } test("unary - on strings should not type check") { val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) - illTyped("""e.select( -e('_2) )""") + illTyped("""e.select( -e(_._2) )""") } } \ No newline at end of file diff --git a/dataset/src/test/scala/frameless/XN.scala b/dataset/src/test/scala/frameless/XN.scala index 4fdab552e..0b19771e2 100644 --- a/dataset/src/test/scala/frameless/XN.scala +++ b/dataset/src/test/scala/frameless/XN.scala @@ -64,3 +64,5 @@ object X5 { implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] = Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e)) } + +case class X4X4[A, B, C, D](xa: X4[A, B, C, D], xb: X4[A, B, C, D]) From 82f1ba7a3ef0a3f7dc24b5932e43cf90e0676e5e Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Sat, 18 Feb 2017 11:46:00 -0800 Subject: [PATCH 02/10] Update tuts for new syntax --- docs/src/main/tut/GettingStarted.md | 34 +++++++++---------- .../main/tut/TypedDatasetVsSparkDataset.md | 6 ++-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/src/main/tut/GettingStarted.md b/docs/src/main/tut/GettingStarted.md index dd29ae587..03d037196 100644 --- a/docs/src/main/tut/GettingStarted.md +++ b/docs/src/main/tut/GettingStarted.md @@ -63,13 +63,13 @@ val apartmentsTypedDS2 = spark.createDataset(apartments).typed This is how we select a particular column from a `TypedDataset`: ```tut:book -val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS('city)) +val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS(_.city)) ``` This is completely safe, for instance suppose we misspell `city`: ```tut:book:fail -apartmentsTypedDS.select(apartmentsTypedDS('citi)) +apartmentsTypedDS.select(apartmentsTypedDS(_.citi)) ``` This gets caught at compile-time, whereas with traditional Spark `Dataset` the error appears at run-time. @@ -81,7 +81,7 @@ apartmentsDS.select('citi) `select()` supports arbitrary column operations: ```tut:book -apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('surface) + 2).show().run() +apartmentsTypedDS.select(apartmentsTypedDS(_.surface) * 10, apartmentsTypedDS(_.surface) + 2).show().run() ``` *Note that unlike the standard Spark api, here `show()` is lazy. It requires to apply `run()` for the @@ -91,14 +91,14 @@ apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('su Let us now try to compute the price by surface unit: ```tut:book:fail -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface)) ^ +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface)) ^ ``` Argh! Looks like we can't divide a `TypedColumn` of `Double` by `Int`. Well, we can cast our `Int`s to `Double`s explicitly to proceed with the computation. ```tut:book -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface).cast[Double]) +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface).cast[Double]) priceBySurfaceUnit.collect().run() ``` @@ -107,7 +107,7 @@ Alternatively, we can perform the cast implicitly: ```tut:book import frameless.implicits.widen._ -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface)) +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface)) priceBySurfaceUnit.collect.run() ``` @@ -115,7 +115,7 @@ Looks like it worked, but that `cast` looks unsafe right? Actually it is safe. Let's try to cast a `TypedColumn` of `String` to `Double`: ```tut:book:fail -apartmentsTypedDS('city).cast[Double] +apartmentsTypedDS(_.city).cast[Double] ``` The compile-time error tells us that to perform the cast, an evidence (in the form of `CatalystCast[String, Double]`) must be available. @@ -136,7 +136,7 @@ The cast is valid and the expression compiles: ```tut:book case class UpdatedSurface(city: String, surface: Int) -val updated = apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('surface) + 2).as[UpdatedSurface] +val updated = apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.surface) + 2).as[UpdatedSurface] updated.show(2).run() ``` @@ -144,7 +144,7 @@ Next we try to cast a `(String, String)` to an `UpdatedSurface` (which has types The cast is not valid and the expression does not compile: ```tut:book:fail -apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('city)).as[UpdatedSurface] +apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.city)).as[UpdatedSurface] ``` ### Projections @@ -161,7 +161,7 @@ import frameless.implicits.widen._ val aptds = apartmentsTypedDS // For shorter expressions case class ApartmentDetails(city: String, price: Double, surface: Int, ratio: Double) -val aptWithRatio = aptds.select(aptds('city), aptds('price), aptds('surface), aptds('price) / aptds('surface)).as[ApartmentDetails] +val aptWithRatio = aptds.select(aptds(_.city), aptds(_.price), aptds(_.surface), aptds(_.price) / aptds(_.surface)).as[ApartmentDetails] ``` Suppose we only want to work with `city` and `ratio`: @@ -222,7 +222,7 @@ val udf = apartmentsTypedDS.makeUDF(priceModifier) val aptds = apartmentsTypedDS // For shorter expressions -val adjustedPrice = aptds.select(aptds('city), udf(aptds('city), aptds('price))) +val adjustedPrice = aptds.select(aptds(_.city), udf(aptds(_.city), aptds(_.price))) adjustedPrice.show().run() ``` @@ -230,12 +230,12 @@ adjustedPrice.show().run() ## GroupBy and Aggregations Let's suppose we wanted to retrieve the average apartment price in each city ```tut:book -val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('price))) +val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.price))) priceByCity.collect().run() ``` Again if we try to aggregate a column that can't be aggregated, we get a compilation error ```tut:book:fail -apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('city))) ^ +apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.city))) ^ ``` Next, we combine `select` and `groupBy` to calculate the average price/surface ratio per city: @@ -243,9 +243,9 @@ Next, we combine `select` and `groupBy` to calculate the average price/surface r ```tut:book val aptds = apartmentsTypedDS // For shorter expressions -val cityPriceRatio = aptds.select(aptds('city), aptds('price) / aptds('surface)) +val cityPriceRatio = aptds.select(aptds(_.city), aptds(_.price) / aptds(_.surface)) -cityPriceRatio.groupBy(cityPriceRatio('_1)).agg(avg(cityPriceRatio('_2))).show().run() +cityPriceRatio.groupBy(cityPriceRatio(_._1)).agg(avg(cityPriceRatio(_._2))).show().run() ``` ## Joins @@ -265,7 +265,7 @@ val citiInfoTypedDS = TypedDataset.create(cityInfo) Here is how to join the population information to the apartment's dataset. ```tut:book -val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS('city), citiInfoTypedDS('name)) +val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS(_.city), citiInfoTypedDS(_.name)) withCityInfo.show().run() ``` @@ -278,7 +278,7 @@ We can then select which information we want to continue to work with: case class AptPriceCity(city: String, aptPrice: Double, cityPopulation: Int) withCityInfo.select( - withCityInfo.colMany('_2, 'name), withCityInfo.colMany('_1, 'price), withCityInfo.colMany('_2, 'population) + withCityInfo(_._2.name), withCityInfo(_._1.price), withCityInfo(_._2.population) ).as[AptPriceCity].show().run ``` diff --git a/docs/src/main/tut/TypedDatasetVsSparkDataset.md b/docs/src/main/tut/TypedDatasetVsSparkDataset.md index 5b6775d39..5a85731ff 100644 --- a/docs/src/main/tut/TypedDatasetVsSparkDataset.md +++ b/docs/src/main/tut/TypedDatasetVsSparkDataset.md @@ -116,19 +116,19 @@ with a fully optimized query plan. import frameless.TypedDataset val fds = TypedDataset.create(ds) -fds.filter( fds('i) === 10 ).select( fds('i) ).show().run() +fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).show().run() ``` And the optimized Physical Plan: ```tut:book -fds.filter( fds('i) === 10 ).select( fds('i) ).explain() +fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).explain() ``` And the compiler is our friend. ```tut:fail -fds.filter( fds('i) === 10 ).select( fds('x) ) +fds.filter( fds(_.i) === 10 ).select( fds(_.x) ) ``` ```tut:invisible From fa0ce2f6b3528cf737379041e43da7db595fef8c Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Sat, 18 Feb 2017 12:07:20 -0800 Subject: [PATCH 03/10] Coverage * Added failure test case for unusable function * Added scoverage tags for known-unreachable branches --- .../main/scala/frameless/column/ColumnMacros.scala | 14 +++++++++++--- dataset/src/test/scala/frameless/ColTests.scala | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/dataset/src/main/scala/frameless/column/ColumnMacros.scala b/dataset/src/main/scala/frameless/column/ColumnMacros.scala index f16b3e539..6a4dc5324 100644 --- a/dataset/src/main/scala/frameless/column/ColumnMacros.scala +++ b/dataset/src/main/scala/frameless/column/ColumnMacros.scala @@ -9,16 +9,20 @@ class ColumnMacros(val c: whitebox.Context) { import c.universe._ // could be used to reintroduce apply('foo) + // $COVERAGE-OFF$ Currently unused def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = { val B = weakTypeOf[B].dealias val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})") c.typecheck(q"${c.prefix}.col[$B]($witness)") } + // $COVERAGE-ON$ def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { - def fail(tree: Tree) = c.abort( - tree.pos, - s"Could not create a column identifier from $tree - try using _.a.b syntax") + def fail(tree: Tree) = { + val err = + s"Could not create a column identifier from $tree - try using _.a.b syntax" + c.abort(tree.pos, err) + } val A = weakTypeOf[A].dealias val B = weakTypeOf[B].dealias @@ -28,7 +32,9 @@ class ColumnMacros(val c: whitebox.Context) { case `argName`(strs) => strs.mkString(".") case other => fail(other) } + // $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked case other => fail(other) + // $COVERAGE-ON$ } val typedCol = appliedType( @@ -50,7 +56,9 @@ class ColumnMacros(val c: whitebox.Context) { tree match { case Ident(`name`) => Some(Queue.empty) case Select(This(strs), nested) => Some(strs enqueue nested.toString) + // $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work case Apply(This(strs), List()) => Some(strs) + // $COVERAGE-ON$ case _ => None } } diff --git a/dataset/src/test/scala/frameless/ColTests.scala b/dataset/src/test/scala/frameless/ColTests.scala index 10b2e1bbb..d6aa27191 100644 --- a/dataset/src/test/scala/frameless/ColTests.scala +++ b/dataset/src/test/scala/frameless/ColTests.scala @@ -60,6 +60,7 @@ class ColTests extends TypedDatasetSuite { x4x4[String](_.xb.b) illTyped("x4x4[String](_.xa.a)", "type mismatch;\n found : Int\n required: String") + illTyped("x4x4(item => item.xa.a * 20)", "Could not create a column identifier from item\\.xa\\.a\\.\\*\\(20\\) - try using _\\.a\\.b syntax") () } From add79eccef34e6f0c9438adb79bdf2570018b22b Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Sun, 19 Feb 2017 21:27:51 -0800 Subject: [PATCH 04/10] More macros! Macro all the things! A new macro for `selectExpr[B](T => B)`. It's like `map`, only it only works if the provided function can be converted to a projection of some column expressions over the original dataset - so it avoids SerDe. It won't work all the time, but (hopefully) gives you good compiler errors when it doesn't work. Currently it drops down into `DataFrame` to create the projection plan, and unifies it back into `TypedDataset` (via `TypedDataset.create` over `Dataset`). There are a couple of hairy pieces that could be cleaned up (i.e. it's not clear what to do about SQL functions and UDFs, since they don't operate on the actual type). But it's already at a point where it's pretty useful. --- build.sbt | 10 +- .../main/scala/frameless/TypedDataset.scala | 5 +- .../scala/frameless/column/ColumnMacros.scala | 70 ---- .../scala/frameless/functions/package.scala | 2 + .../scala/frameless/macros/ColumnMacros.scala | 345 ++++++++++++++++++ .../scala/frameless/SelectExprTests.scala | 71 ++++ 6 files changed, 429 insertions(+), 74 deletions(-) delete mode 100644 dataset/src/main/scala/frameless/column/ColumnMacros.scala create mode 100644 dataset/src/main/scala/frameless/macros/ColumnMacros.scala create mode 100644 dataset/src/test/scala/frameless/SelectExprTests.scala diff --git a/build.sbt b/build.sbt index af02be9f9..5c8044d30 100644 --- a/build.sbt +++ b/build.sbt @@ -4,6 +4,10 @@ val scalatest = "3.0.1" val shapeless = "2.3.2" val scalacheck = "1.13.4" +// spark has scalatest and scalactic as a runtime dependency +// which can mess things up if you use a different version in your project +val exclusions = Seq(ExclusionRule("org.scalatest"), ExclusionRule("org.scalactic")) + lazy val root = Project("frameless", file("." + "frameless")).in(file(".")) .aggregate(core, cats, dataset, docs) .settings(framelessSettings: _*) @@ -22,7 +26,7 @@ lazy val cats = project .settings(publishSettings: _*) .settings(libraryDependencies ++= Seq( "org.typelevel" %% "cats" % catsv, - "org.apache.spark" %% "spark-core" % sparkVersion % "provided")) + "org.apache.spark" %% "spark-core" % sparkVersion % "provided" excludeAll(exclusions: _*))) lazy val dataset = project .settings(name := "frameless-dataset") @@ -31,8 +35,8 @@ lazy val dataset = project .settings(framelessTypedDatasetREPL: _*) .settings(publishSettings: _*) .settings(libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-core" % sparkVersion % "provided", - "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" + "org.apache.spark" %% "spark-core" % sparkVersion % "provided" excludeAll(exclusions: _*), + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" excludeAll(exclusions: _*) )) .dependsOn(core % "test->test;compile->compile") diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 056ae0349..175d4f684 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -64,7 +64,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val */ def apply[A](selector: T => A)(implicit encoder: TypedEncoder[A] - ): TypedColumn[T, A] = macro frameless.column.ColumnMacros.fromFunction[T, A] + ): TypedColumn[T, A] = macro frameless.macros.ColumnMacros.fromFunction[T, A] /** Returns `TypedColumn` of type `A` given it's name. * @@ -317,6 +317,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } + def selectExpr[B](expr: T => B)(implicit encoder: TypedEncoder[B]): TypedDataset[B] = + macro frameless.macros.ColumnMacros.fromExpr[T, B] + /** Type-safe projection from type T to Tuple2[A,B] * {{{ * d.select( d('a), d('a)+d('b), ... ) diff --git a/dataset/src/main/scala/frameless/column/ColumnMacros.scala b/dataset/src/main/scala/frameless/column/ColumnMacros.scala deleted file mode 100644 index 6a4dc5324..000000000 --- a/dataset/src/main/scala/frameless/column/ColumnMacros.scala +++ /dev/null @@ -1,70 +0,0 @@ -package frameless.column - -import frameless.{TypedColumn, TypedEncoder, TypedExpressionEncoder} - -import scala.collection.immutable.Queue -import scala.reflect.macros.whitebox - -class ColumnMacros(val c: whitebox.Context) { - import c.universe._ - - // could be used to reintroduce apply('foo) - // $COVERAGE-OFF$ Currently unused - def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = { - val B = weakTypeOf[B].dealias - val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})") - c.typecheck(q"${c.prefix}.col[$B]($witness)") - } - // $COVERAGE-ON$ - - def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { - def fail(tree: Tree) = { - val err = - s"Could not create a column identifier from $tree - try using _.a.b syntax" - c.abort(tree.pos, err) - } - - val A = weakTypeOf[A].dealias - val B = weakTypeOf[B].dealias - - val selectorStr = selector.tree match { - case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => body match { - case `argName`(strs) => strs.mkString(".") - case other => fail(other) - } - // $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked - case other => fail(other) - // $COVERAGE-ON$ - } - - val typedCol = appliedType( - weakTypeOf[TypedColumn[_, _]].typeConstructor, A, B - ) - - val TEEObj = reify(TypedExpressionEncoder) - - val datasetCol = c.typecheck( - q"${c.prefix}.dataset.col($selectorStr).as[$B]($TEEObj.apply[$B]($encoder))" - ) - - c.typecheck(q"new $typedCol($datasetCol)") - } - - case class NameExtractor(name: TermName) { - private val This = this - def unapply(tree: Tree): Option[Queue[String]] = { - tree match { - case Ident(`name`) => Some(Queue.empty) - case Select(This(strs), nested) => Some(strs enqueue nested.toString) - // $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work - case Apply(This(strs), List()) => Some(strs) - // $COVERAGE-ON$ - case _ => None - } - } - } - - object ArgName { - def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name)) - } -} diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 06fe5bc09..ba5529c87 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -2,4 +2,6 @@ package frameless package object functions extends Udf { object aggregate extends AggregateFunctions + + } diff --git a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala new file mode 100644 index 000000000..6ea2eeaa7 --- /dev/null +++ b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala @@ -0,0 +1,345 @@ +package frameless.macros + +import frameless.{TypedColumn, TypedEncoder} + +import scala.collection.immutable.Queue +import scala.reflect.macros.whitebox + +class ColumnMacros(val c: whitebox.Context) { + import c.universe._ + + private val TypedExpressionEncoder = reify(frameless.TypedExpressionEncoder) + private val TypedDataset = reify(frameless.TypedDataset) + + private def toColumn[A : WeakTypeTag, B : WeakTypeTag]( + selectorStr: String, + encoder: c.Expr[TypedEncoder[B]] + ): Tree = { + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val typedCol = appliedType( + weakTypeOf[TypedColumn[_, _]].typeConstructor, A, B + ) + + + val datasetCol = c.typecheck( + q"${c.prefix}.dataset.col($selectorStr).as[$B]($TypedExpressionEncoder.apply[$B]($encoder))" + ) + + c.typecheck(q"new $typedCol($datasetCol)") + } + + // could be used to reintroduce apply('foo) + // $COVERAGE-OFF$ Currently unused + def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + val B = weakTypeOf[B].dealias + val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})") + c.typecheck(q"${c.prefix}.col[$B]($witness)") + } + // $COVERAGE-ON$ + + def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + def fail(tree: Tree) = { + val err = + s"Could not create a column identifier from $tree - try using _.a.b syntax" + c.abort(tree.pos, err) + } + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val selectorStr = selector.tree match { + case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => body match { + case `argName`(strs) => strs.mkString(".") + case other => fail(other) + } + // $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked + case other => fail(other) + // $COVERAGE-ON$ + } + + toColumn[A, B](selectorStr, encoder) + } + + def fromExpr[A : WeakTypeTag, B : WeakTypeTag](expr: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + def fail(tree: Tree)(reason: String, pos: Position = tree.pos) = c.abort( + tree.pos, + s"Could not expand expression $tree$reason" + ) + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val result = expr.tree match { + case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => + val Extractor = ExprExtractor(argName) + val selectCols = body match { + case Extractor(extracted) => extracted match { + // top-level construct; have to unwrap the args + case Right(Construct(tpe, args)) => + Expr.toNamedColumns(args) + case Right(colExpr) => + List(Expr.refold(colExpr)) + case Left(errs) => + errs match { + case (tree, err) :: Nil => fail(tree)(err) + case multi => + multi foreach { + case (tree, err) => c.error(tree.pos, err) + } + fail(body)("multiple errors composing projection") + } + } + + case other => + val o = other + fail(o)("could not compose projection") + } + val df = q"${c.prefix}.dataset.toDF().select(..$selectCols)" + val ds = q"$df.as[$B]($TypedExpressionEncoder.apply[$B])" + val typedDs = q"$TypedDataset.create[$B]($ds)($encoder)" + + typedDs + } + + + c.typecheck(result) + } + + sealed trait Expr { + def tpe: Type + } + + object Expr { + // zip down to the leaves and construct an expression that gives us a Column + // this is some kind of hylo-mylo-bananafanafofylo-morphism. + def refold(expr: Expr): Tree = expr match { + case LiteralExpr(tpe, tree) => q"org.apache.spark.sql.functions.lit($tree)" + case Column(tpe, selectorStr) => + q"new org.apache.spark.sql.ColumnName($selectorStr)" + case FunctionApplication(tpe, src, fn, args) => + c.abort(c.enclosingPosition, "Functions not yet supported in selectExpr") + case BinaryOperator(tpe, lhs, op, rhs) => + val lhsTree = refold(lhs) + val rhsTree = refold(rhs) + Apply(Select(lhsTree, op), List(rhsTree)) + case UnaryRAOperator(tpe, rhs, op) => + val rhsTree = refold(rhs) + Select(rhsTree, op) + case Construct(tpe, args) => + val argTrees = toNamedColumns(args) + q"org.apache.spark.sql.functions.struct(..$argTrees)" + case TrustMeBro(tpe, tree) => tree + } + + def toNamedColumns(exprMap: Map[String, Expr]): List[Tree] = exprMap.toList.map { + case (colName, expr) => + val colTree = refold(expr) + q"$colTree.as($colName)" + } + } + + case class Construct(tpe: Type, args: Map[String, Expr]) extends Expr // Construct tuple or case class + case class UnaryRAOperator(tpe: Type, rhs: Expr, op: TermName) extends Expr // Unary right-associative op + case class BinaryOperator(tpe: Type, lhs: Expr, op: TermName, rhs: Expr) extends Expr // Binary op + case class FunctionApplication(tpe: Type, src: Tree, fn: TermName, args: List[Expr]) extends Expr // SQL function + case class Column(tpe: Type, selectorStr: String) extends Expr // Column of this dataset + case class LiteralExpr(tpe: Type, tree: Tree) extends Expr // Literal expression + case class TrustMeBro(tpe: Type, tree: Tree) extends Expr + + case class ExprExtractor(Rooted: NameExtractor) { + private val This = this + + // substitute these operators with functions (or other trees) + // i.e. two string columns can't be + together; they must be concat(a, b) instead. + val subOperators: Map[MethodSymbol, (Tree, Tree) => Tree] = Map( + weakTypeOf[String].member(TermName("+").encodedName).asMethod -> + ((lhs, rhs) => q"org.apache.spark.sql.functions.concat($lhs, $rhs)") + ) + + def unapply(tree: Tree): Option[Either[List[(Tree, String)], Expr]] = tree match { + + // A single column with no expression around it + case Rooted(strs) if tree.symbol.asMethod.isCaseAccessor => + Some(Right(Column(tree.tpe, strs.mkString(".")))) + + // A unary operator - the operator must exist on org.apache.spark.sql.Column + case Select(This(rhsE), op: TermName) => Some { + if(isColumnOp(op, 0)) { + rhsE.right.map { + rhs => UnaryRAOperator(tree.tpe, rhs, op) + } + } else { + val err = (tree, s"${op.decodedName} is not a valid column operator") + rhsE match { + case Left(errs) => Left(err :: errs) + case Right(_) => Left(err :: Nil) + } + } + } + + // A literal constant (would it be useful to distinguish this from non-constant literal? + case Literal(_) => Some(Right(LiteralExpr(tree.tpe, tree))) + + // Constructing a case class + case Apply(sel @ Select(qualifier, TermName("apply")), AllExprs(argsE)) => Some { + if (isConstructor(sel, tree.tpe)) { + argsE.right.map { + args => + val params = sel.symbol.asMethod.paramLists.head + val names = sel.symbol.asMethod.paramLists.head.zip(args).map { + case (param, paramExpr) => param.name.toString -> paramExpr + }.toMap + Construct(tree.tpe, names) + } + } else { + val err = (tree, "Only constructor can be used here") + Left { + argsE match { + case Left(errs) => err :: errs + case Right(_) => err :: Nil + } + } + } + } + // Constructing a tuple or parameterized case class + case Apply(ta @ TypeApply(sel @ Select(qualifier, TermName("apply")), typArgs), AllExprs(argsE)) => Some { + if(isConstructor(ta, tree.tpe)) { + argsE.right.map { + args => + val params = sel.symbol.asMethod.paramLists.head + val names = sel.symbol.asMethod.paramLists.head.zip(args).map { + case (param, paramExpr) => param.name.toString -> paramExpr + }.toMap + Construct(tree.tpe, names) + } + } else { + val err = (tree, "Only constructor can be used here") + Left { + argsE match { + case Left(errs) => err :: errs + case Right(_) => err :: Nil + } + } + } + } + + // A binary operator - the operator must exist on org.apache.spark.sql.Column + case Apply(sel @ Select(This(lhsE), op: TermName), List(This(rhsE))) if isColumnOp(op, 1) => + Some { + (lhsE, rhsE) match { + case (Left(errsL), Left(errsR)) => Left(errsL ::: errsR) + case (errs @ Left(_), Right(_)) => errs + case (Right(_), errs @ Left(_)) => errs + case (Right(lhs), Right(rhs)) => Right { + subOperators.get(sel.symbol.asMethod).map { + sub => + val trusted = sub(Expr.refold(lhs), Expr.refold(rhs)) + TrustMeBro(tree.tpe, trusted) + }.getOrElse { + BinaryOperator(tree.tpe, lhs, op, rhs) + } + } + } + } + + // A function application - what to do with this? How can we check if it's an OK function? + // Check org.apache.spark.sql.functions? + // I think we have to port all spark functions to typed versions so we can typecheck here + case Apply(Select(src, fn: TermName), AllExprs(argsE)) => + Some(argsE.right.map(args => FunctionApplication(tree.tpe, src, fn, args))) + + case _ => None + } + + private def isColumnOp(name: TermName, numArgs: Int): Boolean = { + val sym = weakTypeOf[org.apache.spark.sql.Column].member(name) + if(sym.isMethod) { + val bothEmptyArgs = numArgs == 0 && sym.asMethod.paramLists.isEmpty + val matchingArgCounts = sym.asMethod.paramLists.map(_.length) == List(numArgs) + if (bothEmptyArgs || matchingArgCounts) + true + else + false + } else { + false + } + } + + def isConstructor(tree: Tree, result: Type, appliedTypes: Option[List[Type]] = None): Boolean = { + val (qualifier, typeArgs) = tree match { + case TypeApply(Select(q, _), args) => (q, args) + case Select(q, _) => (q, Nil) + } + + val MethodType(params, _) = tree.tpe + + if(tree.symbol.isMethod && result.companion =:= qualifier.tpe) { + // must have same args as primary constructor + val meth = tree.symbol.asMethod + result.members.find(s => s.isConstructor && s.asMethod.isPrimaryConstructor) match { + case Some(defaultConstructor) => + defaultConstructor.asMethod.paramLists.head.zip(params).forall { + case (arg1, arg2) => + val sameArg = arg1.name == arg2.name + val arg1Typ = result.member(arg1.asTerm.name).typeSignatureIn(result).finalResultType + val sameType = arg1Typ =:= arg2.typeSignature + sameArg && sameType + } + case _ => false + } + } else false + } + + object AllExprs { + def unapply(trees: List[Tree]): Option[Either[List[(Tree, String)], List[Expr]]] = { + val eithersOpt = trees.map(This.unapply).foldRight[Option[List[Either[(Tree, String), Expr]]]](Some(Nil)) { + (nextOpt, accumOpt) => for { + next <- nextOpt + accum <- accumOpt + } yield next match { + case Left(errs) => errs.map(Left(_)) ::: accum + case Right(expr) => Right(expr) :: accum + } + } + + // wish cats was here + eithersOpt.map { + eithers => + eithers.foldRight[Either[List[(Tree, String)], List[Expr]]](Right(Nil)) { + (next, accum) => next match { + case Left(err) => accum match { + case Left(errs) => Left(err :: errs) + case Right(_) => Left(err :: Nil) + } + case Right(expr) => accum match { + case l @ Left(errs) => l + case Right(exprs) => Right(expr :: exprs) + } + } + } + } + } + } + } + + case class NameExtractor(name: TermName) { + private val This = this + def unapply(tree: Tree): Option[Queue[String]] = { + tree match { + case Ident(`name`) => Some(Queue.empty) + case Select(This(strs), nested) => Some(strs enqueue nested.toString) + // $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work + case Apply(This(strs), List()) => Some(strs) + // $COVERAGE-ON$ + case _ => None + } + } + } + + object ArgName { + def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name)) + } +} diff --git a/dataset/src/test/scala/frameless/SelectExprTests.scala b/dataset/src/test/scala/frameless/SelectExprTests.scala new file mode 100644 index 000000000..8a7e38c17 --- /dev/null +++ b/dataset/src/test/scala/frameless/SelectExprTests.scala @@ -0,0 +1,71 @@ +package frameless + +import org.scalatest.Matchers + +class SelectExprTests extends TypedDatasetSuite with Matchers { + + test("selectExpr with a single column") { + val ds = TypedDataset.create(Seq(X2(20, "twenty"), X2(30, "thirty"))) + val ds2 = ds.selectExpr(_.a) + val ds3 = ds.selectExpr(_.b) + ds2.collect().run() should contain theSameElementsAs Seq(20, 30) + ds3.collect().run() should contain theSameElementsAs Seq("twenty", "thirty") + } + + test("selectExpr with a single column and unary operation") { + val ds = TypedDataset.create(Seq(X1(20), X1(30))) + val ds2 = ds.selectExpr(-_.a) + ds2.collect().run() should contain theSameElementsAs Seq(-20, -30) + } + + test("selectExpr with a binary operation between two columns") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => x.a * x.b) + ds2.collect().run() should contain theSameElementsAs Seq(200, 600) + } + + test("selectExpr with a binary operation between a column and a literal") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => x.a * 10) + ds2.collect().run() should contain theSameElementsAs Seq(100, 200) + } + + test("selectExpr constructing a tuple") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => (x.a, x.b)) + ds2.collect().run() should contain theSameElementsAs Seq((10, 20), (20, 30)) + } + + test("selectExpr constructing a tuple with an operation") { + val ds = TypedDataset.create(Seq(X2(10, 20), X2(20, 30))) + val ds2 = ds.selectExpr(x => (x.a * 10, x.b + 1)) + ds2.collect().run() should contain theSameElementsAs Seq((100, 21), (200, 31)) + } + + test("selectExpr constructing a nested tuple with substitute function") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + val ds2 = ds.selectExpr(x => ((x.a * 1.1, x.c + "foo"), x.b / 2)) + ds2.collect().run() should contain theSameElementsAs Seq( + ((11.0, "foofoo"), 10.0), + ((22.0, "barfoo"), 15.0) + ) + } + + test("selectExpr constructing a case class") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + val ds2 = ds.selectExpr(x => X2(x.a, x.b)) + ds2.collect().run() should contain theSameElementsAs Seq( + X2(10, 20.0), + X2(20, 30.0) + ) + } + + test("selectExpr constructing a nested case class") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + val ds2 = ds.selectExpr(x => X2(X2(x.a, x.b), x.c)) + ds2.collect().run() should contain theSameElementsAs Seq( + X2(X2(10, 20.0), "foo"), + X2(X2(20, 30.0), "bar") + ) + } +} From b58503a5c75d3fd4ca6080e8e1e04f1a45568a4b Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Mon, 20 Feb 2017 11:19:16 -0800 Subject: [PATCH 05/10] Better error messaging & test coverage --- .../scala/frameless/macros/ColumnMacros.scala | 107 +++++++------- .../scala/frameless/SelectExprTests.scala | 130 ++++++++++++++++++ 2 files changed, 183 insertions(+), 54 deletions(-) diff --git a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala index 6ea2eeaa7..0b33c2b41 100644 --- a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala +++ b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala @@ -1,15 +1,17 @@ package frameless.macros import frameless.{TypedColumn, TypedEncoder} +import org.apache.spark.sql.ColumnName import scala.collection.immutable.Queue -import scala.reflect.macros.whitebox +import scala.reflect.macros.{TypecheckException, whitebox} class ColumnMacros(val c: whitebox.Context) { import c.universe._ private val TypedExpressionEncoder = reify(frameless.TypedExpressionEncoder) private val TypedDataset = reify(frameless.TypedDataset) + private val ColumnName = weakTypeOf[ColumnName] private def toColumn[A : WeakTypeTag, B : WeakTypeTag]( selectorStr: String, @@ -25,7 +27,7 @@ class ColumnMacros(val c: whitebox.Context) { val datasetCol = c.typecheck( - q"${c.prefix}.dataset.col($selectorStr).as[$B]($TypedExpressionEncoder.apply[$B]($encoder))" + q"new $ColumnName($selectorStr).as[$B]($TypedExpressionEncoder.apply[$B]($encoder))" ) c.typecheck(q"new $typedCol($datasetCol)") @@ -84,18 +86,18 @@ class ColumnMacros(val c: whitebox.Context) { List(Expr.refold(colExpr)) case Left(errs) => errs match { - case (tree, err) :: Nil => fail(tree)(err) + case (tree, err) :: Nil => fail(tree)(s": $err") case multi => multi foreach { case (tree, err) => c.error(tree.pos, err) } - fail(body)("multiple errors composing projection") + fail(body)(": multiple errors composing projection") } } case other => val o = other - fail(o)("could not compose projection") + fail(o)(": could not compose projection") } val df = q"${c.prefix}.dataset.toDF().select(..$selectCols)" val ds = q"$df.as[$B]($TypedExpressionEncoder.apply[$B])" @@ -118,9 +120,11 @@ class ColumnMacros(val c: whitebox.Context) { def refold(expr: Expr): Tree = expr match { case LiteralExpr(tpe, tree) => q"org.apache.spark.sql.functions.lit($tree)" case Column(tpe, selectorStr) => - q"new org.apache.spark.sql.ColumnName($selectorStr)" + q"new $ColumnName($selectorStr)" + // $COVERAGE-OFF$ - Functions can't currently be used (and currently give an error before this is reached) case FunctionApplication(tpe, src, fn, args) => c.abort(c.enclosingPosition, "Functions not yet supported in selectExpr") + // $COVERAGE-ON$ case BinaryOperator(tpe, lhs, op, rhs) => val lhsTree = refold(lhs) val rhsTree = refold(rhs) @@ -147,6 +151,8 @@ class ColumnMacros(val c: whitebox.Context) { case class FunctionApplication(tpe: Type, src: Tree, fn: TermName, args: List[Expr]) extends Expr // SQL function case class Column(tpe: Type, selectorStr: String) extends Expr // Column of this dataset case class LiteralExpr(tpe: Type, tree: Tree) extends Expr // Literal expression + + // manually added tree (temporary until functions can be implemented; needed for operator substitution) case class TrustMeBro(tpe: Type, tree: Tree) extends Expr case class ExprExtractor(Rooted: NameExtractor) { @@ -162,7 +168,7 @@ class ColumnMacros(val c: whitebox.Context) { def unapply(tree: Tree): Option[Either[List[(Tree, String)], Expr]] = tree match { // A single column with no expression around it - case Rooted(strs) if tree.symbol.asMethod.isCaseAccessor => + case Rooted(strs) if tree.symbol.isMethod && tree.symbol.asMethod.isCaseAccessor => Some(Right(Column(tree.tpe, strs.mkString(".")))) // A unary operator - the operator must exist on org.apache.spark.sql.Column @@ -171,13 +177,7 @@ class ColumnMacros(val c: whitebox.Context) { rhsE.right.map { rhs => UnaryRAOperator(tree.tpe, rhs, op) } - } else { - val err = (tree, s"${op.decodedName} is not a valid column operator") - rhsE match { - case Left(errs) => Left(err :: errs) - case Right(_) => Left(err :: Nil) - } - } + } else addError(rhsE)(tree, s"${op.decodedName} is not a valid column operator") } // A literal constant (would it be useful to distinguish this from non-constant literal? @@ -194,15 +194,7 @@ class ColumnMacros(val c: whitebox.Context) { }.toMap Construct(tree.tpe, names) } - } else { - val err = (tree, "Only constructor can be used here") - Left { - argsE match { - case Left(errs) => err :: errs - case Right(_) => err :: Nil - } - } - } + } else addError(argsE)(tree, "Only constructor can be used here") } // Constructing a tuple or parameterized case class case Apply(ta @ TypeApply(sel @ Select(qualifier, TermName("apply")), typArgs), AllExprs(argsE)) => Some { @@ -215,25 +207,14 @@ class ColumnMacros(val c: whitebox.Context) { }.toMap Construct(tree.tpe, names) } - } else { - val err = (tree, "Only constructor can be used here") - Left { - argsE match { - case Left(errs) => err :: errs - case Right(_) => err :: Nil - } - } - } + } else addError(argsE)(tree, "Only constructor can be used here") } // A binary operator - the operator must exist on org.apache.spark.sql.Column - case Apply(sel @ Select(This(lhsE), op: TermName), List(This(rhsE))) if isColumnOp(op, 1) => - Some { + case Apply(sel @ Select(This(lhsE), op: TermName), List(This(rhsE))) => Some { + if (isColumnOp(op, 1)) { (lhsE, rhsE) match { - case (Left(errsL), Left(errsR)) => Left(errsL ::: errsR) - case (errs @ Left(_), Right(_)) => errs - case (Right(_), errs @ Left(_)) => errs - case (Right(lhs), Right(rhs)) => Right { + case (Right(lhs), Right(rhs)) => Right { subOperators.get(sel.symbol.asMethod).map { sub => val trusted = sub(Expr.refold(lhs), Expr.refold(rhs)) @@ -242,14 +223,19 @@ class ColumnMacros(val c: whitebox.Context) { BinaryOperator(tree.tpe, lhs, op, rhs) } } + case (lhs, rhs) => failFrom(lhs, rhs) } - } + } else addError(lhsE, rhsE)(tree, s"${op.decodedName} is not a valid column operator") + } // A function application - what to do with this? How can we check if it's an OK function? // Check org.apache.spark.sql.functions? // I think we have to port all spark functions to typed versions so we can typecheck here case Apply(Select(src, fn: TermName), AllExprs(argsE)) => - Some(argsE.right.map(args => FunctionApplication(tree.tpe, src, fn, args))) + Some(Left(List((tree, "Function application not currently supported")))) + //Some(argsE.right.map(args => FunctionApplication(tree.tpe, src, fn, args))) + + case Apply(_, _) => Some(Left(List((tree, "Function application not currently supported")))) case _ => None } @@ -276,7 +262,15 @@ class ColumnMacros(val c: whitebox.Context) { val MethodType(params, _) = tree.tpe - if(tree.symbol.isMethod && result.companion =:= qualifier.tpe) { + val companion = Option(result.companion).filterNot(_ == NoType).getOrElse { + try { + c.typecheck(Ident(result.typeSymbol.name.toTermName)).tpe + } catch { + case TypecheckException(_, _) => NoType + } + } + + if(tree.symbol.isMethod && companion =:= qualifier.tpe) { // must have same args as primary constructor val meth = tree.symbol.asMethod result.members.find(s => s.isConstructor && s.asMethod.isPrimaryConstructor) match { @@ -309,31 +303,36 @@ class ColumnMacros(val c: whitebox.Context) { eithersOpt.map { eithers => eithers.foldRight[Either[List[(Tree, String)], List[Expr]]](Right(Nil)) { - (next, accum) => next match { - case Left(err) => accum match { - case Left(errs) => Left(err :: errs) - case Right(_) => Left(err :: Nil) - } - case Right(expr) => accum match { - case l @ Left(errs) => l - case Right(exprs) => Right(expr :: exprs) - } - } + (next, accum) => next.fold( + err => Left(err :: accum.left.toOption.getOrElse(Nil)), + expr => accum.right.map(expr :: _) + ) } } } } } + def addError[A]( + exprs: Either[List[(Tree, String)], A]*)( + tree: Tree, err: String + ): Either[List[(Tree, String)], Expr] = failFrom(exprs: _*).left.map(_ :+ (tree -> err)) + + def failFrom[A](exprs: Either[List[(Tree, String)], A]*): Left[List[(Tree, String)], Expr] = Left { + exprs.foldLeft[List[(Tree, String)]](Nil) { + (accum, next) => next.fold( + errs => accum ::: errs, + _ => accum + ) + } + } + case class NameExtractor(name: TermName) { private val This = this def unapply(tree: Tree): Option[Queue[String]] = { tree match { case Ident(`name`) => Some(Queue.empty) case Select(This(strs), nested) => Some(strs enqueue nested.toString) - // $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work - case Apply(This(strs), List()) => Some(strs) - // $COVERAGE-ON$ case _ => None } } diff --git a/dataset/src/test/scala/frameless/SelectExprTests.scala b/dataset/src/test/scala/frameless/SelectExprTests.scala index 8a7e38c17..1cb3ee106 100644 --- a/dataset/src/test/scala/frameless/SelectExprTests.scala +++ b/dataset/src/test/scala/frameless/SelectExprTests.scala @@ -1,6 +1,15 @@ package frameless import org.scalatest.Matchers +import shapeless.test.illTyped + +case class Foo(a: Int, b: String) { + // tests failure of unsupported unary operator + def unary_~ : Foo = copy(a = a + 1) + + // tests failure of binary operator with different arg count (Column#when takes two args) + def when(a: Int): Int = this.a + a +} class SelectExprTests extends TypedDatasetSuite with Matchers { @@ -60,6 +69,17 @@ class SelectExprTests extends TypedDatasetSuite with Matchers { ) } + test("selectExpr constructing an unparameterized case class") { + + val ds = TypedDataset.create(Seq(Foo(10, "ten"), Foo(20, "twenty"))) + val ds2 = ds.selectExpr(f => Foo(f.a * 5, f.b + "foo")) + ds2.collect().run() should contain theSameElementsAs Seq( + Foo(50, "tenfoo"), + Foo(100, "twentyfoo") + ) + + } + test("selectExpr constructing a nested case class") { val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) val ds2 = ds.selectExpr(x => X2(X2(x.a, x.b), x.c)) @@ -68,4 +88,114 @@ class SelectExprTests extends TypedDatasetSuite with Matchers { X2(X2(20, 30.0), "bar") ) } + + test("selectExpr accessing a nested case class field") { + val ds = TypedDataset.create(Seq( + X2(X2(11.0, "foofoo"), 10.0), + X2(X2(22.0, "barfoo"), 15.0) + )) + + val ds2 = ds.selectExpr(x => (x.a.a * x.b, x.a.b)) + + ds2.collect().run() should contain theSameElementsAs Seq( + (110.0, "foofoo"), + (330.0, "barfoo") + ) + } + + test("can't use a non-constructor apply method") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + object Foo { + def apply(str: String, i: Int): String = str + i.toString + } + + object FooP { + def apply[A, B](a: A, b: B): X2[A, B] = X2(a, b) + } + + illTyped( + "val ds2 = ds.selectExpr(x => Foo(x.c, x.a))", + ".*Only constructor can be used here" + ) + + illTyped( + "val ds2 = ds.selectExpr(x => FooP(x.c, x.a))", + ".*Only constructor can be used here" + ) + + } + + //TODO: could we just UDF the function in this case? + test("functions not yet supported") { + def strfun(s: String) = s"fun${s}fun" + + object Fun { + def strfun(s: String) = s"Funfun${s}Funfun" + } + + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.a, strfun(x.c)))", + ".*Function application not currently supported" + ) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.a, Fun.strfun(x.c)))", + ".*Function application not currently supported" + ) + } + + test("fails if operator is not available on Column") { + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => x.c substring x.a)", + ".*substring is not a valid column operator" + ) + + val ds3 = TypedDataset.create(Seq(X2(Foo(10, "ten"), "foo"), X2(Foo(20, "twenty"), "bar"))) + + illTyped( + "val ds4 = ds3.selectExpr(x => ~x.a)", + ".*~ is not a valid column operator" + ) + + illTyped( + """val ds5 = ds3.selectExpr(x => x.a when 5)""", + ".*when is not a valid column operator" + // would be nice if this error was clearer about the arity mismatch being the cause + ) + } + + test("multiple errors are all shown before aborting") { + def strfun(s: String) = s"fun${s}fun" + + object Fun { + def strfun(s: String) = s"Funfun${s}Funfun" + } + + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.a, (strfun(x.c), Fun.strfun(x.c))))", + ".*Function application not currently supported" + // this is a limitation of illTyped - it only looks at the first emitted error, rather than the error that aborted + ) + } + + test("errors emitted from nested expressions in allowed binary operator") { + // tests failFrom case in binary operator case + + def intfun(i: Int): Int = i + 1 + + val ds = TypedDataset.create(Seq(X3(10, 20.0, "foo"), X3(20, 30.0, "bar"))) + + illTyped( + "val ds2 = ds.selectExpr(x => (x.c, x.a + intfun(x.a)))", + ".*Function application not currently supported" + ) + + } } From c7550c5f29e175b32fa36f32ab1d1371800e0750 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Mon, 20 Feb 2017 14:23:30 -0800 Subject: [PATCH 06/10] Added Array TypedEncoder * Added Array TypedEncoder derivation * Primitive optimization on Vector TypedEncoder * Added unary function rewrites for `.length` on Array and Vector --- .../main/scala/frameless/TypedEncoder.scala | 88 ++++++++++++++++--- .../scala/frameless/macros/ColumnMacros.scala | 87 ++++++++++++------ .../test/scala/frameless/CreateTests.scala | 40 ++++++++- .../scala/frameless/SelectExprTests.scala | 10 +++ 4 files changed, 187 insertions(+), 38 deletions(-) diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index a434f3d87..4df5d4c88 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -8,6 +8,7 @@ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import shapeless._ + import scala.reflect.ClassTag abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Serializable { @@ -264,15 +265,30 @@ object TypedEncoder { def targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) def constructorFor(path: Expression): Expression = { - val arrayData = Invoke( - MapObjects( - underlying.constructorFor, - path, - underlying.targetDataType - ), - "array", - ScalaReflection.dataTypeFor[Array[AnyRef]] - ) + val arrayData = Option(underlying.sourceDataType) + .filter(ScalaReflection.isNativeType) + .filter(_ == underlying.targetDataType) + .collect { + case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]] + case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] + case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] + case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] + case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] + case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] + case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] + }.map { + case (method, typ) => Invoke(path, method, typ) + }.getOrElse { + Invoke( + MapObjects( + underlying.constructorFor, + path, + underlying.targetDataType + ), + "array", + ScalaReflection.dataTypeFor[Array[AnyRef]] + ) + } StaticInvoke( TypedEncoderUtils.getClass, @@ -296,6 +312,58 @@ object TypedEncoder { } } + implicit def arrayEncoder[A]( + implicit + underlying: TypedEncoder[A], + classTag: ClassTag[Array[A]] + ): TypedEncoder[Array[A]] = new TypedEncoder[Array[A]]() { + def nullable: Boolean = false + + def sourceDataType: DataType = FramelessInternals.objectTypeFor[Array[A]](classTag) + + def targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) + + def constructorFor(path: Expression): Expression = { + Option(underlying.sourceDataType) + .filter(ScalaReflection.isNativeType) + .filter(_ == underlying.targetDataType) + .collect { + case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]] + case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] + case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] + case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] + case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] + case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] + case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] + }.map { + case (method, typ) => Invoke(path, method, typ) + }.getOrElse { + Invoke( + MapObjects( + underlying.constructorFor, + path, + underlying.targetDataType + ), + "array", + ScalaReflection.dataTypeFor[Array[AnyRef]] + ) + } + } + + def extractorFor(path: Expression): Expression = { + // if source `path` is already native for Spark, no need to `map` + if (ScalaReflection.isNativeType(underlying.sourceDataType)) { + NewInstance( + classOf[GenericArrayData], + path :: Nil, + dataType = ArrayType(underlying.targetDataType, underlying.nullable) + ) + } else { + MapObjects(underlying.extractorFor, path, underlying.sourceDataType) + } + } + } + /** Encodes things using injection if there is one defined */ implicit def usingInjection[A: ClassTag, B] (implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] = @@ -322,4 +390,4 @@ object TypedEncoder { recordEncoder: Lazy[RecordEncoderFields[G]], classTag: ClassTag[F] ): TypedEncoder[F] = new RecordEncoder[F, G] -} +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala index 0b33c2b41..fd1ab2863 100644 --- a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala +++ b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala @@ -160,11 +160,20 @@ class ColumnMacros(val c: whitebox.Context) { // substitute these operators with functions (or other trees) // i.e. two string columns can't be + together; they must be concat(a, b) instead. - val subOperators: Map[MethodSymbol, (Tree, Tree) => Tree] = Map( + val subBinaryOperators: Map[MethodSymbol, (Tree, Tree) => Tree] = Map( weakTypeOf[String].member(TermName("+").encodedName).asMethod -> ((lhs, rhs) => q"org.apache.spark.sql.functions.concat($lhs, $rhs)") ) + val subUnaryOperators: Map[MethodSymbol, Tree => Tree] = Map( + weakTypeOf[String].member(TermName("length")).asMethod -> + (lhs => q"org.apache.spark.sql.functions.length($lhs)"), + weakTypeOf[Array[_]].typeConstructor.member(TermName("length")).asMethod -> + (lhs => q"org.apache.spark.sql.functions.size($lhs)"), + weakTypeOf[Vector[_]].typeConstructor.member(TermName("length")).asMethod -> + (lhs => q"org.apache.spark.sql.functions.size($lhs)") + ) + def unapply(tree: Tree): Option[Either[List[(Tree, String)], Expr]] = tree match { // A single column with no expression around it @@ -172,12 +181,34 @@ class ColumnMacros(val c: whitebox.Context) { Some(Right(Column(tree.tpe, strs.mkString(".")))) // A unary operator - the operator must exist on org.apache.spark.sql.Column - case Select(This(rhsE), op: TermName) => Some { - if(isColumnOp(op, 0)) { - rhsE.right.map { - rhs => UnaryRAOperator(tree.tpe, rhs, op) + case sel @ Select(This(rhsE), op: TermName) => Some { + subUnaryOperators.get(sel.symbol.asMethod) + .map { + sub => rhsE.right.map(rhs => TrustMeBro(tree.tpe, sub(Expr.refold(rhs)))) + }.getOrElse { + if(isColumnOp(op, 0)) { + rhsE.right.map { + rhs => UnaryRAOperator(tree.tpe, rhs, op) + } + } else addError(rhsE)(tree, s"${op.decodedName} is not a valid column operator") } - } else addError(rhsE)(tree, s"${op.decodedName} is not a valid column operator") + } + + // A unary operator (which is a no-arg method) - the operator must exist on org.apache.spark.sql.Column + case Apply(sel @ Select(This(rhsE), op: TermName), List()) => Some { + subUnaryOperators.get(sel.symbol.asMethod) + .map { + sub => rhsE.right.map(rhs => TrustMeBro(tree.tpe, sub(Expr.refold(rhs)))) + }.getOrElse { + + if(isColumnOp(op, 0)) { + // $COVERAGE-OFF$ - can't find a unary column op that isn't substituted and is a no-arg apply tree + rhsE.right.map { + rhs => UnaryRAOperator(tree.tpe, rhs, op) + } + // $COVERAGE-ON$ + } else addError(rhsE)(tree, s"${op.decodedName} is not a valid column operator") + } } // A literal constant (would it be useful to distinguish this from non-constant literal? @@ -212,20 +243,19 @@ class ColumnMacros(val c: whitebox.Context) { // A binary operator - the operator must exist on org.apache.spark.sql.Column case Apply(sel @ Select(This(lhsE), op: TermName), List(This(rhsE))) => Some { - if (isColumnOp(op, 1)) { - (lhsE, rhsE) match { - case (Right(lhs), Right(rhs)) => Right { - subOperators.get(sel.symbol.asMethod).map { - sub => - val trusted = sub(Expr.refold(lhs), Expr.refold(rhs)) - TrustMeBro(tree.tpe, trusted) - }.getOrElse { - BinaryOperator(tree.tpe, lhs, op, rhs) - } - } - case (lhs, rhs) => failFrom(lhs, rhs) - } - } else addError(lhsE, rhsE)(tree, s"${op.decodedName} is not a valid column operator") + subBinaryOperators.get(sel.symbol.asMethod).map { + sub => requireBoth(lhsE, rhsE)( + (lhs, rhs) => failFrom(lhs, rhs), + (lhs, rhs) => Right(TrustMeBro(tree.tpe, sub(Expr.refold(lhs), Expr.refold(rhs)))) + ) + }.getOrElse { + if (isColumnOp(op, 1)) { + requireBoth(lhsE, rhsE)( + (lhs, rhs) => failFrom(lhs, rhs), + (lhs, rhs) => Right(BinaryOperator(tree.tpe, lhs, op, rhs)) + ) + } else addError(lhsE, rhsE)(tree, s"${op.decodedName} is not a valid column operator") + } } // A function application - what to do with this? How can we check if it's an OK function? @@ -242,15 +272,10 @@ class ColumnMacros(val c: whitebox.Context) { private def isColumnOp(name: TermName, numArgs: Int): Boolean = { val sym = weakTypeOf[org.apache.spark.sql.Column].member(name) - if(sym.isMethod) { + sym.isMethod && { val bothEmptyArgs = numArgs == 0 && sym.asMethod.paramLists.isEmpty val matchingArgCounts = sym.asMethod.paramLists.map(_.length) == List(numArgs) - if (bothEmptyArgs || matchingArgCounts) - true - else - false - } else { - false + bothEmptyArgs || matchingArgCounts } } @@ -327,6 +352,14 @@ class ColumnMacros(val c: whitebox.Context) { } } + def requireBoth[A, B, C](first: Either[A, B], second: Either[A, B])( + oneLeft: (Either[A, B], Either[A, B]) => C, + bothRight: (B, B) => C + ): C = (first, second) match { + case (Right(f), Right(s)) => bothRight(f, s) + case _ => oneLeft(first, second) + } + case class NameExtractor(name: TermName) { private val This = this def unapply(tree: Tree): Option[Queue[String]] = { diff --git a/dataset/src/test/scala/frameless/CreateTests.scala b/dataset/src/test/scala/frameless/CreateTests.scala index c1ce719b0..cdc38c170 100644 --- a/dataset/src/test/scala/frameless/CreateTests.scala +++ b/dataset/src/test/scala/frameless/CreateTests.scala @@ -1,6 +1,6 @@ package frameless -import org.scalacheck.Prop +import org.scalacheck.{Arbitrary, Prop} import org.scalacheck.Prop._ import org.scalatest.Matchers @@ -28,6 +28,44 @@ class CreateTests extends TypedDatasetSuite with Matchers { Vector[(Food, Country)]] _)) } + test("array fields") { + + def prop[T](implicit arb: Arbitrary[Array[T]], encoder: TypedEncoder[X1[Array[T]]]) = forAll { + data: Array[T] => + val Seq(X1(arr)) = TypedDataset.create(Seq(X1(data))).collect().run() + Prop(arr.sameElements(data)) + } + + check(prop[Boolean]) + check(prop[Byte]) + check(prop[Short]) + check(prop[Int]) + check(prop[Long]) + check(prop[Float]) + check(prop[Double]) + check(prop[X1[String]]) + + } + + test("vector fields") { + + def prop[T](implicit arb: Arbitrary[Vector[T]], encoder: TypedEncoder[X1[Vector[T]]]) = forAll { + data: Vector[T] => + val Seq(X1(vec)) = TypedDataset.create(Seq(X1(data))).collect().run() + Prop(vec == data) + } + + check(prop[Boolean]) + check(prop[Byte]) + check(prop[Short]) + check(prop[Int]) + check(prop[Long]) + check(prop[Float]) + check(prop[Double]) + check(prop[X1[String]]) + + } + test("not alligned columns should throw an exception") { val v = Vector(X2(1,2)) val df = TypedDataset.create(v).dataset.toDF() diff --git a/dataset/src/test/scala/frameless/SelectExprTests.scala b/dataset/src/test/scala/frameless/SelectExprTests.scala index 1cb3ee106..cf38b8a53 100644 --- a/dataset/src/test/scala/frameless/SelectExprTests.scala +++ b/dataset/src/test/scala/frameless/SelectExprTests.scala @@ -58,6 +58,16 @@ class SelectExprTests extends TypedDatasetSuite with Matchers { ((11.0, "foofoo"), 10.0), ((22.0, "barfoo"), 15.0) ) + + val ds3 = ds.selectExpr(x => (x.a, x.c.length)) + } + + test("substitute length() on Array and Vector") { + val ds = TypedDataset.create(Seq( + X2(Array(1, 2, 3), Vector(1, 2)) + )) + val ds2 = ds.selectExpr(x => (x.a.length, x.b.length)) + ds2.collect().run() should contain theSameElementsAs Seq((3, 2)) } test("selectExpr constructing a case class") { From 1e6e019508c890cbaa384c5a7e26250128704153 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Tue, 21 Feb 2017 16:00:29 -0800 Subject: [PATCH 07/10] Added functions to selectExpr `selectExpr` now supports functions. The scheme I've come up with has a mirror of each Spark column function, but with a type signature that describes the actual arguments and result type of the function. These mirrors (in frameless.functions.quoted) are annotated with a variant of QuotedFunc, which specifies which Spark function they should rewrite to. The macro picks up these annotations and allows the function call to be rewritten to the native spark function once the column expressions are refolded. Roughly half of functions are mirrored at this point; need to finish porting aggregate and misc functions. --- .../main/scala/frameless/TypedEncoder.scala | 185 ++++++++--- .../scala/frameless/functions/quoted.scala | 297 ++++++++++++++++++ .../scala/frameless/macros/ColumnMacros.scala | 90 +++++- .../test/scala/frameless/CreateTests.scala | 22 ++ .../scala/frameless/SelectExprTests.scala | 43 ++- 5 files changed, 577 insertions(+), 60 deletions(-) create mode 100644 dataset/src/main/scala/frameless/functions/quoted.scala diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 4df5d4c88..204763f0a 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -4,11 +4,12 @@ import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import shapeless._ +import scala.collection.Map import scala.reflect.ClassTag abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Serializable { @@ -28,16 +29,18 @@ abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Seria // - import TypedEncoder.usingInjection // - import TypedEncoder.usingDerivation // """) -object TypedEncoder { +object TypedEncoder extends TypedEncoder0 { def apply[T: TypedEncoder]: TypedEncoder[T] = implicitly[TypedEncoder[T]] implicit val unitEncoder: TypedEncoder[Unit] = new TypedEncoder[Unit] { def nullable: Boolean = true def sourceDataType: DataType = ScalaReflection.dataTypeFor[Unit] + def targetDataType: DataType = NullType def constructorFor(path: Expression): Expression = Literal.create((), sourceDataType) + def extractorFor(path: Expression): Expression = Literal.create(null, targetDataType) } @@ -45,6 +48,7 @@ object TypedEncoder { def nullable: Boolean = true def sourceDataType: DataType = FramelessInternals.objectTypeFor[String] + def targetDataType: DataType = StringType def extractorFor(path: Expression): Expression = @@ -58,9 +62,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = BooleanType + def targetDataType: DataType = BooleanType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -68,9 +74,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = IntegerType + def targetDataType: DataType = IntegerType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -78,9 +86,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = LongType + def targetDataType: DataType = LongType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -88,9 +98,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ShortType + def targetDataType: DataType = ShortType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -99,6 +111,7 @@ object TypedEncoder { implicit val charAsString: Injection[java.lang.Character, String] = new Injection[java.lang.Character, String] { def apply(a: java.lang.Character): String = String.valueOf(a) + def invert(b: String): java.lang.Character = { require(b.length == 1) b.charAt(0) @@ -111,9 +124,11 @@ object TypedEncoder { // this line fixes underlying encoder def sourceDataType: DataType = FramelessInternals.objectTypeFor[java.lang.Character] + def targetDataType: DataType = StringType def extractorFor(path: Expression): Expression = underlying.extractorFor(path) + def constructorFor(path: Expression): Expression = underlying.constructorFor(path) } @@ -121,9 +136,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ByteType + def targetDataType: DataType = ByteType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -131,9 +148,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = FloatType + def targetDataType: DataType = FloatType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -141,9 +160,11 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = DoubleType + def targetDataType: DataType = DoubleType def extractorFor(path: Expression): Expression = path + def constructorFor(path: Expression): Expression = path } @@ -151,6 +172,7 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ScalaReflection.dataTypeFor[BigDecimal] + def targetDataType: DataType = DecimalType.SYSTEM_DEFAULT def extractorFor(path: Expression): Expression = @@ -164,6 +186,7 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ScalaReflection.dataTypeFor[SQLDate] + def targetDataType: DataType = DateType def extractorFor(path: Expression): Expression = @@ -183,6 +206,7 @@ object TypedEncoder { def nullable: Boolean = false def sourceDataType: DataType = ScalaReflection.dataTypeFor[SQLTimestamp] + def targetDataType: DataType = TimestampType def extractorFor(path: Expression): Expression = @@ -205,6 +229,7 @@ object TypedEncoder { def nullable: Boolean = true def sourceDataType: DataType = FramelessInternals.objectTypeFor[Option[A]](classTag) + def targetDataType: DataType = underlying.targetDataType def extractorFor(path: Expression): Expression = { @@ -254,15 +279,15 @@ object TypedEncoder { WrapOption(underlying.constructorFor(path), underlying.sourceDataType) } - implicit def vectorEncoder[A]( + implicit def vectorEncoder[A : ClassTag]( implicit underlying: TypedEncoder[A] ): TypedEncoder[Vector[A]] = new TypedEncoder[Vector[A]]() { - def nullable: Boolean = false + val nullable: Boolean = false - def sourceDataType: DataType = FramelessInternals.objectTypeFor[Vector[A]](classTag) + val sourceDataType: DataType = FramelessInternals.objectTypeFor[Vector[A]](classTag) - def targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) + val targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) def constructorFor(path: Expression): Expression = { val arrayData = Option(underlying.sourceDataType) @@ -270,25 +295,25 @@ object TypedEncoder { .filter(_ == underlying.targetDataType) .collect { case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]] - case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] - case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] - case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] - case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] - case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] - case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] + case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] + case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] + case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] + case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] + case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] + case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] }.map { - case (method, typ) => Invoke(path, method, typ) - }.getOrElse { - Invoke( - MapObjects( - underlying.constructorFor, - path, - underlying.targetDataType - ), - "array", - ScalaReflection.dataTypeFor[Array[AnyRef]] - ) - } + case (method, typ) => Invoke(path, method, typ) + }.getOrElse { + Invoke( + MapObjects( + underlying.constructorFor, + path, + underlying.targetDataType + ), + "array", + FramelessInternals.objectTypeFor[Array[A]] + ) + } StaticInvoke( TypedEncoderUtils.getClass, @@ -312,16 +337,15 @@ object TypedEncoder { } } - implicit def arrayEncoder[A]( + implicit def arrayEncoder[A : ClassTag]( implicit - underlying: TypedEncoder[A], - classTag: ClassTag[Array[A]] + underlying: TypedEncoder[A] ): TypedEncoder[Array[A]] = new TypedEncoder[Array[A]]() { - def nullable: Boolean = false + val nullable: Boolean = false - def sourceDataType: DataType = FramelessInternals.objectTypeFor[Array[A]](classTag) + val sourceDataType: DataType = FramelessInternals.objectTypeFor[Array[A]](classTag) - def targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) + val targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType) def constructorFor(path: Expression): Expression = { Option(underlying.sourceDataType) @@ -329,25 +353,25 @@ object TypedEncoder { .filter(_ == underlying.targetDataType) .collect { case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]] - case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] - case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] - case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] - case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] - case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] - case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] + case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]] + case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]] + case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]] + case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]] + case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]] + case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]] }.map { - case (method, typ) => Invoke(path, method, typ) - }.getOrElse { - Invoke( - MapObjects( - underlying.constructorFor, - path, - underlying.targetDataType - ), - "array", - ScalaReflection.dataTypeFor[Array[AnyRef]] - ) - } + case (method, typ) => Invoke(path, method, typ) + }.getOrElse { + Invoke( + MapObjects( + underlying.constructorFor, + path, + underlying.targetDataType + ), + "array", + sourceDataType + ) + } } def extractorFor(path: Expression): Expression = { @@ -364,6 +388,71 @@ object TypedEncoder { } } + implicit def mapEncoder[A : ClassTag, B : ClassTag]( + implicit + encodeA: TypedEncoder[A], + encodeB: TypedEncoder[B] + ): TypedEncoder[scala.collection.immutable.Map[A, B]] = new TypedEncoder[scala.collection.immutable.Map[A, B]] { + val nullable: Boolean = false + val sourceDataType = FramelessInternals.objectTypeFor[Map[A, B]] + val targetDataType = MapType(encodeA.targetDataType, encodeB.targetDataType, encodeB.nullable) + + implicit val classTagArrayA = implicitly[ClassTag[A]].wrap + implicit val classTagArrayB = implicitly[ClassTag[B]].wrap + + private val arrayA = arrayEncoder[A] + private val arrayB = arrayEncoder[B] + private val vectorA = vectorEncoder[A] + private val vectorB = vectorEncoder[B] + + + private def wrap(arrayData: Expression) = { + StaticInvoke( + scala.collection.mutable.WrappedArray.getClass, + FramelessInternals.objectTypeFor[Seq[_]], + "make", + arrayData :: Nil) + } + + def constructorFor(path: Expression): Expression = { + val keyArrayType = ArrayType(encodeA.targetDataType, false) + val keyData = wrap(arrayA.constructorFor(Invoke(path, "keyArray", keyArrayType))) + + val valueArrayType = ArrayType(encodeB.targetDataType, encodeB.nullable) + val valueData = wrap(arrayB.constructorFor(Invoke(path, "valueArray", valueArrayType))) + + StaticInvoke( + ArrayBasedMapData.getClass, + sourceDataType, + "toScalaMap", + keyData :: valueData :: Nil) + } + + def extractorFor(path: Expression): Expression = { + val keys = + Invoke( + Invoke(path, "keysIterator", FramelessInternals.objectTypeFor[scala.collection.Iterator[A]]), + "toVector", + vectorA.sourceDataType) + val convertedKeys = arrayA.extractorFor(keys) + + val values = + Invoke( + Invoke(path, "valuesIterator", FramelessInternals.objectTypeFor[scala.collection.Iterator[B]]), + "toVector", + vectorB.sourceDataType) + val convertedValues = arrayB.extractorFor(values) + + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = targetDataType) + } + + } +} + +sealed trait TypedEncoder0 { self: TypedEncoder.type => /** Encodes things using injection if there is one defined */ implicit def usingInjection[A: ClassTag, B] (implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] = diff --git a/dataset/src/main/scala/frameless/functions/quoted.scala b/dataset/src/main/scala/frameless/functions/quoted.scala new file mode 100644 index 000000000..78379a472 --- /dev/null +++ b/dataset/src/main/scala/frameless/functions/quoted.scala @@ -0,0 +1,297 @@ +package frameless.functions + +import scala.annotation.{StaticAnnotation, compileTimeOnly} +import scala.math.BigDecimal.RoundingMode +import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.expressions.{Factorial, Hex} +import org.apache.spark.sql.catalyst.util.NumberConverter +import org.apache.spark.sql.{Column, functions => sf} + +/** + * Quoted functions, for being used inside an expression + * Functions annotated with @sparkFunction are rewritten accordingly, whereas other functions are converted to UDFs and + * actually applied. + * + * Note that these functions have simple implementations in order to provide explanation of what they do, and so that + * it won't be catastrophic to call them in the wrong context. But the implementations ought to never be invoked + * at run time. + */ +object quoted { + + // $COVERAGE-OFF$ - these functions will never actually be invoked at runtime + + sealed trait QuotedFunc + class sparkFunction0[B](referrent: () => B) extends StaticAnnotation with QuotedFunc + class sparkFunction[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc + class sparkColumnOp[A, B](referrent: Column => A => B) extends StaticAnnotation with QuotedFunc + class sparkAggregate[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc + class sparkWindowFunction[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc + + @sparkFunction(sf.abs) + def abs[T : Numeric](t: T): T = implicitly[Numeric[T]].abs(t) + + @sparkFunction((cols: Seq[Column]) => sf.array(cols:_*)) + def array[T : ClassTag](items: T*): Array[T] = items.toArray + + @sparkFunction((cols: Seq[(Column,Column)]) => sf.map(cols.flatMap(a => Seq(a._1, a._2)):_*)) + def map[T, U](items: (T, U)*): Map[T, U] = items.toMap + + @sparkFunction((cols: Seq[Column]) => sf.coalesce(cols:_*)) + def coalesce[T](items: T*): T = items.find(_ != null).getOrElse(null.asInstanceOf[T]) + + @sparkFunction0(sf.input_file_name) + def input_file_name(): String = "invoked outside of quoted expression" + + @sparkFunction(sf.isnan) + def isnan[T: Numeric](t: T): Boolean = implicitly[Numeric[T]].toDouble(t).isNaN + + @sparkFunction(sf.isnull) + def isnull[T](t: T): Boolean = t == null + + private lazy val longs = Stream.range(0L, Long.MaxValue).iterator + @sparkFunction0(sf.monotonically_increasing_id) + def monotonically_increasing_id(): Long = longs.next() + + @sparkFunction((sf.nanvl _).tupled) + def nanvl(first: Double, second: Double): Double = if(java.lang.Double.isNaN(first)) second else first + + @sparkFunction((sf.nanvl _).tupled) + def nanvl(first: Float, second: Float): Float = if(java.lang.Float.isNaN(first)) second else first + + @sparkFunction(sf.negate) + def negate[T : Numeric](t: T): T = implicitly[Numeric[T]].negate(t) + + @sparkFunction(sf.not) + def not(b: Boolean): Boolean = !b + + @sparkFunction(sf.rand) + def rand(seed: Long): Double = { + scala.util.Random.setSeed(seed) + scala.util.Random.nextDouble() + } + + @sparkFunction0(sf.rand) + def rand(): Double = scala.util.Random.nextDouble() + + @sparkFunction(sf.randn) + def randn(seed: Long): Double = { + scala.util.Random.setSeed(seed) + scala.util.Random.nextGaussian() + } + + @sparkFunction0(sf.randn) + def randn(): Double = scala.util.Random.nextGaussian() + + @sparkFunction0(sf.spark_partition_id) + def spark_partition_id(): Int = -1 + + @sparkFunction(sf.sqrt(_: Column)) + def sqrt[T](t: T)(implicit frac: Fractional[T]): Double = math.sqrt(frac.toDouble(t)) + + + case class CaseBuilder[T]() { + @sparkColumnOp(col => (col.when _).tupled) + def when(condition: Boolean, value: T): CaseBuilder[T] = this + + @sparkColumnOp(_.otherwise) + def otherwise(value: T): T = value + } + + @sparkFunction((sf.when _).tupled) + def when[T](condition: Boolean, value: T): CaseBuilder[T] = CaseBuilder[T]() + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(b: Byte): Byte = (~b).toByte + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(s: Short): Short = (~s).toShort + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(i: Int): Int = ~i + + @sparkFunction(sf.bitwiseNOT) + def bitwiseNOT(l: Long): Long = ~l + + @sparkFunction(sf.acos(_:Column)) + def acos(d: Double): Double = scala.math.acos(d) + + @sparkFunction(sf.asin(_:Column)) + def asin(d: Double): Double = scala.math.asin(d) + + @sparkFunction(sf.atan(_:Column)) + def atan(d: Double): Double = scala.math.atan(d) + + @sparkFunction((sf.atan2(_: Column, _: Column)).tupled) + def atan2(l: Double, r: Double): Double = scala.math.atan2(l, r) + + @sparkFunction(sf.bin(_:Column)) + def bin(l: Long): String = java.lang.Long.toBinaryString(l) + + @sparkFunction(sf.cbrt(_:Column)) + def cbrt(d: Double): Double = math.cbrt(d) + + @sparkFunction(sf.ceil(_:Column)) + def ceil(d: Double): Double = math.ceil(d) + + @sparkFunction((sf.conv _).tupled) + def conv(num: String, fromBase: Int, toBase: Int): String = + NumberConverter.convert(num.getBytes(), fromBase, toBase).toString + + @sparkFunction(sf.cos(_:Column)) + def cos(n: Double): Double = math.cos(n) + + @sparkFunction(sf.cosh(_:Column)) + def cosh(n: Double): Double = math.cosh(n) + + @sparkFunction(sf.exp(_:Column)) + def exp(n: Double): Double = math.exp(n) + + @sparkFunction(sf.expm1(_:Column)) + def expm1(n: Double): Double = math.expm1(n) + + @sparkFunction(sf.factorial) + def factorial(n: Int): Long = Factorial.factorial(n) + + @sparkFunction(sf.floor(_:Column)) + def floor(d: Double): Double = math.floor(d) + + @sparkFunction((cols: Seq[Column]) => sf.greatest(cols:_*)) + def greatest[T : Ordering](first: T, rest: T*): T = rest.foldLeft(first)(implicitly[Ordering[T]].max) + + @sparkFunction(sf.hex) + def hex(arr: Array[Byte]): String = Hex.hex(arr).toString + + @sparkFunction(sf.hex) + def hex(l: Long): String = Hex.hex(l).toString + + @sparkFunction(sf.unhex) + def unhex(str: String): Array[Byte] = Hex.unhex(str.getBytes()) + + @sparkFunction((sf.hypot(_:Column,_:Column)).tupled) + def hypot(a: Double, b: Double): Double = math.hypot(a, b) + + @sparkFunction((cols: Seq[Column]) => sf.least(cols:_*)) + def least[T : Ordering](first: T, rest: T*) = rest.foldLeft(first)(implicitly[Ordering[T]].min) + + @sparkFunction(sf.log(_:Column)) + def log(n: Double): Double = math.log(n) + + @sparkFunction((sf.log(_:Double,_:Column)).tupled) + def log(base: Double, n: Double): Double = math.log(n) / math.log(base) + + @sparkFunction(sf.log10(_:Column)) + def log10(n: Double): Double = math.log10(n) + + @sparkFunction(sf.log1p(_:Column)) + def log1p(n: Double): Double = math.log1p(n) + + @sparkFunction(sf.log2(_:Column)) + def log2(n: Double): Double = log(n, 2.0) + + @sparkFunction((sf.pow(_:Column,_:Column)).tupled) + def pow(base: Double, exp: Double): Double = math.pow(base, exp) + + @sparkFunction((sf.pmod _).tupled) + def pmod[T](dividend: T, divisor: T)(implicit int: Integral[T]): T = { + val rem = int.rem(dividend, divisor) + if(int.lt(rem, int.zero)) + int.plus(rem, divisor) + else + rem + } + + @sparkFunction(sf.rint(_:Column)) + def rint(n: Double): Double = math.rint(n) + + @sparkFunction(sf.round(_:Column)) + def round(n: Double): Long = math.round(n) + + @sparkFunction(sf.round(_:Column)) + def round(n: Float): Int = math.round(n) + + @sparkFunction((sf.round(_:Column, _:Int)).tupled) + def round(n: Double, scale: Int): Double = BigDecimal(n).setScale(scale, RoundingMode.HALF_UP).toDouble + + @sparkFunction(sf.bround(_:Column)) + def bround(n: Double): Long = BigDecimal(n).setScale(0, RoundingMode.HALF_EVEN).toLong + + @sparkFunction((sf.bround(_:Column, _:Int)).tupled) + def bround(n: Double, scale: Int): Double = BigDecimal(n).setScale(scale, RoundingMode.HALF_EVEN).toDouble + + @sparkFunction((sf.shiftLeft _).tupled) + def shiftLeft(l: Long, bits: Int): Long = l << bits + + @sparkFunction((sf.shiftLeft _).tupled) + def shiftLeft(i: Int, bits: Int): Int = i << bits + + @sparkFunction((sf.shiftRight _).tupled) + def shiftRight(l: Long, bits: Int): Long = l >> bits + + @sparkFunction((sf.shiftRight _).tupled) + def shiftRight(i: Int, bits: Int): Int = i >> bits + + @sparkFunction((sf.shiftRightUnsigned _).tupled) + def shiftRightUnsigned(l: Long, bits: Int): Long = l >>> bits + + @sparkFunction((sf.shiftRightUnsigned _).tupled) + def shiftRightUnsigned(i: Int, bits: Int): Int = i >>> bits + + @sparkFunction(sf.signum(_:Column)) + def signum[T : Numeric](t: T): Int = implicitly[Numeric[T]].signum(t) + + @sparkFunction(sf.sin(_:Column)) + def sin(d: Double): Double = math.sin(d) + + @sparkFunction(sf.sinh(_:Column)) + def sinh(d: Double): Double = math.sinh(d) + + @sparkFunction(sf.tan(_:Column)) + def tan(d: Double): Double = math.tan(d) + + @sparkFunction(sf.tanh(_:Column)) + def tanh(d: Double): Double = math.tanh(d) + + @sparkFunction(sf.toDegrees(_:Column)) + def toDegrees(d: Double): Double = math.toDegrees(d) + + @sparkFunction(sf.toRadians(_:Column)) + def toRadians(d: Double): Double = math.toRadians(d) + + ///////////////////////// + // Aggregate functions // + ///////////////////////// + private val aggMsg = "Aggregate function can only be used inside a TypedDataset expression" + + @sparkAggregate(sf.approxCountDistinct(_:Column)) + @compileTimeOnly(aggMsg) + def approxCountDistinct[T](t: T): Long = ??? + + @sparkAggregate((sf.approxCountDistinct(_:Column,_:Double)).tupled) + @compileTimeOnly(aggMsg) + def approxCountDistinct[T](a: T, b: Double): Long = ??? + + @sparkAggregate(sf.avg(_:Column)) + @compileTimeOnly(aggMsg) + def avg[T : Numeric](t: T): Double = ??? + + @sparkAggregate(sf.collect_list(_:Column)) + @compileTimeOnly(aggMsg) + def collect_list[T](t: T): Seq[T] = ??? + + @sparkAggregate(sf.collect_set(_:Column)) + @compileTimeOnly(aggMsg) + def collect_set[T](t: T): Seq[T] = ??? + + @sparkAggregate((sf.corr(_:Column,_:Column)).tupled) + @compileTimeOnly(aggMsg) + def corr(a: Double, b: Double): Double = ??? + + @sparkAggregate(sf.count(_:Column)) + @compileTimeOnly(aggMsg) + def count[T](t: T): Long = ??? + + @sparkAggregate(((col: Column, cols: Seq[Column]) => sf.countDistinct(col, cols:_*)).tupled) + @compileTimeOnly(aggMsg) + def countDistinct(columns: Any*): Long = ??? + +} diff --git a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala index fd1ab2863..4558fda06 100644 --- a/dataset/src/main/scala/frameless/macros/ColumnMacros.scala +++ b/dataset/src/main/scala/frameless/macros/ColumnMacros.scala @@ -1,7 +1,8 @@ package frameless.macros +import frameless.functions.quoted.QuotedFunc import frameless.{TypedColumn, TypedEncoder} -import org.apache.spark.sql.ColumnName +import org.apache.spark.sql, sql.ColumnName import scala.collection.immutable.Queue import scala.reflect.macros.{TypecheckException, whitebox} @@ -121,10 +122,12 @@ class ColumnMacros(val c: whitebox.Context) { case LiteralExpr(tpe, tree) => q"org.apache.spark.sql.functions.lit($tree)" case Column(tpe, selectorStr) => q"new $ColumnName($selectorStr)" - // $COVERAGE-OFF$ - Functions can't currently be used (and currently give an error before this is reached) case FunctionApplication(tpe, src, fn, args) => - c.abort(c.enclosingPosition, "Functions not yet supported in selectExpr") - // $COVERAGE-ON$ + val argTrees = args.map { + case Left(tree) => tree + case Right(argExpr) => Expr.refold(argExpr) + } + q"$src.$fn(..$argTrees)" case BinaryOperator(tpe, lhs, op, rhs) => val lhsTree = refold(lhs) val rhsTree = refold(rhs) @@ -148,7 +151,7 @@ class ColumnMacros(val c: whitebox.Context) { case class Construct(tpe: Type, args: Map[String, Expr]) extends Expr // Construct tuple or case class case class UnaryRAOperator(tpe: Type, rhs: Expr, op: TermName) extends Expr // Unary right-associative op case class BinaryOperator(tpe: Type, lhs: Expr, op: TermName, rhs: Expr) extends Expr // Binary op - case class FunctionApplication(tpe: Type, src: Tree, fn: TermName, args: List[Expr]) extends Expr // SQL function + case class FunctionApplication(tpe: Type, src: Tree, fn: TermName, args: List[Either[Tree, Expr]]) extends Expr // SQL function case class Column(tpe: Type, selectorStr: String) extends Expr // Column of this dataset case class LiteralExpr(tpe: Type, tree: Tree) extends Expr // Literal expression @@ -250,23 +253,71 @@ class ColumnMacros(val c: whitebox.Context) { ) }.getOrElse { if (isColumnOp(op, 1)) { + // special case - substitute === for == + val realOp = if(op.decodedName.toString == "==") TermName("===").encodedName.toTermName else op requireBoth(lhsE, rhsE)( (lhs, rhs) => failFrom(lhs, rhs), - (lhs, rhs) => Right(BinaryOperator(tree.tpe, lhs, op, rhs)) + (lhs, rhs) => Right(BinaryOperator(tree.tpe, lhs, realOp, rhs)) ) } else addError(lhsE, rhsE)(tree, s"${op.decodedName} is not a valid column operator") } } - // A function application - what to do with this? How can we check if it's an OK function? - // Check org.apache.spark.sql.functions? - // I think we have to port all spark functions to typed versions so we can typecheck here - case Apply(Select(src, fn: TermName), AllExprs(argsE)) => - Some(Left(List((tree, "Function application not currently supported")))) + // A function application + case Apply(fn @ (Select(_, _) | TypeApply(_, _)), argTrees) => Some { + // if the function is annotated with a QuotedFunc, it gets rewritten to the native spark function + fn.symbol.annotations.find(_.tree.tpe <:< weakTypeOf[QuotedFunc]).map { + annot => annot.tree match { + case Apply(Select(New(tpt), _), List(sparkFuncTree)) => + val Function(sparkFnArgs, body) = sparkFuncTree match { + case SingleExpression(f@Function(_, _)) => f + case Select(SingleExpression(f @ Function(_, _)), TermName("tupled")) => f + case other => + val o = other + println(o) + EmptyTree + } + + + val sparkArgs = if(sparkFnArgs.nonEmpty) { + sparkFnArgs.zipAll(argTrees.map(Some(_)), sparkFnArgs.last, None).flatMap { + case (_, None) => Nil + case (ValDef(_, _, typ, _), Some(This(Right(expr)))) if typ.tpe <:< weakTypeOf[sql.Column] => + List(Right(expr)) + case (ValDef(_, _, typ, _), Some(This(Right(expr)))) if typ.tpe <:< weakTypeOf[Seq[sql.Column]] => + List(Right(expr)) + case (ValDef(_, _, typ, _), Some(Tuple2Tree(This(Right(a)), This(Right(b))))) + if typ.tpe <:< weakTypeOf[Seq[(sql.Column, sql.Column)]] => + List(Right(a), Right(b)) + case (arg, Some(argTree)) => + List(Left(argTree)) + } + } else { + Nil + } + + val (src, term) = body match { + case Apply(Select(q, t), _) => (q, t) + } + + Right(FunctionApplication(tree.tpe, src, term.toTermName, sparkArgs)) + case other => + println(other) + Left(List((tree, "Function application not currently supported"))) + } + }.getOrElse { + Left(List((tree, "Function application not currently supported"))) + } + + } //Some(argsE.right.map(args => FunctionApplication(tree.tpe, src, fn, args))) + case Apply(This(fn), implicitArgs) => Some(fn) + case Apply(_, _) => Some(Left(List((tree, "Function application not currently supported")))) + case Typed(This(expr), _) => Some(expr) + case _ => None } @@ -374,4 +425,21 @@ class ColumnMacros(val c: whitebox.Context) { object ArgName { def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name)) } + + object Tuple2Tree { + def unapply(tree: Tree): Option[(Tree, Tree)] = tree match { + case q"($a, $b)" => Some((a, b)) + case q"$predef.ArrowAssoc[..$tptA]($a).->[..$tptB]($b)" => Some((a, b)) + case other => + None + } + } + + object SingleExpression { + def unapply(tree: Tree): Option[Tree] = tree match { + case Block(Nil, expr) => Some(expr) + case Block(_, _) => None + case other => Some(other) + } + } } diff --git a/dataset/src/test/scala/frameless/CreateTests.scala b/dataset/src/test/scala/frameless/CreateTests.scala index cdc38c170..d9088c913 100644 --- a/dataset/src/test/scala/frameless/CreateTests.scala +++ b/dataset/src/test/scala/frameless/CreateTests.scala @@ -44,6 +44,7 @@ class CreateTests extends TypedDatasetSuite with Matchers { check(prop[Float]) check(prop[Double]) check(prop[X1[String]]) + check(prop[String]) } @@ -63,6 +64,27 @@ class CreateTests extends TypedDatasetSuite with Matchers { check(prop[Float]) check(prop[Double]) check(prop[X1[String]]) + check(prop[String]) + + } + + test("map fields") { + + def prop[A, B](implicit arb: Arbitrary[Map[A, B]], encoder: TypedEncoder[X1[Map[A, B]]]) = forAll { + data: Map[A, B] => + val Seq(X1(map)) = TypedDataset.create(Seq(X1(data))).collect().run() + Prop(map == data) + } + + check(prop[String, Boolean]) + check(prop[String, Byte]) + check(prop[String, Short]) + check(prop[String, Int]) + check(prop[String, Long]) + check(prop[String, Float]) + check(prop[String, Double]) + check(prop[String, X1[String]]) + check(prop[String, String]) } diff --git a/dataset/src/test/scala/frameless/SelectExprTests.scala b/dataset/src/test/scala/frameless/SelectExprTests.scala index cf38b8a53..af368c81b 100644 --- a/dataset/src/test/scala/frameless/SelectExprTests.scala +++ b/dataset/src/test/scala/frameless/SelectExprTests.scala @@ -136,8 +136,49 @@ class SelectExprTests extends TypedDatasetSuite with Matchers { } + test("quoted functions expanded to spark native column functions") { + import frameless.functions.quoted._ + val ds = TypedDataset.create(Seq(X4(10, 20.0, "foo", 10L), X4(-20, 30.0, "bar", 20L))) + + def test[U : TypedEncoder](expected: U*)(fn: TypedDataset[X4[Int, Double, String, Long]] => TypedDataset[U]) = + fn(ds).collect().run() should contain theSameElementsAs expected + + def testConv[T, U : TypedEncoder](expected: T*)(fn: TypedDataset[X4[Int, Double, String, Long]] => TypedDataset[U])(conv: U => T) = + fn(ds).collect().run().map(conv) should contain theSameElementsAs expected + + test(10, 20)(_.selectExpr(x => abs(x.a))) + testConv(List(10, 10, 10), List(-20, -20, -20))(_.selectExpr(x => array(x.a, x.a, x.a)))(_.toList) + test(Map("foo" -> 10), Map("bar" -> -20))(_.selectExpr(x => map(x.c -> x.a))) + test("foo", "bar")(_.selectExpr(x => coalesce(null, x.c))) + test(("foo", ""), ("bar", ""))(_.selectExpr(x => (x.c, input_file_name()))) + test((false, true), (false, true))(_.selectExpr(x => (isnan(x.b), isnan(Double.NaN)))) + test((false, true), (false, true))(_.selectExpr(x => (isnull(x.c), isnull(null: String)))) + test(0L, 1L)(_.selectExpr(x => monotonically_increasing_id())) + test(20.0, 30.0)(_.selectExpr(x => nanvl(Double.NaN, x.b))) + test(-10, 20)(_.selectExpr(x => negate(x.a))) + + // no way to compare output, but test to make sure the functions can be executed + ds.selectExpr(x => (x.a, rand(10L), rand(), randn(10L), randn())).collect().run() + + test(math.sqrt(20.0), math.sqrt(30.0))(_.selectExpr(x => sqrt(x.b))) + test("ten", "not ten")(_.selectExpr(x => when(x.a == 10, "ten") otherwise "not ten")) + test(~10, ~(-20))(_.selectExpr(x => bitwiseNOT(x.a))) + test( + (math.acos(20.0 / 32.0), math.asin(20.0 / 32.0), math.atan(20.0), math.atan2(20.0, 20.0)), + (math.acos(30.0 / 32.0), math.asin(30.0 / 32.0), math.atan(30.0), math.atan2(30.0, 30.0)) + )(_.selectExpr(x => (acos(x.b / 32), asin(x.b / 32), atan(x.b), atan2(x.b, x.b)))) + test(java.lang.Long.toBinaryString(10L), java.lang.Long.toBinaryString(20L))(_.selectExpr(x => bin(x.d))) + + // aggregate functions + test(2L)(_.selectExpr(x => count(x.a))) + test(2L)(_.selectExpr(x => countDistinct(x.a))) + test(2L)(_.selectExpr(x => countDistinct(x.a, x.b))) + + illTyped("val a: Int = 22.22") + } + //TODO: could we just UDF the function in this case? - test("functions not yet supported") { + test("arbitrary functions not yet supported") { def strfun(s: String) = s"fun${s}fun" object Fun { From d5b84952400019625bea6b0b634bc640e27d216f Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Tue, 21 Feb 2017 16:18:36 -0800 Subject: [PATCH 08/10] Prioritize injection over derivation --- .../main/scala/frameless/TypedEncoder.scala | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 204763f0a..63cd9560e 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -450,27 +450,28 @@ object TypedEncoder extends TypedEncoder0 { } } -} -sealed trait TypedEncoder0 { self: TypedEncoder.type => /** Encodes things using injection if there is one defined */ implicit def usingInjection[A: ClassTag, B] - (implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] = - new TypedEncoder[A] { - def nullable: Boolean = trb.nullable - def sourceDataType: DataType = FramelessInternals.objectTypeFor[A](classTag) - def targetDataType: DataType = trb.targetDataType - - def constructorFor(path: Expression): Expression = { - val bexpr = trb.constructorFor(path) - Invoke(Literal.fromObject(inj), "invert", sourceDataType, Seq(bexpr)) - } + (implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] = + new TypedEncoder[A] { + def nullable: Boolean = trb.nullable + def sourceDataType: DataType = FramelessInternals.objectTypeFor[A](classTag) + def targetDataType: DataType = trb.targetDataType + + def constructorFor(path: Expression): Expression = { + val bexpr = trb.constructorFor(path) + Invoke(Literal.fromObject(inj), "invert", sourceDataType, Seq(bexpr)) + } - def extractorFor(path: Expression): Expression = { - val invoke = Invoke(Literal.fromObject(inj), "apply", trb.sourceDataType, Seq(path)) - trb.extractorFor(invoke) - } + def extractorFor(path: Expression): Expression = { + val invoke = Invoke(Literal.fromObject(inj), "apply", trb.sourceDataType, Seq(path)) + trb.extractorFor(invoke) } + } +} + +sealed trait TypedEncoder0 { self: TypedEncoder.type => /** Encodes things as records if there is not Injection defined */ implicit def usingDerivation[F, G <: HList]( From df32af4bdfa40512b9a2fc14d43e230d1c970008 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Wed, 22 Feb 2017 07:10:43 -0800 Subject: [PATCH 09/10] Added more aggregate functions and misc functions --- .../scala/frameless/functions/quoted.scala | 146 +++++++++++++++++- 1 file changed, 145 insertions(+), 1 deletion(-) diff --git a/dataset/src/main/scala/frameless/functions/quoted.scala b/dataset/src/main/scala/frameless/functions/quoted.scala index 78379a472..c29d955d3 100644 --- a/dataset/src/main/scala/frameless/functions/quoted.scala +++ b/dataset/src/main/scala/frameless/functions/quoted.scala @@ -26,6 +26,7 @@ object quoted { class sparkColumnOp[A, B](referrent: Column => A => B) extends StaticAnnotation with QuotedFunc class sparkAggregate[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc class sparkWindowFunction[A, B](referent: A => B) extends StaticAnnotation with QuotedFunc + class sparkWindowFunction0[B](referent: () => B) extends StaticAnnotation with QuotedFunc @sparkFunction(sf.abs) def abs[T : Numeric](t: T): T = implicitly[Numeric[T]].abs(t) @@ -171,7 +172,7 @@ object quoted { def hypot(a: Double, b: Double): Double = math.hypot(a, b) @sparkFunction((cols: Seq[Column]) => sf.least(cols:_*)) - def least[T : Ordering](first: T, rest: T*) = rest.foldLeft(first)(implicitly[Ordering[T]].min) + def least[T : Ordering](first: T, rest: T*): T = rest.foldLeft(first)(implicitly[Ordering[T]].min) @sparkFunction(sf.log(_:Column)) def log(n: Double): Double = math.log(n) @@ -294,4 +295,147 @@ object quoted { @compileTimeOnly(aggMsg) def countDistinct(columns: Any*): Long = ??? + @sparkAggregate((sf.covar_pop(_:Column,_:Column)).tupled) + @compileTimeOnly(aggMsg) + def covar_pop(a: Double, b: Double): Double = ??? + + @sparkAggregate((sf.covar_samp(_:Column,_:Column)).tupled) + @compileTimeOnly(aggMsg) + def covar_samp(a: Double, b: Double): Double = ??? + + @sparkAggregate((sf.first(_:Column,_:Boolean)).tupled) + @compileTimeOnly(aggMsg) + def first[T](col: T, ignoreNulls: Boolean) = ??? + + @sparkAggregate(sf.first(_:Column)) + @compileTimeOnly(aggMsg) + def first[T](col: T): T = first(col, ignoreNulls = false) + + @sparkAggregate(sf.grouping(_:Column)) + @compileTimeOnly(aggMsg) + def grouping(col: Any): Int = ??? + + @sparkAggregate((cols: Seq[Column]) => sf.grouping_id(cols:_*)) + @compileTimeOnly(aggMsg) + def grouping_id(cols: Any*): Int = ??? + + @sparkAggregate(sf.kurtosis(_:Column)) + @compileTimeOnly(aggMsg) + def kurtosis(col: Double): Double = ??? + + @sparkAggregate((sf.last(_:Column,_:Boolean)).tupled) + @compileTimeOnly(aggMsg) + def last[T](col: T, ignoreNulls: Boolean) = ??? + + @sparkAggregate(sf.last(_:Column)) + @compileTimeOnly(aggMsg) + def last[T](col: T): T = last(col, ignoreNulls = false) + + @sparkAggregate(sf.max(_:Column)) + @compileTimeOnly(aggMsg) + def max[T : Ordering](col: T): T = ??? + + @sparkAggregate(sf.mean(_:Column)) + @compileTimeOnly(aggMsg) + def mean[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.min(_:Column)) + @compileTimeOnly(aggMsg) + def min[T : Ordering](col: T): T = ??? + + @sparkAggregate(sf.skewness(_:Column)) + @compileTimeOnly(aggMsg) + def skewness(col: Double): Double = ??? + + @sparkAggregate(sf.stddev(_:Column)) + @compileTimeOnly(aggMsg) + def stddev(col: Double): Double = ??? + + @sparkAggregate(sf.stddev_samp(_:Column)) + @compileTimeOnly(aggMsg) + def stddev_samp(col: Double): Double = ??? + + @sparkAggregate(sf.stddev_pop(_:Column)) + @compileTimeOnly(aggMsg) + def stddev_pop(col: Double): Double = ??? + + @sparkAggregate(sf.sum(_:Column)) + @compileTimeOnly(aggMsg) + def sum[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.sumDistinct(_:Column)) + @compileTimeOnly(aggMsg) + def sumDistinct[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.variance(_:Column)) + @compileTimeOnly(aggMsg) + def variance[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.var_samp(_:Column)) + @compileTimeOnly(aggMsg) + def var_samp[T : Numeric](col: T): T = ??? + + @sparkAggregate(sf.var_pop(_:Column)) + @compileTimeOnly(aggMsg) + def var_pop[T : Numeric](col: T): T = ??? + + ////////////////////// + // Window functions // + ////////////////////// + + @sparkWindowFunction0(sf.cume_dist) + @compileTimeOnly(aggMsg) + def cume_dist(): Double = ??? + + @sparkWindowFunction0(sf.dense_rank) + @compileTimeOnly(aggMsg) + def dense_rank(): Double = ??? + + @sparkWindowFunction((sf.lag(_:Column,_:Int)).tupled) + @compileTimeOnly(aggMsg) + def lag[T](col: T, offset: Int): T = ??? + + @sparkWindowFunction((sf.lead(_:Column,_:Int)).tupled) + @compileTimeOnly(aggMsg) + def lead[T](col: T, offset: Int): T = ??? + + @sparkWindowFunction(sf.ntile) + @compileTimeOnly(aggMsg) + def ntile(n: Int): Int = ??? + + @sparkWindowFunction0(sf.percent_rank) + @compileTimeOnly(aggMsg) + def percent_rank(): Double = ??? + + @sparkWindowFunction0(sf.rank) + @compileTimeOnly(aggMsg) + def rank(): Int = ??? + + @sparkWindowFunction0(sf.row_number) + @compileTimeOnly(aggMsg) + def row_number(): Long = ??? + + //////////////////// + // Misc functions // + //////////////////// + + @sparkFunction(sf.md5) + def md5(col: Array[Byte]): String = org.apache.commons.codec.digest.DigestUtils.md5Hex(col) + + @sparkFunction(sf.sha1) + def sha1(col: Array[Byte]): String = org.apache.commons.codec.digest.DigestUtils.sha1Hex(col) + + @sparkFunction((sf.sha2(_, _)).tupled) + @compileTimeOnly("Implemented only in frameless expressions") + def sha2(col: Array[Byte], numBits: Int): String = ??? + + @sparkFunction(sf.crc32) + @compileTimeOnly("Implemented only in frameless expressions") + def sha2(col: Array[Byte]): Long = ??? + + @sparkFunction(sf.hash) + @compileTimeOnly("Implemented only in frameless expressions") + def hash(cols: Any*): Int = ??? + + } From aa1e55b04effccf839c9e6f74930f0d23b48ee18 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Wed, 22 Feb 2017 07:28:36 -0800 Subject: [PATCH 10/10] Revert prioritization --- dataset/src/main/scala/frameless/TypedEncoder.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 63cd9560e..89512a2e2 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -29,7 +29,7 @@ abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Seria // - import TypedEncoder.usingInjection // - import TypedEncoder.usingDerivation // """) -object TypedEncoder extends TypedEncoder0 { +object TypedEncoder { def apply[T: TypedEncoder]: TypedEncoder[T] = implicitly[TypedEncoder[T]] implicit val unitEncoder: TypedEncoder[Unit] = new TypedEncoder[Unit] { @@ -469,9 +469,6 @@ object TypedEncoder extends TypedEncoder0 { trb.extractorFor(invoke) } } -} - -sealed trait TypedEncoder0 { self: TypedEncoder.type => /** Encodes things as records if there is not Injection defined */ implicit def usingDerivation[F, G <: HList](