From 69dc4cd59b91e5bc6578f973326a91b0d05f87cc Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Dec 2025 21:28:57 -0800 Subject: [PATCH 01/10] some draft spline code --- src/et_models/et_splines.jl | 46 ++++++++ test/etmodels/test_splines.jl | 206 ++++++++++++++++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 src/et_models/et_splines.jl create mode 100644 test/etmodels/test_splines.jl diff --git a/src/et_models/et_splines.jl b/src/et_models/et_splines.jl new file mode 100644 index 00000000..528add6f --- /dev/null +++ b/src/et_models/et_splines.jl @@ -0,0 +1,46 @@ + + +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +import DecoratedParticles: XState + +import LuxCore: AbstractLuxLayer +using ConcreteStructs: @concrete + + +@concrete struct TransSelSplineBasis <: AbstractLuxLayer + trans # transform + envelope # envelope + selector # selector + ref_spl # reference spline basis (ignore the stored parameters) + states # reference spline parameters (frozen hence states) +end + + +(l::TransSelSplineBasis)(x, ps, st) = _apply_etsplinebasis(l, x, ps, st), st + + +function _apply_etsplinebasis(l::TransSelSplineBasis, + X::AbstractVector{<: XState}, + ps, st) + # transform + Y = l.trans(X) + # select the spline parameters + i_sel = map(l.selector, X) + # allocate + S = similar(Y, eltype(Y), (length(X), length(l.ref_spl))) + + for (idx, y) in enumerate(Y) + spl_idx = st.states[i_sel[idx]] + S[idx, :] = P4ML.evaluate(l.ref_spl, y, spl_idx) + end + + if envelope != nothing + ee, _ = l.envelope(X, ps.envelope, st.envelope) + S .= ee .* S + end + + return S +end + diff --git a/test/etmodels/test_splines.jl b/test/etmodels/test_splines.jl new file mode 100644 index 00000000..0c757223 --- /dev/null +++ b/test/etmodels/test_splines.jl @@ -0,0 +1,206 @@ +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) + +## + +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful, + AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML +import DecoratedParticles as DP + +using Polynomials4ML.Testing: print_tf, println_slim + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## + +# Generate an ACE model in the v0.8 style but +# - with fixed rcut. (relaxe this requirement later!!) +# get the pair potential component, compare with ETPairModel +# make pair_learnable = true to prevent splinification. + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 +rcut = 5.5 + +# modify rin0cuts to have same cutoff for all elements +# TODO: there is currently a bug with variable cutoffs +# (?is there? The radials seem fine? check again) +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) + + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true ) + +ps, st = Lux.setup(rng, model) + +# confirm that the E0s are all zero +@assert all( collect(values(model.Vref.E0)) .== 0 ) + +# set the many-body parameters to zero to isolate the pair potential +ps.WB[:] .= 0 + +## +# +# construct an ETPairModel that is consistent with `model` +# fixup the parameters to match the ACE model + +et_pair = ETM.convertpair(model) +et_ps, et_st = Lux.setup(rng, et_pair) + +# radial basis parameters for et_model_2 +et_ps.rembed.rbasis.post.W[:, :, 1] = ps.pairbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.rbasis.post.W[:, :, 2] = ps.pairbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.rbasis.post.W[:, :, 3] = ps.pairbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.rbasis.post.W[:, :, 4] = ps.pairbasis.Wnlq[:, :, 2, 2] + +# many-body basis parameters for et_model_2 +et_ps.readout.W[1, :, 1] .= ps.Wpair[:, 1] +et_ps.readout.W[1, :, 2] .= ps.Wpair[:, 2] + +## +# convert the pair basis to a splined version + +Nspl = 100 + +# polynomial basis taking y = y(r) as input +polys_y = et_pair.rembed.layer.rbasis.basis +# weights for cat-1 +WW = et_ps.rembed.rbasis.post.W +splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), 0.0, rcut, Nspl ) + for i in 1:size(WW, 3) ] +states = [ P4ML._init_luxstate(spl) for spl in splines ] +selector2 = et_pair.rembed.layer.rbasis.post.selector +trans_y = et_pair.rembed.layer.rbasis.trans +env = et_pair.rembed.layer.envelope + +poly_rbasis = et_pair.rembed.layer.rbasis + + +spl_rbasis = ET.EnvRBranchL(env, ) + + + + + + +## +# +# test energy evaluations +# + +calc_model = ACEpotentials.ACEPotential(model, ps, st) + +function rand_struct() + sys = AtomsBuilder.bulk(:Si) * (2,2,1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +function energy_new(sys, et_model) + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + Ei, _ = et_model(G, et_ps, et_st) + return sum(Ei) +end + +## + +@info("Check total energies match") +for ntest = 1:30 + sys = rand_struct() + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + E1 = AtomsCalculators.potential_energy(sys, calc_model) + E2 = energy_new(sys, et_pair) + print_tf( @test abs(ustrip(E1) - ustrip(E2)) < 1e-6 ) +end + +## + +@info("Check gradients and jacobians") + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +nnodes = length(G.node_data) +iZ = et_pair.readout.selector.(G.node_data) +WW = et_ps.readout.W + +# gradient of model w.r.t. positions +∂G = ETM.site_grads(et_pair, G, et_ps, et_st) # test run + +# basis +𝔹1 = ETM.site_basis(et_pair, G, et_ps, et_st) + +# basis jacobian +𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_pair, G, et_ps, et_st) + +println_slim(@test 𝔹1 ≈ 𝔹2) + +∇Ei2 = reduce( hcat, ∂𝔹2[:, i, :] * WW[1, :, iZ[i]] + for (i, iz) in enumerate(iZ) ) +∇Ei3 = reshape(∇Ei2, size(∇Ei2)..., 1) +∇E_𝔹_edges = ET.rev_reshape_embedding(∇Ei3, G)[:] +println_slim(@test all(∇E_𝔹_edges .≈ ∂G.edge_data)) + +## + + +# turn off during CI -- need to sort out CI for GPU tests + +#= + +@info("Check GPU evaluation") +using Metal +dev = Metal.mtl +ps_32 = ET.float32(et_ps) +st_32 = ET.float32(et_st) +ps_dev = dev(ps_32) +st_dev = dev(st_32) + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, 5.0 * u"Å") +G_32 = ET.float32(G) +G_dev = dev(G_32) + +E1, st = et_pair(G_32, ps_32, st_32) +E2_dev, st_dev = et_pair(G_dev, ps_dev, st_dev) +E2 = Array(E2_dev) +println_slim(@test E1 ≈ E2) + +g1 = ETM.site_grads(et_pair, G_32, ps_32, st_32) +g2_dev = ETM.site_grads(et_pair, G_dev, ps_dev, st_dev) +g2_edge = Array(g2_dev.edge_data) +println_slim(@test all(g1.edge_data .≈ g2_edge)) + +b1 = ETM.site_basis(et_pair, G_32, ps_32, st_32) +b2_dev = ETM.site_basis(et_pair, G_dev, ps_dev, st_dev) +b2 = Array(b2_dev) +println_slim(@test b1 ≈ b2) + +b1, ∂db1 = ETM.site_basis_jacobian(et_pair, G_32, ps_32, st_32) +b2_dev, ∂db2_dev = ETM.site_basis_jacobian(et_pair, G_dev, ps_dev, st_dev) +b2 = Array(b2_dev) +∂db2 = Array(∂db2_dev) +println_slim(@test b1 ≈ b2) +jacerr = norm.(∂db1 .- ∂db2) ./ (1 .+ norm.(∂db1) + norm.(∂db2)) +@show maximum(jacerr) +println_slim( @test maximum(jacerr) < 1e-4 ) + +=# \ No newline at end of file From 37fb93fc42107f67916a696642424ebd596e3197 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Dec 2025 09:55:48 -0800 Subject: [PATCH 02/10] first passing spline test --- test/etmodels/test_splines.jl | 52 ++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/test/etmodels/test_splines.jl b/test/etmodels/test_splines.jl index 0c757223..0d47c693 100644 --- a/test/etmodels/test_splines.jl +++ b/test/etmodels/test_splines.jl @@ -1,6 +1,6 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) using TestEnv; TestEnv.activate(); -# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) @@ -48,6 +48,7 @@ model = M.ace_model(; elements = elements, order = order, init_WB = :glorot_normal, init_Wpair = :glorot_normal, pair_learnable = true ) +Random.seed!(1234) # new seed to make sure the tests are consistent ps, st = Lux.setup(rng, model) # confirm that the E0s are all zero @@ -77,52 +78,54 @@ et_ps.readout.W[1, :, 2] .= ps.Wpair[:, 2] ## # convert the pair basis to a splined version -Nspl = 100 +# overkill spline accuracy to check errors +Nspl = 200 # polynomial basis taking y = y(r) as input polys_y = et_pair.rembed.layer.rbasis.basis # weights for cat-1 WW = et_ps.rembed.rbasis.post.W splines = [ - P4ML.splinify( y -> WW[:, :, i] * polys_y(y), 0.0, rcut, Nspl ) + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) for i in 1:size(WW, 3) ] states = [ P4ML._init_luxstate(spl) for spl in splines ] selector2 = et_pair.rembed.layer.rbasis.post.selector trans_y = et_pair.rembed.layer.rbasis.trans -env = et_pair.rembed.layer.envelope +envelope = et_pair.rembed.layer.envelope -poly_rbasis = et_pair.rembed.layer.rbasis +spl_rbasis = ET.TransSelSplines(trans_y, envelope, selector2, splines[1], states) +ps_spl, st_spl = LuxCore.setup(rng, spl_rbasis) +poly_rbasis = et_pair.rembed.layer +ps_poly = et_ps.rembed +st_poly = et_st.rembed -spl_rbasis = ET.EnvRBranchL(env, ) - - - - - - -## -# -# test energy evaluations -# - -calc_model = ACEpotentials.ACEPotential(model, ps, st) +## -function rand_struct() +function rand_X() sys = AtomsBuilder.bulk(:Si) * (2,2,1) rattle!(sys, 0.2u"Å") AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) - return sys + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + return G.edge_data end -function energy_new(sys, et_model) - G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - Ei, _ = et_model(G, et_ps, et_st) - return sum(Ei) +## + +Random.seed!(1234) # new seed to make sure the tests are ok. +for ntest = 1:30 + X = rand_X() + P1, _ = poly_rbasis(X, ps_poly, st_poly) + P2, _ = spl_rbasis(X, ps_spl, st_spl) + spl_err = abs.(P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1) + # @show maximum(spl_err) + print_tf(@test maximum(spl_err) < 1e-5) end ## +#= + @info("Check total energies match") for ntest = 1:30 sys = rand_struct() @@ -164,7 +167,6 @@ println_slim(@test all(∇E_𝔹_edges .≈ ∂G.edge_data)) # turn off during CI -- need to sort out CI for GPU tests -#= @info("Check GPU evaluation") using Metal From 44c2d7eb5dc410b4d3199645e6ac7107f1148a14 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Dec 2025 09:56:19 -0800 Subject: [PATCH 03/10] move spline implementation to ET --- src/et_models/et_splines.jl | 46 ------------------------------------- 1 file changed, 46 deletions(-) delete mode 100644 src/et_models/et_splines.jl diff --git a/src/et_models/et_splines.jl b/src/et_models/et_splines.jl deleted file mode 100644 index 528add6f..00000000 --- a/src/et_models/et_splines.jl +++ /dev/null @@ -1,46 +0,0 @@ - - -import EquivariantTensors as ET -import Polynomials4ML as P4ML - -import DecoratedParticles: XState - -import LuxCore: AbstractLuxLayer -using ConcreteStructs: @concrete - - -@concrete struct TransSelSplineBasis <: AbstractLuxLayer - trans # transform - envelope # envelope - selector # selector - ref_spl # reference spline basis (ignore the stored parameters) - states # reference spline parameters (frozen hence states) -end - - -(l::TransSelSplineBasis)(x, ps, st) = _apply_etsplinebasis(l, x, ps, st), st - - -function _apply_etsplinebasis(l::TransSelSplineBasis, - X::AbstractVector{<: XState}, - ps, st) - # transform - Y = l.trans(X) - # select the spline parameters - i_sel = map(l.selector, X) - # allocate - S = similar(Y, eltype(Y), (length(X), length(l.ref_spl))) - - for (idx, y) in enumerate(Y) - spl_idx = st.states[i_sel[idx]] - S[idx, :] = P4ML.evaluate(l.ref_spl, y, spl_idx) - end - - if envelope != nothing - ee, _ = l.envelope(X, ps.envelope, st.envelope) - S .= ee .* S - end - - return S -end - From fca477141fa8417a68955276e7b75f0f420b0970 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Dec 2025 13:20:35 -0800 Subject: [PATCH 04/10] tests for spline derivatives --- test/etmodels/test_splines.jl | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/test/etmodels/test_splines.jl b/test/etmodels/test_splines.jl index 0d47c693..8380fb3c 100644 --- a/test/etmodels/test_splines.jl +++ b/test/etmodels/test_splines.jl @@ -7,7 +7,8 @@ Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) ## using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful, - AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase + AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase, + ForwardDiff M = ACEpotentials.Models ETM = ACEpotentials.ETModels @@ -110,8 +111,10 @@ function rand_X() return G.edge_data end + ## +@info("Checking spline accuracy against polynomial basis") Random.seed!(1234) # new seed to make sure the tests are ok. for ntest = 1:30 X = rand_X() @@ -120,10 +123,35 @@ for ntest = 1:30 spl_err = abs.(P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1) # @show maximum(spl_err) print_tf(@test maximum(spl_err) < 1e-5) + + (P1a, dP1), _ = ET.evaluate_ed(poly_rbasis, X, ps_poly, st_poly) + (P2a, dP2), _ = ET.evaluate_ed(spl_rbasis, X, ps_spl, st_spl) + print_tf(@test P2 ≈ P2a) + dspl_err = norm.(dP1 - dP2) ./ (1 .+ abs.(P1) + abs.(P2)) + # @show maximum(dspl_err) + print_tf(@test maximum(dspl_err) < 1e-3) end ## +@info("Checking machine precision derivative accuracy ") +# NOTE: This test should really be in ET and not here ... + +X = rand_X() +rand_u() = ( u = (@SVector randn(3)); DP.VState(𝐫 = u/norm(u)) ) +U = [ rand_u() for _ = 1:length(X) ] + +f(t) = spl_rbasis(X + t * U, ps_spl, st_spl)[1] +df0 = ForwardDiff.derivative(f, 0.0) + +(P2a, dP2), _ = ET.evaluate_ed(spl_rbasis, X, ps_spl, st_spl) +dp = [ dot(U[i], dP2[i, j]) for i in 1:length(U), j = 1:size(dP2, 2) ] +println_slim(@test df0 ≈ dp) + + +## + + #= @info("Check total energies match") From b9de17157ea3ba03676a488b04eba0c25e626339 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Dec 2025 15:08:47 -0800 Subject: [PATCH 05/10] splinification of pair model + tests --- src/et_models/et_models.jl | 3 +++ src/et_models/splinify.jl | 27 ++++++++++++++++++++++ test/etmodels/test_etpair.jl | 45 ++++++++++++++++++++++++++++++------ 3 files changed, 68 insertions(+), 7 deletions(-) create mode 100644 src/et_models/splinify.jl diff --git a/src/et_models/et_models.jl b/src/et_models/et_models.jl index 698b931e..73fa4872 100644 --- a/src/et_models/et_models.jl +++ b/src/et_models/et_models.jl @@ -14,5 +14,8 @@ include("et_pair.jl") # converstion utilities: convert from 0.8 style ACE models to ET based models include("convert.jl") +# utilities to convert radial embeddings to splined versions +# for simplicity and performance and to freeze parameters +include("splinify.jl") end \ No newline at end of file diff --git a/src/et_models/splinify.jl b/src/et_models/splinify.jl new file mode 100644 index 00000000..6e90d77d --- /dev/null +++ b/src/et_models/splinify.jl @@ -0,0 +1,27 @@ + +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +function splinify(et_pair::ETPairModel, et_ps, et_st; + Nspl = 30) + + # polynomial basis taking y = y(r) as input + trans_y = et_pair.rembed.layer.rbasis.trans + polys_y = et_pair.rembed.layer.rbasis.basis + # weights for learnable radials + WW = et_ps.rembed.rbasis.post.W + # use P4ML to generate individual cubic splines + splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] + # extract the spline parameters into an array of parameter sets + states = [ P4ML._init_luxstate(spl) for spl in splines ] + # selects the correct spline based on the (Zi, Zj) pair + selector2 = et_pair.rembed.layer.rbasis.post.selector + # envelope multiplying the spline + envelope = et_pair.rembed.layer.envelope + + spl_rbasis = ET.TransSelSplines(trans_y, envelope, selector2, splines[1], states) + + return ETPairModel( ET.EdgeEmbed(spl_rbasis), et_pair.readout ) +end \ No newline at end of file diff --git a/test/etmodels/test_etpair.jl b/test/etmodels/test_etpair.jl index 73c5c9dd..86de71a2 100644 --- a/test/etmodels/test_etpair.jl +++ b/test/etmodels/test_etpair.jl @@ -1,6 +1,6 @@ -# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) -# using TestEnv; TestEnv.activate(); -# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using TestEnv; TestEnv.activate(); +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) @@ -74,6 +74,22 @@ et_ps.rembed.rbasis.post.W[:, :, 4] = ps.pairbasis.Wnlq[:, :, 2, 2] et_ps.readout.W[1, :, 1] .= ps.Wpair[:, 1] et_ps.readout.W[1, :, 2] .= ps.Wpair[:, 2] +## +# +# make a splined version of the et_pair model +# + +spl_50 = ETM.splinify(et_pair, et_ps, et_st; Nspl = 50) +ps_50, st_50 = Lux.setup(rng, spl_50) + +spl_200 = ETM.splinify(et_pair, et_ps, et_st; Nspl = 200) +ps_200, st_200 = Lux.setup(rng, spl_200) + +# many-body basis parameters for et_model_2 +ps_50.readout.W[:] = et_ps.readout.W +ps_200.readout.W[:] = et_ps.readout.W + + ## # # test energy evaluations @@ -88,21 +104,26 @@ function rand_struct() return sys end -function energy_new(sys, et_model) +function energy_new(sys, et_model, ps, st) G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - Ei, _ = et_model(G, et_ps, et_st) + Ei, _ = et_model(G, ps, st) return sum(Ei) end ## +Random.seed!(1234) @info("Check total energies match") for ntest = 1:30 sys = rand_struct() G = ET.Atoms.interaction_graph(sys, rcut * u"Å") E1 = AtomsCalculators.potential_energy(sys, calc_model) - E2 = energy_new(sys, et_pair) + E2 = energy_new(sys, et_pair, et_ps, et_st) + E_50 = energy_new(sys, spl_50, ps_50, st_50) + E_200 = energy_new(sys, spl_200, ps_200, st_200) print_tf( @test abs(ustrip(E1) - ustrip(E2)) < 1e-6 ) + print_tf( @test abs(ustrip(E2) - ustrip(E_50)) < 1e-2 ) + print_tf( @test abs(ustrip(E2) - ustrip(E_200)) < 1e-4 ) end ## @@ -116,15 +137,20 @@ iZ = et_pair.readout.selector.(G.node_data) WW = et_ps.readout.W # gradient of model w.r.t. positions -∂G = ETM.site_grads(et_pair, G, et_ps, et_st) # test run +∂G = ETM.site_grads(et_pair, G, et_ps, et_st) +∂G_200 = ETM.site_grads(spl_200, G, ps_200, st_200) # basis 𝔹1 = ETM.site_basis(et_pair, G, et_ps, et_st) +𝔹1_200 = ETM.site_basis(spl_200, G, ps_200, st_200) # basis jacobian 𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_pair, G, et_ps, et_st) +𝔹2_200, ∂𝔹2_200 = ETM.site_basis_jacobian(spl_200, G, ps_200, st_200) println_slim(@test 𝔹1 ≈ 𝔹2) +println_slim(@test 𝔹1_200 ≈ 𝔹2_200) +println_slim(@test norm(𝔹1 - 𝔹1_200, Inf) < 1e-4) ∇Ei2 = reduce( hcat, ∂𝔹2[:, i, :] * WW[1, :, iZ[i]] for (i, iz) in enumerate(iZ) ) @@ -132,6 +158,11 @@ println_slim(@test 𝔹1 ≈ 𝔹2) ∇E_𝔹_edges = ET.rev_reshape_embedding(∇Ei3, G)[:] println_slim(@test all(∇E_𝔹_edges .≈ ∂G.edge_data)) +# check error in site energy gradients for splines +println_slim(@test maximum(norm.(∂G.edge_data - ∂G_200.edge_data)) < 1e-3) +# check error in basis jacobian for splines +println_slim(@test maximum(norm.(∂𝔹2 - ∂𝔹2_200)) < 1e-3) + ## From 052789901f903d9181ae028fd7f6d4e42b38e626 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Dec 2025 15:55:43 -0800 Subject: [PATCH 06/10] rename -> test_etace --- test/etmodels/{test_etbackend.jl => test_etace.jl} | 0 test/runtests.jl | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename test/etmodels/{test_etbackend.jl => test_etace.jl} (100%) diff --git a/test/etmodels/test_etbackend.jl b/test/etmodels/test_etace.jl similarity index 100% rename from test/etmodels/test_etbackend.jl rename to test/etmodels/test_etace.jl diff --git a/test/runtests.jl b/test/runtests.jl index 3da7fde6..624c7b65 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,7 @@ using ACEpotentials, Test, LazyArtifacts @testset "Weird bugs" begin include("test_bugs.jl") end # new ET backend tests - @testset "ET ACE" begin include("etmodels/test_etbackend.jl") end + @testset "ET ACE" begin include("etmodels/test_etace.jl") end @testset "ET OneBody" begin include("etmodels/test_etonebody.jl") end @testset "ET Pair" begin include("etmodels/test_etpair.jl") end From 30ae6a7095a6d9364e2763181db8f915e3c041da Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Dec 2025 16:39:19 -0800 Subject: [PATCH 07/10] prototype splinification of ETACE --- src/et_models/splinify.jl | 38 +++++++++++++ test/etmodels/test_etace.jl | 109 ++++++++++++++++++++---------------- 2 files changed, 100 insertions(+), 47 deletions(-) diff --git a/src/et_models/splinify.jl b/src/et_models/splinify.jl index 6e90d77d..8b3f33d0 100644 --- a/src/et_models/splinify.jl +++ b/src/et_models/splinify.jl @@ -2,6 +2,11 @@ import EquivariantTensors as ET import Polynomials4ML as P4ML +# These implementations of `splinify` expect a very specific structure of the +# pair potential basis. In principle it is possible to relax this +# considerably but it needs a little bit of thinking and planning/design +# work before just diving in. To be discussed when needed. + function splinify(et_pair::ETPairModel, et_ps, et_st; Nspl = 30) @@ -24,4 +29,37 @@ function splinify(et_pair::ETPairModel, et_ps, et_st; spl_rbasis = ET.TransSelSplines(trans_y, envelope, selector2, splines[1], states) return ETPairModel( ET.EdgeEmbed(spl_rbasis), et_pair.readout ) +end + + +function splinify(et_model::ETACE, et_ps, et_st; Nspl = 50) + + rembed = et_model.rembed.layer # radial embedding, edgeembed stripped + trans = rembed.trans # x -> y dp_transform + rpolys_env = rembed.basis # polynomials * envelope + polys_y = rpolys_env.l.layers.layer_1 # polynomial basis + yenv_func = rpolys_env.l.layers.layer_2.func # envelope function + + # envelope multiplying the spline, apply the transformation a second + # time until we figure out how to reuse it conveniently + trans_yenv = ET.dp_transform( + (x, st) -> yenv_func(trans.f(x, st)), + trans.refstate ) + # selects the correct spline based on the (Zi, Zj) pair + selector2 = rembed.post.selector + # generate the splines using P4ML + WW = et_ps.rembed.post.W + splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] + # extract the spline parameters into an array of parameter sets + states = [ P4ML._init_luxstate(spl) for spl in splines ] + + rembed_spl = ET.TransSelSplines(trans, trans_yenv, selector2, + splines[1], states) + ace_spl = ETACE( ET.EdgeEmbed(rembed_spl), + et_model.yembed, + et_model.basis, + et_model.readout ) + return ace_spl end \ No newline at end of file diff --git a/test/etmodels/test_etace.jl b/test/etmodels/test_etace.jl index 83895a81..3cd9fa32 100644 --- a/test/etmodels/test_etace.jl +++ b/test/etmodels/test_etace.jl @@ -1,6 +1,6 @@ -# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) -# using TestEnv; TestEnv.activate(); -# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using TestEnv; TestEnv.activate(); +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "Polynomials4ML.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) @@ -32,12 +32,13 @@ level = M.TotalDegree() max_level = 10 order = 3 maxl = 6 +rcut = 5.5 # modify rin0cuts to have same cutoff for all elements # TODO: there is currently a bug with variable cutoffs # (?is there? The radials seem fine? check again) rin0cuts = M._default_rin0cuts(elements) -rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) model = M.ace_model(; elements = elements, order = order, @@ -48,6 +49,10 @@ model = M.ace_model(; elements = elements, order = order, ps, st = Lux.setup(rng, model) +# wrap the old ACE model into a calculator +calc_model = ACEpotentials.ACEPotential(model, ps, st) + + # Missing issues: # Vref = 0 => this will not be tested # pair potential will also not be tested @@ -62,8 +67,8 @@ end # Convert the v0.8 model to an ET backend based model based on the # implementation in ETM # -et_model_2 = ETM.convert2et(model) -et_ps_2, et_st_2 = LuxCore.setup(MersenneTwister(1234), et_model_2) +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) ## # fixup all the parameters to make sure they match @@ -71,32 +76,38 @@ et_ps_2, et_st_2 = LuxCore.setup(MersenneTwister(1234), et_model_2) # is because meta["mb_spec"] only gives the original ordering before basis # construction ... something to look into. nnll = M.get_nnll_spec(model.tensor) -et_nnll_2 = et_model_2.basis.meta["mb_spec"] +et_nnll = et_model.basis.meta["mb_spec"] @info("Check basis ordering") -println_slim(@test nnll == et_nnll_2) +println_slim(@test nnll == et_nnll) # but this is also identical ... @info("Check symmetrization operator") -@show ( model.tensor.A2Bmaps[1] == et_model_2.basis.A2Bmaps[1] ) +@show ( model.tensor.A2Bmaps[1] == et_model.basis.A2Bmaps[1] ) -# radial basis parameters for et_model_2 -et_ps_2.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] -et_ps_2.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] -et_ps_2.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] -et_ps_2.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] +# radial basis parameters for et_model +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] -# many-body basis parameters for et_model_2 -et_ps_2.readout.W[1, :, 1] .= ps.WB[:, 1] -et_ps_2.readout.W[1, :, 2] .= ps.WB[:, 2] +# many-body basis parameters for et_model +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps.readout.W[1, :, 2] .= ps.WB[:, 2] ## -# wrap the old ACE model into a calculator -calc_model = ACEpotentials.ACEPotential(model, ps, st) +# setup two splined ACE models + +spl_50 = ETM.splinify(et_model, et_ps, et_st; Nspl = 50) +ps_50, st_50 = Lux.setup(rng, spl_50) +ps_50.readout.W[:] .= et_ps.readout.W[:] + +spl_200 = ETM.splinify(et_model, et_ps, et_st; Nspl = 200) +ps_200, st_200 = Lux.setup(rng, spl_200) +ps_200.readout.W[:] .= et_ps.readout.W[:] + +## -# we will also need to get the cutoff radius which we didn't track -# (Another TODO!!!) -rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) function rand_struct() sys = AtomsBuilder.bulk(:Si) * (2,2,1) @@ -105,21 +116,25 @@ function rand_struct() return sys end -function energy_new_2(sys, et_model) +function energy_new(sys, _model, _ps, _st) G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - Ei, _ = et_model_2(G, et_ps_2, et_st_2) + Ei, _ = _model(G, _ps, _st) return sum(Ei) end ## +Random.seed!(1234) @info("Check total energies match") for ntest = 1:30 sys = rand_struct() - G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - E1 = AtomsCalculators.potential_energy(sys, calc_model) - E3 = energy_new_2(sys, et_model_2) - print_tf( @test abs(ustrip(E1) - ustrip(E3)) < 1e-6 ) + E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) + E2 = energy_new(sys, et_model, et_ps, et_st) + E3 = energy_new(sys, spl_50, ps_50, st_50) + E4 = energy_new(sys, spl_200, ps_200, st_200) + print_tf( @test abs(E1 - E2) < 1e-6 ) + print_tf( @test abs(E2 - E3) / (1+abs(E2)+abs(E3)) < 1e-2 ) + print_tf( @test abs(E2 - E4) / (1+abs(E2)+abs(E4)) < 1e-4 ) end println() @@ -131,8 +146,8 @@ using Zygote, ForwardDiff sys = rand_struct() G = ET.Atoms.interaction_graph(sys, rcut * u"Å") -∂G2a = Zygote.gradient(G -> sum(et_model_2(G, et_ps_2, et_st_2)[1]), G)[1] -∂G2b = ETM.site_grads(et_model_2, G, et_ps_2, et_st_2) +∂G2a = Zygote.gradient(G -> sum(et_model(G, et_ps, et_st)[1]), G)[1] +∂G2b = ETM.site_grads(et_model, G, et_ps, et_st) @info("confirm consistency of Zygote and site_grads") println(@test all(∂G2a.edge_data .≈ ∂G2b.edge_data)) @@ -167,7 +182,7 @@ end @info("confirm consistency of gradients with ForwardDiff") -∇E_fd = grad_fd(G, et_model_2, et_ps_2, et_st_2) +∇E_fd = grad_fd(G, et_model, et_ps, et_st) println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data)) ## @@ -177,11 +192,11 @@ println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data)) G = ET.Atoms.interaction_graph(sys, rcut * u"Å") nnodes = length(G.node_data) -iZ = et_model_2.readout.selector.(G.node_data) -WW = et_ps_2.readout.W +iZ = et_model.readout.selector.(G.node_data) +WW = et_ps.readout.W -𝔹1 = ETM.site_basis(et_model_2, G, et_ps_2, et_st_2) -𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model_2, G, et_ps_2, et_st_2) +𝔹1 = ETM.site_basis(et_model, G, et_ps, et_st) +𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model, G, et_ps, et_st) ## @@ -189,7 +204,7 @@ WW = et_ps_2.readout.W println_slim(@test 𝔹1 ≈ 𝔹2) Ei_a = [ dot(𝔹2[i, :], WW[1, :, iZ[i]]) for (i, iz) in enumerate(iZ) ] -Ei_b = et_model_2(G, et_ps_2, et_st_2)[1][:] +Ei_b = et_model(G, et_ps, et_st)[1][:] println_slim(@test Ei_a ≈ Ei_b) ## @@ -229,21 +244,21 @@ G_32 = ET.float32(G) # move all data to the device G_32_dev = dev(G_32) -ps_dev_2 = dev(ET.float32(et_ps_2)) -st_dev_2 = dev(ET.float32(et_st_2)) -ps_32_2 = ET.float32(et_ps_2) -st_32_2 = ET.float32(et_st_2) +ps_dev = dev(ET.float32(et_ps)) +st_dev = dev(ET.float32(et_st)) +ps_32 = ET.float32(et_ps) +st_32 = ET.float32(et_st) E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) -E4 = sum(et_model_2(G_32_dev, ps_dev_2, st_dev_2)[1]) +E4 = sum(et_model(G_32_dev, ps_dev, st_dev)[1]) println_slim( @test abs(E1 - E4) / (abs(E1) + abs(E4) + 1e-7) < 1e-5 ) ## # gradients on GPU @info("Check Evaluation of gradient on GPU") -g1 = ETM.site_grads(et_model_2, G_32, ps_32_2, st_32_2) -g2_dev = ETM.site_grads(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +g1 = ETM.site_grads(et_model, G_32, ps_32, st_32) +g2_dev = ETM.site_grads(et_model, G_32_dev, ps_dev, st_dev) ∇1 = g1.edge_data ∇2 = Array(g2_dev.edge_data) println_slim( @test all(∇1 .≈ ∇2) ) @@ -252,14 +267,14 @@ println_slim( @test all(∇1 .≈ ∇2) ) @info("Basis evaluation on GPU") -𝔹1 = ETM.site_basis(et_model_2, G_32, ps_32_2, st_32_2) -𝔹2_dev = ETM.site_basis(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +𝔹1 = ETM.site_basis(et_model, G_32, ps_32, st_32) +𝔹2_dev = ETM.site_basis(et_model, G_32_dev, ps_dev, st_dev) 𝔹2 = Array(𝔹2_dev) println_slim( @test 𝔹1 ≈ 𝔹2 ) @info("Basis jacobian evaluation on GPU") -𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model_2, G_32, ps_32_2, st_32_2) -𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model, G_32, ps_32, st_32) +𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model, G_32_dev, ps_dev, st_dev) 𝔹2 = Array(𝔹2_dev) ∂𝔹2 = Array(∂𝔹2_dev) From f3ab005257cdd965528bc18a4f7ebb8fed688d27 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Dec 2025 16:46:00 -0800 Subject: [PATCH 08/10] all spline tests except gpu --- test/etmodels/test_etace.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/etmodels/test_etace.jl b/test/etmodels/test_etace.jl index 3cd9fa32..a1567b33 100644 --- a/test/etmodels/test_etace.jl +++ b/test/etmodels/test_etace.jl @@ -149,9 +149,17 @@ G = ET.Atoms.interaction_graph(sys, rcut * u"Å") ∂G2a = Zygote.gradient(G -> sum(et_model(G, et_ps, et_st)[1]), G)[1] ∂G2b = ETM.site_grads(et_model, G, et_ps, et_st) +∂G_50 = ETM.site_grads(spl_50, G, ps_50, st_50) +∂G_200 = ETM.site_grads(spl_200, G, ps_200, st_200) + @info("confirm consistency of Zygote and site_grads") println(@test all(∂G2a.edge_data .≈ ∂G2b.edge_data)) +err_50 = maximum(norm.(∂G2b.edge_data - ∂G_50.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_50.edge_data))) +err_200 = maximum(norm.(∂G2b.edge_data - ∂G_200.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_200.edge_data))) +println_slim(@test err_50 < 1) +println_slim(@test err_200 < 0.01) + ## # test gradient against ForwardDiff @@ -207,6 +215,16 @@ Ei_a = [ dot(𝔹2[i, :], WW[1, :, iZ[i]]) for (i, iz) in enumerate(iZ) ] Ei_b = et_model(G, et_ps, et_st)[1][:] println_slim(@test Ei_a ≈ Ei_b) +## + +@info("splined site basis") +𝔹_200 = ETM.site_basis(spl_200, G, ps_200, st_200) +𝔹2_200, ∂𝔹2_200 = ETM.site_basis_jacobian(spl_200, G, ps_200, st_200) + +println_slim(@test 𝔹_200 ≈ 𝔹2_200 ) +println_slim(@test norm(𝔹1 - 𝔹_200, Inf) < 3e-3) +println_slim(@test maximum(norm.(∂𝔹2 - ∂𝔹2_200)) < 0.1) + ## @info("Confirm correctness of Jacobian against gradient") From 7589eaa8e7b35a53f9c75999b57f708a1b12d921 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 31 Dec 2025 16:05:55 -0800 Subject: [PATCH 09/10] adjust splines code to ET 0.4.3 --- src/et_models/splinify.jl | 9 ++------- test/etmodels/test_etace.jl | 13 +++++++++---- test/etmodels/test_etpair.jl | 19 ++++++++++++++++++- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/et_models/splinify.jl b/src/et_models/splinify.jl index 8b3f33d0..e36359ce 100644 --- a/src/et_models/splinify.jl +++ b/src/et_models/splinify.jl @@ -19,14 +19,12 @@ function splinify(et_pair::ETPairModel, et_ps, et_st; splines = [ P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) for i in 1:size(WW, 3) ] - # extract the spline parameters into an array of parameter sets - states = [ P4ML._init_luxstate(spl) for spl in splines ] # selects the correct spline based on the (Zi, Zj) pair selector2 = et_pair.rembed.layer.rbasis.post.selector # envelope multiplying the spline envelope = et_pair.rembed.layer.envelope - spl_rbasis = ET.TransSelSplines(trans_y, envelope, selector2, splines[1], states) + spl_rbasis = ET.trans_splines(trans_y, splines, selector2, envelope) return ETPairModel( ET.EdgeEmbed(spl_rbasis), et_pair.readout ) end @@ -52,11 +50,8 @@ function splinify(et_model::ETACE, et_ps, et_st; Nspl = 50) splines = [ P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) for i in 1:size(WW, 3) ] - # extract the spline parameters into an array of parameter sets - states = [ P4ML._init_luxstate(spl) for spl in splines ] - rembed_spl = ET.TransSelSplines(trans, trans_yenv, selector2, - splines[1], states) + rembed_spl = ET.trans_splines(trans, splines, selector2, trans_yenv) ace_spl = ETACE( ET.EdgeEmbed(rembed_spl), et_model.yembed, et_model.basis, diff --git a/test/etmodels/test_etace.jl b/test/etmodels/test_etace.jl index a1567b33..96dc3b87 100644 --- a/test/etmodels/test_etace.jl +++ b/test/etmodels/test_etace.jl @@ -1,6 +1,6 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) -using TestEnv; TestEnv.activate(); -Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "Polynomials4ML.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) @@ -271,6 +271,11 @@ E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) E4 = sum(et_model(G_32_dev, ps_dev, st_dev)[1]) println_slim( @test abs(E1 - E4) / (abs(E1) + abs(E4) + 1e-7) < 1e-5 ) +# Something still wrong evaluating the splines on GPU +# ps_50_dev = dev(ET.float32(ps_50)) +# st_50_dev = dev(ET.float32(st_50)) +# E5 = sum(spl_50(G_32_dev, ps_50_dev, st_50_dev)[1]) + ## # gradients on GPU @@ -303,4 +308,4 @@ println_slim( @test maximum(err_jac) < 1e-4 ) @show maximum(err_jac) @info("The jacobian error feels a bit large. This may need further investigation.") -=# \ No newline at end of file +=# \ No newline at end of file diff --git a/test/etmodels/test_etpair.jl b/test/etmodels/test_etpair.jl index 86de71a2..350fc6e3 100644 --- a/test/etmodels/test_etpair.jl +++ b/test/etmodels/test_etpair.jl @@ -1,6 +1,6 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) using TestEnv; TestEnv.activate(); -Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) @@ -188,11 +188,27 @@ E2_dev, st_dev = et_pair(G_dev, ps_dev, st_dev) E2 = Array(E2_dev) println_slim(@test E1 ≈ E2) +## + +@info(" .... with splines") +ps_50_32 = ET.float32(ps_50) +st_50_32 = ET.float32(st_50) +ps_50_dev = dev(ET.float32(ps_50)) +st_50_dev = dev(ET.float32(st_50)) +E3a, _ = spl_50(G_32, ps_50_32, st_50_32) +E3b_dev, _ = spl_50(G_dev, ps_50_dev, st_50_dev) +E3b = Array(E3b_dev) +println_slim(@test E3a ≈ E3b) + +## + +@info(" .... gradients on GPU") g1 = ETM.site_grads(et_pair, G_32, ps_32, st_32) g2_dev = ETM.site_grads(et_pair, G_dev, ps_dev, st_dev) g2_edge = Array(g2_dev.edge_data) println_slim(@test all(g1.edge_data .≈ g2_edge)) +@info(" .... basis on GPU") b1 = ETM.site_basis(et_pair, G_32, ps_32, st_32) b2_dev = ETM.site_basis(et_pair, G_dev, ps_dev, st_dev) b2 = Array(b2_dev) @@ -207,4 +223,5 @@ jacerr = norm.(∂db1 .- ∂db2) ./ (1 .+ norm.(∂db1) + norm.(∂db2)) @show maximum(jacerr) println_slim( @test maximum(jacerr) < 1e-4 ) +## =# \ No newline at end of file From 4a336931e03a9d8c6d262ea2a555ddeaf1fbffc8 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 31 Dec 2025 16:08:23 -0800 Subject: [PATCH 10/10] cleanuo --- test/etmodels/test_etpair.jl | 4 +- test/etmodels/test_splines.jl | 94 ++--------------------------------- 2 files changed, 7 insertions(+), 91 deletions(-) diff --git a/test/etmodels/test_etpair.jl b/test/etmodels/test_etpair.jl index 350fc6e3..3084dd90 100644 --- a/test/etmodels/test_etpair.jl +++ b/test/etmodels/test_etpair.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) -using TestEnv; TestEnv.activate(); +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) diff --git a/test/etmodels/test_splines.jl b/test/etmodels/test_splines.jl index 8380fb3c..f5cff75b 100644 --- a/test/etmodels/test_splines.jl +++ b/test/etmodels/test_splines.jl @@ -1,7 +1,7 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) -using TestEnv; TestEnv.activate(); -Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) -Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) # Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) ## @@ -94,7 +94,7 @@ selector2 = et_pair.rembed.layer.rbasis.post.selector trans_y = et_pair.rembed.layer.rbasis.trans envelope = et_pair.rembed.layer.envelope -spl_rbasis = ET.TransSelSplines(trans_y, envelope, selector2, splines[1], states) +spl_rbasis = ET.trans_splines(trans_y, splines, selector2, envelope) ps_spl, st_spl = LuxCore.setup(rng, spl_rbasis) poly_rbasis = et_pair.rembed.layer @@ -150,87 +150,3 @@ println_slim(@test df0 ≈ dp) ## - - -#= - -@info("Check total energies match") -for ntest = 1:30 - sys = rand_struct() - G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - E1 = AtomsCalculators.potential_energy(sys, calc_model) - E2 = energy_new(sys, et_pair) - print_tf( @test abs(ustrip(E1) - ustrip(E2)) < 1e-6 ) -end - -## - -@info("Check gradients and jacobians") - -sys = rand_struct() -G = ET.Atoms.interaction_graph(sys, rcut * u"Å") -nnodes = length(G.node_data) -iZ = et_pair.readout.selector.(G.node_data) -WW = et_ps.readout.W - -# gradient of model w.r.t. positions -∂G = ETM.site_grads(et_pair, G, et_ps, et_st) # test run - -# basis -𝔹1 = ETM.site_basis(et_pair, G, et_ps, et_st) - -# basis jacobian -𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_pair, G, et_ps, et_st) - -println_slim(@test 𝔹1 ≈ 𝔹2) - -∇Ei2 = reduce( hcat, ∂𝔹2[:, i, :] * WW[1, :, iZ[i]] - for (i, iz) in enumerate(iZ) ) -∇Ei3 = reshape(∇Ei2, size(∇Ei2)..., 1) -∇E_𝔹_edges = ET.rev_reshape_embedding(∇Ei3, G)[:] -println_slim(@test all(∇E_𝔹_edges .≈ ∂G.edge_data)) - -## - - -# turn off during CI -- need to sort out CI for GPU tests - - -@info("Check GPU evaluation") -using Metal -dev = Metal.mtl -ps_32 = ET.float32(et_ps) -st_32 = ET.float32(et_st) -ps_dev = dev(ps_32) -st_dev = dev(st_32) - -sys = rand_struct() -G = ET.Atoms.interaction_graph(sys, 5.0 * u"Å") -G_32 = ET.float32(G) -G_dev = dev(G_32) - -E1, st = et_pair(G_32, ps_32, st_32) -E2_dev, st_dev = et_pair(G_dev, ps_dev, st_dev) -E2 = Array(E2_dev) -println_slim(@test E1 ≈ E2) - -g1 = ETM.site_grads(et_pair, G_32, ps_32, st_32) -g2_dev = ETM.site_grads(et_pair, G_dev, ps_dev, st_dev) -g2_edge = Array(g2_dev.edge_data) -println_slim(@test all(g1.edge_data .≈ g2_edge)) - -b1 = ETM.site_basis(et_pair, G_32, ps_32, st_32) -b2_dev = ETM.site_basis(et_pair, G_dev, ps_dev, st_dev) -b2 = Array(b2_dev) -println_slim(@test b1 ≈ b2) - -b1, ∂db1 = ETM.site_basis_jacobian(et_pair, G_32, ps_32, st_32) -b2_dev, ∂db2_dev = ETM.site_basis_jacobian(et_pair, G_dev, ps_dev, st_dev) -b2 = Array(b2_dev) -∂db2 = Array(∂db2_dev) -println_slim(@test b1 ≈ b2) -jacerr = norm.(∂db1 .- ∂db2) ./ (1 .+ norm.(∂db1) + norm.(∂db2)) -@show maximum(jacerr) -println_slim( @test maximum(jacerr) < 1e-4 ) - -=# \ No newline at end of file