22
33import importlib
44from dataclasses import dataclass
5+ from functools import lru_cache
56from pathlib import Path
7+ from types import ModuleType
68from typing import Any , Literal , Protocol , TypedDict
79
810import numpy as np
9- import torch
1011
1112from engine .mcts import MCTS
1213from game .actions import ACTION_SPACE
@@ -54,6 +55,15 @@ def run(self, output_names: list[str] | None, input_feed: dict[str, Any]) -> lis
5455 ...
5556
5657
58+ @lru_cache (maxsize = 1 )
59+ def _get_torch_module () -> ModuleType | None :
60+ """Import torch lazily so API startup does not hard-fail in lightweight runtimes."""
61+ try :
62+ return importlib .import_module ("torch" )
63+ except ModuleNotFoundError :
64+ return None
65+
66+
5767class InferenceService :
5868 """Checkpoint-backed inference service for Ataxx move selection."""
5969
@@ -104,10 +114,24 @@ def __init__(
104114 @staticmethod
105115 def _resolve_device (device : str ) -> str :
106116 if device == "auto" :
107- return "cuda" if torch .cuda .is_available () else "cpu"
117+ torch_module = _get_torch_module ()
118+ if torch_module is not None and bool (torch_module .cuda .is_available ()):
119+ return "cuda"
120+ return "cpu"
108121 return device
109122
123+ @staticmethod
124+ def _require_torch () -> ModuleType :
125+ torch_module = _get_torch_module ()
126+ if torch_module is None :
127+ raise ValueError (
128+ "Torch is required for checkpoint-backed inference. "
129+ "Use ONNX artifacts or install torch in this runtime."
130+ )
131+ return torch_module
132+
110133 def _load_system (self ) -> AtaxxZero :
134+ torch_module = self ._require_torch ()
111135 ckpt = self .checkpoint_path
112136 if ckpt .suffix == ".ckpt" :
113137 try :
@@ -118,7 +142,7 @@ def _load_system(self) -> AtaxxZero:
118142 "reentrena o usa carga parcial manual (strict=False)."
119143 ) from exc
120144
121- checkpoint = torch .load (str (ckpt ), map_location = self .device , weights_only = False )
145+ checkpoint = torch_module .load (str (ckpt ), map_location = self .device , weights_only = False )
122146 if not isinstance (checkpoint , dict ):
123147 raise ValueError ("Invalid .pt checkpoint format: expected dictionary." )
124148 state_dict_obj = checkpoint .get ("state_dict" )
@@ -237,15 +261,18 @@ def _fast_result(self, board: AtaxxBoard) -> InferenceResult:
237261
238262 if self .system is None :
239263 raise ValueError ("Fast inference unavailable: no torch checkpoint and ONNX failed." )
264+ torch_module = self ._require_torch ()
240265 mask_np = self ._legal_action_mask (board )
241266 obs = board .get_observation ()
242267
243- obs_tensor = torch .from_numpy (obs ).unsqueeze (0 ).to (self .device )
244- mask_tensor = torch .from_numpy (mask_np ).unsqueeze (0 ).to (self .device )
245- with torch .no_grad ():
268+ obs_tensor = torch_module .from_numpy (obs ).unsqueeze (0 ).to (self .device )
269+ mask_tensor = torch_module .from_numpy (mask_np ).unsqueeze (0 ).to (self .device )
270+ with torch_module .no_grad ():
246271 policy_logits , value_tensor = self .system .model (obs_tensor , action_mask = mask_tensor )
247272
248- policy = torch .softmax (policy_logits , dim = 1 ).squeeze (0 ).detach ().cpu ().numpy ()
273+ policy = (
274+ torch_module .softmax (policy_logits , dim = 1 ).squeeze (0 ).detach ().cpu ().numpy ()
275+ )
249276 if not np .all (np .isfinite (policy )):
250277 legal_indices = np .flatnonzero (mask_np > 0 )
251278 if legal_indices .size == 0 :
@@ -269,6 +296,7 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
269296 if self .system is None :
270297 # If no torch model is available, degrade gracefully to fast ONNX/Torch.
271298 return self ._fast_result (board )
299+ torch_module = self ._require_torch ()
272300 mcts = self ._ensure_mcts ()
273301 probs = mcts .run (board = board , add_dirichlet_noise = False , temperature = 0.0 )
274302 action_idx = int (np .argmax (probs ))
@@ -277,9 +305,9 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
277305 # Value still comes from raw net (current-player perspective), which is stable and cheap.
278306 mask_np = self ._legal_action_mask (board )
279307 obs = board .get_observation ()
280- obs_tensor = torch .from_numpy (obs ).unsqueeze (0 ).to (self .device )
281- mask_tensor = torch .from_numpy (mask_np ).unsqueeze (0 ).to (self .device )
282- with torch .no_grad ():
308+ obs_tensor = torch_module .from_numpy (obs ).unsqueeze (0 ).to (self .device )
309+ mask_tensor = torch_module .from_numpy (mask_np ).unsqueeze (0 ).to (self .device )
310+ with torch_module .no_grad ():
283311 _ , value_tensor = self .system .model (obs_tensor , action_mask = mask_tensor )
284312 value = float (value_tensor .item ())
285313 return InferenceResult (move = move , action_idx = action_idx , value = value , mode = "strong" )
0 commit comments