Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/exec/env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,4 +272,11 @@ namespace experimental::execution

namespace exec = experimental::execution;

namespace STDEXEC::__detail
{
template <class Sender, class Attrs>
extern __mtype<exec::__write_attrs::__sender<__demangle_t<Sender>, Attrs>>
__demangle_v<exec::__write_attrs::__sender<Sender, Attrs>>;
}

STDEXEC_PRAGMA_POP()
27 changes: 9 additions & 18 deletions include/nvexec/stream/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,6 @@ namespace nv::execution
static_cast<BaseEnv&&>(base_env));
}

template <class BaseEnv>
using make_stream_env_t = stream_env_t<BaseEnv>;

template <class BaseEnv>
using make_terminal_stream_env_t = terminal_stream_env_t<BaseEnv>;

Expand Down Expand Up @@ -657,7 +654,7 @@ namespace nv::execution
{
using operation_state_concept = STDEXEC::operation_state_tag;
using outer_env_t = env_of_t<OuterReceiver>;
using env_t = make_stream_env_t<outer_env_t>;
using env_t = stream_env_t<outer_env_t>;

static constexpr bool borrows_stream = borrows_stream_h<outer_env_t>();

Expand All @@ -670,19 +667,11 @@ namespace nv::execution
[[nodiscard]]
auto get_stream_provider() const -> stream_provider*
{
stream_provider* provider{};

if constexpr (borrows_stream)
{
outer_env_t const & env = get_env(rcvr_);
provider = ::nvexec::_strm::get_stream_provider(env);
return _strm::get_stream_provider(get_env(rcvr_));
}
else
{
provider = &const_cast<stream_provider&>(stream_provider_);
}

return provider;
return &stream_provider_;
}

[[nodiscard]]
Expand Down Expand Up @@ -771,10 +760,10 @@ namespace nv::execution
}
}

context ctx_;
void* temp_storage_{nullptr};
OuterReceiver rcvr_;
stream_provider stream_provider_;
context ctx_;
void* temp_storage_{nullptr};
OuterReceiver rcvr_;
mutable stream_provider stream_provider_;
};

template <class OpState, class Env = decltype(__declval<OpState&>().make_env())>
Expand Down Expand Up @@ -804,6 +793,8 @@ namespace nv::execution
[[nodiscard]]
auto get_env() const noexcept -> Env
{
static_assert(__same_as<Env, decltype(opstate_.make_env())>,
"Env must be the type returned by OpState::make_env()");
return opstate_.make_env();
}

Expand Down
62 changes: 33 additions & 29 deletions include/nvexec/stream/let_xxx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,18 @@ namespace nv::execution::_strm
// correctly report the scheduler and domain on which the sender's operation will be
// started.
inline constexpr auto _mk_sch_env =
[]<class CvSender, class Receiver, class SetTag>(CvSender&& sndr, Receiver&& rcvr, SetTag)
[]<class SetTag, class CvSender, class Env>(SetTag, CvSender&& sndr, Env&& env)
{
using cv_fn = __copy_cvref_fn<CvSender>;
return __mk_secondary_env_t<SetTag>()(cv_fn{}, sndr, STDEXEC::get_env(rcvr));
using cv_fn_t = __copy_cvref_fn<CvSender>;
return __mk_secondary_env_t<SetTag>()(cv_fn_t{}, sndr, env);
};

template <class CvSender, class Receiver, class SetTag>
using _sch_env_t = __result_of<_mk_sch_env, CvSender, Receiver, SetTag>;
template <class SetTag, class CvSender, class Env>
using _sch_env_t = __result_of<_mk_sch_env, SetTag, CvSender, Env>;

// The environment of the propagate_receiver used to connect the secondary sender.
template <class SetTag, class CvSender, class _Env>
using _env2_t = __join_env_t<_sch_env_t<SetTag, CvSender, _Env> const &, stream_env_t<_Env>>;

inline constexpr auto _mk_env2 =
[]<class SchEnv, class Receiver>(SchEnv const & sch_env,
Expand All @@ -86,38 +90,34 @@ namespace nv::execution::_strm
return __env::__join(sch_env, opstate.make_env());
};

template <class CvSender, class Receiver, class SetTag>
using _env2_t = __result_of<_mk_env2,
_sch_env_t<CvSender, Receiver, SetTag> const &,
_strm::opstate_base<Receiver> const &>;

template <class CvSender, class Receiver, class Fun, class SetTag>
using _propagate_receiver_t = propagate_receiver<_opstate<CvSender, Receiver, Fun, SetTag>,
_env2_t<CvSender, Receiver, SetTag>>;
_env2_t<SetTag, CvSender, env_of_t<Receiver>>>;

template <class Sender, class Receiver, class Fun, class SetTag>
template <class CvSender, class Receiver, class Fun, class SetTag>
using _mk_opstate_fn = __mcompose<
__mbind_back_q<connect_result_t, _propagate_receiver_t<Sender, Receiver, Fun, SetTag>>,
__mbind_back_q<connect_result_t, _propagate_receiver_t<CvSender, Receiver, Fun, SetTag>>,
_mk_result_sender_fn<Fun>>;

template <class SetTag, class Sig>
template <class Sig, class SetTag>
struct _tfx_signal_fn
{
template <class, class...>
template <class, class, class...>
using __f = completion_signatures<Sig>;
};

template <class SetTag, class... Args>
struct _tfx_signal_fn<SetTag, SetTag(Args...)>
template <class... Args, class SetTag>
struct _tfx_signal_fn<SetTag(Args...), SetTag>
{
template <class Fun, class... StreamEnv>
template <class Fun, class SchEnv, class... Env>
using __f = __transform_completion_signatures_t<
__completion_signatures_of_t<__minvoke<_mk_result_sender_fn<Fun>, Args...>, StreamEnv...>,
__completion_signatures_of_t<__minvoke<_mk_result_sender_fn<Fun>, Args...>,
__join_env_t<SchEnv const &, stream_env_t<Env>>...>,
completion_signatures<set_error_t(cudaError_t)>>;
};

template <class Sig, class Fun, class SetTag, class... StreamEnv>
using _tfx_signal_t = __minvoke<_tfx_signal_fn<SetTag, Sig>, Fun, StreamEnv...>;
template <class Sig, class SetTag, class Fun, class SchEnv, class... Env>
using _tfx_signal_t = __minvoke<_tfx_signal_fn<Sig, SetTag>, Fun, SchEnv, Env...>;

template <class Sender, class Receiver, class Fun, class SetTag, class... Tuples>
struct _receiver : public stream_receiver_base
Expand Down Expand Up @@ -201,7 +201,7 @@ namespace nv::execution::_strm
template <class CvSender, class Receiver, class Fun, class SetTag>
struct _opstate : _opstate_base_t<CvSender, Receiver, Fun, SetTag>
{
using _env2_t = _sch_env_t<CvSender, Receiver, SetTag>;
using _env2_t = _sch_env_t<SetTag, CvSender, env_of_t<Receiver>>;
using _receiver_t = _let::_receiver_t<CvSender, Receiver, Fun, SetTag>;
using _result_tuples_t = _receiver_t::_result_tuples_t;
using _mk_opstate_fn_t = _mk_opstate_fn<CvSender, Receiver, Fun, SetTag>;
Expand All @@ -216,7 +216,7 @@ namespace nv::execution::_strm
static_cast<Receiver&&>(rcvr),
static_cast<Fun&&>(fun),
get_completion_scheduler<set_value_t>(get_env(sndr), get_env(rcvr)),
_mk_sch_env(sndr, rcvr, SetTag{}))
_mk_sch_env(SetTag(), sndr, STDEXEC::get_env(rcvr)))
{}

explicit _opstate(CvSender&& sndr, Receiver&& rcvr, Fun fun, _sch_t sch, _env2_t env2)
Expand All @@ -232,7 +232,7 @@ namespace nv::execution::_strm
STDEXEC_IMMOVABLE(_opstate);

[[nodiscard]]
auto make_env() const noexcept -> _let::_env2_t<CvSender, Receiver, SetTag>
auto make_env() const noexcept -> _let::_env2_t<SetTag, CvSender, env_of_t<Receiver>>
{
return _let::_mk_env2(env2_, *this);
}
Expand Down Expand Up @@ -262,11 +262,15 @@ namespace nv::execution::_strm
template <class Self, class Receiver>
using _receiver_t = _let::_receiver_t<__copy_cvref_t<Self, Sender>, Receiver, Fun, SetTag>;

template <class CvSender, class... StreamEnv>
template <class CvSender, class Env>
using _completions_t =
__mapply<__mtransform<__mbind_back_q<_let::_tfx_signal_t, Fun, SetTag, StreamEnv...>,
__mapply<__mtransform<__mbind_back_q<_let::_tfx_signal_t,
SetTag,
Fun,
_let::_sch_env_t<SetTag, CvSender, Env>,
Env>,
__mtry_q<__concat_completion_signatures_t>>,
__completion_signatures_of_t<CvSender, StreamEnv...>>;
__completion_signatures_of_t<CvSender, stream_env_t<Env>>>;

public:
explicit let_sender(Sender sndr, Fun fun, SetTag)
Expand All @@ -281,10 +285,10 @@ namespace nv::execution::_strm
return {&sndr_};
}

template <class Self, class... Env>
template <class Self, class Env>
static consteval auto get_completion_signatures()
{
return _completions_t<__copy_cvref_t<Self, Sender>, stream_env_t<Env>...>{};
return _completions_t<__copy_cvref_t<Self, Sender>, stream_env_t<Env>>{};
}

template <class Self, receiver Receiver>
Expand Down
19 changes: 16 additions & 3 deletions include/nvexec/stream/sync_wait.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,27 @@ namespace nv::execution::_strm
};
} // namespace _sync_wait

struct BECAUSE_THE_SENDER_IS_NOT_ABLE_TO_PROVIDE_A_COMPLETION_STREAM_SCHEDULER;

template <>
struct apply_sender_for<sync_wait_t>
{
template <stream_completing_sender<STDEXEC::env<>> Sender>
template <class Sender>
auto operator()(Sender&& sndr) const
{
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), STDEXEC::env{});
return _sync_wait::sync_wait_t{}(sched.ctx_, static_cast<Sender&&>(sndr));
if constexpr (!stream_completing_sender<Sender, STDEXEC::env<>>)
{
static_assert(__ok<STDEXEC::__mexception<
STDEXEC::_WHAT_(CANNOT_DISPATCH_THIS_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER),
STDEXEC::_WHY_(BECAUSE_THE_SENDER_IS_NOT_ABLE_TO_PROVIDE_A_COMPLETION_STREAM_SCHEDULER),
STDEXEC::_WHERE_(_IN_ALGORITHM_, STDEXEC::sync_wait_t),
STDEXEC::_WITH_PRETTY_SENDER_<Sender>>>);
}
else
{
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), STDEXEC::env{});
return _sync_wait::sync_wait_t{}(sched.ctx_, static_cast<Sender&&>(sndr));
}
}
};
} // namespace nv::execution::_strm
Expand Down
5 changes: 2 additions & 3 deletions include/nvexec/stream/then.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ STDEXEC_PRAGMA_IGNORE_EDG(cuda_compile)

namespace nv::execution::_strm
{

namespace _then
{
template <class... Args, class Fun>
Expand Down Expand Up @@ -157,15 +156,15 @@ namespace nv::execution::_strm
set_value_t,
Fun,
__copy_cvref_t<Self, Sender>,
Env...>,
stream_env_t<Env>...>,
completion_signatures<set_error_t(cudaError_t)>>;

template <class... Args>
using _set_value_t = __set_value_from_t<Fun, Args...>;

template <class Self, class... Env>
using _completions_t = __transform_completion_signatures_t<
__completion_signatures_of_t<__copy_cvref_t<Self, Sender>, Env...>,
__completion_signatures_of_t<__copy_cvref_t<Self, Sender>, stream_env_t<Env>...>,
__error_completions_t<Self, Env...>,
_set_value_t,
_set_error_t>;
Expand Down
12 changes: 6 additions & 6 deletions include/nvexec/stream_context.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ namespace nv::execution
struct stream_scheduler;

template <class StreamScheduler>
struct stream_scheduler_env
{ // NOLINT(bugprone-crtp-constructor-accessibility)
struct stream_scheduler_env // NOLINT(bugprone-crtp-constructor-accessibility)
{
STDEXEC_ATTRIBUTE(nodiscard)
static auto query(get_forward_progress_guarantee_t) noexcept -> forward_progress_guarantee
{
Expand Down Expand Up @@ -125,27 +125,27 @@ namespace nv::execution
}
};

attrs env_;
attrs attrs_;

public:
using completion_signatures =
STDEXEC::completion_signatures<set_value_t(), set_error_t(cudaError_t)>;

STDEXEC_ATTRIBUTE(host, device)
explicit sender(context ctx) noexcept
: env_{ctx}
: attrs_{ctx}
{}

template <class Receiver>
auto connect(Receiver rcvr) const & noexcept -> opstate<Receiver>
{
return opstate<Receiver>(static_cast<Receiver&&>(rcvr), env_.ctx_);
return opstate<Receiver>(static_cast<Receiver&&>(rcvr), attrs_.ctx_);
}

STDEXEC_ATTRIBUTE(nodiscard)
auto get_env() const noexcept -> attrs const &
{
return (env_);
return attrs_;
}
};

Expand Down
Loading
Loading