diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 65949f8..55a98c6 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -698,7 +698,8 @@ def predict( text = X_test["text"] categorical_variables = X_test["categorical_variables"] - self.pytorch_model.eval().cpu() + self.pytorch_model.eval() + device = next(self.pytorch_model.parameters()).device tokenize_output = self.tokenizer.tokenize( text.tolist(), @@ -711,15 +712,17 @@ def predict( f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method." ) - encoded_text = tokenize_output.input_ids # (batch_size, seq_len) - attention_mask = tokenize_output.attention_mask # (batch_size, seq_len) + encoded_text = tokenize_output.input_ids.to(device) # (batch_size, seq_len) + attention_mask = tokenize_output.attention_mask.to(device) # (batch_size, seq_len) if categorical_variables is not None: categorical_vars = torch.tensor( - categorical_variables, dtype=torch.float32 + categorical_variables, dtype=torch.float32, device=device ) # (batch_size, num_categorical_features) else: - categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32) + categorical_vars = torch.empty( + (encoded_text.shape[0], 0), dtype=torch.float32, device=device + ) model_output = self.pytorch_model( encoded_text,