Skip to content
26 changes: 13 additions & 13 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,14 @@ def compute_input_stats(
f"numb_fparam > 0 but no fparam data is provided "
f"for system {ii}."
)
cat_data = np.concatenate(
[frame["fparam"] for frame in sampled], axis=0
)
cat_data = np.reshape(cat_data, [-1, self.numb_fparam])
xp_fp = array_api_compat.array_namespace(sampled[0]["fparam"])
cat_data = xp_fp.concat([frame["fparam"] for frame in sampled], axis=0)
cat_data = xp_fp.reshape(cat_data, (-1, self.numb_fparam))
fparam_stats = [
StatItem(
number=cat_data.shape[0],
sum=np.sum(cat_data[:, ii]),
squared_sum=np.sum(cat_data[:, ii] ** 2),
sum=float(xp_fp.sum(cat_data[:, ii])),
squared_sum=float(xp_fp.sum(cat_data[:, ii] ** 2)),
)
for ii in range(self.numb_fparam)
]
Expand Down Expand Up @@ -335,22 +334,23 @@ def compute_input_stats(
f"numb_aparam > 0 but no aparam data is provided "
f"for system {ii}."
)
xp_ap = array_api_compat.array_namespace(sampled[0]["aparam"])
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
for ss_ in [frame["aparam"] for frame in sampled]:
ss = np.reshape(ss_, [-1, self.numb_aparam])
sys_sumv.append(np.sum(ss, axis=0))
sys_sumv2.append(np.sum(ss * ss, axis=0))
ss = xp_ap.reshape(ss_, (-1, self.numb_aparam))
sys_sumv.append(xp_ap.sum(ss, axis=0))
sys_sumv2.append(xp_ap.sum(ss * ss, axis=0))
sys_sumn.append(ss.shape[0])
sumv = np.sum(np.stack(sys_sumv), axis=0)
sumv2 = np.sum(np.stack(sys_sumv2), axis=0)
sumv = xp_ap.sum(xp_ap.stack(sys_sumv), axis=0)
sumv2 = xp_ap.sum(xp_ap.stack(sys_sumv2), axis=0)
sumn = sum(sys_sumn)
aparam_stats = [
StatItem(
number=sumn,
sum=sumv[ii],
squared_sum=sumv2[ii],
sum=float(sumv[ii]),
squared_sum=float(sumv2[ii]),
)
for ii in range(self.numb_aparam)
]
Expand Down
37 changes: 34 additions & 3 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def model_call_from_call_lower(
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
coord_corr_for_virial: Array | None = None,
) -> dict[str, Array]:
"""Return model prediction from lower interface.

Expand Down Expand Up @@ -119,14 +120,33 @@ def model_call_from_call_lower(
distinguish_types=False,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
if coord_corr_for_virial is not None:
xp = array_api_compat.array_namespace(coord_corr_for_virial)
# mapping: nf x nall -> nf x nall x 1, then tile to nf x nall x 3
mapping_idx = xp.tile(
xp.reshape(mapping, (nframes, -1, 1)),
(1, 1, 3),
)
extended_coord_corr = xp.take_along_axis(
coord_corr_for_virial,
mapping_idx,
axis=1,
)
else:
extended_coord_corr = None
call_lower_kwargs: dict[str, Any] = {
"fparam": fp,
"aparam": ap,
"do_atomic_virial": do_atomic_virial,
}
if extended_coord_corr is not None:
call_lower_kwargs["extended_coord_corr"] = extended_coord_corr
model_predict_lower = call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
**call_lower_kwargs,
)
model_predict = communicate_extended_output(
model_predict_lower,
Expand Down Expand Up @@ -237,6 +257,7 @@ def call_common(
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
coord_corr_for_virial: Array | None = None,
) -> dict[str, Array]:
"""Return model prediction.

Expand All @@ -255,6 +276,9 @@ def call_common(
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
coord_corr_for_virial
The coordinates correction for virial.
shape: nf x (nloc x 3)

Returns
-------
Expand All @@ -279,6 +303,7 @@ def call_common(
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
coord_corr_for_virial=coord_corr_for_virial,
)
model_predict = self._output_type_cast(model_predict, input_prec)
return model_predict
Expand All @@ -292,6 +317,7 @@ def call_common_lower(
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
extended_coord_corr: Array | None = None,
) -> dict[str, Array]:
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand All @@ -314,6 +340,9 @@ def call_common_lower(
atomic parameter. nf x nloc x nda
do_atomic_virial
whether calculate atomic virial
extended_coord_corr
coordinates correction for virial in extended region.
nf x (nall x 3)

Returns
-------
Expand Down Expand Up @@ -341,6 +370,7 @@ def call_common_lower(
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
extended_coord_corr=extended_coord_corr,
)
model_predict = self._output_type_cast(model_predict, input_prec)
return model_predict
Expand All @@ -354,6 +384,7 @@ def forward_common_atomic(
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
extended_coord_corr: Array | None = None,
) -> dict[str, Array]:
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def get_spin_model(data: dict) -> SpinModel:
data : dict
The data to construct the model.
"""
data = copy.deepcopy(data)
# include virtual spin and placeholder types
data["type_map"] += [item + "_spin" for item in data["type_map"]]
spin = Spin(
Expand Down
Loading
Loading