diff --git a/src/et_models/et_models.jl b/src/et_models/et_models.jl index 698b931e2..73fa48729 100644 --- a/src/et_models/et_models.jl +++ b/src/et_models/et_models.jl @@ -14,5 +14,8 @@ include("et_pair.jl") # converstion utilities: convert from 0.8 style ACE models to ET based models include("convert.jl") +# utilities to convert radial embeddings to splined versions +# for simplicity and performance and to freeze parameters +include("splinify.jl") end \ No newline at end of file diff --git a/src/et_models/splinify.jl b/src/et_models/splinify.jl new file mode 100644 index 000000000..e36359cea --- /dev/null +++ b/src/et_models/splinify.jl @@ -0,0 +1,60 @@ + +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +# These implementations of `splinify` expect a very specific structure of the +# pair potential basis. In principle it is possible to relax this +# considerably but it needs a little bit of thinking and planning/design +# work before just diving in. To be discussed when needed. + +function splinify(et_pair::ETPairModel, et_ps, et_st; + Nspl = 30) + + # polynomial basis taking y = y(r) as input + trans_y = et_pair.rembed.layer.rbasis.trans + polys_y = et_pair.rembed.layer.rbasis.basis + # weights for learnable radials + WW = et_ps.rembed.rbasis.post.W + # use P4ML to generate individual cubic splines + splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] + # selects the correct spline based on the (Zi, Zj) pair + selector2 = et_pair.rembed.layer.rbasis.post.selector + # envelope multiplying the spline + envelope = et_pair.rembed.layer.envelope + + spl_rbasis = ET.trans_splines(trans_y, splines, selector2, envelope) + + return ETPairModel( ET.EdgeEmbed(spl_rbasis), et_pair.readout ) +end + + +function splinify(et_model::ETACE, et_ps, et_st; Nspl = 50) + + rembed = et_model.rembed.layer # radial embedding, edgeembed stripped + trans = rembed.trans # x -> y dp_transform + rpolys_env = rembed.basis # polynomials * envelope + polys_y = rpolys_env.l.layers.layer_1 # polynomial basis + yenv_func = rpolys_env.l.layers.layer_2.func # envelope function + + # envelope multiplying the spline, apply the transformation a second + # time until we figure out how to reuse it conveniently + trans_yenv = ET.dp_transform( + (x, st) -> yenv_func(trans.f(x, st)), + trans.refstate ) + # selects the correct spline based on the (Zi, Zj) pair + selector2 = rembed.post.selector + # generate the splines using P4ML + WW = et_ps.rembed.post.W + splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] + + rembed_spl = ET.trans_splines(trans, splines, selector2, trans_yenv) + ace_spl = ETACE( ET.EdgeEmbed(rembed_spl), + et_model.yembed, + et_model.basis, + et_model.readout ) + return ace_spl +end \ No newline at end of file diff --git a/test/etmodels/test_etbackend.jl b/test/etmodels/test_etace.jl similarity index 64% rename from test/etmodels/test_etbackend.jl rename to test/etmodels/test_etace.jl index 83895a813..96dc3b87b 100644 --- a/test/etmodels/test_etbackend.jl +++ b/test/etmodels/test_etace.jl @@ -32,12 +32,13 @@ level = M.TotalDegree() max_level = 10 order = 3 maxl = 6 +rcut = 5.5 # modify rin0cuts to have same cutoff for all elements # TODO: there is currently a bug with variable cutoffs # (?is there? The radials seem fine? check again) rin0cuts = M._default_rin0cuts(elements) -rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) model = M.ace_model(; elements = elements, order = order, @@ -48,6 +49,10 @@ model = M.ace_model(; elements = elements, order = order, ps, st = Lux.setup(rng, model) +# wrap the old ACE model into a calculator +calc_model = ACEpotentials.ACEPotential(model, ps, st) + + # Missing issues: # Vref = 0 => this will not be tested # pair potential will also not be tested @@ -62,8 +67,8 @@ end # Convert the v0.8 model to an ET backend based model based on the # implementation in ETM # -et_model_2 = ETM.convert2et(model) -et_ps_2, et_st_2 = LuxCore.setup(MersenneTwister(1234), et_model_2) +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) ## # fixup all the parameters to make sure they match @@ -71,32 +76,38 @@ et_ps_2, et_st_2 = LuxCore.setup(MersenneTwister(1234), et_model_2) # is because meta["mb_spec"] only gives the original ordering before basis # construction ... something to look into. nnll = M.get_nnll_spec(model.tensor) -et_nnll_2 = et_model_2.basis.meta["mb_spec"] +et_nnll = et_model.basis.meta["mb_spec"] @info("Check basis ordering") -println_slim(@test nnll == et_nnll_2) +println_slim(@test nnll == et_nnll) # but this is also identical ... @info("Check symmetrization operator") -@show ( model.tensor.A2Bmaps[1] == et_model_2.basis.A2Bmaps[1] ) +@show ( model.tensor.A2Bmaps[1] == et_model.basis.A2Bmaps[1] ) -# radial basis parameters for et_model_2 -et_ps_2.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] -et_ps_2.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] -et_ps_2.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] -et_ps_2.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] +# radial basis parameters for et_model +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] -# many-body basis parameters for et_model_2 -et_ps_2.readout.W[1, :, 1] .= ps.WB[:, 1] -et_ps_2.readout.W[1, :, 2] .= ps.WB[:, 2] +# many-body basis parameters for et_model +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps.readout.W[1, :, 2] .= ps.WB[:, 2] ## -# wrap the old ACE model into a calculator -calc_model = ACEpotentials.ACEPotential(model, ps, st) +# setup two splined ACE models + +spl_50 = ETM.splinify(et_model, et_ps, et_st; Nspl = 50) +ps_50, st_50 = Lux.setup(rng, spl_50) +ps_50.readout.W[:] .= et_ps.readout.W[:] + +spl_200 = ETM.splinify(et_model, et_ps, et_st; Nspl = 200) +ps_200, st_200 = Lux.setup(rng, spl_200) +ps_200.readout.W[:] .= et_ps.readout.W[:] + +## -# we will also need to get the cutoff radius which we didn't track -# (Another TODO!!!) -rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) function rand_struct() sys = AtomsBuilder.bulk(:Si) * (2,2,1) @@ -105,21 +116,25 @@ function rand_struct() return sys end -function energy_new_2(sys, et_model) +function energy_new(sys, _model, _ps, _st) G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - Ei, _ = et_model_2(G, et_ps_2, et_st_2) + Ei, _ = _model(G, _ps, _st) return sum(Ei) end ## +Random.seed!(1234) @info("Check total energies match") for ntest = 1:30 sys = rand_struct() - G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - E1 = AtomsCalculators.potential_energy(sys, calc_model) - E3 = energy_new_2(sys, et_model_2) - print_tf( @test abs(ustrip(E1) - ustrip(E3)) < 1e-6 ) + E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) + E2 = energy_new(sys, et_model, et_ps, et_st) + E3 = energy_new(sys, spl_50, ps_50, st_50) + E4 = energy_new(sys, spl_200, ps_200, st_200) + print_tf( @test abs(E1 - E2) < 1e-6 ) + print_tf( @test abs(E2 - E3) / (1+abs(E2)+abs(E3)) < 1e-2 ) + print_tf( @test abs(E2 - E4) / (1+abs(E2)+abs(E4)) < 1e-4 ) end println() @@ -131,12 +146,20 @@ using Zygote, ForwardDiff sys = rand_struct() G = ET.Atoms.interaction_graph(sys, rcut * u"Å") -∂G2a = Zygote.gradient(G -> sum(et_model_2(G, et_ps_2, et_st_2)[1]), G)[1] -∂G2b = ETM.site_grads(et_model_2, G, et_ps_2, et_st_2) +∂G2a = Zygote.gradient(G -> sum(et_model(G, et_ps, et_st)[1]), G)[1] +∂G2b = ETM.site_grads(et_model, G, et_ps, et_st) + +∂G_50 = ETM.site_grads(spl_50, G, ps_50, st_50) +∂G_200 = ETM.site_grads(spl_200, G, ps_200, st_200) @info("confirm consistency of Zygote and site_grads") println(@test all(∂G2a.edge_data .≈ ∂G2b.edge_data)) +err_50 = maximum(norm.(∂G2b.edge_data - ∂G_50.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_50.edge_data))) +err_200 = maximum(norm.(∂G2b.edge_data - ∂G_200.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_200.edge_data))) +println_slim(@test err_50 < 1) +println_slim(@test err_200 < 0.01) + ## # test gradient against ForwardDiff @@ -167,7 +190,7 @@ end @info("confirm consistency of gradients with ForwardDiff") -∇E_fd = grad_fd(G, et_model_2, et_ps_2, et_st_2) +∇E_fd = grad_fd(G, et_model, et_ps, et_st) println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data)) ## @@ -177,11 +200,11 @@ println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data)) G = ET.Atoms.interaction_graph(sys, rcut * u"Å") nnodes = length(G.node_data) -iZ = et_model_2.readout.selector.(G.node_data) -WW = et_ps_2.readout.W +iZ = et_model.readout.selector.(G.node_data) +WW = et_ps.readout.W -𝔹1 = ETM.site_basis(et_model_2, G, et_ps_2, et_st_2) -𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model_2, G, et_ps_2, et_st_2) +𝔹1 = ETM.site_basis(et_model, G, et_ps, et_st) +𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model, G, et_ps, et_st) ## @@ -189,9 +212,19 @@ WW = et_ps_2.readout.W println_slim(@test 𝔹1 ≈ 𝔹2) Ei_a = [ dot(𝔹2[i, :], WW[1, :, iZ[i]]) for (i, iz) in enumerate(iZ) ] -Ei_b = et_model_2(G, et_ps_2, et_st_2)[1][:] +Ei_b = et_model(G, et_ps, et_st)[1][:] println_slim(@test Ei_a ≈ Ei_b) +## + +@info("splined site basis") +𝔹_200 = ETM.site_basis(spl_200, G, ps_200, st_200) +𝔹2_200, ∂𝔹2_200 = ETM.site_basis_jacobian(spl_200, G, ps_200, st_200) + +println_slim(@test 𝔹_200 ≈ 𝔹2_200 ) +println_slim(@test norm(𝔹1 - 𝔹_200, Inf) < 3e-3) +println_slim(@test maximum(norm.(∂𝔹2 - ∂𝔹2_200)) < 0.1) + ## @info("Confirm correctness of Jacobian against gradient") @@ -229,21 +262,26 @@ G_32 = ET.float32(G) # move all data to the device G_32_dev = dev(G_32) -ps_dev_2 = dev(ET.float32(et_ps_2)) -st_dev_2 = dev(ET.float32(et_st_2)) -ps_32_2 = ET.float32(et_ps_2) -st_32_2 = ET.float32(et_st_2) +ps_dev = dev(ET.float32(et_ps)) +st_dev = dev(ET.float32(et_st)) +ps_32 = ET.float32(et_ps) +st_32 = ET.float32(et_st) E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) -E4 = sum(et_model_2(G_32_dev, ps_dev_2, st_dev_2)[1]) +E4 = sum(et_model(G_32_dev, ps_dev, st_dev)[1]) println_slim( @test abs(E1 - E4) / (abs(E1) + abs(E4) + 1e-7) < 1e-5 ) +# Something still wrong evaluating the splines on GPU +# ps_50_dev = dev(ET.float32(ps_50)) +# st_50_dev = dev(ET.float32(st_50)) +# E5 = sum(spl_50(G_32_dev, ps_50_dev, st_50_dev)[1]) + ## # gradients on GPU @info("Check Evaluation of gradient on GPU") -g1 = ETM.site_grads(et_model_2, G_32, ps_32_2, st_32_2) -g2_dev = ETM.site_grads(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +g1 = ETM.site_grads(et_model, G_32, ps_32, st_32) +g2_dev = ETM.site_grads(et_model, G_32_dev, ps_dev, st_dev) ∇1 = g1.edge_data ∇2 = Array(g2_dev.edge_data) println_slim( @test all(∇1 .≈ ∇2) ) @@ -252,14 +290,14 @@ println_slim( @test all(∇1 .≈ ∇2) ) @info("Basis evaluation on GPU") -𝔹1 = ETM.site_basis(et_model_2, G_32, ps_32_2, st_32_2) -𝔹2_dev = ETM.site_basis(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +𝔹1 = ETM.site_basis(et_model, G_32, ps_32, st_32) +𝔹2_dev = ETM.site_basis(et_model, G_32_dev, ps_dev, st_dev) 𝔹2 = Array(𝔹2_dev) println_slim( @test 𝔹1 ≈ 𝔹2 ) @info("Basis jacobian evaluation on GPU") -𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model_2, G_32, ps_32_2, st_32_2) -𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model, G_32, ps_32, st_32) +𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model, G_32_dev, ps_dev, st_dev) 𝔹2 = Array(𝔹2_dev) ∂𝔹2 = Array(∂𝔹2_dev) @@ -270,4 +308,4 @@ println_slim( @test maximum(err_jac) < 1e-4 ) @show maximum(err_jac) @info("The jacobian error feels a bit large. This may need further investigation.") -=# \ No newline at end of file +=# \ No newline at end of file diff --git a/test/etmodels/test_etpair.jl b/test/etmodels/test_etpair.jl index 73c5c9dd4..3084dd90e 100644 --- a/test/etmodels/test_etpair.jl +++ b/test/etmodels/test_etpair.jl @@ -74,6 +74,22 @@ et_ps.rembed.rbasis.post.W[:, :, 4] = ps.pairbasis.Wnlq[:, :, 2, 2] et_ps.readout.W[1, :, 1] .= ps.Wpair[:, 1] et_ps.readout.W[1, :, 2] .= ps.Wpair[:, 2] +## +# +# make a splined version of the et_pair model +# + +spl_50 = ETM.splinify(et_pair, et_ps, et_st; Nspl = 50) +ps_50, st_50 = Lux.setup(rng, spl_50) + +spl_200 = ETM.splinify(et_pair, et_ps, et_st; Nspl = 200) +ps_200, st_200 = Lux.setup(rng, spl_200) + +# many-body basis parameters for et_model_2 +ps_50.readout.W[:] = et_ps.readout.W +ps_200.readout.W[:] = et_ps.readout.W + + ## # # test energy evaluations @@ -88,21 +104,26 @@ function rand_struct() return sys end -function energy_new(sys, et_model) +function energy_new(sys, et_model, ps, st) G = ET.Atoms.interaction_graph(sys, rcut * u"Å") - Ei, _ = et_model(G, et_ps, et_st) + Ei, _ = et_model(G, ps, st) return sum(Ei) end ## +Random.seed!(1234) @info("Check total energies match") for ntest = 1:30 sys = rand_struct() G = ET.Atoms.interaction_graph(sys, rcut * u"Å") E1 = AtomsCalculators.potential_energy(sys, calc_model) - E2 = energy_new(sys, et_pair) + E2 = energy_new(sys, et_pair, et_ps, et_st) + E_50 = energy_new(sys, spl_50, ps_50, st_50) + E_200 = energy_new(sys, spl_200, ps_200, st_200) print_tf( @test abs(ustrip(E1) - ustrip(E2)) < 1e-6 ) + print_tf( @test abs(ustrip(E2) - ustrip(E_50)) < 1e-2 ) + print_tf( @test abs(ustrip(E2) - ustrip(E_200)) < 1e-4 ) end ## @@ -116,15 +137,20 @@ iZ = et_pair.readout.selector.(G.node_data) WW = et_ps.readout.W # gradient of model w.r.t. positions -∂G = ETM.site_grads(et_pair, G, et_ps, et_st) # test run +∂G = ETM.site_grads(et_pair, G, et_ps, et_st) +∂G_200 = ETM.site_grads(spl_200, G, ps_200, st_200) # basis 𝔹1 = ETM.site_basis(et_pair, G, et_ps, et_st) +𝔹1_200 = ETM.site_basis(spl_200, G, ps_200, st_200) # basis jacobian 𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_pair, G, et_ps, et_st) +𝔹2_200, ∂𝔹2_200 = ETM.site_basis_jacobian(spl_200, G, ps_200, st_200) println_slim(@test 𝔹1 ≈ 𝔹2) +println_slim(@test 𝔹1_200 ≈ 𝔹2_200) +println_slim(@test norm(𝔹1 - 𝔹1_200, Inf) < 1e-4) ∇Ei2 = reduce( hcat, ∂𝔹2[:, i, :] * WW[1, :, iZ[i]] for (i, iz) in enumerate(iZ) ) @@ -132,6 +158,11 @@ println_slim(@test 𝔹1 ≈ 𝔹2) ∇E_𝔹_edges = ET.rev_reshape_embedding(∇Ei3, G)[:] println_slim(@test all(∇E_𝔹_edges .≈ ∂G.edge_data)) +# check error in site energy gradients for splines +println_slim(@test maximum(norm.(∂G.edge_data - ∂G_200.edge_data)) < 1e-3) +# check error in basis jacobian for splines +println_slim(@test maximum(norm.(∂𝔹2 - ∂𝔹2_200)) < 1e-3) + ## @@ -157,11 +188,27 @@ E2_dev, st_dev = et_pair(G_dev, ps_dev, st_dev) E2 = Array(E2_dev) println_slim(@test E1 ≈ E2) +## + +@info(" .... with splines") +ps_50_32 = ET.float32(ps_50) +st_50_32 = ET.float32(st_50) +ps_50_dev = dev(ET.float32(ps_50)) +st_50_dev = dev(ET.float32(st_50)) +E3a, _ = spl_50(G_32, ps_50_32, st_50_32) +E3b_dev, _ = spl_50(G_dev, ps_50_dev, st_50_dev) +E3b = Array(E3b_dev) +println_slim(@test E3a ≈ E3b) + +## + +@info(" .... gradients on GPU") g1 = ETM.site_grads(et_pair, G_32, ps_32, st_32) g2_dev = ETM.site_grads(et_pair, G_dev, ps_dev, st_dev) g2_edge = Array(g2_dev.edge_data) println_slim(@test all(g1.edge_data .≈ g2_edge)) +@info(" .... basis on GPU") b1 = ETM.site_basis(et_pair, G_32, ps_32, st_32) b2_dev = ETM.site_basis(et_pair, G_dev, ps_dev, st_dev) b2 = Array(b2_dev) @@ -176,4 +223,5 @@ jacerr = norm.(∂db1 .- ∂db2) ./ (1 .+ norm.(∂db1) + norm.(∂db2)) @show maximum(jacerr) println_slim( @test maximum(jacerr) < 1e-4 ) +## =# \ No newline at end of file diff --git a/test/etmodels/test_splines.jl b/test/etmodels/test_splines.jl new file mode 100644 index 000000000..f5cff75b4 --- /dev/null +++ b/test/etmodels/test_splines.jl @@ -0,0 +1,152 @@ +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) + +## + +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful, + AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase, + ForwardDiff + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML +import DecoratedParticles as DP + +using Polynomials4ML.Testing: print_tf, println_slim + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## + +# Generate an ACE model in the v0.8 style but +# - with fixed rcut. (relaxe this requirement later!!) +# get the pair potential component, compare with ETPairModel +# make pair_learnable = true to prevent splinification. + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 +rcut = 5.5 + +# modify rin0cuts to have same cutoff for all elements +# TODO: there is currently a bug with variable cutoffs +# (?is there? The radials seem fine? check again) +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) + + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true ) + +Random.seed!(1234) # new seed to make sure the tests are consistent +ps, st = Lux.setup(rng, model) + +# confirm that the E0s are all zero +@assert all( collect(values(model.Vref.E0)) .== 0 ) + +# set the many-body parameters to zero to isolate the pair potential +ps.WB[:] .= 0 + +## +# +# construct an ETPairModel that is consistent with `model` +# fixup the parameters to match the ACE model + +et_pair = ETM.convertpair(model) +et_ps, et_st = Lux.setup(rng, et_pair) + +# radial basis parameters for et_model_2 +et_ps.rembed.rbasis.post.W[:, :, 1] = ps.pairbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.rbasis.post.W[:, :, 2] = ps.pairbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.rbasis.post.W[:, :, 3] = ps.pairbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.rbasis.post.W[:, :, 4] = ps.pairbasis.Wnlq[:, :, 2, 2] + +# many-body basis parameters for et_model_2 +et_ps.readout.W[1, :, 1] .= ps.Wpair[:, 1] +et_ps.readout.W[1, :, 2] .= ps.Wpair[:, 2] + +## +# convert the pair basis to a splined version + +# overkill spline accuracy to check errors +Nspl = 200 + +# polynomial basis taking y = y(r) as input +polys_y = et_pair.rembed.layer.rbasis.basis +# weights for cat-1 +WW = et_ps.rembed.rbasis.post.W +splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] +states = [ P4ML._init_luxstate(spl) for spl in splines ] +selector2 = et_pair.rembed.layer.rbasis.post.selector +trans_y = et_pair.rembed.layer.rbasis.trans +envelope = et_pair.rembed.layer.envelope + +spl_rbasis = ET.trans_splines(trans_y, splines, selector2, envelope) +ps_spl, st_spl = LuxCore.setup(rng, spl_rbasis) + +poly_rbasis = et_pair.rembed.layer +ps_poly = et_ps.rembed +st_poly = et_st.rembed + +## + +function rand_X() + sys = AtomsBuilder.bulk(:Si) * (2,2,1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + return G.edge_data +end + + +## + +@info("Checking spline accuracy against polynomial basis") +Random.seed!(1234) # new seed to make sure the tests are ok. +for ntest = 1:30 + X = rand_X() + P1, _ = poly_rbasis(X, ps_poly, st_poly) + P2, _ = spl_rbasis(X, ps_spl, st_spl) + spl_err = abs.(P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1) + # @show maximum(spl_err) + print_tf(@test maximum(spl_err) < 1e-5) + + (P1a, dP1), _ = ET.evaluate_ed(poly_rbasis, X, ps_poly, st_poly) + (P2a, dP2), _ = ET.evaluate_ed(spl_rbasis, X, ps_spl, st_spl) + print_tf(@test P2 ≈ P2a) + dspl_err = norm.(dP1 - dP2) ./ (1 .+ abs.(P1) + abs.(P2)) + # @show maximum(dspl_err) + print_tf(@test maximum(dspl_err) < 1e-3) +end + +## + +@info("Checking machine precision derivative accuracy ") +# NOTE: This test should really be in ET and not here ... + +X = rand_X() +rand_u() = ( u = (@SVector randn(3)); DP.VState(𝐫 = u/norm(u)) ) +U = [ rand_u() for _ = 1:length(X) ] + +f(t) = spl_rbasis(X + t * U, ps_spl, st_spl)[1] +df0 = ForwardDiff.derivative(f, 0.0) + +(P2a, dP2), _ = ET.evaluate_ed(spl_rbasis, X, ps_spl, st_spl) +dp = [ dot(U[i], dP2[i, j]) for i in 1:length(U), j = 1:size(dP2, 2) ] +println_slim(@test df0 ≈ dp) + + +## diff --git a/test/runtests.jl b/test/runtests.jl index 3da7fde65..624c7b65f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,7 @@ using ACEpotentials, Test, LazyArtifacts @testset "Weird bugs" begin include("test_bugs.jl") end # new ET backend tests - @testset "ET ACE" begin include("etmodels/test_etbackend.jl") end + @testset "ET ACE" begin include("etmodels/test_etace.jl") end @testset "ET OneBody" begin include("etmodels/test_etonebody.jl") end @testset "ET Pair" begin include("etmodels/test_etpair.jl") end