From b490986bcd40be703f14b1eb2a0e8037ba65004f Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Sun, 14 Dec 2025 23:38:20 -0500 Subject: [PATCH 1/9] Add SSMProductDistribution and code quality tests Introduces the SSMProductDistribution wrapper for product distributions of sequential sampling models, enabling custom logpdf methods for NamedTuple data. Refactors related methods in product_distribution.jl, updates exports, and adds a codequality.jl test file for Aqua and JET checks. Project.toml is updated to reflect new weak dependencies, test targets, and extras. --- Project.toml | 17 ++++++----- src/MDFT.jl | 2 +- src/SequentialSamplingModels.jl | 2 ++ src/product_distribution.jl | 53 ++++++++++++++++++++++++++------- test/codequality.jl | 28 +++++++++++++++++ 5 files changed, 83 insertions(+), 19 deletions(-) create mode 100644 test/codequality.jl diff --git a/Project.toml b/Project.toml index 19b75583..fe40e45d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,13 +5,9 @@ version = "0.12.7" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FunctionZeros = "b21f74c0-b399-568f-9643-d20f4fa2c814" HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -20,12 +16,15 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [extensions] -PlotsExt = "Plots" -TuringExt = "Turing" +PlotsExt = ["Plots", "Interpolations", "KernelDensity"] +TuringExt = ["Turing", "DynamicPPL"] [compat] Distributions = "v0.24.6, 0.25" @@ -46,6 +45,7 @@ StatsAPI = "1.0.0" StatsBase = "0.33.0,0.34.0" Turing = "0.35.0,0.36.0,0.37,0.38.0,0.39.0" julia = "1" +JET = "0.11.2" [extras] Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -56,6 +56,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" [targets] -test = ["Interpolations", "KernelDensity", "Plots", "QuadGK", "SafeTestsets", "StatsBase", "Statistics", "Test", "Turing"] +test = ["Interpolations", "KernelDensity", "Plots", "QuadGK", "SafeTestsets", "StatsBase", "Statistics", "Test", "Turing", "Aqua", "JET", "JuliaFormatter"] diff --git a/src/MDFT.jl b/src/MDFT.jl index b2bac371..ca1127aa 100644 --- a/src/MDFT.jl +++ b/src/MDFT.jl @@ -248,7 +248,7 @@ make_default_contrast(3) -0.5 -0.5 1.0 ``` """ -function make_default_contrast(n) +function make_default_contrast(n::Integer) C = fill(0.0, n, n) C .= -1 / (n - 1) for r ∈ 1:n diff --git a/src/SequentialSamplingModels.jl b/src/SequentialSamplingModels.jl index d9b7d1aa..3fc0ee49 100644 --- a/src/SequentialSamplingModels.jl +++ b/src/SequentialSamplingModels.jl @@ -63,6 +63,7 @@ export PoissonRace export ShiftedLogNormal export SSM1D export SSM2D +export SSMProductDistribution export stDDM export ContinuousMultivariateSSM export Wald @@ -86,6 +87,7 @@ export plot_model! export plot_quantiles export plot_quantiles! export predict_distribution +export product_distribution export rand export simulate export std diff --git a/src/product_distribution.jl b/src/product_distribution.jl index d19cbd02..73761f63 100644 --- a/src/product_distribution.jl +++ b/src/product_distribution.jl @@ -1,7 +1,38 @@ +""" + SSMProductDistribution + +Wrapper around `ProductDistribution` for sequential sampling models. +This type allows us to define `logpdf` methods for `NamedTuple` data +without type piracy. +""" +struct SSMProductDistribution{D <: ProductDistribution} + dist::D +end + +""" + product_distribution(dists) + +Create a product distribution from a vector of distributions. +Returns an `SSMProductDistribution` for SSM types, or a standard +`ProductDistribution` for other types. +""" +function product_distribution(dists::AbstractVector) + pd = ProductDistribution(dists) + # Check if this is an SSM that produces NamedTuple data + if eltype(dists) <: SSM2D + return SSMProductDistribution(pd) + else + return pd + end +end + +Base.size(s::SSMProductDistribution, dims...) = size(s.dist, dims...) +Base.length(s::SSMProductDistribution) = length(s.dist) + function rand( rng::AbstractRNG, - s::Sampleable{T, R} -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} + s::SSMProductDistribution +) n = size(s, 2) data = (; choice = fill(0, n), rt = fill(0.0, n)) return rand!(rng, s, data) @@ -9,9 +40,9 @@ end function rand( rng::AbstractRNG, - s::Sampleable{T, R}, + s::SSMProductDistribution, dims::Dims -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} +) n = size(s, 2) ax = map(Base.OneTo, dims) data = [(; choice = fill(0, n), rt = fill(0.0, n)) for _ in Iterators.product(ax...)] @@ -20,23 +51,23 @@ end function rand!( rng::AbstractRNG, - s::Sampleable{T, R}, + s::SSMProductDistribution, data::NamedTuple -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} +) for i ∈ 1:size(s, 2) - data.choice[i], data.rt[i] = rand(rng, s.dists[i]) + data.choice[i], data.rt[i] = rand(rng, s.dist.dists[i]) end return data end -function logpdf(d::ProductDistribution, data_array::Array{<:NamedTuple, N}) where {N} +function logpdf(d::SSMProductDistribution, data_array::Array{<:NamedTuple, N}) where {N} return [logpdf(d, data) for data ∈ data_array] end -function logpdf(d::ProductDistribution, data::NamedTuple) +function logpdf(d::SSMProductDistribution, data::NamedTuple) LL = 0.0 - for i ∈ 1:length(d.dists) - LL += logpdf(d.dists[i], data.choice[i], data.rt[i]) + for i ∈ 1:length(d.dist.dists) + LL += logpdf(d.dist.dists[i], data.choice[i], data.rt[i]) end return LL end diff --git a/test/codequality.jl b/test/codequality.jl new file mode 100644 index 00000000..fe13b50e --- /dev/null +++ b/test/codequality.jl @@ -0,0 +1,28 @@ +@safetestset "Code Quality" begin + + # check code is formatted + # @safetestset "code formatting" begin + # using JuliaFormatter + # using SequentialSamplingModels + # @test JuliaFormatter.format( + # SequentialSamplingModels; verbose = false, overwrite = false + # ) + # end + + # check code quality via Aqua + @safetestset "Aqua" begin + using Aqua + using SequentialSamplingModels + Aqua.test_all( + SequentialSamplingModels; ambiguities = false, + deps_compat = (check_extras = false,) + ) + end + + # test JET + @safetestset "JET" begin + using JET + using SequentialSamplingModels + JET.test_package(SequentialSamplingModels; target_defined_modules = true) + end +end From 9f202b3efaba3d28be756ba2e290436d896b6010 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 15 Dec 2025 00:01:01 -0500 Subject: [PATCH 2/9] Update dependencies and improve product_distribution tests Expanded [extensions] and [compat] in Project.toml, added new test dependencies and compat entries in test/Project.toml, and removed JET from main compat. In src, removed export of predict_distribution. In tests, explicitly imported product_distribution in relevant test sets and updated code quality checks to address package compatibility issues. --- Project.toml | 4 ++-- src/SequentialSamplingModels.jl | 1 - test/Project.toml | 10 ++++++++-- test/codequality.jl | 10 +++++++--- test/product_distribution_tests.jl | 8 ++++++++ 5 files changed, 25 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index c9967168..50f1fa3f 100644 --- a/Project.toml +++ b/Project.toml @@ -23,11 +23,12 @@ KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [extensions] -PlotsExt = "Plots" +PlotsExt = ["Plots", "Interpolations", "KernelDensity"] [compat] ArgCheck = "2.5.0" Distributions = "v0.24.6, 0.25" +DynamicPPL = "0.25 - 0.39" FunctionZeros = "0.2.0,0.3.0, 1" HCubature = "1" Interpolations = "0.14.0,0.15.0,0.16.0" @@ -41,7 +42,6 @@ Statistics = "1" StatsAPI = "1.0.0" StatsBase = "0.33.0,0.34.0" julia = "1" -JET = "0.11.2" [extras] Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/SequentialSamplingModels.jl b/src/SequentialSamplingModels.jl index 4630813d..1a3d19ce 100644 --- a/src/SequentialSamplingModels.jl +++ b/src/SequentialSamplingModels.jl @@ -86,7 +86,6 @@ export plot_model export plot_model! export plot_quantiles export plot_quantiles! -export predict_distribution export product_distribution export rand export simulate diff --git a/test/Project.toml b/test/Project.toml index 85d21439..496cebba 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,9 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" @@ -14,5 +17,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" TuringUtilities = "35dc62cd-6c01-44e1-a736-6cc36bfce0cc" -[sources.TuringUtilities] -url = "https://github.com/itsdfish/TuringUtilities.jl" \ No newline at end of file +[sources] +TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"} + +[compat] +JET = "0.11.2" \ No newline at end of file diff --git a/test/codequality.jl b/test/codequality.jl index fe13b50e..710e74cb 100644 --- a/test/codequality.jl +++ b/test/codequality.jl @@ -1,6 +1,8 @@ @safetestset "Code Quality" begin - # check code is formatted + # check code is formatted. Disabled as JET requires JuliaSynatx v1.0, but JuliaFormatter uses v0.4. + # This causes conflicts when both packages are loaded. Once JuluiaFormatter supports JuliaSyntax v1.0, + # this can be re-enabled. # @safetestset "code formatting" begin # using JuliaFormatter # using SequentialSamplingModels @@ -14,8 +16,10 @@ using Aqua using SequentialSamplingModels Aqua.test_all( - SequentialSamplingModels; ambiguities = false, - deps_compat = (check_extras = false,) + SequentialSamplingModels; + ambiguities = false, + deps_compat = (check_extras = false,), + project_extras = false ) end diff --git a/test/product_distribution_tests.jl b/test/product_distribution_tests.jl index 876dd8e6..72dacf9b 100644 --- a/test/product_distribution_tests.jl +++ b/test/product_distribution_tests.jl @@ -2,6 +2,7 @@ @safetestset "rand SSM1D 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -15,6 +16,7 @@ @safetestset "rand SSM1D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -28,6 +30,7 @@ @safetestset "rand logpdf 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -41,6 +44,7 @@ @safetestset "logpdf SSM1D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -54,6 +58,7 @@ @safetestset "rand SSM2D 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [ @@ -70,6 +75,7 @@ @safetestset "rand SSM2D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [ @@ -86,6 +92,7 @@ @safetestset "logpdf SSM2D 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [ @@ -103,6 +110,7 @@ @safetestset "logpdf SSM2D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [ From 79b14f0605c98d7a217db6f0aeb1bdd66ab798b4 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 15 Dec 2025 00:03:36 -0500 Subject: [PATCH 3/9] Update CI matrix. --- .github/workflows/CI.yml | 3 ++- test/codequality.jl | 11 ----------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 32d5e024..77f80360 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,8 @@ jobs: matrix: version: - '1.11' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'. - os: + - '1' + os: - ubuntu-latest arch: - x64 diff --git a/test/codequality.jl b/test/codequality.jl index 710e74cb..47b8e5b8 100644 --- a/test/codequality.jl +++ b/test/codequality.jl @@ -1,16 +1,5 @@ @safetestset "Code Quality" begin - # check code is formatted. Disabled as JET requires JuliaSynatx v1.0, but JuliaFormatter uses v0.4. - # This causes conflicts when both packages are loaded. Once JuluiaFormatter supports JuliaSyntax v1.0, - # this can be re-enabled. - # @safetestset "code formatting" begin - # using JuliaFormatter - # using SequentialSamplingModels - # @test JuliaFormatter.format( - # SequentialSamplingModels; verbose = false, overwrite = false - # ) - # end - # check code quality via Aqua @safetestset "Aqua" begin using Aqua From 474405ee1ea15c059713e4af2afdd6b8aba1224d Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 15 Dec 2025 00:07:26 -0500 Subject: [PATCH 4/9] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 496cebba..981a4c05 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,4 +21,4 @@ TuringUtilities = "35dc62cd-6c01-44e1-a736-6cc36bfce0cc" TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"} [compat] -JET = "0.11.2" \ No newline at end of file +JET = ">=0.9.0" \ No newline at end of file From f2ca4ad6a3dae3f47538c0d849e48330e7c778e7 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 29 Dec 2025 11:03:12 -0500 Subject: [PATCH 5/9] Update Project.toml --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 981a4c05..f3093ab8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,7 +3,6 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" From b5209626f3bb122d27ee3d5c945b4112e31446e4 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 29 Dec 2025 11:20:56 -0500 Subject: [PATCH 6/9] Update test/Project.toml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index f3093ab8..5681a33c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -20,4 +20,4 @@ TuringUtilities = "35dc62cd-6c01-44e1-a736-6cc36bfce0cc" TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"} [compat] -JET = ">=0.9.0" \ No newline at end of file +JET = "0.9, 1" \ No newline at end of file From 62f1e502dff7aab20b7dd12c2625d45c52e6ffba Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 29 Dec 2025 11:24:37 -0500 Subject: [PATCH 7/9] Update .github/workflows/CI.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 77f80360..41aca6ef 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -17,7 +17,7 @@ jobs: version: - '1.11' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'. - '1' - os: + os: - ubuntu-latest arch: - x64 From fe219b156ab9beee0d5b330638eaff376093c0b5 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 29 Dec 2025 11:44:22 -0500 Subject: [PATCH 8/9] Update Project.toml --- test/Project.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 5681a33c..efaa9327 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,7 +17,4 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" TuringUtilities = "35dc62cd-6c01-44e1-a736-6cc36bfce0cc" [sources] -TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"} - -[compat] -JET = "0.9, 1" \ No newline at end of file +TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"} \ No newline at end of file From c186cf38e3f1540bb2c40b44e1c0a04ddbf0efc0 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 29 Dec 2025 12:19:11 -0500 Subject: [PATCH 9/9] Update codequality.jl --- test/codequality.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/codequality.jl b/test/codequality.jl index 47b8e5b8..50c37775 100644 --- a/test/codequality.jl +++ b/test/codequality.jl @@ -16,6 +16,6 @@ @safetestset "JET" begin using JET using SequentialSamplingModels - JET.test_package(SequentialSamplingModels; target_defined_modules = true) + JET.test_package(SequentialSamplingModels; target_modules = (SequentialSamplingModels,)) end end