diff --git a/Project.toml b/Project.toml index 961df9590..77d88e76c 100644 --- a/Project.toml +++ b/Project.toml @@ -65,7 +65,7 @@ ConcreteStructs = "0.2.3" DecoratedParticles = "0.1.3" DynamicPolynomials = "0.6" EmpiricalPotentials = "0.2" -EquivariantTensors = "0.4" +EquivariantTensors = "0.4.3" ExtXYZ = "0.2.0" Folds = "0.2" ForwardDiff = "0.10, 1" @@ -79,7 +79,7 @@ OffsetArrays = "1" Optim = "1" Optimisers = "0.3.4, 0.4" OrderedCollections = "1" -Polynomials4ML = "0.5.6" +Polynomials4ML = "0.5" PrettyTables = "1.3, 2" Reexport = "1" Roots = "2" diff --git a/benchmark/benchmark_full_model.jl b/benchmark/benchmark_full_model.jl new file mode 100644 index 000000000..1a4c859f0 --- /dev/null +++ b/benchmark/benchmark_full_model.jl @@ -0,0 +1,198 @@ +# Benchmark: Full model (1+2+many body) with StackedCalculator +# Compares ACE CPU vs ETACE CPU vs ETACE GPU for energy and forces + +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators +using StaticArrays, Lux, Random, LuxCore, LinearAlgebra +using AtomsBase, AtomsBuilder, Unitful +using BenchmarkTools +using Printf + +# GPU detection +dev = identity +has_cuda = false + +try + using CUDA + if CUDA.functional() + @info "Using CUDA" + CUDA.versioninfo() + global has_cuda = true + global dev = cu + else + @info "CUDA is not functional" + end +catch e + @info "Couldn't load CUDA: $e" +end + +if !has_cuda + @info "No GPU available. Using CPU only." +end + +rng = Random.MersenneTwister(1234) + +# Build models with E0s and pair potential enabled +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +# E0s for one-body +E0s = Dict(:Si => -158.54496821, :O => -2042.0330099956639) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true, # Keep learnable for ET conversion + E0s = E0s) + +ps, st = Lux.setup(rng, model) + +# Create old ACE calculator (full model with E0s and pair) +ace_calc = M.ACEPotential(model, ps, st) + +# Convert to full ETACE with StackedCalculator +et_calc = ETM.convert2et_full(model, ps, st) + +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# Function to create system of given size +function make_system(n_repeat) + sys = AtomsBuilder.bulk(:Si, cubic=true) * n_repeat + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +# Benchmark configurations +configs = [ + (2, 2, 2), # 64 atoms + (3, 3, 2), # 144 atoms + (4, 4, 2), # 256 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("=" ^ 90) +println("BENCHMARK: Full Model (1+2+many body) - ACE vs ETACE StackedCalculator") +println("=" ^ 90) +println() + +# --- ENERGY BENCHMARK --- +println("### ENERGY ###") +println() + +if has_cuda + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE GPU (ms) | CPU Speedup | GPU Speedup |") + println("|-------|---------|--------------|----------------|----------------|-------------|-------------|") +else + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") + println("|-------|---------|--------------|----------------|-------------|") +end + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.potential_energy(sys, ace_calc) + + # Warmup ETACE CPU + _ = AtomsCalculators.potential_energy(sys, et_calc) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.potential_energy($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU + t_etace_cpu = @belapsed AtomsCalculators.potential_energy($sys, $et_calc) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + if has_cuda + # For GPU we need to handle the StackedCalculator with GPU-capable models + # TODO: GPU version of StackedCalculator + t_etace_gpu_ms = NaN + gpu_speedup = NaN + + @printf("| %5d | %7d | %12.2f | %14.2f | %14s | %10.1fx | %10s |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, "N/A", cpu_speedup, "N/A") + else + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) + end +end + +println() + +# --- FORCES BENCHMARK --- +println("### FORCES ###") +println() + +if has_cuda + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE GPU (ms) | CPU Speedup | GPU Speedup |") + println("|-------|---------|--------------|----------------|----------------|-------------|-------------|") +else + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") + println("|-------|---------|--------------|----------------|-------------|") +end + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.forces(sys, ace_calc) + + # Warmup ETACE CPU + _ = AtomsCalculators.forces(sys, et_calc) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.forces($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU + t_etace_cpu = @belapsed AtomsCalculators.forces($sys, $et_calc) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + if has_cuda + t_etace_gpu_ms = NaN + gpu_speedup = NaN + + @printf("| %5d | %7d | %12.2f | %14.2f | %14s | %10.1fx | %10s |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, "N/A", cpu_speedup, "N/A") + else + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) + end +end + +println() +println("Notes:") +println("- ACE CPU: Original ACEpotentials model (full: E0 + pair + many-body)") +println("- ETACE CPU: StackedCalculator with ETOneBody + ETPairModel + ETACE") +println("- CPU Speedup = ACE CPU / ETACE CPU") +println("- Graph construction time included in ETACE timings") diff --git a/benchmark/gpu_benchmark.jl b/benchmark/gpu_benchmark.jl new file mode 100644 index 000000000..e861be532 --- /dev/null +++ b/benchmark/gpu_benchmark.jl @@ -0,0 +1,317 @@ +# GPU Benchmark for ETACE Models +# Run with: julia --project=test benchmark/gpu_benchmark.jl +# +# Tests: ETOneBody, ETPairModel, ETACE (many-body), and combined full model + +using CUDA +using LuxCUDA + +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +using Lux, LuxCore, Random +using AtomsBase, AtomsBuilder, Unitful +using Printf + +println("CUDA available: ", CUDA.functional()) + +# Build model +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +E0s = Dict(:Si => -158.54496821, :O => -2042.0330099956639) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true, + E0s = E0s) + +rng = Random.MersenneTwister(1234) +ps, st = Lux.setup(rng, model) + +rcut = 5.5 +NZ = 2 + +# ============================================================================ +# 1. ETACE (many-body only) +# ============================================================================ +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(rng, et_model) + +# Copy parameters +for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + et_ps.rembed.post.W[:, :, idx] .= ps.rbasis.Wnlq[:, :, i, j] +end +for iz in 1:NZ + et_ps.readout.W[1, :, iz] .= ps.WB[:, iz] +end + +# ============================================================================ +# 2. ETPairModel +# ============================================================================ +et_pair = ETM.convertpair(model) +pair_ps, pair_st = LuxCore.setup(rng, et_pair) + +# Copy pair parameters +for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + pair_ps.rembed.rbasis.post.W[:, :, idx] .= ps.pairbasis.Wnlq[:, :, i, j] +end +for iz in 1:NZ + pair_ps.readout.W[1, :, iz] .= ps.Wpair[:, iz] +end + +# ============================================================================ +# 3. ETOneBody +# ============================================================================ +zlist = ChemicalSpecies.((:Si, :O)) +E0_dict = Dict(z => E0s[Symbol(z)] for z in zlist) +et_onebody = ETM.one_body(E0_dict, x -> x.z) +onebody_ps, onebody_st = LuxCore.setup(rng, et_onebody) + +# GPU device +gdev = Lux.gpu_device() +println("GPU device: ", gdev) + +# Benchmark configurations +configs = [ + (2, 2, 2), # 64 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("="^80) +println("GPU BENCHMARK: ETACE Models (with P4ML v0.5.8)") +println("="^80) + +# ============================================================================ +# SECTION 1: Many-Body Only (ETACE) +# ============================================================================ +println() +println("### MANY-BODY ONLY (ETACE) - ENERGY ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup CPU + _ = et_model(G, et_ps, et_st) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:10 + et_model(G, et_ps, et_st) + end + t_cpu_ms = (t_cpu / 10) * 1000 + + # GPU setup + G_gpu = gdev(G) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + # Warmup GPU + CUDA.@sync et_model(G_gpu, et_ps_gpu, et_st_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:10 + CUDA.@sync et_model(G_gpu, et_ps_gpu, et_st_gpu) + end + t_gpu_ms = (t_gpu / 10) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +println() +println("### MANY-BODY ONLY (ETACE) - FORCES ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup CPU + _ = ETM.site_grads(et_model, G, et_ps, et_st) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:5 + ETM.site_grads(et_model, G, et_ps, et_st) + end + t_cpu_ms = (t_cpu / 5) * 1000 + + # GPU setup + G_gpu = gdev(G) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + # Warmup GPU + CUDA.@sync ETM.site_grads(et_model, G_gpu, et_ps_gpu, et_st_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:5 + CUDA.@sync ETM.site_grads(et_model, G_gpu, et_ps_gpu, et_st_gpu) + end + t_gpu_ms = (t_gpu / 5) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +# ============================================================================ +# SECTION 2: Full Model (E0 + Pair + Many-Body) +# ============================================================================ +println() +println("="^80) +println("### FULL MODEL (E0 + Pair + Many-Body) - ENERGY ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # CPU: evaluate all three models + function full_energy_cpu(G) + E_onebody, _ = et_onebody(G, onebody_ps, onebody_st) + E_pair, _ = et_pair(G, pair_ps, pair_st) + E_mb, _ = et_model(G, et_ps, et_st) + return sum(E_onebody) + sum(E_pair) + sum(E_mb) + end + + # Warmup CPU + _ = full_energy_cpu(G) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:10 + full_energy_cpu(G) + end + t_cpu_ms = (t_cpu / 10) * 1000 + + # GPU setup - all models + G_gpu = gdev(G) + onebody_ps_gpu = gdev(onebody_ps) + onebody_st_gpu = gdev(onebody_st) + pair_ps_gpu = gdev(pair_ps) + pair_st_gpu = gdev(pair_st) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + function full_energy_gpu(G_gpu) + E_onebody, _ = et_onebody(G_gpu, onebody_ps_gpu, onebody_st_gpu) + E_pair, _ = et_pair(G_gpu, pair_ps_gpu, pair_st_gpu) + E_mb, _ = et_model(G_gpu, et_ps_gpu, et_st_gpu) + return sum(E_onebody) + sum(E_pair) + sum(E_mb) + end + + # Warmup GPU + CUDA.@sync full_energy_gpu(G_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:10 + CUDA.@sync full_energy_gpu(G_gpu) + end + t_gpu_ms = (t_gpu / 10) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +println() +println("### FULL MODEL (E0 + Pair + Many-Body) - FORCES ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # CPU: evaluate all three gradients + function full_grads_cpu(G) + ∂onebody = ETM.site_grads(et_onebody, G, onebody_ps, onebody_st) + ∂pair = ETM.site_grads(et_pair, G, pair_ps, pair_st) + ∂mb = ETM.site_grads(et_model, G, et_ps, et_st) + return (∂onebody, ∂pair, ∂mb) + end + + # Warmup CPU + _ = full_grads_cpu(G) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:5 + full_grads_cpu(G) + end + t_cpu_ms = (t_cpu / 5) * 1000 + + # GPU setup - all models + G_gpu = gdev(G) + onebody_ps_gpu = gdev(onebody_ps) + onebody_st_gpu = gdev(onebody_st) + pair_ps_gpu = gdev(pair_ps) + pair_st_gpu = gdev(pair_st) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + function full_grads_gpu(G_gpu) + ∂onebody = ETM.site_grads(et_onebody, G_gpu, onebody_ps_gpu, onebody_st_gpu) + ∂pair = ETM.site_grads(et_pair, G_gpu, pair_ps_gpu, pair_st_gpu) + ∂mb = ETM.site_grads(et_model, G_gpu, et_ps_gpu, et_st_gpu) + return (∂onebody, ∂pair, ∂mb) + end + + # Warmup GPU + CUDA.@sync full_grads_gpu(G_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:5 + CUDA.@sync full_grads_gpu(G_gpu) + end + t_gpu_ms = (t_gpu / 5) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +println() diff --git a/docs/Project.toml b/docs/Project.toml index 1facd1b7c..8b19a1034 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,15 +5,19 @@ AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EquivariantTensors = "5e107534-7145-4f8f-b06f-47a52840c895" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/docs/make.jl b/docs/make.jl index ed112b1e3..285a3c7b2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -28,8 +28,14 @@ Literate.markdown(_tutorial_src * "/dataset_analysis.jl", Literate.markdown(_tutorial_src * "/descriptor.jl", _tutorial_out; documenter = true) -Literate.markdown(_tutorial_src * "/asp.jl", +Literate.markdown(_tutorial_src * "/asp.jl", _tutorial_out; documenter = true) + +# ETACE tutorial lives in examples/ to avoid duplication +_examples_src = joinpath(@__DIR__(), "..", "examples", "etmodels") +Literate.markdown(_examples_src * "/etace_tutorial.jl", + _tutorial_out; documenter = true) + # Literate.markdown(_tutorial_src * "/first_example_model.jl", # _tutorial_out; documenter = true) @@ -70,9 +76,10 @@ makedocs(; "literate_tutorials/basic_julia_workflow.md", "literate_tutorials/smoothness_priors.md", "literate_tutorials/dataset_analysis.md", - "tutorials/scripting.md", + "tutorials/scripting.md", "literate_tutorials/descriptor.md", "literate_tutorials/asp.md", + "literate_tutorials/etace_tutorial.md", ], "Additional Topics" => Any[ "gettingstarted/parallel-fitting.md", diff --git a/docs/plans/et_calculators_plan.md b/docs/plans/et_calculators_plan.md new file mode 100644 index 000000000..8522b0577 --- /dev/null +++ b/docs/plans/et_calculators_plan.md @@ -0,0 +1,231 @@ +# Plan: ETACE Calculator Interface and Training Support + +## Overview + +Create calculator wrappers and training assembly for the new ETACE backend, integrating with EquivariantTensors.jl. + +**Status**: ✅ Core implementation complete. GPU acceleration working. PR #313 under review. + +**Branch**: `jrk/etcalculators` (rebased on `acesuit/co/etback` including `co/etpair` merge) + +**PR**: https://github.com/ACEsuit/ACEpotentials.jl/pull/313 + +--- + +## Progress Summary + +| Phase | Description | Status | +|-------|-------------|--------| +| Phase 1 | ETACEPotential with AtomsCalculators interface | ✅ Complete | +| Phase 2 | WrappedSiteCalculator + StackedCalculator | ✅ Complete | +| Phase 3 | E0Model + PairModel | ✅ Complete (upstream ETOneBody, ETPairModel) | +| Phase 5 | Training assembly functions | ✅ Complete (many-body only) | +| Phase 6 | Full model integration | ✅ Complete | +| Benchmarks | CPU + GPU performance comparison | ✅ Complete | + +### Key Design Decision: Unified Architecture + +**All upstream ETACE-pattern models share the same interface:** + +| Method | ETACE | ETPairModel | ETOneBody | +|--------|-------|-------------|-----------| +| `model(G, ps, st)` | site energies | site energies | site energies | +| `site_grads(model, G, ps, st)` | edge gradients | edge gradients | zero gradients | +| `site_basis(model, G, ps, st)` | basis matrix | basis matrix | empty | +| `site_basis_jacobian(model, G, ps, st)` | (basis, jac) | (basis, jac) | (empty, empty) | + +This enables a **unified `WrappedSiteCalculator`** that works with all three model types directly. + +--- + +## Benchmark Results + +### GPU Benchmarks (Many-Body Only - ETACE) + +**Energy:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2146 | 3.38 | 0.54 | **6.3x** | +| 512 | 17176 | 27.77 | 0.66 | **41.9x** | +| 800 | 26868 | 37.12 | 0.78 | **47.6x** | + +**Forces:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2146 | 46.46 | 14.42 | **3.2x** | +| 512 | 17178 | 104.39 | 15.12 | **6.9x** | +| 800 | 26860 | 289.32 | 16.33 | **17.7x** | + +### GPU Benchmarks (Full Model - E0 + Pair + Many-Body) + +**Energy:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2140 | 3.40 | 0.94 | **3.6x** | +| 512 | 17166 | 31.18 | 0.95 | **32.9x** | +| 800 | 26858 | 45.16 | 1.24 | **36.4x** | + +**Forces:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2134 | 24.05 | 19.34 | **1.2x** | +| 512 | 17178 | ~110 | ~20 | **~5x** | +| 800 | 26860 | ~300 | ~22 | **~14x** | + +### CPU Benchmarks (ETACE vs Classic ACE) + +**Forces (Full Model):** +| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE Speedup | +|-------|-------|--------------|----------------|---------------| +| 64 | 2146 | 73.6 | 30.5 | **2.4x** | +| 256 | 8596 | 307.7 | 74.4 | **4.1x** | +| 800 | 26886 | 975.0 | 225.6 | **4.3x** | + +**Notes:** +- GPU forces require Polynomials4ML v0.5.8+ (bug fix Dec 29, 2024) +- GPU shows excellent scaling: larger systems see better speedups +- Full model GPU speedups are lower than many-body only due to graph construction overhead +- CPU forces are 2-4x faster with ETACE due to Zygote AD through ET graph + +--- + +## Architecture + +### Current Implementation (Complete) + +``` +StackedCalculator +├── WrappedSiteCalculator{ETOneBody} # One-body reference energies +├── WrappedSiteCalculator{ETPairModel} # Pair potential +└── WrappedSiteCalculator{ETACE} # Many-body ACE +``` + +### Core Components + +**WrappedSiteCalculator{M, PS, ST}** (`et_calculators.jl`) +- Unified wrapper for any ETACE-pattern model +- Provides AtomsCalculators interface (energy, forces, virial) +- Mutable to allow parameter updates during training + +**ETACEPotential** - Type alias for `WrappedSiteCalculator{ETACE, PS, ST}` + +**StackedCalculator{N, C}** (`stackedcalc.jl`) +- Combines multiple calculators by summing contributions +- Uses @generated functions for type-stable loop unrolling + +### Conversion Functions + +```julia +convert2et(model) # Many-body ACE → ETACE +convertpair(model) # Pair potential → ETPairModel +convert2et_full(model, ps, st) # Full model → StackedCalculator +``` + +### Training Assembly (Many-Body Only) + +```julia +length_basis(calc) # Total linear parameters +get_linear_parameters(calc) # Extract θ vector +set_linear_parameters!(calc, θ) # Set θ vector +potential_energy_basis(sys, calc) # Energy design row +energy_forces_virial_basis(sys, calc) # Full EFV design row +``` + +--- + +## Files + +### Source Files +- `src/et_models/et_ace.jl` - ETACE model implementation +- `src/et_models/et_pair.jl` - ETPairModel implementation +- `src/et_models/onebody.jl` - ETOneBody implementation +- `src/et_models/et_calculators.jl` - WrappedSiteCalculator, ETACEPotential, training assembly +- `src/et_models/stackedcalc.jl` - StackedCalculator with @generated +- `src/et_models/convert.jl` - Model conversion utilities +- `src/et_models/et_envbranch.jl` - EnvRBranchL for envelope × radial basis +- `src/et_models/et_models.jl` - Module includes and exports + +### Test Files +- `test/etmodels/test_etbackend.jl` - ETACE tests +- `test/etmodels/test_etpair.jl` - ETPairModel tests +- `test/etmodels/test_etonebody.jl` - ETOneBody tests + +### Benchmark Files +- `benchmark/gpu_benchmark.jl` - GPU energy/forces benchmarks +- `benchmark/benchmark_full_model.jl` - CPU comparison benchmarks + +--- + +## Outstanding Work + +### ~~1. Training Assembly for Pair Model~~ ✅ Complete +**Status**: Implemented in `et_calculators.jl` and `stackedcalc.jl` + +**What was done**: +- Added `ETPairPotential` type alias with full training assembly support +- Added `ETOneBodyPotential` type alias (returns empty arrays - no learnable params) +- Implemented `length_basis`, `energy_forces_virial_basis`, `potential_energy_basis`, `get_linear_parameters`, `set_linear_parameters!` for all calculator types +- Extended `StackedCalculator` to concatenate basis functions from all components +- Added `ACEfit.basis_size` dispatch for all calculator types + +### ~~2. ACEfit.assemble Dispatch Integration~~ ✅ Complete +**Status**: Works out-of-the-box after extending `length_basis` and `energy_forces_virial_basis` + +**What was done**: +- Added empty function declarations in `models/models.jl` for `length_basis`, `energy_forces_virial_basis`, `potential_energy_basis` +- ETModels now imports and extends these functions +- `ACEfit.feature_matrix(d::AtomsData, calc)` works with ETACE calculators +- `ACEfit.assemble(data, calc)` works with `StackedCalculator` + +### 3. Committee Support +**Priority**: Low +**Description**: Extend committee/uncertainty quantification to work with StackedCalculator. + +### 4. Basis Index Design Discussion +**Priority**: Needs Discussion +**Description**: Moderator raised concern about basis indices: + +> "I realized I made a mistake in the design of the basis interface. I'm returning the site energy basis but for each center-atom species, the basis occupies the same indices. We need to perform a transformation so that bases for different species occupy separate indices." + +**Current Implementation**: Species separation is handled at the **calculator level** in `energy_forces_virial_basis` using `p = (s-1) * nbasis + k`. Each species gets separate parameter indices. + +**Options**: +1. Keep current approach (calculator-level separation) +2. Move to site potential model level +3. Handle at WrappedSiteCalculator level + +Moderator wants discussion before making changes. + +--- + +## Dependencies + +- EquivariantTensors.jl >= 0.4.3 +- Polynomials4ML.jl >= 0.5.8 (for GPU forces) +- LuxCUDA (for GPU support, test dependency) + +--- + +## Test Status + +All tests pass: **946 passed, 1 broken** (known Julia 1.12 hash ordering issue) + +```bash +# Run ET model tests +julia --project=test -e 'using Pkg; Pkg.test("ACEpotentials"; test_args=["etmodels"])' + +# Run GPU benchmark +julia --project=test benchmark/gpu_benchmark.jl +``` + +--- + +## Notes + +- Virial formula: `V = -∑ ∂E/∂𝐫ij ⊗ 𝐫ij` +- GPU time scales sub-linearly with system size +- Forces speedup (CPU) larger than energy speedup due to Zygote AD efficiency +- StackedCalculator uses @generated functions for zero-overhead composition +- Upstream `ETOneBody` stores E0s in state (`st.E0s`) for float type flexibility +- All models use `VState` for edge gradients in `site_grads()` return +- `ETOneBody.site_grads()` returns `fill(VState(), length(edges))` for type stability (empty VState acts as additive identity) diff --git a/docs/src/all_exported.md b/docs/src/all_exported.md index 9e250d258..16dacd79e 100644 --- a/docs/src/all_exported.md +++ b/docs/src/all_exported.md @@ -3,13 +3,13 @@ ### Exported ```@autodocs -Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat] +Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat, ACEpotentials.ETModels] Private = false -``` +``` ### Not exported ```@autodocs -Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat] +Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat, ACEpotentials.ETModels] Public = false ``` diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 98ee4560c..23da5738e 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,11 +1,11 @@ # Tutorials Overview -* [Basic Julia Workflow](../literate_tutorials/basic_julia_workflow.md) : minimal example to fit a potential to an existing dataset using a Julia script +* [Basic Julia Workflow](../literate_tutorials/basic_julia_workflow.md) : minimal example to fit a potential to an existing dataset using a Julia script * [Basic Shell Workflow](scripting.md) : basic workflow for fitting via the command line * [Smoothness Priors](../literate_tutorials/smoothness_priors.md) : brief introduction to smoothness priors * [Basic Dataset Analysis](../literate_tutorials/dataset_analysis.md) : basic techniques to visualize training datasets and correlate such observations to the choice of geometric priors -* [Descriptors](../literate_tutorials/descriptor.md) : `ACEpotentials` can be used as descriptors of atomic environments or structures, which is described here. +* [Descriptors](../literate_tutorials/descriptor.md) : `ACEpotentials` can be used as descriptors of atomic environments or structures, which is described here. * [Sparse Solvers](../literate_tutorials/asp.md) : basic tutorial on using the `ASP` and `OMP` solvers. - +* [ETACE Models](../literate_tutorials/etace_tutorial.md) : using the EquivariantTensors backend for ACE models, including conversion from standard ACE models and creating ETACE models from scratch. diff --git a/examples/etmodels/etace_tutorial.jl b/examples/etmodels/etace_tutorial.jl new file mode 100644 index 000000000..dd5f31054 --- /dev/null +++ b/examples/etmodels/etace_tutorial.jl @@ -0,0 +1,348 @@ +# # ETACE Models Tutorial +# +# This tutorial demonstrates how to use the EquivariantTensors (ET) backend +# for ACE models in ACEpotentials.jl. The ET backend provides: +# - Graph-based evaluation (edge-centric computation) +# - Automatic differentiation via Zygote +# - GPU-ready architecture via KernelAbstractions +# - Lux.jl layer integration +# +# We cover two approaches: +# 1. **Converting from an existing ACE model** - The recommended approach +# 2. **Creating an ETACE model from scratch** - For advanced users +# + +## Load required packages +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful +using AtomsCalculators, Random, LinearAlgebra + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +rng = Random.MersenneTwister(1234) + +# ============================================================================= +# Part 1: Converting from an Existing ACE Model (Recommended) +# ============================================================================= +# +# The simplest way to get an ETACE model is to convert from a standard ACE model. +# This approach ensures consistency with the familiar ACE model construction API. + +## Define model hyperparameters +elements = (:Si, :O) +order = 3 # correlation order (body-order = order + 1) +max_level = 10 # total polynomial degree +maxl = 6 # maximum angular momentum +rcut = 5.5 # cutoff radius in Angstrom + +## Create the standard ACE model +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) + +## Note: pair_learnable=true is required for ET conversion +## (default uses splines which aren't yet supported by convert2et) +model = M.ace_model(; + elements = elements, + order = order, + Ytype = :solid, + level = M.TotalDegree(), + max_level = max_level, + maxl = maxl, + pair_maxn = max_level, + rin0cuts = rin0cuts, + E0s = Dict(:Si => -0.846, :O => -1.023), # reference energies + pair_learnable = true # required for ET conversion +) + +## Initialize parameters with Lux +ps, st = Lux.setup(rng, model) + +@info "Standard ACE model created" +@info " Number of basis functions: $(M.length_basis(model))" + +# ----------------------------------------------------------------------------- +# Method A: Convert full model (E0 + Pair + Many-body) to StackedCalculator +# ----------------------------------------------------------------------------- + +## convert2et_full creates a StackedCalculator combining: +## - ETOneBody (reference energies per species) +## - ETPairModel (pair potential) +## - ETACE (many-body ACE potential) + +et_calc_full = ETM.convert2et_full(model, ps, st; rng=rng) + +@info "Full conversion to StackedCalculator" +@info " Contains: ETOneBody + ETPairPotential + ETACEPotential" +@info " Total linear parameters: $(ETM.length_basis(et_calc_full))" + +# ----------------------------------------------------------------------------- +# Method B: Convert only the many-body ACE component +# ----------------------------------------------------------------------------- + +## convert2et creates just the ETACE model (many-body only, no E0 or pair) +et_ace = ETM.convert2et(model) +et_ace_ps, et_ace_st = Lux.setup(rng, et_ace) + +## Copy parameters from the original model +ETM.copy_ace_params!(et_ace_ps, ps, model) + +## Wrap in calculator for AtomsCalculators interface +et_ace_calc = ETM.ETACEPotential(et_ace, et_ace_ps, et_ace_st, rcut) + +@info "Many-body only conversion" +@info " ETACE basis size: $(ETM.length_basis(et_ace_calc))" + +# ----------------------------------------------------------------------------- +# Method C: Convert only the pair potential +# ----------------------------------------------------------------------------- + +## convertpair creates an ETPairModel +et_pair = ETM.convertpair(model) +et_pair_ps, et_pair_st = Lux.setup(rng, et_pair) + +## Copy parameters from the original model +ETM.copy_pair_params!(et_pair_ps, ps, model) + +## Wrap in calculator +et_pair_calc = ETM.ETPairPotential(et_pair, et_pair_ps, et_pair_st, rcut) + +@info "Pair potential only conversion" +@info " ETPairModel basis size: $(ETM.length_basis(et_pair_calc))" + + +# ============================================================================= +# Part 2: Using ETACE Calculators +# ============================================================================= + +## Create a test system +sys = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys, 0.1u"Å") +AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + +@info "Test system: $(length(sys)) atoms" + +## Evaluate energy, forces, virial using AtomsCalculators interface +E = AtomsCalculators.potential_energy(sys, et_calc_full) +F = AtomsCalculators.forces(sys, et_calc_full) +V = AtomsCalculators.virial(sys, et_calc_full) + +@info "Energy evaluation with full ETACE calculator" +@info " Energy: $E" +@info " Max force magnitude: $(maximum(norm.(F)))" + +## Combined evaluation (more efficient) +efv = AtomsCalculators.energy_forces_virial(sys, et_calc_full) +@info " Combined EFV evaluation successful" + + +# ============================================================================= +# Part 3: Training Assembly (for Linear Fitting) +# ============================================================================= +# +# The ETACE calculators support training assembly functions for ACEfit integration. +# These compute the design matrix rows for linear least squares fitting. + +## Energy-only basis evaluation (fastest) +E_basis = ETM.potential_energy_basis(sys, et_ace_calc) +@info "Energy basis: $(length(E_basis)) components" + +## Full energy, forces, virial basis +efv_basis = ETM.energy_forces_virial_basis(sys, et_ace_calc) +@info "EFV basis shapes:" +@info " Energy: $(size(efv_basis.energy))" +@info " Forces: $(size(efv_basis.forces))" +@info " Virial: $(size(efv_basis.virial))" + +## Get/set linear parameters +params = ETM.get_linear_parameters(et_ace_calc) +@info "Linear parameters: $(length(params)) values" + +## Parameters can be updated for fitting: +## ETM.set_linear_parameters!(et_ace_calc, new_params) + + +# ============================================================================= +# Part 4: Creating an ETACE Model from Scratch (Advanced) +# ============================================================================= +# +# For advanced users who want direct control over the model architecture. +# This requires understanding the EquivariantTensors.jl API. + +## Define model parameters +scratch_elements = [:Si, :O] +scratch_maxn = 6 # number of radial basis functions +scratch_maxl = 4 # maximum angular momentum +scratch_order = 2 # correlation order +scratch_rcut = 5.5 # cutoff radius + +## Species information +zlist = ChemicalSpecies.(scratch_elements) +NZ = length(zlist) + +# ----------------------------------------------------------------------------- +# Build the radial embedding (Rnl) +# ----------------------------------------------------------------------------- + +## Radial specification (n, l pairs) +Rnl_spec = [(n=n, l=l) for n in 1:scratch_maxn for l in 0:scratch_maxl] + +## Distance transform: r -> transformed coordinate y +## Using standard Agnesi transform parameters +f_trans = let rcut = scratch_rcut + (x, st) -> begin + r = norm(x.𝐫) + ## Simple polynomial transform (normalized to [-1, 1]) + y = 1 - 2 * r / rcut + return y + end +end +trans = ET.NTtransformST(f_trans, NamedTuple()) + +## Envelope function: smooth cutoff +f_env = y -> (1 - y^2)^2 # quartic envelope + +## Polynomial basis (Chebyshev) +polys = P4ML.ChebBasis(scratch_maxn) +Penv = P4ML.wrapped_basis(Lux.BranchLayer( + polys, + Lux.WrappedFunction(y -> f_env.(y)), + fusion = Lux.WrappedFunction(Pe -> Pe[2] .* Pe[1]) +)) + +## Species-pair selector for radial weights +selector_ij = let zlist = tuple(zlist...) + xij -> ET.catcat2idx(zlist, xij.z0, xij.z1) +end + +## Linear layer: P(yij) -> W[(Zi, Zj)] * P(yij) +linl = ET.SelectLinL(scratch_maxn, length(Rnl_spec), NZ^2, selector_ij) + +## Complete radial embedding +rbasis = ET.EmbedDP(trans, Penv, linl) +rembed = ET.EdgeEmbed(rbasis) + +# ----------------------------------------------------------------------------- +# Build the angular embedding (Ylm) +# ----------------------------------------------------------------------------- + +## Spherical harmonics basis +ylm_basis = P4ML.real_sphericalharmonics(scratch_maxl) +Ylm_spec = P4ML.natural_indices(ylm_basis) + +## Angular embedding: edge direction -> spherical harmonics +ybasis = ET.EmbedDP( + ET.NTtransformST((x, st) -> x.𝐫, NamedTuple()), + ylm_basis +) +yembed = ET.EdgeEmbed(ybasis) + +# ----------------------------------------------------------------------------- +# Build the many-body basis (sparse ACE) +# ----------------------------------------------------------------------------- + +## Define the many-body specification +## This specifies which (n,l) combinations appear in each correlation +## For simplicity, use all 1-correlations up to given degree +mb_spec = [[(n=n, l=l)] for n in 1:scratch_maxn for l in 0:scratch_maxl] + +## Create sparse equivariant tensor (ACE basis) +mb_basis = ET.sparse_equivariant_tensor( + L = 0, # scalar (invariant) output + mb_spec = mb_spec, + Rnl_spec = Rnl_spec, + Ylm_spec = Ylm_spec, + basis = real # real-valued basis +) + +# ----------------------------------------------------------------------------- +# Build the readout layer +# ----------------------------------------------------------------------------- + +## Species selector for readout +selector_i = let zlist = zlist + x -> ET.cat2idx(zlist, x.z) +end + +## Readout: basis values -> site energies +readout = ET.SelectLinL( + mb_basis.lens[1], # input dimension (basis length) + 1, # output dimension (site energy) + NZ, # number of species categories + selector_i +) + +# ----------------------------------------------------------------------------- +# Assemble the ETACE model +# ----------------------------------------------------------------------------- + +scratch_etace = ETM.ETACE(rembed, yembed, mb_basis, readout) + +## Initialize with Lux +scratch_ps, scratch_st = Lux.setup(rng, scratch_etace) + +@info "ETACE model created from scratch" +@info " Radial basis size: $(length(Rnl_spec))" +@info " Angular basis size: $(length(Ylm_spec))" +@info " Many-body basis size: $(mb_basis.lens[1])" + +## Wrap in calculator +scratch_calc = ETM.ETACEPotential(scratch_etace, scratch_ps, scratch_st, scratch_rcut) + +## Test evaluation +E_scratch = AtomsCalculators.potential_energy(sys, scratch_calc) +@info "Scratch model energy: $E_scratch" + + +# ============================================================================= +# Part 5: Creating One-Body and Pair Models from Scratch +# ============================================================================= + +# ----------------------------------------------------------------------------- +# ETOneBody: Reference energies +# ----------------------------------------------------------------------------- + +## Define reference energies per species +E0_dict = Dict(ChemicalSpecies(:Si) => -0.846, + ChemicalSpecies(:O) => -1.023) + +## Category function extracts species from atom state +catfun = x -> x.z # x.z is the ChemicalSpecies + +## Create one-body model +et_onebody = ETM.one_body(E0_dict, catfun) +_, onebody_st = Lux.setup(rng, et_onebody) + +## Wrap in calculator (uses small cutoff since no neighbors needed) +onebody_calc = ETM.ETOneBodyPotential(et_onebody, nothing, onebody_st, 3.0) + +@info "ETOneBody model created" +@info " Reference energies: $E0_dict" + +E_onebody = AtomsCalculators.potential_energy(sys, onebody_calc) +@info " One-body energy for test system: $E_onebody" + + +# ============================================================================= +# Part 6: Combining Models with StackedCalculator +# ============================================================================= +# +# StackedCalculator combines multiple calculators by summing their contributions. + +## Stack our from-scratch models +combined_calc = ETM.StackedCalculator((onebody_calc, scratch_calc)) + +@info "StackedCalculator created" +@info " Components: ETOneBody + ETACE" +@info " Total basis size: $(ETM.length_basis(combined_calc))" + +## Evaluate combined model +E_combined = AtomsCalculators.potential_energy(sys, combined_calc) +@info " Combined energy: $E_combined" + +## Training assembly works on StackedCalculator too +efv_combined = ETM.energy_forces_virial_basis(sys, combined_calc) +@info " Combined EFV basis shapes: E=$(size(efv_combined.energy)), F=$(size(efv_combined.forces))" + +@info "Tutorial complete!" diff --git a/src/et_models/et_ace.jl b/src/et_models/et_ace.jl index 9c57f1d2c..df2446453 100644 --- a/src/et_models/et_ace.jl +++ b/src/et_models/et_ace.jl @@ -69,9 +69,11 @@ function site_basis(l::ETACE, X::ET.ETGraph, ps, st) end -function site_basis_jacobian(l::ETACE, X::ET.ETGraph, ps, st) +function site_basis_jacobian(l::ETACE, X::ET.ETGraph, ps, st) (R, ∂R), _ = ET.evaluate_ed(l.rembed, X, ps.rembed, st.rembed) (Y, ∂Y), _ = ET.evaluate_ed(l.yembed, X, ps.yembed, st.yembed) + # _jacobian_X for SparseACEbasis takes (basis, Rnl, Ylm, dRnl, dYlm, ps, st) + # Requires EquivariantTensors >= 0.4.2 (𝔹,), (∂𝔹,) = ET._jacobian_X(l.basis, R, Y, ∂R, ∂Y, ps.basis, st.basis) return 𝔹, ∂𝔹 end diff --git a/src/et_models/et_calculators.jl b/src/et_models/et_calculators.jl new file mode 100644 index 000000000..ac1c36749 --- /dev/null +++ b/src/et_models/et_calculators.jl @@ -0,0 +1,784 @@ + +# Calculator interfaces for ETACE models +# Provides AtomsCalculators-compatible energy/forces/virial evaluation +# +# Architecture: +# - WrappedSiteCalculator: Unified wrapper for ETACE-pattern models (ETACE, ETPairModel, ETOneBody) +# - ETACEPotential: Type alias for WrappedSiteCalculator with ETACE model +# - StackedCalculator: Combines multiple calculators (see stackedcalc.jl) +# +# All wrapped models must implement the ETACE interface: +# model(G, ps, st) -> (site_energies, st) +# site_grads(model, G, ps, st) -> edge gradients +# +# See also: stackedcalc.jl for StackedCalculator (combines multiple calculators) + +import AtomsCalculators +import AtomsBase: AbstractSystem, ChemicalSpecies +import EquivariantTensors as ET +using DecoratedParticles: PState +using StaticArrays +using Unitful +using LinearAlgebra: norm + +# Import from parent Models module to extend these functions +import ..Models: length_basis, energy_forces_virial_basis, potential_energy_basis + + +# ============================================================================ +# WrappedSiteCalculator - Unified wrapper for ETACE-pattern models +# ============================================================================ + +""" + WrappedSiteCalculator{M, PS, ST} + +Wraps any ETACE-pattern model (ETACE, ETPairModel, ETOneBody) and provides +the AtomsCalculators interface. + +All wrapped models must implement the ETACE interface: +- `model(G, ps, st)` → `(site_energies, st)` +- `site_grads(model, G, ps, st)` → edge gradients + +Mutable to allow parameter updates during training. + +# Example +```julia +# With ETACE model +calc = WrappedSiteCalculator(et_model, ps, st, 5.5) + +# With ETOneBody (upstream) +et_onebody = ETM.one_body(Dict(:Si => -0.846), x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) +calc = WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + +E = potential_energy(sys, calc) +F = forces(sys, calc) +``` + +# Fields +- `model` - ETACE-pattern model (ETACE, ETPairModel, or ETOneBody) +- `ps` - Model parameters (can be `nothing` for ETOneBody) +- `st` - Model state +- `rcut::Float64` - Cutoff radius for graph construction (Å) +- `co_ps` - Optional committee parameters for uncertainty quantification +""" +mutable struct WrappedSiteCalculator{M, PS, ST} + model::M + ps::PS + st::ST + rcut::Float64 + co_ps::Any +end + +# Constructor without committee parameters +function WrappedSiteCalculator(model, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut), nothing) +end + +cutoff_radius(calc::WrappedSiteCalculator) = calc.rcut * u"Å" + +function _wrapped_energy(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + Ei, _ = calc.model(G, calc.ps, calc.st) + return sum(Ei) +end + +function _wrapped_forces(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + ∂G = site_grads(calc.model, G, calc.ps, calc.st) + # Handle empty edge case (e.g., ETOneBody with small cutoff) + if isempty(∂G.edge_data) + return zeros(SVector{3, Float64}, length(sys)) + end + # forces_from_edge_grads returns +∇E, negate for forces + return -ET.Atoms.forces_from_edge_grads(sys, G, ∂G.edge_data) +end + +# Compute virial tensor from edge gradients +function _compute_virial(G::ET.ETGraph, ∂G) + # V = -∑ (∂E/∂𝐫ij) ⊗ 𝐫ij + V = zeros(SMatrix{3,3,Float64,9}) + for (edge, ∂edge) in zip(G.edge_data, ∂G.edge_data) + V -= ∂edge.𝐫 * edge.𝐫' + end + return V +end + +function _wrapped_virial(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + ∂G = site_grads(calc.model, G, calc.ps, calc.st) + # Handle empty edge case + if isempty(∂G.edge_data) + return zeros(SMatrix{3,3,Float64,9}) + end + return _compute_virial(G, ∂G) +end + +function _wrapped_energy_forces_virial(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + + # Energy from site energies (call model directly - ETACE interface) + Ei, _ = calc.model(G, calc.ps, calc.st) + E = sum(Ei) + + # Forces and virial from edge gradients + ∂G = site_grads(calc.model, G, calc.ps, calc.st) + + # Handle empty edge case (e.g., ETOneBody with small cutoff) + if isempty(∂G.edge_data) + F = zeros(SVector{3, Float64}, length(sys)) + V = zeros(SMatrix{3,3,Float64,9}) + else + F = -ET.Atoms.forces_from_edge_grads(sys, G, ∂G.edge_data) + V = _compute_virial(G, ∂G) + end + + return (energy=E, forces=F, virial=V) +end + +# AtomsCalculators interface for WrappedSiteCalculator +AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + return _wrapped_energy(calc, sys) * u"eV" +end + +AtomsCalculators.@generate_interface function AtomsCalculators.forces( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + return _wrapped_forces(calc, sys) .* u"eV/Å" +end + +AtomsCalculators.@generate_interface function AtomsCalculators.virial( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + return _wrapped_virial(calc, sys) * u"eV" +end + +function AtomsCalculators.energy_forces_virial( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + efv = _wrapped_energy_forces_virial(calc, sys) + return ( + energy = efv.energy * u"eV", + forces = efv.forces .* u"eV/Å", + virial = efv.virial * u"eV" + ) +end + + +# ============================================================================ +# ETACEPotential - Type alias for WrappedSiteCalculator{ETACE} +# ============================================================================ + +""" + ETACEPotential + +AtomsCalculators-compatible calculator wrapping an ETACE model. +This is a type alias for `WrappedSiteCalculator{<:ETACE, PS, ST}`. + +Access underlying components via: +- `calc.model` - The ETACE model +- `calc.ps` - Model parameters +- `calc.st` - Model state +- `calc.rcut` - Cutoff radius in Ångström +- `calc.co_ps` - Committee parameters (optional) + +# Example +```julia +calc = ETACEPotential(et_model, ps, st, 5.5) +E = potential_energy(sys, calc) +``` +""" +const ETACEPotential{MOD<:ETACE, PS, ST} = WrappedSiteCalculator{MOD, PS, ST} + +# Constructor: creates WrappedSiteCalculator with ETACE model directly +function ETACEPotential(model::ETACE, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut)) +end + +# ============================================================================ +# Training Assembly Interface +# ============================================================================ +# +# These functions compute the basis values for linear least squares fitting. +# The linear parameters are the readout weights W[1, k, s] where: +# k = basis function index (1:nbasis) +# s = species index (1:nspecies) +# +# Total parameters: nbasis * nspecies +# +# Energy basis: E = ∑_i ∑_k W[k, species[i]] * 𝔹[i, k] +# Force basis: F_atom = -∑ edges ∂E/∂r_edge, computed per basis function +# Virial basis: V = -∑ edges (∂E/∂r_edge) ⊗ r_edge, computed per basis function + +# Accessor helpers for ETACEPotential (which is WrappedSiteCalculator{ETACE}) +_etace(calc::ETACEPotential) = calc.model # Underlying ETACE model (direct) +_ps(calc::ETACEPotential) = calc.ps # Model parameters +_st(calc::ETACEPotential) = calc.st # Model state + +""" + length_basis(calc::ETACEPotential) + +Return the number of linear parameters in the model (nbasis * nspecies). +""" +function length_basis(calc::ETACEPotential) + etace = _etace(calc) + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + return nbasis * nspecies +end + +# ACEfit integration +import ACEfit +ACEfit.basis_size(calc::ETACEPotential) = length_basis(calc) + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::ETACEPotential) + +Compute the basis functions for energy, forces, and virial. +Returns a named tuple with: +- `energy::Vector{Float64}` - length = length_basis(calc) +- `forces::Matrix{SVector{3,Float64}}` - size = (natoms, length_basis) +- `virial::Vector{SMatrix{3,3,Float64}}` - length = length_basis(calc) + +The linear combination of basis values with parameters gives: + E = dot(energy, params) + F = forces * params + V = sum(params .* virial) +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::ETACEPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + etace = _etace(calc) + + # Get basis and jacobian + # 𝔹: (nnodes, nbasis) - basis values per site (Float64) + # ∂𝔹: (maxneigs, nnodes, nbasis) - directional derivatives (VState objects) + 𝔹, ∂𝔹 = site_basis_jacobian(etace, G, _ps(calc), _st(calc)) + + natoms = length(sys) + nnodes = size(𝔹, 1) + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + nparams = nbasis * nspecies + maxneigs = size(∂𝔹, 1) + + # Species indices for each node + iZ = etace.readout.selector.(G.node_data) + + # Initialize outputs + E_basis = zeros(nparams) + F_basis = zeros(SVector{3, Float64}, natoms, nparams) + V_basis = zeros(SMatrix{3, 3, Float64, 9}, nparams) + + # Pre-allocate work buffer for gradient (same element type as ∂𝔹) + # This avoids allocating a new matrix in each iteration + ∇Ei_buf = similar(∂𝔹, maxneigs, nnodes) + + # Pre-compute a zero element for masking (same type as ∂𝔹 elements) + zero_grad = zero(∂𝔹[1, 1, 1]) + + # Pre-compute edge vectors for virial (avoid repeated access) + edge_𝐫 = [edge.𝐫 for edge in G.edge_data] + + # Compute basis values for each parameter (k, s) pair + # Parameter index: p = (s-1) * nbasis + k + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + + # Energy basis: sum of 𝔹[i, k] for atoms of species s + for i in 1:nnodes + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + + # Fill gradient buffer: ∇Ei[:, i] = ∂𝔹[:, i, k] if iZ[i] == s, else zeros + # This avoids allocating W_unit and doing matrix-vector multiply + for i in 1:nnodes + if iZ[i] == s + @views ∇Ei_buf[:, i] .= ∂𝔹[:, i, k] + else + @views ∇Ei_buf[:, i] .= Ref(zero_grad) + end + end + + # Reshape for rev_reshape_embedding (needs 3D array) - this is a view, no allocation + ∇Ei_3d = reshape(∇Ei_buf, maxneigs, nnodes, 1) + + # Convert to edge-indexed format with 3D vectors + ∇E_edges = ET.rev_reshape_embedding(∇Ei_3d, G)[:] + + # Convert edge gradients to atomic forces (negate for forces) + F_basis[:, p] = -ET.Atoms.forces_from_edge_grads(sys, G, ∇E_edges) + + # Compute virial: V = -∑ (∂E/∂𝐫ij) ⊗ 𝐫ij + V = zero(SMatrix{3, 3, Float64, 9}) + for (e, ∂edge) in enumerate(∇E_edges) + V -= ∂edge.𝐫 * edge_𝐫[e]' + end + V_basis[p] = V + end + end + + return ( + energy = E_basis * u"eV", + forces = F_basis .* u"eV/Å", + virial = V_basis * u"eV" + ) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::ETACEPotential) + +Compute only the energy basis (faster when forces/virial not needed). +""" +function potential_energy_basis(sys::AbstractSystem, calc::ETACEPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + etace = _etace(calc) + + # Get basis values + 𝔹 = site_basis(etace, G, _ps(calc), _st(calc)) + + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + nparams = nbasis * nspecies + + # Species indices for each node + iZ = etace.readout.selector.(G.node_data) + + # Compute energy basis + E_basis = zeros(nparams) + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + for i in 1:length(G.node_data) + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + end + end + + return E_basis * u"eV" +end + +""" + get_linear_parameters(calc::ETACEPotential) + +Extract the linear parameters (readout weights) as a flat vector. +Parameters are ordered as: [W[1,:,1]; W[1,:,2]; ... ; W[1,:,nspecies]] +""" +function get_linear_parameters(calc::ETACEPotential) + return vec(_ps(calc).readout.W) +end + +""" + set_linear_parameters!(calc::ETACEPotential, θ::AbstractVector) + +Set the linear parameters (readout weights) from a flat vector. +""" +function set_linear_parameters!(calc::ETACEPotential, θ::AbstractVector) + etace = _etace(calc) + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + @assert length(θ) == nbasis * nspecies + + # Reshape and copy into ps (WrappedSiteCalculator is mutable) + ps = _ps(calc) + new_W = reshape(θ, 1, nbasis, nspecies) + calc.ps = merge(ps, (readout = merge(ps.readout, (W = new_W,)),)) + return calc +end + + +# ============================================================================ +# ETPairPotential - Type alias for WrappedSiteCalculator{ETPairModel} +# ============================================================================ + +""" + ETPairPotential + +AtomsCalculators-compatible calculator wrapping an ETPairModel. +This is a type alias for `WrappedSiteCalculator{<:ETPairModel, PS, ST}`. + +Supports training assembly functions: +- `length_basis(calc)` - Total linear parameters +- `energy_forces_virial_basis(sys, calc)` - Full EFV design row +- `potential_energy_basis(sys, calc)` - Energy design row +- `get_linear_parameters(calc)` / `set_linear_parameters!(calc, θ)` + +# Example +```julia +et_pair = convertpair(model) +ps, st = Lux.setup(rng, et_pair) +calc = ETPairPotential(et_pair, ps, st, 5.5) +E = potential_energy(sys, calc) +``` +""" +const ETPairPotential{MOD<:ETPairModel, PS, ST} = WrappedSiteCalculator{MOD, PS, ST} + +function ETPairPotential(model::ETPairModel, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut)) +end + +# ============================================================================ +# ETPairPotential Training Assembly +# ============================================================================ + +# Accessor helpers +_pair(calc::ETPairPotential) = calc.model +_ps(calc::ETPairPotential) = calc.ps +_st(calc::ETPairPotential) = calc.st + +""" + length_basis(calc::ETPairPotential) + +Return the number of linear parameters in the pair model (nbasis * nspecies). +""" +function length_basis(calc::ETPairPotential) + pair = _pair(calc) + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + return nbasis * nspecies +end + +ACEfit.basis_size(calc::ETPairPotential) = length_basis(calc) + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::ETPairPotential) + +Compute the basis functions for energy, forces, and virial for pair potential. +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::ETPairPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + pair = _pair(calc) + + # Get basis and jacobian + 𝔹, ∂𝔹 = site_basis_jacobian(pair, G, _ps(calc), _st(calc)) + + natoms = length(sys) + nnodes = size(𝔹, 1) + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + nparams = nbasis * nspecies + maxneigs = size(∂𝔹, 1) + + # Species indices for each node + iZ = pair.readout.selector.(G.node_data) + + # Initialize outputs + E_basis = zeros(nparams) + F_basis = zeros(SVector{3, Float64}, natoms, nparams) + V_basis = zeros(SMatrix{3, 3, Float64, 9}, nparams) + + # Pre-allocate work buffer + ∇Ei_buf = similar(∂𝔹, maxneigs, nnodes) + zero_grad = zero(∂𝔹[1, 1, 1]) + edge_𝐫 = [edge.𝐫 for edge in G.edge_data] + + # Compute basis values for each parameter (k, s) pair + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + + # Energy basis + for i in 1:nnodes + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + + # Fill gradient buffer + for i in 1:nnodes + if iZ[i] == s + @views ∇Ei_buf[:, i] .= ∂𝔹[:, i, k] + else + @views ∇Ei_buf[:, i] .= Ref(zero_grad) + end + end + + # Convert to edge format and compute forces/virial + ∇Ei_3d = reshape(∇Ei_buf, maxneigs, nnodes, 1) + ∇E_edges = ET.rev_reshape_embedding(∇Ei_3d, G)[:] + F_basis[:, p] = -ET.Atoms.forces_from_edge_grads(sys, G, ∇E_edges) + + V = zero(SMatrix{3, 3, Float64, 9}) + for (e, ∂edge) in enumerate(∇E_edges) + V -= ∂edge.𝐫 * edge_𝐫[e]' + end + V_basis[p] = V + end + end + + return ( + energy = E_basis * u"eV", + forces = F_basis .* u"eV/Å", + virial = V_basis * u"eV" + ) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::ETPairPotential) + +Compute only the energy basis for pair potential. +""" +function potential_energy_basis(sys::AbstractSystem, calc::ETPairPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + pair = _pair(calc) + + 𝔹 = site_basis(pair, G, _ps(calc), _st(calc)) + + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + nparams = nbasis * nspecies + + iZ = pair.readout.selector.(G.node_data) + + E_basis = zeros(nparams) + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + for i in 1:length(G.node_data) + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + end + end + + return E_basis * u"eV" +end + +""" + get_linear_parameters(calc::ETPairPotential) + +Extract the linear parameters (readout weights) as a flat vector. +""" +function get_linear_parameters(calc::ETPairPotential) + return vec(_ps(calc).readout.W) +end + +""" + set_linear_parameters!(calc::ETPairPotential, θ::AbstractVector) + +Set the linear parameters (readout weights) from a flat vector. +""" +function set_linear_parameters!(calc::ETPairPotential, θ::AbstractVector) + pair = _pair(calc) + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + @assert length(θ) == nbasis * nspecies + + ps = _ps(calc) + new_W = reshape(θ, 1, nbasis, nspecies) + calc.ps = merge(ps, (readout = merge(ps.readout, (W = new_W,)),)) + return calc +end + + +# ============================================================================ +# ETOneBodyPotential - Type alias for WrappedSiteCalculator{ETOneBody} +# ============================================================================ + +""" + ETOneBodyPotential + +AtomsCalculators-compatible calculator wrapping an ETOneBody model. +This is a type alias for `WrappedSiteCalculator{<:ETOneBody, PS, ST}`. + +ETOneBody has no learnable parameters, so training assembly returns empty results: +- `length_basis(calc)` returns 0 +- `energy_forces_virial_basis(sys, calc)` returns empty arrays +- Forces and virial are always zero (energy only depends on atom types) + +# Example +```julia +et_onebody = one_body(Dict(:Si => -0.846), x -> x.z) +_, st = Lux.setup(rng, et_onebody) +calc = ETOneBodyPotential(et_onebody, nothing, st, 3.0) +E = potential_energy(sys, calc) +``` +""" +const ETOneBodyPotential{MOD<:ETOneBody, PS, ST} = WrappedSiteCalculator{MOD, PS, ST} + +function ETOneBodyPotential(model::ETOneBody, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut)) +end + +# ============================================================================ +# ETOneBodyPotential Training Assembly (empty - no learnable parameters) +# ============================================================================ + +_onebody(calc::ETOneBodyPotential) = calc.model +_ps(calc::ETOneBodyPotential) = calc.ps +_st(calc::ETOneBodyPotential) = calc.st + +""" + length_basis(calc::ETOneBodyPotential) + +Return 0 - ETOneBody has no learnable linear parameters. +""" +length_basis(calc::ETOneBodyPotential) = 0 + +ACEfit.basis_size(calc::ETOneBodyPotential) = 0 + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::ETOneBodyPotential) + +Return empty arrays - ETOneBody has no learnable parameters. +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::ETOneBodyPotential) + natoms = length(sys) + return ( + energy = zeros(0) * u"eV", + forces = zeros(SVector{3, Float64}, natoms, 0) .* u"eV/Å", + virial = zeros(SMatrix{3, 3, Float64, 9}, 0) * u"eV" + ) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::ETOneBodyPotential) + +Return empty array - ETOneBody has no learnable parameters. +""" +potential_energy_basis(sys::AbstractSystem, calc::ETOneBodyPotential) = zeros(0) * u"eV" + +""" + get_linear_parameters(calc::ETOneBodyPotential) + +Return empty vector - ETOneBody has no learnable parameters. +""" +get_linear_parameters(calc::ETOneBodyPotential) = Float64[] + +""" + set_linear_parameters!(calc::ETOneBodyPotential, θ::AbstractVector) + +No-op for ETOneBody (no learnable parameters). +""" +function set_linear_parameters!(calc::ETOneBodyPotential, θ::AbstractVector) + @assert length(θ) == 0 "ETOneBody has no learnable parameters" + return calc +end + + +# ============================================================================ +# Full Model Conversion +# ============================================================================ + +using Random: AbstractRNG, default_rng +using Lux: setup + +""" + convert2et_full(model, ps, st; rng=default_rng()) -> StackedCalculator + +Convert a complete ACE model (E0 + Pair + Many-body) to an ETACE-based +StackedCalculator. This creates a calculator that combines: +1. ETOneBody - reference energies per species +2. ETPairModel - pair potential +3. ETACE - many-body ACE potential + +The returned StackedCalculator is fully compatible with AtomsCalculators +and can be used for energy, forces, and virial evaluation. + +# Arguments +- `model`: ACE model (from ACEpotentials.Models) +- `ps`: Model parameters (from Lux.setup) +- `st`: Model state (from Lux.setup) +- `rng`: Random number generator (default: `default_rng()`) + +# Returns +- `StackedCalculator` combining ETOneBody, ETPairModel, and ETACE + +# Example +```julia +model = ace_model(elements=[:Si], order=3, totaldegree=8) +ps, st = Lux.setup(rng, model) +# ... fit model ... +calc = convert2et_full(model, ps, st) +E = potential_energy(sys, calc) +``` +""" +function convert2et_full(model, ps, st; rng::AbstractRNG=default_rng()) + # Extract cutoff radius from pair basis + rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + + # 1. Convert E0/Vref to ETOneBody + E0s = model.Vref.E0 # Dict{Int, Float64} + zlist = ChemicalSpecies.(model.rbasis._i2z) + E0_dict = Dict(z => E0s[z.atomic_number] for z in zlist) + et_onebody = one_body(E0_dict, x -> x.z) + _, onebody_st = setup(rng, et_onebody) + # Use minimum cutoff for graph construction (ETOneBody needs no neighbors) + onebody_calc = WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + + # 2. Convert pair potential to ETPairModel + et_pair = convertpair(model) + et_pair_ps, et_pair_st = setup(rng, et_pair) + copy_pair_params!(et_pair_ps, ps, model) + pair_calc = WrappedSiteCalculator(et_pair, et_pair_ps, et_pair_st, rcut) + + # 3. Convert many-body to ETACE + et_ace = convert2et(model) + et_ace_ps, et_ace_st = setup(rng, et_ace) + copy_ace_params!(et_ace_ps, ps, model) + ace_calc = WrappedSiteCalculator(et_ace, et_ace_ps, et_ace_st, rcut) + + # 4. Stack all components + return StackedCalculator((onebody_calc, pair_calc, ace_calc)) +end + + +# ============================================================================ +# Parameter Copying Utilities +# ============================================================================ + +""" + copy_ace_params!(et_ps, ps, model) + +Copy many-body (ACE) parameters from ACE model format to ETACE format. +""" +function copy_ace_params!(et_ps, ps, model) + NZ = length(model.rbasis._i2z) + + # Copy radial basis parameters (Wnlq) + # ACE format: Wnlq[:, :, iz, jz] for species pair (iz, jz) + # ETACE format: rembed.post.W[:, :, idx] where idx = (i-1)*NZ + j + # (post is the SelectLinL layer in EmbedDP) + for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + et_ps.rembed.post.W[:, :, idx] .= ps.rbasis.Wnlq[:, :, i, j] + end + + # Copy readout (many-body) parameters + # ACE format: WB[:, s] for species s + # ETACE format: W[1, :, s] + for s in 1:NZ + et_ps.readout.W[1, :, s] .= ps.WB[:, s] + end +end + + +""" + copy_pair_params!(et_ps, ps, model) + +Copy pair potential parameters from ACE model format to ETPairModel format. +Based on parameter mapping from test/etmodels/test_etpair.jl. +""" +function copy_pair_params!(et_ps, ps, model) + NZ = length(model.pairbasis._i2z) + + # Copy pair radial basis parameters + # ACE format: pairbasis.Wnlq[:, :, i, j] for species pair (i, j) + # ETACE format: rembed.rbasis.post.W[:, :, idx] where idx = (i-1)*NZ + j + # (post is the SelectLinL layer in EmbedDP) + for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + et_ps.rembed.rbasis.post.W[:, :, idx] .= ps.pairbasis.Wnlq[:, :, i, j] + end + + # Copy pair readout parameters + # ACE format: Wpair[:, s] for species s + # ETACE format: readout.W[1, :, s] + for s in 1:NZ + et_ps.readout.W[1, :, s] .= ps.Wpair[:, s] + end +end + diff --git a/src/et_models/et_models.jl b/src/et_models/et_models.jl index 73fa48729..9aedb182e 100644 --- a/src/et_models/et_models.jl +++ b/src/et_models/et_models.jl @@ -1,5 +1,5 @@ -module ETModels +module ETModels # utility layers : these should likely be moved into ET or be removed # if more convenient implementations can be found. @@ -14,8 +14,11 @@ include("et_pair.jl") # converstion utilities: convert from 0.8 style ACE models to ET based models include("convert.jl") -# utilities to convert radial embeddings to splined versions -# for simplicity and performance and to freeze parameters +# utilities to convert radial embeddings to splined versions +# for simplicity and performance and to freeze parameters include("splinify.jl") +include("et_calculators.jl") +include("stackedcalc.jl") + end \ No newline at end of file diff --git a/src/et_models/et_pair.jl b/src/et_models/et_pair.jl index 1a3ce5f11..419321d95 100644 --- a/src/et_models/et_pair.jl +++ b/src/et_models/et_pair.jl @@ -1,8 +1,8 @@ # -# This is a temporary model implementation needed due to the fact that -# ETACEModel has Rnl, Ylm hard-coded. In the future it could be tested -# whether the pair model could simply be taken as another ACE model -# with a single embedding rather than several, This would need generalization +# This is a temporary model implementation needed due to the fact that +# ETACEModel has Rnl, Ylm hard-coded. In the future it could be tested +# whether the pair model could simply be taken as another ACE model +# with a single embedding rather than several, This would need generalization # of a fair few methods in both ACEpotentials and EquivariantTensors. # diff --git a/src/et_models/onebody.jl b/src/et_models/onebody.jl index 25f4732af..986e4971d 100644 --- a/src/et_models/onebody.jl +++ b/src/et_models/onebody.jl @@ -58,8 +58,12 @@ ___apply_onebody(selector, X::AbstractVector, E0s) = map(x -> E0s[selector(x)], X) -site_grads(l::ETOneBody, X::ET.ETGraph, ps, st) = - fill(VState(), (ET.maxneigs(X), ET.nnodes(X), )) +# ETOneBody energy only depends on atom types (categorical), not positions. +# Gradient w.r.t. positions is always zero. +# Return empty edge_data array since there are no position-dependent gradients. +function site_grads(l::ETOneBody, X::ET.ETGraph, ps, st) + return (; edge_data = VState[]) +end site_basis(l::ETOneBody, X::ET.ETGraph, ps, st) = fill(zero(eltype(st.E0s)), (ET.nnodes(X), 0)) diff --git a/src/et_models/stackedcalc.jl b/src/et_models/stackedcalc.jl new file mode 100644 index 000000000..cdcb737b4 --- /dev/null +++ b/src/et_models/stackedcalc.jl @@ -0,0 +1,219 @@ + +# StackedCalculator - Combines multiple AtomsCalculators +# +# Generic utility for combining multiple calculators by summing their +# energy, forces, and virial contributions. +# +# Uses @generated functions with Base.Cartesian for efficient +# compile-time loop unrolling when the number of calculators is known. + +import AtomsCalculators +import AtomsBase: AbstractSystem +using StaticArrays +using Unitful +using Base.Cartesian: @nexprs, @ntuple, @ncall + +""" + StackedCalculator{N, C<:Tuple} + +Combines multiple AtomsCalculators by summing their energy, forces, and virial. +Each calculator in the tuple must implement the AtomsCalculators interface. + +This allows combining site-based calculators (via WrappedSiteCalculator) with +calculators that don't have site decompositions (e.g., Coulomb, dispersion). + +The implementation uses compile-time loop unrolling for efficiency when +the number of calculators is small and known at compile time. + +# Example +```julia +# Wrap site energy models +E0_calc = WrappedSiteCalculator(E0Model(Dict(:Si => -0.846))) +ace_calc = WrappedSiteCalculator(WrappedETACE(et_model, ps, st, 5.5)) + +# Stack them (could also add Coulomb, dispersion, etc.) +calc = StackedCalculator((E0_calc, ace_calc)) + +E = potential_energy(sys, calc) +F = forces(sys, calc) +``` + +# Fields +- `calcs::Tuple` - Tuple of N calculators implementing AtomsCalculators interface +""" +struct StackedCalculator{N, C<:Tuple} + calcs::C +end + +# Constructor that infers N from the tuple length +StackedCalculator(calcs::C) where {C<:Tuple} = StackedCalculator{length(C.parameters), C}(calcs) + +# Get maximum cutoff from all calculators (for informational purposes) +@generated function cutoff_radius(calc::StackedCalculator{N}) where {N} + quote + rcuts = @ntuple $N i -> ustrip(u"Å", cutoff_radius(calc.calcs[i])) + return maximum(rcuts) * u"Å" + end +end + +# ============================================================================ +# Efficient implementations using @generated for compile-time unrolling +# ============================================================================ + +@generated function _stacked_energy(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> E_i = AtomsCalculators.potential_energy(sys, calc.calcs[i]) + return sum(@ntuple $N E) + end +end + +@generated function _stacked_forces(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> F_i = AtomsCalculators.forces(sys, calc.calcs[i]) + return reduce(.+, @ntuple $N F) + end +end + +@generated function _stacked_virial(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> V_i = AtomsCalculators.virial(sys, calc.calcs[i]) + return sum(@ntuple $N V) + end +end + +@generated function _stacked_efv(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> efv_i = AtomsCalculators.energy_forces_virial(sys, calc.calcs[i]) + return ( + energy = sum(@ntuple $N i -> efv_i.energy), + forces = reduce(.+, @ntuple $N i -> efv_i.forces), + virial = sum(@ntuple $N i -> efv_i.virial) + ) + end +end + +# ============================================================================ +# AtomsCalculators interface +# ============================================================================ + +AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_energy(sys, calc) +end + +AtomsCalculators.@generate_interface function AtomsCalculators.forces( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_forces(sys, calc) +end + +AtomsCalculators.@generate_interface function AtomsCalculators.virial( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_virial(sys, calc) +end + +function AtomsCalculators.energy_forces_virial( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_efv(sys, calc) +end + +# ============================================================================ +# Training Assembly Interface for StackedCalculator +# ============================================================================ + +import ACEfit + +""" + length_basis(calc::StackedCalculator) + +Return total number of linear parameters across all stacked calculators. +""" +function length_basis(calc::StackedCalculator) + return sum(length_basis(c) for c in calc.calcs) +end + +ACEfit.basis_size(calc::StackedCalculator) = length_basis(calc) + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::StackedCalculator) + +Compute concatenated basis for all stacked calculators. +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::StackedCalculator) + # Collect basis from each calculator + results = [energy_forces_virial_basis(sys, c) for c in calc.calcs] + + natoms = length(sys) + + # Concatenate results - energy is Vector of Quantity{Float64} + E_basis = vcat([_strip_energy_units(r.energy) for r in results]...) + + # For forces, need to hcat the matrices + # Strip units element by element for matrices of SVectors with units + F_parts = [_strip_force_units(r.forces) for r in results] + F_basis = isempty(F_parts) ? zeros(SVector{3, Float64}, natoms, 0) : hcat(F_parts...) + + # Virial is Vector of SMatrix with units + V_basis = vcat([_strip_virial_units(r.virial) for r in results]...) + + return ( + energy = E_basis * u"eV", + forces = F_basis .* u"eV/Å", + virial = V_basis * u"eV" + ) +end + +# Helper to strip units from energy (Vector of Quantity{Float64}) +function _strip_energy_units(E) + return map(e -> ustrip(e), E) +end + +# Helper to strip units from force matrices (Matrix of SVector with units) +function _strip_force_units(F) + # F is Matrix{SVector{3, Quantity}} + # We need to strip the units from the inner SVectors + return map(f -> SVector{3, Float64}(ustrip.(f)), F) +end + +# Helper to strip units from virial (Vector of SMatrix with units) +function _strip_virial_units(V) + # V is Vector{SMatrix{3,3, Quantity}} + return map(v -> SMatrix{3, 3, Float64, 9}(ustrip.(v)), V) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::StackedCalculator) + +Compute concatenated energy basis for all stacked calculators. +""" +function potential_energy_basis(sys::AbstractSystem, calc::StackedCalculator) + results = [potential_energy_basis(sys, c) for c in calc.calcs] + E_basis = vcat([ustrip.(u"eV", r) for r in results]...) + return E_basis * u"eV" +end + +""" + get_linear_parameters(calc::StackedCalculator) + +Get concatenated linear parameters from all stacked calculators. +""" +function get_linear_parameters(calc::StackedCalculator) + return vcat([get_linear_parameters(c) for c in calc.calcs]...) +end + +""" + set_linear_parameters!(calc::StackedCalculator, θ::AbstractVector) + +Set linear parameters for all stacked calculators from concatenated vector. +""" +function set_linear_parameters!(calc::StackedCalculator, θ::AbstractVector) + offset = 0 + for c in calc.calcs + n = length_basis(c) + if n > 0 + set_linear_parameters!(c, θ[offset+1:offset+n]) + end + offset += n + end + @assert offset == length(θ) "Parameter count mismatch" + return calc +end diff --git a/src/models/models.jl b/src/models/models.jl index 05527efb9..5bbeb88e2 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -18,7 +18,9 @@ import LuxCore: AbstractLuxLayer, initialparameters, initialstates -function length_basis end +function length_basis end +function energy_forces_virial_basis end +function potential_energy_basis end include("elements.jl") diff --git a/test/Project.toml b/test/Project.toml index 11f939c87..5113a70ff 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,13 @@ [deps] ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" +ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" +ACEpotentials = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" AtomsCalculatorsUtilities = "9855a07e-8816-4d1b-ac92-859c17475477" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DecoratedParticles = "023d0394-cb16-4d2d-a5c7-724bed42bbb6" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" @@ -15,7 +19,9 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" @@ -24,8 +30,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +ACEpotentials = {path = ".."} + [compat] +EquivariantTensors = "0.4.3" StaticArrays = "1" diff --git a/test/benchmark_comparison.jl b/test/benchmark_comparison.jl new file mode 100644 index 000000000..5ad9fea45 --- /dev/null +++ b/test/benchmark_comparison.jl @@ -0,0 +1,172 @@ +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators +using StaticArrays, Lux, Random, LuxCore, LinearAlgebra +using AtomsBase, AtomsBuilder, Unitful +using BenchmarkTools +using Printf + +# GPU detection (simplified from ET test utils) +dev = identity +has_cuda = false + +try + using CUDA + if CUDA.functional() + @info "Using CUDA" + CUDA.versioninfo() + global has_cuda = true + global dev = cu + else + @info "CUDA is not functional" + end +catch e + @info "Couldn't load CUDA: $e" +end + +if !has_cuda + @info "No GPU available. Using CPU only." +end + +rng = Random.MersenneTwister(1234) + +# Build models +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, model) + +# Kill the pair basis for fair comparison +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +# Create old ACE calculator +ace_calc = M.ACEPotential(model, ps, st) + +# Convert to ETACE +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(rng, et_model) + +# Copy parameters +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps.readout.W[1, :, 2] .= ps.WB[:, 2] + +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# GPU setup +has_gpu = has_cuda +if has_gpu + et_ps_32 = ET.float32(et_ps) + et_st_32 = ET.float32(et_st) + et_ps_gpu = dev(et_ps_32) + et_st_gpu = dev(et_st_32) +end + +# Function to create system of given size +function make_system(n_repeat) + sys = AtomsBuilder.bulk(:Si, cubic=true) * n_repeat + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +# Benchmark configurations (tuple for bulk multiplication) +configs = [ + (1, 1, 1), # 8 atoms + (2, 1, 1), # 16 atoms + (2, 2, 2), # 64 atoms + (3, 3, 2), # 144 atoms + (4, 4, 2), # 256 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("=" ^ 85) +println("BENCHMARK: ACE (CPU) vs ETACE (CPU) vs ETACE (GPU)") +println("=" ^ 85) +println() + +# Header +if has_gpu + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE GPU (ms) | CPU Speedup | GPU Speedup |") + println("|-------|---------|--------------|----------------|----------------|-------------|-------------|") +else + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") + println("|-------|---------|--------------|----------------|-------------|") +end + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.potential_energy(sys, ace_calc) + + # Warmup ETACE CPU + _ = sum(et_model(G, et_ps, et_st)[1]) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.potential_energy($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU (graph construction NOT included for fair comparison with GPU) + t_etace_cpu = @belapsed sum($et_model($G, $et_ps, $et_st)[1]) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + if has_gpu + # Convert graph to GPU (Float32) + G_32 = ET.float32(G) + G_gpu = dev(G_32) + + # Warmup GPU + _ = sum(et_model(G_gpu, et_ps_gpu, et_st_gpu)[1]) + + # Benchmark ETACE GPU (graph already on GPU) + t_etace_gpu = @belapsed sum($et_model($G_gpu, $et_ps_gpu, $et_st_gpu)[1]) samples=5 evals=3 + t_etace_gpu_ms = t_etace_gpu * 1000 + + gpu_speedup = t_ace_ms / t_etace_gpu_ms + + @printf("| %5d | %7d | %12.2f | %14.2f | %14.2f | %10.1fx | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, t_etace_gpu_ms, cpu_speedup, gpu_speedup) + else + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) + end +end + +println() +println("Notes:") +println("- ACE CPU: Original ACEpotentials model (pair basis zeroed for fair comparison)") +println("- ETACE CPU: EquivariantTensors backend on CPU (Float64)") +println("- ETACE GPU: EquivariantTensors backend on GPU (Float32)") +println("- CPU Speedup = ACE CPU / ETACE CPU") +println("- GPU Speedup = ACE CPU / ETACE GPU") +println("- Graph construction time NOT included (currently CPU-only)") diff --git a/test/benchmark_forces.jl b/test/benchmark_forces.jl new file mode 100644 index 000000000..afef3a864 --- /dev/null +++ b/test/benchmark_forces.jl @@ -0,0 +1,122 @@ +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators +using StaticArrays, Lux, Random, LuxCore, LinearAlgebra +using AtomsBase, AtomsBuilder, Unitful +using BenchmarkTools +using Printf + +rng = Random.MersenneTwister(1234) + +# Build models +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, model) + +# Kill the pair basis for fair comparison +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +# Create old ACE calculator +ace_calc = M.ACEPotential(model, ps, st) + +# Convert to ETACE +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(rng, et_model) + +# Copy parameters +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps.readout.W[1, :, 2] .= ps.WB[:, 2] + +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# Function to create system of given size +function make_system(n_repeat) + sys = AtomsBuilder.bulk(:Si, cubic=true) * n_repeat + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +# ETACE forces function (CPU only for now) +function etace_forces(et_model, G, sys, et_ps, et_st) + ∂G = ETM.site_grads(et_model, G, et_ps, et_st) + return -ET.Atoms.forces_from_edge_grads(sys, G, ∂G.edge_data) +end + +# Benchmark configurations (tuple for bulk multiplication) +configs = [ + (1, 1, 1), # 8 atoms + (2, 1, 1), # 16 atoms + (2, 2, 2), # 64 atoms + (3, 3, 2), # 144 atoms + (4, 4, 2), # 256 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("=" ^ 70) +println("BENCHMARK: Forces - ACE (CPU) vs ETACE (CPU)") +println("=" ^ 70) +println() + +# Header +println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") +println("|-------|---------|--------------|----------------|-------------|") + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.forces(sys, ace_calc) + + # Warmup ETACE CPU + _ = etace_forces(et_model, G, sys, et_ps, et_st) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.forces($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU (graph construction NOT included for fair comparison) + t_etace_cpu = @belapsed etace_forces($et_model, $G, $sys, $et_ps, $et_st) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) +end + +println() +println("Notes:") +println("- ACE CPU: Original ACEpotentials model (pair basis zeroed for fair comparison)") +println("- ETACE CPU: EquivariantTensors backend on CPU (Float64)") +println("- CPU Speedup = ACE CPU / ETACE CPU") +println("- Graph construction time NOT included") diff --git a/test/et_models/test_et_calculators.jl b/test/et_models/test_et_calculators.jl new file mode 100644 index 000000000..5b14a5b77 --- /dev/null +++ b/test/et_models/test_et_calculators.jl @@ -0,0 +1,776 @@ +# Tests for ETACEPotential calculator interface +# +# These tests verify: +# 1. Energy consistency between ETACE model and ETACEPotential +# 2. Force consistency against original ACE model +# 3. Virial consistency against original ACE model +# 4. AtomsCalculators interface compliance + +using Test, ACEbase, BenchmarkTools +using Polynomials4ML.Testing: print_tf, println_slim + +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators + +using AtomsBase, AtomsBuilder, Unitful +using Random, LuxCore, StaticArrays, LinearAlgebra + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## +# Build an ETACE model for testing + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 + +# Use same cutoff for all elements +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = LuxCore.setup(rng, model) + +# Kill pair basis for clarity (test only ACE part) +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +# Convert to ETACE model +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) + +# Match parameters using utility function +ETM.copy_ace_params!(et_ps, ps, model) + +# Get cutoff radius +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# Helper to generate random structures +function rand_struct() + sys = AtomsBuilder.bulk(:Si) * (2, 2, 1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +## + +@info("Testing ETACEPotential construction") + +# Create calculator from ETACE model +# ETACEPotential is now WrappedSiteCalculator{ETACE} (direct, no WrappedETACE) +et_calc = ETM.ETACEPotential(et_model, et_ps, et_st, rcut) + +# Access underlying ETACE directly via calc.model +@test et_calc.model === et_model +@test et_calc.rcut == rcut +@test et_calc.co_ps === nothing + +## + +@info("Testing energy consistency: ETACE model vs ETACEPotential") + +for ntest = 1:20 + local sys, G, E_model, E_calc + + sys = rand_struct() + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + + # Energy from direct model evaluation + Ei_model, _ = et_model(G, et_ps, et_st) + E_model = sum(Ei_model) + + # Energy from calculator + E_calc = AtomsCalculators.potential_energy(sys, et_calc) + + print_tf(@test abs(E_model - ustrip(E_calc)) < 1e-10) +end +println() + +## + +@info("Testing energy consistency: ETACE vs original ACE model") + +# Wrap original ACE model into calculator +calc_model = M.ACEPotential(model, ps, st) + +for ntest = 1:20 + local sys, E_old, E_new + + sys = rand_struct() + E_old = AtomsCalculators.potential_energy(sys, calc_model) + E_new = AtomsCalculators.potential_energy(sys, et_calc) + + print_tf(@test abs(ustrip(E_old) - ustrip(E_new)) < 1e-6) +end +println() + +## + +@info("Testing forces consistency: ETACE vs original ACE model") + +for ntest = 1:20 + local sys, F_old, F_new + + sys = rand_struct() + F_old = AtomsCalculators.forces(sys, calc_model) + F_new = AtomsCalculators.forces(sys, et_calc) + + # Compare force magnitudes (allow small numerical differences) + max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_old, F_new)) + print_tf(@test max_diff < 1e-6) +end +println() + +## + +@info("Testing virial consistency: ETACE vs original ACE model") + +for ntest = 1:20 + local sys, V_old, V_new + + sys = rand_struct() + efv_old = AtomsCalculators.energy_forces_virial(sys, calc_model) + efv_new = AtomsCalculators.energy_forces_virial(sys, et_calc) + + V_old = ustrip.(efv_old.virial) + V_new = ustrip.(efv_new.virial) + + # Compare virial tensors + print_tf(@test norm(V_old - V_new) / (norm(V_old) + 1e-10) < 1e-6) +end +println() + +## + +@info("Testing AtomsCalculators interface compliance") + +sys = rand_struct() + +# Test individual methods +E = AtomsCalculators.potential_energy(sys, et_calc) +F = AtomsCalculators.forces(sys, et_calc) +V = AtomsCalculators.virial(sys, et_calc) + +@test E isa typeof(1.0u"eV") +@test eltype(F) <: StaticArrays.SVector +@test V isa StaticArrays.SMatrix + +## + +@info("Testing combined energy_forces_virial efficiency") + +sys = rand_struct() + +# Combined evaluation +efv1 = AtomsCalculators.energy_forces_virial(sys, et_calc) + +# Separate evaluations +E = AtomsCalculators.potential_energy(sys, et_calc) +F = AtomsCalculators.forces(sys, et_calc) +V = AtomsCalculators.virial(sys, et_calc) + +@test ustrip(efv1.energy) ≈ ustrip(E) +@test all(ustrip.(efv1.forces) .≈ ustrip.(F)) +@test ustrip.(efv1.virial) ≈ ustrip.(V) + +## + +@info("Testing cutoff_radius function") + +@test ETM.cutoff_radius(et_calc) == rcut * u"Å" + +## + +@info("All Phase 1 tests passed!") + +# ============================================================================ +# Phase 2 Tests: WrappedSiteCalculator and StackedCalculator +# ============================================================================ + +@info("Testing Phase 2: WrappedSiteCalculator and StackedCalculator") + +## + +@info("Testing ETOneBody (upstream one-body model)") + +using Lux + +# Create ETOneBody model with reference energies (using upstream interface) +E0_Si = -0.846 +E0_O = -2.15 +et_onebody = ETM.one_body(Dict(:Si => E0_Si, :O => E0_O), x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) + +# Test site energies via direct model call +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +Ei_E0, _ = et_onebody(G, nothing, onebody_st) + +# Count Si and O atoms +n_Si = count(node -> node.z == AtomsBase.ChemicalSpecies(:Si), G.node_data) +n_O = count(node -> node.z == AtomsBase.ChemicalSpecies(:O), G.node_data) +expected_E0 = n_Si * E0_Si + n_O * E0_O + +@test length(Ei_E0) == length(sys) +@test sum(Ei_E0) ≈ expected_E0 + +# Test site gradients (should be empty for constant energies) +# Returns NamedTuple with empty edge_data, matching ETACE/ETPairModel interface +∂G_E0 = ETM.site_grads(et_onebody, G, nothing, onebody_st) +@test isempty(∂G_E0.edge_data) + +## + +@info("Testing WrappedSiteCalculator with ETOneBody") + +# Wrap ETOneBody in a calculator (using new unified interface) +E0_calc = ETM.WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) +@test ustrip(u"Å", ETM.cutoff_radius(E0_calc)) == 3.0 # minimum cutoff + +# Test ETOneBody calculator energy +sys = rand_struct() +E_E0_calc = AtomsCalculators.potential_energy(sys, E0_calc) +G = ET.Atoms.interaction_graph(sys, 3.0 * u"Å") +n_Si = count(node -> node.z == AtomsBase.ChemicalSpecies(:Si), G.node_data) +n_O = count(node -> node.z == AtomsBase.ChemicalSpecies(:O), G.node_data) +expected_E = (n_Si * E0_Si + n_O * E0_O) * u"eV" +@test ustrip(E_E0_calc) ≈ ustrip(expected_E) + +# Test ETOneBody calculator forces (should be zero) +F_E0_calc = AtomsCalculators.forces(sys, E0_calc) +@test all(norm(ustrip.(f)) < 1e-14 for f in F_E0_calc) + +## + +@info("Testing WrappedSiteCalculator with ETACE") + +# Wrap ETACE model in a calculator (unified interface) +ace_site_calc = ETM.WrappedSiteCalculator(et_model, et_ps, et_st, rcut) +@test ustrip(u"Å", ETM.cutoff_radius(ace_site_calc)) == rcut + +# Test ETACE calculator matches ETACEPotential +sys = rand_struct() +E_ace_site = AtomsCalculators.potential_energy(sys, ace_site_calc) +E_ace_pot = AtomsCalculators.potential_energy(sys, et_calc) +@test ustrip(E_ace_site) ≈ ustrip(E_ace_pot) + +F_ace_site = AtomsCalculators.forces(sys, ace_site_calc) +F_ace_pot = AtomsCalculators.forces(sys, et_calc) +max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_ace_site, F_ace_pot)) +@test max_diff < 1e-10 + +## + +@info("Testing StackedCalculator construction") + +# Create stacked calculator with E0 + ACE (both wrapped) +stacked = ETM.StackedCalculator((E0_calc, ace_site_calc)) + +@test ustrip(u"Å", ETM.cutoff_radius(stacked)) == rcut +@test length(stacked.calcs) == 2 + +## + +@info("Testing StackedCalculator energy consistency") + +for ntest = 1:10 + local sys, E_stacked, E_separate + + sys = rand_struct() + + # Energy from stacked calculator + E_stacked = AtomsCalculators.potential_energy(sys, stacked) + + # Energy from separate evaluations + E_E0 = AtomsCalculators.potential_energy(sys, E0_calc) + E_ace = AtomsCalculators.potential_energy(sys, ace_site_calc) + E_separate = E_E0 + E_ace + + print_tf(@test ustrip(E_stacked) ≈ ustrip(E_separate)) +end +println() + +## + +@info("Testing StackedCalculator forces consistency") + +for ntest = 1:10 + local sys, F_stacked, F_ace_only, max_diff + + sys = rand_struct() + + # Forces from stacked calculator + F_stacked = AtomsCalculators.forces(sys, stacked) + + # Forces from ACE-only (E0 has zero forces) + F_ace_only = AtomsCalculators.forces(sys, et_calc) + + # Should be identical since E0 contributes zero forces + max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_stacked, F_ace_only)) + print_tf(@test max_diff < 1e-10) +end +println() + +## + +@info("Testing StackedCalculator virial consistency") + +for ntest = 1:10 + local sys, efv_stacked, efv_ace_only + + sys = rand_struct() + + efv_stacked = AtomsCalculators.energy_forces_virial(sys, stacked) + efv_ace_only = AtomsCalculators.energy_forces_virial(sys, et_calc) + + # Virial should match (E0 has zero virial) + V_stacked = ustrip.(efv_stacked.virial) + V_ace_only = ustrip.(efv_ace_only.virial) + + print_tf(@test norm(V_stacked - V_ace_only) / (norm(V_ace_only) + 1e-10) < 1e-10) +end +println() + +## + +@info("Testing StackedCalculator with ETOneBody only") + +# Create stacked calculator with just ETOneBody (E0_calc is WrappedSiteCalculator{ETOneBody}) +E0_only_stacked = ETM.StackedCalculator((E0_calc,)) + +sys = rand_struct() +E = AtomsCalculators.potential_energy(sys, E0_only_stacked) +F = AtomsCalculators.forces(sys, E0_only_stacked) + +# Energy should match E0_calc +E_direct = AtomsCalculators.potential_energy(sys, E0_calc) +@test ustrip(E) ≈ ustrip(E_direct) + +# Forces should be zero +@test all(norm(ustrip.(f)) < 1e-14 for f in F) + +## + +@info("All Phase 2 tests passed!") + +## ============================================================================ +## Phase 5: Training Assembly Tests +## ============================================================================ + +@info("Testing Phase 5: Training assembly functions") + +## + +@info("Testing length_basis") +nparams = ETM.length_basis(et_calc) +nbasis = et_model.readout.in_dim +nspecies = et_model.readout.ncat +@test nparams == nbasis * nspecies + +## + +@info("Testing get/set_linear_parameters round-trip") +θ_orig = ETM.get_linear_parameters(et_calc) +@test length(θ_orig) == nparams + +# Modify and restore +θ_test = randn(nparams) +ETM.set_linear_parameters!(et_calc, θ_test) +θ_check = ETM.get_linear_parameters(et_calc) +@test θ_check ≈ θ_test + +# Restore original +ETM.set_linear_parameters!(et_calc, θ_orig) +@test ETM.get_linear_parameters(et_calc) ≈ θ_orig + +## + +@info("Testing potential_energy_basis") +sys = rand_struct() +E_basis = ETM.potential_energy_basis(sys, et_calc) +@test length(E_basis) == nparams +@test eltype(ustrip.(E_basis)) <: Real + +## + +@info("Testing energy_forces_virial_basis") +efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) +natoms = length(sys) + +@test length(efv_basis.energy) == nparams +@test size(efv_basis.forces) == (natoms, nparams) +@test length(efv_basis.virial) == nparams + +## + +@info("Testing linear combination gives correct energy") + +# E = dot(E_basis, θ) should match potential_energy +θ = ETM.get_linear_parameters(et_calc) +E_from_basis = dot(ustrip.(efv_basis.energy), θ) +E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + +print_tf(@test E_from_basis ≈ E_direct rtol=1e-10) +println() + +## + +@info("Testing linear combination gives correct forces") + +# F = efv_basis.forces * θ should match forces +F_from_basis = efv_basis.forces * θ +F_direct = AtomsCalculators.forces(sys, et_calc) + +max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) +print_tf(@test max_diff < 1e-10) +println() + +## + +@info("Testing linear combination gives correct virial") + +# V = sum(θ .* virial) should match virial +V_from_basis = sum(θ[i] * ustrip.(efv_basis.virial[i]) for i in 1:nparams) +V_direct = ustrip.(AtomsCalculators.virial(sys, et_calc)) + +virial_diff = maximum(abs.(V_from_basis - V_direct)) +print_tf(@test virial_diff < 1e-10) +println() + +## + +@info("Testing potential_energy_basis matches energy from efv_basis") +@test ustrip.(E_basis) ≈ ustrip.(efv_basis.energy) + +## + +@info("All Phase 5 basic tests passed!") + +## ============================================================================ +## Phase 5b: Extended Training Assembly Tests +## ============================================================================ + +@info("Testing Phase 5b: Extended training assembly tests") + +## + +@info("Testing ACEfit.basis_size integration") +import ACEfit +@test ACEfit.basis_size(et_calc) == ETM.length_basis(et_calc) + +## + +@info("Testing training assembly on multiple structures") + +# Generate multiple random structures +nstructs = 5 +test_systems = [rand_struct() for _ in 1:nstructs] + +all_ok = true +for (i, sys) in enumerate(test_systems) + local θ = ETM.get_linear_parameters(et_calc) + local efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) + + # Check energy + local E_from_basis = dot(ustrip.(efv_basis.energy), θ) + local E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + if !isapprox(E_from_basis, E_direct, rtol=1e-10) + @warn "Energy mismatch on structure $i" + all_ok = false + end + + # Check forces + local F_from_basis = efv_basis.forces * θ + local F_direct = AtomsCalculators.forces(sys, et_calc) + local max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) + if max_diff >= 1e-10 + @warn "Force mismatch on structure $i: max_diff = $max_diff" + all_ok = false + end + + # Check virial + local V_from_basis = sum(θ[k] * ustrip.(efv_basis.virial[k]) for k in 1:length(θ)) + local V_direct = ustrip.(AtomsCalculators.virial(sys, et_calc)) + local virial_diff = maximum(abs.(V_from_basis - V_direct)) + if virial_diff >= 1e-10 + @warn "Virial mismatch on structure $i: max_diff = $virial_diff" + all_ok = false + end +end +print_tf(@test all_ok) +println() + +## + +@info("Testing multi-species parameter ordering") + +# Create structures with varying species compositions +# Pure Si +sys_si = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_si, 0.1u"Å") + +# Pure O (use Si lattice but with O atoms) +sys_o = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_o, 0.1u"Å") +AtomsBuilder.randz!(sys_o, [:O => 1.0]) + +# Mixed 50/50 +sys_mixed = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_mixed, 0.1u"Å") +AtomsBuilder.randz!(sys_mixed, [:Si => 0.5, :O => 0.5]) + +# Mixed 25/75 +sys_mixed2 = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_mixed2, 0.1u"Å") +AtomsBuilder.randz!(sys_mixed2, [:Si => 0.25, :O => 0.75]) + +species_test_systems = [sys_si, sys_o, sys_mixed, sys_mixed2] +species_labels = ["Pure Si", "Pure O", "50/50 Si:O", "25/75 Si:O"] + +all_species_ok = true +for (label, sys) in zip(species_labels, species_test_systems) + local θ = ETM.get_linear_parameters(et_calc) + local efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) + + # Check energy consistency + local E_from_basis = dot(ustrip.(efv_basis.energy), θ) + local E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + + if !isapprox(E_from_basis, E_direct, rtol=1e-10) + @warn "Energy mismatch for $label: basis=$E_from_basis, direct=$E_direct" + all_species_ok = false + end + + # Check forces + local F_from_basis = efv_basis.forces * θ + local F_direct = AtomsCalculators.forces(sys, et_calc) + local max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) + + if max_diff >= 1e-10 + @warn "Force mismatch for $label: max_diff=$max_diff" + all_species_ok = false + end +end +print_tf(@test all_species_ok) +println() + +## + +@info("Testing species-specific basis contributions") + +# Verify that different species contribute to different parts of the basis +nbasis = et_model.readout.in_dim +nspecies = et_model.readout.ncat + +# For pure Si system, only Si basis should contribute +efv_si = ETM.energy_forces_virial_basis(sys_si, et_calc) +E_basis_si = ustrip.(efv_si.energy) + +# For pure O system, only O basis should contribute +efv_o = ETM.energy_forces_virial_basis(sys_o, et_calc) +E_basis_o = ustrip.(efv_o.energy) + +# Check that the patterns differ (different species activate different parameters) +# Si uses parameters 1:nbasis, O uses parameters (nbasis+1):(2*nbasis) +si_params = E_basis_si[1:nbasis] +o_params_for_si = E_basis_si[(nbasis+1):end] +o_params = E_basis_o[(nbasis+1):end] +si_params_for_o = E_basis_o[1:nbasis] + +# Pure Si should have zero contribution from O parameters +@test all(abs.(o_params_for_si) .< 1e-12) +# Pure O should have zero contribution from Si parameters +@test all(abs.(si_params_for_o) .< 1e-12) +# Pure Si should have nonzero Si parameters +@test any(abs.(si_params) .> 1e-12) +# Pure O should have nonzero O parameters +@test any(abs.(o_params) .> 1e-12) + +## + +@info("All Phase 5b extended tests passed!") + +## ============================================================================ +## Phase 5c: Training Assembly for ETPairPotential, ETOneBodyPotential, StackedCalculator +## ============================================================================ + +@info("Testing Phase 5c: Training assembly for pair, onebody, and stacked calculators") + +## + +@info("Testing ETOneBodyPotential training assembly (empty - no learnable params)") + +# Create ETOneBody calculator +E0s = model.Vref.E0 +zlist = ChemicalSpecies.(model.rbasis._i2z) +E0_dict = Dict(z => E0s[z.atomic_number] for z in zlist) +et_onebody = ETM.one_body(E0_dict, x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) +onebody_calc = ETM.WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + +# Test length_basis returns 0 +@test ETM.length_basis(onebody_calc) == 0 + +# Test energy_forces_virial_basis returns empty arrays +sys = rand_struct() +efv_onebody = ETM.energy_forces_virial_basis(sys, onebody_calc) +@test length(efv_onebody.energy) == 0 +@test size(efv_onebody.forces, 2) == 0 +@test length(efv_onebody.virial) == 0 + +# Test get/set_linear_parameters +@test length(ETM.get_linear_parameters(onebody_calc)) == 0 +ETM.set_linear_parameters!(onebody_calc, Float64[]) # Should not error + +# Test ACEfit.basis_size +@test ACEfit.basis_size(onebody_calc) == 0 + +## + +@info("Testing ETPairPotential training assembly") + +# Need a model with learnable pair basis for this test +# Create a new model with pair_learnable=true +elements_pair = (:Si, :O) +level_pair = M.TotalDegree() +max_level_pair = 10 +order_pair = 3 +maxl_pair = 4 + +rin0cuts_pair = M._default_rin0cuts(elements_pair) +rin0cuts_pair = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts_pair) + +model_pair = M.ace_model(; elements = elements_pair, order = order_pair, + Ytype = :solid, level = level_pair, max_level = max_level_pair, + maxl = maxl_pair, pair_maxn = max_level_pair, + rin0cuts = rin0cuts_pair, + pair_learnable = true, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) +ps_pair, st_pair = Lux.setup(rng, model_pair) + +# Convert pair potential +et_pair = ETM.convertpair(model_pair) +et_pair_ps, et_pair_st = Lux.setup(rng, et_pair) + +# Copy pair parameters using utility function +ETM.copy_pair_params!(et_pair_ps, ps_pair, model_pair) + +rcut_pair = maximum(a.rcut for a in model_pair.pairbasis.rin0cuts) +pair_calc = ETM.ETPairPotential(et_pair, et_pair_ps, et_pair_st, rcut_pair) + +# Test length_basis +pair_nbasis = et_pair.readout.in_dim +pair_nspecies = et_pair.readout.ncat +@test ETM.length_basis(pair_calc) == pair_nbasis * pair_nspecies + +# Test energy_forces_virial_basis +sys_pair = rand_struct() # Uses Si/O system from earlier +efv_pair = ETM.energy_forces_virial_basis(sys_pair, pair_calc) +natoms_pair = length(sys_pair) +nparams_pair = ETM.length_basis(pair_calc) + +@test length(efv_pair.energy) == nparams_pair +@test size(efv_pair.forces) == (natoms_pair, nparams_pair) +@test length(efv_pair.virial) == nparams_pair + +# Test linear combination gives correct energy +θ_pair = ETM.get_linear_parameters(pair_calc) +E_from_pair_basis = dot(ustrip.(efv_pair.energy), θ_pair) +E_pair_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys_pair, pair_calc)) +print_tf(@test E_from_pair_basis ≈ E_pair_direct rtol=1e-10) +println() + +# Test get/set round-trip +θ_pair_test = randn(nparams_pair) +ETM.set_linear_parameters!(pair_calc, θ_pair_test) +@test ETM.get_linear_parameters(pair_calc) ≈ θ_pair_test +ETM.set_linear_parameters!(pair_calc, θ_pair) # Restore + +# Test ACEfit.basis_size +@test ACEfit.basis_size(pair_calc) == nparams_pair + +## + +@info("Testing StackedCalculator training assembly") + +# Create a StackedCalculator with E0 + Pair + ManyBody (using convert2et_full) +stacked_calc = ETM.convert2et_full(model_pair, ps_pair, st_pair) + +# Verify structure: 3 components (ETOneBody, ETPairModel, ETACE) +@test length(stacked_calc.calcs) == 3 + +# Test length_basis is sum of components +n_onebody = ETM.length_basis(stacked_calc.calcs[1]) +n_pair = ETM.length_basis(stacked_calc.calcs[2]) +n_ace = ETM.length_basis(stacked_calc.calcs[3]) +n_total = ETM.length_basis(stacked_calc) + +@test n_onebody == 0 # ETOneBody has no learnable params +@test n_pair > 0 +@test n_ace > 0 +@test n_total == n_onebody + n_pair + n_ace + +# Test energy_forces_virial_basis +sys_stacked = rand_struct() +efv_stacked = ETM.energy_forces_virial_basis(sys_stacked, stacked_calc) +natoms_stacked = length(sys_stacked) + +@test length(efv_stacked.energy) == n_total +@test size(efv_stacked.forces) == (natoms_stacked, n_total) +@test length(efv_stacked.virial) == n_total + +# Test linear combination gives correct energy +θ_stacked = ETM.get_linear_parameters(stacked_calc) +@test length(θ_stacked) == n_total +E_from_stacked_basis = dot(ustrip.(efv_stacked.energy), θ_stacked) +E_stacked_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys_stacked, stacked_calc)) +print_tf(@test E_from_stacked_basis ≈ E_stacked_direct rtol=1e-10) +println() + +# Test linear combination gives correct forces +F_from_stacked_basis = efv_stacked.forces * θ_stacked +F_stacked_direct = AtomsCalculators.forces(sys_stacked, stacked_calc) +max_diff_stacked_F = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_stacked_basis, F_stacked_direct)) +print_tf(@test max_diff_stacked_F < 1e-10) +println() + +# Test linear combination gives correct virial +V_from_stacked_basis = sum(θ_stacked[k] * ustrip.(efv_stacked.virial[k]) for k in 1:n_total) +V_stacked_direct = ustrip.(AtomsCalculators.virial(sys_stacked, stacked_calc)) +virial_diff_stacked = maximum(abs.(V_from_stacked_basis - V_stacked_direct)) +print_tf(@test virial_diff_stacked < 1e-10) +println() + +# Test get/set_linear_parameters round-trip +θ_stacked_orig = copy(θ_stacked) +θ_stacked_test = randn(n_total) +ETM.set_linear_parameters!(stacked_calc, θ_stacked_test) +θ_stacked_check = ETM.get_linear_parameters(stacked_calc) +@test θ_stacked_check ≈ θ_stacked_test +ETM.set_linear_parameters!(stacked_calc, θ_stacked_orig) # Restore + +# Test potential_energy_basis consistency +E_basis_stacked = ETM.potential_energy_basis(sys_stacked, stacked_calc) +@test length(E_basis_stacked) == n_total +@test ustrip.(E_basis_stacked) ≈ ustrip.(efv_stacked.energy) rtol=1e-10 + +# Test ACEfit.basis_size +@test ACEfit.basis_size(stacked_calc) == n_total + +## + +@info("All Phase 5c tests passed!") diff --git a/test/et_models/test_et_silicon.jl b/test/et_models/test_et_silicon.jl new file mode 100644 index 000000000..204a52f99 --- /dev/null +++ b/test/et_models/test_et_silicon.jl @@ -0,0 +1,231 @@ +# Integration test for ETACE calculators +# +# This test verifies that ETACE calculators produce identical results +# to the many-body part of ACE models (excluding pair potential). +# +# Note: convert2et only supports LearnableRnlrzzBasis (not SplineRnlrzzBasis), +# so we use ace_model() directly instead of ace1_model(). +# ETACE implements only the many-body basis, not the pair potential. + +using Test +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +using ExtXYZ, AtomsBase, Unitful, StaticArrays +using AtomsCalculators +using LazyArtifacts +using LuxCore, Lux, Random, LinearAlgebra + +@info("ETACE Integration Test: Silicon dataset") + +## ----- setup ----- + +# Build model using ace_model (LearnableRnlrzzBasis, compatible with convert2et) +elements = (:Si,) +level = M.TotalDegree() +max_level = 12 +order = 3 +maxl = 6 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +rng = Random.MersenneTwister(1234) + +ace_model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, ace_model) + +# Create ACE calculator +model = M.ACEPotential(ace_model, ps, st) + +# Load dataset +data = ExtXYZ.load(artifact"Si_tiny_dataset" * "/Si_tiny.xyz") + +data_keys = [:energy_key => "dft_energy", + :force_key => "dft_force", + :virial_key => "dft_virial"] + +weights = Dict("default" => Dict("E"=>30.0, "F"=>1.0, "V"=>1.0), + "liq" => Dict("E"=>10.0, "F"=>0.66, "V"=>0.25)) + +## ----- Fit original ACE model ----- + +@info("Fitting original ACE model with QR solver") +acefit!(data, model; + data_keys..., + weights = weights, + solver = ACEfit.QR()) + +ace_err = ACEpotentials.compute_errors(data, model; data_keys..., weights=weights) +@info("Original ACE RMSE (set):", + E=ace_err["rmse"]["set"]["E"], + F=ace_err["rmse"]["set"]["F"], + V=ace_err["rmse"]["set"]["V"]) + +# Store for comparison +ace_rmse_E = ace_err["rmse"]["set"]["E"] +ace_rmse_F = ace_err["rmse"]["set"]["F"] +ace_rmse_V = ace_err["rmse"]["set"]["V"] + +## ----- Convert to ETACE and compare ----- + +@info("Converting to ETACE model") + +# Update ps from model after fitting +ps = model.ps +st = model.st + +# Convert to ETACE +et_model = ETM.convert2et(ace_model) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) + +# Copy radial basis parameters (single species case) +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] + +# Copy readout parameters +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] + +# Get cutoff +rcut = maximum(a.rcut for a in ace_model.pairbasis.rin0cuts) + +# Create ETACEPotential +et_calc = ETM.ETACEPotential(et_model, et_ps, et_st, rcut) + +# Create ACE model WITHOUT pair potential for fair comparison +# (ETACE only implements the many-body basis, not the pair potential) +ps_nopair = merge(ps, (Wpair = zeros(size(ps.Wpair)),)) +model_nopair = M.ACEPotential(ace_model, ps_nopair, st) + +## ----- Test energy consistency ----- + +@info("Testing energy consistency between ACE (no pair) and ETACE") + +# Skip isolated atom (index 1) - ETACE requires at least 2 atoms for graph construction +max_energy_diff = 0.0 +for (i, sys) in enumerate(data[2:min(11, length(data))]) + local E_ace = ustrip(u"eV", AtomsCalculators.potential_energy(sys, model_nopair)) + local E_etace = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + local diff = abs(E_ace - E_etace) + global max_energy_diff = max(max_energy_diff, diff) +end + +@info("Max energy difference: $max_energy_diff eV") +@test max_energy_diff < 1e-10 +println("Energy consistency: OK (max_diff = $max_energy_diff eV)") + +## ----- Test forces consistency ----- + +@info("Testing forces consistency between ACE (no pair) and ETACE") + +max_force_diff = 0.0 +for (i, sys) in enumerate(data[2:min(10, length(data))]) + F_ace = AtomsCalculators.forces(sys, model_nopair) + F_etace = AtomsCalculators.forces(sys, et_calc) + for (f1, f2) in zip(F_ace, F_etace) + diff = norm(ustrip.(f1) - ustrip.(f2)) + global max_force_diff = max(max_force_diff, diff) + end +end + +@info("Max force difference: $max_force_diff eV/Å") +@test max_force_diff < 1e-10 +println("Forces consistency: OK (max_diff = $max_force_diff eV/Å)") + +## ----- Test virial consistency ----- + +@info("Testing virial consistency between ACE (no pair) and ETACE") + +max_virial_diff = 0.0 +for (i, sys) in enumerate(data[2:min(10, length(data))]) + V_ace = AtomsCalculators.virial(sys, model_nopair) + V_etace = AtomsCalculators.virial(sys, et_calc) + diff = maximum(abs.(ustrip.(V_ace) - ustrip.(V_etace))) + global max_virial_diff = max(max_virial_diff, diff) +end + +@info("Max virial difference: $max_virial_diff eV") +@test max_virial_diff < 1e-9 +println("Virial consistency: OK (max_diff = $max_virial_diff eV)") + +## ----- Test training basis assembly ----- + +@info("Testing training basis assembly") + +# Pick a test structure +sys = data[5] +natoms = length(sys) +nparams = ETM.length_basis(et_calc) + +# Get basis +efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) + +# Verify shapes +@test length(efv_basis.energy) == nparams +@test size(efv_basis.forces) == (natoms, nparams) +@test length(efv_basis.virial) == nparams + +# Verify linear combination matches direct evaluation +θ = ETM.get_linear_parameters(et_calc) + +E_from_basis = dot(ustrip.(efv_basis.energy), θ) +E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) +@test isapprox(E_from_basis, E_direct, rtol=1e-10) + +F_from_basis = efv_basis.forces * θ +F_direct = AtomsCalculators.forces(sys, et_calc) +max_F_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) +@test max_F_diff < 1e-10 + +V_from_basis = sum(θ[k] * ustrip.(efv_basis.virial[k]) for k in 1:nparams) +V_direct = ustrip.(AtomsCalculators.virial(sys, et_calc)) +max_V_diff = maximum(abs.(V_from_basis - V_direct)) +@test max_V_diff < 1e-9 + +println("Training basis assembly: OK") + +## ----- Test StackedCalculator with ETOneBody ----- + +@info("Testing StackedCalculator with ETOneBody") + +# Create ETOneBody model with arbitrary E0 value for testing (upstream interface) +E0s = Dict(:Si => -158.54496821) # Si symbol => E0 +et_onebody = ETM.one_body(E0s, x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) +E0_calc = ETM.WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + +# Create wrapped ETACE (unified interface) +ace_calc = ETM.WrappedSiteCalculator(et_model, et_ps, et_st, rcut) + +# Stack them +stacked = ETM.StackedCalculator((E0_calc, ace_calc)) + +# Test on a few structures (skip isolated atom) +for (i, sys) in enumerate(data[2:5]) + E_E0 = AtomsCalculators.potential_energy(sys, E0_calc) + E_ace = AtomsCalculators.potential_energy(sys, ace_calc) + E_stacked = AtomsCalculators.potential_energy(sys, stacked) + + expected = ustrip(E_E0) + ustrip(E_ace) + actual = ustrip(E_stacked) + + @test isapprox(expected, actual, rtol=1e-10) +end + +println("StackedCalculator: OK") + +## ----- Summary ----- + +@info("All ETACE integration tests passed!") +@info("Summary:") +@info(" - Energy matches ACE (many-body only) to < 1e-10 eV") +@info(" - Forces match ACE (many-body only) to < 1e-10 eV/Å") +@info(" - Virial matches ACE (many-body only) to < 1e-9 eV") +@info(" - Training basis assembly verified") +@info(" - StackedCalculator composition verified") +@info("Note: ETACE implements only the many-body basis, not the pair potential.") diff --git a/test/etmodels/test_etonebody.jl b/test/etmodels/test_etonebody.jl index b74d13c9c..8e7b18f5d 100644 --- a/test/etmodels/test_etonebody.jl +++ b/test/etmodels/test_etonebody.jl @@ -100,14 +100,17 @@ println() ## -@info("Confirm correctness of gradient") +@info("Confirm correctness of gradient") sys = rand_struct() G = ET.Atoms.interaction_graph(sys, 5.0 * u"Å") ∂G1 = ETM.site_grads(et_V0, G, ps, st) -println_slim(@test size(∂G1) == (G.maxneigs, length(sys))) -println_slim(@test all( norm.(∂G1) .== 0 ) ) +# ETOneBody returns NamedTuple with empty edge_data array since there are no +# position-dependent gradients (energy only depends on atom types, not positions) +println_slim(@test ∂G1 isa NamedTuple) +println_slim(@test haskey(∂G1, :edge_data)) +println_slim(@test isempty(∂G1.edge_data)) ## diff --git a/test/models/test_committee.jl b/test/models/test_committee.jl index e8291e884..c45d1ea11 100644 --- a/test/models/test_committee.jl +++ b/test/models/test_committee.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) ## using Test, ACEbase, LinearAlgebra diff --git a/test/runtests.jl b/test/runtests.jl index 624c7b65f..af4b1d070 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,9 +17,10 @@ using ACEpotentials, Test, LazyArtifacts @testset "Weird bugs" begin include("test_bugs.jl") end # new ET backend tests - @testset "ET ACE" begin include("etmodels/test_etace.jl") end + @testset "ET ACE" begin include("etmodels/test_etace.jl") end @testset "ET OneBody" begin include("etmodels/test_etonebody.jl") end @testset "ET Pair" begin include("etmodels/test_etpair.jl") end + @testset "ET Calculators" begin include("et_models/test_et_calculators.jl") end # ACE1 compatibility tests # TODO: these tests need to be revived either by creating a JSON