From 815b7a2f5c7c87e62abb638752b6a918a2f00b35 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Fri, 17 Feb 2017 18:39:00 -0800 Subject: [PATCH 1/6] Change TypeDataset#apply syntax to use a function Before/After: ```scala case class Foo(a: String, b: Int) case class Bar(foo: Foo, c: Double) val ds = TypedDataset.create(spark.createDataset(Seq.empty[Bar])) // After: val a = ds.select(ds(_.foo.a)) val b = ds.select(ds(_.foo.b)) val c = ds.select(ds(_.c)) // Before: val a = ??? // no equivalent val b = ??? // no equivalent val c = ds.select(ds('c)) ``` I am proposing this change because: 1. It typechecks before any macros are invoked 2. It makes editors like IntelliJ happy (because of #1) 3. It seems more idiomatic to me than using `Symbol`s 4. It allows specifying columns of nested structures, which is impossible with the current syntax The downside is that a macro is used. But, a macro is used (indirectly, through Witness.apply) for the existing syntax anyway. Also, that syntax uses implicit conversions, which the proposed syntax doesn't. The macro introduced for the new syntax is reasonably uncomplicated. I changed `apply` and left `col` intact, so that both syntaxes would be available. If this change is distasteful (understandable due to the BC break), it could be renamed to something other than `apply`. However, I really think this syntax is strictly more powerful than what's currently used for `apply`, and should be the default behavior. --- .../main/scala/frameless/TypedColumn.scala | 3 +- .../main/scala/frameless/TypedDataset.scala | 6 +- .../scala/frameless/column/ColumnMacros.scala | 62 +++++++++++++++++++ .../src/test/scala/frameless/ColTests.scala | 35 +++++++++++ .../test/scala/frameless/FilterTests.scala | 6 +- .../test/scala/frameless/SelectTests.scala | 16 ++--- dataset/src/test/scala/frameless/XN.scala | 2 + 7 files changed, 114 insertions(+), 16 deletions(-) create mode 100644 dataset/src/main/scala/frameless/column/ColumnMacros.scala diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 92719d482..b17d2eafe 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -271,6 +271,7 @@ object TypedColumn { lgen: LabelledGeneric.Aux[T, H], selector: Selector.Aux[H, K, V] ): Exists[T, K, V] = new Exists[T, K, V] {} + } implicit class OrderedTypedColumnSyntax[T, U: CatalystOrdered](col: TypedColumn[T, U]) { @@ -279,4 +280,4 @@ object TypedColumn { def >(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped > other.untyped).typed def >=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped >= other.untyped).typed } -} +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index ad9913131..056ae0349 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -62,11 +62,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val * * It is statically checked that column with such name exists and has type `A`. */ - def apply[A](column: Witness.Lt[Symbol])( - implicit - exists: TypedColumn.Exists[T, column.T, A], + def apply[A](selector: T => A)(implicit encoder: TypedEncoder[A] - ): TypedColumn[T, A] = col(column) + ): TypedColumn[T, A] = macro frameless.column.ColumnMacros.fromFunction[T, A] /** Returns `TypedColumn` of type `A` given it's name. * diff --git a/dataset/src/main/scala/frameless/column/ColumnMacros.scala b/dataset/src/main/scala/frameless/column/ColumnMacros.scala new file mode 100644 index 000000000..f16b3e539 --- /dev/null +++ b/dataset/src/main/scala/frameless/column/ColumnMacros.scala @@ -0,0 +1,62 @@ +package frameless.column + +import frameless.{TypedColumn, TypedEncoder, TypedExpressionEncoder} + +import scala.collection.immutable.Queue +import scala.reflect.macros.whitebox + +class ColumnMacros(val c: whitebox.Context) { + import c.universe._ + + // could be used to reintroduce apply('foo) + def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + val B = weakTypeOf[B].dealias + val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})") + c.typecheck(q"${c.prefix}.col[$B]($witness)") + } + + def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { + def fail(tree: Tree) = c.abort( + tree.pos, + s"Could not create a column identifier from $tree - try using _.a.b syntax") + + val A = weakTypeOf[A].dealias + val B = weakTypeOf[B].dealias + + val selectorStr = selector.tree match { + case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => body match { + case `argName`(strs) => strs.mkString(".") + case other => fail(other) + } + case other => fail(other) + } + + val typedCol = appliedType( + weakTypeOf[TypedColumn[_, _]].typeConstructor, A, B + ) + + val TEEObj = reify(TypedExpressionEncoder) + + val datasetCol = c.typecheck( + q"${c.prefix}.dataset.col($selectorStr).as[$B]($TEEObj.apply[$B]($encoder))" + ) + + c.typecheck(q"new $typedCol($datasetCol)") + } + + case class NameExtractor(name: TermName) { + private val This = this + def unapply(tree: Tree): Option[Queue[String]] = { + tree match { + case Ident(`name`) => Some(Queue.empty) + case Select(This(strs), nested) => Some(strs enqueue nested.toString) + case Apply(This(strs), List()) => Some(strs) + case _ => None + } + } + } + + object ArgName { + def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name)) + } +} diff --git a/dataset/src/test/scala/frameless/ColTests.scala b/dataset/src/test/scala/frameless/ColTests.scala index ad62aa068..10b2e1bbb 100644 --- a/dataset/src/test/scala/frameless/ColTests.scala +++ b/dataset/src/test/scala/frameless/ColTests.scala @@ -29,6 +29,41 @@ class ColTests extends TypedDatasetSuite { () } + test("colApply") { + val x4 = TypedDataset.create[X4[Int, String, Long, Boolean]](Nil) + val t4 = TypedDataset.create[(Int, String, Long, Boolean)](Nil) + val x4x4 = TypedDataset.create[X4X4[Int, String, Long, Boolean]](Nil) + + x4(_.a) + t4(_._1) + x4[Int](_.a) + t4[Int](_._1) + + illTyped("x4[String](_.a)", "type mismatch;\n found : Int\n required: String") + + x4(_.b) + t4(_._2) + + x4[String](_.b) + t4[String](_._2) + + illTyped("x4[Int](_.b)", "type mismatch;\n found : String\n required: Int") + + x4x4(_.xa.a) + x4x4[Int](_.xa.a) + x4x4(_.xa.b) + x4x4[String](_.xa.b) + + x4x4(_.xb.a) + x4x4[Int](_.xb.a) + x4x4(_.xb.b) + x4x4[String](_.xb.b) + + illTyped("x4x4[String](_.xa.a)", "type mismatch;\n found : Int\n required: String") + + () + } + test("colMany") { type X2X2 = X2[X2[Int, String], X2[Long, Boolean]] val x2x2 = TypedDataset.create[X2X2](Nil) diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala index 837afed2d..81710c69a 100644 --- a/dataset/src/test/scala/frameless/FilterTests.scala +++ b/dataset/src/test/scala/frameless/FilterTests.scala @@ -22,7 +22,7 @@ class FilterTests extends TypedDatasetSuite { test("filter with arithmetic expressions: addition") { check(forAll { (data: Vector[X1[Int]]) => val ds = TypedDataset.create(data) - val res = ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector + val res = ds.filter((ds(_.a) + 1) === (ds(_.a) + 1)).collect().run().toVector res ?= data }) } @@ -31,7 +31,7 @@ class FilterTests extends TypedDatasetSuite { val t = X1(1) :: X1(2) :: X1(3) :: Nil val tds: TypedDataset[X1[Int]] = TypedDataset.create(t) - assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1))) - assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1))) + assert(tds.filter(tds(_.a) * 2 === 2).collect().run().toVector === Vector(X1(1))) + assert(tds.filter(tds(_.a) * 3 === 3).collect().run().toVector === Vector(X1(1))) } } diff --git a/dataset/src/test/scala/frameless/SelectTests.scala b/dataset/src/test/scala/frameless/SelectTests.scala index cd3d6c411..99cf86d01 100644 --- a/dataset/src/test/scala/frameless/SelectTests.scala +++ b/dataset/src/test/scala/frameless/SelectTests.scala @@ -297,7 +297,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) + const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) + const).collect().run().toVector val data2 = data.map { case X1(a) => num.plus(a, const) } dataset2 ?= data2 @@ -319,7 +319,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) * const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) * const).collect().run().toVector val data2 = data.map { case X1(a) => num.times(a, const) } dataset2 ?= data2 @@ -341,7 +341,7 @@ class SelectTests extends TypedDatasetSuite { ): Prop = { val ds = TypedDataset.create(data) - val dataset2 = ds.select(ds('a) - const).collect().run().toVector + val dataset2 = ds.select(ds(_.a) - const).collect().run().toVector val data2 = data.map { case X1(a) => num.minus(a, const) } dataset2 ?= data2 @@ -363,7 +363,7 @@ class SelectTests extends TypedDatasetSuite { val ds = TypedDataset.create(data) if (const != 0) { - val dataset2 = ds.select(ds('a) / const).collect().run().toVector.asInstanceOf[Vector[A]] + val dataset2 = ds.select(ds(_.a) / const).collect().run().toVector.asInstanceOf[Vector[A]] val data2 = data.map { case X1(a) => frac.div(a, const) } dataset2 ?= data2 } else 0 ?= 0 @@ -379,17 +379,17 @@ class SelectTests extends TypedDatasetSuite { assert(t.select(t.col('_1)).collect().run().toList === List(2)) // Issue #54 val fooT = t.select(t.col('_1)).map(x => Tuple1.apply(x)).as[Foo] - assert(fooT.select(fooT('i)).collect().run().toList === List(2)) + assert(fooT.select(fooT(_.i)).collect().run().toList === List(2)) } test("unary - on arithmetic") { val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) - assert(e.select(-e('_1)).collect().run().toVector === Vector(-1, -2, -2)) - assert(e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(-3L, -6L, -3L)) + assert(e.select(-e(_._1)).collect().run().toVector === Vector(-1, -2, -2)) + assert(e.select(-(e(_._1) + e(_._3))).collect().run().toVector === Vector(-3L, -6L, -3L)) } test("unary - on strings should not type check") { val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) - illTyped("""e.select( -e('_2) )""") + illTyped("""e.select( -e(_._2) )""") } } \ No newline at end of file diff --git a/dataset/src/test/scala/frameless/XN.scala b/dataset/src/test/scala/frameless/XN.scala index 4fdab552e..0b19771e2 100644 --- a/dataset/src/test/scala/frameless/XN.scala +++ b/dataset/src/test/scala/frameless/XN.scala @@ -64,3 +64,5 @@ object X5 { implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] = Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e)) } + +case class X4X4[A, B, C, D](xa: X4[A, B, C, D], xb: X4[A, B, C, D]) From 82f1ba7a3ef0a3f7dc24b5932e43cf90e0676e5e Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Sat, 18 Feb 2017 11:46:00 -0800 Subject: [PATCH 2/6] Update tuts for new syntax --- docs/src/main/tut/GettingStarted.md | 34 +++++++++---------- .../main/tut/TypedDatasetVsSparkDataset.md | 6 ++-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/src/main/tut/GettingStarted.md b/docs/src/main/tut/GettingStarted.md index dd29ae587..03d037196 100644 --- a/docs/src/main/tut/GettingStarted.md +++ b/docs/src/main/tut/GettingStarted.md @@ -63,13 +63,13 @@ val apartmentsTypedDS2 = spark.createDataset(apartments).typed This is how we select a particular column from a `TypedDataset`: ```tut:book -val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS('city)) +val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS(_.city)) ``` This is completely safe, for instance suppose we misspell `city`: ```tut:book:fail -apartmentsTypedDS.select(apartmentsTypedDS('citi)) +apartmentsTypedDS.select(apartmentsTypedDS(_.citi)) ``` This gets caught at compile-time, whereas with traditional Spark `Dataset` the error appears at run-time. @@ -81,7 +81,7 @@ apartmentsDS.select('citi) `select()` supports arbitrary column operations: ```tut:book -apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('surface) + 2).show().run() +apartmentsTypedDS.select(apartmentsTypedDS(_.surface) * 10, apartmentsTypedDS(_.surface) + 2).show().run() ``` *Note that unlike the standard Spark api, here `show()` is lazy. It requires to apply `run()` for the @@ -91,14 +91,14 @@ apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('su Let us now try to compute the price by surface unit: ```tut:book:fail -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface)) ^ +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface)) ^ ``` Argh! Looks like we can't divide a `TypedColumn` of `Double` by `Int`. Well, we can cast our `Int`s to `Double`s explicitly to proceed with the computation. ```tut:book -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface).cast[Double]) +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface).cast[Double]) priceBySurfaceUnit.collect().run() ``` @@ -107,7 +107,7 @@ Alternatively, we can perform the cast implicitly: ```tut:book import frameless.implicits.widen._ -val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface)) +val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface)) priceBySurfaceUnit.collect.run() ``` @@ -115,7 +115,7 @@ Looks like it worked, but that `cast` looks unsafe right? Actually it is safe. Let's try to cast a `TypedColumn` of `String` to `Double`: ```tut:book:fail -apartmentsTypedDS('city).cast[Double] +apartmentsTypedDS(_.city).cast[Double] ``` The compile-time error tells us that to perform the cast, an evidence (in the form of `CatalystCast[String, Double]`) must be available. @@ -136,7 +136,7 @@ The cast is valid and the expression compiles: ```tut:book case class UpdatedSurface(city: String, surface: Int) -val updated = apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('surface) + 2).as[UpdatedSurface] +val updated = apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.surface) + 2).as[UpdatedSurface] updated.show(2).run() ``` @@ -144,7 +144,7 @@ Next we try to cast a `(String, String)` to an `UpdatedSurface` (which has types The cast is not valid and the expression does not compile: ```tut:book:fail -apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('city)).as[UpdatedSurface] +apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.city)).as[UpdatedSurface] ``` ### Projections @@ -161,7 +161,7 @@ import frameless.implicits.widen._ val aptds = apartmentsTypedDS // For shorter expressions case class ApartmentDetails(city: String, price: Double, surface: Int, ratio: Double) -val aptWithRatio = aptds.select(aptds('city), aptds('price), aptds('surface), aptds('price) / aptds('surface)).as[ApartmentDetails] +val aptWithRatio = aptds.select(aptds(_.city), aptds(_.price), aptds(_.surface), aptds(_.price) / aptds(_.surface)).as[ApartmentDetails] ``` Suppose we only want to work with `city` and `ratio`: @@ -222,7 +222,7 @@ val udf = apartmentsTypedDS.makeUDF(priceModifier) val aptds = apartmentsTypedDS // For shorter expressions -val adjustedPrice = aptds.select(aptds('city), udf(aptds('city), aptds('price))) +val adjustedPrice = aptds.select(aptds(_.city), udf(aptds(_.city), aptds(_.price))) adjustedPrice.show().run() ``` @@ -230,12 +230,12 @@ adjustedPrice.show().run() ## GroupBy and Aggregations Let's suppose we wanted to retrieve the average apartment price in each city ```tut:book -val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('price))) +val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.price))) priceByCity.collect().run() ``` Again if we try to aggregate a column that can't be aggregated, we get a compilation error ```tut:book:fail -apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('city))) ^ +apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.city))) ^ ``` Next, we combine `select` and `groupBy` to calculate the average price/surface ratio per city: @@ -243,9 +243,9 @@ Next, we combine `select` and `groupBy` to calculate the average price/surface r ```tut:book val aptds = apartmentsTypedDS // For shorter expressions -val cityPriceRatio = aptds.select(aptds('city), aptds('price) / aptds('surface)) +val cityPriceRatio = aptds.select(aptds(_.city), aptds(_.price) / aptds(_.surface)) -cityPriceRatio.groupBy(cityPriceRatio('_1)).agg(avg(cityPriceRatio('_2))).show().run() +cityPriceRatio.groupBy(cityPriceRatio(_._1)).agg(avg(cityPriceRatio(_._2))).show().run() ``` ## Joins @@ -265,7 +265,7 @@ val citiInfoTypedDS = TypedDataset.create(cityInfo) Here is how to join the population information to the apartment's dataset. ```tut:book -val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS('city), citiInfoTypedDS('name)) +val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS(_.city), citiInfoTypedDS(_.name)) withCityInfo.show().run() ``` @@ -278,7 +278,7 @@ We can then select which information we want to continue to work with: case class AptPriceCity(city: String, aptPrice: Double, cityPopulation: Int) withCityInfo.select( - withCityInfo.colMany('_2, 'name), withCityInfo.colMany('_1, 'price), withCityInfo.colMany('_2, 'population) + withCityInfo(_._2.name), withCityInfo(_._1.price), withCityInfo(_._2.population) ).as[AptPriceCity].show().run ``` diff --git a/docs/src/main/tut/TypedDatasetVsSparkDataset.md b/docs/src/main/tut/TypedDatasetVsSparkDataset.md index 5b6775d39..5a85731ff 100644 --- a/docs/src/main/tut/TypedDatasetVsSparkDataset.md +++ b/docs/src/main/tut/TypedDatasetVsSparkDataset.md @@ -116,19 +116,19 @@ with a fully optimized query plan. import frameless.TypedDataset val fds = TypedDataset.create(ds) -fds.filter( fds('i) === 10 ).select( fds('i) ).show().run() +fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).show().run() ``` And the optimized Physical Plan: ```tut:book -fds.filter( fds('i) === 10 ).select( fds('i) ).explain() +fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).explain() ``` And the compiler is our friend. ```tut:fail -fds.filter( fds('i) === 10 ).select( fds('x) ) +fds.filter( fds(_.i) === 10 ).select( fds(_.x) ) ``` ```tut:invisible From fa0ce2f6b3528cf737379041e43da7db595fef8c Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Sat, 18 Feb 2017 12:07:20 -0800 Subject: [PATCH 3/6] Coverage * Added failure test case for unusable function * Added scoverage tags for known-unreachable branches --- .../main/scala/frameless/column/ColumnMacros.scala | 14 +++++++++++--- dataset/src/test/scala/frameless/ColTests.scala | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/dataset/src/main/scala/frameless/column/ColumnMacros.scala b/dataset/src/main/scala/frameless/column/ColumnMacros.scala index f16b3e539..6a4dc5324 100644 --- a/dataset/src/main/scala/frameless/column/ColumnMacros.scala +++ b/dataset/src/main/scala/frameless/column/ColumnMacros.scala @@ -9,16 +9,20 @@ class ColumnMacros(val c: whitebox.Context) { import c.universe._ // could be used to reintroduce apply('foo) + // $COVERAGE-OFF$ Currently unused def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = { val B = weakTypeOf[B].dealias val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})") c.typecheck(q"${c.prefix}.col[$B]($witness)") } + // $COVERAGE-ON$ def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = { - def fail(tree: Tree) = c.abort( - tree.pos, - s"Could not create a column identifier from $tree - try using _.a.b syntax") + def fail(tree: Tree) = { + val err = + s"Could not create a column identifier from $tree - try using _.a.b syntax" + c.abort(tree.pos, err) + } val A = weakTypeOf[A].dealias val B = weakTypeOf[B].dealias @@ -28,7 +32,9 @@ class ColumnMacros(val c: whitebox.Context) { case `argName`(strs) => strs.mkString(".") case other => fail(other) } + // $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked case other => fail(other) + // $COVERAGE-ON$ } val typedCol = appliedType( @@ -50,7 +56,9 @@ class ColumnMacros(val c: whitebox.Context) { tree match { case Ident(`name`) => Some(Queue.empty) case Select(This(strs), nested) => Some(strs enqueue nested.toString) + // $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work case Apply(This(strs), List()) => Some(strs) + // $COVERAGE-ON$ case _ => None } } diff --git a/dataset/src/test/scala/frameless/ColTests.scala b/dataset/src/test/scala/frameless/ColTests.scala index 10b2e1bbb..d6aa27191 100644 --- a/dataset/src/test/scala/frameless/ColTests.scala +++ b/dataset/src/test/scala/frameless/ColTests.scala @@ -60,6 +60,7 @@ class ColTests extends TypedDatasetSuite { x4x4[String](_.xb.b) illTyped("x4x4[String](_.xa.a)", "type mismatch;\n found : Int\n required: String") + illTyped("x4x4(item => item.xa.a * 20)", "Could not create a column identifier from item\\.xa\\.a\\.\\*\\(20\\) - try using _\\.a\\.b syntax") () } From 6909cf5230900cec1b59ddd2404105f5ce51a7f3 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Wed, 22 Feb 2017 12:14:33 -0800 Subject: [PATCH 4/6] Clean up macro --- .../scala/frameless/column/ColumnMacros.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/dataset/src/main/scala/frameless/column/ColumnMacros.scala b/dataset/src/main/scala/frameless/column/ColumnMacros.scala index 6a4dc5324..a8dc9dffd 100644 --- a/dataset/src/main/scala/frameless/column/ColumnMacros.scala +++ b/dataset/src/main/scala/frameless/column/ColumnMacros.scala @@ -28,10 +28,8 @@ class ColumnMacros(val c: whitebox.Context) { val B = weakTypeOf[B].dealias val selectorStr = selector.tree match { - case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => body match { - case `argName`(strs) => strs.mkString(".") - case other => fail(other) - } + case NameExtractor(str) => str + case Function(_, body) => fail(body) // $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked case other => fail(other) // $COVERAGE-ON$ @@ -50,21 +48,23 @@ class ColumnMacros(val c: whitebox.Context) { c.typecheck(q"new $typedCol($datasetCol)") } - case class NameExtractor(name: TermName) { - private val This = this + case class NameExtractor(name: TermName) { Self => def unapply(tree: Tree): Option[Queue[String]] = { tree match { case Ident(`name`) => Some(Queue.empty) - case Select(This(strs), nested) => Some(strs enqueue nested.toString) + case Select(Self(strs), nested) => Some(strs enqueue nested.toString) // $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work - case Apply(This(strs), List()) => Some(strs) + case Apply(Self(strs), List()) => Some(strs) // $COVERAGE-ON$ case _ => None } } } - object ArgName { - def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name)) + object NameExtractor { + def unapply(tree: Tree): Option[String] = tree match { + case Function(List(ValDef(_, name, argTyp, _)), body) => NameExtractor(name).unapply(body).map(_.mkString(".")) + case _ => None + } } } From 72171d90e45aec02f50becfa197076f1f9bdff3c Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Thu, 21 Sep 2017 12:05:29 -0700 Subject: [PATCH 5/6] Update tests to new syntax --- .../src/test/scala/frameless/ColumnTests.scala | 2 +- .../src/test/scala/frameless/FilterTests.scala | 2 +- .../test/scala/frameless/WithColumnTest.scala | 10 +++++----- .../functions/AggregateFunctionsTests.scala | 16 ++++++++-------- .../frameless/functions/UnaryFunctionsTest.scala | 8 ++++---- .../scala/frameless/ops/ColumnTypesTest.scala | 2 +- .../src/test/scala/frameless/ops/PivotTest.scala | 14 +++++++------- 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/dataset/src/test/scala/frameless/ColumnTests.scala b/dataset/src/test/scala/frameless/ColumnTests.scala index f42ef920f..5ff581cec 100644 --- a/dataset/src/test/scala/frameless/ColumnTests.scala +++ b/dataset/src/test/scala/frameless/ColumnTests.scala @@ -41,6 +41,6 @@ class ColumnTests extends TypedDatasetSuite { test("toString") { val t = TypedDataset.create((1,2)::Nil) - t('_1).toString ?= t.dataset.col("_1").toString() + t(_._1).toString ?= t.dataset.col("_1").toString() } } diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala index c84211fef..4b581df22 100644 --- a/dataset/src/test/scala/frameless/FilterTests.scala +++ b/dataset/src/test/scala/frameless/FilterTests.scala @@ -67,7 +67,7 @@ class FilterTests extends TypedDatasetSuite { test("filter with values (not columns): addition") { check(forAll { (data: Vector[X1[Int]], const: Int) => val ds = TypedDataset.create(data) - val res = ds.filter(ds('a) > const).collect().run().toVector + val res = ds.filter(ds(_.a) > const).collect().run().toVector res ?= data.filter(_.a > const) }) } diff --git a/dataset/src/test/scala/frameless/WithColumnTest.scala b/dataset/src/test/scala/frameless/WithColumnTest.scala index abf9a05da..6bac9d293 100644 --- a/dataset/src/test/scala/frameless/WithColumnTest.scala +++ b/dataset/src/test/scala/frameless/WithColumnTest.scala @@ -7,11 +7,11 @@ class WithColumnTest extends TypedDatasetSuite { test("append five columns") { def prop[A: TypedEncoder](value: A): Prop = { val d = TypedDataset.create(X1(value) :: Nil) - val d1 = d.withColumn(d('a)) - val d2 = d1.withColumn(d1('_1)) - val d3 = d2.withColumn(d2('_2)) - val d4 = d3.withColumn(d3('_3)) - val d5 = d4.withColumn(d4('_4)) + val d1 = d.withColumn(d(_.a)) + val d2 = d1.withColumn(d1(_._1)) + val d3 = d2.withColumn(d2(_._2)) + val d4 = d3.withColumn(d3(_._3)) + val d5 = d4.withColumn(d4(_._4)) (value, value, value, value, value, value) ?= d5.collect().run().head } diff --git a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala index 8090efb83..6d6d11ede 100644 --- a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala @@ -282,7 +282,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check { forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] => val tds = TypedDataset.create(xs) - val tdsRes: Seq[(Int, Long)] = tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run() + val tdsRes: Seq[(Int, Long)] = tds.groupBy(tds(_._1)).agg(countDistinct(tds(_._2))).collect().run() tdsRes.toMap ?= xs.groupBy(_._1).mapValues(_.map(_._2).distinct.size.toLong).toSeq.toMap } } @@ -300,7 +300,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] => val tds = TypedDataset.create(xs) val tdsRes: Seq[(Int, Long, Long)] = - tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2))).collect().run() + tds.groupBy(tds(_._1)).agg(countDistinct(tds(_._2)), approxCountDistinct(tds(_._2))).collect().run() tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2) } } } @@ -310,7 +310,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val tds = TypedDataset.create(xs) val allowedError = 0.1 // 10% val tdsRes: Seq[(Int, Long, Long)] = - tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2), allowedError)).collect().run() + tds.groupBy(tds(_._1)).agg(countDistinct(tds(_._2)), approxCountDistinct(tds(_._2), allowedError)).collect().run() tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2, allowedError) } } } @@ -319,7 +319,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("collectList") { def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = { val tds = TypedDataset.create(xs) - val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run() + val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds(_.a)).agg(collectList(tds(_.b))).collect().run() tdsRes.toMap.mapValues(_.sorted) ?= xs.groupBy(_.a).mapValues(_.map(_.b).toVector.sorted) } @@ -333,7 +333,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("collectSet") { def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = { val tds = TypedDataset.create(xs) - val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run() + val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds(_.a)).agg(collectSet(tds(_.b))).collect().run() tdsRes.toMap.mapValues(_.toSet) ?= xs.groupBy(_.a).mapValues(_.map(_.b).toSet) } @@ -347,7 +347,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("lit") { def prop[A: TypedEncoder](xs: List[X1[A]], l: A): Prop = { val tds = TypedDataset.create(xs) - tds.select(tds('a), lit(l)).collect().run() ?= xs.map(x => (x.a, l)) + tds.select(tds(_.a), lit(l)).collect().run() ?= xs.map(x => (x.a, l)) } check(forAll(prop[Long] _)) @@ -379,7 +379,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val tds = TypedDataset.create(xs) // Typed implementation of bivar stats function - val tdBivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b), tds('c))).deserialized.map(kv => + val tdBivar = tds.groupBy(tds(_.a)).agg(framelessFun(tds(_.b), tds(_.c))).deserialized.map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)) ).collect().run() @@ -416,7 +416,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val tds = TypedDataset.create(xs) //typed implementation of univariate stats function - val tdUnivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b))).deserialized.map(kv => + val tdUnivar = tds.groupBy(tds(_.a)).agg(framelessFun(tds(_.b))).deserialized.map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)) ).collect().run() diff --git a/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala b/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala index d50d6c021..e492f456a 100644 --- a/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala +++ b/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala @@ -11,7 +11,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite { def prop[A: TypedEncoder](xs: List[X1[Vector[A]]]): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(size(tds('a))).collect().run().toVector + val framelessResults = tds.select(size(tds(_.a))).collect().run().toVector val scalaResults = xs.map(x => x.a.size).toVector framelessResults ?= scalaResults @@ -26,7 +26,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite { def prop[A: TypedEncoder : Ordering](xs: List[X1[Vector[A]]]): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(sortAscending(tds('a))).collect().run().toVector + val framelessResults = tds.select(sortAscending(tds(_.a))).collect().run().toVector val scalaResults = xs.map(x => x.a.sorted).toVector framelessResults ?= scalaResults @@ -42,7 +42,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite { def prop[A: TypedEncoder : Ordering](xs: List[X1[Vector[A]]]): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(sortDescending(tds('a))).collect().run().toVector + val framelessResults = tds.select(sortDescending(tds(_.a))).collect().run().toVector val scalaResults = xs.map(x => x.a.sorted.reverse).toVector framelessResults ?= scalaResults @@ -58,7 +58,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite { def prop[A: TypedEncoder](xs: List[X1[Vector[A]]]): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(explode(tds('a))).collect().run().toSet + val framelessResults = tds.select(explode(tds(_.a))).collect().run().toSet val scalaResults = xs.flatMap(_.a).toSet framelessResults ?= scalaResults diff --git a/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala b/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala index 303eb2cbd..59011df9c 100644 --- a/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala +++ b/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala @@ -10,7 +10,7 @@ class ColumnTypesTest extends TypedDatasetSuite { test("test summoning") { def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = { val d: TypedDataset[X4[A, B, C, D]] = TypedDataset.create(data) - val hlist = d('a) :: d('b) :: d('c) :: d('d) :: HNil + val hlist = d(_.a) :: d(_.b) :: d(_.c) :: d(_.d) :: HNil type TC[N] = TypedColumn[X4[A,B,C,D], N] diff --git a/dataset/src/test/scala/frameless/ops/PivotTest.scala b/dataset/src/test/scala/frameless/ops/PivotTest.scala index 359d97c4b..771783c73 100644 --- a/dataset/src/test/scala/frameless/ops/PivotTest.scala +++ b/dataset/src/test/scala/frameless/ops/PivotTest.scala @@ -22,9 +22,9 @@ class PivotTest extends TypedDatasetSuite { test("X4[Boolean, String, Int, Boolean] pivot on String") { def prop(data: Vector[X4[String, String, Int, Boolean]]): Prop = { val d = TypedDataset.create(data) - val frameless = d.groupBy(d('a)). - pivot(d('b)).on("a", "b", "c"). - agg(sum(d('c)), first(d('d))).collect().run().toVector + val frameless = d.groupBy(d(_.a)). + pivot(d(_.b)).on("a", "b", "c"). + agg(sum(d(_.c)), first(d(_.d))).collect().run().toVector val spark = d.dataset.groupBy("a") .pivot("b", Seq("a", "b", "c")) @@ -45,8 +45,8 @@ class PivotTest extends TypedDatasetSuite { test("Pivot on Boolean") { val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) val d = TypedDataset.create(x) - d.groupByMany(d('a)). - pivot(d('c)).on(true, false). + d.groupByMany(d(_.a)). + pivot(d(_.c)).on(true, false). agg(count[X3[String, Boolean, Boolean]]()). collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) // two true one false } @@ -54,8 +54,8 @@ class PivotTest extends TypedDatasetSuite { test("Pivot with groupBy on two columns, pivot on Long") { val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) val d = TypedDataset.create(x) - d.groupBy(d('a), d('b)). - pivot(d('c)).on(1L, 20L). + d.groupBy(d(_.a), d(_.b)). + pivot(d(_.c)).on(1L, 20L). agg(count[X3[String, String, Long]]()). collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) } From ff238a091640295eb9f1b106b597908a920041a4 Mon Sep 17 00:00:00 2001 From: Jeremy Smith Date: Thu, 21 Sep 2017 12:12:32 -0700 Subject: [PATCH 6/6] Add additional check for `isCaseAccessorLike` --- dataset/src/main/scala/frameless/column/ColumnMacros.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dataset/src/main/scala/frameless/column/ColumnMacros.scala b/dataset/src/main/scala/frameless/column/ColumnMacros.scala index a8dc9dffd..20da16089 100644 --- a/dataset/src/main/scala/frameless/column/ColumnMacros.scala +++ b/dataset/src/main/scala/frameless/column/ColumnMacros.scala @@ -1,11 +1,12 @@ package frameless.column import frameless.{TypedColumn, TypedEncoder, TypedExpressionEncoder} +import shapeless.CaseClassMacros import scala.collection.immutable.Queue import scala.reflect.macros.whitebox -class ColumnMacros(val c: whitebox.Context) { +class ColumnMacros(val c: whitebox.Context) extends CaseClassMacros { import c.universe._ // could be used to reintroduce apply('foo) @@ -52,7 +53,8 @@ class ColumnMacros(val c: whitebox.Context) { def unapply(tree: Tree): Option[Queue[String]] = { tree match { case Ident(`name`) => Some(Queue.empty) - case Select(Self(strs), nested) => Some(strs enqueue nested.toString) + case s @ Select(Self(strs), nested) if s.symbol.isTerm && isCaseAccessorLike(s.symbol.asTerm) => + Some(strs enqueue nested.toString) // $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work case Apply(Self(strs), List()) => Some(strs) // $COVERAGE-ON$