Skip to content

Commit f27f2b7

Browse files
committed
fix: Incorrect call to spglib fixed. Now shows correct international shorthand
1 parent 69322ab commit f27f2b7

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/ui/app/model_inference.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
class XRDModelInference:
1313
"""Handles loading and inference for the XRD analysis model"""
1414

15+
# Build a lookup table mapping space group number (1-230) to the
16+
# corresponding Hall number. spglib.get_spacegroup_type() is indexed
17+
# by Hall number (1-530), NOT by space group number. We pick the
18+
# first (standard-setting) Hall number for each space group.
19+
_sg_to_hall: Dict[int, int] = {}
20+
for _hall in range(1, 531):
21+
_sg_type = spglib.get_spacegroup_type(_hall)
22+
_sg_num = _sg_type.number if hasattr(_sg_type, 'number') else _sg_type['number']
23+
if _sg_num not in _sg_to_hall:
24+
_sg_to_hall[_sg_num] = _hall
25+
1526
def __init__(self):
1627
self.model = None
1728
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -253,16 +264,23 @@ def _process_model_output(self, output) -> Dict:
253264
}
254265

255266
def _get_space_group_symbol(self, sg_number: int) -> str:
256-
"""Get space group symbol from number using spglib"""
267+
"""Get space group symbol from number using spglib.
268+
269+
spglib.get_spacegroup_type() is indexed by *Hall number* (1-530),
270+
not by space group number (1-230). We use the precomputed
271+
``_sg_to_hall`` mapping to translate first.
272+
"""
257273
if sg_number < 1 or sg_number > 230:
258274
return f"SG{sg_number}"
259275

260276
try:
261-
# Get space group type information from spglib
262-
sg_type = spglib.get_spacegroup_type(sg_number)
277+
hall_number = self._sg_to_hall.get(sg_number)
278+
if hall_number is None:
279+
return f"SG{sg_number}"
280+
sg_type = spglib.get_spacegroup_type(hall_number)
263281
if sg_type is not None:
264-
# Use the international short symbol (Hermann-Mauguin notation)
265-
return sg_type['international_short']
282+
symbol = sg_type.international_short if hasattr(sg_type, 'international_short') else sg_type['international_short']
283+
return symbol
266284
return f"SG{sg_number}"
267285
except Exception:
268286
return f"SG{sg_number}"

0 commit comments

Comments
 (0)