|
12 | 12 | class XRDModelInference: |
13 | 13 | """Handles loading and inference for the XRD analysis model""" |
14 | 14 |
|
| 15 | + # Build a lookup table mapping space group number (1-230) to the |
| 16 | + # corresponding Hall number. spglib.get_spacegroup_type() is indexed |
| 17 | + # by Hall number (1-530), NOT by space group number. We pick the |
| 18 | + # first (standard-setting) Hall number for each space group. |
| 19 | + _sg_to_hall: Dict[int, int] = {} |
| 20 | + for _hall in range(1, 531): |
| 21 | + _sg_type = spglib.get_spacegroup_type(_hall) |
| 22 | + _sg_num = _sg_type.number if hasattr(_sg_type, 'number') else _sg_type['number'] |
| 23 | + if _sg_num not in _sg_to_hall: |
| 24 | + _sg_to_hall[_sg_num] = _hall |
| 25 | + |
15 | 26 | def __init__(self): |
16 | 27 | self.model = None |
17 | 28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
@@ -253,16 +264,23 @@ def _process_model_output(self, output) -> Dict: |
253 | 264 | } |
254 | 265 |
|
255 | 266 | def _get_space_group_symbol(self, sg_number: int) -> str: |
256 | | - """Get space group symbol from number using spglib""" |
| 267 | + """Get space group symbol from number using spglib. |
| 268 | +
|
| 269 | + spglib.get_spacegroup_type() is indexed by *Hall number* (1-530), |
| 270 | + not by space group number (1-230). We use the precomputed |
| 271 | + ``_sg_to_hall`` mapping to translate first. |
| 272 | + """ |
257 | 273 | if sg_number < 1 or sg_number > 230: |
258 | 274 | return f"SG{sg_number}" |
259 | 275 |
|
260 | 276 | try: |
261 | | - # Get space group type information from spglib |
262 | | - sg_type = spglib.get_spacegroup_type(sg_number) |
| 277 | + hall_number = self._sg_to_hall.get(sg_number) |
| 278 | + if hall_number is None: |
| 279 | + return f"SG{sg_number}" |
| 280 | + sg_type = spglib.get_spacegroup_type(hall_number) |
263 | 281 | if sg_type is not None: |
264 | | - # Use the international short symbol (Hermann-Mauguin notation) |
265 | | - return sg_type['international_short'] |
| 282 | + symbol = sg_type.international_short if hasattr(sg_type, 'international_short') else sg_type['international_short'] |
| 283 | + return symbol |
266 | 284 | return f"SG{sg_number}" |
267 | 285 | except Exception: |
268 | 286 | return f"SG{sg_number}" |
|
0 commit comments