Skip to content

Commit 4d97bbb

Browse files
committed
Add coroutine-aware Read and Write function to FileDescriptor
1 parent faabcc1 commit 4d97bbb

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

toolbelt/fd.cc

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,96 @@ void CloseAllFds(std::function<bool(int)> predicate) {
88
int e = getrlimit(RLIMIT_NOFILE, &lim);
99
if (e == 0) {
1010
for (int fd = 0; fd < lim.rlim_cur; fd++) {
11-
if (fcntl(fd, F_GETFD) == 0 && predicate(fd) ) {
11+
if (fcntl(fd, F_GETFD) == 0 && predicate(fd)) {
1212
(void)close(fd);
1313
}
1414
}
1515
}
1616
}
1717

18-
}
18+
absl::StatusOr<ssize_t> FileDescriptor::Read(void *buffer, size_t length,
19+
co::Coroutine *c) {
20+
char *buf = reinterpret_cast<char *>(buffer);
21+
size_t total = 0;
22+
while (total < length) {
23+
if (c != nullptr) {
24+
// For a coroutine we need to get it to wait for the fd to be ready.
25+
// This is a coroutine yield point.
26+
int fd = c->Wait(data_->fd, POLLIN);
27+
if (fd != Fd()) {
28+
return absl::InternalError(
29+
absl::StrFormat("Unexpected file descriptor from Wait: %d", fd));
30+
}
31+
}
32+
ssize_t n = ::read(Fd(), buf + total, length - total);
33+
if (n == 0) {
34+
break;
35+
}
36+
if (n == -1) {
37+
if (errno == EAGAIN || errno == EWOULDBLOCK) {
38+
if (c == nullptr) {
39+
return absl::InternalError("Operation would block");
40+
}
41+
// If we are nonblocking yield the coroutine now. When we
42+
// are resumed we can write to the socket again.
43+
if (!data_->non_blocking_) {
44+
int fd = c->Wait(Fd(), POLLIN);
45+
if (fd != Fd()) {
46+
return absl::InternalError(absl::StrFormat(
47+
"Unexpected file descriptor from Wait: %d", fd));
48+
}
49+
}
50+
continue;
51+
}
52+
return absl::InternalError(
53+
absl::StrFormat("Read failed: %s", strerror(errno)));
54+
}
55+
total += n;
56+
}
57+
return total;
58+
}
59+
60+
absl::StatusOr<ssize_t> FileDescriptor::Write(const void *buffer, size_t length,
61+
co::Coroutine *c) {
62+
const char *buf = reinterpret_cast<const char *>(buffer);
63+
64+
size_t total = 0;
65+
while (total < length) {
66+
if (c != nullptr) {
67+
// For a coroutine we need to get it to wait for the fd to be ready.
68+
// This is a coroutine yield point.
69+
int fd = c->Wait(data_->fd, POLLOUT);
70+
if (fd != Fd()) {
71+
return absl::InternalError(
72+
absl::StrFormat("Unexpected file descriptor from Wait: %d", fd));
73+
}
74+
}
75+
ssize_t n = ::write(Fd(), buf + total, length - total);
76+
if (n == 0) {
77+
break;
78+
}
79+
if (n == -1) {
80+
if (errno == EAGAIN || errno == EWOULDBLOCK) {
81+
if (c == nullptr) {
82+
return absl::InternalError("Operation would block");
83+
}
84+
// If we are nonblocking yield the coroutine now. When we
85+
// are resumed we can write to the socket again.
86+
if (!data_->non_blocking_) {
87+
int fd = c->Wait(Fd(), POLLOUT);
88+
if (fd != Fd()) {
89+
return absl::InternalError(absl::StrFormat(
90+
"Unexpected file descriptor from Wait: %d", fd));
91+
}
92+
}
93+
continue;
94+
}
95+
return absl::InternalError(
96+
absl::StrFormat("Write failed: %s", strerror(errno)));
97+
}
98+
total += n;
99+
}
100+
return total;
101+
}
102+
103+
} // namespace toolbelt

toolbelt/fd.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#define __TOOLBELT_FD_H
77

88
#include "absl/status/status.h"
9+
#include "absl/status/statusor.h"
910
#include "absl/strings/str_format.h"
1011
#include <cassert>
1112
#include <cerrno>
@@ -17,6 +18,7 @@
1718
#include <sys/resource.h>
1819
#include <sys/stat.h>
1920
#include <unistd.h>
21+
#include "coroutine.h"
2022

2123
namespace toolbelt {
2224

@@ -160,6 +162,7 @@ class FileDescriptor {
160162
return absl::InternalError(absl::StrFormat(
161163
"Failed to set nonblocking mode on fd: %s", strerror(errno)));
162164
}
165+
data_->non_blocking_ = true;
163166
return absl::OkStatus();
164167
}
165168

@@ -180,6 +183,9 @@ class FileDescriptor {
180183
return absl::OkStatus();
181184
}
182185

186+
absl::StatusOr<ssize_t> Read(void* buffer, size_t length, co::Coroutine* c = nullptr);
187+
absl::StatusOr<ssize_t> Write(const void* buffer, size_t length,
188+
co::Coroutine* c = nullptr);
183189
private:
184190
// Reference counted OS fd, shared among all FileDescriptors with the
185191
// same OS fd, provided you don't create two FileDescriptors with the
@@ -193,6 +199,7 @@ class FileDescriptor {
193199
}
194200
}
195201
int fd = -1; // OS file descriptor.
202+
bool non_blocking_ = false;
196203
};
197204

198205
// The actual shared data. If nullptr the FileDescriptor is invalid.

0 commit comments

Comments
 (0)