From bf2348fea98219f135a30a1311d6d7d21eb3f1dc Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Fri, 6 Mar 2026 12:11:18 +0100 Subject: [PATCH] Add position to syntax --- Cargo.toml | 1 + doc/_quarto.yml | 7 +- doc/styles.scss | 5 + doc/syntax/index.qmd | 22 +- doc/syntax/layer/position/dodge.qmd | 46 + doc/syntax/layer/position/identity.qmd | 7 + doc/syntax/layer/position/jitter.qmd | 65 + doc/syntax/layer/position/stack.qmd | 61 + doc/syntax/layer/{ => type}/area.qmd | 25 +- doc/syntax/layer/{ => type}/bar.qmd | 17 +- doc/syntax/layer/{ => type}/boxplot.qmd | 3 +- doc/syntax/layer/{ => type}/density.qmd | 9 +- doc/syntax/layer/{ => type}/histogram.qmd | 13 +- doc/syntax/layer/{ => type}/line.qmd | 4 +- doc/syntax/layer/{ => type}/path.qmd | 8 +- doc/syntax/layer/{ => type}/point.qmd | 17 +- doc/syntax/layer/{ => type}/polygon.qmd | 6 +- doc/syntax/layer/{ => type}/ribbon.qmd | 4 +- doc/syntax/layer/{ => type}/violin.qmd | 9 +- src/Cargo.toml | 1 + src/execute/mod.rs | 53 +- src/execute/position.rs | 325 ++++ src/execute/scale.rs | 46 +- src/parser/builder.rs | 102 ++ src/plot/layer/geom/abline.rs | 9 +- src/plot/layer/geom/area.rs | 8 +- src/plot/layer/geom/arrow.rs | 9 +- src/plot/layer/geom/bar.rs | 14 +- src/plot/layer/geom/boxplot.rs | 13 +- src/plot/layer/geom/density.rs | 5 +- src/plot/layer/geom/errorbar.rs | 9 +- src/plot/layer/geom/histogram.rs | 4 + src/plot/layer/geom/hline.rs | 9 +- src/plot/layer/geom/label.rs | 9 +- src/plot/layer/geom/line.rs | 9 +- src/plot/layer/geom/mod.rs | 25 + src/plot/layer/geom/path.rs | 9 +- src/plot/layer/geom/point.rs | 9 +- src/plot/layer/geom/polygon.rs | 9 +- src/plot/layer/geom/ribbon.rs | 9 +- src/plot/layer/geom/segment.rs | 9 +- src/plot/layer/geom/smooth.rs | 9 +- src/plot/layer/geom/text.rs | 9 +- src/plot/layer/geom/tile.rs | 9 +- src/plot/layer/geom/violin.rs | 176 ++- src/plot/layer/geom/vline.rs | 9 +- src/plot/layer/mod.rs | 52 +- src/plot/layer/position/dodge.rs | 699 +++++++++ src/plot/layer/position/identity.rs | 58 + src/plot/layer/position/jitter.rs | 1634 +++++++++++++++++++++ src/plot/layer/position/mod.rs | 352 +++++ src/plot/layer/position/stack.rs | 765 ++++++++++ src/plot/scale/scale_type/binned.rs | 127 +- src/reader/mod.rs | 115 ++ src/writer/vegalite/layer.rs | 223 +-- src/writer/vegalite/mod.rs | 66 + 56 files changed, 5132 insertions(+), 195 deletions(-) create mode 100644 doc/syntax/layer/position/dodge.qmd create mode 100644 doc/syntax/layer/position/identity.qmd create mode 100644 doc/syntax/layer/position/jitter.qmd create mode 100644 doc/syntax/layer/position/stack.qmd rename doc/syntax/layer/{ => type}/area.qmd (60%) rename doc/syntax/layer/{ => type}/bar.qmd (82%) rename doc/syntax/layer/{ => type}/boxplot.qmd (91%) rename doc/syntax/layer/{ => type}/density.qmd (88%) rename doc/syntax/layer/{ => type}/histogram.qmd (81%) rename doc/syntax/layer/{ => type}/line.qmd (83%) rename doc/syntax/layer/{ => type}/path.qmd (84%) rename doc/syntax/layer/{ => type}/point.qmd (67%) rename doc/syntax/layer/{ => type}/polygon.qmd (84%) rename doc/syntax/layer/{ => type}/ribbon.qmd (84%) rename doc/syntax/layer/{ => type}/violin.qmd (88%) create mode 100644 src/execute/position.rs create mode 100644 src/plot/layer/position/dodge.rs create mode 100644 src/plot/layer/position/identity.rs create mode 100644 src/plot/layer/position/jitter.rs create mode 100644 src/plot/layer/position/mod.rs create mode 100644 src/plot/layer/position/stack.rs diff --git a/Cargo.toml b/Cargo.toml index ab823748..3d6a410f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ palette = "0.7" # Utilities regex = "1.10" chrono = "0.4" +rand = "0.8" const_format = "0.2" uuid = { version = "1.0", features = ["v4"] } diff --git a/doc/_quarto.yml b/doc/_quarto.yml index a5801133..27ec54dc 100644 --- a/doc/_quarto.yml +++ b/doc/_quarto.yml @@ -79,7 +79,12 @@ website: href: syntax/clause/label.qmd - section: Layers contents: - - auto: syntax/layer/* + - section: Types + contents: + - auto: syntax/layer/type/* + - section: Position adjustment + contents: + - auto: syntax/layer/position/* - section: Scales contents: - section: Types diff --git a/doc/styles.scss b/doc/styles.scss index 1fafc53d..8848bfbb 100644 --- a/doc/styles.scss +++ b/doc/styles.scss @@ -19,6 +19,11 @@ code { font-variant-ligatures: none } +// Add spacing below rendered plots so text doesn't crowd them +.cell-output-display { + margin-bottom: 1.5rem; +} + .hero-banner { padding: 0; margin: 0; diff --git a/doc/syntax/index.qmd b/doc/syntax/index.qmd index 1400280f..45764414 100644 --- a/doc/syntax/index.qmd +++ b/doc/syntax/index.qmd @@ -15,17 +15,17 @@ ggsql augments the standard SQL syntax with a number of new clauses to describe ## Layers There are many different layers to choose from when visualising your data. Some are straightforward translations of your data into visual marks such as a point layer, while others perform more or less complicated calculations like e.g. the histogram layer. A layer is selected by providing the layer name after the `DRAW` clause -- [`point`](layer/point.qmd) is used to create a scatterplot layer -- [`line`](layer/line.qmd) is used to produce lineplots with the data sorted along the x axis -- [`path`](layer/path.qmd) is like `line` above but does not sort the data but plot it according to its own order -- [`area`](layer/area.qmd) is used to display series as an area chart. -- [`ribbon`](layer/ribbon.qmd) is used to display series extrema. -- [`polygon`](layer/polygon.qmd) is used to display arbitrary shapes as polygons. -- [`bar`](layer/bar.qmd) creates a bar chart, optionally calculating y from the number of records in each bar -- [`density`](layer/density.qmd) creates univariate kernel density estimates, showing the distribution of a variable -- [`violin`](layer/violin.qmd) displays a rotated kernel density estimate -- [`histogram`](layer/histogram.qmd) bins the data along the x axis and produces a bar for each bin showing the number of records in it -- [`boxplot`](layer/boxplot.qmd) displays continuous variables as 5-number summaries +- [`point`](layer/type/point.qmd) is used to create a scatterplot layer +- [`line`](layer/type/line.qmd) is used to produce lineplots with the data sorted along the x axis +- [`path`](layer/type/path.qmd) is like `line` above but does not sort the data but plot it according to its own order +- [`area`](layer/type/area.qmd) is used to display series as an area chart. +- [`ribbon`](layer/type/ribbon.qmd) is used to display series extrema. +- [`polygon`](layer/type/polygon.qmd) is used to display arbitrary shapes as polygons. +- [`bar`](layer/type/bar.qmd) creates a bar chart, optionally calculating y from the number of records in each bar +- [`density`](layer/type/density.qmd) creates univariate kernel density estimates, showing the distribution of a variable +- [`violin`](layer/type/violin.qmd) displays a rotated kernel density estimate +- [`histogram`](layer/type/histogram.qmd) bins the data along the x axis and produces a bar for each bin showing the number of records in it +- [`boxplot`](layer/type/boxplot.qmd) displays continuous variables as 5-number summaries ## Scales A scale is responsible for translating a data value to an aesthetic literal, e.g. a specific color for the fill aesthetic, or a radius in points for the size aesthetic. A scale is a combination of a specific aesthetic and a scale type diff --git a/doc/syntax/layer/position/dodge.qmd b/doc/syntax/layer/position/dodge.qmd new file mode 100644 index 00000000..7f82c1e3 --- /dev/null +++ b/doc/syntax/layer/position/dodge.qmd @@ -0,0 +1,46 @@ +--- +title: Dodge +--- + +> Positions are set within the [`DRAW` clause](../../clause/draw.qmd), using the `SETTING`subclause. Read the documentation for this clause for a thorough description of how to use it. + +The dodge adjustment is intended to move entities that share the same position on a discrete scale side by side so they don't overlap. It is most often used for boxplots and violin plots but can also be used in e.g. bar plots as an alternative to [stacking](stack.qmd). + +## Scale requirements +Dodge doesn't have specific requirements to the position scale type of the plot but will only affect discrete scales (including binned and ordinal). If only one scale is discrete the dodging happens in the scale direction. If both scales are discrete the dodging happens as a 2D grid. + +## Settings +Apart from the settings of the layer type, setting `position => 'dodge'` will allow these additional settings: + +* `width`: The total width the dodging will occupy as a proportion of the space available on the scale. Defaults to 0.9 + +## Examples + +Dodging is default in boxplots (and violin plots) + +```{ggsql} +VISUALISE species AS x, bill_dep AS y, sex AS fill FROM ggsql:penguins + DRAW boxplot +``` + +Turning it off allows you to see the effect of it + +```{ggsql} +VISUALISE species AS x, bill_dep AS y, sex AS fill FROM ggsql:penguins + DRAW boxplot SETTING position => 'identity' +``` + +Dodge can be used for bar plots as an alternative to the default stack + +```{ggsql} +VISUALISE species AS x, island AS fill FROM ggsql:penguins + DRAW bar SETTING position => 'dodge' +``` + +Often `width` is part of the layer settings and gets used directly by the dodge position, but for layers with no inherent width setting dodge provides that setting as well + +```{ggsql} +VISUALISE species AS x, bill_dep AS y, sex AS shape FROM ggsql:penguins + DRAW point SETTING position => 'dodge', width => 0.5 +``` + diff --git a/doc/syntax/layer/position/identity.qmd b/doc/syntax/layer/position/identity.qmd new file mode 100644 index 00000000..be0845a1 --- /dev/null +++ b/doc/syntax/layer/position/identity.qmd @@ -0,0 +1,7 @@ +--- +title: Identity +--- + +> Positions are set within the [`DRAW` clause](../../clause/draw.qmd), using the `SETTING`subclause. Read the documentation for this clause for a thorough description of how to use it. + +The identity position is a position adjustment that does nothing, i.e. it leaves the data where it is. It is used to turn off any position adjustments for layers that defaults to something else. It takes no arguments and has no requirements. diff --git a/doc/syntax/layer/position/jitter.qmd b/doc/syntax/layer/position/jitter.qmd new file mode 100644 index 00000000..b446f5d0 --- /dev/null +++ b/doc/syntax/layer/position/jitter.qmd @@ -0,0 +1,65 @@ +--- +title: Jitter +--- + +> Positions are set within the [`DRAW` clause](../../clause/draw.qmd), using the `SETTING`subclause. Read the documentation for this clause for a thorough description of how to use it. + +Jitter adjustment adds a random offset to the data point to avoid overplotting on discrete axes. It is mainly used in conjunction with point layers. + +## Scale requirements +Jitter requires at least one axis to be discrete as it only jitters along discrete axes. For the `'density'` and `'intensity'` distributions (see [settings](#settings)) the other axis *must be* continuous + +## Settings +Apart from the settings of the layer type, setting `position => 'jitter'` will allow these additional settings: + +* `width`: The total width the jittering will occupy as a proportion of the space available on the scale. Defaults to 0.9 +* `dodge`: Should dodging be applied before jittering. The dodging behavior follows the [dodge position](dodge.qmd) behavior? Default to `true` +* `distribution`: Which kind of distribution should the jittering follow? One of: + - `'uniform'` (default): Jittering is sampled from a uniform distribution between `-width/2` and `width/2` + - `'normal'`: Jittering is sampled from a normal distribution with σ as `width/4` resulting in 95% of the points falling inside the given width + - `'density'`: Jittering follows the density distribution within the group so that the jitter occupies the same area as an equivalent [violin plot](../type/violin.qmd) with density remapped to offset + - `'intensity'`: Jittering follows the intensity distribution within the group so that the jitter occupies the same area as an equivalent [violin plot](../type/violin.qmd) with intensity remapped to offset +* `bandwidth`: A numerical value setting the smoothing bandwidth to use for the `'density'` and `'intensity'` distributions. If absent (default), the bandwidth will be computed using Silverman's rule of thumb. +* `adjust`: A numerical value as multiplier for the `bandwidth` setting, with 1 as default. + +## Examples +When plotting points on a discrete axis they are all placed in the middle + +```{ggsql} +VISUALISE species AS x, bill_dep AS y, sex AS fill FROM ggsql:penguins +DRAW point +``` + +Use jittering to better see the individual points + +```{ggsql} +VISUALISE species AS x, bill_dep AS y, sex AS fill FROM ggsql:penguins +DRAW point + SETTING position => 'jitter' +``` + +By default, dodging is applied to separate the groups. Turn this off if you want the jitter to occupy the same space regardless of grouping + +```{ggsql} +VISUALISE species AS x, bill_dep AS y, sex AS fill FROM ggsql:penguins +DRAW point + SETTING position => 'jitter', dodge => false +``` + +Use a `'density'` distribution to also indicate the distribution shape with the jitter + +```{ggsql} +VISUALISE species AS x, bill_dep AS y FROM ggsql:penguins +DRAW point + SETTING position => 'jitter', distribution => 'density' +``` + +When both axes are discrete the dodging follows a grid + +```{ggsql} +VISUALISE species AS x, sex AS y, body_mass AS fill FROM ggsql:penguins +DRAW point + SETTING position => 'jitter' +SCALE BINNED fill + SETTING breaks => 4, pretty => false +``` diff --git a/doc/syntax/layer/position/stack.qmd b/doc/syntax/layer/position/stack.qmd new file mode 100644 index 00000000..a1c10f22 --- /dev/null +++ b/doc/syntax/layer/position/stack.qmd @@ -0,0 +1,61 @@ +--- +title: Stack +--- + +> Positions are set within the [`DRAW` clause](../../clause/draw.qmd), using the `SETTING`subclause. Read the documentation for this clause for a thorough description of how to use it. + +The stack position adjustment works by stacking objects on top of each other. It makes the most sense for layer types where their height is the primary encoding (i.e. they naturally extend from 0). Stack is the default position for bar and area plots + +## Scale requirements +Stack requires a continuous position scale with a range mapping (e.g. either `y` + `yend` or `ymin` + `ymax`) and all ranges be positive with a baseline of zero. The axis that satisfies this will be used as the stacking direction + +## Settings +Apart from the settings of the layer type, setting `position => 'stack'` will allow these additional settings: + +* `center`: Should the full stack be centered around 0. Can be used in conjunction with area layers to create steamgraphs. Default to `false` +* `total`: Sets a value each stack height should be normalised to. Defaults to `null` (no normalisation) + +## Examples + +Stack is the default for bar and area + +```{ggsql} +VISUALISE Day AS x, Wind AS y FROM ggsql:airquality +DRAW area + MAPPING Month AS fill + FILTER Day <= 30 +SCALE ORDINAL fill +``` + +Turn it off to see the effect (stacking is nonsensical for wind measurements) + +```{ggsql} +VISUALISE Day AS x, Wind AS y FROM ggsql:airquality +DRAW area + MAPPING Month AS fill + SETTING position => 'identity' + FILTER Day <= 30 +SCALE ORDINAL fill +``` + +Set `center => true` to create a steamgraph + +```{ggsql} +VISUALISE Day AS x, Wind AS y FROM ggsql:airquality +DRAW area + MAPPING Month AS fill + SETTING center => true + FILTER Day <= 30 +SCALE ORDINAL fill +``` + +Use `total` to see the percentage contribution from each group + +```{ggsql} +VISUALISE Day AS x, Wind AS y FROM ggsql:airquality +DRAW area + MAPPING Month AS fill + SETTING total => 100 + FILTER Day <= 30 +SCALE ORDINAL fill +``` diff --git a/doc/syntax/layer/area.qmd b/doc/syntax/layer/type/area.qmd similarity index 60% rename from doc/syntax/layer/area.qmd rename to doc/syntax/layer/type/area.qmd index c247b95d..81830f39 100644 --- a/doc/syntax/layer/area.qmd +++ b/doc/syntax/layer/type/area.qmd @@ -2,7 +2,7 @@ title: "Area" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. The area layer is used to display absolute amounts over a sorted x-axis. It can be seen as a [ribbon layer](ribbon.qmd) where the `ymin` is anchored at zero. @@ -21,10 +21,7 @@ The following aesthetics are recognised by the area layer. * `linewidth`: The width of the contour lines. ## Settings -* `stacking`: Determines how multiple groups are displayed. One of the following: - * `'off'`: The groups `y`-values are displayed as-is (default). - * `'on'`: The `y`-values are stacked per `x` position, accumulating over groups. - * `'fill'`: Like `'on'` but displayed as a fraction of the total per `x` position. +* `position`: Determines the position adjustment to use for the layer (default is `'stack'`) ## Data transformation The area layer does not transform its data but passes it through unchanged. @@ -56,17 +53,23 @@ VISUALISE Date AS x, Value AS y FROM long_airquality DRAW area MAPPING Series AS colour ``` -We can stack the series by using `stacking => 'on'`. The line serves as a reference for 'unstacked' data. +By default the areas are stacked on top of each other. If you'd rather see all with a 0 baseline set the position to identity ```{ggsql} VISUALISE Date AS x, Value AS y, Series AS colour FROM long_airquality - DRAW area SETTING stacking => 'on', opacity => 0.5 - DRAW line + DRAW area SETTING position => 'identity', opacity => 0.5 ``` -When `stacking => 'fill'` we're plotting stacked proportions. These only make sense if every series is measured in the same absolute unit. (Wind and temperature have different units and the temperature is not absolute.) +When `position => 'stack_fill'` we're plotting stacked proportions. These only make sense if every series is measured in the same absolute unit. (Wind and temperature have different units and the temperature is not absolute.) ```{ggsql} VISUALISE Date AS x, Value AS y, Series AS colour FROM long_airquality - DRAW area SETTING stacking => 'fill' -``` \ No newline at end of file + DRAW area SETTING position => 'fill' +``` + +An alternative is to center the stacks to create a steamgraph + +```{ggsql} +VISUALISE Date AS x, Value AS y, Series AS colour FROM long_airquality + DRAW area SETTING position => 'stack', center => true +``` diff --git a/doc/syntax/layer/bar.qmd b/doc/syntax/layer/type/bar.qmd similarity index 82% rename from doc/syntax/layer/bar.qmd rename to doc/syntax/layer/type/bar.qmd index 96948b85..0acc5ab7 100644 --- a/doc/syntax/layer/bar.qmd +++ b/doc/syntax/layer/type/bar.qmd @@ -2,7 +2,7 @@ title: "Bar" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. The bar layer is used to create bar plots. You can either specify the height of the bars directly or let the layer calculate it either as the count of records within the same group or as a weighted sum of the records. @@ -23,7 +23,7 @@ The bar layer has no required aesthetics * `linetype`: The type of stroke, i.e. the dashing pattern ## Settings - +* `position`: Determines the position adjustment to use for the layer (default is `'stack'`) * `width`: The width of the bars as a proportion of the available width ## Data transformation @@ -68,6 +68,15 @@ DRAW bar MAPPING species AS x, island AS fill ``` +Or change the position setting to e.g. get a dodged bar chart + +```{ggsql} +VISUALISE FROM ggsql:penguins +DRAW bar + MAPPING species AS x, sex AS fill + SETTING position => 'dodge' +``` + Map to y if the dataset already contains the value you want to show ```{ggsql} @@ -87,3 +96,7 @@ DRAW bar SCALE BINNED x SETTING breaks => 10 ``` + +And use with a polar coordinate system to create a pie chart + +**TBD** diff --git a/doc/syntax/layer/boxplot.qmd b/doc/syntax/layer/type/boxplot.qmd similarity index 91% rename from doc/syntax/layer/boxplot.qmd rename to doc/syntax/layer/type/boxplot.qmd index 10e8d2c4..60ad010f 100644 --- a/doc/syntax/layer/boxplot.qmd +++ b/doc/syntax/layer/type/boxplot.qmd @@ -1,7 +1,7 @@ --- title: "Boxplot" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. Boxplots display a summary of a continuous distribution. In the style of Tukey, it displays the median, two hinges and two whiskers as well as outlying points. @@ -23,6 +23,7 @@ The following aesthetics are recognised by the boxplot layer. * `shape` The shape of outlier points. ## Settings +* `position`: Determines the position adjustment to use for the layer (default is `'dodge'`) * `outliers`: Whether to display outliers as points. Defaults to `true`. * `coef`: A number indicating the length of the whiskers as a multiple of the interquartile range (IQR). Defaults to `1.5`. * `width`: Relative width of the boxes. Defaults to `0.9`. diff --git a/doc/syntax/layer/density.qmd b/doc/syntax/layer/type/density.qmd similarity index 88% rename from doc/syntax/layer/density.qmd rename to doc/syntax/layer/type/density.qmd index 97319e31..b73bd3a2 100644 --- a/doc/syntax/layer/density.qmd +++ b/doc/syntax/layer/type/density.qmd @@ -2,7 +2,7 @@ title: "Density" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. Visualise the distribution of a single continuous variable by computing a kernel density estimate. It has a similar interpretation as a histogram but smoothing out observations rather than binning them. @@ -21,10 +21,7 @@ The following aesthetics are recognised by the density layer. * `linetype` The dash pattern of the contour line. ## Settings -* `stacking`: Determines how multiple groups are displayed. One of the following: - * `'off'`: The groups `y`-values are displayed as-is (default). - * `'on'`: The `y`-values are stacked per `x` position, accumulating over groups. - * `'fill'`: Like `'on'` but displayed as a fraction of the total per `x` position. +* `position`: Determines the position adjustment to use for the layer (default is `'identity'`) * `bandwidth`: A numerical value setting the smoothing bandwidth to use. If absent (default), the bandwidth will be computed using Silverman's rule of thumb. * `adjust`: A numerical value as multiplier for the `bandwidth` setting, with 1 as default. * `kernel`: Determines the smoothing kernel shape. Can be one of the following: @@ -87,7 +84,7 @@ Stacking the different groups instead of overlaying them. ```{ggsql} VISUALISE bill_dep AS x, species AS colour FROM ggsql:penguins - DRAW density SETTING stacking => 'on' + DRAW density SETTING position => 'stack' ``` Using weighted estimates by mapping a column to the optional weight aesthetic. Note that the difference in output is subtle. diff --git a/doc/syntax/layer/histogram.qmd b/doc/syntax/layer/type/histogram.qmd similarity index 81% rename from doc/syntax/layer/histogram.qmd rename to doc/syntax/layer/type/histogram.qmd index b0f57b79..08ba69f8 100644 --- a/doc/syntax/layer/histogram.qmd +++ b/doc/syntax/layer/type/histogram.qmd @@ -2,7 +2,7 @@ title: "Histogram" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. Visualise the distribution of a single continuous variable by dividing the x axis into bins and counting the number of observations in each bin. If providing a weight then a weighted histogram is calculated instead. @@ -21,7 +21,7 @@ The following aesthetics are recognised by the bar layer. * `linetype`: The type of stroke, i.e. the dashing pattern ## Settings - +* `position`: Determines the position adjustment to use for the layer (default is `'stack'`) * `bins`: The number of bins to calculate. Defaults to `30` * `binwidth`: The width of each bin. If provided it will override the binwidth calculated from `bins` * `closed`: Either `'left'` or `'right'` (default). Determines whether the bin intervals are closed to the left or right side @@ -60,6 +60,15 @@ DRAW histogram MAPPING body_mass AS x, sex AS fill ``` +The default is to stack multiple histograms. To compare them from a baseline of 0 set position to identity + +```{ggsql} +VISUALISE FROM ggsql:penguins +DRAW histogram + MAPPING body_mass AS x, sex AS fill + SETTING position => 'identity' +``` + Make the two histograms the same scale by remapping to density ```{ggsql} diff --git a/doc/syntax/layer/line.qmd b/doc/syntax/layer/type/line.qmd similarity index 83% rename from doc/syntax/layer/line.qmd rename to doc/syntax/layer/type/line.qmd index 6cbb6902..3dc45367 100644 --- a/doc/syntax/layer/line.qmd +++ b/doc/syntax/layer/type/line.qmd @@ -2,7 +2,7 @@ title: "Line" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. The line layer is used to create lineplots. Lineplots always connects records along the x-axis, in contrast to [path layers](path.qmd) which use the order of data to connect records. Lines are divided due to their grouping, which is the combination of the discrete mapped aesthetics and the columns specified in the layers [`PARTITION BY`](../clause/draw.qmd#partition-by). @@ -20,7 +20,7 @@ The following aesthetics are recognised by the line layer. * `linetype`: The type of line, i.e. the dashing pattern ## Settings -The line layer has no additional settings +* `position`: Determines the position adjustment to use for the layer (default is `'identity'`) ## Data transformation The line layer does not transform its data but passes it through unchanged diff --git a/doc/syntax/layer/path.qmd b/doc/syntax/layer/type/path.qmd similarity index 84% rename from doc/syntax/layer/path.qmd rename to doc/syntax/layer/type/path.qmd index 439775ab..7b755abf 100644 --- a/doc/syntax/layer/path.qmd +++ b/doc/syntax/layer/type/path.qmd @@ -2,9 +2,9 @@ title: "Path" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. -The path layer is used to create lineplots, but contrary to the [line layer](line.qmd) the data will not be connected along the x-axis. Instead records are connected in the order they appear in the data. Lines are divided due to their grouping, which is the combination of the discrete mapped aesthetics and the columns specified in the layers [`PARTITION BY`](../clause/draw.qmd#partition-by). +The path layer is used to create lineplots, but contrary to the [line layer](line.qmd) the data will not be connected along the x-axis. Instead records are connected in the order they appear in the data. Lines are divided due to their grouping, which is the combination of the discrete mapped aesthetics and the columns specified in the layers [`PARTITION BY`](../../clause/draw.qmd#partition-by). ## Aesthetics The following aesthetics are recognised by the path layer. @@ -20,7 +20,7 @@ The following aesthetics are recognised by the path layer. * `linetype`: The type of path, i.e. the dashing pattern ## Settings -The line layer has no additional settings +* `position`: Determines the position adjustment to use for the layer (default is `'identity'`) ## Data transformation The line layer does not transform its data but passes it through unchanged @@ -76,4 +76,4 @@ Compared to polygons, paths don't close their shapes and fill their interiors. VISUALISE x, y FROM df DRAW polygon MAPPING 'Polygon' AS stroke DRAW path MAPPING 'Path' AS stroke -``` \ No newline at end of file +``` diff --git a/doc/syntax/layer/point.qmd b/doc/syntax/layer/type/point.qmd similarity index 67% rename from doc/syntax/layer/point.qmd rename to doc/syntax/layer/type/point.qmd index a1f181cb..cf8875a5 100644 --- a/doc/syntax/layer/point.qmd +++ b/doc/syntax/layer/type/point.qmd @@ -2,7 +2,7 @@ title: "Point" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. The point layer is used to create scatterplots. The scatterplot is most useful for displaying the relationship between two continuous variables. A bubblechart is a scatterplot with a third variable mapped to the size of points. @@ -22,7 +22,7 @@ The following aesthetics are recognised by the point layer. * `shape`: The shape used to draw the point ## Settings -The point layer has no additional settings +* `position`: Determines the position adjustment to use for the layer (default is `'identity'`) ## Data transformation The point layer does not transform its data but passes it through unchanged @@ -53,3 +53,16 @@ DRAW point MAPPING bill_len AS x, bill_dep AS y, species AS fill FILTER sex = 'female' ``` + +```{ggsql} +VISUALISE species AS x, sex AS y, island AS fill FROM ggsql:penguins +DRAW point + SETTING position => 'jitter', distribution => 'normal' +``` + +Use density distribution for a violin-like jitter effect, where jitter width scales with local data density. + +```{ggsql} +VISUALISE species AS x, bill_dep AS y FROM ggsql:penguins +DRAW point SETTING position => 'jitter', distribution => 'density' +``` diff --git a/doc/syntax/layer/polygon.qmd b/doc/syntax/layer/type/polygon.qmd similarity index 84% rename from doc/syntax/layer/polygon.qmd rename to doc/syntax/layer/type/polygon.qmd index ee9f3536..94c1188f 100644 --- a/doc/syntax/layer/polygon.qmd +++ b/doc/syntax/layer/type/polygon.qmd @@ -2,7 +2,7 @@ title: "Polygon" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. Polygons can be used to draw arbitrary closed shapes based on an ordered sequence of x,y-coordinates. They are similar to [paths](path.qmd), but close the shapes and fill the interior. @@ -22,7 +22,7 @@ The following aesthetics are recognised by the polygon layer. * `linetype` The dash pattern of the contour line. ## Settings -The polygon layer has no additional settings +* `position`: Determines the position adjustment to use for the layer (default is `'identity'`) ## Data transformation The polygon layer does not transform its data but passes it through unchanged @@ -62,4 +62,4 @@ Invoking a group through discrete aesthetics works as well. ```{ggsql} VISUALISE x, y FROM df DRAW polygon MAPPING id AS colour -``` \ No newline at end of file +``` diff --git a/doc/syntax/layer/ribbon.qmd b/doc/syntax/layer/type/ribbon.qmd similarity index 84% rename from doc/syntax/layer/ribbon.qmd rename to doc/syntax/layer/type/ribbon.qmd index 2f87f88c..58c36b1f 100644 --- a/doc/syntax/layer/ribbon.qmd +++ b/doc/syntax/layer/type/ribbon.qmd @@ -2,7 +2,7 @@ title: "Ribbon" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. The ribbon layer is used to display extrema over a sorted x-axis. It can be seen as an [area chart](area.qmd) that is unanchored from zero. @@ -22,7 +22,7 @@ The following aesthetics are recognised by the ribbon layer. * `linewidth`: The width of the contour lines. ## Settings -The ribbon layer has no additional settings. +* `position`: Determines the position adjustment to use for the layer (default is `'identity'`) ## Data transformation The ribbon layer does not transform its data but passes it through unchanged. diff --git a/doc/syntax/layer/violin.qmd b/doc/syntax/layer/type/violin.qmd similarity index 88% rename from doc/syntax/layer/violin.qmd rename to doc/syntax/layer/type/violin.qmd index 3283d9d2..88a91d1f 100644 --- a/doc/syntax/layer/violin.qmd +++ b/doc/syntax/layer/type/violin.qmd @@ -2,7 +2,7 @@ title: "Violin" --- -> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. +> Layers are declared with the [`DRAW` clause](../../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. Violin plots display the distribution of a single continuous variable for multiple groups. The violins are mirrored kernel density estimates, similar to the [density](density.qmd) layer, but organised as distinct groups. @@ -23,6 +23,7 @@ The following aesthetics are recognised by the violin layer. * `linetype` The dash pattern of the contour line. ## Settings +* `position`: Determines the position adjustment to use for the layer (default is `'dodge'`) * `bandwidth`: A numerical value setting the smoothing bandwidth to use. If absent (default), the bandwidth will be computed using Silverman's rule of thumb. * `adjust`: A numerical value as multiplier for the `bandwidth` setting, with 1 as default. * `kernel`: Determines the smoothing kernel shape. Can be one of the following: @@ -32,6 +33,7 @@ The following aesthetics are recognised by the violin layer. * `'rectangular'` or `'uniform'` * `'biweight'` or `'quartic'` * `'cosine'` +* `width`: Relative width of the violins. Defaults to `0.9`. ## Data transformation A violin layer uses the same computation as a density layer. See the [density data transformation](density.qmd#data-transformation) section for details. @@ -76,11 +78,8 @@ VISUALISE species AS x, bill_dep AS y FROM ggsql:penguins You can combine groups to expand the categories. - - ```{ggsql} -SELECT *, species || ' ' || island AS groups FROM ggsql:penguins -VISUALISE groups AS x, bill_dep AS y, island AS fill +VISUALISE species AS x, bill_dep AS y, island AS fill FROM ggsql:penguins DRAW violin ``` diff --git a/src/Cargo.toml b/src/Cargo.toml index 9757fff3..e86dd347 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -56,6 +56,7 @@ thiserror.workspace = true # Utilities regex.workspace = true chrono.workspace = true +rand.workspace = true sprintf = "0.4" const_format.workspace = true uuid.workspace = true diff --git a/src/execute/mod.rs b/src/execute/mod.rs index a8f17e07..24ea0f30 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -13,6 +13,7 @@ mod casting; mod cte; mod layer; +mod position; mod scale; mod schema; @@ -792,6 +793,18 @@ fn collect_layer_required_columns(layer: &Layer, spec: &Plot) -> HashSet required.insert(naming::ORDER_COLUMN.to_string()); } + // Position offset column for position adjustments that create pos1offset + // This column is created by dodge/jitter positions and is not in layer.mappings + if layer.position.creates_pos1offset() { + required.insert(naming::aesthetic_column("pos1offset")); + } + + // Position offset column for position adjustments that create pos2offset + // This column is created by jitter position for vertical jittering + if layer.position.creates_pos2offset() { + required.insert(naming::aesthetic_column("pos2offset")); + } + required } @@ -1062,7 +1075,9 @@ pub fn prepare_data_with_reader(query: &str, reader: &R) -> Result(query: &str, reader: &R) -> Result = HashSet::new(); @@ -1125,7 +1141,10 @@ pub fn prepare_data_with_reader(query: &str, reader: &R) -> Result(query: &str, reader: &R) -> Result(query: &str, reader: &R) -> Result, +) -> Result<()> { + for idx in 0..spec.layers.len() { + // Skip identity position (no adjustment needed) + if spec.layers[idx].position.position_type() == PositionType::Identity { + continue; + } + + let Some(key) = spec.layers[idx].data_key.clone() else { + continue; + }; + + let Some(df) = data_map.get(&key) else { + continue; + }; + + // Delegate to the position's apply_adjustment implementation + // Each position validates its own requirements internally + let (adjusted_df, adjusted_width) = + spec.layers[idx] + .position + .apply_adjustment(df, &spec.layers[idx], spec)?; + + data_map.insert(key.clone(), adjusted_df); + + // Store adjusted width on layer (for writers that need it) + // This does NOT override the user's width parameter + if let Some(width) = adjusted_width { + spec.layers[idx].adjusted_width = Some(width); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::plot::layer::{Geom, Position}; + use crate::plot::{AestheticValue, Mappings, ParameterValue, Scale, ScaleType}; + use polars::prelude::*; + + fn make_continuous_scale(aesthetic: &str) -> Scale { + let mut scale = Scale::new(aesthetic); + scale.scale_type = Some(ScaleType::continuous()); + scale + } + + fn make_discrete_scale(aesthetic: &str) -> Scale { + let mut scale = Scale::new(aesthetic); + scale.scale_type = Some(ScaleType::discrete()); + scale + } + + fn make_test_df() -> DataFrame { + df! { + "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + } + .unwrap() + } + + fn make_test_layer() -> crate::plot::Layer { + let mut layer = crate::plot::Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + // Add fill to partition_by (simulates what add_discrete_columns_to_partition_by does) + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + layer + } + + #[test] + fn test_identity_no_change() { + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.position = Position::identity(); + + let spec = Plot::new(); + let mut data_map = HashMap::new(); + layer.data_key = Some("__ggsql_layer_0__".to_string()); + data_map.insert("__ggsql_layer_0__".to_string(), df.clone()); + + let mut spec_with_layer = spec; + spec_with_layer.layers.push(layer); + + apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); + + // Data should be unchanged + let result_df = data_map.get("__ggsql_layer_0__").unwrap(); + assert_eq!(result_df.height(), 4); + } + + #[test] + fn test_stack_cumsum() { + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.position = Position::stack(); + + let spec = Plot::new(); + let mut data_map = HashMap::new(); + layer.data_key = Some("__ggsql_layer_0__".to_string()); + data_map.insert("__ggsql_layer_0__".to_string(), df); + + let mut spec_with_layer = spec; + spec_with_layer.layers.push(layer); + + apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); + + let result_df = data_map.get("__ggsql_layer_0__").unwrap(); + let pos2_col = result_df.column("__ggsql_aes_pos2__").unwrap(); + let pos2end_col = result_df.column("__ggsql_aes_pos2end__").unwrap(); + + // Verify stacking was applied + assert!(pos2_col.f64().is_ok() || pos2_col.i64().is_ok()); + assert!(pos2end_col.f64().is_ok() || pos2end_col.i64().is_ok()); + } + + #[test] + fn test_dodge_offset() { + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.position = Position::dodge(); + + // Create spec with pos1 as discrete and pos2 as continuous + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let mut data_map = HashMap::new(); + layer.data_key = Some("__ggsql_layer_0__".to_string()); + data_map.insert("__ggsql_layer_0__".to_string(), df); + + let mut spec_with_layer = spec; + spec_with_layer.layers.push(layer); + + apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); + + let result_df = data_map.get("__ggsql_layer_0__").unwrap(); + + // Verify pos1offset column was created + let offset_col = result_df.column("__ggsql_aes_pos1offset__"); + assert!(offset_col.is_ok(), "pos1offset column should be created"); + + let offset = offset_col.unwrap().f64().unwrap(); + + // With 2 groups (X, Y) and default width 0.9: + // - adjusted_width = 0.9 / 2 = 0.45 + // - center_offset = 0.5 + // - Group X: center = (0 - 0.5) * 0.45 = -0.225 + // - Group Y: center = (1 - 0.5) * 0.45 = +0.225 + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + assert!( + offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001), + "Should have offset -0.225 for group X, got {:?}", + offsets + ); + assert!( + offsets.iter().any(|&v| (v - 0.225).abs() < 0.001), + "Should have offset +0.225 for group Y, got {:?}", + offsets + ); + + // Verify adjusted_width was set + let adjusted = spec_with_layer.layers[0].adjusted_width; + assert!(adjusted.is_some()); + assert!((adjusted.unwrap() - 0.45).abs() < 0.001); + } + + #[test] + fn test_dodge_custom_width() { + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.position = Position::dodge(); + layer + .parameters + .insert("width".to_string(), ParameterValue::Number(0.6)); + + // Create spec with pos1 as discrete and pos2 as continuous + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let mut data_map = HashMap::new(); + layer.data_key = Some("__ggsql_layer_0__".to_string()); + data_map.insert("__ggsql_layer_0__".to_string(), df); + + let mut spec_with_layer = spec; + spec_with_layer.layers.push(layer); + + apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); + + let result_df = data_map.get("__ggsql_layer_0__").unwrap(); + let offset = result_df + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + + // With 2 groups and custom width 0.6: + // - adjusted_width = 0.6 / 2 = 0.3 + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + assert!(offsets.iter().any(|&v| (v - (-0.15)).abs() < 0.001)); + assert!(offsets.iter().any(|&v| (v - 0.15).abs() < 0.001)); + + let adjusted = spec_with_layer.layers[0].adjusted_width; + assert!((adjusted.unwrap() - 0.3).abs() < 0.001); + } + + #[test] + fn test_jitter_offset() { + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.position = Position::jitter(); + + // Create spec with pos1 as discrete and pos2 as continuous + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let mut data_map = HashMap::new(); + layer.data_key = Some("__ggsql_layer_0__".to_string()); + data_map.insert("__ggsql_layer_0__".to_string(), df); + + let mut spec_with_layer = spec; + spec_with_layer.layers.push(layer); + + apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); + + let result_df = data_map.get("__ggsql_layer_0__").unwrap(); + + // Verify pos1offset column was created + let offset_col = result_df.column("__ggsql_aes_pos1offset__"); + assert!(offset_col.is_ok()); + + let offset = offset_col.unwrap().f64().unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // With default width 0.9, offsets should be in range [-0.45, 0.45] + for &v in &offsets { + assert!(v >= -0.45 && v <= 0.45); + } + + // No adjusted_width for jitter + assert!(spec_with_layer.layers[0].adjusted_width.is_none()); + } + + #[test] + fn test_jitter_custom_width() { + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.position = Position::jitter(); + layer + .parameters + .insert("width".to_string(), ParameterValue::Number(0.6)); + + // Create spec with pos1 as discrete and pos2 as continuous + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let mut data_map = HashMap::new(); + layer.data_key = Some("__ggsql_layer_0__".to_string()); + data_map.insert("__ggsql_layer_0__".to_string(), df); + + let mut spec_with_layer = spec; + spec_with_layer.layers.push(layer); + + apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); + + let result_df = data_map.get("__ggsql_layer_0__").unwrap(); + let offset = result_df + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // With custom width 0.6, offsets should be in range [-0.3, 0.3] + for &v in &offsets { + assert!(v >= -0.3 && v <= 0.3); + } + } +} diff --git a/src/execute/scale.rs b/src/execute/scale.rs index 3ad4cf2b..b1586b28 100644 --- a/src/execute/scale.rs +++ b/src/execute/scale.rs @@ -904,6 +904,49 @@ pub fn coerce_aesthetic_columns( Ok(()) } +// ============================================================================= +// Scale Type Inference (Early) +// ============================================================================= + +/// Infer scale types from data for scales that don't have explicit types. +/// +/// This is a lightweight version of resolve_scales that only infers scale_type. +/// Called before position adjustments so dodge/stack can correctly identify +/// continuous vs discrete axes (e.g., stat-generated count columns). +/// +/// Does NOT perform coercion or full resolution - that happens later in resolve_scales. +pub fn infer_scale_types_from_data( + spec: &mut Plot, + data_map: &HashMap, +) -> Result<()> { + let aesthetic_ctx = spec.get_aesthetic_context(); + + for idx in 0..spec.scales.len() { + // Skip scales that already have a type + if spec.scales[idx].scale_type.is_some() { + continue; + } + + let aesthetic = spec.scales[idx].aesthetic.clone(); + + // Find column references for this aesthetic + let column_refs = + find_columns_for_aesthetic(&spec.layers, &aesthetic, data_map, &aesthetic_ctx); + + if column_refs.is_empty() { + continue; + } + + // Infer scale_type from column dtype + spec.scales[idx].scale_type = Some(ScaleType::infer_for_aesthetic( + column_refs[0].dtype(), + &aesthetic, + )); + } + + Ok(()) +} + // ============================================================================= // Scale Resolution // ============================================================================= @@ -956,7 +999,8 @@ pub fn resolve_scales(spec: &mut Plot, data_map: &mut HashMap continue; } - // Infer scale_type if not already set + // Infer scale_type if not already set (fallback - usually already inferred + // by infer_scale_types_from_data() which runs before position adjustments) if spec.scales[idx].scale_type.is_none() { spec.scales[idx].scale_type = Some(ScaleType::infer_for_aesthetic( column_refs[0].dtype(), diff --git a/src/parser/builder.rs b/src/parser/builder.rs index f7bc9cb1..d80313c2 100644 --- a/src/parser/builder.rs +++ b/src/parser/builder.rs @@ -477,7 +477,24 @@ fn build_layer(node: &Node, source: &SourceTree) -> Result { } } + // Extract position from parameters if present, otherwise use geom default + let position = if let Some(ParameterValue::String(pos_str)) = parameters.remove("position") { + Position::from_str(&pos_str) + } else { + // Check geom's default_params for position default + geom.default_params() + .iter() + .find(|p| p.name == "position") + .and_then(|p| p.to_parameter_value()) + .and_then(|v| match v { + ParameterValue::String(s) => Some(Position::from_str(&s)), + _ => None, + }) + .unwrap_or_default() + }; + let mut layer = Layer::new(geom); + layer.position = position; layer.mappings = aesthetics; layer.remappings = remappings; layer.parameters = parameters; @@ -3424,4 +3441,89 @@ mod tests { let project = specs[0].project.as_ref().unwrap(); assert_eq!(project.coord.coord_kind(), CoordKind::Cartesian); } + + // ======================================== + // Position Adjustment Parsing Tests + // ======================================== + + #[test] + fn test_position_stack_from_setting() { + let query = r#" + VISUALISE + DRAW bar MAPPING cat AS x, val AS y, grp AS fill + SETTING position => 'stack' + "#; + + let result = parse_test_query(query); + assert!(result.is_ok()); + let specs = result.unwrap(); + + assert_eq!(specs[0].layers.len(), 1); + assert_eq!(specs[0].layers[0].position, Position::stack()); + } + + #[test] + fn test_position_dodge_from_setting() { + let query = r#" + VISUALISE + DRAW bar MAPPING cat AS x, val AS y, grp AS fill + SETTING position => 'dodge' + "#; + + let result = parse_test_query(query); + assert!(result.is_ok()); + let specs = result.unwrap(); + + assert_eq!(specs[0].layers.len(), 1); + assert_eq!(specs[0].layers[0].position, Position::dodge()); + } + + #[test] + fn test_position_jitter_from_setting() { + let query = r#" + VISUALISE + DRAW point MAPPING cat AS x, val AS y + SETTING position => 'jitter' + "#; + + let result = parse_test_query(query); + assert!(result.is_ok()); + let specs = result.unwrap(); + + assert_eq!(specs[0].layers.len(), 1); + assert_eq!(specs[0].layers[0].position, Position::jitter()); + } + + #[test] + fn test_position_geom_defaults() { + // Bar defaults to Stack (ggplot2 behavior) + let query = r#" + VISUALISE + DRAW bar MAPPING cat AS x, val AS y + "#; + let result = parse_test_query(query); + assert!(result.is_ok()); + let specs = result.unwrap(); + assert_eq!(specs[0].layers[0].position, Position::stack()); + + // Point defaults to Identity + let query = r#" + VISUALISE + DRAW point MAPPING cat AS x, val AS y + "#; + let result = parse_test_query(query); + assert!(result.is_ok()); + let specs = result.unwrap(); + assert_eq!(specs[0].layers[0].position, Position::identity()); + + // Boxplot defaults to Dodge + let query = r#" + VISUALISE + DRAW boxplot MAPPING cat AS x, val AS y + "#; + let result = parse_test_query(query); + assert!(result.is_ok()); + let specs = result.unwrap(); + assert_eq!(specs[0].layers[0].position, Position::dodge()); + } } diff --git a/src/plot/layer/geom/abline.rs b/src/plot/layer/geom/abline.rs index ff335dc6..901a4203 100644 --- a/src/plot/layer/geom/abline.rs +++ b/src/plot/layer/geom/abline.rs @@ -1,6 +1,6 @@ //! AbLine geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// AbLine geom - lines with slope and intercept @@ -24,6 +24,13 @@ impl GeomTrait for AbLine { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for AbLine { diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index 06e8a7a0..91d02cb7 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -27,10 +27,14 @@ impl GeomTrait for Area { } } + fn default_remappings(&self) -> &'static [(&'static str, DefaultAestheticValue)] { + &[("pos2end", DefaultAestheticValue::Number(0.0))] + } + fn default_params(&self) -> &'static [DefaultParam] { &[DefaultParam { - name: "stacking", - default: DefaultParamValue::String("off"), + name: "position", + default: DefaultParamValue::String("stack"), }] } } diff --git a/src/plot/layer/geom/arrow.rs b/src/plot/layer/geom/arrow.rs index d2eb9e84..538b0446 100644 --- a/src/plot/layer/geom/arrow.rs +++ b/src/plot/layer/geom/arrow.rs @@ -1,6 +1,6 @@ //! Arrow geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Arrow geom - line segments with arrowheads @@ -27,6 +27,13 @@ impl GeomTrait for Arrow { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Arrow { diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index 191b4ee2..67ac2e89 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -53,10 +53,16 @@ impl GeomTrait for Bar { } fn default_params(&self) -> &'static [DefaultParam] { - &[DefaultParam { - name: "width", - default: DefaultParamValue::Number(0.9), - }] + &[ + DefaultParam { + name: "width", + default: DefaultParamValue::Number(0.9), + }, + DefaultParam { + name: "position", + default: DefaultParamValue::String("stack"), + }, + ] } fn stat_consumed_aesthetics(&self) -> &'static [&'static str] { diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index dda15397..975db0fb 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -62,6 +62,10 @@ impl GeomTrait for Boxplot { name: "width", default: DefaultParamValue::Number(0.9), }, + DefaultParam { + name: "position", + default: DefaultParamValue::String("dodge"), + }, ] } @@ -569,7 +573,7 @@ mod tests { let boxplot = Boxplot; let params = boxplot.default_params(); - assert_eq!(params.len(), 3); + assert_eq!(params.len(), 4); // Find and verify outliers param let outliers_param = params.iter().find(|p| p.name == "outliers").unwrap(); @@ -589,6 +593,13 @@ mod tests { assert!( matches!(width_param.default, DefaultParamValue::Number(v) if (v - 0.9).abs() < f64::EPSILON) ); + + // Find and verify position param (boxplot defaults to dodge) + let position_param = params.iter().find(|p| p.name == "position").unwrap(); + assert!(matches!( + position_param.default, + DefaultParamValue::String("dodge") + )); } #[test] diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index e4a28d3d..b2954b67 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -46,8 +46,8 @@ impl GeomTrait for Density { fn default_params(&self) -> &'static [DefaultParam] { &[ DefaultParam { - name: "stacking", - default: DefaultParamValue::String("off"), + name: "position", + default: DefaultParamValue::String("identity"), }, DefaultParam { name: "bandwidth", @@ -68,6 +68,7 @@ impl GeomTrait for Density { &[ ("pos1", DefaultAestheticValue::Column("pos1")), ("pos2", DefaultAestheticValue::Column("density")), + ("pos2end", DefaultAestheticValue::Number(0.0)), ] } diff --git a/src/plot/layer/geom/errorbar.rs b/src/plot/layer/geom/errorbar.rs index 423d56a1..4db87034 100644 --- a/src/plot/layer/geom/errorbar.rs +++ b/src/plot/layer/geom/errorbar.rs @@ -1,6 +1,6 @@ //! ErrorBar geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// ErrorBar geom - error bars (confidence intervals) @@ -27,6 +27,13 @@ impl GeomTrait for ErrorBar { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for ErrorBar { diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index a94bf695..5830c87f 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -61,6 +61,10 @@ impl GeomTrait for Histogram { name: "binwidth", default: DefaultParamValue::Null, }, + DefaultParam { + name: "position", + default: DefaultParamValue::String("stack"), + }, ] } diff --git a/src/plot/layer/geom/hline.rs b/src/plot/layer/geom/hline.rs index e3338c83..a4a251f1 100644 --- a/src/plot/layer/geom/hline.rs +++ b/src/plot/layer/geom/hline.rs @@ -1,6 +1,6 @@ //! HLine geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// HLine geom - horizontal reference lines @@ -23,6 +23,13 @@ impl GeomTrait for HLine { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for HLine { diff --git a/src/plot/layer/geom/label.rs b/src/plot/layer/geom/label.rs index d1892e02..d2f41d1c 100644 --- a/src/plot/layer/geom/label.rs +++ b/src/plot/layer/geom/label.rs @@ -1,6 +1,6 @@ //! Label geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Label geom - text labels with background @@ -29,6 +29,13 @@ impl GeomTrait for Label { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Label { diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index fa3dea59..7f283676 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -1,6 +1,6 @@ //! Line geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Line geom - line charts with connected points @@ -24,6 +24,13 @@ impl GeomTrait for Line { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Line { diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index c953c9f7..abbac5f8 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -206,6 +206,22 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { Ok(StatResult::Identity) } + /// Post-process the DataFrame after stat query execution. + /// + /// This method is called after the stat transform query has been executed + /// and allows geoms to modify the resulting data. The default implementation + /// returns the data unchanged. + /// + /// Used by violin to scale the offset column to [0, 0.5 * width] using global + /// max normalization before Vega-Lite rendering. + fn post_process( + &self, + df: DataFrame, + _parameters: &HashMap, + ) -> Result { + Ok(df) + } + /// Returns valid parameter names for SETTING clause. /// /// Combines supported aesthetics with non-aesthetic parameters. @@ -413,6 +429,15 @@ impl Geom { ) } + /// Post-process DataFrame after stat query execution + pub fn post_process( + &self, + df: DataFrame, + parameters: &HashMap, + ) -> Result { + self.0.post_process(df, parameters) + } + /// Get valid settings pub fn valid_settings(&self) -> Vec<&'static str> { self.0.valid_settings() diff --git a/src/plot/layer/geom/path.rs b/src/plot/layer/geom/path.rs index 1d718da4..39726c29 100644 --- a/src/plot/layer/geom/path.rs +++ b/src/plot/layer/geom/path.rs @@ -1,6 +1,6 @@ //! Path geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Path geom - connected line segments in order @@ -24,6 +24,13 @@ impl GeomTrait for Path { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Path { diff --git a/src/plot/layer/geom/point.rs b/src/plot/layer/geom/point.rs index 25a2c1cf..73ce013a 100644 --- a/src/plot/layer/geom/point.rs +++ b/src/plot/layer/geom/point.rs @@ -1,6 +1,6 @@ //! Point geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Point geom - scatter plots and similar @@ -26,6 +26,13 @@ impl GeomTrait for Point { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Point { diff --git a/src/plot/layer/geom/polygon.rs b/src/plot/layer/geom/polygon.rs index ad250c79..78b9d9b4 100644 --- a/src/plot/layer/geom/polygon.rs +++ b/src/plot/layer/geom/polygon.rs @@ -1,6 +1,6 @@ //! Polygon geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Polygon geom - arbitrary polygons @@ -25,6 +25,13 @@ impl GeomTrait for Polygon { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Polygon { diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 17777c9a..0bcd4d8b 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -1,6 +1,6 @@ //! Ribbon geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Ribbon geom - confidence bands and ranges @@ -26,6 +26,13 @@ impl GeomTrait for Ribbon { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Ribbon { diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index eb60a520..18765e90 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -1,6 +1,6 @@ //! Segment geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Segment geom - line segments between two points @@ -26,6 +26,13 @@ impl GeomTrait for Segment { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Segment { diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index 947dc5db..e8d55854 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -1,6 +1,6 @@ //! Smooth geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; use crate::Mappings; @@ -26,6 +26,13 @@ impl GeomTrait for Smooth { } } + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } diff --git a/src/plot/layer/geom/text.rs b/src/plot/layer/geom/text.rs index a185737e..eff8d887 100644 --- a/src/plot/layer/geom/text.rs +++ b/src/plot/layer/geom/text.rs @@ -1,6 +1,6 @@ //! Text geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Text geom - text labels at positions @@ -28,6 +28,13 @@ impl GeomTrait for Text { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Text { diff --git a/src/plot/layer/geom/tile.rs b/src/plot/layer/geom/tile.rs index 870721d3..a13b462a 100644 --- a/src/plot/layer/geom/tile.rs +++ b/src/plot/layer/geom/tile.rs @@ -1,6 +1,6 @@ //! Tile geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// Tile geom - heatmaps and tile-based visualizations @@ -25,6 +25,13 @@ impl GeomTrait for Tile { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for Tile { diff --git a/src/plot/layer/geom/violin.rs b/src/plot/layer/geom/violin.rs index 91384a2d..7650dc5d 100644 --- a/src/plot/layer/geom/violin.rs +++ b/src/plot/layer/geom/violin.rs @@ -2,12 +2,14 @@ use super::{DefaultAesthetics, GeomTrait, GeomType, StatResult}; use crate::{ + naming, plot::{ geom::types::get_column_name, DefaultAestheticValue, DefaultParam, DefaultParamValue, ParameterValue, }, - GgsqlError, Mappings, Result, + DataFrame, GgsqlError, Mappings, Result, }; +use polars::prelude::*; use std::collections::HashMap; /// Violin geom - violin plots (mirrored density) @@ -30,7 +32,7 @@ impl GeomTrait for Violin { ("opacity", DefaultAestheticValue::Number(0.8)), ("linewidth", DefaultAestheticValue::Number(1.0)), ("linetype", DefaultAestheticValue::String("solid")), - ("offset", DefaultAestheticValue::Delayed), // Computed by stat + ("offset", DefaultAestheticValue::Delayed), // Computed by stat, used for violin shape ], } } @@ -53,6 +55,14 @@ impl GeomTrait for Violin { name: "kernel", default: DefaultParamValue::String("gaussian"), }, + DefaultParam { + name: "position", + default: DefaultParamValue::String("dodge"), + }, + DefaultParam { + name: "width", + default: DefaultParamValue::Number(0.9), + }, ] } @@ -82,6 +92,31 @@ impl GeomTrait for Violin { ) -> Result { stat_violin(query, aesthetics, group_by, parameters, execute_query) } + + /// Post-process the violin DataFrame to scale offset to [0, 0.5 * width]. + /// + /// Uses global max normalization so relative differences across groups are preserved: + /// - Narrow distributions will have higher peaks (normalized density) + /// - Groups with more data will be wider when using intensity remapping + fn post_process( + &self, + df: DataFrame, + parameters: &HashMap, + ) -> Result { + let offset_col = naming::aesthetic_column("offset"); + + // Get width parameter (default 0.9) + let width = parameters + .get("width") + .and_then(|v| match v { + ParameterValue::Number(n) => Some(*n), + _ => None, + }) + .unwrap_or(0.9); + let half_width = 0.5 * width; + + scale_offset_column(df, &offset_col, half_width) + } } impl std::fmt::Display for Violin { @@ -90,6 +125,40 @@ impl std::fmt::Display for Violin { } } +/// Scale the offset column to [0, half_width] using global max normalization. +/// +/// new_offset = offset * half_width / global_max +fn scale_offset_column(df: DataFrame, offset_col: &str, half_width: f64) -> Result { + // Check if offset column exists + if df.column(offset_col).is_err() { + // No offset column, return unchanged + return Ok(df); + } + + // Get global max of offset column + let max_val = df + .column(offset_col) + .map_err(|e| GgsqlError::InternalError(format!("Failed to get offset column: {}", e)))? + .f64() + .map_err(|e| GgsqlError::InternalError(format!("Offset column must be f64: {}", e)))? + .max() + .unwrap_or(1.0); + + if max_val <= 0.0 { + return Ok(df); + } + + // Scale: new_offset = offset * half_width / max_val + let scale_factor = half_width / max_val; + let scaled = df + .lazy() + .with_column((col(offset_col) * lit(scale_factor)).alias(offset_col)) + .collect() + .map_err(|e| GgsqlError::InternalError(format!("Failed to scale offset: {}", e)))?; + + Ok(scaled) +} + fn stat_violin( query: &str, aesthetics: &Mappings, @@ -293,4 +362,107 @@ mod tests { _ => panic!("Expected Transformed result"), } } + + #[test] + fn test_violin_width_parameter() { + // Verify that the violin geom has a width parameter with default 0.9 + let violin = Violin; + let params = violin.default_params(); + + let width_param = params.iter().find(|p| p.name == "width"); + assert!( + width_param.is_some(), + "Violin should have a 'width' parameter" + ); + + if let Some(param) = width_param { + match param.default { + DefaultParamValue::Number(n) => { + assert!( + (n - 0.9).abs() < 1e-6, + "Default width should be 0.9, got {}", + n + ); + } + _ => panic!("Width parameter should have a numeric default"), + } + } + } + + // ==================== Post-Process Tests ==================== + + #[test] + fn test_violin_post_process_scales_offset() { + let violin = Violin; + let offset_col = naming::aesthetic_column("offset"); + + // Create a DataFrame with offset values + let df = df! { + offset_col.as_str() => [0.0, 0.5, 1.0, 0.25], + "__ggsql_aes_pos2__" => [1.0, 2.0, 3.0, 4.0], + } + .unwrap(); + + // With default width 0.9, half_width = 0.45 + // Offset should be scaled to [0, 0.45] + let parameters = HashMap::new(); + let result = violin.post_process(df, ¶meters).unwrap(); + + let scaled_offset = result.column(&offset_col).unwrap().f64().unwrap(); + let values: Vec = scaled_offset.into_iter().filter_map(|v| v).collect(); + + // Max offset (1.0) should be scaled to 0.45 (half_width) + // Other values should be proportionally scaled + assert!((values[0] - 0.0).abs() < 1e-6, "0.0 should stay 0.0"); + assert!((values[1] - 0.225).abs() < 1e-6, "0.5 should become 0.225"); + assert!((values[2] - 0.45).abs() < 1e-6, "1.0 should become 0.45"); + assert!( + (values[3] - 0.1125).abs() < 1e-6, + "0.25 should become 0.1125" + ); + } + + #[test] + fn test_violin_post_process_custom_width() { + let violin = Violin; + let offset_col = naming::aesthetic_column("offset"); + + // Create a DataFrame with offset values + let df = df! { + offset_col.as_str() => [0.0, 0.5, 1.0], + "__ggsql_aes_pos2__" => [1.0, 2.0, 3.0], + } + .unwrap(); + + // With width 0.6, half_width = 0.3 + let mut parameters = HashMap::new(); + parameters.insert("width".to_string(), ParameterValue::Number(0.6)); + + let result = violin.post_process(df, ¶meters).unwrap(); + + let scaled_offset = result.column(&offset_col).unwrap().f64().unwrap(); + let values: Vec = scaled_offset.into_iter().filter_map(|v| v).collect(); + + // Max offset (1.0) should be scaled to 0.3 (half_width) + assert!((values[0] - 0.0).abs() < 1e-6, "0.0 should stay 0.0"); + assert!((values[1] - 0.15).abs() < 1e-6, "0.5 should become 0.15"); + assert!((values[2] - 0.3).abs() < 1e-6, "1.0 should become 0.3"); + } + + #[test] + fn test_violin_post_process_no_offset_column() { + let violin = Violin; + + // Create a DataFrame without offset column + let df = df! { + "__ggsql_aes_pos2__" => [1.0, 2.0, 3.0], + } + .unwrap(); + + let parameters = HashMap::new(); + let result = violin.post_process(df.clone(), ¶meters).unwrap(); + + // Should return unchanged DataFrame + assert_eq!(result.height(), df.height()); + } } diff --git a/src/plot/layer/geom/vline.rs b/src/plot/layer/geom/vline.rs index 37ec2058..879510d6 100644 --- a/src/plot/layer/geom/vline.rs +++ b/src/plot/layer/geom/vline.rs @@ -1,6 +1,6 @@ //! VLine geom implementation -use super::{DefaultAesthetics, GeomTrait, GeomType}; +use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; use crate::plot::types::DefaultAestheticValue; /// VLine geom - vertical reference lines @@ -23,6 +23,13 @@ impl GeomTrait for VLine { ], } } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }] + } } impl std::fmt::Display for VLine { diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 84076071..a8d79f91 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -6,14 +6,20 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -// Geom is now a submodule of layer +// Geom is a submodule of layer pub mod geom; +// Position is a submodule of layer +pub mod position; + // Re-export geom types for convenience pub use geom::{ DefaultAesthetics, DefaultParam, DefaultParamValue, Geom, GeomTrait, GeomType, StatResult, }; +// Re-export position types for convenience +pub use position::{Position, PositionTrait, PositionType}; + use crate::plot::types::{AestheticValue, DataSource, Mappings, ParameterValue, SqlExpression}; /// A single visualization layer (from DRAW clause) @@ -21,6 +27,8 @@ use crate::plot::types::{AestheticValue, DataSource, Mappings, ParameterValue, S pub struct Layer { /// Geometric object type pub geom: Geom, + /// Position adjustment for overlapping elements + pub position: Position, /// All aesthetic mappings combined from multiple sources: /// /// 1. **MAPPING clause** (from query, highest precedence): @@ -60,6 +68,11 @@ pub struct Layer { /// but may point to another layer's data when queries are deduplicated. #[serde(skip_serializing_if = "Option::is_none")] pub data_key: Option, + /// Adjusted width after position adjustment (e.g., for dodged bars). + /// Set during execution by position::apply_position_adjustments(). + /// Writers can use this to know the actual element width after dodging. + #[serde(skip_serializing_if = "Option::is_none")] + pub adjusted_width: Option, } impl Layer { @@ -67,6 +80,7 @@ impl Layer { pub fn new(geom: Geom) -> Self { Self { geom, + position: Position::default(), mappings: Mappings::new(), remappings: Mappings::new(), parameters: HashMap::new(), @@ -75,9 +89,16 @@ impl Layer { order_by: None, partition_by: Vec::new(), data_key: None, + adjusted_width: None, } } + /// Set the position adjustment + pub fn with_position(mut self, position: Position) -> Self { + self.position = position; + self + } + /// Set the filter expression pub fn with_filter(mut self, filter: SqlExpression) -> Self { self.filter = Some(filter); @@ -167,6 +188,25 @@ impl Layer { } } + /// Apply default position parameter values for any params not specified by user. + /// + /// This is called AFTER apply_default_params() so geom defaults take precedence + /// over position defaults. For example, if a geom defines width => 0.8 and the + /// position (dodge) defines width => 0.9, the geom's 0.8 is used. + pub fn apply_default_position_params(&mut self) { + for param in self.position.default_params() { + if !self.parameters.contains_key(param.name) { + let value = match ¶m.default { + DefaultParamValue::String(s) => ParameterValue::String(s.to_string()), + DefaultParamValue::Number(n) => ParameterValue::Number(*n), + DefaultParamValue::Boolean(b) => ParameterValue::Boolean(*b), + DefaultParamValue::Null => continue, + }; + self.parameters.insert(param.name.to_string(), value); + } + } + } + /// Resolve aesthetics for all supported aesthetics not in MAPPING. /// /// For each supported aesthetic that's not already mapped in MAPPING: @@ -210,15 +250,19 @@ impl Layer { } } - /// Validate that all SETTING parameters are valid for this layer's geom + /// Validate that all SETTING parameters are valid for this layer's geom and position pub fn validate_settings(&self) -> std::result::Result<(), String> { - let valid = self.geom.valid_settings(); + // Combine valid settings from both geom and position + let mut valid = self.geom.valid_settings(); + valid.extend(self.position.valid_settings()); + for param_name in self.parameters.keys() { if !valid.contains(¶m_name.as_str()) { return Err(format!( - "Invalid setting '{}' for geom '{}'. Valid settings are: {}", + "Invalid setting '{}' for geom '{}' with position '{}'. Valid settings are: {}", param_name, self.geom, + self.position, valid.join(", ") )); } diff --git a/src/plot/layer/position/dodge.rs b/src/plot/layer/position/dodge.rs new file mode 100644 index 00000000..0f6e83e6 --- /dev/null +++ b/src/plot/layer/position/dodge.rs @@ -0,0 +1,699 @@ +//! Dodge position adjustment +//! +//! Positions elements side-by-side within groups. Dodge automatically detects +//! which axes are discrete and applies dodge accordingly: +//! - If only pos1 is discrete → dodge horizontally (pos1offset) +//! - If only pos2 is discrete → dodge vertically (pos2offset) +//! - If both are discrete → 2D grid dodge (both offsets, arranged in a grid) + +use super::{is_continuous_scale, Layer, PositionTrait, PositionType}; +use crate::plot::types::{DefaultParam, DefaultParamValue, ParameterValue}; +use crate::{naming, DataFrame, GgsqlError, Plot, Result}; +use polars::prelude::*; +use std::collections::HashMap; + +/// Result of computing group indices for dodge/jitter operations. +/// +/// Contains the number of unique groups and the group index for each row. +pub struct GroupIndices { + /// Number of unique groups + pub n_groups: usize, + /// Group index (0 to n_groups-1) for each row + pub indices: Vec, +} + +/// Compute group indices from partition_by columns. +/// +/// Returns None if there are no grouping columns or columns don't exist. +/// Returns Some(GroupIndices) with n_groups=1 if there's only one group. +pub fn compute_group_indices( + df: &DataFrame, + group_cols: &[String], +) -> Result> { + if group_cols.is_empty() { + return Ok(None); + } + + // Check if required grouping columns exist + for col_name in group_cols { + if df.column(col_name).is_err() { + return Ok(None); + } + } + + // Create composite key for each row by concatenating all grouping column values + let n_rows = df.height(); + let mut composite_keys: Vec = Vec::with_capacity(n_rows); + + for row_idx in 0..n_rows { + let mut key_parts: Vec = Vec::with_capacity(group_cols.len()); + for col_name in group_cols { + let col = df.column(col_name).unwrap(); + let val = col.get(row_idx).map_err(|e| { + GgsqlError::InternalError(format!("Failed to get value at row {}: {}", row_idx, e)) + })?; + key_parts.push(format!("{}", val)); + } + composite_keys.push(key_parts.join("\x00")); // Use null byte as separator + } + + // Get unique composite keys and sort them for consistent ordering + let mut unique_keys: Vec = composite_keys.to_vec(); + unique_keys.sort(); + unique_keys.dedup(); + + let n_groups = unique_keys.len(); + + // Create mapping from composite key to index + let key_to_idx: HashMap = unique_keys + .into_iter() + .enumerate() + .map(|(idx, key)| (key, idx)) + .collect(); + + // Create index column by mapping each row's composite key + let indices: Vec = composite_keys + .iter() + .map(|key| *key_to_idx.get(key).unwrap()) + .collect(); + + Ok(Some(GroupIndices { n_groups, indices })) +} + +/// Dodge position - position elements side-by-side +#[derive(Debug, Clone, Copy)] +pub struct Dodge; + +impl std::fmt::Display for Dodge { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "dodge") + } +} + +impl PositionTrait for Dodge { + fn position_type(&self) -> PositionType { + PositionType::Dodge + } + + fn default_params(&self) -> &'static [DefaultParam] { + &[DefaultParam { + name: "width", + default: DefaultParamValue::Number(0.9), + }] + } + + fn creates_pos1offset(&self) -> bool { + true + } + + fn creates_pos2offset(&self) -> bool { + true // May create pos2offset when pos2 is discrete + } + + fn apply_adjustment( + &self, + df: &DataFrame, + layer: &Layer, + spec: &Plot, + ) -> Result<(DataFrame, Option)> { + apply_dodge_with_width(df, layer, spec) + } +} + +/// Apply dodge position adjustment and compute adjusted bar width. +/// +/// Automatically detects which axes are discrete and applies dodge accordingly: +/// - Discrete pos1 only → creates pos1offset column (horizontal dodge) +/// - Discrete pos2 only → creates pos2offset column (vertical dodge) +/// - Both discrete → creates both offset columns (2D grid arrangement) +/// - Neither discrete → returns unchanged (no dodge applied) +/// +/// For 2D grid dodge, groups are arranged in a square grid pattern. For example: +/// - 4 groups → 2x2 grid +/// - 8 groups → 3x3 grid (one cell empty) +/// - 9 groups → 3x3 grid (all cells filled) +/// +/// If an existing "offset" column exists (e.g., from violin geom), scales it by n_groups +/// so the layer can use adjusted values for its shape rendering. +/// Also returns the adjusted bar width (original width / n_groups for 1D, or +/// original width / grid_size for 2D). +fn apply_dodge_with_width( + df: &DataFrame, + layer: &Layer, + spec: &Plot, +) -> Result<(DataFrame, Option)> { + let offset_col = naming::aesthetic_column("offset"); + let pos1offset_col = naming::aesthetic_column("pos1offset"); + let pos2offset_col = naming::aesthetic_column("pos2offset"); + + // Check which axes should be dodged (discrete axes) + // Since infer_scale_types_from_data() runs before position adjustments, + // scale types are always known, so we use explicit discrete checks. + let dodge_pos1 = is_continuous_scale(spec, "pos1") == Some(false); + let dodge_pos2 = is_continuous_scale(spec, "pos2") == Some(false); + + // If neither is discrete, nothing to dodge + if !dodge_pos1 && !dodge_pos2 { + return Ok((df.clone(), None)); + } + + // Compute group indices + let group_info = match compute_group_indices(df, &layer.partition_by)? { + Some(info) => info, + None => return Ok((df.clone(), None)), + }; + + let GroupIndices { n_groups, indices } = group_info; + + if n_groups <= 1 { + // Only one group - no dodging needed + return Ok((df.clone(), None)); + } + + // Get the default bar width from layer parameters (or use 0.9 as default) + let bar_width = layer + .parameters + .get("width") + .and_then(|v| match v { + ParameterValue::Number(n) => Some(*n), + _ => None, + }) + .unwrap_or(0.9); + + // Check if layer has an existing offset column (e.g., violin density offset) + let has_offset_col = df.column(&offset_col).is_ok(); + + let mut lf = df.clone().lazy(); + + // Compute offsets based on which axes are being dodged + let adjusted_width = if dodge_pos1 && dodge_pos2 { + // 2D grid arrangement: arrange groups in a square grid + let grid_size = (n_groups as f64).sqrt().ceil() as usize; + let adjusted_width = bar_width / grid_size as f64; + let center_offset = (grid_size as f64 - 1.0) / 2.0; + + // Compute grid positions for each group + let pos1_offsets: Vec = indices + .iter() + .map(|&idx| { + let col_idx = idx % grid_size; + (col_idx as f64 - center_offset) * adjusted_width + }) + .collect(); + let pos2_offsets: Vec = indices + .iter() + .map(|&idx| { + let row_idx = idx / grid_size; + (row_idx as f64 - center_offset) * adjusted_width + }) + .collect(); + + lf = lf.with_column( + lit(Series::new(pos1offset_col.clone().into(), pos1_offsets)).alias(&pos1offset_col), + ); + lf = lf.with_column( + lit(Series::new(pos2offset_col.clone().into(), pos2_offsets)).alias(&pos2offset_col), + ); + + // If offset column exists (e.g., violin), scale it by grid_size + if has_offset_col { + lf = lf.with_column((col(&offset_col) / lit(grid_size as f64)).alias(&offset_col)); + } + + adjusted_width + } else if dodge_pos1 { + // Horizontal dodge only (original behavior) + let n_groups_f64 = n_groups as f64; + let adjusted_width = bar_width / n_groups_f64; + let center_offset = (n_groups_f64 - 1.0) / 2.0; + + let pos1_offsets: Vec = indices + .iter() + .map(|&idx| (idx as f64 - center_offset) * adjusted_width) + .collect(); + + lf = lf.with_column( + lit(Series::new(pos1offset_col.clone().into(), pos1_offsets)).alias(&pos1offset_col), + ); + + // If offset column exists (e.g., violin), scale it by n_groups + if has_offset_col { + lf = lf.with_column((col(&offset_col) / lit(n_groups_f64)).alias(&offset_col)); + } + + adjusted_width + } else { + // Vertical dodge only (dodge_pos2 is true) + let n_groups_f64 = n_groups as f64; + let adjusted_width = bar_width / n_groups_f64; + let center_offset = (n_groups_f64 - 1.0) / 2.0; + + let pos2_offsets: Vec = indices + .iter() + .map(|&idx| (idx as f64 - center_offset) * adjusted_width) + .collect(); + + lf = lf.with_column( + lit(Series::new(pos2offset_col.clone().into(), pos2_offsets)).alias(&pos2offset_col), + ); + + // If offset column exists (e.g., violin), scale it by n_groups + if has_offset_col { + lf = lf.with_column((col(&offset_col) / lit(n_groups_f64)).alias(&offset_col)); + } + + adjusted_width + }; + + // Collect the result + let final_df = lf.collect().map_err(|e| { + GgsqlError::InternalError(format!("Dodge position adjustment failed: {}", e)) + })?; + + Ok((final_df, Some(adjusted_width))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::plot::layer::Geom; + use crate::plot::{AestheticValue, Mappings, Scale, ScaleType}; + + fn make_test_df() -> DataFrame { + df! { + "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + } + .unwrap() + } + + fn make_test_layer() -> Layer { + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + layer + } + + fn make_continuous_scale(aesthetic: &str) -> Scale { + let mut scale = Scale::new(aesthetic); + scale.scale_type = Some(ScaleType::continuous()); + scale + } + + fn make_discrete_scale(aesthetic: &str) -> Scale { + let mut scale = Scale::new(aesthetic); + scale.scale_type = Some(ScaleType::discrete()); + scale + } + + #[test] + fn test_dodge_horizontal_only() { + // When pos1 is discrete and pos2 is continuous, only pos1offset is created + let dodge = Dodge; + assert_eq!(dodge.position_type(), PositionType::Dodge); + + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark pos1 as discrete and pos2 as continuous via scales + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, width) = dodge.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify pos1offset column was created + assert!( + result.column("__ggsql_aes_pos1offset__").is_ok(), + "pos1offset column should be created" + ); + // Verify pos2offset column was NOT created + assert!( + result.column("__ggsql_aes_pos2offset__").is_err(), + "pos2offset column should NOT be created when pos2 is continuous" + ); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + + // With 2 groups (X, Y) and default width 0.9: + // - adjusted_width = 0.9 / 2 = 0.45 + // - center_offset = 0.5 + // - Group X: center = (0 - 0.5) * 0.45 = -0.225 + // - Group Y: center = (1 - 0.5) * 0.45 = +0.225 + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + assert!( + offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001), + "Should have offset -0.225 for group X, got {:?}", + offsets + ); + assert!( + offsets.iter().any(|&v| (v - 0.225).abs() < 0.001), + "Should have offset +0.225 for group Y, got {:?}", + offsets + ); + + // Verify adjusted_width was returned + assert!(width.is_some()); + assert!( + (width.unwrap() - 0.45).abs() < 0.001, + "adjusted_width should be 0.9/2 = 0.45, got {:?}", + width + ); + } + + #[test] + fn test_dodge_vertical_only() { + // When pos1 is continuous and pos2 is discrete, only pos2offset is created + let dodge = Dodge; + + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark pos1 as continuous and pos2 as discrete via scales + let mut spec = Plot::new(); + spec.scales.push(make_continuous_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + + let (result, width) = dodge.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify pos1offset column was NOT created + assert!( + result.column("__ggsql_aes_pos1offset__").is_err(), + "pos1offset column should NOT be created when pos1 is continuous" + ); + // Verify pos2offset column was created + assert!( + result.column("__ggsql_aes_pos2offset__").is_ok(), + "pos2offset column should be created" + ); + + let offset = result + .column("__ggsql_aes_pos2offset__") + .unwrap() + .f64() + .unwrap(); + + // With 2 groups (X, Y) and default width 0.9: + // - adjusted_width = 0.9 / 2 = 0.45 + // - center_offset = 0.5 + // - Group X: center = (0 - 0.5) * 0.45 = -0.225 + // - Group Y: center = (1 - 0.5) * 0.45 = +0.225 + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + assert!( + offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001), + "Should have offset -0.225 for group X, got {:?}", + offsets + ); + assert!( + offsets.iter().any(|&v| (v - 0.225).abs() < 0.001), + "Should have offset +0.225 for group Y, got {:?}", + offsets + ); + + // Verify adjusted_width was returned + assert!(width.is_some()); + assert!( + (width.unwrap() - 0.45).abs() < 0.001, + "adjusted_width should be 0.9/2 = 0.45, got {:?}", + width + ); + } + + #[test] + fn test_dodge_bidirectional_2x2_grid() { + // When both axes are discrete, groups are arranged in a 2D grid + let dodge = Dodge; + + let df = make_test_df(); + let layer = make_test_layer(); + + // Both axes must be explicitly marked as discrete + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + + let (result, width) = dodge.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify both offset columns were created + assert!( + result.column("__ggsql_aes_pos1offset__").is_ok(), + "pos1offset column should be created" + ); + assert!( + result.column("__ggsql_aes_pos2offset__").is_ok(), + "pos2offset column should be created" + ); + + // With 2 groups in 2D mode, grid_size = ceil(sqrt(2)) = 2 + // adjusted_width = 0.9 / 2 = 0.45 + // center_offset = (2 - 1) / 2 = 0.5 + // Group 0 (X): col=0, row=0 → pos1=(-0.5)*0.45=-0.225, pos2=(-0.5)*0.45=-0.225 + // Group 1 (Y): col=1, row=0 → pos1=(0.5)*0.45=0.225, pos2=(-0.5)*0.45=-0.225 + let pos1_offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let pos2_offset = result + .column("__ggsql_aes_pos2offset__") + .unwrap() + .f64() + .unwrap(); + + let pos1_offsets: Vec = pos1_offset.into_iter().filter_map(|v| v).collect(); + let pos2_offsets: Vec = pos2_offset.into_iter().filter_map(|v| v).collect(); + + // Verify we have both expected pos1 offsets + assert!( + pos1_offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001), + "Should have pos1offset -0.225, got {:?}", + pos1_offsets + ); + assert!( + pos1_offsets.iter().any(|&v| (v - 0.225).abs() < 0.001), + "Should have pos1offset +0.225, got {:?}", + pos1_offsets + ); + + // Verify pos2 offsets (in 2x2 grid with 2 groups, both groups are in row 0) + // Group 0: row=0, Group 1: row=0 + // So all pos2 offsets should be (0 - 0.5) * 0.45 = -0.225 + for &v in &pos2_offsets { + assert!( + (v - (-0.225)).abs() < 0.001, + "All pos2 offsets should be -0.225 for 2 groups in 2x2 grid, got {}", + v + ); + } + + // Verify adjusted_width + assert!(width.is_some()); + assert!( + (width.unwrap() - 0.45).abs() < 0.001, + "adjusted_width should be 0.9/2 = 0.45, got {:?}", + width + ); + } + + #[test] + fn test_dodge_bidirectional_3x3_grid() { + // Test with 4 groups to verify 2x2 arrangement within 2x2 grid + let dodge = Dodge; + + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "A"], + "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_fill__" => ["G1", "G2", "G3", "G4"], + } + .unwrap(); + + let mut layer = Layer::new(Geom::point()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + + // Both axes must be explicitly marked as discrete + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + + let (result, width) = dodge.apply_adjustment(&df, &layer, &spec).unwrap(); + + // With 4 groups in 2D mode, grid_size = ceil(sqrt(4)) = 2 + // This gives a 2x2 grid layout: + // G1: col=0, row=0 → (-0.5, -0.5) * adjusted_width + // G2: col=1, row=0 → (+0.5, -0.5) * adjusted_width + // G3: col=0, row=1 → (-0.5, +0.5) * adjusted_width + // G4: col=1, row=1 → (+0.5, +0.5) * adjusted_width + + let pos1_offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let pos2_offset = result + .column("__ggsql_aes_pos2offset__") + .unwrap() + .f64() + .unwrap(); + + let pos1_offsets: Vec = pos1_offset.into_iter().filter_map(|v| v).collect(); + let pos2_offsets: Vec = pos2_offset.into_iter().filter_map(|v| v).collect(); + + // Verify we have both positive and negative offsets in both dimensions + assert!( + pos1_offsets.iter().any(|&v| v < 0.0), + "Should have negative pos1 offsets" + ); + assert!( + pos1_offsets.iter().any(|&v| v > 0.0), + "Should have positive pos1 offsets" + ); + assert!( + pos2_offsets.iter().any(|&v| v < 0.0), + "Should have negative pos2 offsets" + ); + assert!( + pos2_offsets.iter().any(|&v| v > 0.0), + "Should have positive pos2 offsets" + ); + + // Verify adjusted_width = 0.9 / 2 = 0.45 + assert!(width.is_some()); + assert!( + (width.unwrap() - 0.45).abs() < 0.001, + "adjusted_width should be 0.9/2 = 0.45 for 4 groups (2x2 grid), got {:?}", + width + ); + } + + #[test] + fn test_dodge_neither_discrete() { + // When both axes are continuous, no offset columns are created + let dodge = Dodge; + + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark both as continuous + let mut spec = Plot::new(); + spec.scales.push(make_continuous_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, width) = dodge.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify neither offset column was created + assert!( + result.column("__ggsql_aes_pos1offset__").is_err(), + "pos1offset column should NOT be created when pos1 is continuous" + ); + assert!( + result.column("__ggsql_aes_pos2offset__").is_err(), + "pos2offset column should NOT be created when pos2 is continuous" + ); + + // No adjusted width when no dodging occurs + assert!(width.is_none()); + } + + #[test] + fn test_dodge_custom_width() { + let dodge = Dodge; + + let df = make_test_df(); + let mut layer = make_test_layer(); + layer + .parameters + .insert("width".to_string(), ParameterValue::Number(0.6)); + + // Mark pos1 as discrete and pos2 as continuous so only pos1offset is created + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, width) = dodge.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + + // With 2 groups and custom width 0.6: + // - adjusted_width = 0.6 / 2 = 0.3 + // - center_offset = 0.5 + // - Group X: center = (0 - 0.5) * 0.3 = -0.15 + // - Group Y: center = (1 - 0.5) * 0.3 = +0.15 + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + assert!( + offsets.iter().any(|&v| (v - (-0.15)).abs() < 0.001), + "Should have offset -0.15 for group X, got {:?}", + offsets + ); + assert!( + offsets.iter().any(|&v| (v - 0.15).abs() < 0.001), + "Should have offset +0.15 for group Y, got {:?}", + offsets + ); + + assert!((width.unwrap() - 0.3).abs() < 0.001); + } + + #[test] + fn test_dodge_creates_pos1offset() { + assert!(Dodge.creates_pos1offset()); + } + + #[test] + fn test_dodge_creates_pos2offset() { + assert!(Dodge.creates_pos2offset()); + } + + #[test] + fn test_dodge_default_params() { + let dodge = Dodge; + let params = dodge.default_params(); + assert_eq!(params.len(), 1); + assert_eq!(params[0].name, "width"); + assert!(matches!(params[0].default, DefaultParamValue::Number(0.9))); + } +} diff --git a/src/plot/layer/position/identity.rs b/src/plot/layer/position/identity.rs new file mode 100644 index 00000000..f7c44961 --- /dev/null +++ b/src/plot/layer/position/identity.rs @@ -0,0 +1,58 @@ +//! Identity position adjustment +//! +//! No position adjustment - elements are positioned at their exact data values. + +use super::{Layer, PositionTrait, PositionType}; +use crate::{DataFrame, Plot, Result}; + +/// Identity position - no adjustment applied +#[derive(Debug, Clone, Copy)] +pub struct Identity; + +impl std::fmt::Display for Identity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "identity") + } +} + +impl PositionTrait for Identity { + fn position_type(&self) -> PositionType { + PositionType::Identity + } + + fn apply_adjustment( + &self, + df: &DataFrame, + _layer: &Layer, + _spec: &Plot, + ) -> Result<(DataFrame, Option)> { + // Identity returns data unchanged + Ok((df.clone(), None)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use polars::prelude::*; + + #[test] + fn test_identity_no_change() { + let identity = Identity; + assert_eq!(identity.position_type(), PositionType::Identity); + + let df = df! { + "x" => [1, 2, 3], + "y" => [10, 20, 30], + } + .unwrap(); + + let layer = Layer::new(crate::plot::layer::Geom::point()); + let spec = Plot::new(); + + let (result, width) = identity.apply_adjustment(&df, &layer, &spec).unwrap(); + + assert_eq!(result.height(), 3); + assert!(width.is_none()); + } +} diff --git a/src/plot/layer/position/jitter.rs b/src/plot/layer/position/jitter.rs new file mode 100644 index 00000000..36711401 --- /dev/null +++ b/src/plot/layer/position/jitter.rs @@ -0,0 +1,1634 @@ +//! Jitter position adjustment +//! +//! Adds random displacement to elements to avoid overplotting. +//! Jitter automatically detects which axes are discrete and applies +//! jitter to those axes: +//! - If only pos1 is discrete → jitter horizontally (pos1offset) +//! - If only pos2 is discrete → jitter vertically (pos2offset) +//! - If both are discrete → jitter in both directions +//! +//! When `dodge=true` (default), jitter first applies dodge positioning to separate +//! groups, then applies random jitter within the reduced width of each group's space. +//! +//! The `distribution` parameter controls the shape of the jitter: +//! - `uniform` (default): uniform random distribution across the width +//! - `normal`: normal/Gaussian distribution with ~95% of points within the width + +use super::{compute_group_indices, is_continuous_scale, Layer, PositionTrait, PositionType}; +use crate::plot::types::{DefaultParam, DefaultParamValue, ParameterValue}; +use crate::{naming, DataFrame, GgsqlError, Plot, Result}; +use polars::prelude::*; +use rand::Rng; + +/// Jitter distribution type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum JitterDistribution { + Uniform, + Normal, + /// Per-group normalized density (area under curve = 1). + /// Narrow distributions have higher peaks than wide distributions. + Density, + /// Count-weighted density (not normalized by group size). + /// Groups with more observations have higher peaks. + /// Both density and intensity use global max normalization. + Intensity, +} + +impl JitterDistribution { + fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "normal" | "gaussian" => Self::Normal, + "density" => Self::Density, + "intensity" => Self::Intensity, + _ => Self::Uniform, + } + } + + /// Generate a random jitter value within the given width. + /// + /// For uniform: values are in [-width/2, width/2] + /// For normal: σ = width/4, so ~95% of values fall within [-width/2, width/2] + /// For density/intensity: not applicable, density scaling is handled separately + fn sample(&self, rng: &mut R, width: f64) -> f64 { + match self { + Self::Uniform | Self::Density | Self::Intensity => (rng.gen::() - 0.5) * width, + Self::Normal => { + // Box-Muller transform for normal distribution + // Use σ = width/4 so 95% of values fall within ±2σ = ±width/2 + let sigma = width / 4.0; + let u1: f64 = rng.gen(); + let u2: f64 = rng.gen(); + // Avoid log(0) by ensuring u1 > 0 + let u1 = if u1 == 0.0 { f64::MIN_POSITIVE } else { u1 }; + let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos(); + z * sigma + } + } + } +} + +// ============================================================================ +// Density estimation for density-based jitter +// ============================================================================ + +/// Compute Silverman's rule of thumb bandwidth for KDE. +/// +/// Uses the formula: h = 0.9 * adjust * min(σ, IQR/1.34) * n^(-0.2) +/// +/// This matches the bandwidth calculation used by the density and violin geoms, +/// ensuring consistent density estimates when using `distribution => 'density'` +/// with a violin layer. +fn silverman_bandwidth(values: &[f64], adjust: f64) -> f64 { + let n = values.len() as f64; + if n <= 1.0 { + return 1.0; + } + + // Compute mean and standard deviation (population stddev to match SQL STDDEV) + let mean = values.iter().sum::() / n; + let variance = values.iter().map(|x| (x - mean).powi(2)).sum::() / n; + let std_dev = variance.sqrt(); + + // Compute IQR (interquartile range) using linear interpolation + // This matches SQL's QUANTILE_CONT behavior + let mut sorted = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + + let q1 = quantile_cont(&sorted, 0.25); + let q3 = quantile_cont(&sorted, 0.75); + let iqr = q3 - q1; + + // Silverman's rule: 0.9 * adjust * min(σ, IQR/1.34) * n^(-0.2) + let scale = if iqr > 0.0 { + std_dev.min(iqr / 1.34) + } else { + std_dev + }; + + if scale == 0.0 { + return 1.0; // Fallback for constant data + } + + 0.9 * adjust * scale * n.powf(-0.2) +} + +/// Compute continuous quantile using linear interpolation. +/// Matches SQL QUANTILE_CONT behavior. +fn quantile_cont(sorted: &[f64], p: f64) -> f64 { + if sorted.is_empty() { + return 0.0; + } + if sorted.len() == 1 { + return sorted[0]; + } + + let n = sorted.len() as f64; + let idx = p * (n - 1.0); + let lo = idx.floor() as usize; + let hi = idx.ceil() as usize; + let frac = idx - lo as f64; + + if lo == hi || hi >= sorted.len() { + sorted[lo] + } else { + sorted[lo] * (1.0 - frac) + sorted[hi] * frac + } +} + +/// Compute density at each point using Gaussian KDE (normalized PDF). +/// +/// For each point xi, computes: f(xi) = (1/nh) * Σ K((xi - xj) / h) +/// where K is the Gaussian kernel. +/// +/// This produces a normalized PDF where the area under the curve equals 1. +/// Narrow distributions will have higher peaks than wide distributions. +fn compute_densities(values: &[f64], bandwidth: f64) -> Vec { + let n = values.len() as f64; + let norm_factor = 1.0 / (bandwidth * n * (2.0 * std::f64::consts::PI).sqrt()); + + values + .iter() + .map(|&xi| { + // Sum kernel contributions from all points + let density: f64 = values + .iter() + .map(|&xj| { + let u = (xi - xj) / bandwidth; + (-0.5 * u * u).exp() + }) + .sum(); + density * norm_factor + }) + .collect() +} + +/// Compute intensity at each point using Gaussian KDE (count-weighted, not normalized). +/// +/// For each point xi, computes: f(xi) = (1/h) * Σ K((xi - xj) / h) +/// where K is the Gaussian kernel. +/// +/// Unlike `compute_densities`, this does NOT divide by n, so groups with more +/// observations will have higher values. This makes the width proportional to +/// the number of data points. +fn compute_intensities(values: &[f64], bandwidth: f64) -> Vec { + let norm_factor = 1.0 / (bandwidth * (2.0 * std::f64::consts::PI).sqrt()); + + values + .iter() + .map(|&xi| { + // Sum kernel contributions from all points + let intensity: f64 = values + .iter() + .map(|&xj| { + let u = (xi - xj) / bandwidth; + (-0.5 * u * u).exp() + }) + .sum(); + intensity * norm_factor + }) + .collect() +} + +/// Compute density/intensity scales for grouped data with global normalization. +/// +/// When groups exist, compute density/intensity separately per group, but normalize +/// using the global max across ALL groups. This preserves relative differences: +/// - For density: narrow distributions appear wider (higher peaks) +/// - For intensity: groups with more data appear wider +/// +/// # Arguments +/// * `values` - All values from the continuous axis +/// * `group_indices` - Group index for each value +/// * `n_groups` - Number of distinct groups +/// * `explicit_bandwidth` - Optional explicit bandwidth (overrides Silverman's rule) +/// * `adjust` - Bandwidth adjustment multiplier +/// * `use_intensity` - If true, use intensity (count-weighted); if false, use density (normalized PDF) +fn compute_grouped_scales( + values: &[f64], + group_indices: &[usize], + n_groups: usize, + explicit_bandwidth: Option, + adjust: f64, + use_intensity: bool, +) -> Vec { + // Group values by their group index + let mut grouped_values: Vec> = vec![Vec::new(); n_groups]; + let mut grouped_original_indices: Vec> = vec![Vec::new(); n_groups]; + + for (i, (&value, &group_idx)) in values.iter().zip(group_indices.iter()).enumerate() { + grouped_values[group_idx].push(value); + grouped_original_indices[group_idx].push(i); + } + + // Compute raw density/intensity for each group (before normalization) + let mut all_raw_values = vec![0.0; values.len()]; + + for group_idx in 0..n_groups { + let group_vals = &grouped_values[group_idx]; + if group_vals.is_empty() { + continue; + } + + // Use explicit bandwidth if provided, otherwise compute per-group using Silverman's rule + // This matches how violin/density compute bandwidth per group + let bandwidth = explicit_bandwidth + .map(|bw| bw * adjust) + .unwrap_or_else(|| silverman_bandwidth(group_vals, adjust)); + + // Compute raw values using appropriate formula + let raw = if use_intensity { + compute_intensities(group_vals, bandwidth) + } else { + compute_densities(group_vals, bandwidth) + }; + + // Map back to original indices + for (within_group_idx, &original_idx) in + grouped_original_indices[group_idx].iter().enumerate() + { + all_raw_values[original_idx] = raw[within_group_idx]; + } + } + + // Global normalization: divide by max across ALL groups + let global_max = all_raw_values.iter().fold(0.0_f64, |a, &b| a.max(b)); + if global_max > 0.0 { + all_raw_values.iter().map(|v| v / global_max).collect() + } else { + vec![1.0; values.len()] + } +} + +/// Jitter position - add random displacement +#[derive(Debug, Clone, Copy)] +pub struct Jitter; + +impl std::fmt::Display for Jitter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "jitter") + } +} + +impl PositionTrait for Jitter { + fn position_type(&self) -> PositionType { + PositionType::Jitter + } + + fn default_params(&self) -> &'static [DefaultParam] { + &[ + DefaultParam { + name: "width", + default: DefaultParamValue::Number(0.9), + }, + DefaultParam { + name: "dodge", + default: DefaultParamValue::Boolean(true), + }, + DefaultParam { + name: "distribution", + default: DefaultParamValue::String("uniform"), + }, + // Density distribution parameters (match violin/density geoms) + DefaultParam { + name: "bandwidth", + default: DefaultParamValue::Null, + }, + DefaultParam { + name: "adjust", + default: DefaultParamValue::Number(1.0), + }, + ] + } + + fn creates_pos1offset(&self) -> bool { + true + } + + fn creates_pos2offset(&self) -> bool { + true + } + + fn apply_adjustment( + &self, + df: &DataFrame, + layer: &Layer, + spec: &Plot, + ) -> Result<(DataFrame, Option)> { + Ok((apply_jitter(df, layer, spec)?, None)) + } +} + +/// Apply jitter position adjustment. +/// +/// Automatically detects which axes are discrete and applies jitter accordingly: +/// - Discrete pos1 → creates pos1offset column +/// - Discrete pos2 → creates pos2offset column +/// - Both discrete → creates both offset columns +/// - Neither discrete → returns unchanged (no jitter applied) +/// +/// When `dodge=true` (default), groups are first dodged to separate positions, +/// then jitter is applied within each group's reduced space. +/// +/// The width parameter controls the total jitter range. When dodging is applied, +/// the effective jitter range is reduced by the number of groups. +/// +/// The `distribution` parameter controls the jitter shape: +/// - `uniform`: uniform random distribution across the width (default) +/// - `normal`: Gaussian distribution with ~95% of points within the width +/// - `density`: scales jitter width by local density (requires exactly one continuous axis) +fn apply_jitter(df: &DataFrame, layer: &Layer, spec: &Plot) -> Result { + // Check which axes should be jittered (discrete axes) + // Since infer_scale_types_from_data() runs before position adjustments, + // scale types are always known, so we use explicit discrete checks. + let jitter_pos1 = is_continuous_scale(spec, "pos1") == Some(false); + let jitter_pos2 = is_continuous_scale(spec, "pos2") == Some(false); + + // Get width parameter (default 0.9) + let width = layer + .parameters + .get("width") + .and_then(|v| match v { + ParameterValue::Number(n) => Some(*n), + _ => None, + }) + .unwrap_or(0.9); + + // Get dodge parameter (default true) + let dodge = layer + .parameters + .get("dodge") + .and_then(|v| match v { + ParameterValue::Boolean(b) => Some(*b), + _ => None, + }) + .unwrap_or(true); + + // Get distribution parameter (default "uniform") + let distribution = layer + .parameters + .get("distribution") + .and_then(|v| match v { + ParameterValue::String(s) => Some(JitterDistribution::from_str(s)), + _ => None, + }) + .unwrap_or(JitterDistribution::Uniform); + + // Density/intensity distribution validation: requires exactly one continuous axis + // (one discrete axis to jitter along, one continuous axis for density) + let pos1_continuous = !jitter_pos1; + let pos2_continuous = !jitter_pos2; + let use_density_scaling = distribution == JitterDistribution::Density + || distribution == JitterDistribution::Intensity; + if use_density_scaling { + let continuous_count = [pos1_continuous, pos2_continuous] + .iter() + .filter(|&&b| b) + .count(); + if continuous_count != 1 { + let dist_name = if distribution == JitterDistribution::Intensity { + "intensity" + } else { + "density" + }; + return Err(GgsqlError::ValidationError(format!( + "Jitter distribution '{}' requires exactly one continuous axis", + dist_name + ))); + } + } + + let mut rng = rand::thread_rng(); + let n_rows = df.height(); + + // Compute group info for dodge-first behavior + let group_info = if dodge { + compute_group_indices(df, &layer.partition_by)? + } else { + None + }; + + // Determine effective width and dodge offsets based on grouping + let (effective_width, n_groups, group_indices) = match &group_info { + Some(info) if info.n_groups > 1 => { + let adjusted = width / info.n_groups as f64; + (adjusted, info.n_groups, Some(&info.indices)) + } + _ => (width, 1, None), + }; + + // Get density-specific parameters (match violin/density geoms) + let explicit_bandwidth = layer.parameters.get("bandwidth").and_then(|v| match v { + ParameterValue::Number(n) => Some(*n), + _ => None, + }); + + let adjust = layer + .parameters + .get("adjust") + .and_then(|v| match v { + ParameterValue::Number(n) => Some(*n), + _ => None, + }) + .unwrap_or(1.0); + + // For density/intensity distribution, compute scales along the continuous axis + // IMPORTANT: Density must be computed per group, matching how violin computes density. + // Groups are determined by: discrete axis (e.g., species) + partition_by (e.g., color) + let use_intensity = distribution == JitterDistribution::Intensity; + let density_scales = if use_density_scaling { + // Identify axes + let continuous_col = if pos1_continuous { "pos1" } else { "pos2" }; + let discrete_col = if pos1_continuous { "pos2" } else { "pos1" }; + let continuous_col_name = naming::aesthetic_column(continuous_col); + let discrete_col_name = naming::aesthetic_column(discrete_col); + + // Extract values from the continuous axis + let values: Vec = df + .column(&continuous_col_name) + .map_err(|_| { + GgsqlError::InternalError(format!( + "Missing {} column for density jitter", + continuous_col + )) + })? + .cast(&DataType::Float64) + .map_err(|_| { + GgsqlError::InternalError(format!( + "{} must be numeric for density jitter", + continuous_col + )) + })? + .f64() + .map_err(|_| { + GgsqlError::InternalError(format!( + "{} must be numeric for density jitter", + continuous_col + )) + })? + .into_iter() + .map(|v| v.unwrap_or(0.0)) + .collect(); + + // Build density grouping columns: discrete axis + relevant partition_by columns + // This matches how violin computes density per group + let mut density_group_cols = vec![discrete_col_name.clone()]; + for col in &layer.partition_by { + if density_group_cols.contains(col) { + continue; + } + // When dodge is false, only include facet variables (not color/fill groups) + // Facet variables have predictable names: __ggsql_aes_facet1__, __ggsql_aes_facet2__ + if !dodge && !col.contains("_facet") { + continue; + } + density_group_cols.push(col.clone()); + } + + // Compute density grouping + let density_group_info = compute_group_indices(df, &density_group_cols)?; + + // Compute density/intensity scales per group with global normalization + if let Some(info) = density_group_info { + Some(compute_grouped_scales( + &values, + &info.indices, + info.n_groups, + explicit_bandwidth, + adjust, + use_intensity, + )) + } else { + // Single group - compute global density/intensity + let bandwidth = explicit_bandwidth + .map(|bw| bw * adjust) + .unwrap_or_else(|| silverman_bandwidth(&values, adjust)); + let raw = if use_intensity { + compute_intensities(&values, bandwidth) + } else { + compute_densities(&values, bandwidth) + }; + // Normalize to [0, 1] + let max_val = raw.iter().fold(0.0_f64, |a, &b| a.max(b)); + if max_val > 0.0 { + Some(raw.iter().map(|v| v / max_val).collect()) + } else { + Some(vec![1.0; values.len()]) + } + } + } else { + None + }; + + let pos1offset_col = naming::aesthetic_column("pos1offset"); + let pos2offset_col = naming::aesthetic_column("pos2offset"); + + let mut result = df.clone().lazy(); + + // For 1D jitter (only one axis is discrete), use linear dodge layout + // For 2D jitter (both axes discrete), use grid layout like dodge does + if jitter_pos1 && jitter_pos2 && n_groups > 1 { + // 2D grid dodge + jitter + let grid_size = (n_groups as f64).sqrt().ceil() as usize; + let grid_adjusted_width = width / grid_size as f64; + let center_offset = (grid_size as f64 - 1.0) / 2.0; + + let indices = group_indices.unwrap(); + + // Pre-generate jitter values for both axes + let jitter1: Vec = (0..indices.len()) + .map(|_| distribution.sample(&mut rng, grid_adjusted_width)) + .collect(); + let jitter2: Vec = (0..indices.len()) + .map(|_| distribution.sample(&mut rng, grid_adjusted_width)) + .collect(); + + // Compute pos1 offsets: dodge center + jitter + let pos1_offsets: Vec = indices + .iter() + .zip(jitter1.iter()) + .map(|(&idx, &jitter)| { + let col_idx = idx % grid_size; + let dodge_center = (col_idx as f64 - center_offset) * grid_adjusted_width; + dodge_center + jitter + }) + .collect(); + + // Compute pos2 offsets: dodge center + jitter + let pos2_offsets: Vec = indices + .iter() + .zip(jitter2.iter()) + .map(|(&idx, &jitter)| { + let row_idx = idx / grid_size; + let dodge_center = (row_idx as f64 - center_offset) * grid_adjusted_width; + dodge_center + jitter + }) + .collect(); + + result = result.with_column( + lit(Series::new(pos1offset_col.clone().into(), pos1_offsets)).alias(&pos1offset_col), + ); + result = result.with_column( + lit(Series::new(pos2offset_col.clone().into(), pos2_offsets)).alias(&pos2offset_col), + ); + } else { + // 1D jitter (or no dodge) + let center_offset = (n_groups as f64 - 1.0) / 2.0; + + // Add pos1offset if pos1 is discrete + if jitter_pos1 { + let offsets: Vec = if let Some(indices) = group_indices { + // Pre-generate jitter values + let jitters: Vec = (0..indices.len()) + .map(|_| distribution.sample(&mut rng, effective_width)) + .collect(); + // Dodge + jitter: deterministic dodge center + random jitter (with density scaling) + indices + .iter() + .zip(jitters.iter()) + .enumerate() + .map(|(i, (&idx, &jitter))| { + let dodge_center = (idx as f64 - center_offset) * effective_width; + // Scale jitter by density if using density distribution + let scaled_jitter = if let Some(ref scales) = density_scales { + jitter * scales[i] + } else { + jitter + }; + dodge_center + scaled_jitter + }) + .collect() + } else { + // Pure jitter (with density scaling) + (0..n_rows) + .map(|i| { + let jitter = distribution.sample(&mut rng, width); + if let Some(ref scales) = density_scales { + jitter * scales[i] + } else { + jitter + } + }) + .collect() + }; + result = result.with_column( + lit(Series::new(pos1offset_col.clone().into(), offsets)).alias(&pos1offset_col), + ); + } + + // Add pos2offset if pos2 is discrete + if jitter_pos2 { + let offsets: Vec = if let Some(indices) = group_indices { + // Pre-generate jitter values + let jitters: Vec = (0..indices.len()) + .map(|_| distribution.sample(&mut rng, effective_width)) + .collect(); + // Dodge + jitter: deterministic dodge center + random jitter (with density scaling) + indices + .iter() + .zip(jitters.iter()) + .enumerate() + .map(|(i, (&idx, &jitter))| { + let dodge_center = (idx as f64 - center_offset) * effective_width; + // Scale jitter by density if using density distribution + let scaled_jitter = if let Some(ref scales) = density_scales { + jitter * scales[i] + } else { + jitter + }; + dodge_center + scaled_jitter + }) + .collect() + } else { + // Pure jitter (with density scaling) + (0..n_rows) + .map(|i| { + let jitter = distribution.sample(&mut rng, width); + if let Some(ref scales) = density_scales { + jitter * scales[i] + } else { + jitter + } + }) + .collect() + }; + result = result.with_column( + lit(Series::new(pos2offset_col.clone().into(), offsets)).alias(&pos2offset_col), + ); + } + } + + result + .collect() + .map_err(|e| GgsqlError::InternalError(format!("Jitter position adjustment failed: {}", e))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::plot::layer::Geom; + use crate::plot::{AestheticValue, Mappings, Scale, ScaleType}; + + fn make_test_df() -> DataFrame { + df! { + "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + } + .unwrap() + } + + fn make_test_layer() -> Layer { + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + layer + } + + fn make_continuous_scale(aesthetic: &str) -> Scale { + let mut scale = Scale::new(aesthetic); + scale.scale_type = Some(ScaleType::continuous()); + scale + } + + fn make_discrete_scale(aesthetic: &str) -> Scale { + let mut scale = Scale::new(aesthetic); + scale.scale_type = Some(ScaleType::discrete()); + scale + } + + #[test] + fn test_jitter_horizontal_only_with_dodge() { + // When pos1 is discrete and pos2 is continuous, only pos1offset is created + // With default dodge=true and 2 groups, offsets should be dodge + jitter + let jitter = Jitter; + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark pos1 as discrete and pos2 as continuous via scales + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, width) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify pos1offset column was created + assert!( + result.column("__ggsql_aes_pos1offset__").is_ok(), + "pos1offset column should be created" + ); + // Verify pos2offset column was NOT created + assert!( + result.column("__ggsql_aes_pos2offset__").is_err(), + "pos2offset column should NOT be created when pos2 is continuous" + ); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // With default width 0.9 and 2 groups (dodge=true): + // effective_width = 0.9 / 2 = 0.45 + // Group X center: -0.225, Group Y center: +0.225 + // With jitter in range [-0.225, +0.225] around each center + // Total range: [-0.45, 0.45] + for &v in &offsets { + assert!( + v >= -0.45 && v <= 0.45, + "Jitter+dodge offset {} should be in range [-0.45, 0.45]", + v + ); + } + + // Verify no adjusted_width is returned for jitter + assert!(width.is_none()); + } + + #[test] + fn test_jitter_horizontal_no_dodge() { + // With dodge=false, should behave like classic jitter + let jitter = Jitter; + let df = make_test_df(); + let mut layer = make_test_layer(); + layer + .parameters + .insert("dodge".to_string(), ParameterValue::Boolean(false)); + + // Mark pos1 as discrete and pos2 as continuous via scales + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // With dodge=false and width 0.9, pure jitter in range [-0.45, 0.45] + for &v in &offsets { + assert!( + v >= -0.45 && v <= 0.45, + "Pure jitter offset {} should be in range [-0.45, 0.45]", + v + ); + } + } + + #[test] + fn test_jitter_vertical_only() { + // When pos1 is continuous and pos2 is discrete, only pos2offset is created + let jitter = Jitter; + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark pos1 as continuous and pos2 as discrete via scales + let mut spec = Plot::new(); + spec.scales.push(make_continuous_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify pos1offset column was NOT created + assert!( + result.column("__ggsql_aes_pos1offset__").is_err(), + "pos1offset column should NOT be created when pos1 is continuous" + ); + // Verify pos2offset column was created + assert!( + result.column("__ggsql_aes_pos2offset__").is_ok(), + "pos2offset column should be created" + ); + + let offset = result + .column("__ggsql_aes_pos2offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // With default width 0.9 and 2 groups (dodge=true), effective range is [-0.45, 0.45] + for &v in &offsets { + assert!( + v >= -0.45 && v <= 0.45, + "Jitter+dodge offset {} should be in range [-0.45, 0.45]", + v + ); + } + } + + #[test] + fn test_jitter_bidirectional() { + // When both axes are discrete, both offset columns are created + let jitter = Jitter; + let df = make_test_df(); + let layer = make_test_layer(); + + // Both axes must be explicitly marked as discrete + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify both offset columns were created + assert!( + result.column("__ggsql_aes_pos1offset__").is_ok(), + "pos1offset column should be created" + ); + assert!( + result.column("__ggsql_aes_pos2offset__").is_ok(), + "pos2offset column should be created" + ); + } + + #[test] + fn test_jitter_neither_discrete() { + // When both axes are continuous, no offset columns are created + let jitter = Jitter; + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark both as continuous + let mut spec = Plot::new(); + spec.scales.push(make_continuous_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify neither offset column was created + assert!( + result.column("__ggsql_aes_pos1offset__").is_err(), + "pos1offset column should NOT be created when pos1 is continuous" + ); + assert!( + result.column("__ggsql_aes_pos2offset__").is_err(), + "pos2offset column should NOT be created when pos2 is continuous" + ); + } + + #[test] + fn test_jitter_custom_width_with_dodge() { + let jitter = Jitter; + + let df = make_test_df(); + let mut layer = make_test_layer(); + layer + .parameters + .insert("width".to_string(), ParameterValue::Number(0.6)); + + // Mark pos1 as discrete and pos2 as continuous so only pos1offset is created + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // With custom width 0.6 and 2 groups (dodge=true): + // effective_width = 0.6 / 2 = 0.3 + // Total range: [-0.3, 0.3] + for &v in &offsets { + assert!( + v >= -0.3 && v <= 0.3, + "Jitter+dodge offset {} should be in range [-0.3, 0.3] with width 0.6", + v + ); + } + } + + #[test] + fn test_jitter_groups_separate_with_dodge() { + // With dodge=true, different groups should have different center positions + let jitter = Jitter; + + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark pos1 as discrete and pos2 as continuous so only pos1offset is created + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let fill_col = result.column("__ggsql_aes_fill__").unwrap(); + + // Collect offsets by group + let mut group_x_offsets = vec![]; + let mut group_y_offsets = vec![]; + + for i in 0..result.height() { + let fill_val = fill_col.get(i).unwrap(); + let offset_val = offset.get(i).unwrap(); + if fill_val.to_string().contains("X") { + group_x_offsets.push(offset_val); + } else { + group_y_offsets.push(offset_val); + } + } + + // With dodge, group X should have negative-centered offsets + // and group Y should have positive-centered offsets + let x_mean: f64 = group_x_offsets.iter().sum::() / group_x_offsets.len() as f64; + let y_mean: f64 = group_y_offsets.iter().sum::() / group_y_offsets.len() as f64; + + // The means should be on opposite sides of 0 (X negative, Y positive) + // Allow some variance due to jitter randomness + assert!( + x_mean < y_mean, + "Group X mean ({}) should be less than Group Y mean ({})", + x_mean, + y_mean + ); + } + + #[test] + fn test_jitter_no_groups_no_dodge() { + // Without partition_by columns, no dodge is applied + let jitter = Jitter; + + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.partition_by = vec![]; // No grouping + + // Mark pos1 as discrete and pos2 as continuous so only pos1offset is created + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // Without groups, pure jitter with full width range [-0.45, 0.45] + for &v in &offsets { + assert!( + v >= -0.45 && v <= 0.45, + "Pure jitter offset {} should be in range [-0.45, 0.45]", + v + ); + } + } + + #[test] + fn test_jitter_creates_pos1offset() { + assert!(Jitter.creates_pos1offset()); + } + + #[test] + fn test_jitter_creates_pos2offset() { + assert!(Jitter.creates_pos2offset()); + } + + #[test] + fn test_jitter_default_params() { + let jitter = Jitter; + let params = jitter.default_params(); + assert_eq!(params.len(), 5); + assert_eq!(params[0].name, "width"); + assert!(matches!(params[0].default, DefaultParamValue::Number(0.9))); + assert_eq!(params[1].name, "dodge"); + assert!(matches!( + params[1].default, + DefaultParamValue::Boolean(true) + )); + assert_eq!(params[2].name, "distribution"); + assert!(matches!( + params[2].default, + DefaultParamValue::String("uniform") + )); + // Density distribution parameters (match violin/density geoms) + assert_eq!(params[3].name, "bandwidth"); + assert!(matches!(params[3].default, DefaultParamValue::Null)); + assert_eq!(params[4].name, "adjust"); + assert!(matches!(params[4].default, DefaultParamValue::Number(1.0))); + } + + #[test] + fn test_jitter_normal_distribution() { + // Normal distribution should have ~95% of values within the width + let jitter = Jitter; + + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.partition_by = vec![]; // No grouping for pure jitter + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("normal".to_string()), + ); + + // Mark pos1 as discrete and pos2 as continuous so only pos1offset is created + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // Normal distribution is centered at 0 + // Values can exceed the width bounds (unlike uniform), but should be centered + let mean: f64 = offsets.iter().sum::() / offsets.len() as f64; + assert!( + mean.abs() < 0.3, // Should be roughly centered (with 4 values, some variance expected) + "Normal distribution mean {} should be close to 0", + mean + ); + } + + #[test] + fn test_jitter_distribution_from_str() { + assert_eq!( + JitterDistribution::from_str("uniform"), + JitterDistribution::Uniform + ); + assert_eq!( + JitterDistribution::from_str("normal"), + JitterDistribution::Normal + ); + assert_eq!( + JitterDistribution::from_str("gaussian"), + JitterDistribution::Normal + ); + assert_eq!( + JitterDistribution::from_str("density"), + JitterDistribution::Density + ); + assert_eq!( + JitterDistribution::from_str("DENSITY"), + JitterDistribution::Density + ); + assert_eq!( + JitterDistribution::from_str("NORMAL"), + JitterDistribution::Normal + ); + assert_eq!( + JitterDistribution::from_str("intensity"), + JitterDistribution::Intensity + ); + assert_eq!( + JitterDistribution::from_str("INTENSITY"), + JitterDistribution::Intensity + ); + assert_eq!( + JitterDistribution::from_str("unknown"), + JitterDistribution::Uniform + ); + } + + #[test] + fn test_jitter_density_requires_one_continuous_axis() { + // Density distribution requires exactly one continuous axis + let jitter = Jitter; + + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.partition_by = vec![]; // No grouping + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("density".to_string()), + ); + + // Test 1: Both axes discrete - should fail + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + let result = jitter.apply_adjustment(&df, &layer, &spec); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("requires exactly one continuous axis")); + + // Test 2: Both axes continuous - should fail + let mut spec = Plot::new(); + spec.scales.push(make_continuous_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + let result = jitter.apply_adjustment(&df, &layer, &spec); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("requires exactly one continuous axis")); + + // Test 3: Only pos2 continuous (pos1 discrete) - should succeed + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + let result = jitter.apply_adjustment(&df, &layer, &spec); + assert!(result.is_ok()); + } + + #[test] + fn test_jitter_density_distribution() { + // Density distribution should create violin-like spread + // Points in dense regions should have larger jitter amplitude (due to density scaling) + let jitter = Jitter; + + // Create data with clear density peaks + // Values 1.0 appears 5 times, values 2.0 and 3.0 appear once each + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m + }; + layer.partition_by = vec![]; + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("density".to_string()), + ); + + // Mark pos1 as discrete and pos2 as continuous (density computed along pos2) + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // Due to randomness, we can't assert exact values + // But we can verify that offsets were generated + assert_eq!(offsets.len(), 7); + } + + #[test] + fn test_jitter_density_per_group() { + // When groups exist, density should be computed separately per group + let jitter = Jitter; + + // Create data with two groups, each with different density distributions + // Group X: dense at 1.0 + // Group Y: dense at 3.0 + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 3.0, 3.0, 3.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => ["X", "X", "X", "Y", "Y", "Y"], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("density".to_string()), + ); + + // Mark pos1 as discrete and pos2 as continuous + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify offsets were created and are within expected bounds + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + assert_eq!(offsets.len(), 6); + + // With 2 groups, we should see separated dodge positions + // Group X centered at negative, Group Y centered at positive + let fill_col = result.column("__ggsql_aes_fill__").unwrap(); + let mut group_x_offsets = vec![]; + let mut group_y_offsets = vec![]; + + for i in 0..result.height() { + let fill_val = fill_col.get(i).unwrap(); + let offset_val = offset.get(i).unwrap(); + if fill_val.to_string().contains("X") { + group_x_offsets.push(offset_val); + } else { + group_y_offsets.push(offset_val); + } + } + + // Groups should be separated due to dodge + let x_mean: f64 = group_x_offsets.iter().sum::() / group_x_offsets.len() as f64; + let y_mean: f64 = group_y_offsets.iter().sum::() / group_y_offsets.len() as f64; + assert!( + x_mean < y_mean, + "Group X mean ({}) should be less than Group Y mean ({})", + x_mean, + y_mean + ); + } + + #[test] + fn test_silverman_bandwidth() { + // Test Silverman bandwidth computation with default adjust=1.0 + let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let bandwidth = super::silverman_bandwidth(&values, 1.0); + // Should return a positive value + assert!(bandwidth > 0.0); + + // Constant data should return fallback + let constant = vec![5.0, 5.0, 5.0, 5.0, 5.0]; + let bandwidth = super::silverman_bandwidth(&constant, 1.0); + assert_eq!(bandwidth, 1.0); + + // Single value should return fallback + let single = vec![5.0]; + let bandwidth = super::silverman_bandwidth(&single, 1.0); + assert_eq!(bandwidth, 1.0); + + // Test adjust parameter + let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let bw_default = super::silverman_bandwidth(&values, 1.0); + let bw_double = super::silverman_bandwidth(&values, 2.0); + assert!( + (bw_double - bw_default * 2.0).abs() < 1e-10, + "Bandwidth with adjust=2.0 should be twice the default" + ); + } + + #[test] + fn test_compute_densities() { + // Test that densities are computed correctly + let values = vec![0.0, 0.0, 0.0, 5.0, 10.0]; + let bandwidth = 1.0; + let densities = super::compute_densities(&values, bandwidth); + + // Values near 0.0 (3 points) should have higher density than value at 10.0 + assert!(densities[0] > densities[4]); + assert!(densities[1] > densities[4]); + assert!(densities[2] > densities[4]); + } + + #[test] + fn test_compute_intensities() { + // Test that intensities differ from densities by not dividing by n + let values = vec![1.0, 1.0, 1.0, 5.0, 10.0]; + let bandwidth = 1.0; + let densities = super::compute_densities(&values, bandwidth); + let intensities = super::compute_intensities(&values, bandwidth); + + // Intensities should be n times larger than densities + let n = values.len() as f64; + for (d, i) in densities.iter().zip(intensities.iter()) { + assert!( + (i - d * n).abs() < 1e-10, + "Intensity {} should be {} times density {}", + i, + n, + d + ); + } + } + + #[test] + fn test_jitter_intensity_requires_one_continuous_axis() { + // Intensity distribution requires exactly one continuous axis + let jitter = Jitter; + + let df = make_test_df(); + let mut layer = make_test_layer(); + layer.partition_by = vec![]; // No grouping + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("intensity".to_string()), + ); + + // Both axes discrete - should fail + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + let result = jitter.apply_adjustment(&df, &layer, &spec); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("requires exactly one continuous axis")); + + // Only pos2 continuous (pos1 discrete) - should succeed + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + let result = jitter.apply_adjustment(&df, &layer, &spec); + assert!(result.is_ok()); + } + + #[test] + fn test_jitter_intensity_distribution() { + // Intensity distribution should create violin-like spread + let jitter = Jitter; + + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m + }; + layer.partition_by = vec![]; + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("intensity".to_string()), + ); + + // Mark pos1 as discrete and pos2 as continuous (density computed along pos2) + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + + // Due to randomness, we can't assert exact values + // But we can verify that offsets were generated + assert_eq!(offsets.len(), 7); + } + + #[test] + fn test_jitter_intensity_global_normalization() { + // Test that intensity uses global max normalization across groups + // Group A has 5 points, Group B has 2 points + // With intensity distribution, Group A should have larger scales (more data) + let jitter = Jitter; + + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "B", "B"], + "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m + }; + layer.partition_by = vec![]; + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("intensity".to_string()), + ); + // Use dodge=false to avoid group separation + layer + .parameters + .insert("dodge".to_string(), ParameterValue::Boolean(false)); + + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = jitter.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Verify offsets were created + let offset = result + .column("__ggsql_aes_pos1offset__") + .unwrap() + .f64() + .unwrap(); + let offsets: Vec = offset.into_iter().filter_map(|v| v).collect(); + assert_eq!(offsets.len(), 7); + } + + #[test] + fn test_jitter_density_explicit_bandwidth() { + // Test that explicit bandwidth parameter is used + let jitter = Jitter; + + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m + }; + layer.partition_by = vec![]; + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("density".to_string()), + ); + // Set explicit bandwidth matching what violin might use + layer + .parameters + .insert("bandwidth".to_string(), ParameterValue::Number(0.5)); + + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let result = jitter.apply_adjustment(&df, &layer, &spec); + assert!(result.is_ok(), "Should succeed with explicit bandwidth"); + } + + #[test] + fn test_jitter_density_adjust_parameter() { + // Test that adjust parameter scales bandwidth + let jitter = Jitter; + + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m + }; + layer.partition_by = vec![]; + layer.parameters.insert( + "distribution".to_string(), + ParameterValue::String("density".to_string()), + ); + // Set adjust parameter (scales auto-computed bandwidth) + layer + .parameters + .insert("adjust".to_string(), ParameterValue::Number(2.0)); + + let mut spec = Plot::new(); + spec.scales.push(make_discrete_scale("pos1")); + spec.scales.push(make_continuous_scale("pos2")); + + let result = jitter.apply_adjustment(&df, &layer, &spec); + assert!(result.is_ok(), "Should succeed with adjust parameter"); + } + + #[test] + fn test_quantile_cont() { + // Test quantile interpolation + let sorted = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + + // Exact quartiles + let q0 = super::quantile_cont(&sorted, 0.0); + assert!((q0 - 1.0).abs() < 1e-10); + + let q1 = super::quantile_cont(&sorted, 1.0); + assert!((q1 - 5.0).abs() < 1e-10); + + // Median + let q50 = super::quantile_cont(&sorted, 0.5); + assert!((q50 - 3.0).abs() < 1e-10); + + // Interpolated values + let q25 = super::quantile_cont(&sorted, 0.25); + assert!((q25 - 2.0).abs() < 1e-10); + + let q75 = super::quantile_cont(&sorted, 0.75); + assert!((q75 - 4.0).abs() < 1e-10); + } +} diff --git a/src/plot/layer/position/mod.rs b/src/plot/layer/position/mod.rs new file mode 100644 index 00000000..182f5672 --- /dev/null +++ b/src/plot/layer/position/mod.rs @@ -0,0 +1,352 @@ +//! Position adjustment trait and implementations +//! +//! This module provides a trait-based design for position adjustments in ggsql. +//! Each position type is implemented as its own struct, mirroring the geom pattern. +//! +//! # Architecture +//! +//! - `PositionType`: Enum for pattern matching and serialization +//! - `PositionTrait`: Trait defining position adjustment behavior +//! - `Position`: Wrapper struct holding a boxed trait object + +mod dodge; +mod identity; +mod jitter; +mod stack; + +use crate::plot::types::{DefaultParam, DefaultParamValue, ParameterValue}; +use crate::plot::ScaleTypeKind; +use crate::{DataFrame, Plot, Result}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Check if an aesthetic has a continuous scale type. +/// Returns None if no scale is defined (defer to data type). +/// This is the shared helper used by position adjustments. +pub fn is_continuous_scale(spec: &Plot, aesthetic: &str) -> Option { + spec.scales + .iter() + .find(|s| s.aesthetic == aesthetic) + .and_then(|s| s.scale_type.as_ref()) + .map(|st| st.scale_type_kind() == ScaleTypeKind::Continuous) +} + +// Re-export position implementations +pub use dodge::{compute_group_indices, Dodge, GroupIndices}; +pub use identity::Identity; +pub use jitter::Jitter; +pub use stack::Stack; + +use super::Layer; + +/// Enum of all position types for pattern matching and serialization +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PositionType { + Identity, + Stack, + Dodge, + Jitter, +} + +impl std::fmt::Display for PositionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + PositionType::Identity => "identity", + PositionType::Stack => "stack", + PositionType::Dodge => "dodge", + PositionType::Jitter => "jitter", + }; + write!(f, "{}", s) + } +} + +/// Core trait for position adjustment behavior +/// +/// Each position type implements this trait. Most methods have sensible defaults; +/// only `position_type()` and `apply_adjustment()` are typically required. +pub trait PositionTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { + /// Returns which position type this is (for pattern matching) + fn position_type(&self) -> PositionType; + + /// Returns default parameter values for this position + fn default_params(&self) -> &'static [DefaultParam] { + &[] + } + + /// Returns valid parameter names for SETTING validation + fn valid_settings(&self) -> Vec<&'static str> { + self.default_params().iter().map(|p| p.name).collect() + } + + /// Whether this position creates a pos1offset column + fn creates_pos1offset(&self) -> bool { + false + } + + /// Whether this position creates a pos2offset column + fn creates_pos2offset(&self) -> bool { + false + } + + /// Apply the position adjustment to the DataFrame + /// + /// Returns the adjusted DataFrame and optionally an adjusted width + /// (for position types like dodge that modify element width) + fn apply_adjustment( + &self, + df: &DataFrame, + layer: &Layer, + spec: &Plot, + ) -> Result<(DataFrame, Option)>; +} + +/// Wrapper struct for position trait objects +/// +/// This provides a convenient interface for working with positions while hiding +/// the complexity of trait objects. +#[derive(Clone)] +pub struct Position(Arc); + +impl Position { + /// Create an Identity position (no adjustment) + pub fn identity() -> Self { + Self(Arc::new(Identity)) + } + + /// Create a Stack position + pub fn stack() -> Self { + Self(Arc::new(Stack)) + } + + /// Create a Dodge position + pub fn dodge() -> Self { + Self(Arc::new(Dodge)) + } + + /// Create a Jitter position + pub fn jitter() -> Self { + Self(Arc::new(Jitter)) + } + + /// Parse a position from a string value + pub fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "stack" => Self::stack(), + "dodge" => Self::dodge(), + "jitter" => Self::jitter(), + _ => Self::identity(), + } + } + + /// Create a Position from a PositionType + pub fn from_type(t: PositionType) -> Self { + match t { + PositionType::Identity => Self::identity(), + PositionType::Stack => Self::stack(), + PositionType::Dodge => Self::dodge(), + PositionType::Jitter => Self::jitter(), + } + } + + /// Get the position type + pub fn position_type(&self) -> PositionType { + self.0.position_type() + } + + /// Get default parameters + pub fn default_params(&self) -> &'static [DefaultParam] { + self.0.default_params() + } + + /// Get valid settings for SETTING validation + pub fn valid_settings(&self) -> Vec<&'static str> { + self.0.valid_settings() + } + + /// Check if this position creates a pos1offset column + pub fn creates_pos1offset(&self) -> bool { + self.0.creates_pos1offset() + } + + /// Check if this position creates a pos2offset column + pub fn creates_pos2offset(&self) -> bool { + self.0.creates_pos2offset() + } + + /// Apply the position adjustment + pub fn apply_adjustment( + &self, + df: &DataFrame, + layer: &Layer, + spec: &Plot, + ) -> Result<(DataFrame, Option)> { + self.0.apply_adjustment(df, layer, spec) + } + + /// Apply default position parameter values to a layer + /// + /// For each parameter defined in default_params(), if the layer doesn't + /// already have that parameter set, insert the default value. + pub fn apply_defaults_to_layer(&self, layer: &mut Layer) { + for param in self.default_params() { + if !layer.parameters.contains_key(param.name) { + let value = match ¶m.default { + DefaultParamValue::String(s) => ParameterValue::String(s.to_string()), + DefaultParamValue::Number(n) => ParameterValue::Number(*n), + DefaultParamValue::Boolean(b) => ParameterValue::Boolean(*b), + DefaultParamValue::Null => continue, + }; + layer.parameters.insert(param.name.to_string(), value); + } + } + } +} + +impl Default for Position { + fn default() -> Self { + Self::identity() + } +} + +impl std::fmt::Debug for Position { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Position::{:?}", self.position_type()) + } +} + +impl std::fmt::Display for Position { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl PartialEq for Position { + fn eq(&self, other: &Self) -> bool { + self.position_type() == other.position_type() + } +} + +impl Eq for Position {} + +impl Serialize for Position { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + self.position_type().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Position { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let position_type = PositionType::deserialize(deserializer)?; + Ok(Position::from_type(position_type)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_position_creation() { + let identity = Position::identity(); + assert_eq!(identity.position_type(), PositionType::Identity); + + let stack = Position::stack(); + assert_eq!(stack.position_type(), PositionType::Stack); + + let dodge = Position::dodge(); + assert_eq!(dodge.position_type(), PositionType::Dodge); + } + + #[test] + fn test_position_equality() { + let p1 = Position::identity(); + let p2 = Position::identity(); + let p3 = Position::stack(); + + assert_eq!(p1, p2); + assert_ne!(p1, p3); + } + + #[test] + fn test_position_display() { + assert_eq!(format!("{}", Position::identity()), "identity"); + assert_eq!(format!("{}", Position::stack()), "stack"); + assert_eq!(format!("{}", Position::dodge()), "dodge"); + } + + #[test] + fn test_position_from_str() { + assert_eq!( + Position::from_str("stack").position_type(), + PositionType::Stack + ); + assert_eq!( + Position::from_str("dodge").position_type(), + PositionType::Dodge + ); + assert_eq!( + Position::from_str("jitter").position_type(), + PositionType::Jitter + ); + assert_eq!( + Position::from_str("unknown").position_type(), + PositionType::Identity + ); + } + + #[test] + fn test_position_serialization() { + let dodge = Position::dodge(); + let json = serde_json::to_string(&dodge).unwrap(); + assert_eq!(json, "\"dodge\""); + + let deserialized: Position = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.position_type(), PositionType::Dodge); + } + + #[test] + fn test_creates_pos1offset() { + assert!(!Position::identity().creates_pos1offset()); + assert!(!Position::stack().creates_pos1offset()); + assert!(Position::dodge().creates_pos1offset()); + assert!(Position::jitter().creates_pos1offset()); + } + + #[test] + fn test_creates_pos2offset() { + assert!(!Position::identity().creates_pos2offset()); + assert!(!Position::stack().creates_pos2offset()); + assert!(Position::dodge().creates_pos2offset()); // Dodge now supports vertical/2D + assert!(Position::jitter().creates_pos2offset()); + } + + #[test] + fn test_is_continuous_scale() { + use crate::plot::{Scale, ScaleType}; + + // No scale defined - returns None + let spec = crate::plot::Plot::new(); + assert!(is_continuous_scale(&spec, "pos1").is_none()); + + // Continuous scale defined + let mut spec = crate::plot::Plot::new(); + let mut scale = Scale::new("pos1"); + scale.scale_type = Some(ScaleType::continuous()); + spec.scales.push(scale); + assert_eq!(is_continuous_scale(&spec, "pos1"), Some(true)); + + // Discrete scale defined + let mut spec = crate::plot::Plot::new(); + let mut scale = Scale::new("pos1"); + scale.scale_type = Some(ScaleType::discrete()); + spec.scales.push(scale); + assert_eq!(is_continuous_scale(&spec, "pos1"), Some(false)); + } +} diff --git a/src/plot/layer/position/stack.rs b/src/plot/layer/position/stack.rs new file mode 100644 index 00000000..17b91128 --- /dev/null +++ b/src/plot/layer/position/stack.rs @@ -0,0 +1,765 @@ +//! Stack position adjustments +//! +//! Implements stacking of elements: normal, fill (normalized), and center. +//! +//! Stacking automatically detects which axis is continuous and stacks accordingly: +//! - If pos2 is continuous → stack vertically (modify pos2/pos2end) +//! - If pos1 is continuous and pos2 is discrete → stack horizontally (modify pos1/pos1end) + +use super::{is_continuous_scale, Layer, PositionTrait, PositionType}; +use crate::plot::types::{DefaultParam, DefaultParamValue, ParameterValue}; +use crate::{naming, DataFrame, GgsqlError, Plot, Result}; +use polars::prelude::*; + +/// Stack mode for position adjustments +#[derive(Clone, Copy)] +enum StackMode { + /// Normal stacking (cumsum from 0) + Normal, + /// Normalized stacking (cumsum / total, then scaled to target) + Fill(f64), + /// Centered stacking (cumsum - total/2, centered at 0) + Center, +} + +/// Stack position - stack elements vertically +#[derive(Debug, Clone, Copy)] +pub struct Stack; + +impl std::fmt::Display for Stack { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "stack") + } +} + +impl PositionTrait for Stack { + fn position_type(&self) -> PositionType { + PositionType::Stack + } + + fn default_params(&self) -> &'static [DefaultParam] { + &[ + DefaultParam { + name: "center", + default: DefaultParamValue::Boolean(false), + }, + DefaultParam { + name: "total", + default: DefaultParamValue::Null, + }, + ] + } + + fn apply_adjustment( + &self, + df: &DataFrame, + layer: &Layer, + spec: &Plot, + ) -> Result<(DataFrame, Option)> { + let center = layer + .parameters + .get("center") + .and_then(|v| match v { + ParameterValue::Boolean(b) => Some(*b), + _ => None, + }) + .unwrap_or(false); + + let total = layer.parameters.get("total").and_then(|v| match v { + ParameterValue::Number(n) => Some(*n), + _ => None, + }); + + let mode = match (center, total) { + (true, _) => StackMode::Center, + (false, Some(target)) => StackMode::Fill(target), + (false, None) => StackMode::Normal, + }; + Ok((apply_stack(df, layer, spec, mode)?, None)) + } +} + +/// Direction for stacking +#[derive(Clone, Copy)] +enum StackDirection { + /// Stack vertically (modify pos2/pos2end, group by pos1) + Vertical, + /// Stack horizontally (modify pos1/pos1end, group by pos2) + Horizontal, +} + +/// Check if an axis is stackable. +/// +/// An axis is stackable if: +/// 1. It has a continuous scale (scale type is always known after infer_scale_types_from_data) +/// 2. It has a pos/posend pair (e.g., pos2/pos2end) or posmin/posmax pair +/// 3. Every row has a zero baseline in one of the range columns +fn is_axis_stackable(spec: &Plot, layer: &Layer, df: &DataFrame, axis: &str) -> bool { + // Must be continuous (scale type always known after infer_scale_types_from_data) + if is_continuous_scale(spec, axis) != Some(true) { + return false; + } + + // Check for pos/posend pair (e.g., pos2/pos2end) + let end_aesthetic = format!("{}end", axis); + let has_end_pair = + layer.mappings.contains_key(axis) && layer.mappings.contains_key(&end_aesthetic); + + // Check for posmin/posmax pair (e.g., pos2min/pos2max) + let min_aesthetic = format!("{}min", axis); + let max_aesthetic = format!("{}max", axis); + let has_minmax_pair = + layer.mappings.contains_key(&min_aesthetic) && layer.mappings.contains_key(&max_aesthetic); + + if !has_end_pair && !has_minmax_pair { + return false; + } + + // Check that each row has zero baseline in one of the range columns + if has_end_pair { + let pos_col = naming::aesthetic_column(axis); + let end_col = naming::aesthetic_column(&end_aesthetic); + if has_zero_baseline_per_row(df, &pos_col, &end_col) { + return true; + } + } + if has_minmax_pair { + let min_col = naming::aesthetic_column(&min_aesthetic); + let max_col = naming::aesthetic_column(&max_aesthetic); + if has_zero_baseline_per_row(df, &min_col, &max_col) { + return true; + } + } + false +} + +/// Check that for every row, at least one of the two columns is zero. +fn has_zero_baseline_per_row(df: &DataFrame, col_a: &str, col_b: &str) -> bool { + let (Ok(a), Ok(b)) = (df.column(col_a), df.column(col_b)) else { + return false; + }; + let (Ok(a_f64), Ok(b_f64)) = (a.f64(), b.f64()) else { + return false; + }; + // For each row, either a or b must be 0 + a_f64 + .into_iter() + .zip(b_f64.into_iter()) + .all(|(a_val, b_val)| a_val == Some(0.0) || b_val == Some(0.0)) +} + +/// Determine stacking direction based on scale types and axis configuration. +/// +/// An axis is stackable if it's continuous AND has pos/posend or posmin/posmax pairs +/// AND has zero baseline per row. +/// +/// Returns: +/// - Vertical if pos2 is stackable and pos1 is not +/// - Horizontal if pos1 is stackable and pos2 is not +/// - Vertical as default (for backward compatibility) +fn determine_stack_direction(spec: &Plot, layer: &Layer, df: &DataFrame) -> Option { + let pos1_stackable = is_axis_stackable(spec, layer, df, "pos1"); + let pos2_stackable = is_axis_stackable(spec, layer, df, "pos2"); + + match (pos1_stackable, pos2_stackable) { + (false, true) => Some(StackDirection::Vertical), + (true, false) => Some(StackDirection::Horizontal), + _ => Some(StackDirection::Vertical), // Default + } +} + +/// Apply stack position adjustment. +/// +/// Automatically detects stacking direction based on scale types: +/// - Vertical stacking: for each unique pos1 value, compute cumulative sums of pos2 +/// - Horizontal stacking: for each unique pos2 value, compute cumulative sums of pos1 +/// +/// Modes: +/// - Normal: standard stacking from 0 +/// - Fill: normalized to sum to 1 (100% stacked) +/// - Center: centered around 0 (streamgraph style) +fn apply_stack(df: &DataFrame, layer: &Layer, spec: &Plot, mode: StackMode) -> Result { + // Determine stacking direction + let direction = determine_stack_direction(spec, layer, df).unwrap_or(StackDirection::Vertical); + + // Set up column names based on direction + let (stack_col, stack_end_col, group_col) = match direction { + StackDirection::Vertical => ( + naming::aesthetic_column("pos2"), + naming::aesthetic_column("pos2end"), + naming::aesthetic_column("pos1"), + ), + StackDirection::Horizontal => ( + naming::aesthetic_column("pos1"), + naming::aesthetic_column("pos1end"), + naming::aesthetic_column("pos2"), + ), + }; + + // Check if required columns exist + if df.column(&stack_col).is_err() { + return Ok(df.clone()); + } + + // Stacking currently only supports non-negative values + let min_result = df + .clone() + .lazy() + .select([col(&stack_col).min()]) + .collect() + .map_err(|e| GgsqlError::InternalError(format!("Failed to check min value: {}", e)))?; + + if let Some(min_col) = min_result.get_columns().first() { + if let Ok(min_val) = min_col.get(0) { + if let Ok(min) = min_val.try_extract::() { + if min < 0.0 { + let axis = match direction { + StackDirection::Vertical => "y", + StackDirection::Horizontal => "x", + }; + return Err(GgsqlError::ValidationError(format!( + "position 'stack' requires non-negative {} values", + axis + ))); + } + } + } + } + + // Convert to lazy for transformations + let lf = df.clone().lazy(); + + // Sort by group column and partition_by columns to ensure consistent stacking order + // This ensures that within each group (e.g., x position), the stacking order is + // consistent even if data arrives in different orders or has missing values + let mut sort_cols = vec![col(&group_col)]; + for partition_col in &layer.partition_by { + sort_cols.push(col(partition_col)); + } + let sort_options = SortMultipleOptions::default(); + let lf = lf.sort_by_exprs(&sort_cols, sort_options); + + // For stacking, compute cumulative sums within each group: + // 1. stack_col = cumulative sum (the bar top/end) + // 2. stack_end_col = lag(stack_col, 1, 0) - the bar bottom/start (previous stack top) + // The cumsum naturally stacks across the grouping column values + + // Treat NA heights as 0 for stacking + let lf = lf.with_column(col(&stack_col).fill_null(lit(0.0)).alias(&stack_col)); + + match mode { + StackMode::Normal => { + let stack_expr = col(&stack_col) + .cum_sum(false) + .over([col(&group_col)]) + .alias(&stack_col); + + let stack_end_expr = col(&stack_col) + .cum_sum(false) + .shift(lit(1)) + .fill_null(lit(0.0)) + .over([col(&group_col)]) + .alias(&stack_end_col); + + lf.with_columns([stack_expr, stack_end_expr]) + .collect() + .map_err(|e| { + GgsqlError::InternalError(format!("Stack position adjustment failed: {}", e)) + }) + } + StackMode::Fill(target) => { + // Normalize by total sum within each group, then scale to target + let lf = lf + .with_column( + col(&stack_col) + .sum() + .over([col(&group_col)]) + .alias("__total__"), + ) + .with_column( + col(&stack_col) + .cum_sum(false) + .over([col(&group_col)]) + .alias("__cumsum__"), + ) + .with_column( + col(&stack_col) + .cum_sum(false) + .shift(lit(1)) + .fill_null(lit(0.0)) + .over([col(&group_col)]) + .alias("__cumsum_lag__"), + ); + + let stack_expr = (col("__cumsum__") / col("__total__") * lit(target)).alias(&stack_col); + let stack_end_expr = + (col("__cumsum_lag__") / col("__total__") * lit(target)).alias(&stack_end_col); + + let result = lf + .with_columns([stack_expr, stack_end_expr]) + .collect() + .map_err(|e| { + GgsqlError::InternalError(format!("Stack position adjustment failed: {}", e)) + })?; + + result + .drop("__total__") + .and_then(|df| df.drop("__cumsum__")) + .and_then(|df| df.drop("__cumsum_lag__")) + .map_err(|e| { + GgsqlError::InternalError(format!("Failed to drop temp column: {}", e)) + }) + } + StackMode::Center => { + // Center around 0 by subtracting half the total + let lf = lf + .with_column( + (col(&stack_col).sum() / lit(2.0)) + .over([col(&group_col)]) + .alias("__half_total__"), + ) + .with_column( + col(&stack_col) + .cum_sum(false) + .over([col(&group_col)]) + .alias("__cumsum__"), + ) + .with_column( + col(&stack_col) + .cum_sum(false) + .shift(lit(1)) + .fill_null(lit(0.0)) + .over([col(&group_col)]) + .alias("__cumsum_lag__"), + ); + + let stack_expr = (col("__cumsum__") - col("__half_total__")).alias(&stack_col); + let stack_end_expr = + (col("__cumsum_lag__") - col("__half_total__")).alias(&stack_end_col); + + let result = lf + .with_columns([stack_expr, stack_end_expr]) + .collect() + .map_err(|e| { + GgsqlError::InternalError(format!("Stack position adjustment failed: {}", e)) + })?; + + result + .drop("__half_total__") + .and_then(|df| df.drop("__cumsum__")) + .and_then(|df| df.drop("__cumsum_lag__")) + .map_err(|e| { + GgsqlError::InternalError(format!("Failed to drop temp column: {}", e)) + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::plot::layer::Geom; + use crate::plot::{AestheticValue, Mappings}; + + fn make_test_df() -> DataFrame { + df! { + "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + } + .unwrap() + } + + fn make_test_layer() -> Layer { + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + layer + } + + #[test] + fn test_stack_cumsum() { + let stack = Stack; + assert_eq!(stack.position_type(), PositionType::Stack); + + let df = make_test_df(); + let layer = make_test_layer(); + let spec = Plot::new(); + + let (result, width) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + assert!(width.is_none()); + let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); + let pos2end_col = result.column("__ggsql_aes_pos2end__").unwrap(); + + assert!(pos2_col.f64().is_ok() || pos2_col.i64().is_ok()); + assert!(pos2end_col.f64().is_ok() || pos2end_col.i64().is_ok()); + } + + #[test] + fn test_stack_default_params() { + let stack = Stack; + let params = stack.default_params(); + assert_eq!(params.len(), 2); + assert_eq!(params[0].name, "center"); + assert!(matches!( + params[0].default, + DefaultParamValue::Boolean(false) + )); + assert_eq!(params[1].name, "total"); + assert!(matches!(params[1].default, DefaultParamValue::Null)); + } + + #[test] + fn test_stack_center_parameter() { + let stack = Stack; + let df = make_test_df(); + let spec = Plot::new(); + + // Test default (center = false) - should stack from 0 + let layer = make_test_layer(); + let (result_normal, _) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Test with center = true - should center around 0 + let mut layer_centered = make_test_layer(); + layer_centered + .parameters + .insert("center".to_string(), ParameterValue::Boolean(true)); + let (result_centered, _) = stack.apply_adjustment(&df, &layer_centered, &spec).unwrap(); + + // Normal stacking should have pos2end starting at 0 + let pos2end_normal = result_normal.column("__ggsql_aes_pos2end__").unwrap(); + let first_normal = pos2end_normal.get(0).unwrap(); + // First element's pos2end should be 0 for normal stack + if let polars::prelude::AnyValue::Float64(v) = first_normal { + assert_eq!(v, 0.0); + } + + // Centered stacking should have negative values + let pos2end_centered = result_centered.column("__ggsql_aes_pos2end__").unwrap(); + let first_centered = pos2end_centered.get(0).unwrap(); + // First element's pos2end should be negative for centered stack (shifted by -total/2) + if let polars::prelude::AnyValue::Float64(v) = first_centered { + assert!( + v < 0.0, + "Centered stack should have negative pos2end for first element" + ); + } + } + + fn make_continuous_scale(aesthetic: &str) -> crate::plot::Scale { + let mut scale = crate::plot::Scale::new(aesthetic); + scale.scale_type = Some(crate::plot::ScaleType::continuous()); + scale + } + + fn make_discrete_scale(aesthetic: &str) -> crate::plot::Scale { + let mut scale = crate::plot::Scale::new(aesthetic); + scale.scale_type = Some(crate::plot::ScaleType::discrete()); + scale + } + + #[test] + fn test_stack_vertical_when_pos2_continuous() { + // Default case: pos2 continuous -> stack vertically + let stack = Stack; + let df = make_test_df(); + let layer = make_test_layer(); + + // Mark pos2 as continuous + let mut spec = Plot::new(); + spec.scales.push(make_continuous_scale("pos2")); + + let (result, _) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + // pos2 should be modified (stacked) + assert!(result.column("__ggsql_aes_pos2__").is_ok()); + assert!(result.column("__ggsql_aes_pos2end__").is_ok()); + } + + #[test] + fn test_stack_horizontal_when_pos1_continuous() { + // When pos1 is continuous and pos2 is discrete -> stack horizontally + let stack = Stack; + + // Create data with numeric pos1 values and pos1end column with zero baselines + let df = df! { + "__ggsql_aes_pos1__" => [10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos1end__" => [0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_pos2__" => ["A", "A", "B", "B"], + "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos1end", + AestheticValue::standard_column("__ggsql_aes_pos1end__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + + // Mark pos1 as continuous, pos2 as discrete + let mut spec = Plot::new(); + spec.scales.push(make_continuous_scale("pos1")); + spec.scales.push(make_discrete_scale("pos2")); + + let (result, _) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + // pos1 should be modified (stacked horizontally) + assert!( + result.column("__ggsql_aes_pos1__").is_ok(), + "pos1 column should exist" + ); + assert!( + result.column("__ggsql_aes_pos1end__").is_ok(), + "pos1end column should be created for horizontal stacking" + ); + + // Verify stacking occurred - values should be cumulative sums + let pos1_col = result.column("__ggsql_aes_pos1__").unwrap(); + let pos1_vals: Vec = pos1_col + .f64() + .unwrap() + .into_iter() + .filter_map(|v| v) + .collect(); + + // Should have cumulative sums (10, 30, 15, 40) for groups A and B + assert!( + pos1_vals.iter().any(|&v| v > 20.0), + "Should have cumulative values > original max, got {:?}", + pos1_vals + ); + } + + #[test] + fn test_stack_total_parameter() { + let stack = Stack; + let df = make_test_df(); + let spec = Plot::new(); + + // Test with total = 100 (percentage stacking) + let mut layer = make_test_layer(); + layer + .parameters + .insert("total".to_string(), ParameterValue::Number(100.0)); + + let (result, _) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + // pos2 should sum to 100 within each group (A and B) + let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); + let pos2_vals: Vec = pos2_col + .f64() + .unwrap() + .into_iter() + .filter_map(|v| v) + .collect(); + + // For group A: values 10, 20 -> normalized: 10/30, 20/30 -> cumsum: 10/30, 30/30 + // Multiplied by 100: ~33.33, 100 + // For group B: values 15, 25 -> normalized: 15/40, 25/40 -> cumsum: 15/40, 40/40 + // Multiplied by 100: 37.5, 100 + // So max values should be 100 + let max_val = pos2_vals.iter().cloned().fold(f64::MIN, f64::max); + assert!( + (max_val - 100.0).abs() < 0.01, + "Expected max value ~100 for normalized stack, got {}", + max_val + ); + } + + #[test] + fn test_stack_total_parameter_arbitrary_value() { + let stack = Stack; + let df = make_test_df(); + let spec = Plot::new(); + + // Test with total = 1 (normalized to 1, like old stack_fill behavior) + let mut layer = make_test_layer(); + layer + .parameters + .insert("total".to_string(), ParameterValue::Number(1.0)); + + let (result, _) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); + let pos2_vals: Vec = pos2_col + .f64() + .unwrap() + .into_iter() + .filter_map(|v| v) + .collect(); + + // Max values should be 1 (normalized to sum to 1) + let max_val = pos2_vals.iter().cloned().fold(f64::MIN, f64::max); + assert!( + (max_val - 1.0).abs() < 0.01, + "Expected max value ~1 for normalized stack with total=1, got {}", + max_val + ); + } + + #[test] + fn test_stack_na_values_treated_as_zero() { + let stack = Stack; + + // Create data with NA values in pos2 + let df = df! { + "__ggsql_aes_pos1__" => ["A", "A", "A", "B", "B", "B"], + "__ggsql_aes_pos2__" => [Some(10.0), None, Some(20.0), Some(15.0), Some(25.0), None], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => ["X", "Y", "Z", "X", "Y", "Z"], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + + let spec = Plot::new(); + let (result, _) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + // Get pos2 values - should have no nulls after stacking + let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); + let pos2_vals: Vec> = pos2_col.f64().unwrap().into_iter().collect(); + + // All values should be non-null (NA treated as 0) + assert!( + pos2_vals.iter().all(|v| v.is_some()), + "Expected no null values after stacking, got {:?}", + pos2_vals + ); + + // For group A: 10, 0 (NA), 20 -> cumsum: 10, 10, 30 + // For group B: 15, 25, 0 (NA) -> cumsum: 15, 40, 40 + // Check that the cumsum for group A ends at 30 (10 + 0 + 20) + let group_a_max = pos2_vals[2].unwrap(); // Third row is last for group A + assert!( + (group_a_max - 30.0).abs() < 0.01, + "Expected group A max ~30 (NA treated as 0), got {}", + group_a_max + ); + } + + #[test] + fn test_stack_consistent_order_with_shuffled_data() { + let stack = Stack; + + // Create data in shuffled order - categories not in order within groups + let df = df! { + "__ggsql_aes_pos1__" => ["A", "B", "A", "B", "A", "B"], + "__ggsql_aes_pos2__" => [10.0, 15.0, 30.0, 35.0, 20.0, 25.0], + "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => ["X", "X", "Z", "Z", "Y", "Y"], + } + .unwrap(); + + let mut layer = Layer::new(Geom::bar()); + layer.mappings = { + let mut m = Mappings::new(); + m.insert( + "pos1", + AestheticValue::standard_column("__ggsql_aes_pos1__"), + ); + m.insert( + "pos2", + AestheticValue::standard_column("__ggsql_aes_pos2__"), + ); + m.insert( + "pos2end", + AestheticValue::standard_column("__ggsql_aes_pos2end__"), + ); + m.insert( + "fill", + AestheticValue::standard_column("__ggsql_aes_fill__"), + ); + m + }; + layer.partition_by = vec!["__ggsql_aes_fill__".to_string()]; + + let spec = Plot::new(); + let (result, _) = stack.apply_adjustment(&df, &layer, &spec).unwrap(); + + // After sorting by pos1 then fill, the order should be: + // A-X(10), A-Y(20), A-Z(30) -> cumsum: 10, 30, 60 + // B-X(15), B-Y(25), B-Z(35) -> cumsum: 15, 40, 75 + + // Check that data is sorted consistently + let pos1_col = result.column("__ggsql_aes_pos1__").unwrap(); + let fill_col = result.column("__ggsql_aes_fill__").unwrap(); + let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); + + let pos1_vals: Vec<&str> = pos1_col.str().unwrap().into_iter().flatten().collect(); + let fill_vals: Vec<&str> = fill_col.str().unwrap().into_iter().flatten().collect(); + let pos2_vals: Vec = pos2_col.f64().unwrap().into_iter().flatten().collect(); + + // Should be sorted: A-X, A-Y, A-Z, B-X, B-Y, B-Z + assert_eq!(pos1_vals, vec!["A", "A", "A", "B", "B", "B"]); + assert_eq!(fill_vals, vec!["X", "Y", "Z", "X", "Y", "Z"]); + + // Group A cumsum: 10, 30, 60 + assert!((pos2_vals[0] - 10.0).abs() < 0.01, "A-X should be 10"); + assert!((pos2_vals[1] - 30.0).abs() < 0.01, "A-Y should be 30"); + assert!((pos2_vals[2] - 60.0).abs() < 0.01, "A-Z should be 60"); + + // Group B cumsum: 15, 40, 75 + assert!((pos2_vals[3] - 15.0).abs() < 0.01, "B-X should be 15"); + assert!((pos2_vals[4] - 40.0).abs() < 0.01, "B-Y should be 40"); + assert!((pos2_vals[5] - 75.0).abs() < 0.01, "B-Z should be 75"); + } +} diff --git a/src/plot/scale/scale_type/binned.rs b/src/plot/scale/scale_type/binned.rs index 1a655eb4..40779a90 100644 --- a/src/plot/scale/scale_type/binned.rs +++ b/src/plot/scale/scale_type/binned.rs @@ -345,9 +345,15 @@ impl ScaleTypeTrait for Binned { .iter() .map(|elem| resolved_transform.parse_value(elem)) .collect(); - // Filter breaks to input range (explicit breaks always filtered) - let filtered = if let Some(ref range) = scale.input_range { - super::super::super::breaks::filter_breaks_to_range(&converted, range) + // Only filter breaks to input range if BOTH explicit breaks AND explicit input range + // were provided. If the user only provided breaks (no FROM clause), their breaks + // should be used as-is to define the bin boundaries (input_range is derived later). + let filtered = if scale.explicit_input_range { + if let Some(ref range) = scale.input_range { + super::super::super::breaks::filter_breaks_to_range(&converted, range) + } else { + converted + } } else { converted }; @@ -1935,4 +1941,119 @@ mod tests { assert!(err.contains("Boolean")); assert!(err.contains("DISCRETE")); } + + // ========================================================================= + // Break Resolution Tests + // ========================================================================= + + #[test] + fn test_explicit_breaks_preserved_without_explicit_range() { + // Regression test: explicit breaks extending beyond data range should NOT + // be filtered when no explicit FROM clause is provided. + // Issue: breaks like [2600, 3550, 4050, 4750, 6400] were getting terminal + // breaks removed when data range was ~[2700, 6300]. + use super::ScaleTypeTrait; + use polars::prelude::DataType; + + let binned = Binned; + let mut scale = Scale::new("fill"); + + // User provides explicit breaks that extend beyond data range + scale.properties.insert( + "breaks".to_string(), + ParameterValue::Array(vec![ + ArrayElement::Number(2600.0), + ArrayElement::Number(3550.0), + ArrayElement::Number(4050.0), + ArrayElement::Number(4750.0), + ArrayElement::Number(6400.0), + ]), + ); + // No explicit input range (no FROM clause) + scale.explicit_input_range = false; + + // Data context with narrower range than breaks + let context = ScaleDataContext { + range: Some(InputRange::Continuous(vec![ + ArrayElement::Number(2700.0), + ArrayElement::Number(6300.0), + ])), + dtype: Some(DataType::Float64), + is_discrete: false, + }; + + binned.resolve(&mut scale, &context, "fill").unwrap(); + + // All 5 breaks should be preserved (not filtered to data range) + let resolved_breaks = match scale.properties.get("breaks") { + Some(ParameterValue::Array(arr)) => arr.clone(), + _ => panic!("breaks should be an array"), + }; + assert_eq!( + resolved_breaks.len(), + 5, + "All explicit breaks should be preserved: {:?}", + resolved_breaks + ); + + // Verify the exact values + let values: Vec = resolved_breaks.iter().filter_map(|e| e.to_f64()).collect(); + assert_eq!(values, vec![2600.0, 3550.0, 4050.0, 4750.0, 6400.0]); + } + + #[test] + fn test_explicit_breaks_filtered_with_explicit_range() { + // When BOTH explicit breaks AND explicit range are provided, + // breaks should be filtered to the range. + use super::ScaleTypeTrait; + use polars::prelude::DataType; + + let binned = Binned; + let mut scale = Scale::new("fill"); + + // User provides explicit breaks + scale.properties.insert( + "breaks".to_string(), + ParameterValue::Array(vec![ + ArrayElement::Number(2600.0), + ArrayElement::Number(3550.0), + ArrayElement::Number(4050.0), + ArrayElement::Number(4750.0), + ArrayElement::Number(6400.0), + ]), + ); + // WITH explicit input range (FROM clause) + scale.input_range = Some(vec![ + ArrayElement::Number(3000.0), + ArrayElement::Number(6000.0), + ]); + scale.explicit_input_range = true; + + let context = ScaleDataContext { + range: Some(InputRange::Continuous(vec![ + ArrayElement::Number(2700.0), + ArrayElement::Number(6300.0), + ])), + dtype: Some(DataType::Float64), + is_discrete: false, + }; + + binned.resolve(&mut scale, &context, "fill").unwrap(); + + // Breaks should be filtered to [3000, 6000] + // Only 3550, 4050, 4750 are within range + // Plus range boundaries 3000 and 6000 are added + let resolved_breaks = match scale.properties.get("breaks") { + Some(ParameterValue::Array(arr)) => arr.clone(), + _ => panic!("breaks should be an array"), + }; + + // Should have: 3000 (boundary), 3550, 4050, 4750, 6000 (boundary) + let values: Vec = resolved_breaks.iter().filter_map(|e| e.to_f64()).collect(); + assert_eq!( + values, + vec![3000.0, 3550.0, 4050.0, 4750.0, 6000.0], + "Breaks should be filtered to explicit range with boundaries added" + ); + } } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 25771b79..186c0fe4 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -929,4 +929,119 @@ mod tests { ); } } + + #[test] + fn test_stacked_bar_chart() { + // Test stacked bar chart via position => 'stack' + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let query = r#" + SELECT * FROM (VALUES + ('A', 'X', 10), + ('A', 'Y', 20), + ('B', 'X', 15), + ('B', 'Y', 25) + ) AS t(cat, grp, val) + VISUALISE + DRAW bar MAPPING cat AS x, val AS y, grp AS fill + SETTING position => 'stack' + "#; + + let spec = reader.execute(query).unwrap(); + let writer = VegaLiteWriter::new(); + let result = writer.render(&spec).unwrap(); + + let json: serde_json::Value = serde_json::from_str(&result).unwrap(); + let layer = json["layer"].as_array().unwrap().first().unwrap(); + + // Verify y and y2 encodings exist (stacked bars use y/y2 for range) + let encoding = &layer["encoding"]; + assert!(encoding["y"].is_object(), "Should have y encoding"); + assert!( + encoding["y2"].is_object(), + "Should have y2 encoding for stacked bars" + ); + + // Verify Vega-Lite stacking is disabled (we handle it ourselves) + assert!( + encoding["y"]["stack"].is_null(), + "y encoding should have stack: null to disable VL stacking. Got: {}", + serde_json::to_string_pretty(&encoding["y"]).unwrap() + ); + } + + #[test] + fn test_dodged_bar_chart() { + // Test dodged bar chart via position => 'dodge' + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let query = r#" + SELECT * FROM (VALUES + ('A', 'X', 10), + ('A', 'Y', 20), + ('B', 'X', 15), + ('B', 'Y', 25) + ) AS t(cat, grp, val) + VISUALISE + DRAW bar MAPPING cat AS x, val AS y, grp AS fill + SETTING position => 'dodge' + "#; + + let spec = reader.execute(query).unwrap(); + let writer = VegaLiteWriter::new(); + let result = writer.render(&spec).unwrap(); + + let json: serde_json::Value = serde_json::from_str(&result).unwrap(); + let layer = json["layer"].as_array().unwrap().first().unwrap(); + + // Verify xOffset encoding exists (dodged bars use xOffset for displacement) + let encoding = &layer["encoding"]; + assert!( + encoding["xOffset"].is_object(), + "Should have xOffset encoding for dodged bars. Encoding: {}", + serde_json::to_string_pretty(encoding).unwrap() + ); + + // Verify bar width uses bandwidth expression with adjusted_width for dodged bars + // For 2 groups with default width 0.9: adjusted_width = 0.9 / 2 = 0.45 + let mark = &layer["mark"]; + let width_expr = mark["width"]["expr"].as_str(); + assert!( + width_expr.is_some(), + "Dodged bars should have expression-based width. Mark: {}", + serde_json::to_string_pretty(mark).unwrap() + ); + let expr = width_expr.unwrap(); + assert!( + expr.contains("bandwidth('x')") && expr.contains("0.45"), + "Width expression should use bandwidth('x') * adjusted_width, got: {}", + expr + ); + } + + #[test] + fn test_position_identity_default() { + // Test that identity position (default) doesn't modify data + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let query = r#" + SELECT * FROM (VALUES + ('A', 10), + ('B', 20) + ) AS t(cat, val) + VISUALISE + DRAW bar MAPPING cat AS x, val AS y + "#; + + let spec = reader.execute(query).unwrap(); + let writer = VegaLiteWriter::new(); + let result = writer.render(&spec).unwrap(); + + let json: serde_json::Value = serde_json::from_str(&result).unwrap(); + let layer = json["layer"].as_array().unwrap().first().unwrap(); + + // Verify no xOffset encoding (identity position) + let encoding = &layer["encoding"]; + assert!( + encoding.get("xOffset").is_none(), + "Identity position should not have xOffset encoding" + ); + } } diff --git a/src/writer/vegalite/layer.rs b/src/writer/vegalite/layer.rs index 311b3381..19fb22f7 100644 --- a/src/writer/vegalite/layer.rs +++ b/src/writer/vegalite/layer.rs @@ -198,9 +198,20 @@ impl GeomRenderer for BarRenderer { Some(ParameterValue::Number(w)) => *w, _ => 0.9, }; + + // For dodged bars, use expression-based width with the adjusted width + // For non-dodged bars, use band-relative width + let width_value = if let Some(adjusted) = layer.adjusted_width { + // Use bandwidth expression for dodged bars + json!({"expr": format!("bandwidth('x') * {}", adjusted)}) + } else { + json!({"band": width}) + }; + layer_spec["mark"] = json!({ "type": "bar", - "width": {"band": width}, + "width": width_value, + "align": "center", "clip": true }); Ok(()) @@ -245,34 +256,11 @@ impl GeomRenderer for RibbonRenderer { // Area Renderer // ============================================================================= -/// Renderer for area geom - handles stacking options +/// Renderer for area geom - uses DefaultRenderer behavior +/// Stacking is handled by the position infrastructure via `position => 'stack'` pub struct AreaRenderer; -impl GeomRenderer for AreaRenderer { - fn modify_encoding(&self, encoding: &mut Map, layer: &Layer) -> Result<()> { - if let Some(mut y) = encoding.remove("y") { - let stack_value; - if let Some(ParameterValue::String(stack)) = layer.parameters.get("stacking") { - stack_value = match stack.as_str() { - "on" => json!("zero"), - "off" => Value::Null, - "fill" => json!("normalize"), - _ => { - return Err(GgsqlError::ValidationError(format!( - "Area layer's `stacking` must be \"on\", \"off\" or \"fill\", not \"{}\"", - stack - ))); - } - } - } else { - stack_value = Value::Null - } - y["stack"] = stack_value; - encoding.insert("y".to_string(), y); - } - Ok(()) - } -} +impl GeomRenderer for AreaRenderer {} // ============================================================================= // Polygon Renderer @@ -318,10 +306,6 @@ impl GeomRenderer for ViolinRenderer { }); let offset_col = naming::aesthetic_column("offset"); - // Mirror the density on both sides. - // It'll be implemented as an offset. - let violin_offset = format!("[datum.{offset}, -datum.{offset}]", offset = offset_col); - // We use an order calculation to create a proper closed shape. // Right side (+ offset), sort by -y (top -> bottom) // Left side (- offset), sort by +y (bottom -> top) @@ -331,8 +315,8 @@ impl GeomRenderer for ViolinRenderer { ); // Filter threshold to trim very low density regions (removes thin tails) - // In theory, this depends on the grid resolution and might be better - // handled upstream, but for now it seems not unreasonable. + // The offset is pre-scaled to [0, 0.5 * width] by geom post_process, + // but this filter still catches extremely low values. let filter_expr = format!("datum.{} > 0.001", offset_col); // Preserve existing transforms (e.g., source filter) and extend with violin-specific transforms @@ -342,6 +326,12 @@ impl GeomRenderer for ViolinRenderer { .cloned() .unwrap_or_default(); + // Mirror the offset on both sides (offset is already scaled by post_process) + let violin_offset = format!("[datum.{offset}, -datum.{offset}]", offset = offset_col); + + // Check if pos1offset exists (from dodging) - we'll combine it with violin offset + let pos1offset_col = naming::aesthetic_column("pos1offset"); + let mut transforms = existing_transforms; transforms.extend(vec![ json!({ @@ -349,6 +339,7 @@ impl GeomRenderer for ViolinRenderer { "filter": filter_expr }), json!({ + // Mirror offset on both sides (offset is pre-scaled to [0, 0.5 * width]) "calculate": violin_offset, "as": "violin_offsets" }), @@ -356,6 +347,15 @@ impl GeomRenderer for ViolinRenderer { "flatten": ["violin_offsets"], "as": ["__violin_offset"] }), + json!({ + // Add pos1offset (dodge displacement) if it exists, otherwise use violin offset directly + // This positions the violin correctly when dodging + "calculate": format!( + "datum.{pos1offset} != null ? datum.__violin_offset + datum.{pos1offset} : datum.__violin_offset", + pos1offset = pos1offset_col + ), + "as": "__final_offset" + }), json!({ "calculate": calc_order, "as": "__order" @@ -432,8 +432,11 @@ impl GeomRenderer for ViolinRenderer { encoding.insert( "xOffset".to_string(), json!({ - "field": "__violin_offset", - "type": "quantitative" + "field": "__final_offset", + "type": "quantitative", + "scale": { + "domain": [-0.5, 0.5] + } }), ); encoding.insert( @@ -453,8 +456,6 @@ impl GeomRenderer for ViolinRenderer { /// Metadata for boxplot rendering struct BoxplotMetadata { - /// Grouping column names - grouping_cols: Vec, /// Whether there are any outliers has_outliers: bool, } @@ -465,30 +466,15 @@ pub struct BoxplotRenderer; impl BoxplotRenderer { /// Prepare boxplot data by splitting into type-specific datasets. /// - /// Returns a HashMap of type_suffix -> data_values, plus grouping_cols and has_outliers. + /// Returns a HashMap of type_suffix -> data_values, plus has_outliers flag. /// Type suffixes are: "lower_whisker", "upper_whisker", "box", "median", "outlier" - #[allow(clippy::type_complexity)] fn prepare_components( &self, data: &DataFrame, binned_columns: &HashMap>, - ) -> Result<(HashMap>, Vec, bool)> { + ) -> Result<(HashMap>, bool)> { let type_col = naming::aesthetic_column("type"); let type_col = type_col.as_str(); - let value_col = naming::aesthetic_column("pos2"); - let value_col = value_col.as_str(); - let value2_col = naming::aesthetic_column("pos2end"); - let value2_col = value2_col.as_str(); - - // Find grouping columns (all columns except type, value, value2) - let grouping_cols: Vec = data - .get_column_names() - .iter() - .filter(|&col| { - col.as_str() != type_col && col.as_str() != value_col && col.as_str() != value2_col - }) - .map(|s| s.to_string()) - .collect(); // Get the type column for filtering let type_series = data @@ -527,7 +513,7 @@ impl BoxplotRenderer { type_datasets.insert(type_name.to_string(), values); } - Ok((type_datasets, grouping_cols, has_outliers)) + Ok((type_datasets, has_outliers)) } /// Render boxplot layers using filter transforms on the unified dataset. @@ -538,7 +524,6 @@ impl BoxplotRenderer { prototype: Value, layer: &Layer, base_key: &str, - grouping_cols: &[String], has_outliers: bool, ) -> Result> { let mut layers: Vec = Vec::new(); @@ -553,7 +538,8 @@ impl BoxplotRenderer { .ok_or_else(|| { GgsqlError::WriterError("Boxplot requires 'x' aesthetic mapping".to_string()) })?; - let y_col = layer + // Validate y aesthetic exists (not used directly but required for boxplot) + layer .mappings .get("pos2") .and_then(|y| y.column_name()) @@ -563,23 +549,26 @@ impl BoxplotRenderer { // Set orientation let is_horizontal = x_col == value_col; - let group_col = if is_horizontal { y_col } else { x_col }; - let offset = if is_horizontal { "yOffset" } else { "xOffset" }; let value_var1 = if is_horizontal { "x" } else { "y" }; let value_var2 = if is_horizontal { "x2" } else { "y2" }; - // Find dodge groups (grouping cols minus the axis group col) - let dodge_groups: Vec<&str> = grouping_cols - .iter() - .filter(|col| col.as_str() != group_col) - .map(|s| s.as_str()) - .collect(); - // Get width parameter - let mut width = 0.9; - if let Some(ParameterValue::Number(num)) = layer.parameters.get("width") { - width = *num; - } + let base_width = layer + .parameters + .get("width") + .and_then(|v| match v { + ParameterValue::Number(n) => Some(*n), + _ => None, + }) + .unwrap_or(0.9); + + // For dodged boxplots, use expression-based width with adjusted_width + // For non-dodged boxplots, use band-relative width + let width_value = if let Some(adjusted) = layer.adjusted_width { + json!({"expr": format!("bandwidth('x') * {}", adjusted)}) + } else { + json!({"band": base_width}) + }; // Helper to create filter transform for source selection let make_source_filter = |type_suffix: &str| -> Value { @@ -620,11 +609,6 @@ impl BoxplotRenderer { points["mark"]["filled"] = json!(true); } - // Add dodging offset - if !dodge_groups.is_empty() { - points["encoding"][offset] = json!({"field": dodge_groups[0]}); - } - layers.push(points); } @@ -688,7 +672,7 @@ impl BoxplotRenderer { "box", json!({ "type": "bar", - "width": {"band": width}, + "width": width_value, "align": "center" }), ); @@ -701,21 +685,12 @@ impl BoxplotRenderer { "median", json!({ "type": "tick", - "width": {"band": width}, + "width": width_value, "align": "center" }), ); median_line["encoding"][value_var1] = y_encoding; - // Add dodging to all summary layers - if !dodge_groups.is_empty() { - let offset_val = json!({"field": dodge_groups[0]}); - lower_whiskers["encoding"][offset] = offset_val.clone(); - upper_whiskers["encoding"][offset] = offset_val.clone(); - box_part["encoding"][offset] = offset_val.clone(); - median_line["encoding"][offset] = offset_val; - } - layers.push(lower_whiskers); layers.push(upper_whiskers); layers.push(box_part); @@ -732,15 +707,11 @@ impl GeomRenderer for BoxplotRenderer { _data_key: &str, binned_columns: &HashMap>, ) -> Result { - let (components, grouping_cols, has_outliers) = - self.prepare_components(df, binned_columns)?; + let (components, has_outliers) = self.prepare_components(df, binned_columns)?; Ok(PreparedData::Composite { components, - metadata: Box::new(BoxplotMetadata { - grouping_cols, - has_outliers, - }), + metadata: Box::new(BoxplotMetadata { has_outliers }), }) } @@ -766,13 +737,7 @@ impl GeomRenderer for BoxplotRenderer { GgsqlError::InternalError("Failed to downcast boxplot metadata".to_string()) })?; - self.render_layers( - prototype, - layer, - data_key, - &info.grouping_cols, - info.has_outliers, - ) + self.render_layers(prototype, layer, data_key, info.has_outliers) } } @@ -893,4 +858,64 @@ mod tests { ])) ); } + + #[test] + fn test_violin_mirroring() { + use crate::naming; + + let renderer = ViolinRenderer; + + let layer = Layer::new(crate::plot::Geom::violin()); + let mut layer_spec = json!({ + "mark": {"type": "line"}, + "encoding": { + "x": {"field": "species", "type": "nominal"}, + "y": {"field": naming::aesthetic_column("pos2"), "type": "quantitative"} + } + }); + + renderer.modify_spec(&mut layer_spec, &layer).unwrap(); + + // Verify transforms include mirroring (violin_offsets) + let transforms = layer_spec["transform"].as_array().unwrap(); + + // Find the violin_offsets calculation (mirrors offset on both sides) + let mirror_calc = transforms + .iter() + .find(|t| t.get("as").and_then(|a| a.as_str()) == Some("violin_offsets")); + assert!( + mirror_calc.is_some(), + "Should have violin_offsets mirroring calculation" + ); + + let calc_expr = mirror_calc.unwrap()["calculate"].as_str().unwrap(); + let offset_col = naming::aesthetic_column("offset"); + // Should mirror the offset column: [datum.offset, -datum.offset] + assert!( + calc_expr.contains(&offset_col), + "Mirror calculation should use offset column: {}", + calc_expr + ); + assert!( + calc_expr.contains("-datum"), + "Mirror calculation should negate: {}", + calc_expr + ); + + // Verify flatten transform exists + let flatten = transforms.iter().find(|t| t.get("flatten").is_some()); + assert!( + flatten.is_some(), + "Should have flatten transform for violin_offsets" + ); + + // Verify __final_offset calculation (combines with dodge offset) + let final_offset = transforms + .iter() + .find(|t| t.get("as").and_then(|a| a.as_str()) == Some("__final_offset")); + assert!( + final_offset.is_some(), + "Should have __final_offset calculation" + ); + } } diff --git a/src/writer/vegalite/mod.rs b/src/writer/vegalite/mod.rs index 7e5c7ef6..0d59dd02 100644 --- a/src/writer/vegalite/mod.rs +++ b/src/writer/vegalite/mod.rs @@ -291,6 +291,72 @@ fn build_layer_encoding( encoding.insert("detail".to_string(), detail); } + // Add xOffset encoding for dodged positions (pos1offset column) + // This column is created by position::apply_dodge() for Position::Dodge + // The offset values are centered around 0 (e.g., -0.3, 0, +0.3 for 3 groups) + // We set domain [-0.5, 0.5] to ensure the scale is symmetric and maps to full band width + let pos1offset_col = naming::aesthetic_column("pos1offset"); + if df.column(&pos1offset_col).is_ok() { + // Map to appropriate offset channel based on coord type + let offset_channel = match coord_kind { + CoordKind::Cartesian => "xOffset", + CoordKind::Polar => "thetaOffset", + }; + encoding.insert( + offset_channel.to_string(), + json!({ + "field": pos1offset_col, + "type": "quantitative", + "scale": { + "domain": [-0.5, 0.5] + } + }), + ); + } + + // Add yOffset encoding for vertical jitter (pos2offset column) + // This column is created by position::Jitter when pos2 axis is discrete + let pos2offset_col = naming::aesthetic_column("pos2offset"); + if df.column(&pos2offset_col).is_ok() { + // Map to appropriate offset channel based on coord type + let offset_channel = match coord_kind { + CoordKind::Cartesian => "yOffset", + CoordKind::Polar => "radiusOffset", + }; + encoding.insert( + offset_channel.to_string(), + json!({ + "field": pos2offset_col, + "type": "quantitative", + "scale": { + "domain": [-0.5, 0.5] + } + }), + ); + } + + // Disable Vega-Lite's automatic stacking - we handle position adjustments ourselves + // This prevents Vega-Lite from applying its own stack/dodge logic on top of ours + // Set stack: null on both y and y2 channels (pos2 and pos2end in our terminology) + let y_channel = match coord_kind { + CoordKind::Cartesian => "y", + CoordKind::Polar => "radius", + }; + let y2_channel = match coord_kind { + CoordKind::Cartesian => "y2", + CoordKind::Polar => "radius2", + }; + if let Some(y_enc) = encoding.get_mut(y_channel) { + if let Some(obj) = y_enc.as_object_mut() { + obj.insert("stack".to_string(), Value::Null); + } + } + if let Some(y2_enc) = encoding.get_mut(y2_channel) { + if let Some(obj) = y2_enc.as_object_mut() { + obj.insert("stack".to_string(), Value::Null); + } + } + // Apply geom-specific encoding modifications via renderer let renderer = get_renderer(&layer.geom); renderer.modify_encoding(&mut encoding, layer)?;