Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions infini_train/include/dispatcher.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <map>
#include <string>
#include <type_traits>
#include <utility>

Expand All @@ -16,13 +17,14 @@ namespace infini_train {

class KernelFunction {
public:
template <typename FuncT> explicit KernelFunction(FuncT &&func) : func_ptr_(reinterpret_cast<void *>(func)) {}
template <typename FuncT>
KernelFunction(Device::DeviceType device, std::string name, FuncT &&func)
: device_(device), name_(std::move(name)), func_ptr_(reinterpret_cast<void *>(func)) {}

// TODO(dcj): support auto-deduction of return type and parameter types
template <typename RetT, class... ArgsT> RetT Call(ArgsT... args) const {
#ifdef PROFILE_MODE
const auto &ctx = GetProfileContext();
Profiler::Instance().StartRecord(ctx.name, ctx.device);
Profiler::Instance().StartRecord(name_, device_);
#endif

using FuncT = RetT (*)(ArgsT...);
Expand All @@ -31,19 +33,21 @@ class KernelFunction {
if constexpr (std::is_void_v<RetT>) {
fn(std::forward<ArgsT>(args)...);
#ifdef PROFILE_MODE
Profiler::Instance().EndRecord(ctx.name, ctx.device);
Profiler::Instance().EndRecord(name_, device_);
#endif
return;
} else {
RetT ret = fn(std::forward<ArgsT>(args)...);
#ifdef PROFILE_MODE
Profiler::Instance().EndRecord(ctx.name, ctx.device);
Profiler::Instance().EndRecord(name_, device_);
#endif
return ret;
}
}

private:
Device::DeviceType device_;
std::string name_;
void *func_ptr_ = nullptr;
};

Expand All @@ -59,16 +63,13 @@ class Dispatcher {
const KernelFunction &GetKernel(KeyT key) const {
CHECK(key_to_kernel_map_.contains(key))
<< "Kernel not found: " << key.second << " on device: " << static_cast<int>(key.first);
#ifdef PROFILE_MODE
SetProfileContext(key.second, key.first);
#endif
return key_to_kernel_map_.at(key);
}

template <typename FuncT> void Register(const KeyT &key, FuncT &&kernel) {
CHECK(!key_to_kernel_map_.contains(key))
<< "Kernel already registered: " << key.second << " on device: " << static_cast<int>(key.first);
key_to_kernel_map_.emplace(key, kernel);
key_to_kernel_map_.try_emplace(key, key.first, key.second, std::forward<FuncT>(kernel));
}

template <typename RetT, class... ArgsT> RetT Call(KeyT key, ArgsT... args) const {
Expand Down
16 changes: 0 additions & 16 deletions infini_train/include/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,6 @@ class Event;

inline thread_local int g_profiling_depth = 0;

struct ProfileContext {
std::string name;
Device::DeviceType device;
};

inline thread_local ProfileContext g_profile_context;

inline void SetProfileContext(const std::string &name, Device::DeviceType device) {
if (g_profiling_depth == 0) {
g_profile_context.name = name;
g_profile_context.device = device;
}
}

inline const ProfileContext &GetProfileContext() { return g_profile_context; }

struct KernelProfileInfo {
int64_t host_total_us = 0;
int64_t device_total_us = 0;
Expand Down
Loading