From c1a27932a9a1851dbd7ac390498e42bfabb11ece Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Thu, 26 Feb 2026 15:05:58 -0800 Subject: [PATCH 1/7] feat: Typed dataframe API --- dataframe.cabal | 12 +- src/DataFrame/Internal/Column.hs | 111 +++------ src/DataFrame/Typed.hs | 219 ++++++++++++++++++ src/DataFrame/Typed/Access.hs | 55 +++++ src/DataFrame/Typed/Aggregate.hs | 97 ++++++++ src/DataFrame/Typed/Expr.hs | 280 ++++++++++++++++++++++ src/DataFrame/Typed/Freeze.hs | 93 ++++++++ src/DataFrame/Typed/Join.hs | 77 ++++++ src/DataFrame/Typed/Operations.hs | 373 ++++++++++++++++++++++++++++++ src/DataFrame/Typed/Schema.hs | 347 +++++++++++++++++++++++++++ src/DataFrame/Typed/TH.hs | 91 ++++++++ src/DataFrame/Typed/Types.hs | 117 ++++++++++ 12 files changed, 1794 insertions(+), 78 deletions(-) create mode 100644 src/DataFrame/Typed.hs create mode 100644 src/DataFrame/Typed/Access.hs create mode 100644 src/DataFrame/Typed/Aggregate.hs create mode 100644 src/DataFrame/Typed/Expr.hs create mode 100644 src/DataFrame/Typed/Freeze.hs create mode 100644 src/DataFrame/Typed/Join.hs create mode 100644 src/DataFrame/Typed/Operations.hs create mode 100644 src/DataFrame/Typed/Schema.hs create mode 100644 src/DataFrame/Typed/TH.hs create mode 100644 src/DataFrame/Typed/Types.hs diff --git a/dataframe.cabal b/dataframe.cabal index c10646b..38c0d69 100644 --- a/dataframe.cabal +++ b/dataframe.cabal @@ -91,7 +91,17 @@ library DataFrame.Lazy.IO.CSV, DataFrame.Lazy.Internal.DataFrame, DataFrame.Monad, - DataFrame.DecisionTree + DataFrame.DecisionTree, + DataFrame.Typed.Types, + DataFrame.Typed.Schema, + DataFrame.Typed.Freeze, + DataFrame.Typed.Access, + DataFrame.Typed.Operations, + DataFrame.Typed.Join, + DataFrame.Typed.Aggregate, + DataFrame.Typed.TH, + DataFrame.Typed.Expr, + DataFrame.Typed build-depends: base >= 4 && <5, aeson >= 0.11.0.0 && < 3, array >= 0.5.4.0 && < 0.6, diff --git a/src/DataFrame/Internal/Column.hs b/src/DataFrame/Internal/Column.hs index f3c3b33..650907e 100644 --- a/src/DataFrame/Internal/Column.hs +++ b/src/DataFrame/Internal/Column.hs @@ -620,83 +620,40 @@ zipColumns (OptionalColumn optcolumn) (OptionalColumn optother) = BoxedColumn (V -- | Merge two columns using `These`. mergeColumns :: Column -> Column -> Column -mergeColumns (BoxedColumn column) (BoxedColumn other) = BoxedColumn (VG.zipWith These column other) -mergeColumns (BoxedColumn column) (UnboxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - (\i -> These (column VG.! i) (other VG.! i)) - ) -mergeColumns (BoxedColumn column) (OptionalColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (other VG.! i) - then This (column VG.! i) - else These (column VG.! i) (fromJust $ other VG.! i) - ) - ) -mergeColumns (UnboxedColumn column) (BoxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - (\i -> These (column VG.! i) (other VG.! i)) - ) -mergeColumns (UnboxedColumn column) (UnboxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - (\i -> These (column VG.! i) (other VG.! i)) - ) -mergeColumns (UnboxedColumn column) (OptionalColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (other VG.! i) - then This (column VG.! i) - else These (column VG.! i) (fromJust $ other VG.! i) - ) - ) -mergeColumns (OptionalColumn column) (BoxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (column VG.! i) - then That (other VG.! i) - else These (fromJust $ column VG.! i) (other VG.! i) - ) - ) -mergeColumns (OptionalColumn column) (UnboxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (column VG.! i) - then That (other VG.! i) - else These (fromJust $ column VG.! i) (other VG.! i) - ) - ) -mergeColumns (OptionalColumn column) (OptionalColumn other) = - OptionalColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (column VG.! i) && isNothing (other VG.! i) - then Nothing - else - ( if isNothing (column VG.! i) - then Just (That (fromJust $ other VG.! i)) - else - ( if isNothing (other VG.! i) - then Just (This (fromJust $ column VG.! i)) - else Just (These (fromJust $ column VG.! i) (fromJust $ other VG.! i)) - ) - ) - ) - ) +mergeColumns colA colB = case (colA, colB) of + (OptionalColumn c1, OptionalColumn c2) -> + OptionalColumn $ mkVec c1 c2 $ \v1 v2 -> + case (v1, v2) of + (Nothing, Nothing) -> Nothing + (Just x, Nothing) -> Just (This x) + (Nothing, Just y) -> Just (That y) + (Just x, Just y) -> Just (These x y) + (OptionalColumn c1, BoxedColumn c2) -> optReq c1 c2 + (OptionalColumn c1, UnboxedColumn c2) -> optReq c1 c2 + (BoxedColumn c1, OptionalColumn c2) -> reqOpt c1 c2 + (UnboxedColumn c1, OptionalColumn c2) -> reqOpt c1 c2 + (BoxedColumn c1, BoxedColumn c2) -> reqReq c1 c2 + (BoxedColumn c1, UnboxedColumn c2) -> reqReq c1 c2 + (UnboxedColumn c1, BoxedColumn c2) -> reqReq c1 c2 + (UnboxedColumn c1, UnboxedColumn c2) -> reqReq c1 c2 + where + mkVec c1 c2 combineElements = + VB.generate + (min (VG.length c1) (VG.length c2)) + (\i -> combineElements (c1 VG.! i) (c2 VG.! i)) + {-# INLINE mkVec #-} + + reqReq c1 c2 = BoxedColumn $ mkVec c1 c2 These + + reqOpt c1 c2 = BoxedColumn $ mkVec c1 c2 $ \v1 v2 -> + case v2 of + Nothing -> This v1 + Just y -> These v1 y + + optReq c1 c2 = BoxedColumn $ mkVec c1 c2 $ \v1 v2 -> + case v1 of + Nothing -> That v2 + Just x -> These x v2 {-# INLINE mergeColumns #-} -- | An internal, column version of zipWith. diff --git a/src/DataFrame/Typed.hs b/src/DataFrame/Typed.hs new file mode 100644 index 0000000..8a893f2 --- /dev/null +++ b/src/DataFrame/Typed.hs @@ -0,0 +1,219 @@ +{-# LANGUAGE DataKinds #-} + +{- | +Module : DataFrame.Typed +Copyright : (c) 2025 +License : MIT +Maintainer : mschavinda@gmail.com +Stability : experimental + +A type-safe layer over the @dataframe@ library. + +This module provides 'TypedDataFrame', a phantom-typed wrapper around +the untyped 'DataFrame' that tracks column names and types at compile time. +All operations delegate to the untyped core at runtime; the phantom type +is updated at compile time to reflect schema changes. + +== Key difference from untyped API: TExpr + +All expression-taking operations use 'TExpr' (typed expressions) instead +of raw @Expr@. Column references are validated at compile time: + +@ +{\-\# LANGUAGE DataKinds, TypeApplications, TypeOperators \#-\} +import qualified DataFrame.Typed as T + +type People = '[T.Column \"name\" Text, T.Column \"age\" Int] + +main = do + raw <- D.readCsv \"people.csv\" + case T.freeze \@People raw of + Nothing -> putStrLn \"Schema mismatch!\" + Just df -> do + let adults = T.filterWhere (T.col \@\"age\" T..>=. T.lit 18) df + let names = T.columnAsList \@\"name\" adults -- :: [Text] + print names +@ + +Column references like @T.col \@\"age\"@ are checked at compile time — if the +column doesn't exist or has the wrong type, you get a type error, not a +runtime exception. + +== filterAllJust tracks Maybe-stripping + +@ +df :: TypedDataFrame '[Column \"x\" (Maybe Double), Column \"y\" Int] +T.filterAllJust df :: TypedDataFrame '[Column \"x\" Double, Column \"y\" Int] +@ + +== Typed aggregation (Option B) + +@ +result = T.aggregate + (T.agg \@\"total\" (T.tsum (T.col \@\"salary\")) + $ T.agg \@\"count\" (T.tcount (T.col \@\"salary\")) + $ T.aggNil) + (T.groupBy \@'[\"dept\"] employees) +@ +-} +module DataFrame.Typed ( + -- * Core types + TypedDataFrame, + Column, + TypedGrouped, + These (..), + + -- * Typed expressions + TExpr (..), + col, + lit, + ifThenElse, + tlift, + tlift2, + + -- * Comparison operators + (.==.), + (./=.), + (.<.), + (.<=.), + (.>=.), + (.>.), + + -- * Logical operators + (.&&.), + (.||.), + tnot, + + -- * Aggregation expression combinators + tsum, + tmean, + tcount, + tminimum, + tmaximum, + tcollect, + + -- * Typed sort orders + TSortOrder (..), + asc, + desc, + + -- * Named expression helper + DataFrame.Typed.Expr.as, + + -- * Freeze / thaw boundary + freeze, + freezeWithError, + thaw, + unsafeFreeze, + + -- * Typed column access + columnAsVector, + columnAsList, + + -- * Schema-preserving operations + filterWhere, + filter, + filterBy, + filterAllJust, + filterJust, + filterNothing, + sortBy, + take, + takeLast, + drop, + dropLast, + range, + cube, + distinct, + sample, + shuffle, + + -- * Schema-modifying operations + derive, + select, + exclude, + rename, + renameMany, + insert, + insertColumn, + insertVector, + cloneColumn, + dropColumn, + replaceColumn, + + -- * Metadata + dimensions, + nRows, + nColumns, + columnNames, + + -- * Vertical merge + append, + + -- * Joins + innerJoin, + leftJoin, + rightJoin, + fullOuterJoin, + + -- * GroupBy and Aggregation (Option B) + groupBy, + agg, + aggNil, + aggregate, + aggregateUntyped, + + -- * Template Haskell + deriveSchema, + + -- * Schema type families (for advanced use) + Lookup, + HasName, + SubsetSchema, + ExcludeSchema, + RenameInSchema, + RemoveColumn, + Append, + Reverse, + StripAllMaybe, + StripMaybeAt, + GroupKeyColumns, + InnerJoinSchema, + LeftJoinSchema, + RightJoinSchema, + FullOuterJoinSchema, + AssertAbsent, + AssertPresent, + + -- * Constraints + KnownSchema (..), + AllKnownSymbol (..), + + -- * Pipe operator + (|>), +) where + +import Prelude hiding (drop, filter, take) + +import DataFrame.Typed.Access (columnAsList, columnAsVector) +import DataFrame.Typed.Aggregate ( + agg, + aggNil, + aggregate, + aggregateUntyped, + groupBy, + ) +import DataFrame.Typed.Expr +import DataFrame.Typed.Freeze (freeze, freezeWithError, thaw, unsafeFreeze) +import DataFrame.Typed.Join (fullOuterJoin, innerJoin, leftJoin, rightJoin) +import DataFrame.Typed.Operations +import DataFrame.Typed.Schema +import DataFrame.Typed.TH (deriveSchema) +import DataFrame.Typed.Types ( + Column, + TExpr (..), + TSortOrder (..), + These (..), + TypedDataFrame, + TypedGrouped, + ) diff --git a/src/DataFrame/Typed/Access.hs b/src/DataFrame/Typed/Access.hs new file mode 100644 index 0000000..268fa91 --- /dev/null +++ b/src/DataFrame/Typed/Access.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Access ( + -- * Typed column access + columnAsVector, + columnAsList, +) where + +import Control.Exception (throw) +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import qualified Data.Vector as V +import GHC.TypeLits (KnownSymbol, symbolVal) + +import DataFrame.Internal.Column (Columnable) +import DataFrame.Internal.Expression (Expr (Col)) +import qualified DataFrame.Operations.Core as D +import DataFrame.Typed.Schema (AssertPresent, Lookup) +import DataFrame.Typed.Types (TypedDataFrame (..)) + +{- | Retrieve a column as a boxed 'Vector', with the type determined by +the schema. The column must exist (enforced at compile time). +-} +columnAsVector :: + forall name cols a. + ( KnownSymbol name + , a ~ Lookup name cols + , Columnable a + , AssertPresent name cols + ) => + TypedDataFrame cols -> V.Vector a +columnAsVector (TDF df) = + either throw id $ D.columnAsVector (Col @a colName) df + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Retrieve a column as a list, with the type determined by the schema. +columnAsList :: + forall name cols a. + ( KnownSymbol name + , a ~ Lookup name cols + , Columnable a + , AssertPresent name cols + ) => + TypedDataFrame cols -> [a] +columnAsList (TDF df) = + D.columnAsList (Col @a colName) df + where + colName = T.pack (symbolVal (Proxy @name)) diff --git a/src/DataFrame/Typed/Aggregate.hs b/src/DataFrame/Typed/Aggregate.hs new file mode 100644 index 0000000..341315f --- /dev/null +++ b/src/DataFrame/Typed/Aggregate.hs @@ -0,0 +1,97 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Aggregate ( + -- * Typed groupBy + groupBy, + + -- * Typed aggregation builder (Option B) + agg, + aggNil, + + -- * Running aggregations + aggregate, + + -- * Escape hatch + aggregateUntyped, +) where + +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) + +import DataFrame.Internal.Column (Columnable) +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Internal.Expression (Expr, NamedExpr, UExpr (..)) +import qualified DataFrame.Operations.Aggregation as DA + +import DataFrame.Typed.Freeze (unsafeFreeze) +import DataFrame.Typed.Schema +import DataFrame.Typed.Types + +{- | Group a typed DataFrame by one or more key columns. + +@ +grouped = groupBy \@'[\"department\"] employees +@ +-} +groupBy :: + forall (keys :: [Symbol]) cols. + (AllKnownSymbol keys) => + TypedDataFrame cols -> TypedGrouped keys cols +groupBy (TDF df) = TGD (DA.groupBy (symbolVals @keys) df) + +-- | The empty aggregation — no output columns beyond the group keys. +aggNil :: TAgg keys cols '[] +aggNil = TAggNil + +{- | Add one aggregation to the builder. + +Each call prepends a @Column name a@ to the result schema and records +the runtime 'NamedExpr'. The expression is validated against the +source schema @cols@ at compile time. + +@ +agg \@\"total_sales\" (tsum (col \@\"salary\")) + $ agg \@\"avg_price\" (tmean (col \@\"price\")) + $ aggNil +@ +-} +agg :: + forall name a keys cols aggs. + ( KnownSymbol name + , Columnable a + ) => + TExpr cols a -> TAgg keys cols aggs -> TAgg keys cols (Column name a ': aggs) +agg = TAggCons colName + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Run a typed aggregation. + +Result schema = grouping key columns ++ aggregated columns (in declaration order). + +@ +result = aggregate + (agg \@\"total\" (tsum salary) $ agg \@\"count\" (tcount salary) $ aggNil) + (groupBy \@'[\"dept\"] employees) +-- result :: TDF '[Column \"dept\" Text, Column \"total\" Double, Column \"count\" Int] +@ +-} +aggregate :: + forall keys cols aggs. + TAgg keys cols aggs -> + TypedGrouped keys cols -> + TypedDataFrame (Append (GroupKeyColumns keys cols) (Reverse aggs)) +aggregate tagg (TGD gdf) = + unsafeFreeze (DA.aggregate (taggToNamedExprs tagg) gdf) + +-- | Escape hatch: run an untyped aggregation and return a raw 'DataFrame'. +aggregateUntyped :: [NamedExpr] -> TypedGrouped keys cols -> D.DataFrame +aggregateUntyped exprs (TGD gdf) = DA.aggregate exprs gdf diff --git a/src/DataFrame/Typed/Expr.hs b/src/DataFrame/Typed/Expr.hs new file mode 100644 index 0000000..25647db --- /dev/null +++ b/src/DataFrame/Typed/Expr.hs @@ -0,0 +1,280 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +{- | Type-safe expression construction for typed DataFrames. + +Unlike the untyped @Expr a@ where column references are unchecked strings, +'TExpr' ensures at compile time that: + +* Referenced columns exist in the schema +* Column types match the expression type + +== Example + +@ +type Schema = '[Column \"age\" Int, Column \"salary\" Double] + +-- This compiles: +goodExpr :: TExpr Schema Double +goodExpr = col \@\"salary\" + +-- This gives a compile-time error (column not found): +badExpr :: TExpr Schema Double +badExpr = col \@\"nonexistent\" + +-- This gives a compile-time error (type mismatch): +wrongType :: TExpr Schema Int +wrongType = col \@\"salary\" -- salary is Double, not Int +@ +-} +module DataFrame.Typed.Expr ( + -- * Core typed expression type (re-exported from Types) + TExpr (..), + + -- * Column reference (schema-checked) + col, + + -- * Literals + lit, + + -- * Conditional + ifThenElse, + + -- * Unary / binary lifting + tlift, + tlift2, + + -- * Comparison operators + (.==.), + (./=.), + (.<.), + (.<=.), + (.>=.), + (.>.), + + -- * Logical operators + (.&&.), + (.||.), + tnot, + + -- * Aggregation combinators + tsum, + tmean, + tcount, + tminimum, + tmaximum, + tcollect, + + -- * Named expression helper + as, + + -- * Sort helpers + asc, + desc, +) where + +import Data.Kind (Type) +import Data.Proxy (Proxy (..)) +import Data.String (IsString (..)) +import qualified Data.Text as T +import qualified Data.Vector.Unboxed as VU +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) + +import DataFrame.Internal.Column (Columnable) +import DataFrame.Internal.Expression ( + AggStrategy (..), + BinaryOp (..), + Expr (..), + NamedExpr, + UExpr (..), + UnaryOp (..), + ) +import DataFrame.Typed.Schema (AssertPresent, Lookup) +import DataFrame.Typed.Types (Column, TExpr (..), TSortOrder (..)) + +------------------------------------------------------------------------------- +-- Column reference — the core type-safe construction point +------------------------------------------------------------------------------- + +{- | Create a typed column reference. This is the key type-safety entry point. + +The column name must exist in @cols@ and its type must match @a@. +Both checks happen at compile time via type families. + +@ +salary :: TExpr '[Column \"salary\" Double] Double +salary = col \@\"salary\" +@ +-} +col :: + forall (name :: Symbol) cols a. + ( KnownSymbol name + , a ~ Lookup name cols + , Columnable a + , AssertPresent name cols + ) => + TExpr cols a +col = TExpr (Col (T.pack (symbolVal (Proxy @name)))) + +{- | Create a literal expression. Valid for any schema since it +references no columns. +-} +lit :: (Columnable a) => a -> TExpr cols a +lit = TExpr . Lit + +-- | Conditional expression. +ifThenElse :: + (Columnable a) => TExpr cols Bool -> TExpr cols a -> TExpr cols a -> TExpr cols a +ifThenElse (TExpr c) (TExpr t) (TExpr e) = TExpr (If c t e) + +------------------------------------------------------------------------------- +-- Numeric instances (mirror Expr's instances) +------------------------------------------------------------------------------- + +instance (Num a, Columnable a) => Num (TExpr cols a) where + (TExpr a) + (TExpr b) = TExpr (a + b) + (TExpr a) - (TExpr b) = TExpr (a - b) + (TExpr a) * (TExpr b) = TExpr (a * b) + negate (TExpr a) = TExpr (negate a) + abs (TExpr a) = TExpr (abs a) + signum (TExpr a) = TExpr (signum a) + fromInteger = TExpr . fromInteger + +instance (Fractional a, Columnable a) => Fractional (TExpr cols a) where + fromRational = TExpr . fromRational + (TExpr a) / (TExpr b) = TExpr (a / b) + +instance (Floating a, Columnable a) => Floating (TExpr cols a) where + pi = TExpr pi + exp (TExpr a) = TExpr (exp a) + sqrt (TExpr a) = TExpr (sqrt a) + log (TExpr a) = TExpr (log a) + (TExpr a) ** (TExpr b) = TExpr (a ** b) + logBase (TExpr a) (TExpr b) = TExpr (logBase a b) + sin (TExpr a) = TExpr (sin a) + cos (TExpr a) = TExpr (cos a) + tan (TExpr a) = TExpr (tan a) + asin (TExpr a) = TExpr (asin a) + acos (TExpr a) = TExpr (acos a) + atan (TExpr a) = TExpr (atan a) + sinh (TExpr a) = TExpr (sinh a) + cosh (TExpr a) = TExpr (cosh a) + asinh (TExpr a) = TExpr (asinh a) + acosh (TExpr a) = TExpr (acosh a) + atanh (TExpr a) = TExpr (atanh a) + +instance (IsString a, Columnable a) => IsString (TExpr cols a) where + fromString = TExpr . fromString + +------------------------------------------------------------------------------- +-- Lifting arbitrary functions +------------------------------------------------------------------------------- + +-- | Lift a unary function into a typed expression. +tlift :: + (Columnable a, Columnable b) => (a -> b) -> TExpr cols a -> TExpr cols b +tlift f (TExpr e) = TExpr (Unary (MkUnaryOp f "unaryUdf" Nothing) e) + +-- | Lift a binary function into typed expressions. +tlift2 :: + (Columnable a, Columnable b, Columnable c) => + (a -> b -> c) -> TExpr cols a -> TExpr cols b -> TExpr cols c +tlift2 f (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp f "binaryUdf" Nothing False 0) a b) + +------------------------------------------------------------------------------- +-- Comparison operators +------------------------------------------------------------------------------- + +infixl 4 .==., ./=., .<., .<=., .>=., .>. +infixr 3 .&&. +infixr 2 .||. + +(.==.) :: + (Columnable a, Eq a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.==.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (==) "eq" (Just "==") True 4) a b) + +(./=.) :: + (Columnable a, Eq a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(./=.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (/=) "neq" (Just "/=") True 4) a b) + +(.<.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.<.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (<) "lt" (Just "<") False 4) a b) + +(.<=.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.<=.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (<=) "leq" (Just "<=") False 4) a b) + +(.>=.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.>=.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (>=) "geq" (Just ">=") False 4) a b) + +(.>.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.>.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (>) "gt" (Just ">") False 4) a b) + +(.&&.) :: TExpr cols Bool -> TExpr cols Bool -> TExpr cols Bool +(.&&.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (&&) "and" (Just "&&") True 3) a b) + +(.||.) :: TExpr cols Bool -> TExpr cols Bool -> TExpr cols Bool +(.||.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (||) "or" (Just "||") True 2) a b) + +tnot :: TExpr cols Bool -> TExpr cols Bool +tnot (TExpr e) = TExpr (Unary (MkUnaryOp not "not" (Just "!")) e) + +------------------------------------------------------------------------------- +-- Aggregation combinators +------------------------------------------------------------------------------- + +tsum :: (Columnable a, Num a) => TExpr cols a -> TExpr cols a +tsum (TExpr e) = TExpr (Agg (FoldAgg "sum" Nothing (+)) e) + +tmean :: (Columnable a, Real a, VU.Unbox a) => TExpr cols a -> TExpr cols Double +tmean (TExpr e) = TExpr (Agg (CollectAgg "mean" mean') e) + where + mean' v = + let s = VU.foldl' (\acc x -> acc + realToFrac x) (0 :: Double) v + n = VU.length v + in if n == 0 then 0 else s / fromIntegral n + +tcount :: (Columnable a) => TExpr cols a -> TExpr cols Int +tcount (TExpr e) = TExpr (Agg (FoldAgg "count" (Just 0) (\acc _ -> acc + 1)) e) + +tminimum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a +tminimum (TExpr e) = TExpr (Agg (FoldAgg "minimum" Nothing min) e) + +tmaximum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a +tmaximum (TExpr e) = TExpr (Agg (FoldAgg "maximum" Nothing max) e) + +tcollect :: (Columnable a) => TExpr cols a -> TExpr cols [a] +tcollect (TExpr e) = TExpr (Agg (FoldAgg "collect" (Just []) (flip (:))) e) + +------------------------------------------------------------------------------- +-- Named expression helper +------------------------------------------------------------------------------- + +-- | Create a 'NamedExpr' for use with 'aggregateUntyped'. +as :: (Columnable a) => TExpr cols a -> T.Text -> NamedExpr +as (TExpr e) name = (name, UExpr e) + +------------------------------------------------------------------------------- +-- Sort helpers +------------------------------------------------------------------------------- + +-- | Create an ascending sort order from a typed expression. +asc :: (Columnable a) => TExpr cols a -> TSortOrder cols +asc = Asc + +-- | Create a descending sort order from a typed expression. +desc :: (Columnable a) => TExpr cols a -> TSortOrder cols +desc = Desc diff --git a/src/DataFrame/Typed/Freeze.hs b/src/DataFrame/Typed/Freeze.hs new file mode 100644 index 0000000..79963df --- /dev/null +++ b/src/DataFrame/Typed/Freeze.hs @@ -0,0 +1,93 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module DataFrame.Typed.Freeze ( + -- * Safe boundary + freeze, + freezeWithError, + + -- * Escape hatches + thaw, + unsafeFreeze, +) where + +import Data.Kind (Type) +import qualified Data.Map as M +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import Type.Reflection (SomeTypeRep, someTypeRep) + +import qualified DataFrame.Internal.Column as C +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Operations.Core (columnNames) +import DataFrame.Typed.Schema (KnownSchema (..)) +import DataFrame.Typed.Types (TypedDataFrame (..)) + +{- | Validate that an untyped 'DataFrame' matches the expected schema @cols@, +then wrap it. Returns 'Nothing' on mismatch. +-} +freeze :: + forall cols. (KnownSchema cols) => D.DataFrame -> Maybe (TypedDataFrame cols) +freeze df = case validateSchema @cols df of + Left _ -> Nothing + Right _ -> Just (TDF df) + +-- | Like 'freeze' but returns a descriptive error message on failure. +freezeWithError :: + forall cols. + (KnownSchema cols) => + D.DataFrame -> Either T.Text (TypedDataFrame cols) +freezeWithError df = case validateSchema @cols df of + Left err -> Left err + Right _ -> Right (TDF df) + +{- | Unwrap a typed DataFrame back to the untyped representation. +Always safe; discards type information. +-} +thaw :: TypedDataFrame cols -> D.DataFrame +thaw (TDF df) = df + +{- | Wrap an untyped DataFrame without any validation. +Used internally after delegation where the library guarantees schema correctness. +-} +unsafeFreeze :: D.DataFrame -> TypedDataFrame cols +unsafeFreeze = TDF + +------------------------------------------------------------------------------- +-- Internal validation +------------------------------------------------------------------------------- + +validateSchema :: + forall cols. + (KnownSchema cols) => + D.DataFrame -> Either T.Text () +validateSchema df = mapM_ checkCol (schemaEvidence @cols) + where + checkCol :: (T.Text, SomeTypeRep) -> Either T.Text () + checkCol (name, expectedRep) = case D.getColumn name df of + Nothing -> + Left $ + "Column '" + <> name + <> "' not found in DataFrame. " + <> "Available columns: " + <> T.pack (show (columnNames df)) + Just col -> + if matchesType expectedRep col + then Right () + else + Left $ + "Type mismatch on column '" + <> name + <> "': expected " + <> T.pack (show expectedRep) + <> ", got " + <> T.pack (C.columnTypeString col) + +-- | Check if a Column's element type matches the expected SomeTypeRep. +matchesType :: SomeTypeRep -> C.Column -> Bool +matchesType expected col = T.pack (show expected) == T.pack (C.columnTypeString col) diff --git a/src/DataFrame/Typed/Join.hs b/src/DataFrame/Typed/Join.hs new file mode 100644 index 0000000..2c8622a --- /dev/null +++ b/src/DataFrame/Typed/Join.hs @@ -0,0 +1,77 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} + +module DataFrame.Typed.Join ( + -- * Typed joins + innerJoin, + leftJoin, + rightJoin, + fullOuterJoin, +) where + +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) + +import DataFrame.Internal.Column (Columnable) +import qualified DataFrame.Internal.DataFrame as D +import qualified DataFrame.Operations.Core as D +import qualified DataFrame.Operations.Join as DJ + +import DataFrame.Typed.Freeze (thaw, unsafeFreeze) +import DataFrame.Typed.Schema +import DataFrame.Typed.Types (TypedDataFrame (..)) + +-- | Typed inner join on one or more key columns. +innerJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (InnerJoinSchema keys left right) +innerJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.innerJoin keyNames r l) + where + keyNames = symbolVals @keys + +-- | Typed left join. +leftJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (LeftJoinSchema keys left right) +leftJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.leftJoin keyNames r l) + where + keyNames = symbolVals @keys + +-- | Typed right join. +rightJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (RightJoinSchema keys left right) +rightJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.rightJoin keyNames r l) + where + keyNames = symbolVals @keys + +-- | Typed full outer join. +fullOuterJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (FullOuterJoinSchema keys left right) +fullOuterJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.fullOuterJoin keyNames r l) + where + keyNames = symbolVals @keys diff --git a/src/DataFrame/Typed/Operations.hs b/src/DataFrame/Typed/Operations.hs new file mode 100644 index 0000000..26955d4 --- /dev/null +++ b/src/DataFrame/Typed/Operations.hs @@ -0,0 +1,373 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Operations ( + -- * Schema-preserving operations + filterWhere, + filter, + filterBy, + filterAllJust, + filterJust, + filterNothing, + sortBy, + take, + takeLast, + drop, + dropLast, + range, + cube, + distinct, + sample, + shuffle, + + -- * Schema-modifying operations + derive, + select, + exclude, + rename, + renameMany, + insert, + insertColumn, + insertVector, + cloneColumn, + dropColumn, + replaceColumn, + + -- * Metadata + dimensions, + nRows, + nColumns, + columnNames, + + -- * Vertical merge + append, + + -- * Pipe operator + (|>), +) where + +import qualified Data.Foldable as F +import Data.Function ((&)) +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import qualified Data.Vector as V +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) +import System.Random (RandomGen) +import Prelude hiding (drop, filter, take) + +import DataFrame.Internal.Column (Columnable) +import qualified DataFrame.Internal.Column as C +import DataFrame.Internal.Expression (Expr (..)) +import qualified DataFrame.Operations.Aggregation as DA +import qualified DataFrame.Operations.Core as D +import DataFrame.Operations.Merge () +import qualified DataFrame.Operations.Permutation as D +import qualified DataFrame.Operations.Subset as D +import qualified DataFrame.Operations.Transformations as D + +-- Semigroup instance + +import DataFrame.Typed.Freeze (thaw, unsafeFreeze) +import DataFrame.Typed.Schema +import DataFrame.Typed.Types (TExpr (..), TSortOrder (..), TypedDataFrame (..)) +import qualified DataFrame.Typed.Types as T + +-- | Pipe operator, re-exported for convenience. +(|>) :: a -> (a -> b) -> b +(|>) = (&) + +infixl 1 |> + +------------------------------------------------------------------------------- +-- Schema-preserving operations +------------------------------------------------------------------------------- + +{- | Filter rows where a boolean expression evaluates to True. +The expression is validated against the schema at compile time. +-} +filterWhere :: TExpr cols Bool -> TypedDataFrame cols -> TypedDataFrame cols +filterWhere (TExpr expr) (TDF df) = TDF (D.filterWhere expr df) + +-- | Filter rows by applying a predicate to a typed expression. +filter :: + (Columnable a) => + TExpr cols a -> (a -> Bool) -> TypedDataFrame cols -> TypedDataFrame cols +filter (TExpr expr) pred' (TDF df) = TDF (D.filter expr pred' df) + +-- | Filter rows by a predicate on a column expression (flipped argument order). +filterBy :: + (Columnable a) => + (a -> Bool) -> TExpr cols a -> TypedDataFrame cols -> TypedDataFrame cols +filterBy pred' (TExpr expr) (TDF df) = TDF (D.filterBy pred' expr df) + +{- | Keep only rows where ALL Optional columns have Just values. +Strips 'Maybe' from all column types in the result schema. + +@ +df :: TDF '[Column \"x\" (Maybe Double), Column \"y\" Int] +filterAllJust df :: TDF '[Column \"x\" Double, Column \"y\" Int] +@ +-} +filterAllJust :: TypedDataFrame cols -> TypedDataFrame (StripAllMaybe cols) +filterAllJust (TDF df) = unsafeFreeze (D.filterAllJust df) + +{- | Keep only rows where the named column has Just values. +Strips 'Maybe' from that column's type in the result schema. + +@ +filterJust \@\"x\" df +@ +-} +filterJust :: + forall name cols. + ( KnownSymbol name + , AssertPresent name cols + ) => + TypedDataFrame cols -> TypedDataFrame (StripMaybeAt name cols) +filterJust (TDF df) = unsafeFreeze (D.filterJust colName df) + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Keep only rows where the named column has Nothing. +Schema is preserved (column types unchanged, just fewer rows). +-} +filterNothing :: + forall name cols. + ( KnownSymbol name + , AssertPresent name cols + ) => + TypedDataFrame cols -> TypedDataFrame cols +filterNothing (TDF df) = TDF (D.filterNothing colName df) + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Sort by the given typed sort orders. +Sort orders reference columns that are validated against the schema. +-} +sortBy :: [TSortOrder cols] -> TypedDataFrame cols -> TypedDataFrame cols +sortBy ords (TDF df) = TDF (D.sortBy (map toUntypedSort ords) df) + where + toUntypedSort :: TSortOrder cols -> D.SortOrder + toUntypedSort (Asc (TExpr e)) = D.Asc e + toUntypedSort (Desc (TExpr e)) = D.Desc e + +-- | Take the first @n@ rows. +take :: Int -> TypedDataFrame cols -> TypedDataFrame cols +take n (TDF df) = TDF (D.take n df) + +-- | Take the last @n@ rows. +takeLast :: Int -> TypedDataFrame cols -> TypedDataFrame cols +takeLast n (TDF df) = TDF (D.takeLast n df) + +-- | Drop the first @n@ rows. +drop :: Int -> TypedDataFrame cols -> TypedDataFrame cols +drop n (TDF df) = TDF (D.drop n df) + +-- | Drop the last @n@ rows. +dropLast :: Int -> TypedDataFrame cols -> TypedDataFrame cols +dropLast n (TDF df) = TDF (D.dropLast n df) + +-- | Take rows in the given range (start, end). +range :: (Int, Int) -> TypedDataFrame cols -> TypedDataFrame cols +range r (TDF df) = TDF (D.range r df) + +-- | Take a sub-cube of the DataFrame. +cube :: (Int, Int) -> TypedDataFrame cols -> TypedDataFrame cols +cube c (TDF df) = TDF (D.cube c df) + +-- | Remove duplicate rows. +distinct :: TypedDataFrame cols -> TypedDataFrame cols +distinct (TDF df) = TDF (DA.distinct df) + +-- | Randomly sample a fraction of rows. +sample :: + (RandomGen g) => g -> Double -> TypedDataFrame cols -> TypedDataFrame cols +sample g frac (TDF df) = TDF (D.sample g frac df) + +-- | Shuffle all rows randomly. +shuffle :: (RandomGen g) => g -> TypedDataFrame cols -> TypedDataFrame cols +shuffle g (TDF df) = TDF (D.shuffle g df) + +------------------------------------------------------------------------------- +-- Schema-modifying operations +------------------------------------------------------------------------------- + +{- | Derive a new column from a typed expression. The column name must NOT +already exist in the schema (enforced at compile time via 'AssertAbsent'). +The expression is validated against the current schema. + +@ +df' = derive \@\"total\" (col \@\"price\" * col \@\"qty\") df +-- df' :: TDF (Column \"total\" Double ': originalCols) +@ +-} +derive :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , AssertAbsent name cols + ) => + TExpr cols a -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) +derive (TExpr expr) (TDF df) = unsafeFreeze (D.derive colName expr df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Select a subset of columns by name. +select :: + forall (names :: [Symbol]) cols. + (AllKnownSymbol names) => + TypedDataFrame cols -> TypedDataFrame (SubsetSchema names cols) +select (TDF df) = unsafeFreeze (D.select (symbolVals @names) df) + +-- | Exclude columns by name. +exclude :: + forall (names :: [Symbol]) cols. + (AllKnownSymbol names) => + TypedDataFrame cols -> TypedDataFrame (ExcludeSchema names cols) +exclude (TDF df) = unsafeFreeze (D.exclude (symbolVals @names) df) + +-- | Rename a column. +rename :: + forall old new cols. + (KnownSymbol old, KnownSymbol new) => + TypedDataFrame cols -> TypedDataFrame (RenameInSchema old new cols) +rename (TDF df) = unsafeFreeze (D.rename oldName newName df) + where + oldName = T.pack (symbolVal (Proxy @old)) + newName = T.pack (symbolVal (Proxy @new)) + +-- | Rename multiple columns from a type-level list of pairs. +renameMany :: + forall (pairs :: [(Symbol, Symbol)]) cols. + (AllKnownPairs pairs) => + TypedDataFrame cols -> TypedDataFrame (RenameManyInSchema pairs cols) +renameMany (TDF df) = unsafeFreeze (foldRenames (pairVals @pairs) df) + where + foldRenames [] df' = df' + foldRenames ((old, new) : rest) df' = foldRenames rest (D.rename old new df') + +-- | Insert a new column from a Foldable container. +insert :: + forall name a cols t. + ( KnownSymbol name + , Columnable a + , Foldable t + , AssertAbsent name cols + ) => + t a -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) +insert xs (TDF df) = unsafeFreeze (D.insert colName xs df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Insert a raw 'Column' value. +insertColumn :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , AssertAbsent name cols + ) => + C.Column -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) +insertColumn col (TDF df) = unsafeFreeze (D.insertColumn colName col df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Insert a boxed 'Vector'. +insertVector :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , AssertAbsent name cols + ) => + V.Vector a -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) +insertVector vec (TDF df) = unsafeFreeze (D.insertVector colName vec df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Clone an existing column under a new name. +cloneColumn :: + forall old new cols. + ( KnownSymbol old + , KnownSymbol new + , AssertPresent old cols + , AssertAbsent new cols + ) => + TypedDataFrame cols -> TypedDataFrame (T.Column new (Lookup old cols) ': cols) +cloneColumn (TDF df) = unsafeFreeze (D.cloneColumn oldName newName df) + where + oldName = T.pack (symbolVal (Proxy @old)) + newName = T.pack (symbolVal (Proxy @new)) + +-- | Drop a column by name. +dropColumn :: + forall name cols. + ( KnownSymbol name + , AssertPresent name cols + ) => + TypedDataFrame cols -> TypedDataFrame (RemoveColumn name cols) +dropColumn (TDF df) = unsafeFreeze (D.exclude [colName] df) + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Replace an existing column with new values derived from a typed expression. +The column must already exist and the new type must match. +-} +replaceColumn :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , a ~ Lookup name cols + , AssertPresent name cols + ) => + TExpr cols a -> TypedDataFrame cols -> TypedDataFrame cols +replaceColumn (TExpr expr) (TDF df) = unsafeFreeze (D.derive colName expr df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Vertically merge two DataFrames with the same schema. +append :: TypedDataFrame cols -> TypedDataFrame cols -> TypedDataFrame cols +append (TDF a) (TDF b) = TDF (a <> b) + +------------------------------------------------------------------------------- +-- Metadata (pass-through) +------------------------------------------------------------------------------- + +dimensions :: TypedDataFrame cols -> (Int, Int) +dimensions (TDF df) = D.dimensions df + +nRows :: TypedDataFrame cols -> Int +nRows (TDF df) = D.nRows df + +nColumns :: TypedDataFrame cols -> Int +nColumns (TDF df) = D.nColumns df + +columnNames :: TypedDataFrame cols -> [T.Text] +columnNames (TDF df) = D.columnNames df + +------------------------------------------------------------------------------- +-- Internal helpers +------------------------------------------------------------------------------- + +-- | Helper class for extracting [(Text, Text)] from type-level pairs. +class AllKnownPairs (pairs :: [(Symbol, Symbol)]) where + pairVals :: [(T.Text, T.Text)] + +instance AllKnownPairs '[] where + pairVals = [] + +instance + (KnownSymbol a, KnownSymbol b, AllKnownPairs rest) => + AllKnownPairs ('(a, b) ': rest) + where + pairVals = + ( T.pack (symbolVal (Proxy @a)) + , T.pack (symbolVal (Proxy @b)) + ) + : pairVals @rest diff --git a/src/DataFrame/Typed/Schema.hs b/src/DataFrame/Typed/Schema.hs new file mode 100644 index 0000000..499c70d --- /dev/null +++ b/src/DataFrame/Typed/Schema.hs @@ -0,0 +1,347 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module DataFrame.Typed.Schema ( + -- * Type families for schema manipulation + Lookup, + HasName, + RemoveColumn, + SubsetSchema, + ExcludeSchema, + RenameInSchema, + RenameManyInSchema, + Append, + Reverse, + ColumnNames, + AssertAbsent, + AssertPresent, + IsElem, + + -- * Maybe-stripping families + StripAllMaybe, + StripMaybeAt, + + -- * Join schema families + SharedNames, + UniqueLeft, + InnerJoinSchema, + LeftJoinSchema, + RightJoinSchema, + FullOuterJoinSchema, + WrapMaybe, + WrapMaybeColumns, + CollidingColumns, + + -- * GroupBy helpers + GroupKeyColumns, + + -- * KnownSchema class + KnownSchema (..), + + -- * Helpers + AllKnownSymbol (..), +) where + +import Data.Kind (Constraint, Type) +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import Data.These (These) +import GHC.TypeLits +import Type.Reflection (SomeTypeRep, Typeable, someTypeRep, typeRep) + +import DataFrame.Internal.Column (Columnable) +import DataFrame.Typed.Types (Column) + +------------------------------------------------------------------------------- +-- Core type families +------------------------------------------------------------------------------- + +-- | Look up the element type of a column by name. +type family Lookup (name :: Symbol) (cols :: [Type]) :: Type where + Lookup name (Column name a ': _) = a + Lookup name (Column _ _ ': rest) = Lookup name rest + Lookup name '[] = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") + +-- | Check whether a column name exists in a schema (type-level Bool). +type family HasName (name :: Symbol) (cols :: [Type]) :: Bool where + HasName name (Column name _ ': _) = 'True + HasName name (Column _ _ ': rest) = HasName name rest + HasName name '[] = 'False + +-- | Remove a column by name from a schema. +type family RemoveColumn (name :: Symbol) (cols :: [Type]) :: [Type] where + RemoveColumn name (Column name _ ': rest) = rest + RemoveColumn name (col ': rest) = col ': RemoveColumn name rest + RemoveColumn name '[] = '[] + +-- | Select a subset of columns by a list of names. +type family SubsetSchema (names :: [Symbol]) (cols :: [Type]) :: [Type] where + SubsetSchema '[] cols = '[] + SubsetSchema (n ': ns) cols = Column n (Lookup n cols) ': SubsetSchema ns cols + +-- | Exclude columns by a list of names. +type family ExcludeSchema (names :: [Symbol]) (cols :: [Type]) :: [Type] where + ExcludeSchema names '[] = '[] + ExcludeSchema names (Column n a ': rest) = + If + (IsElem n names) + (ExcludeSchema names rest) + (Column n a ': ExcludeSchema names rest) + +-- | Type-level if +type family If (b :: Bool) (t :: k) (f :: k) :: k where + If 'True t _ = t + If 'False _ f = f + +-- | Type-level elem for Symbols +type family IsElem (x :: Symbol) (xs :: [Symbol]) :: Bool where + IsElem x '[] = 'False + IsElem x (x ': _) = 'True + IsElem x (_ ': xs) = IsElem x xs + +-- | Rename a column in the schema. +type family RenameInSchema (old :: Symbol) (new :: Symbol) (cols :: [Type]) :: [Type] where + RenameInSchema old new (Column old a ': rest) = Column new a ': rest + RenameInSchema old new (col ': rest) = col ': RenameInSchema old new rest + RenameInSchema old new '[] = + TypeError + ('Text "Cannot rename: column '" ':<>: 'Text old ':<>: 'Text "' not found") + +-- | Rename multiple columns. +type family RenameManyInSchema (pairs :: [(Symbol, Symbol)]) (cols :: [Type]) :: [Type] where + RenameManyInSchema '[] cols = cols + RenameManyInSchema ('(old, new) ': rest) cols = + RenameManyInSchema rest (RenameInSchema old new cols) + +-- | Append two type-level lists. +type family Append (xs :: [k]) (ys :: [k]) :: [k] where + Append '[] ys = ys + Append (x ': xs) ys = x ': Append xs ys + +-- | Reverse a type-level list. +type family Reverse (xs :: [Type]) :: [Type] where + Reverse xs = ReverseAcc xs '[] + +type family ReverseAcc (xs :: [Type]) (acc :: [Type]) :: [Type] where + ReverseAcc '[] acc = acc + ReverseAcc (x ': xs) acc = ReverseAcc xs (x ': acc) + +-- | Extract column names as a type-level list of Symbols. +type family ColumnNames (cols :: [Type]) :: [Symbol] where + ColumnNames '[] = '[] + ColumnNames (Column n _ ': rest) = n ': ColumnNames rest + +-- | Assert that a column name is absent from the schema (for derive/insert). +type family AssertAbsent (name :: Symbol) (cols :: [Type]) :: Constraint where + AssertAbsent name cols = AssertAbsentHelper name (HasName name cols) cols + +type family + AssertAbsentHelper (name :: Symbol) (found :: Bool) (cols :: [Type]) :: + Constraint + where + AssertAbsentHelper name 'False cols = () + AssertAbsentHelper name 'True cols = + TypeError + ( 'Text "Column '" + ':<>: 'Text name + ':<>: 'Text "' already exists in schema. " + ':<>: 'Text "Use replaceColumn to overwrite." + ) + +-- | Assert that a column name is present in the schema. +type family AssertPresent (name :: Symbol) (cols :: [Type]) :: Constraint where + AssertPresent name cols = AssertPresentHelper name (HasName name cols) cols + +type family + AssertPresentHelper (name :: Symbol) (found :: Bool) (cols :: [Type]) :: + Constraint + where + AssertPresentHelper name 'True cols = () + AssertPresentHelper name 'False cols = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") + +------------------------------------------------------------------------------- +-- Maybe-stripping families +------------------------------------------------------------------------------- + +{- | Strip 'Maybe' from all columns. Used by 'filterAllJust'. + +@Column "x" (Maybe Double)@ becomes @Column "x" Double@. +@Column "y" Int@ stays @Column "y" Int@. +-} +type family StripAllMaybe (cols :: [Type]) :: [Type] where + StripAllMaybe '[] = '[] + StripAllMaybe (Column n (Maybe a) ': rest) = Column n a ': StripAllMaybe rest + StripAllMaybe (Column n a ': rest) = Column n a ': StripAllMaybe rest + +{- | Strip 'Maybe' from a single named column. Used by 'filterJust'. + +@StripMaybeAt "x" '[Column "x" (Maybe Double), Column "y" Int]@ + = @'[Column "x" Double, Column "y" Int]@ +-} +type family StripMaybeAt (name :: Symbol) (cols :: [Type]) :: [Type] where + StripMaybeAt name (Column name (Maybe a) ': rest) = Column name a ': rest + StripMaybeAt name (Column name a ': rest) = Column name a ': rest + StripMaybeAt name (col ': rest) = col ': StripMaybeAt name rest + StripMaybeAt name '[] = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") + +------------------------------------------------------------------------------- +-- Join schema families +------------------------------------------------------------------------------- + +-- | Extract column names that appear in both schemas. +type family SharedNames (left :: [Type]) (right :: [Type]) :: [Symbol] where + SharedNames '[] right = '[] + SharedNames (Column n _ ': rest) right = + If (HasName n right) (n ': SharedNames rest right) (SharedNames rest right) + +-- | Columns from @left@ whose names do NOT appear in @right@. +type family UniqueLeft (left :: [Type]) (rightNames :: [Symbol]) :: [Type] where + UniqueLeft '[] _ = '[] + UniqueLeft (Column n a ': rest) rn = + If (IsElem n rn) (UniqueLeft rest rn) (Column n a ': UniqueLeft rest rn) + +-- | Wrap column types in Maybe. +type family WrapMaybe (cols :: [Type]) :: [Type] where + WrapMaybe '[] = '[] + WrapMaybe (Column n a ': rest) = Column n (Maybe a) ': WrapMaybe rest + +-- | Wrap selected columns in Maybe by name list. +type family WrapMaybeColumns (names :: [Symbol]) (cols :: [Type]) :: [Type] where + WrapMaybeColumns names '[] = '[] + WrapMaybeColumns names (Column n a ': rest) = + If + (IsElem n names) + (Column n (Maybe a) ': WrapMaybeColumns names rest) + (Column n a ': WrapMaybeColumns names rest) + +-- | Columns in left whose names collide with right (excluding keys). +type family CollidingColumns (left :: [Type]) (right :: [Type]) (keys :: [Symbol]) :: [Type] where + CollidingColumns '[] _ _ = '[] + CollidingColumns (Column n a ': rest) right keys = + If + (IsElem n keys) + (CollidingColumns rest right keys) + ( If + (HasName n right) + (Column n (These a (Lookup n right)) ': CollidingColumns rest right keys) + (CollidingColumns rest right keys) + ) + +-- | Inner join result schema. +type family InnerJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: [Type] where + InnerJoinSchema keys left right = + Append + (SubsetSchema keys left) + ( Append + (UniqueLeft left (Append keys (ColumnNames right))) + ( Append + (UniqueLeft right (Append keys (ColumnNames left))) + (CollidingColumns left right keys) + ) + ) + +-- | Left join result schema. +type family LeftJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: [Type] where + LeftJoinSchema keys left right = + Append + (SubsetSchema keys left) + ( Append + (UniqueLeft left (Append keys (ColumnNames right))) + ( Append + (WrapMaybe (UniqueLeft right (Append keys (ColumnNames left)))) + (CollidingColumns left right keys) + ) + ) + +-- | Right join result schema. +type family RightJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: [Type] where + RightJoinSchema keys left right = + Append + (SubsetSchema keys right) + ( Append + (WrapMaybe (UniqueLeft left (Append keys (ColumnNames right)))) + ( Append + (UniqueLeft right (Append keys (ColumnNames left))) + (CollidingColumns left right keys) + ) + ) + +-- | Full outer join result schema. +type family + FullOuterJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: + [Type] + where + FullOuterJoinSchema keys left right = + Append + (WrapMaybe (SubsetSchema keys left)) + ( Append + (WrapMaybe (UniqueLeft left (Append keys (ColumnNames right)))) + ( Append + (WrapMaybe (UniqueLeft right (Append keys (ColumnNames left)))) + (CollidingColumns left right keys) + ) + ) + +------------------------------------------------------------------------------- +-- GroupBy helpers +------------------------------------------------------------------------------- + +-- | Extract Column entries from a schema whose names appear in @keys@. +type family GroupKeyColumns (keys :: [Symbol]) (cols :: [Type]) :: [Type] where + GroupKeyColumns keys '[] = '[] + GroupKeyColumns keys (Column n a ': rest) = + If + (IsElem n keys) + (Column n a ': GroupKeyColumns keys rest) + (GroupKeyColumns keys rest) + +------------------------------------------------------------------------------- +-- KnownSchema class +------------------------------------------------------------------------------- + +-- | Provides runtime evidence of a schema: a list of (name, TypeRep) pairs. +class KnownSchema (cols :: [Type]) where + schemaEvidence :: [(T.Text, SomeTypeRep)] + +instance KnownSchema '[] where + schemaEvidence = [] + +instance + (KnownSymbol name, Typeable a, Columnable a, KnownSchema rest) => + KnownSchema (Column name a ': rest) + where + schemaEvidence = + (T.pack (symbolVal (Proxy @name)), someTypeRep (Proxy @a)) + : schemaEvidence @rest + +------------------------------------------------------------------------------- +-- AllKnownSymbol helper +------------------------------------------------------------------------------- + +-- | A class that provides a list of 'Text' values for a type-level list of Symbols. +class AllKnownSymbol (names :: [Symbol]) where + symbolVals :: [T.Text] + +instance AllKnownSymbol '[] where + symbolVals = [] + +instance (KnownSymbol n, AllKnownSymbol ns) => AllKnownSymbol (n ': ns) where + symbolVals = T.pack (symbolVal (Proxy @n)) : symbolVals @ns diff --git a/src/DataFrame/Typed/TH.hs b/src/DataFrame/Typed/TH.hs new file mode 100644 index 0000000..acaf058 --- /dev/null +++ b/src/DataFrame/Typed/TH.hs @@ -0,0 +1,91 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} + +module DataFrame.Typed.TH ( + -- * Schema inference + deriveSchema, + + -- * Re-export for TH splices + TypedDataFrame, + Column, +) where + +import qualified Data.List as L +import qualified Data.Map as M +import qualified Data.Text as T + +import Language.Haskell.TH +import Language.Haskell.TH.Syntax (Lift (..)) + +import qualified DataFrame.Internal.Column as C +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Typed.Types (Column, TypedDataFrame) + +{- | Generate a type synonym for a schema based on an existing 'DataFrame'. + +@ +-} + +{- $(deriveSchema \"IrisSchema\" irisDF) +-- Generates: type IrisSchema = '[Column \"sepal_length\" Double, ...] +@ +-} + +deriveSchema :: String -> D.DataFrame -> DecsQ +deriveSchema typeName df = do + let cols = getSchemaInfo df + let names = map fst cols + case findDuplicate names of + Just dup -> fail $ "Duplicate column name in DataFrame: " ++ T.unpack dup + Nothing -> pure () + colTypes <- mapM mkColumnType cols + let schemaType = foldr (\t acc -> PromotedConsT `AppT` t `AppT` acc) PromotedNilT colTypes + let synName = mkName typeName + pure [TySynD synName [] schemaType] + +------------------------------------------------------------------------------- +-- Internal helpers +------------------------------------------------------------------------------- + +getSchemaInfo :: D.DataFrame -> [(T.Text, String)] +getSchemaInfo df = + let orderedNames = + map fst $ + L.sortBy (\(_, a) (_, b) -> compare a b) $ + M.toList (D.columnIndices df) + in map (\name -> (name, getColumnTypeStr name df)) orderedNames + +getColumnTypeStr :: T.Text -> D.DataFrame -> String +getColumnTypeStr name df = case D.getColumn name df of + Just col -> C.columnTypeString col + Nothing -> error $ "Column not found: " ++ T.unpack name + +mkColumnType :: (T.Text, String) -> Q Type +mkColumnType (name, tyStr) = do + ty <- parseTypeString tyStr + let nameLit = LitT (StrTyLit (T.unpack name)) + pure $ ConT ''Column `AppT` nameLit `AppT` ty + +parseTypeString :: String -> Q Type +parseTypeString "Int" = pure $ ConT ''Int +parseTypeString "Double" = pure $ ConT ''Double +parseTypeString "Float" = pure $ ConT ''Float +parseTypeString "Bool" = pure $ ConT ''Bool +parseTypeString "Char" = pure $ ConT ''Char +parseTypeString "String" = pure $ ConT ''String +parseTypeString "Text" = pure $ ConT ''T.Text +parseTypeString "Integer" = pure $ ConT ''Integer +parseTypeString s + | "Maybe " `L.isPrefixOf` s = do + inner <- parseTypeString (L.drop 6 s) + pure $ ConT ''Maybe `AppT` inner +parseTypeString s = fail $ "Unsupported column type in schema inference: " ++ s + +findDuplicate :: (Eq a) => [a] -> Maybe a +findDuplicate [] = Nothing +findDuplicate (x : xs) + | x `elem` xs = Just x + | otherwise = findDuplicate xs diff --git a/src/DataFrame/Typed/Types.hs b/src/DataFrame/Typed/Types.hs new file mode 100644 index 0000000..286a3bf --- /dev/null +++ b/src/DataFrame/Typed/Types.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Types ( + -- * Core phantom-typed wrapper + TypedDataFrame (..), + + -- * Column phantom type (no constructors) + Column, + + -- * Typed expressions (schema-validated) + TExpr (..), + + -- * Typed sort orders + TSortOrder (..), + + -- * Grouped typed dataframe + TypedGrouped (..), + + -- * Typed aggregation builder (Option B) + TAgg (..), + taggToNamedExprs, + + -- * Re-export These + These (..), +) where + +import Data.Kind (Type) +import Data.These (These (..)) +import GHC.TypeLits (Symbol) + +import qualified Data.Text as T +import DataFrame.Internal.Column (Columnable) +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Internal.Expression (Expr, NamedExpr, UExpr (..)) + +{- | A phantom-typed wrapper over the untyped 'DataFrame'. + +The type parameter @cols@ is a type-level list of @Column name ty@ entries +that tracks the schema at compile time. All operations delegate to the +untyped core at runtime and update the phantom type at compile time. +-} +newtype TypedDataFrame (cols :: [Type]) = TDF {unTDF :: D.DataFrame} + +instance Show (TypedDataFrame cols) where + show (TDF df) = show df + +instance Eq (TypedDataFrame cols) where + (TDF a) == (TDF b) = a == b + +{- | A phantom type that pairs a type-level column name ('Symbol') +with its element type. Has no value-level constructors — used +purely at the type level to describe schemas. +-} +data Column (name :: Symbol) (a :: Type) + +{- | A typed expression validated against schema @cols@, producing values of type @a@. + +Unlike the untyped 'Expr a', a 'TExpr' can only be constructed through +type-safe combinators ('col', 'lit', arithmetic operations) that verify +column references exist in the schema with the correct type. + +Use 'unTExpr' to extract the underlying 'Expr' for delegation to the untyped API. +-} +newtype TExpr (cols :: [Type]) a = TExpr {unTExpr :: Expr a} + +-- | A typed sort order validated against schema @cols@. +data TSortOrder (cols :: [Type]) where + Asc :: (Columnable a) => TExpr cols a -> TSortOrder cols + Desc :: (Columnable a) => TExpr cols a -> TSortOrder cols + +-- | A phantom-typed wrapper over 'GroupedDataFrame'. +newtype TypedGrouped (keys :: [Symbol]) (cols :: [Type]) + = TGD {unTGD :: D.GroupedDataFrame} + +{- | A typed aggregation builder (Option B). + +Accumulates 'NamedExpr' values at the term level while building +the result schema at the type level. Each @agg@ call prepends a +'Column' to the @aggs@ phantom list. + +Usage: + +@ +agg \@\"total\" (F.sum salary) + $ agg \@\"avg_age\" (F.mean age) + $ aggNil +@ +-} +data TAgg (keys :: [Symbol]) (cols :: [Type]) (aggs :: [Type]) where + TAggNil :: TAgg keys cols '[] + TAggCons :: + (Columnable a) => + -- | column name + T.Text -> + -- | typed aggregation expression + TExpr cols a -> + -- | rest + TAgg keys cols aggs -> + TAgg keys cols (Column name a ': aggs) + +{- | Extract the runtime 'NamedExpr' list from a 'TAgg', in +declaration order (reversed from the cons-built order). +-} +taggToNamedExprs :: TAgg keys cols aggs -> [NamedExpr] +taggToNamedExprs = reverse . go + where + go :: TAgg keys cols aggs -> [NamedExpr] + go TAggNil = [] + go (TAggCons name (TExpr expr) rest) = (name, UExpr expr) : go rest From ecb804d8ccd3646f68ae77f26aa5ab64d4bd097b Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Thu, 26 Feb 2026 15:21:51 -0800 Subject: [PATCH 2/7] fix: Linting. --- src/DataFrame/Typed/Expr.hs | 3 ++- src/DataFrame/Typed/TH.hs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/DataFrame/Typed/Expr.hs b/src/DataFrame/Typed/Expr.hs index 25647db..f2706d9 100644 --- a/src/DataFrame/Typed/Expr.hs +++ b/src/DataFrame/Typed/Expr.hs @@ -134,7 +134,8 @@ lit = TExpr . Lit -- | Conditional expression. ifThenElse :: - (Columnable a) => TExpr cols Bool -> TExpr cols a -> TExpr cols a -> TExpr cols a + (Columnable a) => + TExpr cols Bool -> TExpr cols a -> TExpr cols a -> TExpr cols a ifThenElse (TExpr c) (TExpr t) (TExpr e) = TExpr (If c t e) ------------------------------------------------------------------------------- diff --git a/src/DataFrame/Typed/TH.hs b/src/DataFrame/Typed/TH.hs index acaf058..51ff840 100644 --- a/src/DataFrame/Typed/TH.hs +++ b/src/DataFrame/Typed/TH.hs @@ -1,7 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TemplateHaskellQuotes #-} {-# LANGUAGE TypeApplications #-} module DataFrame.Typed.TH ( From c89c5a16179dbee999fec78a32fdc974c890b9d9 Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Fri, 27 Feb 2026 22:34:53 -0800 Subject: [PATCH 3/7] feat: Typed derive should add new columns to the end of the schema. --- src/DataFrame/Typed/Operations.hs | 2 +- src/DataFrame/Typed/Schema.hs | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/DataFrame/Typed/Operations.hs b/src/DataFrame/Typed/Operations.hs index 26955d4..b5b8a62 100644 --- a/src/DataFrame/Typed/Operations.hs +++ b/src/DataFrame/Typed/Operations.hs @@ -215,7 +215,7 @@ derive :: , Columnable a , AssertAbsent name cols ) => - TExpr cols a -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) + TExpr cols a -> TypedDataFrame cols -> TypedDataFrame (Snoc cols (T.Column name a)) derive (TExpr expr) (TDF df) = unsafeFreeze (D.derive colName expr df) where colName = T.pack (symbolVal (Proxy @name)) diff --git a/src/DataFrame/Typed/Schema.hs b/src/DataFrame/Typed/Schema.hs index 499c70d..b2323f1 100644 --- a/src/DataFrame/Typed/Schema.hs +++ b/src/DataFrame/Typed/Schema.hs @@ -23,6 +23,7 @@ module DataFrame.Typed.Schema ( RenameInSchema, RenameManyInSchema, Append, + Snoc, Reverse, ColumnNames, AssertAbsent, @@ -76,6 +77,11 @@ type family Lookup (name :: Symbol) (cols :: [Type]) :: Type where TypeError ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") +-- | Add type to the end of a list. +type family Snoc (xs :: [k]) (x :: k) :: [k] where + Snoc '[] x = '[x] + Snoc (y ': ys) x = y ': Snoc ys x + -- | Check whether a column name exists in a schema (type-level Bool). type family HasName (name :: Symbol) (cols :: [Type]) :: Bool where HasName name (Column name _ ': _) = 'True From c445a0c7e184007189ef4eed5bd14378efb00cc6 Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Sun, 1 Mar 2026 22:23:03 -0800 Subject: [PATCH 4/7] feat: Add TH function for deriving schema from CSV file. --- src/DataFrame/Typed.hs | 3 ++- src/DataFrame/Typed/Operations.hs | 4 +++- src/DataFrame/Typed/Schema.hs | 2 +- src/DataFrame/Typed/TH.hs | 8 ++++++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/DataFrame/Typed.hs b/src/DataFrame/Typed.hs index 8a893f2..d450347 100644 --- a/src/DataFrame/Typed.hs +++ b/src/DataFrame/Typed.hs @@ -165,6 +165,7 @@ module DataFrame.Typed ( -- * Template Haskell deriveSchema, + deriveSchemaFromCsvFile, -- * Schema type families (for advanced use) Lookup, @@ -208,7 +209,7 @@ import DataFrame.Typed.Freeze (freeze, freezeWithError, thaw, unsafeFreeze) import DataFrame.Typed.Join (fullOuterJoin, innerJoin, leftJoin, rightJoin) import DataFrame.Typed.Operations import DataFrame.Typed.Schema -import DataFrame.Typed.TH (deriveSchema) +import DataFrame.Typed.TH (deriveSchema, deriveSchemaFromCsvFile) import DataFrame.Typed.Types ( Column, TExpr (..), diff --git a/src/DataFrame/Typed/Operations.hs b/src/DataFrame/Typed/Operations.hs index b5b8a62..638425e 100644 --- a/src/DataFrame/Typed/Operations.hs +++ b/src/DataFrame/Typed/Operations.hs @@ -215,7 +215,9 @@ derive :: , Columnable a , AssertAbsent name cols ) => - TExpr cols a -> TypedDataFrame cols -> TypedDataFrame (Snoc cols (T.Column name a)) + TExpr cols a -> + TypedDataFrame cols -> + TypedDataFrame (Snoc cols (T.Column name a)) derive (TExpr expr) (TDF df) = unsafeFreeze (D.derive colName expr df) where colName = T.pack (symbolVal (Proxy @name)) diff --git a/src/DataFrame/Typed/Schema.hs b/src/DataFrame/Typed/Schema.hs index b2323f1..b7592fb 100644 --- a/src/DataFrame/Typed/Schema.hs +++ b/src/DataFrame/Typed/Schema.hs @@ -79,7 +79,7 @@ type family Lookup (name :: Symbol) (cols :: [Type]) :: Type where -- | Add type to the end of a list. type family Snoc (xs :: [k]) (x :: k) :: [k] where - Snoc '[] x = '[x] + Snoc '[] x = '[x] Snoc (y ': ys) x = y ': Snoc ys x -- | Check whether a column name exists in a schema (type-level Bool). diff --git a/src/DataFrame/Typed/TH.hs b/src/DataFrame/Typed/TH.hs index 51ff840..47c076c 100644 --- a/src/DataFrame/Typed/TH.hs +++ b/src/DataFrame/Typed/TH.hs @@ -7,12 +7,14 @@ module DataFrame.Typed.TH ( -- * Schema inference deriveSchema, + deriveSchemaFromCsvFile, -- * Re-export for TH splices TypedDataFrame, Column, ) where +import Control.Monad.IO.Class import qualified Data.List as L import qualified Data.Map as M import qualified Data.Text as T @@ -20,6 +22,7 @@ import qualified Data.Text as T import Language.Haskell.TH import Language.Haskell.TH.Syntax (Lift (..)) +import qualified DataFrame.IO.CSV as D import qualified DataFrame.Internal.Column as C import qualified DataFrame.Internal.DataFrame as D import DataFrame.Typed.Types (Column, TypedDataFrame) @@ -46,6 +49,11 @@ deriveSchema typeName df = do let synName = mkName typeName pure [TySynD synName [] schemaType] +deriveSchemaFromCsvFile :: String -> String -> DecsQ +deriveSchemaFromCsvFile typeName path = do + df <- liftIO (D.readCsv path) + deriveSchema typeName df + ------------------------------------------------------------------------------- -- Internal helpers ------------------------------------------------------------------------------- From 4ed1e0ee178a4667d8b564be8ec42e5a40b74f65 Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Mon, 2 Mar 2026 11:50:02 -0800 Subject: [PATCH 5/7] feat: Add type level impute function. --- src/DataFrame/Typed.hs | 3 ++- src/DataFrame/Typed/Aggregate.hs | 2 +- src/DataFrame/Typed/Expr.hs | 3 +-- src/DataFrame/Typed/Freeze.hs | 5 +---- src/DataFrame/Typed/Join.hs | 9 ++------- src/DataFrame/Typed/Operations.hs | 20 +++++++++++++++++--- src/DataFrame/Typed/Schema.hs | 17 +++++++++++------ src/DataFrame/Typed/TH.hs | 1 - 8 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/DataFrame/Typed.hs b/src/DataFrame/Typed.hs index d450347..72f0cbb 100644 --- a/src/DataFrame/Typed.hs +++ b/src/DataFrame/Typed.hs @@ -130,6 +130,7 @@ module DataFrame.Typed ( -- * Schema-modifying operations derive, + impute, select, exclude, rename, @@ -174,6 +175,7 @@ module DataFrame.Typed ( ExcludeSchema, RenameInSchema, RemoveColumn, + Impute, Append, Reverse, StripAllMaybe, @@ -212,7 +214,6 @@ import DataFrame.Typed.Schema import DataFrame.Typed.TH (deriveSchema, deriveSchemaFromCsvFile) import DataFrame.Typed.Types ( Column, - TExpr (..), TSortOrder (..), These (..), TypedDataFrame, diff --git a/src/DataFrame/Typed/Aggregate.hs b/src/DataFrame/Typed/Aggregate.hs index 341315f..15f230d 100644 --- a/src/DataFrame/Typed/Aggregate.hs +++ b/src/DataFrame/Typed/Aggregate.hs @@ -28,7 +28,7 @@ import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) import DataFrame.Internal.Column (Columnable) import qualified DataFrame.Internal.DataFrame as D -import DataFrame.Internal.Expression (Expr, NamedExpr, UExpr (..)) +import DataFrame.Internal.Expression (NamedExpr) import qualified DataFrame.Operations.Aggregation as DA import DataFrame.Typed.Freeze (unsafeFreeze) diff --git a/src/DataFrame/Typed/Expr.hs b/src/DataFrame/Typed/Expr.hs index f2706d9..9fc7ff3 100644 --- a/src/DataFrame/Typed/Expr.hs +++ b/src/DataFrame/Typed/Expr.hs @@ -83,7 +83,6 @@ module DataFrame.Typed.Expr ( desc, ) where -import Data.Kind (Type) import Data.Proxy (Proxy (..)) import Data.String (IsString (..)) import qualified Data.Text as T @@ -100,7 +99,7 @@ import DataFrame.Internal.Expression ( UnaryOp (..), ) import DataFrame.Typed.Schema (AssertPresent, Lookup) -import DataFrame.Typed.Types (Column, TExpr (..), TSortOrder (..)) +import DataFrame.Typed.Types (TExpr (..), TSortOrder (..)) ------------------------------------------------------------------------------- -- Column reference — the core type-safe construction point diff --git a/src/DataFrame/Typed/Freeze.hs b/src/DataFrame/Typed/Freeze.hs index 79963df..bc2f450 100644 --- a/src/DataFrame/Typed/Freeze.hs +++ b/src/DataFrame/Typed/Freeze.hs @@ -15,11 +15,8 @@ module DataFrame.Typed.Freeze ( unsafeFreeze, ) where -import Data.Kind (Type) -import qualified Data.Map as M -import Data.Proxy (Proxy (..)) import qualified Data.Text as T -import Type.Reflection (SomeTypeRep, someTypeRep) +import Type.Reflection (SomeTypeRep) import qualified DataFrame.Internal.Column as C import qualified DataFrame.Internal.DataFrame as D diff --git a/src/DataFrame/Typed/Join.hs b/src/DataFrame/Typed/Join.hs index 2c8622a..fdb4928 100644 --- a/src/DataFrame/Typed/Join.hs +++ b/src/DataFrame/Typed/Join.hs @@ -15,16 +15,11 @@ module DataFrame.Typed.Join ( fullOuterJoin, ) where -import Data.Proxy (Proxy (..)) -import qualified Data.Text as T -import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) +import GHC.TypeLits (Symbol) -import DataFrame.Internal.Column (Columnable) -import qualified DataFrame.Internal.DataFrame as D -import qualified DataFrame.Operations.Core as D import qualified DataFrame.Operations.Join as DJ -import DataFrame.Typed.Freeze (thaw, unsafeFreeze) +import DataFrame.Typed.Freeze (unsafeFreeze) import DataFrame.Typed.Schema import DataFrame.Typed.Types (TypedDataFrame (..)) diff --git a/src/DataFrame/Typed/Operations.hs b/src/DataFrame/Typed/Operations.hs index 638425e..52854aa 100644 --- a/src/DataFrame/Typed/Operations.hs +++ b/src/DataFrame/Typed/Operations.hs @@ -30,6 +30,7 @@ module DataFrame.Typed.Operations ( -- * Schema-modifying operations derive, + impute, select, exclude, rename, @@ -54,7 +55,6 @@ module DataFrame.Typed.Operations ( (|>), ) where -import qualified Data.Foldable as F import Data.Function ((&)) import Data.Proxy (Proxy (..)) import qualified Data.Text as T @@ -63,9 +63,9 @@ import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) import System.Random (RandomGen) import Prelude hiding (drop, filter, take) +import qualified DataFrame.Functions as DF import DataFrame.Internal.Column (Columnable) import qualified DataFrame.Internal.Column as C -import DataFrame.Internal.Expression (Expr (..)) import qualified DataFrame.Operations.Aggregation as DA import qualified DataFrame.Operations.Core as D import DataFrame.Operations.Merge () @@ -75,7 +75,7 @@ import qualified DataFrame.Operations.Transformations as D -- Semigroup instance -import DataFrame.Typed.Freeze (thaw, unsafeFreeze) +import DataFrame.Typed.Freeze (unsafeFreeze) import DataFrame.Typed.Schema import DataFrame.Typed.Types (TExpr (..), TSortOrder (..), TypedDataFrame (..)) import qualified DataFrame.Typed.Types as T @@ -222,6 +222,20 @@ derive (TExpr expr) (TDF df) = unsafeFreeze (D.derive colName expr df) where colName = T.pack (symbolVal (Proxy @name)) +impute :: + forall name a cols. + ( KnownSymbol name + , Columnable a + ) => + a -> + TypedDataFrame cols -> + TypedDataFrame (Impute name cols) +impute value (TDF df) = + unsafeFreeze + (D.derive colName (DF.fromMaybe value (DF.col @(Maybe a) colName)) df) + where + colName = T.pack (symbolVal (Proxy @name)) + -- | Select a subset of columns by name. select :: forall (names :: [Symbol]) cols. diff --git a/src/DataFrame/Typed/Schema.hs b/src/DataFrame/Typed/Schema.hs index b7592fb..5e8c4c1 100644 --- a/src/DataFrame/Typed/Schema.hs +++ b/src/DataFrame/Typed/Schema.hs @@ -18,6 +18,7 @@ module DataFrame.Typed.Schema ( Lookup, HasName, RemoveColumn, + Impute, SubsetSchema, ExcludeSchema, RenameInSchema, @@ -60,9 +61,10 @@ import Data.Proxy (Proxy (..)) import qualified Data.Text as T import Data.These (These) import GHC.TypeLits -import Type.Reflection (SomeTypeRep, Typeable, someTypeRep, typeRep) +import Type.Reflection (SomeTypeRep, Typeable, someTypeRep) import DataFrame.Internal.Column (Columnable) +import DataFrame.Internal.Types (If) import DataFrame.Typed.Types (Column) ------------------------------------------------------------------------------- @@ -77,6 +79,14 @@ type family Lookup (name :: Symbol) (cols :: [Type]) :: Type where TypeError ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") +type family Impute (name :: Symbol) (cols :: [Type]) :: [Type] where + Impute name (Column name (Maybe a) ': rest) = Column name a ': rest + Impute name (Column name _ ': rest) = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' is not of kind Maybe *") + Impute name (col ': rest) = col ': Impute name rest + Impute name '[] = '[] + -- | Add type to the end of a list. type family Snoc (xs :: [k]) (x :: k) :: [k] where Snoc '[] x = '[x] @@ -108,11 +118,6 @@ type family ExcludeSchema (names :: [Symbol]) (cols :: [Type]) :: [Type] where (ExcludeSchema names rest) (Column n a ': ExcludeSchema names rest) --- | Type-level if -type family If (b :: Bool) (t :: k) (f :: k) :: k where - If 'True t _ = t - If 'False _ f = f - -- | Type-level elem for Symbols type family IsElem (x :: Symbol) (xs :: [Symbol]) :: Bool where IsElem x '[] = 'False diff --git a/src/DataFrame/Typed/TH.hs b/src/DataFrame/Typed/TH.hs index 47c076c..2048341 100644 --- a/src/DataFrame/Typed/TH.hs +++ b/src/DataFrame/Typed/TH.hs @@ -20,7 +20,6 @@ import qualified Data.Map as M import qualified Data.Text as T import Language.Haskell.TH -import Language.Haskell.TH.Syntax (Lift (..)) import qualified DataFrame.IO.CSV as D import qualified DataFrame.Internal.Column as C From 7a2cc8eb36961b03eddfb66b1235e86bba0249bb Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Mon, 2 Mar 2026 12:42:12 -0800 Subject: [PATCH 6/7] chore: Remove t prefix from operations. --- src/DataFrame/Typed.hs | 18 ++++---- src/DataFrame/Typed/Aggregate.hs | 2 +- src/DataFrame/Typed/Expr.hs | 73 +++++++++++++------------------- src/DataFrame/Typed/Freeze.hs | 4 -- src/DataFrame/Typed/Schema.hs | 21 +-------- src/DataFrame/Typed/TH.hs | 4 -- 6 files changed, 40 insertions(+), 82 deletions(-) diff --git a/src/DataFrame/Typed.hs b/src/DataFrame/Typed.hs index 72f0cbb..2af5917 100644 --- a/src/DataFrame/Typed.hs +++ b/src/DataFrame/Typed.hs @@ -68,8 +68,8 @@ module DataFrame.Typed ( col, lit, ifThenElse, - tlift, - tlift2, + lift, + lift2, -- * Comparison operators (.==.), @@ -82,15 +82,15 @@ module DataFrame.Typed ( -- * Logical operators (.&&.), (.||.), - tnot, + DataFrame.Typed.Expr.not, -- * Aggregation expression combinators - tsum, - tmean, - tcount, - tminimum, - tmaximum, - tcollect, + DataFrame.Typed.Expr.sum, + mean, + count, + DataFrame.Typed.Expr.minimum, + DataFrame.Typed.Expr.maximum, + collect, -- * Typed sort orders TSortOrder (..), diff --git a/src/DataFrame/Typed/Aggregate.hs b/src/DataFrame/Typed/Aggregate.hs index 15f230d..ac538c1 100644 --- a/src/DataFrame/Typed/Aggregate.hs +++ b/src/DataFrame/Typed/Aggregate.hs @@ -79,7 +79,7 @@ Result schema = grouping key columns ++ aggregated columns (in declaration order @ result = aggregate - (agg \@\"total\" (tsum salary) $ agg \@\"count\" (tcount salary) $ aggNil) + (agg \@\"total\" (tsum (col @"salary")) $ agg \@\"count\" (tcount (col @"salary") $ aggNil) (groupBy \@'[\"dept\"] employees) -- result :: TDF '[Column \"dept\" Text, Column \"total\" Double, Column \"count\" Int] @ diff --git a/src/DataFrame/Typed/Expr.hs b/src/DataFrame/Typed/Expr.hs index 9fc7ff3..61c1900 100644 --- a/src/DataFrame/Typed/Expr.hs +++ b/src/DataFrame/Typed/Expr.hs @@ -51,8 +51,8 @@ module DataFrame.Typed.Expr ( ifThenElse, -- * Unary / binary lifting - tlift, - tlift2, + lift, + lift2, -- * Comparison operators (.==.), @@ -65,15 +65,15 @@ module DataFrame.Typed.Expr ( -- * Logical operators (.&&.), (.||.), - tnot, + DataFrame.Typed.Expr.not, -- * Aggregation combinators - tsum, - tmean, - tcount, - tminimum, - tmaximum, - tcollect, + sum, + mean, + count, + minimum, + maximum, + collect, -- * Named expression helper as, @@ -98,12 +98,10 @@ import DataFrame.Internal.Expression ( UExpr (..), UnaryOp (..), ) +import DataFrame.Internal.Statistics import DataFrame.Typed.Schema (AssertPresent, Lookup) import DataFrame.Typed.Types (TExpr (..), TSortOrder (..)) - -------------------------------------------------------------------------------- --- Column reference — the core type-safe construction point -------------------------------------------------------------------------------- +import Prelude hiding (maximum, minimum, sum) {- | Create a typed column reference. This is the key type-safety entry point. @@ -181,19 +179,15 @@ instance (IsString a, Columnable a) => IsString (TExpr cols a) where ------------------------------------------------------------------------------- -- | Lift a unary function into a typed expression. -tlift :: +lift :: (Columnable a, Columnable b) => (a -> b) -> TExpr cols a -> TExpr cols b -tlift f (TExpr e) = TExpr (Unary (MkUnaryOp f "unaryUdf" Nothing) e) +lift f (TExpr e) = TExpr (Unary (MkUnaryOp f "unaryUdf" Nothing) e) -- | Lift a binary function into typed expressions. -tlift2 :: +lift2 :: (Columnable a, Columnable b, Columnable c) => (a -> b -> c) -> TExpr cols a -> TExpr cols b -> TExpr cols c -tlift2 f (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp f "binaryUdf" Nothing False 0) a b) - -------------------------------------------------------------------------------- --- Comparison operators -------------------------------------------------------------------------------- +lift2 f (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp f "binaryUdf" Nothing False 0) a b) infixl 4 .==., ./=., .<., .<=., .>=., .>. infixr 3 .&&. @@ -229,35 +223,30 @@ infixr 2 .||. (.||.) :: TExpr cols Bool -> TExpr cols Bool -> TExpr cols Bool (.||.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (||) "or" (Just "||") True 2) a b) -tnot :: TExpr cols Bool -> TExpr cols Bool -tnot (TExpr e) = TExpr (Unary (MkUnaryOp not "not" (Just "!")) e) +not :: TExpr cols Bool -> TExpr cols Bool +not (TExpr e) = TExpr (Unary (MkUnaryOp Prelude.not "not" (Just "!")) e) ------------------------------------------------------------------------------- -- Aggregation combinators ------------------------------------------------------------------------------- -tsum :: (Columnable a, Num a) => TExpr cols a -> TExpr cols a -tsum (TExpr e) = TExpr (Agg (FoldAgg "sum" Nothing (+)) e) +sum :: (Columnable a, Num a) => TExpr cols a -> TExpr cols a +sum (TExpr e) = TExpr (Agg (FoldAgg "sum" Nothing (+)) e) -tmean :: (Columnable a, Real a, VU.Unbox a) => TExpr cols a -> TExpr cols Double -tmean (TExpr e) = TExpr (Agg (CollectAgg "mean" mean') e) - where - mean' v = - let s = VU.foldl' (\acc x -> acc + realToFrac x) (0 :: Double) v - n = VU.length v - in if n == 0 then 0 else s / fromIntegral n +mean :: (Columnable a, Real a, VU.Unbox a) => TExpr cols a -> TExpr cols Double +mean (TExpr e) = TExpr (Agg (CollectAgg "mean" mean') e) -tcount :: (Columnable a) => TExpr cols a -> TExpr cols Int -tcount (TExpr e) = TExpr (Agg (FoldAgg "count" (Just 0) (\acc _ -> acc + 1)) e) +count :: (Columnable a) => TExpr cols a -> TExpr cols Int +count (TExpr e) = TExpr (Agg (FoldAgg "count" (Just 0) (\acc _ -> acc + 1)) e) -tminimum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -tminimum (TExpr e) = TExpr (Agg (FoldAgg "minimum" Nothing min) e) +minimum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a +minimum (TExpr e) = TExpr (Agg (FoldAgg "minimum" Nothing min) e) -tmaximum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -tmaximum (TExpr e) = TExpr (Agg (FoldAgg "maximum" Nothing max) e) +maximum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a +maximum (TExpr e) = TExpr (Agg (FoldAgg "maximum" Nothing max) e) -tcollect :: (Columnable a) => TExpr cols a -> TExpr cols [a] -tcollect (TExpr e) = TExpr (Agg (FoldAgg "collect" (Just []) (flip (:))) e) +collect :: (Columnable a) => TExpr cols a -> TExpr cols [a] +collect (TExpr e) = TExpr (Agg (FoldAgg "collect" (Just []) (flip (:))) e) ------------------------------------------------------------------------------- -- Named expression helper @@ -267,10 +256,6 @@ tcollect (TExpr e) = TExpr (Agg (FoldAgg "collect" (Just []) (flip (:))) e) as :: (Columnable a) => TExpr cols a -> T.Text -> NamedExpr as (TExpr e) name = (name, UExpr e) -------------------------------------------------------------------------------- --- Sort helpers -------------------------------------------------------------------------------- - -- | Create an ascending sort order from a typed expression. asc :: (Columnable a) => TExpr cols a -> TSortOrder cols asc = Asc diff --git a/src/DataFrame/Typed/Freeze.hs b/src/DataFrame/Typed/Freeze.hs index bc2f450..fb00de2 100644 --- a/src/DataFrame/Typed/Freeze.hs +++ b/src/DataFrame/Typed/Freeze.hs @@ -54,10 +54,6 @@ Used internally after delegation where the library guarantees schema correctness unsafeFreeze :: D.DataFrame -> TypedDataFrame cols unsafeFreeze = TDF -------------------------------------------------------------------------------- --- Internal validation -------------------------------------------------------------------------------- - validateSchema :: forall cols. (KnownSchema cols) => diff --git a/src/DataFrame/Typed/Schema.hs b/src/DataFrame/Typed/Schema.hs index 5e8c4c1..7510b7a 100644 --- a/src/DataFrame/Typed/Schema.hs +++ b/src/DataFrame/Typed/Schema.hs @@ -67,10 +67,6 @@ import DataFrame.Internal.Column (Columnable) import DataFrame.Internal.Types (If) import DataFrame.Typed.Types (Column) -------------------------------------------------------------------------------- --- Core type families -------------------------------------------------------------------------------- - -- | Look up the element type of a column by name. type family Lookup (name :: Symbol) (cols :: [Type]) :: Type where Lookup name (Column name a ': _) = a @@ -79,6 +75,7 @@ type family Lookup (name :: Symbol) (cols :: [Type]) :: Type where TypeError ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") +-- | Unwrap a Maybe from a type after we impute values. type family Impute (name :: Symbol) (cols :: [Type]) :: [Type] where Impute name (Column name (Maybe a) ': rest) = Column name a ': rest Impute name (Column name _ ': rest) = @@ -186,10 +183,6 @@ type family TypeError ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") -------------------------------------------------------------------------------- --- Maybe-stripping families -------------------------------------------------------------------------------- - {- | Strip 'Maybe' from all columns. Used by 'filterAllJust'. @Column "x" (Maybe Double)@ becomes @Column "x" Double@. @@ -213,10 +206,6 @@ type family StripMaybeAt (name :: Symbol) (cols :: [Type]) :: [Type] where TypeError ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") -------------------------------------------------------------------------------- --- Join schema families -------------------------------------------------------------------------------- - -- | Extract column names that appear in both schemas. type family SharedNames (left :: [Type]) (right :: [Type]) :: [Symbol] where SharedNames '[] right = '[] @@ -324,10 +313,6 @@ type family GroupKeyColumns (keys :: [Symbol]) (cols :: [Type]) :: [Type] where (Column n a ': GroupKeyColumns keys rest) (GroupKeyColumns keys rest) -------------------------------------------------------------------------------- --- KnownSchema class -------------------------------------------------------------------------------- - -- | Provides runtime evidence of a schema: a list of (name, TypeRep) pairs. class KnownSchema (cols :: [Type]) where schemaEvidence :: [(T.Text, SomeTypeRep)] @@ -343,10 +328,6 @@ instance (T.pack (symbolVal (Proxy @name)), someTypeRep (Proxy @a)) : schemaEvidence @rest -------------------------------------------------------------------------------- --- AllKnownSymbol helper -------------------------------------------------------------------------------- - -- | A class that provides a list of 'Text' values for a type-level list of Symbols. class AllKnownSymbol (names :: [Symbol]) where symbolVals :: [T.Text] diff --git a/src/DataFrame/Typed/TH.hs b/src/DataFrame/Typed/TH.hs index 2048341..c1d460c 100644 --- a/src/DataFrame/Typed/TH.hs +++ b/src/DataFrame/Typed/TH.hs @@ -53,10 +53,6 @@ deriveSchemaFromCsvFile typeName path = do df <- liftIO (D.readCsv path) deriveSchema typeName df -------------------------------------------------------------------------------- --- Internal helpers -------------------------------------------------------------------------------- - getSchemaInfo :: D.DataFrame -> [(T.Text, String)] getSchemaInfo df = let orderedNames = From b12e9c338efcf75ce10cb35c7a3ef3d86af53199 Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Mon, 2 Mar 2026 13:40:26 -0800 Subject: [PATCH 7/7] chore: Add tests for typed API. --- tests/Main.hs | 23 ++++++------ tests/Operations/Aggregations.hs | 60 +++++++++++++++++++++++++++++++- tests/Operations/Derive.hs | 28 +++++++++++++++ tests/Operations/Join.hs | 53 ++++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 12 deletions(-) diff --git a/tests/Main.hs b/tests/Main.hs index 37820ed..6569e1b 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -5134,15 +5134,16 @@ isSuccessful _ = False main :: IO () main = do result <- runTestTT tests - -- Property tests - propRes <- - mapM - (quickCheckWithResult stdArgs) - Operations.Subset.tests - monadRes <- mapM (quickCheckWithResult stdArgs) Monad.tests - if failures result > 0 - || errors result > 0 - || not (all isSuccessful propRes) - || not (all isSuccessful monadRes) + if failures result > 0 || errors result > 0 then Exit.exitFailure - else Exit.exitSuccess + else do + -- Property tests + propRes <- + mapM + (quickCheckWithResult stdArgs) + Operations.Subset.tests + monadRes <- mapM (quickCheckWithResult stdArgs) Monad.tests + if not (all isSuccessful propRes) + || not (all isSuccessful monadRes) + then Exit.exitFailure + else Exit.exitSuccess diff --git a/tests/Operations/Aggregations.hs b/tests/Operations/Aggregations.hs index d43b0ae..7bb2307 100644 --- a/tests/Operations/Aggregations.hs +++ b/tests/Operations/Aggregations.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -7,6 +8,7 @@ import qualified Data.Text as T import qualified DataFrame as D import qualified DataFrame.Functions as F import qualified DataFrame.Internal.Column as DI +import qualified DataFrame.Typed as DT import Data.Function import DataFrame.Operators @@ -18,7 +20,7 @@ values = , ("test2", DI.fromList ([12, 11 .. 1] :: [Int])) , ("test3", DI.fromList ([1 .. 12] :: [Int])) , ("test4", DI.fromList ['a' .. 'l']) - , ("test4", DI.fromList (map show ['a' .. 'l'])) + , ("test5", DI.fromList (map show ['a' .. 'l'])) , ("test6", DI.fromList ([1 .. 12] :: [Integer])) ] @@ -42,6 +44,33 @@ foldAggregation = ) ) +foldAggregationTyped :: Test +foldAggregationTyped = + TestCase + ( assertEqual + "Typed counting elements after grouping gives correct numbers" + ( D.fromNamedColumns + [ ("test1", DI.fromList [1 :: Int, 2, 3]) + , ("test2_count", DI.fromList [6 :: Int, 3, 3]) + ] + ) + ( testData + & either (error . show) id + . DT.freezeWithError + @[ DT.Column "test1" Int + , DT.Column "test2" Int + , DT.Column "test3" Int + , DT.Column "test4" Char + , DT.Column "test5" String + , DT.Column "test6" Integer + ] + & DT.groupBy @'["test1"] + & DT.aggregate (DT.agg @"test2_count" (DT.count (DT.col @"test2")) DT.aggNil) + & DT.sortBy [DT.asc (DT.col @"test1")] + & DT.thaw + ) + ) + numericAggregation :: Test numericAggregation = TestCase @@ -59,6 +88,33 @@ numericAggregation = ) ) +numericAggregationTyped :: Test +numericAggregationTyped = + TestCase + ( assertEqual + "Typed ean works for ints" + ( D.fromNamedColumns + [ ("test1", DI.fromList [1 :: Int, 2, 3]) + , ("test2_mean", DI.fromList [6.5 :: Double, 8.0, 5.0]) + ] + ) + ( testData + & either (error . show) id + . DT.freezeWithError + @[ DT.Column "test1" Int + , DT.Column "test2" Int + , DT.Column "test3" Int + , DT.Column "test4" Char + , DT.Column "test5" String + , DT.Column "test6" Integer + ] + & DT.groupBy @'["test1"] + & DT.aggregate (DT.agg @"test2_mean" (DT.mean (DT.col @"test2")) DT.aggNil) + & DT.sortBy [DT.asc (DT.col @"test1")] + & DT.thaw + ) + ) + numericAggregationOfUnaggregatedUnaryOp :: Test numericAggregationOfUnaggregatedUnaryOp = TestCase @@ -154,7 +210,9 @@ aggregationOnNoRows = tests :: [Test] tests = [ TestLabel "foldAggregation" foldAggregation + , TestLabel "foldAggregationTyped" foldAggregationTyped , TestLabel "numericAggregation" numericAggregation + , TestLabel "numericAggregationTyped" numericAggregationTyped , TestLabel "numericAggregationOfUnaggregatedUnaryOp" numericAggregationOfUnaggregatedUnaryOp diff --git a/tests/Operations/Derive.hs b/tests/Operations/Derive.hs index 88692cf..0675ea3 100644 --- a/tests/Operations/Derive.hs +++ b/tests/Operations/Derive.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -10,6 +11,7 @@ import qualified DataFrame as D import qualified DataFrame.Functions as F import qualified DataFrame.Internal.Column as DI import qualified DataFrame.Internal.DataFrame as DI +import qualified DataFrame.Typed as DT import Test.HUnit @@ -44,7 +46,33 @@ deriveWAI = ) ) +deriveWAITyped :: Test +deriveWAITyped = + TestCase + ( assertEqual + "typed derive works with column expression" + (zipWith (\n c -> show n ++ [c]) [1 .. 26] ['a' .. 'z']) + ( DT.columnAsList @"test4" $ + DT.derive + @"test4" + ( DT.lift2 + (++) + (DT.lift show (DT.col @"test1")) + (DT.lift (: ([] :: [Char])) (DT.col @"test3")) + ) + ( either + (error . show) + id + ( DT.freezeWithError + @[DT.Column "test1" Int, DT.Column "test2" String, DT.Column "test3" Char] + testData + ) + ) + ) + ) + tests :: [Test] tests = [ TestLabel "deriveWAI" deriveWAI + , TestLabel "deriveWAITyped" deriveWAITyped ] diff --git a/tests/Operations/Join.hs b/tests/Operations/Join.hs index e34757d..24cde38 100644 --- a/tests/Operations/Join.hs +++ b/tests/Operations/Join.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -8,6 +9,7 @@ import Data.These import qualified DataFrame as D import qualified DataFrame.Functions as F import DataFrame.Operations.Join +import qualified DataFrame.Typed as DT import Test.HUnit df1 :: D.DataFrame @@ -66,6 +68,54 @@ testRightJoin = (D.sortBy [D.Asc (F.col @Text "key")] (rightJoin ["key"] df2 df1)) ) +tdf1 :: DT.TypedDataFrame [DT.Column "key" Text, DT.Column "A" Text] +tdf1 = either (error . show) id (DT.freezeWithError df1) + +tdf2 :: DT.TypedDataFrame [DT.Column "key" Text, DT.Column "B" Text] +tdf2 = either (error . show) id (DT.freezeWithError df2) + +testInnerJoinTyped :: Test +testInnerJoinTyped = + TestCase + ( assertEqual + "Test typed inner join with single key" + ( D.fromNamedColumns + [ ("key", D.fromList ["K0" :: Text, "K1", "K2"]) + , ("A", D.fromList ["A0" :: Text, "A1", "A2"]) + , ("B", D.fromList ["B0" :: Text, "B1", "B2"]) + ] + ) + (DT.thaw $ DT.sortBy [DT.asc (DT.col @"key")] (DT.innerJoin @'["key"] tdf1 tdf2)) + ) + +testLeftJoinTyped :: Test +testLeftJoinTyped = + TestCase + ( assertEqual + "Test typed left join with single key" + ( D.fromNamedColumns + [ ("key", D.fromList ["K0" :: Text, "K1", "K2", "K3", "K4", "K5"]) + , ("A", D.fromList ["A0" :: Text, "A1", "A2", "A3", "A4", "A5"]) + , ("B", D.fromList [Just "B0", Just "B1" :: Maybe Text, Just "B2"]) + ] + ) + (DT.thaw $ DT.sortBy [DT.asc (DT.col @"key")] (DT.leftJoin @'["key"] tdf1 tdf2)) + ) + +testRightJoinTyped :: Test +testRightJoinTyped = + TestCase + ( assertEqual + "Test typed right join with single key" + ( D.fromNamedColumns + [ ("key", D.fromList ["K0" :: Text, "K1", "K2"]) + , ("A", D.fromList [Just "A0" :: Maybe Text, Just "A1", Just "A2"]) + , ("B", D.fromList ["B0" :: Text, "B1", "B2"]) + ] + ) + (DT.thaw $ DT.sortBy [DT.asc (DT.col @"key")] (DT.rightJoin @'["key"] tdf1 tdf2)) + ) + staffDf :: D.DataFrame staffDf = D.fromRows @@ -206,8 +256,11 @@ testOuterJoinWithCollisions = tests :: [Test] tests = [ TestLabel "innerJoin" testInnerJoin + , TestLabel "testInnerJoinTyped" testInnerJoinTyped , TestLabel "leftJoin" testLeftJoin + , TestLabel "testLeftJoinTyped" testLeftJoinTyped , TestLabel "rightJoin" testRightJoin + , TestLabel "testRightJoinTyped" testRightJoinTyped , TestLabel "fullOuterJoin" testFullOuterJoin , TestLabel "innerJoinWithCollisions" testInnerJoinWithCollisions , TestLabel "leftJoinWithCollisions" testLeftJoinWithCollisions