diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index df2ad74dc..147c4ef5b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -526,8 +526,8 @@ def free_lora_adapter(): self.n_tokens = 0 self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) self.scores: npt.NDArray[np.single] = np.ndarray( - (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single - ) + (n_ctx, self._n_vocab), dtype=np.single + ) if self._logits_all else None self._mirostat_mu = ctypes.c_float( 2.0 * 5.0 @@ -638,6 +638,10 @@ def _input_ids(self) -> npt.NDArray[np.intc]: @property def _scores(self) -> npt.NDArray[np.single]: + if not self._logits_all: + raise RuntimeError( + "Llama model must be created with logits_all=True to call this method" + ) return self.scores[: self.n_tokens, :] @property @@ -646,6 +650,10 @@ def eval_tokens(self) -> Deque[int]: @property def eval_logits(self) -> Deque[List[float]]: + if not self._logits_all: + raise RuntimeError( + "Llama model must be created with logits_all=True to call this method" + ) return deque( self.scores[: self.n_tokens, :].tolist(), maxlen=self._n_ctx if self._logits_all else 1, @@ -2434,10 +2442,11 @@ def save_state(self) -> LlamaState: ) def load_state(self, state: LlamaState) -> None: - # Only filling in up to `n_tokens` and then zero-ing out the rest - self.scores[: state.n_tokens, :] = state.scores.copy() - rest = self.scores[state.n_tokens :, :] - rest[rest > 0] = 0.0 + if self._logits_all: + # Only filling in up to `n_tokens` and then zero-ing out the rest + self.scores[: state.n_tokens, :] = state.scores.copy() + rest = self.scores[state.n_tokens :, :] + rest[rest > 0] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens self._seed = state.seed