diff --git a/http/server/HttpService.cpp b/http/server/HttpService.cpp index acab0fa08..aafbd9019 100644 --- a/http/server/HttpService.cpp +++ b/http/server/HttpService.cpp @@ -5,13 +5,105 @@ namespace hv { +static bool has_wildcard(const char* path) { + return path && strchr(path, '*'); +} + +static void split_path(const std::string& path, std::vector& segments) { + if (path.empty()) return; + size_t i = 0; + while (i < path.size() && path[i] == '/') ++i; + for (;;) { + size_t j = i; + while (j < path.size() && path[j] != '/') ++j; + if (i < path.size() || (j == path.size() && !segments.empty())) { + segments.emplace_back(path.substr(i, j - i)); + } + if (j == path.size()) break; + i = j + 1; + if (i == path.size()) { + segments.emplace_back(""); + break; + } + } +} + +static std::string parse_param_name(const std::string& segment) { + if (segment.empty()) return ""; + if (segment[0] == ':') { + return segment.substr(1); + } + if (segment[0] == '{') { + if (segment.size() >= 2 && segment.back() == '}') { + return segment.substr(1, segment.size() - 2); + } + return segment.substr(1); + } + return ""; +} + +static std::shared_ptr match_route_trie( + const std::shared_ptr& node, + const std::vector& segments, + size_t index, + std::map& params) { + if (!node) return NULL; + if (index == segments.size()) { + return node->method_handlers ? node : NULL; + } + + auto iter = node->children.find(segments[index]); + if (iter != node->children.end()) { + auto found = match_route_trie(iter->second, segments, index + 1, params); + if (found) return found; + } + + if (node->param_child) { + params[node->param_child->param_name] = segments[index]; + auto found = match_route_trie(node->param_child, segments, index + 1, params); + if (found) return found; + params.erase(node->param_child->param_name); + } + + return NULL; +} + +static int get_method_handler(std::shared_ptr method_handlers, http_method method, http_handler** handler) { + for (auto iter = method_handlers->begin(); iter != method_handlers->end(); ++iter) { + if (iter->method == method) { + if (handler) *handler = &iter->handler; + return 0; + } + } + if (handler) *handler = NULL; + return HTTP_STATUS_METHOD_NOT_ALLOWED; +} + +static bool match_wildcard_route(const std::string& pattern, const std::string& path) { + const char* kp = pattern.c_str(); + const char* vp = path.c_str(); + while (*kp && *vp) { + if (kp[0] == '*') { + return hv_strendswith(vp, kp + 1); + } + if (*kp != *vp) { + return false; + } + ++kp; + ++vp; + } + return *kp == '\0' && *vp == '\0'; +} + void HttpService::AddRoute(const char* path, http_method method, const http_handler& handler) { std::shared_ptr method_handlers = NULL; + bool is_new_path = false; auto iter = pathHandlers.find(path); if (iter == pathHandlers.end()) { // add path method_handlers = std::make_shared(); pathHandlers[path] = method_handlers; + is_new_path = true; } else { method_handlers = iter->second; @@ -25,6 +117,36 @@ void HttpService::AddRoute(const char* path, http_method method, const http_hand } // add method_handlers->push_back(http_method_handler(method, handler)); + + if (!is_new_path) return; + + std::string str_path(path); + if (has_wildcard(path)) { + wildcardHandlers.emplace_back(str_path, method_handlers); + return; + } + + std::vector segments; + split_path(str_path, segments); + auto node = routeTrie; + for (size_t i = 0; i < segments.size(); ++i) { + const auto& segment = segments[i]; + std::string param_name = parse_param_name(segment); + if (!param_name.empty()) { + if (!node->param_child) { + node->param_child = std::make_shared(); + } + node = node->param_child; + node->param_name = param_name; + } else { + auto child_iter = node->children.find(segment); + if (child_iter == node->children.end()) { + node->children[segment] = std::make_shared(); + } + node = node->children[segment]; + } + } + node->method_handlers = method_handlers; } int HttpService::GetRoute(const char* url, http_method method, http_handler** handler) { @@ -67,63 +189,31 @@ int HttpService::GetRoute(HttpRequest* req, http_handler** handler) { while (*e && *e != '?') ++e; std::string path = std::string(s, e); - const char *kp, *ks, *vp, *vs; - bool match; - for (auto iter = pathHandlers.begin(); iter != pathHandlers.end(); ++iter) { - kp = iter->first.c_str(); - vp = path.c_str(); - match = false; - std::map params; - - while (*kp && *vp) { - if (kp[0] == '*') { - // wildcard * - match = hv_strendswith(vp, kp+1); - break; - } else if (*kp != *vp) { - match = false; - break; - } else if (kp[0] == '/' && (kp[1] == ':' || kp[1] == '{')) { - // RESTful /:field/ - // RESTful /{field}/ - kp += 2; - ks = kp; - while (*kp && *kp != '/') {++kp;} - vp += 1; - vs = vp; - while (*vp && *vp != '/') {++vp;} - int klen = kp - ks; - if (*(ks-1) == '{' && *(kp-1) == '}') { - --klen; - } - params[std::string(ks, klen)] = std::string(vs, vp-vs); - continue; - } else { - ++kp; - ++vp; + std::vector segments; + split_path(path, segments); + std::map params; + auto route_node = match_route_trie(routeTrie, segments, 0, params); + if (route_node) { + int ret = get_method_handler(route_node->method_handlers, req->method, handler); + if (ret == 0) { + for (auto& param : params) { + req->query_params[param.first] = param.second; } } + return ret; + } - match = match ? match : (*kp == '\0' && *vp == '\0'); - - if (match) { - auto method_handlers = iter->second; - for (auto iter = method_handlers->begin(); iter != method_handlers->end(); ++iter) { - if (iter->method == req->method) { - for (auto& param : params) { - // RESTful /:field/ => req->query_params[field] - req->query_params[param.first] = param.second; - } - if (handler) *handler = &iter->handler; - return 0; - } - } + bool method_not_allowed = false; + for (const auto& route : wildcardHandlers) { + if (!match_wildcard_route(route.first, path)) continue; + int ret = get_method_handler(route.second, req->method, handler); + if (ret == 0) return 0; + method_not_allowed = true; + } - if (params.size() == 0) { - if (handler) *handler = NULL; - return HTTP_STATUS_METHOD_NOT_ALLOWED; - } - } + if (method_not_allowed) { + if (handler) *handler = NULL; + return HTTP_STATUS_METHOD_NOT_ALLOWED; } if (handler) *handler = NULL; return HTTP_STATUS_NOT_FOUND; diff --git a/http/server/HttpService.h b/http/server/HttpService.h index 16ef7a3c2..e9373a135 100644 --- a/http/server/HttpService.h +++ b/http/server/HttpService.h @@ -106,6 +106,15 @@ typedef std::list // path => http_method_handlers typedef std::unordered_map> http_path_handlers; +struct route_trie_node { + std::unordered_map> children; + std::shared_ptr param_child; + std::string param_name; + std::shared_ptr method_handlers; +}; + +typedef std::vector>> http_wildcard_handlers; + namespace hv { struct HV_EXPORT HttpService { @@ -121,6 +130,8 @@ struct HV_EXPORT HttpService { /* API handlers */ std::string base_url; http_path_handlers pathHandlers; + std::shared_ptr routeTrie; + http_wildcard_handlers wildcardHandlers; /* Static file service */ http_handler staticHandler; @@ -181,6 +192,8 @@ struct HV_EXPORT HttpService { enable_access_log = 1; enable_forward_proxy = 0; + + routeTrie = std::make_shared(); } void AddRoute(const char* path, http_method method, const http_handler& handler); diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt index 2eec34790..016e7990b 100644 --- a/unittest/CMakeLists.txt +++ b/unittest/CMakeLists.txt @@ -72,6 +72,10 @@ add_executable(objectpool_test objectpool_test.cpp) target_include_directories(objectpool_test PRIVATE .. ../base ../cpputil) target_link_libraries(objectpool_test -lpthread) +add_executable(http_service_route_test http_service_route_test.cpp) +target_include_directories(http_service_route_test PRIVATE ..) +target_link_libraries(http_service_route_test ${HV_LIBRARIES}) + # ------protocol------ add_executable(nslookup nslookup_test.c ../protocol/dns.c ../base/hsocket.c ../base/htime.c) target_include_directories(nslookup PRIVATE .. ../base ../protocol) @@ -111,6 +115,7 @@ add_custom_target(unittest DEPENDS synchronized_test threadpool_test objectpool_test + http_service_route_test nslookup ping ftp diff --git a/unittest/http_service_route_test.cpp b/unittest/http_service_route_test.cpp new file mode 100644 index 000000000..2b7e8f922 --- /dev/null +++ b/unittest/http_service_route_test.cpp @@ -0,0 +1,55 @@ +#include + +#include "hv/HttpService.h" + +using namespace hv; + +static std::string call_route(HttpService& router, const char* path, http_method method, int* status = NULL) { + HttpRequest req; + req.path = path; + req.method = method; + http_handler* handler = NULL; + int ret = router.GetRoute(&req, &handler); + if (status) *status = ret; + if (ret != 0 || !handler || !handler->sync_handler) { + return ""; + } + HttpResponse resp; + ret = handler->sync_handler(&req, &resp); + assert(ret == 200); + return resp.body; +} + +int main(int argc, char** argv) { + HttpService router; + router.GET("/status", [](HttpRequest* req, HttpResponse* resp) { + (void)req; + resp->body = "EXACT:/status"; + return 200; + }); + router.GET("/user/:id", [](HttpRequest* req, HttpResponse* resp) { + resp->body = req->query_params["id"]; + return 200; + }); + router.Any("*", [](HttpRequest* req, HttpResponse* resp) { + (void)req; + resp->body = "FALLBACK"; + return 200; + }); + + int status = 0; + assert(call_route(router, "/status", HTTP_GET, &status) == "EXACT:/status"); + assert(status == 0); + + assert(call_route(router, "/user/123", HTTP_GET, &status) == "123"); + assert(status == 0); + + assert(call_route(router, "/missing", HTTP_GET, &status) == "FALLBACK"); + assert(status == 0); + + status = 0; + call_route(router, "/status", HTTP_OPTIONS, &status); + assert(status == HTTP_STATUS_METHOD_NOT_ALLOWED); + + return 0; +}