-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_handler.py
More file actions
116 lines (98 loc) · 4.65 KB
/
model_handler.py
File metadata and controls
116 lines (98 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import random
from typing import Tuple
import os
# Predefined classes for classification (using standard ImageNet classes as example)
CLASSES = [
"tench", "goldfish", "great_white_shark", "tiger_shark", "hammerhead",
"electric_ray", "stingray", "cock", "hen", "ostrich",
"brambling", "goldfinch", "house_finch", "junco", "indigo_bunting",
"robin", "bulbul", "jay", "magpie", "chickadee",
"water_ouzel", "kite", "bald_eagle", "vulture", "great_grey_owl"
]
class ImageClassifier:
def __init__(self):
# In a real application, you would load your trained PyTorch model here
# Example: self.model = torch.load("path/to/your/model.pth")
# For this template, we'll initialize a placeholder for a model
self.classes = CLASSES
self.model = self._load_model()
def _load_model(self):
"""
Load the PyTorch model during initialization
In a real application, this would load your trained model
"""
print("Loading PyTorch model...")
return None
# try:
# # For this template, we'll simulate loading a pretrained model
# # In a real application, you would use:
# # model = torch.load("path/to/your/model.pth", map_location='cpu')
# # model.eval() # Set model to evaluation mode
# model = torchvision.models.resnet18(pretrained=True)
# model.eval()
# print("Model loaded successfully!")
# return model
# except Exception as e:
# print(f"Error loading model: {e}")
# print("Using simulated model for demonstration purposes")
# return None
def preprocess_image(self, image_path: str) -> torch.Tensor:
"""
Preprocess the input image to prepare it for the model
"""
image = Image.open(image_path)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model
return input_batch
def predict(self, image_path: str) -> Tuple[str, float]:
"""
Perform prediction on the given image using the loaded model
"""
if self.model is not None:
try:
# Preprocess the image
input_batch = self.preprocess_image(image_path)
# Perform inference
with torch.no_grad(): # Disable gradient computation for inference
output = self.model(input_batch)
# Convert to probabilities using softmax
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Get the top prediction
top_prob, top_catid = torch.topk(probabilities, 1)
# Map to class label
predicted_class_idx = top_catid.item()
confidence = top_prob.item()
# Handle case where model outputs more classes than we have names for
if predicted_class_idx < len(self.classes):
predicted_class = self.classes[predicted_class_idx]
else:
# If model predicts a class not in our predefined list, pick a random one
predicted_class = self.classes[random.randint(0, len(self.classes)-1)]
confidence = round(random.uniform(0.5, 1.0), 3)
return predicted_class, confidence
except Exception as e:
print(f"Error during prediction: {e}")
# Fallback to simulated prediction
return self._simulate_prediction()
else:
# If model failed to load, use simulation
return self._simulate_prediction()
def _simulate_prediction(self) -> Tuple[str, float]:
"""
Simulate model prediction when actual model is not available
"""
# In a real application, this shouldn't happen as model loading should succeed
predicted_idx = random.randint(0, len(self.classes) - 1)
predicted_class = self.classes[predicted_idx]
confidence = round(random.uniform(0.5, 1.0), 3)
return predicted_class, confidence