-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathcloud_server.py
More file actions
108 lines (91 loc) · 3.72 KB
/
cloud_server.py
File metadata and controls
108 lines (91 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import threading
import time
import numpy as np
from queue import Queue
from datetime import datetime
import yaml
import munch
import grpc
from concurrent import futures
from loguru import logger
from database.database import DataBase
from edge.info import TASK_STATE
from grpc_server.rpc_server import MessageTransmissionServicer
from model_management.object_detection import Object_Detection
from grpc_server import message_transmission_pb2, message_transmission_pb2_grpc
class CloudServer:
def __init__(self, config):
self.config = config
self.server_id = config.server_id
self.large_object_detection = Object_Detection(config, type='large inference')
#create database and tables
self.database = DataBase(self.config.database)
self.database.use_database()
self.database.create_table(self.config.edge_ids)
# start the thread for local process
self.local_queue = Queue(config.local_queue_maxsize)
self.local_processor = threading.Thread(target=self.cloud_local, daemon=True)
self.local_processor.start()
def cloud_local(self):
while True:
task = self.local_queue.get(block=True)
if time.time() - task.start_time >= self.config.wait_thresh:
end_time = time.time()
task.end_time = end_time
task.state = TASK_STATE.TIMEOUT
self.update_table(task)
continue
task.frame_cloud = task.frame_edge
frame = task.frame_cloud
high_boxes, high_class, high_score = self.large_object_detection.large_inference(frame)
# scale the small result
scale = task.raw_shape[0] / frame.shape[0]
if high_boxes:
high_boxes = (np.array(high_boxes) * scale).tolist()
task.add_result(high_boxes, high_class, high_score)
end_time = time.time()
task.end_time = end_time
# upload the result to database
task.state = TASK_STATE.FINISHED
self.update_table(task)
def update_table(self, task):
if task.state == TASK_STATE.FINISHED:
state = "Finished"
elif task.state == TASK_STATE.TIMEOUT:
state = "Timeout"
else:
state = ""
detection_boxes, detection_class, detection_score = task.get_result()
result = {
'labels': detection_class,
'boxes': detection_boxes,
'scores': detection_score
}
# upload the result to database
data = (
task.end_time,
str(result),
state,
task.frame_index)
self.database.update_data(task.edge_id, data)
def start_server(self):
logger.info("cloud server is starting")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
message_transmission_pb2_grpc.add_MessageTransmissionServicer_to_server(
MessageTransmissionServicer(self.local_queue, self.server_id, self.large_object_detection), server)
server.add_insecure_port('[::]:50051')
server.start()
logger.info("cloud server is listening")
server.wait_for_termination()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="configuration description")
parser.add_argument("--yaml_path", default="./config/config.yaml", help="input the path of *.yaml")
args = parser.parse_args()
with open(args.yaml_path, 'r') as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# provide class-like access for dict
config = munch.munchify(config)
server_config = config.server
cloud_server = CloudServer(server_config)
cloud_server.start_server()