|
1 | 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
2 | 2 | # Licensed under the MIT License. |
3 | 3 | import json |
4 | | -from typing import List, Union, Type |
5 | 4 |
|
6 | 5 | from aiohttp.web import RouteTableDef, Request, Response |
7 | 6 |
|
8 | | -from microsoft_agents.activity import ( |
9 | | - AgentsModel, |
10 | | - Activity, |
11 | | - AttachmentData, |
12 | | - ConversationParameters, |
13 | | - Transcript, |
14 | | -) |
15 | 7 | from microsoft_agents.hosting.core import ChannelApiHandlerProtocol |
| 8 | +from microsoft_agents.hosting.core.http import ChannelServiceRoutes |
16 | 9 |
|
17 | 10 |
|
18 | | -async def deserialize_from_body( |
19 | | - request: Request, target_model: Type[AgentsModel] |
20 | | -) -> Activity: |
21 | | - if "application/json" in request.headers["Content-Type"]: |
22 | | - body = await request.json() |
23 | | - else: |
24 | | - return Response(status=415) |
| 11 | +class AiohttpRequestAdapter: |
| 12 | + """Adapter for aiohttp requests to use with ChannelServiceRoutes.""" |
25 | 13 |
|
26 | | - return target_model.model_validate(body) |
| 14 | + def __init__(self, request: Request): |
| 15 | + self._request = request |
27 | 16 |
|
| 17 | + @property |
| 18 | + def method(self) -> str: |
| 19 | + return self._request.method |
28 | 20 |
|
29 | | -def get_serialized_response( |
30 | | - model_or_list: Union[AgentsModel, List[AgentsModel]], |
31 | | -) -> Response: |
32 | | - if isinstance(model_or_list, AgentsModel): |
33 | | - json_obj = model_or_list.model_dump( |
34 | | - mode="json", exclude_unset=True, by_alias=True |
35 | | - ) |
36 | | - else: |
37 | | - json_obj = [ |
38 | | - model.model_dump(mode="json", exclude_unset=True, by_alias=True) |
39 | | - for model in model_or_list |
40 | | - ] |
| 21 | + @property |
| 22 | + def headers(self): |
| 23 | + return self._request.headers |
| 24 | + |
| 25 | + async def json(self): |
| 26 | + return await self._request.json() |
| 27 | + |
| 28 | + def get_claims_identity(self): |
| 29 | + return self._request.get("claims_identity") |
41 | 30 |
|
42 | | - return Response(body=json.dumps(json_obj), content_type="application/json") |
| 31 | + def get_path_param(self, name: str) -> str: |
| 32 | + return self._request.match_info[name] |
43 | 33 |
|
44 | 34 |
|
45 | 35 | def channel_service_route_table( |
46 | 36 | handler: ChannelApiHandlerProtocol, base_url: str = "" |
47 | 37 | ) -> RouteTableDef: |
48 | | - # pylint: disable=unused-variable |
| 38 | + """Create aiohttp route table for Channel Service API. |
| 39 | +
|
| 40 | + Args: |
| 41 | + handler: The handler that implements the Channel API protocol. |
| 42 | + base_url: Optional base URL prefix for all routes. |
| 43 | +
|
| 44 | + Returns: |
| 45 | + RouteTableDef with all channel service routes. |
| 46 | + """ |
49 | 47 | routes = RouteTableDef() |
| 48 | + service_routes = ChannelServiceRoutes(handler, base_url) |
| 49 | + |
| 50 | + def json_response(data: dict) -> Response: |
| 51 | + return Response(body=json.dumps(data), content_type="application/json") |
50 | 52 |
|
51 | 53 | @routes.post(base_url + "/v3/conversations/{conversation_id}/activities") |
52 | 54 | async def send_to_conversation(request: Request): |
53 | | - activity = await deserialize_from_body(request, Activity) |
54 | | - result = await handler.on_send_to_conversation( |
55 | | - request.get("claims_identity"), |
56 | | - request.match_info["conversation_id"], |
57 | | - activity, |
| 55 | + result = await service_routes.send_to_conversation( |
| 56 | + AiohttpRequestAdapter(request) |
58 | 57 | ) |
59 | | - |
60 | | - return get_serialized_response(result) |
| 58 | + return json_response(result) |
61 | 59 |
|
62 | 60 | @routes.post( |
63 | 61 | base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" |
64 | 62 | ) |
65 | 63 | async def reply_to_activity(request: Request): |
66 | | - activity = await deserialize_from_body(request, Activity) |
67 | | - result = await handler.on_reply_to_activity( |
68 | | - request.get("claims_identity"), |
69 | | - request.match_info["conversation_id"], |
70 | | - request.match_info["activity_id"], |
71 | | - activity, |
72 | | - ) |
73 | | - |
74 | | - return get_serialized_response(result) |
| 64 | + result = await service_routes.reply_to_activity(AiohttpRequestAdapter(request)) |
| 65 | + return json_response(result) |
75 | 66 |
|
76 | 67 | @routes.put( |
77 | 68 | base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" |
78 | 69 | ) |
79 | 70 | async def update_activity(request: Request): |
80 | | - activity = await deserialize_from_body(request, Activity) |
81 | | - result = await handler.on_update_activity( |
82 | | - request.get("claims_identity"), |
83 | | - request.match_info["conversation_id"], |
84 | | - request.match_info["activity_id"], |
85 | | - activity, |
86 | | - ) |
87 | | - |
88 | | - return get_serialized_response(result) |
| 71 | + result = await service_routes.update_activity(AiohttpRequestAdapter(request)) |
| 72 | + return json_response(result) |
89 | 73 |
|
90 | 74 | @routes.delete( |
91 | 75 | base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" |
92 | 76 | ) |
93 | 77 | async def delete_activity(request: Request): |
94 | | - await handler.on_delete_activity( |
95 | | - request.get("claims_identity"), |
96 | | - request.match_info["conversation_id"], |
97 | | - request.match_info["activity_id"], |
98 | | - ) |
99 | | - |
| 78 | + await service_routes.delete_activity(AiohttpRequestAdapter(request)) |
100 | 79 | return Response() |
101 | 80 |
|
102 | 81 | @routes.get( |
103 | 82 | base_url |
104 | 83 | + "/v3/conversations/{conversation_id}/activities/{activity_id}/members" |
105 | 84 | ) |
106 | 85 | async def get_activity_members(request: Request): |
107 | | - result = await handler.on_get_activity_members( |
108 | | - request.get("claims_identity"), |
109 | | - request.match_info["conversation_id"], |
110 | | - request.match_info["activity_id"], |
| 86 | + result = await service_routes.get_activity_members( |
| 87 | + AiohttpRequestAdapter(request) |
111 | 88 | ) |
112 | | - |
113 | | - return get_serialized_response(result) |
| 89 | + return json_response(result) |
114 | 90 |
|
115 | 91 | @routes.post(base_url + "/") |
116 | 92 | async def create_conversation(request: Request): |
117 | | - conversation_parameters = deserialize_from_body(request, ConversationParameters) |
118 | | - result = await handler.on_create_conversation( |
119 | | - request.get("claims_identity"), conversation_parameters |
| 93 | + result = await service_routes.create_conversation( |
| 94 | + AiohttpRequestAdapter(request) |
120 | 95 | ) |
121 | | - |
122 | | - return get_serialized_response(result) |
| 96 | + return json_response(result) |
123 | 97 |
|
124 | 98 | @routes.get(base_url + "/") |
125 | 99 | async def get_conversation(request: Request): |
126 | | - # TODO: continuation token? conversation_id? |
127 | | - result = await handler.on_get_conversations( |
128 | | - request.get("claims_identity"), None |
129 | | - ) |
130 | | - |
131 | | - return get_serialized_response(result) |
| 100 | + result = await service_routes.get_conversations(AiohttpRequestAdapter(request)) |
| 101 | + return json_response(result) |
132 | 102 |
|
133 | 103 | @routes.get(base_url + "/v3/conversations/{conversation_id}/members") |
134 | 104 | async def get_conversation_members(request: Request): |
135 | | - result = await handler.on_get_conversation_members( |
136 | | - request.get("claims_identity"), |
137 | | - request.match_info["conversation_id"], |
| 105 | + result = await service_routes.get_conversation_members( |
| 106 | + AiohttpRequestAdapter(request) |
138 | 107 | ) |
139 | | - |
140 | | - return get_serialized_response(result) |
| 108 | + return json_response(result) |
141 | 109 |
|
142 | 110 | @routes.get(base_url + "/v3/conversations/{conversation_id}/members/{member_id}") |
143 | 111 | async def get_conversation_member(request: Request): |
144 | | - result = await handler.on_get_conversation_member( |
145 | | - request.get("claims_identity"), |
146 | | - request.match_info["member_id"], |
147 | | - request.match_info["conversation_id"], |
| 112 | + result = await service_routes.get_conversation_member( |
| 113 | + AiohttpRequestAdapter(request) |
148 | 114 | ) |
149 | | - |
150 | | - return get_serialized_response(result) |
| 115 | + return json_response(result) |
151 | 116 |
|
152 | 117 | @routes.get(base_url + "/v3/conversations/{conversation_id}/pagedmembers") |
153 | 118 | async def get_conversation_paged_members(request: Request): |
154 | | - # TODO: continuation token? page size? |
155 | | - result = await handler.on_get_conversation_paged_members( |
156 | | - request.get("claims_identity"), |
157 | | - request.match_info["conversation_id"], |
| 119 | + result = await service_routes.get_conversation_paged_members( |
| 120 | + AiohttpRequestAdapter(request) |
158 | 121 | ) |
159 | | - |
160 | | - return get_serialized_response(result) |
| 122 | + return json_response(result) |
161 | 123 |
|
162 | 124 | @routes.delete(base_url + "/v3/conversations/{conversation_id}/members/{member_id}") |
163 | 125 | async def delete_conversation_member(request: Request): |
164 | | - result = await handler.on_delete_conversation_member( |
165 | | - request.get("claims_identity"), |
166 | | - request.match_info["conversation_id"], |
167 | | - request.match_info["member_id"], |
| 126 | + result = await service_routes.delete_conversation_member( |
| 127 | + AiohttpRequestAdapter(request) |
168 | 128 | ) |
169 | | - |
170 | | - return get_serialized_response(result) |
| 129 | + return json_response(result) |
171 | 130 |
|
172 | 131 | @routes.post(base_url + "/v3/conversations/{conversation_id}/activities/history") |
173 | 132 | async def send_conversation_history(request: Request): |
174 | | - transcript = deserialize_from_body(request, Transcript) |
175 | | - result = await handler.on_send_conversation_history( |
176 | | - request.get("claims_identity"), |
177 | | - request.match_info["conversation_id"], |
178 | | - transcript, |
| 133 | + result = await service_routes.send_conversation_history( |
| 134 | + AiohttpRequestAdapter(request) |
179 | 135 | ) |
180 | | - |
181 | | - return get_serialized_response(result) |
| 136 | + return json_response(result) |
182 | 137 |
|
183 | 138 | @routes.post(base_url + "/v3/conversations/{conversation_id}/attachments") |
184 | 139 | async def upload_attachment(request: Request): |
185 | | - attachment_data = deserialize_from_body(request, AttachmentData) |
186 | | - result = await handler.on_upload_attachment( |
187 | | - request.get("claims_identity"), |
188 | | - request.match_info["conversation_id"], |
189 | | - attachment_data, |
190 | | - ) |
191 | | - |
192 | | - return get_serialized_response(result) |
| 140 | + result = await service_routes.upload_attachment(AiohttpRequestAdapter(request)) |
| 141 | + return json_response(result) |
193 | 142 |
|
194 | 143 | return routes |
0 commit comments