Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/et_models/et_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@ include("et_pair.jl")
# converstion utilities: convert from 0.8 style ACE models to ET based models
include("convert.jl")

# utilities to convert radial embeddings to splined versions
# for simplicity and performance and to freeze parameters
include("splinify.jl")

end
60 changes: 60 additions & 0 deletions src/et_models/splinify.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import EquivariantTensors as ET
import Polynomials4ML as P4ML

# These implementations of `splinify` expect a very specific structure of the
# pair potential basis. In principle it is possible to relax this
# considerably but it needs a little bit of thinking and planning/design
# work before just diving in. To be discussed when needed.

function splinify(et_pair::ETPairModel, et_ps, et_st;
Nspl = 30)

# polynomial basis taking y = y(r) as input
trans_y = et_pair.rembed.layer.rbasis.trans
polys_y = et_pair.rembed.layer.rbasis.basis
# weights for learnable radials
WW = et_ps.rembed.rbasis.post.W
# use P4ML to generate individual cubic splines
splines = [
P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl )
for i in 1:size(WW, 3) ]
# selects the correct spline based on the (Zi, Zj) pair
selector2 = et_pair.rembed.layer.rbasis.post.selector
# envelope multiplying the spline
envelope = et_pair.rembed.layer.envelope

spl_rbasis = ET.trans_splines(trans_y, splines, selector2, envelope)

return ETPairModel( ET.EdgeEmbed(spl_rbasis), et_pair.readout )
end


function splinify(et_model::ETACE, et_ps, et_st; Nspl = 50)

rembed = et_model.rembed.layer # radial embedding, edgeembed stripped
trans = rembed.trans # x -> y dp_transform
rpolys_env = rembed.basis # polynomials * envelope
polys_y = rpolys_env.l.layers.layer_1 # polynomial basis
yenv_func = rpolys_env.l.layers.layer_2.func # envelope function

# envelope multiplying the spline, apply the transformation a second
# time until we figure out how to reuse it conveniently
trans_yenv = ET.dp_transform(
(x, st) -> yenv_func(trans.f(x, st)),
trans.refstate )
# selects the correct spline based on the (Zi, Zj) pair
selector2 = rembed.post.selector
# generate the splines using P4ML
WW = et_ps.rembed.post.W
splines = [
P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl )
for i in 1:size(WW, 3) ]

rembed_spl = ET.trans_splines(trans, splines, selector2, trans_yenv)
ace_spl = ETACE( ET.EdgeEmbed(rembed_spl),
et_model.yembed,
et_model.basis,
et_model.readout )
return ace_spl
end
128 changes: 83 additions & 45 deletions test/etmodels/test_etbackend.jl → test/etmodels/test_etace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ level = M.TotalDegree()
max_level = 10
order = 3
maxl = 6
rcut = 5.5

# modify rin0cuts to have same cutoff for all elements
# TODO: there is currently a bug with variable cutoffs
# (?is there? The radials seem fine? check again)
rin0cuts = M._default_rin0cuts(elements)
rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts)
rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts)


model = M.ace_model(; elements = elements, order = order,
Expand All @@ -48,6 +49,10 @@ model = M.ace_model(; elements = elements, order = order,

ps, st = Lux.setup(rng, model)

# wrap the old ACE model into a calculator
calc_model = ACEpotentials.ACEPotential(model, ps, st)


# Missing issues:
# Vref = 0 => this will not be tested
# pair potential will also not be tested
Expand All @@ -62,41 +67,47 @@ end
# Convert the v0.8 model to an ET backend based model based on the
# implementation in ETM
#
et_model_2 = ETM.convert2et(model)
et_ps_2, et_st_2 = LuxCore.setup(MersenneTwister(1234), et_model_2)
et_model = ETM.convert2et(model)
et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model)

##
# fixup all the parameters to make sure they match
# the basis ordering appears to be identical, but it is not clear it really
# is because meta["mb_spec"] only gives the original ordering before basis
# construction ... something to look into.
nnll = M.get_nnll_spec(model.tensor)
et_nnll_2 = et_model_2.basis.meta["mb_spec"]
et_nnll = et_model.basis.meta["mb_spec"]
@info("Check basis ordering")
println_slim(@test nnll == et_nnll_2)
println_slim(@test nnll == et_nnll)

# but this is also identical ...
@info("Check symmetrization operator")
@show ( model.tensor.A2Bmaps[1] == et_model_2.basis.A2Bmaps[1] )
@show ( model.tensor.A2Bmaps[1] == et_model.basis.A2Bmaps[1] )

# radial basis parameters for et_model_2
et_ps_2.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1]
et_ps_2.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2]
et_ps_2.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1]
et_ps_2.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2]
# radial basis parameters for et_model
et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1]
et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2]
et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1]
et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2]

# many-body basis parameters for et_model_2
et_ps_2.readout.W[1, :, 1] .= ps.WB[:, 1]
et_ps_2.readout.W[1, :, 2] .= ps.WB[:, 2]
# many-body basis parameters for et_model
et_ps.readout.W[1, :, 1] .= ps.WB[:, 1]
et_ps.readout.W[1, :, 2] .= ps.WB[:, 2]

##

# wrap the old ACE model into a calculator
calc_model = ACEpotentials.ACEPotential(model, ps, st)
# setup two splined ACE models

spl_50 = ETM.splinify(et_model, et_ps, et_st; Nspl = 50)
ps_50, st_50 = Lux.setup(rng, spl_50)
ps_50.readout.W[:] .= et_ps.readout.W[:]

spl_200 = ETM.splinify(et_model, et_ps, et_st; Nspl = 200)
ps_200, st_200 = Lux.setup(rng, spl_200)
ps_200.readout.W[:] .= et_ps.readout.W[:]

##

# we will also need to get the cutoff radius which we didn't track
# (Another TODO!!!)
rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts)

function rand_struct()
sys = AtomsBuilder.bulk(:Si) * (2,2,1)
Expand All @@ -105,21 +116,25 @@ function rand_struct()
return sys
end

function energy_new_2(sys, et_model)
function energy_new(sys, _model, _ps, _st)
G = ET.Atoms.interaction_graph(sys, rcut * u"Å")
Ei, _ = et_model_2(G, et_ps_2, et_st_2)
Ei, _ = _model(G, _ps, _st)
return sum(Ei)
end

##

Random.seed!(1234)
@info("Check total energies match")
for ntest = 1:30
sys = rand_struct()
G = ET.Atoms.interaction_graph(sys, rcut * u"Å")
E1 = AtomsCalculators.potential_energy(sys, calc_model)
E3 = energy_new_2(sys, et_model_2)
print_tf( @test abs(ustrip(E1) - ustrip(E3)) < 1e-6 )
E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model))
E2 = energy_new(sys, et_model, et_ps, et_st)
E3 = energy_new(sys, spl_50, ps_50, st_50)
E4 = energy_new(sys, spl_200, ps_200, st_200)
print_tf( @test abs(E1 - E2) < 1e-6 )
print_tf( @test abs(E2 - E3) / (1+abs(E2)+abs(E3)) < 1e-2 )
print_tf( @test abs(E2 - E4) / (1+abs(E2)+abs(E4)) < 1e-4 )
end
println()

Expand All @@ -131,12 +146,20 @@ using Zygote, ForwardDiff

sys = rand_struct()
G = ET.Atoms.interaction_graph(sys, rcut * u"Å")
∂G2a = Zygote.gradient(G -> sum(et_model_2(G, et_ps_2, et_st_2)[1]), G)[1]
∂G2b = ETM.site_grads(et_model_2, G, et_ps_2, et_st_2)
∂G2a = Zygote.gradient(G -> sum(et_model(G, et_ps, et_st)[1]), G)[1]
∂G2b = ETM.site_grads(et_model, G, et_ps, et_st)

∂G_50 = ETM.site_grads(spl_50, G, ps_50, st_50)
∂G_200 = ETM.site_grads(spl_200, G, ps_200, st_200)

@info("confirm consistency of Zygote and site_grads")
println(@test all(∂G2a.edge_data .≈ ∂G2b.edge_data))

err_50 = maximum(norm.(∂G2b.edge_data - ∂G_50.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_50.edge_data)))
err_200 = maximum(norm.(∂G2b.edge_data - ∂G_200.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_200.edge_data)))
println_slim(@test err_50 < 1)
println_slim(@test err_200 < 0.01)

##
# test gradient against ForwardDiff

Expand Down Expand Up @@ -167,7 +190,7 @@ end

@info("confirm consistency of gradients with ForwardDiff")

∇E_fd = grad_fd(G, et_model_2, et_ps_2, et_st_2)
∇E_fd = grad_fd(G, et_model, et_ps, et_st)
println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data))

##
Expand All @@ -177,21 +200,31 @@ println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data))

G = ET.Atoms.interaction_graph(sys, rcut * u"Å")
nnodes = length(G.node_data)
iZ = et_model_2.readout.selector.(G.node_data)
WW = et_ps_2.readout.W
iZ = et_model.readout.selector.(G.node_data)
WW = et_ps.readout.W

𝔹1 = ETM.site_basis(et_model_2, G, et_ps_2, et_st_2)
𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model_2, G, et_ps_2, et_st_2)
𝔹1 = ETM.site_basis(et_model, G, et_ps, et_st)
𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model, G, et_ps, et_st)

##

@info("confirm correctness of site basis")

println_slim(@test 𝔹1 ≈ 𝔹2)
Ei_a = [ dot(𝔹2[i, :], WW[1, :, iZ[i]]) for (i, iz) in enumerate(iZ) ]
Ei_b = et_model_2(G, et_ps_2, et_st_2)[1][:]
Ei_b = et_model(G, et_ps, et_st)[1][:]
println_slim(@test Ei_a ≈ Ei_b)

##

@info("splined site basis")
𝔹_200 = ETM.site_basis(spl_200, G, ps_200, st_200)
𝔹2_200, ∂𝔹2_200 = ETM.site_basis_jacobian(spl_200, G, ps_200, st_200)

println_slim(@test 𝔹_200 ≈ 𝔹2_200 )
println_slim(@test norm(𝔹1 - 𝔹_200, Inf) < 3e-3)
println_slim(@test maximum(norm.(∂𝔹2 - ∂𝔹2_200)) < 0.1)

##

@info("Confirm correctness of Jacobian against gradient")
Expand Down Expand Up @@ -229,21 +262,26 @@ G_32 = ET.float32(G)

# move all data to the device
G_32_dev = dev(G_32)
ps_dev_2 = dev(ET.float32(et_ps_2))
st_dev_2 = dev(ET.float32(et_st_2))
ps_32_2 = ET.float32(et_ps_2)
st_32_2 = ET.float32(et_st_2)
ps_dev = dev(ET.float32(et_ps))
st_dev = dev(ET.float32(et_st))
ps_32 = ET.float32(et_ps)
st_32 = ET.float32(et_st)

E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model))
E4 = sum(et_model_2(G_32_dev, ps_dev_2, st_dev_2)[1])
E4 = sum(et_model(G_32_dev, ps_dev, st_dev)[1])
println_slim( @test abs(E1 - E4) / (abs(E1) + abs(E4) + 1e-7) < 1e-5 )

# Something still wrong evaluating the splines on GPU
# ps_50_dev = dev(ET.float32(ps_50))
# st_50_dev = dev(ET.float32(st_50))
# E5 = sum(spl_50(G_32_dev, ps_50_dev, st_50_dev)[1])

##
# gradients on GPU

@info("Check Evaluation of gradient on GPU")
g1 = ETM.site_grads(et_model_2, G_32, ps_32_2, st_32_2)
g2_dev = ETM.site_grads(et_model_2, G_32_dev, ps_dev_2, st_dev_2)
g1 = ETM.site_grads(et_model, G_32, ps_32, st_32)
g2_dev = ETM.site_grads(et_model, G_32_dev, ps_dev, st_dev)
∇1 = g1.edge_data
∇2 = Array(g2_dev.edge_data)
println_slim( @test all(∇1 .≈ ∇2) )
Expand All @@ -252,14 +290,14 @@ println_slim( @test all(∇1 .≈ ∇2) )

@info("Basis evaluation on GPU")

𝔹1 = ETM.site_basis(et_model_2, G_32, ps_32_2, st_32_2)
𝔹2_dev = ETM.site_basis(et_model_2, G_32_dev, ps_dev_2, st_dev_2)
𝔹1 = ETM.site_basis(et_model, G_32, ps_32, st_32)
𝔹2_dev = ETM.site_basis(et_model, G_32_dev, ps_dev, st_dev)
𝔹2 = Array(𝔹2_dev)
println_slim( @test 𝔹1 ≈ 𝔹2 )

@info("Basis jacobian evaluation on GPU")
𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model_2, G_32, ps_32_2, st_32_2)
𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model_2, G_32_dev, ps_dev_2, st_dev_2)
𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model, G_32, ps_32, st_32)
𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model, G_32_dev, ps_dev, st_dev)

𝔹2 = Array(𝔹2_dev)
∂𝔹2 = Array(∂𝔹2_dev)
Expand All @@ -270,4 +308,4 @@ println_slim( @test maximum(err_jac) < 1e-4 )
@show maximum(err_jac)
@info("The jacobian error feels a bit large. This may need further investigation.")

=#
=#
Loading
Loading