diff --git a/include/expresso/core/io/connection.h b/include/expresso/core/io/connection.h new file mode 100644 index 0000000..a7cc480 --- /dev/null +++ b/include/expresso/core/io/connection.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace expresso { + +namespace core { + +namespace io { + +struct PendingFile { + int fd; + off_t offset; + off_t remaining; +}; + +enum class ConnectionPhase { + READING = 0, + PROCESSING = 1, + WRITING = 2, + DONE = 3, +}; + +struct Connection { + int fd; + + std::vector readBuffer; + + std::string writeBuffer; + size_t writeOffset; + + std::optional pendingFile; + + std::atomic phase; + + explicit Connection(int fd) + : fd(fd), + writeOffset(0), + phase(static_cast(ConnectionPhase::READING)) {} + + Connection(const Connection&) = delete; + Connection& operator=(const Connection&) = delete; +}; + +} // namespace io + +} // namespace core + +} // namespace expresso diff --git a/include/expresso/core/io/ipoller.h b/include/expresso/core/io/ipoller.h new file mode 100644 index 0000000..162ff2c --- /dev/null +++ b/include/expresso/core/io/ipoller.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +namespace expresso { + +namespace core { + +namespace io { + +enum class EventType { + READ = 1, + WRITE = 2, +}; + +struct Event { + int fd; + EventType type; +}; + +class IPoller { +public: + virtual ~IPoller() = default; + + virtual void add(int fd, EventType type) noexcept(false) = 0; + virtual void modify(int fd, EventType type) noexcept(false) = 0; + virtual void remove(int fd) noexcept(false) = 0; + virtual std::vector wait(int timeoutMs = -1) noexcept(false) = 0; +}; + +} // namespace io + +} // namespace core + +} // namespace expresso diff --git a/include/expresso/core/io/poller.h b/include/expresso/core/io/poller.h new file mode 100644 index 0000000..90b5895 --- /dev/null +++ b/include/expresso/core/io/poller.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include + +namespace expresso { + +namespace core { + +namespace io { + +std::unique_ptr makePoller() noexcept(false); + +} // namespace io + +} // namespace core + +} // namespace expresso diff --git a/include/expresso/core/server.h b/include/expresso/core/server.h index b40cd18..b8731f3 100644 --- a/include/expresso/core/server.h +++ b/include/expresso/core/server.h @@ -1,11 +1,18 @@ #pragma once #include +#include +#include +#include +#include #include #include #include +#include +#include +#include #include #include #include @@ -17,19 +24,36 @@ namespace core { class Server : public Router { private: - int socket; + int serverSocket; size_t maxConnections; struct sockaddr_in address; + std::unique_ptr poller; + std::map> connections; + + int wakeupPipe[2]; + std::queue pendingWrites; + std::mutex pendingWritesMutex; + + nexus::pool threadPool; + mochios::enums::method getMethodFromString( const std::string& method) noexcept(false); void setupMiddlewares(); - void acceptConnections(); - void handleConnection(int clientSocket) noexcept(false); + void runEventLoop(); + void acceptNewConnections(); + void readFromConnection( + std::shared_ptr conn); + void writeToConnection( + std::shared_ptr conn); + void processConnection( + std::shared_ptr conn); + void closeConnection( + std::shared_ptr conn); - expresso::messages::Request makeRequest(std::string& request) noexcept(false); - nexus::pool threadPool; + expresso::messages::Request makeRequest( + const std::string& request) noexcept(false); public: Server(size_t maxConnections = SOMAXCONN, @@ -41,4 +65,4 @@ class Server : public Router { } // namespace core -} // namespace expresso \ No newline at end of file +} // namespace expresso diff --git a/include/expresso/helpers/response.h b/include/expresso/helpers/response.h index 24b63f7..020593e 100644 --- a/include/expresso/helpers/response.h +++ b/include/expresso/helpers/response.h @@ -1,10 +1,11 @@ #pragma once +#include +#include #include #include #include -#include #include namespace expresso { @@ -17,10 +18,10 @@ std::string getAvailableFile(const std::string& path); const std::string generateETag(const std::string& data); -bool sendChunkedData(const int& socket, const std::string& data); +std::string makeChunkedData(const std::string& data); -bool sendFileInChunks(const int& socket, const std::string& path); +std::string readFileAsChunked(const std::string& path); } // namespace helpers -} // namespace expresso \ No newline at end of file +} // namespace expresso diff --git a/include/expresso/messages/response.h b/include/expresso/messages/response.h index d776a12..f17e6aa 100644 --- a/include/expresso/messages/response.h +++ b/include/expresso/messages/response.h @@ -1,11 +1,13 @@ #pragma once +#include #include +#include #include +#include #include #include -#include #include #include #include @@ -19,16 +21,25 @@ namespace expresso { namespace messages { class Response : public mochios::messages::Response { +public: + struct FileTransfer { + int fd = -1; + off_t offset = 0; + off_t length = 0; + }; + private: bool hasEnded; int socket; std::string message; + std::string writeBuffer; std::vector cookies; - void sendToClient(); - void sendHeaders(); + FileTransfer fileTransfer; + + void buildHeaders(); public: Response(int clientSocket); @@ -50,9 +61,13 @@ class Response : public mochios::messages::Response { void sendInvalidRange(); void end(); + const std::string& getWriteBuffer() const; + std::string takeWriteBuffer(); + bool hasPendingFile() const; + FileTransfer takeFileTransfer(); const void print() const override; }; } // namespace messages -} // namespace expresso \ No newline at end of file +} // namespace expresso diff --git a/src/core/io/epoll_poller.cpp b/src/core/io/epoll_poller.cpp new file mode 100644 index 0000000..e44d0f4 --- /dev/null +++ b/src/core/io/epoll_poller.cpp @@ -0,0 +1,91 @@ +#ifdef __linux__ + +#include +#include + +#include + +#include +#include + +namespace expresso { + +namespace core { + +namespace io { + +class EpollPoller : public IPoller { +private: + int epollFd; + static constexpr int MAX_EVENTS = 64; + +public: + EpollPoller() { + this->epollFd = epoll_create1(0); + if (this->epollFd < 0) { + logger::error( + "Failed to create epoll fd", + "expresso::core::io::EpollPoller::EpollPoller()"); + } + } + + ~EpollPoller() { + if (this->epollFd >= 0) { + close(this->epollFd); + } + } + + void add(int fd, EventType type) noexcept(false) override { + struct epoll_event ev; + ev.data.fd = fd; + ev.events = (type == EventType::READ ? EPOLLIN : EPOLLOUT) | EPOLLET; + if (epoll_ctl(this->epollFd, EPOLL_CTL_ADD, fd, &ev) < 0) { + logger::error( + "epoll_ctl ADD failed for fd " + std::to_string(fd), + "expresso::core::io::EpollPoller::add()"); + } + } + + void modify(int fd, EventType type) noexcept(false) override { + struct epoll_event ev; + ev.data.fd = fd; + ev.events = (type == EventType::READ ? EPOLLIN : EPOLLOUT) | EPOLLET; + if (epoll_ctl(this->epollFd, EPOLL_CTL_MOD, fd, &ev) < 0) { + logger::error( + "epoll_ctl MOD failed for fd " + std::to_string(fd), + "expresso::core::io::EpollPoller::modify()"); + } + } + + void remove(int fd) noexcept(false) override { + epoll_ctl(this->epollFd, EPOLL_CTL_DEL, fd, nullptr); + } + + std::vector wait(int timeoutMs = -1) noexcept(false) override { + struct epoll_event events[MAX_EVENTS]; + int n = epoll_wait(this->epollFd, events, MAX_EVENTS, timeoutMs); + if (n < 0) { + return {}; + } + std::vector result; + result.reserve(n); + for (int i = 0; i < n; i++) { + EventType type = + (events[i].events & EPOLLIN) ? EventType::READ : EventType::WRITE; + result.push_back({events[i].data.fd, type}); + } + return result; + } +}; + +std::unique_ptr makePoller() noexcept(false) { + return std::make_unique(); +} + +} // namespace io + +} // namespace core + +} // namespace expresso + +#endif // __linux__ diff --git a/src/core/io/kqueue_poller.cpp b/src/core/io/kqueue_poller.cpp new file mode 100644 index 0000000..d860af9 --- /dev/null +++ b/src/core/io/kqueue_poller.cpp @@ -0,0 +1,104 @@ +#if defined(__APPLE__) || defined(__FreeBSD__) + +#include +#include +#include +#include + +#include + +#include +#include + +namespace expresso { + +namespace core { + +namespace io { + +class KqueuePoller : public IPoller { +private: + int kqFd; + static constexpr int MAX_EVENTS = 64; + +public: + KqueuePoller() { + this->kqFd = kqueue(); + if (this->kqFd < 0) { + logger::error( + "Failed to create kqueue fd", + "expresso::core::io::KqueuePoller::KqueuePoller()"); + } + } + + ~KqueuePoller() { + if (this->kqFd >= 0) { + close(this->kqFd); + } + } + + void add(int fd, EventType type) noexcept(false) override { + struct kevent ev; + int filter = (type == EventType::READ) ? EVFILT_READ : EVFILT_WRITE; + EV_SET(&ev, fd, filter, EV_ADD | EV_CLEAR, 0, 0, nullptr); + if (kevent(this->kqFd, &ev, 1, nullptr, 0, nullptr) < 0) { + logger::error( + "kevent ADD failed for fd " + std::to_string(fd), + "expresso::core::io::KqueuePoller::add()"); + } + } + + void modify(int fd, EventType type) noexcept(false) override { + struct kevent evs[2]; + int oldFilter = + (type == EventType::READ) ? EVFILT_WRITE : EVFILT_READ; + int newFilter = + (type == EventType::READ) ? EVFILT_READ : EVFILT_WRITE; + EV_SET(&evs[0], fd, oldFilter, EV_DELETE, 0, 0, nullptr); + EV_SET(&evs[1], fd, newFilter, EV_ADD | EV_CLEAR, 0, 0, nullptr); + kevent(this->kqFd, evs, 2, nullptr, 0, nullptr); + } + + void remove(int fd) noexcept(false) override { + struct kevent evs[2]; + EV_SET(&evs[0], fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + EV_SET(&evs[1], fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + kevent(this->kqFd, evs, 2, nullptr, 0, nullptr); + } + + std::vector wait(int timeoutMs = -1) noexcept(false) override { + struct kevent events[MAX_EVENTS]; + struct timespec ts; + struct timespec* tsPtr = nullptr; + if (timeoutMs >= 0) { + ts.tv_sec = timeoutMs / 1000; + ts.tv_nsec = (timeoutMs % 1000) * 1000000L; + tsPtr = &ts; + } + int n = kevent(this->kqFd, nullptr, 0, events, MAX_EVENTS, tsPtr); + if (n < 0) { + return {}; + } + std::vector result; + result.reserve(n); + for (int i = 0; i < n; i++) { + EventType type = (events[i].filter == EVFILT_READ) + ? EventType::READ + : EventType::WRITE; + result.push_back({static_cast(events[i].ident), type}); + } + return result; + } +}; + +std::unique_ptr makePoller() noexcept(false) { + return std::make_unique(); +} + +} // namespace io + +} // namespace core + +} // namespace expresso + +#endif // __APPLE__ || __FreeBSD__ diff --git a/src/core/server.cpp b/src/core/server.cpp index 578955a..da89052 100644 --- a/src/core/server.cpp +++ b/src/core/server.cpp @@ -1,5 +1,27 @@ +#include +#include +#include +#include +#include +#include + +#ifdef __linux__ +#include +#endif + +#include +#include + #include +static void setNonBlocking(int fd) { + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0) { + flags = 0; + } + fcntl(fd, F_SETFL, flags | O_NONBLOCK); +} + expresso::core::Server::Server(size_t maxConnections, size_t maxThreads) : maxConnections(maxConnections), threadPool(maxThreads) { signal(SIGPIPE, SIG_IGN); @@ -11,46 +33,66 @@ expresso::core::Server::Server(size_t maxConnections, size_t maxThreads) "." + std::to_string(EXPRESSO_VERSION_MINOR) + "." + std::to_string(EXPRESSO_VERSION_PATCH)); - this->socket = brewtils::sys::socket(AF_INET, SOCK_STREAM, 0); - if (this->socket < 0) { + this->serverSocket = brewtils::sys::socket(AF_INET, SOCK_STREAM, 0); + if (this->serverSocket < 0) { logger::error("Socket not created!", "expresso::core::Server::Server()"); } + int opt = 1; + setsockopt(this->serverSocket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + setNonBlocking(this->serverSocket); + this->address.sin_family = AF_INET; this->address.sin_addr.s_addr = INADDR_ANY; + + if (pipe(this->wakeupPipe) < 0) { + logger::error("Failed to create wakeup pipe", + "expresso::core::Server::Server()"); + } + setNonBlocking(this->wakeupPipe[0]); + setNonBlocking(this->wakeupPipe[1]); + + this->poller = expresso::core::io::makePoller(); this->setupMiddlewares(); return; } expresso::core::Server::~Server() { - if (this->socket > 0) { - close(this->socket); + if (this->serverSocket > 0) { + close(this->serverSocket); } + close(this->wakeupPipe[0]); + close(this->wakeupPipe[1]); return; } -void expresso::core::Server::listen(int port, std::function callback) { +void expresso::core::Server::listen(int port, + std::function callback) { this->address.sin_port = htons(port); pid_t pid = getpid(); logger::info("Server started with PID " + std::to_string(pid)); - if (brewtils::sys::bind(this->socket, (struct sockaddr*)&this->address, + if (brewtils::sys::bind(this->serverSocket, + (struct sockaddr*)&this->address, sizeof(this->address)) < 0) { logger::error("Unable to bind socket!", - "void expresso::core::Server::run(int port)"); + "void expresso::core::Server::listen(int port)"); } - if (brewtils::sys::listen(this->socket, this->maxConnections) < 0) { + if (brewtils::sys::listen(this->serverSocket, this->maxConnections) < 0) { logger::error("Unable to listen on socket!", - "void expresso::core::Server::run(int port)"); + "void expresso::core::Server::listen(int port)"); } + this->poller->add(this->serverSocket, expresso::core::io::EventType::READ); + this->poller->add(this->wakeupPipe[0], expresso::core::io::EventType::READ); + if (callback != nullptr) { callback(); } - this->acceptConnections(); + this->runEventLoop(); return; } @@ -65,8 +107,9 @@ mochios::enums::method expresso::core::Server::getMethodFromString( else if (method == "OPTIONS") return mochios::enums::method::OPTIONS; else if (method == "HEAD") return mochios::enums::method::HEAD; else - logger::error("Unsupported HTTP method: " + method, - "expresso::core::Server::getMethodFromString(std::string &method) noexcept(false)"); + logger::error( + "Unsupported HTTP method: " + method, + "expresso::core::Server::getMethodFromString(const std::string&)"); } void expresso::core::Server::setupMiddlewares() { @@ -74,94 +117,245 @@ void expresso::core::Server::setupMiddlewares() { this->use(std::make_unique()); } -void expresso::core::Server::acceptConnections() { +void expresso::core::Server::runEventLoop() { + while (true) { + std::vector events = this->poller->wait(-1); + + for (const expresso::core::io::Event& event : events) { + if (event.fd == this->serverSocket) { + this->acceptNewConnections(); + } else if (event.fd == this->wakeupPipe[0]) { + char buf[64]; + while (::read(this->wakeupPipe[0], buf, sizeof(buf)) > 0) {} + + std::queue toWrite; + { + std::lock_guard lock(this->pendingWritesMutex); + std::swap(toWrite, this->pendingWrites); + } + while (!toWrite.empty()) { + int fd = toWrite.front(); + toWrite.pop(); + auto it = this->connections.find(fd); + if (it != this->connections.end()) { + this->poller->modify(fd, expresso::core::io::EventType::WRITE); + } + } + } else if (event.type == expresso::core::io::EventType::READ) { + auto it = this->connections.find(event.fd); + if (it != this->connections.end()) { + this->readFromConnection(it->second); + } + } else if (event.type == expresso::core::io::EventType::WRITE) { + auto it = this->connections.find(event.fd); + if (it != this->connections.end()) { + this->writeToConnection(it->second); + } + } + } + } +} + +void expresso::core::Server::acceptNewConnections() { while (true) { struct sockaddr_in clientAddress; socklen_t clientAddressLength = sizeof(clientAddress); - int clientSocket = accept(this->socket, (struct sockaddr*)&clientAddress, + int clientSocket = accept(this->serverSocket, + (struct sockaddr*)&clientAddress, &clientAddressLength); if (clientSocket < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; + } logger::error("Client connection not accepted!", - "void expresso::core::Server::acceptConnections()"); - return; + "expresso::core::Server::acceptNewConnections()"); + break; } - this->threadPool.enqueue( - [this, clientSocket]() { this->handleConnection(clientSocket); }); + setNonBlocking(clientSocket); + auto conn = + std::make_shared(clientSocket); + this->connections[clientSocket] = conn; + this->poller->add(clientSocket, expresso::core::io::EventType::READ); } - - return; } -void expresso::core::Server::handleConnection( - int clientSocket) noexcept(false) { +void expresso::core::Server::readFromConnection( + std::shared_ptr conn) { + if (conn->phase != static_cast( + expresso::core::io::ConnectionPhase::READING)) { + return; + } + constexpr size_t bufferSize = 4096; - std::vector charRequest; - charRequest.resize(bufferSize, '\0'); - size_t totalBytesRead = 0; + char buf[bufferSize]; while (true) { - // sanity check - if (charRequest.size() - totalBytesRead < bufferSize) { - charRequest.resize(charRequest.size() + bufferSize, '\0'); + ssize_t n = ::recv(conn->fd, buf, bufferSize, 0); + if (n < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; + } + this->closeConnection(conn); + return; } - - ssize_t bytesRead = brewtils::sys::recv( - clientSocket, - charRequest.data() + totalBytesRead, - bufferSize, - 0); - - // miraculous happening - if (bytesRead < 0) { - close(clientSocket); - logger::error("Failed to receive data from client!", - "void expresso::core::Server::handleConnection(int clientSocket) noexcept(false)"); + if (n == 0) { + this->closeConnection(conn); return; } + conn->readBuffer.insert(conn->readBuffer.end(), buf, buf + n); + } - // if client closed connection - if (bytesRead == 0) { - break; - } + std::string data(conn->readBuffer.data(), conn->readBuffer.size()); - totalBytesRead += bytesRead; + size_t headersEnd = data.find("\r\n\r\n"); + if (headersEnd == std::string::npos) { + return; + } - // if end of headers reached - if (std::string(charRequest.data(), totalBytesRead).find("\r\n\r\n") != - std::string::npos) { - break; + // Ensure the full body has arrived for requests that carry one + size_t contentLength = 0; + size_t clPos = data.find("content-length:"); + if (clPos == std::string::npos) { + clPos = data.find("Content-Length:"); + } + if (clPos != std::string::npos) { + size_t valueStart = data.find(':', clPos) + 1; + while (valueStart < data.size() && data[valueStart] == ' ') { + ++valueStart; + } + size_t valueEnd = data.find('\r', valueStart); + if (valueEnd != std::string::npos) { + try { + contentLength = + std::stoull(data.substr(valueStart, valueEnd - valueStart)); + } catch (...) { + } } } - charRequest.resize(totalBytesRead); - std::string request(charRequest.data(), totalBytesRead); - if (totalBytesRead == 0 || request.empty()) { - close(clientSocket); + size_t bodyStart = headersEnd + 4; + size_t bodyReceived = + conn->readBuffer.size() > bodyStart + ? conn->readBuffer.size() - bodyStart + : 0; + if (bodyReceived < contentLength) { return; } - expresso::messages::Response* res = - new expresso::messages::Response(clientSocket); + conn->phase = + static_cast(expresso::core::io::ConnectionPhase::PROCESSING); + this->threadPool.enqueue( + [this, conn]() { this->processConnection(conn); }); +} + +void expresso::core::Server::processConnection( + std::shared_ptr conn) { + expresso::messages::Response res(conn->fd); try { - expresso::messages::Request req = this->makeRequest(request); - req.res = res; - this->handleRequest(req, *res); - delete res; + std::string requestStr(conn->readBuffer.data(), + conn->readBuffer.size()); + expresso::messages::Request req = this->makeRequest(requestStr); + req.res = &res; + this->handleRequest(req, res); } catch (const std::exception& e) { - logger::error( - e.what(), - "void expresso::core::Server::handleConnection(int clientSocket)"); - res->status(expresso::enums::STATUS_CODE::BAD_REQUEST).send("Bad Request"); + logger::error(e.what(), + "expresso::core::Server::processConnection()"); + res.status(expresso::enums::STATUS_CODE::BAD_REQUEST) + .send("Bad Request"); } - close(clientSocket); - return; + res.end(); + conn->writeBuffer = res.takeWriteBuffer(); + conn->writeOffset = 0; + + if (res.hasPendingFile()) { + auto ft = res.takeFileTransfer(); + conn->pendingFile = expresso::core::io::PendingFile{ + ft.fd, ft.offset, ft.length}; + } + + conn->phase = + static_cast(expresso::core::io::ConnectionPhase::WRITING); + + { + std::lock_guard lock(this->pendingWritesMutex); + this->pendingWrites.push(conn->fd); + } + char byte = 1; + ::write(this->wakeupPipe[1], &byte, 1); } -expresso::messages::Request -expresso::core::Server::makeRequest(std::string& request) noexcept(false) { +void expresso::core::Server::writeToConnection( + std::shared_ptr conn) { + while (conn->writeOffset < conn->writeBuffer.size()) { + ssize_t n = ::send(conn->fd, + conn->writeBuffer.data() + conn->writeOffset, + conn->writeBuffer.size() - conn->writeOffset, + 0); + if (n < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + this->closeConnection(conn); + return; + } + conn->writeOffset += static_cast(n); + } + + if (conn->pendingFile.has_value()) { + auto& pf = conn->pendingFile.value(); + while (pf.remaining > 0) { +#ifdef __linux__ + ssize_t n = ::sendfile(conn->fd, pf.fd, &pf.offset, + static_cast(pf.remaining)); + if (n < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + this->closeConnection(conn); + return; + } + pf.remaining -= static_cast(n); +#elif defined(__APPLE__) || defined(__FreeBSD__) + off_t len = pf.remaining; + int r = ::sendfile(pf.fd, conn->fd, pf.offset, &len, nullptr, 0); + if (len > 0) { + pf.offset += len; + pf.remaining -= len; + } + if (r < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + this->closeConnection(conn); + return; + } +#endif + } + ::close(pf.fd); + conn->pendingFile.reset(); + } + + this->closeConnection(conn); +} + +void expresso::core::Server::closeConnection( + std::shared_ptr conn) { + if (conn->pendingFile.has_value()) { + ::close(conn->pendingFile->fd); + conn->pendingFile.reset(); + } + this->poller->remove(conn->fd); + close(conn->fd); + conn->phase = + static_cast(expresso::core::io::ConnectionPhase::DONE); + this->connections.erase(conn->fd); +} + +expresso::messages::Request expresso::core::Server::makeRequest( + const std::string& request) noexcept(false) { std::string line; std::istringstream stream(request); std::getline(stream, line); @@ -172,14 +366,13 @@ expresso::core::Server::makeRequest(std::string& request) noexcept(false) { req.method = this->getMethodFromString(method); req.httpVersion = parts[2]; if (req.httpVersion.substr(0, 5) != "HTTP/") { - logger::error("Invalid HTTP version: " + req.httpVersion, - "expresso::core::Server::makeRequest(std::string &request) " - "noexcept(false)"); + logger::error( + "Invalid HTTP version: " + req.httpVersion, + "expresso::core::Server::makeRequest(const std::string&)"); } req.tempPath = req.path.substr(1, req.path.size()); - // Headers while (std::getline(stream, line) && line != "\r") { size_t separator = line.find(':', 0); if (separator != std::string::npos) { @@ -206,7 +399,6 @@ expresso::core::Server::makeRequest(std::string& request) noexcept(false) { } } - // Queries size_t start = req.tempPath.find('?', 0); if (start != std::string::npos) { start += 1; @@ -221,16 +413,17 @@ expresso::core::Server::makeRequest(std::string& request) noexcept(false) { } std::string value = req.tempPath.substr(start, end - start); - req.queries[brewtils::url::decode(key)] = brewtils::url::decode(value); + req.queries[brewtils::url::decode(key)] = + brewtils::url::decode(value); start = end + 1; } } } req.path = req.tempPath; - // Fixing the tempPath req.tempPath = req.tempPath.substr(0, req.tempPath.find('?', 0)); - if (req.tempPath[req.tempPath.size() - 1] == '/') { + if (!req.tempPath.empty() && + req.tempPath[req.tempPath.size() - 1] == '/') { req.tempPath = req.tempPath.substr(0, req.tempPath.size() - 1); } @@ -239,20 +432,21 @@ expresso::core::Server::makeRequest(std::string& request) noexcept(false) { return req; } - // Setting the body std::string contentType = req.headers["content-type"]; std::string body = request.substr(request.find("\r\n\r\n") + 4); - if (contentType == "text/plain" || contentType == "application/javascript") { + if (contentType == "text/plain" || + contentType == "application/javascript") { req.body = json::object(body); } else if (contentType == "application/json") { json::parser parser; req.body = parser.loads(body); } else if (contentType == "application/x-www-form-urlencoded") { - std::vector parts = brewtils::string::split(body, "&"); + std::vector bodyParts = + brewtils::string::split(body, "&"); std::string key; std::string value; req.body = json::object(std::map()); - for (const std::string& str : parts) { + for (const std::string& str : bodyParts) { key = brewtils::url::decode(brewtils::string::split(str, "=")[0]); value = brewtils::url::decode(brewtils::string::split(str, "=")[1]); req.body[key] = json::object(value); @@ -261,14 +455,15 @@ expresso::core::Server::makeRequest(std::string& request) noexcept(false) { "multipart/form-data") { std::string delimiter = brewtils::string::split( brewtils::string::split(contentType, ";")[1], "=")[1]; - std::vector parts = brewtils::string::split(body, delimiter); + std::vector bodyParts = + brewtils::string::split(body, delimiter); std::vector data; std::string key; std::string value; req.body = json::object(std::map()); - for (const std::string& str : parts) { - data = brewtils::string::split(str, - "Content-Disposition: form-data; name=\""); + for (const std::string& str : bodyParts) { + data = brewtils::string::split( + str, "Content-Disposition: form-data; name=\""); if (data.size() == 2) { key = brewtils::string::split(data[1], "\r\n")[0]; key = key.substr(0, key.size() - 1); @@ -280,4 +475,4 @@ expresso::core::Server::makeRequest(std::string& request) noexcept(false) { } return req; -} \ No newline at end of file +} diff --git a/src/helpers/response.cpp b/src/helpers/response.cpp index d519832..cd1cbe7 100644 --- a/src/helpers/response.cpp +++ b/src/helpers/response.cpp @@ -1,3 +1,5 @@ +#include + #include std::string expresso::helpers::getAvailableFile(const std::string& path) { @@ -26,52 +28,35 @@ const std::string expresso::helpers::generateETag(const std::string& data) { return etag.str(); } -bool expresso::helpers::sendChunkedData(const int& socket, - const std::string& data) { - std::ostringstream dataSizeHex; - dataSizeHex << std::hex << data.length(); - std::string dataSize = dataSizeHex.str() + "\r\n"; - - if (brewtils::sys::send(socket, dataSize.c_str(), dataSize.length(), 0) == - -1) { - return false; - } - if (brewtils::sys::send(socket, data.c_str(), data.length(), 0) == -1) { - return false; - } - if (brewtils::sys::send(socket, "\r\n", 2, 0) == -1) { - return false; - } - return true; +std::string expresso::helpers::makeChunkedData(const std::string& data) { + std::ostringstream result; + result << std::hex << data.length() << "\r\n" << data << "\r\n"; + return result.str(); } -bool expresso::helpers::sendFileInChunks(const int& socket, - const std::string& path) { +std::string expresso::helpers::readFileAsChunked(const std::string& path) { + std::string result; std::fstream file(path, std::ios::in | std::ios::binary); char buffer[expresso::helpers::CHUNK_SIZE]; try { - std::streamsize bytesRead = 0; while (true) { file.read(buffer, expresso::helpers::CHUNK_SIZE); - bytesRead = file.gcount(); + std::streamsize bytesRead = file.gcount(); if (bytesRead == 0) { break; } - - if (!expresso::helpers::sendChunkedData(socket, - std::string(buffer, bytesRead))) { - return false; - } + result += expresso::helpers::makeChunkedData( + std::string(buffer, static_cast(bytesRead))); } } catch (const std::exception& e) { - logger::error(e.what(), "bool sendFileInChunks(const int &socket, const " - "std::string &path)"); - return false; + logger::error(e.what(), + "expresso::helpers::readFileAsChunked()"); } + if (file.is_open()) { file.close(); } - return true; -} \ No newline at end of file + return result; +} diff --git a/src/messages/response.cpp b/src/messages/response.cpp index 295929f..76d5cbe 100644 --- a/src/messages/response.cpp +++ b/src/messages/response.cpp @@ -1,8 +1,11 @@ +#include +#include + #include #include expresso::messages::Response::Response(int clientSocket) - : hasEnded(false), socket(clientSocket), message("") { + : hasEnded(false), socket(clientSocket), message(""), writeBuffer("") { this->set("connection", "close"); this->statusCode = expresso::enums::STATUS_CODE::OK; @@ -11,7 +14,11 @@ expresso::messages::Response::Response(int clientSocket) expresso::messages::Response::~Response() { if (!this->hasEnded) { - this->sendToClient(); + this->end(); + } + + if (this->fileTransfer.fd >= 0) { + ::close(this->fileTransfer.fd); } for (expresso::messages::Cookie* cookie : this->cookies) { @@ -21,6 +28,21 @@ expresso::messages::Response::~Response() { return; } +void expresso::messages::Response::buildHeaders() { + std::string header = + "HTTP/1.1 " + std::to_string(this->statusCode) + "\r\n"; + for (const std::pair& it : this->headers) { + header += it.first + ": " + it.second + "\r\n"; + } + for (Cookie* cookie : this->cookies) { + header += "set-cookie: " + cookie->serialize() + "\r\n"; + } + header += "\r\n"; + this->writeBuffer += header; + + return; +} + void expresso::messages::Response::set(std::string headerName, std::string headerValue) { this->headers[brewtils::string::lower(headerName)] = headerValue; @@ -112,43 +134,28 @@ void expresso::messages::Response::sendFile(const std::string& path, std::to_string(end) + "/" + std::to_string(fileSize)); } - this->sendHeaders(); - std::ifstream file(availableFile, std::ios::binary); - try { - file.seekg(start, std::ios::beg); - char buffer[expresso::helpers::CHUNK_SIZE]; - while (file.read(buffer, expresso::helpers::CHUNK_SIZE)) { - if (brewtils::sys::send(this->socket, buffer, - expresso::helpers::CHUNK_SIZE, 0) == -1) { - this->hasEnded = true; - break; - } - } - if (file.gcount() > 0) { - if (brewtils::sys::send(this->socket, buffer, file.gcount(), 0) == -1) { - this->hasEnded = true; - } - } - } catch (const std::exception& e) { - logger::error(e.what(), "void expresso::messages::Response::sendFile(const " - "std::string &path, int64_t start, int64_t end)"); - } - if (file.is_open()) { - file.close(); + int fd = ::open(availableFile.c_str(), O_RDONLY); + if (fd < 0) { + return this->sendNotFound(); } + this->buildHeaders(); + this->hasEnded = true; + this->fileTransfer = {fd, static_cast(start), + static_cast(end - start + 1)}; + return; } -void expresso::messages::Response::sendFiles(const std::set& paths, - const std::string& zipFileName) { +void expresso::messages::Response::sendFiles( + const std::set& paths, const std::string& zipFileName) { this->headers.erase("content-length"); this->set("transfer-encoding", "chunked"); this->set("content-type", brewtils::os::file::getMimeType(zipFileName)); - this->set("content-disposition", "inline; filename=\"" + zipFileName + "\""); + this->set("content-disposition", + "inline; filename=\"" + zipFileName + "\""); this->set("accept-ranges", "bytes"); - this->sendHeaders(); zippuccino::Zipper zipper; for (const std::string& path : paths) { @@ -156,34 +163,28 @@ void expresso::messages::Response::sendFiles(const std::set& paths, } zipper.zip(); + std::string chunkedBody; try { while (!zipper.isFinished()) { - if (!expresso::helpers::sendChunkedData(this->socket, - zipper.getHeader())) { - this->hasEnded = true; - return; - } - std::string currentFile = zipper.getCurrentFile(); - if (!expresso::helpers::sendFileInChunks(this->socket, currentFile)) { - this->hasEnded = true; - return; - } - } - - if (!expresso::helpers::sendChunkedData(this->socket, zipper.getFooter())) { - this->hasEnded = true; - return; - } - if (brewtils::sys::send(this->socket, "0\r\n\r\n", 5, 0) == -1) { - this->hasEnded = true; + chunkedBody += + expresso::helpers::makeChunkedData(zipper.getHeader()); + chunkedBody += + expresso::helpers::readFileAsChunked(zipper.getCurrentFile()); } + chunkedBody += + expresso::helpers::makeChunkedData(zipper.getFooter()); + chunkedBody += "0\r\n\r\n"; } catch (const std::exception& e) { - logger::error(e.what(), - "void expresso::messages::Response::sendFiles(const " - "std::set &paths, const std::string " - "&zipFileName)"); + logger::error( + e.what(), + "void expresso::messages::Response::sendFiles(const " + "std::set &paths, const std::string &zipFileName)"); } + this->buildHeaders(); + this->writeBuffer += chunkedBody; + this->hasEnded = true; + return; } @@ -204,12 +205,34 @@ void expresso::messages::Response::sendInvalidRange() { void expresso::messages::Response::end() { if (!this->hasEnded) { - this->sendToClient(); + this->set("content-length", std::to_string(this->message.length())); + this->buildHeaders(); + this->writeBuffer += this->message; + this->hasEnded = true; } return; } +const std::string& expresso::messages::Response::getWriteBuffer() const { + return this->writeBuffer; +} + +std::string expresso::messages::Response::takeWriteBuffer() { + return std::move(this->writeBuffer); +} + +bool expresso::messages::Response::hasPendingFile() const { + return this->fileTransfer.fd >= 0; +} + +expresso::messages::Response::FileTransfer +expresso::messages::Response::takeFileTransfer() { + FileTransfer ft = this->fileTransfer; + this->fileTransfer.fd = -1; + return ft; +} + const void expresso::messages::Response::print() const { logger::info("Response: "); logger::info(" statusCode: " + std::to_string(this->statusCode)); @@ -222,32 +245,3 @@ const void expresso::messages::Response::print() const { return; } - -void expresso::messages::Response::sendToClient() { - this->set("content-length", std::to_string(this->message.length())); - this->sendHeaders(); - if (this->hasEnded) { - return; - } - brewtils::sys::send(this->socket, this->message.c_str(), - this->message.length(), 0); - this->hasEnded = true; - - return; -} - -void expresso::messages::Response::sendHeaders() { - std::string header = "HTTP/1.1 " + std::to_string(this->statusCode) + "\r\n"; - for (std::pair it : this->headers) { - header += it.first + ": " + it.second + "\r\n"; - } - for (Cookie* cookie : this->cookies) { - header += "set-cookie: " + cookie->serialize() + "\r\n"; - } - header += "\r\n"; - if (brewtils::sys::send(this->socket, header.c_str(), header.length(), 0) == - -1) { - this->hasEnded = true; - } - return; -} \ No newline at end of file