diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 29d11b73..638df76a 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -74,6 +74,9 @@ class Dispatcher { template RetT Call(KeyT key, ArgsT... args) const { auto kernel = this->GetKernel(key); tls_autocast_context.Autocast(key, args...); +#ifdef PROFILE_MODE + SetProfileContext(key.second, key.first); +#endif return kernel.Call(std::forward(args)...); } diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index 6ccad9e4..b4bdafd8 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -25,7 +25,6 @@ std::vector> Broadcast(const std::vector> ReduceAddCoalesced(const std::vector>> &grads, Device destination) { std::vector> outputs; - auto kernel = Dispatcher::Instance().GetKernel({destination.type(), "AccumulateGrad"}); std::vector>> to_destination_grads; for (int i = 0; i < grads[0].size(); ++i) { outputs.emplace_back(std::make_shared(grads[0][i]->Dims(), grads[0][i]->Dtype(), destination)); @@ -37,6 +36,9 @@ std::vector> ReduceAddCoalesced(const std::vector(grads[i][j]->To(destination))); } } + // NOTE(zbl): To ensure Profiler works correctly, there should not be any other kernel calls + // between GetKernel and kernel.Call, otherwise ProfileContext would be tainted + auto kernel = Dispatcher::Instance().GetKernel({destination.type(), "AccumulateGrad"}); for (int i = 0; i < grads.size(); ++i) { for (int j = 0; j < grads[i].size(); ++j) { kernel.Call(to_destination_grads[i][j], static_cast(1.0), outputs[j]);