Skip to content

Commit ec2a2b3

Browse files
committed
lazy-load torch for api runtime startup
1 parent 3ce9057 commit ec2a2b3

1 file changed

Lines changed: 38 additions & 10 deletions

File tree

src/inference/service.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import importlib
44
from dataclasses import dataclass
5+
from functools import lru_cache
56
from pathlib import Path
7+
from types import ModuleType
68
from typing import Any, Literal, Protocol, TypedDict
79

810
import numpy as np
9-
import torch
1011

1112
from engine.mcts import MCTS
1213
from game.actions import ACTION_SPACE
@@ -54,6 +55,15 @@ def run(self, output_names: list[str] | None, input_feed: dict[str, Any]) -> lis
5455
...
5556

5657

58+
@lru_cache(maxsize=1)
59+
def _get_torch_module() -> ModuleType | None:
60+
"""Import torch lazily so API startup does not hard-fail in lightweight runtimes."""
61+
try:
62+
return importlib.import_module("torch")
63+
except ModuleNotFoundError:
64+
return None
65+
66+
5767
class InferenceService:
5868
"""Checkpoint-backed inference service for Ataxx move selection."""
5969

@@ -104,10 +114,24 @@ def __init__(
104114
@staticmethod
105115
def _resolve_device(device: str) -> str:
106116
if device == "auto":
107-
return "cuda" if torch.cuda.is_available() else "cpu"
117+
torch_module = _get_torch_module()
118+
if torch_module is not None and bool(torch_module.cuda.is_available()):
119+
return "cuda"
120+
return "cpu"
108121
return device
109122

123+
@staticmethod
124+
def _require_torch() -> ModuleType:
125+
torch_module = _get_torch_module()
126+
if torch_module is None:
127+
raise ValueError(
128+
"Torch is required for checkpoint-backed inference. "
129+
"Use ONNX artifacts or install torch in this runtime."
130+
)
131+
return torch_module
132+
110133
def _load_system(self) -> AtaxxZero:
134+
torch_module = self._require_torch()
111135
ckpt = self.checkpoint_path
112136
if ckpt.suffix == ".ckpt":
113137
try:
@@ -118,7 +142,7 @@ def _load_system(self) -> AtaxxZero:
118142
"reentrena o usa carga parcial manual (strict=False)."
119143
) from exc
120144

121-
checkpoint = torch.load(str(ckpt), map_location=self.device, weights_only=False)
145+
checkpoint = torch_module.load(str(ckpt), map_location=self.device, weights_only=False)
122146
if not isinstance(checkpoint, dict):
123147
raise ValueError("Invalid .pt checkpoint format: expected dictionary.")
124148
state_dict_obj = checkpoint.get("state_dict")
@@ -237,15 +261,18 @@ def _fast_result(self, board: AtaxxBoard) -> InferenceResult:
237261

238262
if self.system is None:
239263
raise ValueError("Fast inference unavailable: no torch checkpoint and ONNX failed.")
264+
torch_module = self._require_torch()
240265
mask_np = self._legal_action_mask(board)
241266
obs = board.get_observation()
242267

243-
obs_tensor = torch.from_numpy(obs).unsqueeze(0).to(self.device)
244-
mask_tensor = torch.from_numpy(mask_np).unsqueeze(0).to(self.device)
245-
with torch.no_grad():
268+
obs_tensor = torch_module.from_numpy(obs).unsqueeze(0).to(self.device)
269+
mask_tensor = torch_module.from_numpy(mask_np).unsqueeze(0).to(self.device)
270+
with torch_module.no_grad():
246271
policy_logits, value_tensor = self.system.model(obs_tensor, action_mask=mask_tensor)
247272

248-
policy = torch.softmax(policy_logits, dim=1).squeeze(0).detach().cpu().numpy()
273+
policy = (
274+
torch_module.softmax(policy_logits, dim=1).squeeze(0).detach().cpu().numpy()
275+
)
249276
if not np.all(np.isfinite(policy)):
250277
legal_indices = np.flatnonzero(mask_np > 0)
251278
if legal_indices.size == 0:
@@ -269,6 +296,7 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
269296
if self.system is None:
270297
# If no torch model is available, degrade gracefully to fast ONNX/Torch.
271298
return self._fast_result(board)
299+
torch_module = self._require_torch()
272300
mcts = self._ensure_mcts()
273301
probs = mcts.run(board=board, add_dirichlet_noise=False, temperature=0.0)
274302
action_idx = int(np.argmax(probs))
@@ -277,9 +305,9 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
277305
# Value still comes from raw net (current-player perspective), which is stable and cheap.
278306
mask_np = self._legal_action_mask(board)
279307
obs = board.get_observation()
280-
obs_tensor = torch.from_numpy(obs).unsqueeze(0).to(self.device)
281-
mask_tensor = torch.from_numpy(mask_np).unsqueeze(0).to(self.device)
282-
with torch.no_grad():
308+
obs_tensor = torch_module.from_numpy(obs).unsqueeze(0).to(self.device)
309+
mask_tensor = torch_module.from_numpy(mask_np).unsqueeze(0).to(self.device)
310+
with torch_module.no_grad():
283311
_, value_tensor = self.system.model(obs_tensor, action_mask=mask_tensor)
284312
value = float(value_tensor.item())
285313
return InferenceResult(move=move, action_idx=action_idx, value=value, mode="strong")

0 commit comments

Comments
 (0)