Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions build/deps/gen/deps/dep_capnp_cpp.MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
http = use_extension("@//:build/exts/http.bzl", "http")
http.archive(
name = "capnp-cpp",
sha256 = "0884582731fb8b9a6ef7a0d58056535d0480533256958d3eaae189d706bc43aa",
strip_prefix = "capnproto-capnproto-0776402/c++",
sha256 = "cbf7dcef02deb3a3addcfefacb76672c46a48c953024860bf80fceabc255d41d",
strip_prefix = "capnproto-capnproto-c1bce20/c++",
type = "tgz",
url = "https://github.com/capnproto/capnproto/tarball/07764022ba75c6924a250d5be0bf2e83250602a0",
url = "https://github.com/capnproto/capnproto/tarball/c1bce2095a8dd76851fe3c1c61550f79b69d671d",
)
use_repo(http, "capnp-cpp")
77 changes: 17 additions & 60 deletions src/workerd/api/basics.c++
Original file line number Diff line number Diff line change
Expand Up @@ -572,52 +572,6 @@ class AbortTriggerRpcClient final {
rpc::AbortTrigger::Client client;
};

namespace {
// The jsrpc handler that receives aborts from the remote and triggers them locally
class AbortTriggerRpcServer final: public rpc::AbortTrigger::Server {
public:
AbortTriggerRpcServer(kj::Own<kj::PromiseFulfiller<void>> fulfiller,
kj::Own<AbortSignal::PendingReason>&& pendingReason)
: fulfiller(kj::mv(fulfiller)),
pendingReason(kj::mv(pendingReason)) {}

kj::Promise<void> abort(AbortContext abortCtx) override {
auto params = abortCtx.getParams();
auto reason = params.getReason().getV8Serialized();

pendingReason->getWrapped() = kj::heapArray(reason.asBytes());
fulfiller->fulfill();
return kj::READY_NOW;
}

kj::Promise<void> release(ReleaseContext releaseCtx) override {
released = true;
return kj::READY_NOW;
}

~AbortTriggerRpcServer() noexcept(false) {
if (pendingReason->getWrapped() != nullptr) {
// Already triggered
return;
}

if (!released) {
pendingReason->getWrapped() = JSG_KJ_EXCEPTION(FAILED, DOMAbortError,
"An AbortSignal received over RPC was implicitly aborted because the connection back to "
"its trigger was lost.");
}

// Always fulfill the promise in case the AbortSignal was waiting
fulfiller->fulfill();
}

private:
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
kj::Own<AbortSignal::PendingReason> pendingReason;
bool released = false;
};
} // namespace

AbortSignal::AbortSignal(kj::Maybe<kj::Exception> exception,
jsg::Optional<jsg::JsRef<jsg::JsValue>> maybeReason,
Flag flag)
Expand Down Expand Up @@ -858,15 +812,19 @@ void AbortSignal::serialize(jsg::Lock& js, jsg::Serializer& serializer) {
return;
}

auto streamCap = externalHandler
->writeStream([&](rpc::JsValue::External::Builder builder) mutable {
builder.setAbortTrigger();
}).castAs<rpc::AbortTrigger>();
auto pipeline = externalHandler->getExternalPusher()
.pushAbortSignalRequest(capnp::MessageSize{2, 0})
.sendForPipeline();

externalHandler->write(
[signal = pipeline.getSignal()](rpc::JsValue::External::Builder builder) mutable {
builder.setAbortSignal(kj::mv(signal));
});

auto& ioContext = IoContext::current();
// Keep track of every AbortSignal cloned from this one.
// If this->triggerAbort(...) is called, each rpcClient will be informed.
rpcClients.add(ioContext.addObject(kj::heap<AbortTriggerRpcClient>(kj::mv(streamCap))));
rpcClients.add(ioContext.addObject(kj::heap<AbortTriggerRpcClient>(pipeline.getTrigger())));
}

jsg::Ref<AbortSignal> AbortSignal::deserialize(
Expand All @@ -890,20 +848,19 @@ jsg::Ref<AbortSignal> AbortSignal::deserialize(
return js.alloc<AbortSignal>(/* exception */ kj::none, /* maybeReason */ kj::none, flag);
}

auto reader = externalHandler->read();
KJ_REQUIRE(reader.isAbortTrigger(), "external table slot type does't match serialization tag");

// The AbortSignalImpl will receive any remote triggerAbort requests and fulfill the promise with the reason for abort

auto signal = js.alloc<AbortSignal>(/* exception */ kj::none, /* maybeReason */ kj::none, flag);

auto paf = kj::newPromiseAndFulfiller<void>();
auto pendingReason = IoContext::current().addObject(kj::refcounted<PendingReason>());
auto& ioctx = IoContext::current();

auto reader = externalHandler->read();
KJ_REQUIRE(reader.isAbortSignal(), "external table slot type does't match serialization tag");

auto resolvedSignal = ioctx.getExternalPusher()->unwrapAbortSignal(reader.getAbortSignal());

externalHandler->setLastStream(
kj::heap<AbortTriggerRpcServer>(kj::mv(paf.fulfiller), kj::addRef(*pendingReason)));
signal->rpcAbortPromise = IoContext::current().addObject(kj::heap(kj::mv(paf.promise)));
signal->pendingReason = kj::mv(pendingReason);
signal->rpcAbortPromise = ioctx.addObject(kj::heap(kj::mv(resolvedSignal.signal)));
signal->pendingReason = ioctx.addObject(kj::mv(resolvedSignal.reason));

return signal;
}
Expand Down
5 changes: 2 additions & 3 deletions src/workerd/api/basics.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// TODO(cleanup): Rename to events.h?

#include <workerd/io/compatibility-date.capnp.h>
#include <workerd/io/external-pusher.h>
#include <workerd/io/io-own.h>
#include <workerd/io/worker-interface.capnp.h>
#include <workerd/jsg/jsg.h>
Expand Down Expand Up @@ -571,9 +572,7 @@ class AbortSignal final: public EventTarget {
jsg::Optional<jsg::JsRef<jsg::JsValue>> maybeReason = kj::none,
Flag flag = Flag::NONE);

using PendingReason = kj::RefcountedWrapper<
kj::OneOf<kj::Array<kj::byte> /* v8Serialized */, kj::Exception /* if capability is dropped */
>>;
using PendingReason = ExternalPusherImpl::PendingAbortReason;

// The AbortSignal explicitly does not expose a constructor(). It is
// illegal for user code to create an AbortSignal directly.
Expand Down
116 changes: 11 additions & 105 deletions src/workerd/api/streams/readable.c++
Original file line number Diff line number Diff line change
Expand Up @@ -519,89 +519,6 @@ jsg::Optional<uint32_t> ByteLengthQueuingStrategy::size(

namespace {

// HACK: We need as async pipe, like kj::newOneWayPipe(), except supporting explicit end(). So we
// wrap the two ends of the pipe in special adapters that track whether end() was called.
class ExplicitEndOutputPipeAdapter final: public capnp::ExplicitEndOutputStream {
public:
ExplicitEndOutputPipeAdapter(
kj::Own<kj::AsyncOutputStream> inner, kj::Own<kj::RefcountedWrapper<bool>> ended)
: inner(kj::mv(inner)),
ended(kj::mv(ended)) {}

kj::Promise<void> write(kj::ArrayPtr<const byte> buffer) override {
return KJ_REQUIRE_NONNULL(inner)->write(buffer);
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
return KJ_REQUIRE_NONNULL(inner)->write(pieces);
}

kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount) override {
return KJ_REQUIRE_NONNULL(inner)->tryPumpFrom(input, amount);
}

kj::Promise<void> whenWriteDisconnected() override {
return KJ_REQUIRE_NONNULL(inner)->whenWriteDisconnected();
}

kj::Promise<void> end() override {
// Signal to the other side that end() was actually called.
ended->getWrapped() = true;
inner = kj::none;
return kj::READY_NOW;
}

private:
kj::Maybe<kj::Own<kj::AsyncOutputStream>> inner;
kj::Own<kj::RefcountedWrapper<bool>> ended;
};

class ExplicitEndInputPipeAdapter final: public kj::AsyncInputStream {
public:
ExplicitEndInputPipeAdapter(kj::Own<kj::AsyncInputStream> inner,
kj::Own<kj::RefcountedWrapper<bool>> ended,
kj::Maybe<uint64_t> expectedLength)
: inner(kj::mv(inner)),
ended(kj::mv(ended)),
expectedLength(expectedLength) {}

kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
size_t result = co_await inner->tryRead(buffer, minBytes, maxBytes);

KJ_IF_SOME(l, expectedLength) {
KJ_ASSERT(result <= l);
l -= result;
if (l == 0) {
// If we got all the bytes we expected, we treat this as a successful end, because the
// underlying KJ pipe is not actually going to wait for the other side to drop. This is
// consistent with the behavior of Content-Length in HTTP anyway.
ended->getWrapped() = true;
}
}

if (result < minBytes) {
// Verify that end() was called.
if (!ended->getWrapped()) {
JSG_FAIL_REQUIRE(Error, "ReadableStream received over RPC disconnected prematurely.");
}
}
co_return result;
}

kj::Maybe<uint64_t> tryGetLength() override {
return inner->tryGetLength();
}

kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
return inner->pumpTo(output, amount);
}

private:
kj::Own<kj::AsyncInputStream> inner;
kj::Own<kj::RefcountedWrapper<bool>> ended;
kj::Maybe<uint64_t> expectedLength;
};

// Wrapper around ReadableStreamSource that prevents deferred proxying. We need this for RPC
// streams because although they are "system streams", they become disconnected when the IoContext
// is destroyed, due to the JsRpcCustomEvent being canceled.
Expand Down Expand Up @@ -676,18 +593,21 @@ void ReadableStream::serialize(jsg::Lock& js, jsg::Serializer& serializer) {
StreamEncoding encoding = controller.getPreferredEncoding();
auto expectedLength = controller.tryGetLength(encoding);

auto streamCap = externalHandler->writeStream(
[encoding, expectedLength](rpc::JsValue::External::Builder builder) mutable {
auto req = externalHandler->getExternalPusher().pushByteStreamRequest(capnp::MessageSize{2, 0});
KJ_IF_SOME(el, expectedLength) {
req.setLengthPlusOne(el + 1);
}
auto pipeline = req.sendForPipeline();

externalHandler->write([encoding, expectedLength, source = pipeline.getSource()](
rpc::JsValue::External::Builder builder) mutable {
auto rs = builder.initReadableStream();
rs.setStream(kj::mv(source));
rs.setEncoding(encoding);
KJ_IF_SOME(l, expectedLength) {
rs.getExpectedLength().setKnown(l);
}
});

kj::Own<capnp::ExplicitEndOutputStream> kjStream =
ioctx.getByteStreamFactory().capnpToKjExplicitEnd(
kj::mv(streamCap).castAs<capnp::ByteStream>());
ioctx.getByteStreamFactory().capnpToKjExplicitEnd(pipeline.getSink());

auto sink = newSystemStream(kj::mv(kjStream), encoding, ioctx);

Expand Down Expand Up @@ -718,21 +638,7 @@ jsg::Ref<ReadableStream> ReadableStream::deserialize(

auto& ioctx = IoContext::current();

kj::Maybe<uint64_t> expectedLength;
auto el = rs.getExpectedLength();
if (el.isKnown()) {
expectedLength = el.getKnown();
}

auto pipe = kj::newOneWayPipe(expectedLength);

auto endedFlag = kj::refcounted<kj::RefcountedWrapper<bool>>(false);

auto out = kj::heap<ExplicitEndOutputPipeAdapter>(kj::mv(pipe.out), kj::addRef(*endedFlag));
auto in =
kj::heap<ExplicitEndInputPipeAdapter>(kj::mv(pipe.in), kj::mv(endedFlag), expectedLength);

externalHandler->setLastStream(ioctx.getByteStreamFactory().kjToCapnp(kj::mv(out)));
auto in = ioctx.getExternalPusher()->unwrapStream(rs.getStream());

return js.alloc<ReadableStream>(ioctx,
kj::heap<NoDeferredProxyReadableStream>(newSystemStream(kj::mv(in), encoding, ioctx), ioctx));
Expand Down
30 changes: 16 additions & 14 deletions src/workerd/api/tests/abortsignal-test.wd-test
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
using Workerd = import "/workerd/workerd.capnp";

const services :List(Workerd.Service) = [
( name = "abortsignal-test",
worker = (
modules = [
(name = "worker", esModule = embed "abortsignal-test.js")
],
compatibilityDate = "2025-01-01",
compatibilityFlags = ["nodejs_compat", "enable_abortsignal_rpc", "experimental"],
bindings = [
(name = "RpcRemoteEnd", service = (name = "abortsignal-test", entrypoint = "RpcRemoteEnd")),
]
)
),
];

const unitTests :Workerd.Config = (
services = [
( name = "abortsignal-test",
worker = (
modules = [
(name = "worker", esModule = embed "abortsignal-test.js")
],
compatibilityDate = "2025-01-01",
compatibilityFlags = ["nodejs_compat", "enable_abortsignal_rpc", "experimental"],
bindings = [
(name = "RpcRemoteEnd", service = (name = "abortsignal-test", entrypoint = "RpcRemoteEnd")),
]
)
),
],
services = .services,
v8Flags = ["--expose-gc"]
);
29 changes: 29 additions & 0 deletions src/workerd/api/tests/js-rpc-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2066,3 +2066,32 @@ export let sendServiceStubOverRpc = {
}
},
};

// Make sure that calls are delivered in e-order, even in the presence of pushed externals.
export let eOrderTest = {
async test(controller, env, ctx) {
let abortController = new AbortController();
let abortSignal = abortController.signal;

let readableController;
let readableStream = new ReadableStream({
start(c) {
readableController = c;
},
});

let stub = await env.MyService.makeCounter(0);

let promises = [];
promises.push(stub.increment(1));
promises.push(stub.increment(1));
promises.push(stub.increment(1, abortSignal));
promises.push(stub.increment(1));
promises.push(stub.increment(1, readableStream));
promises.push(stub.increment(1));

let results = await Promise.all(promises);

assert.deepEqual(results, [1, 2, 3, 4, 5, 6]);
},
};
Loading
Loading