Skip to content

Commit 3932538

Browse files
committed
Add LLVMExecutableCode class
1 parent 98fa65b commit 3932538

File tree

11 files changed

+318
-2
lines changed

11 files changed

+318
-2
lines changed

src/dev/engine/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ target_sources(scratchcpp
33
executioncontext.cpp
44
executioncontext_p.cpp
55
executioncontext_p.h
6+
internal/llvmexecutablecode.cpp
7+
internal/llvmexecutablecode.h
68
internal/llvmexecutioncontext.cpp
79
internal/llvmexecutioncontext.h
810
)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include <scratchcpp/dev/executioncontext.h>
4+
#include <llvm/Support/Error.h>
5+
#include <iostream>
6+
7+
#include "llvmexecutablecode.h"
8+
#include "llvmexecutioncontext.h"
9+
10+
using namespace libscratchcpp;
11+
12+
LLVMExecutableCode::LLVMExecutableCode(std::unique_ptr<llvm::Module> module) :
13+
m_ctx(std::make_unique<llvm::LLVMContext>()),
14+
m_jit(llvm::orc::LLJITBuilder().create())
15+
{
16+
if (!m_jit) {
17+
llvm::errs() << "error: failed to create JIT: " << toString(m_jit.takeError()) << "\n";
18+
return;
19+
}
20+
21+
if (!module)
22+
return;
23+
24+
std::string name = module->getName().str();
25+
auto err = m_jit->get()->addIRModule(llvm::orc::ThreadSafeModule(std::move(module), std::move(m_ctx)));
26+
27+
if (err) {
28+
llvm::errs() << "error: failed to add module '" << name << "' to JIT: " << toString(std::move(err)) << "\n";
29+
return;
30+
}
31+
32+
// Lookup functions
33+
size_t i = 0;
34+
35+
while (true) {
36+
auto func = m_jit->get()->lookup("f" + std::to_string(i));
37+
38+
if (func)
39+
m_functions.push_back((FunctionType)(func->getValue()));
40+
else {
41+
// Ignore error
42+
llvm::consumeError(func.takeError());
43+
break;
44+
}
45+
46+
i++;
47+
}
48+
}
49+
50+
void LLVMExecutableCode::run(ExecutionContext *context)
51+
{
52+
LLVMExecutionContext *ctx = getContext(context);
53+
54+
if (ctx->pos() < m_functions.size())
55+
ctx->setPos(m_functions[ctx->pos()](context->target()));
56+
}
57+
58+
void LLVMExecutableCode::kill(ExecutionContext *context)
59+
{
60+
getContext(context)->setPos(m_functions.size());
61+
}
62+
63+
void LLVMExecutableCode::reset(ExecutionContext *context)
64+
{
65+
getContext(context)->setPos(0);
66+
}
67+
68+
bool LLVMExecutableCode::isFinished(ExecutionContext *context) const
69+
{
70+
return getContext(context)->pos() >= m_functions.size();
71+
}
72+
73+
void LLVMExecutableCode::promise()
74+
{
75+
}
76+
77+
void LLVMExecutableCode::resolvePromise()
78+
{
79+
}
80+
81+
std::shared_ptr<ExecutionContext> LLVMExecutableCode::createExecutionContext(Target *target) const
82+
{
83+
return std::make_shared<LLVMExecutionContext>(target);
84+
}
85+
86+
LLVMExecutionContext *LLVMExecutableCode::getContext(ExecutionContext *context)
87+
{
88+
assert(dynamic_cast<LLVMExecutionContext *>(context));
89+
return static_cast<LLVMExecutionContext *>(context);
90+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#pragma once
4+
5+
#include <scratchcpp/dev/executablecode.h>
6+
#include <llvm/IR/LLVMContext.h>
7+
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
8+
9+
namespace libscratchcpp
10+
{
11+
12+
class Target;
13+
class LLVMExecutionContext;
14+
15+
class LLVMExecutableCode : public ExecutableCode
16+
{
17+
public:
18+
LLVMExecutableCode(std::unique_ptr<llvm::Module> module);
19+
20+
void run(ExecutionContext *context) override;
21+
void kill(libscratchcpp::ExecutionContext *context) override;
22+
void reset(ExecutionContext *context) override;
23+
24+
bool isFinished(ExecutionContext *context) const override;
25+
26+
void promise() override;
27+
void resolvePromise() override;
28+
29+
std::shared_ptr<ExecutionContext> createExecutionContext(Target *target) const override;
30+
31+
private:
32+
using FunctionType = size_t (*)(Target *);
33+
34+
static LLVMExecutionContext *getContext(ExecutionContext *context);
35+
36+
std::unique_ptr<llvm::LLVMContext> m_ctx;
37+
llvm::Expected<std::unique_ptr<llvm::orc::LLJIT>> m_jit;
38+
39+
std::vector<FunctionType> m_functions;
40+
};
41+
42+
} // namespace libscratchcpp

test/dev/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(blocks)
22
add_subdirectory(executioncontext)
3+
add_subdirectory(llvm)

test/dev/llvm/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
1+
add_library(
2+
llvm_test_functions SHARED
3+
testfunctions.cpp
4+
testfunctions.h
5+
testmock.h
6+
)
7+
8+
target_link_libraries(
9+
llvm_test_functions
10+
GTest::gtest_main
11+
GTest::gmock_main
12+
scratchcpp
13+
)
14+
115
add_executable(
216
llvm_test
17+
main.cpp
318
llvmexecutioncontext_test.cpp
19+
llvmexecutablecode_test.cpp
420
)
521

622
target_link_libraries(
@@ -9,6 +25,8 @@ target_link_libraries(
925
GTest::gmock_main
1026
scratchcpp
1127
scratchcpp_mocks
28+
llvm_test_functions
29+
LLVM
1230
)
1331

1432
gtest_discover_tests(llvm_test)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include <scratchcpp/target.h>
2+
#include <dev/engine/internal/llvmexecutablecode.h>
3+
#include <dev/engine/internal/llvmexecutioncontext.h>
4+
#include <llvm/Support/TargetSelect.h>
5+
#include <llvm/IR/IRBuilder.h>
6+
#include <gmock/gmock.h>
7+
8+
#include "testmock.h"
9+
#include "testfunctions.h"
10+
11+
using namespace libscratchcpp;
12+
13+
class LLVMExecutableCodeTest : public testing::Test
14+
{
15+
public:
16+
void SetUp() override
17+
{
18+
m_module = std::make_unique<llvm::Module>("test", m_ctx);
19+
m_builder = std::make_unique<llvm::IRBuilder<>>(m_ctx);
20+
test_function(nullptr, nullptr); // force dependency
21+
22+
llvm::InitializeNativeTarget();
23+
llvm::InitializeNativeTargetAsmPrinter();
24+
llvm::InitializeNativeTargetAsmParser();
25+
}
26+
27+
llvm::Function *beginFunction(size_t index)
28+
{
29+
// size_t f#(Target *)
30+
llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder->getInt64Ty(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false);
31+
llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f" + std::to_string(index), m_module.get());
32+
33+
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func);
34+
m_builder->SetInsertPoint(entry);
35+
return func;
36+
}
37+
38+
void endFunction(size_t index)
39+
{
40+
// Return next function index
41+
m_builder->CreateRet(m_builder->getInt64(index + 1));
42+
}
43+
44+
void addTestFunction(llvm::Function *mainFunc)
45+
{
46+
auto ptrType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0);
47+
auto func = m_module->getOrInsertFunction("test_function", llvm::FunctionType::get(m_builder->getVoidTy(), { ptrType, ptrType }, false));
48+
49+
llvm::Constant *mockInt = llvm::ConstantInt::get(llvm::Type::getInt64Ty(m_ctx), (uintptr_t)&m_mock, false);
50+
llvm::Constant *mockPtr = llvm::ConstantExpr::getIntToPtr(mockInt, ptrType);
51+
52+
m_builder->CreateCall(func, { mockPtr, mainFunc->getArg(0) });
53+
}
54+
55+
llvm::LLVMContext m_ctx;
56+
std::unique_ptr<llvm::Module> m_module;
57+
std::unique_ptr<llvm::IRBuilder<>> m_builder;
58+
Target m_target;
59+
TestMock m_mock;
60+
};
61+
62+
TEST_F(LLVMExecutableCodeTest, NoFunctions)
63+
{
64+
LLVMExecutionContext ctx(&m_target);
65+
LLVMExecutableCode code(std::move(m_module));
66+
ASSERT_TRUE(code.isFinished(&ctx));
67+
68+
code.run(&ctx);
69+
ASSERT_TRUE(code.isFinished(&ctx));
70+
71+
code.kill(&ctx);
72+
ASSERT_TRUE(code.isFinished(&ctx));
73+
74+
code.reset(&ctx);
75+
ASSERT_TRUE(code.isFinished(&ctx));
76+
}
77+
78+
TEST_F(LLVMExecutableCodeTest, SingleFunction)
79+
{
80+
auto f = beginFunction(0);
81+
addTestFunction(f);
82+
endFunction(0);
83+
84+
LLVMExecutionContext ctx(&m_target);
85+
LLVMExecutableCode code(std::move(m_module));
86+
ASSERT_FALSE(code.isFinished(&ctx));
87+
88+
EXPECT_CALL(m_mock, f(&m_target));
89+
code.run(&ctx);
90+
ASSERT_TRUE(code.isFinished(&ctx));
91+
92+
code.kill(&ctx);
93+
ASSERT_TRUE(code.isFinished(&ctx));
94+
95+
code.reset(&ctx);
96+
ASSERT_FALSE(code.isFinished(&ctx));
97+
98+
code.kill(&ctx);
99+
ASSERT_TRUE(code.isFinished(&ctx));
100+
101+
EXPECT_CALL(m_mock, f).Times(0);
102+
code.run(&ctx);
103+
ASSERT_TRUE(code.isFinished(&ctx));
104+
}
105+
106+
TEST_F(LLVMExecutableCodeTest, MultipleFunctions)
107+
{
108+
static const int count = 5;
109+
110+
for (int i = 0; i < count; i++) {
111+
auto f = beginFunction(i);
112+
addTestFunction(f);
113+
endFunction(i);
114+
}
115+
116+
LLVMExecutionContext ctx(&m_target);
117+
LLVMExecutableCode code(std::move(m_module));
118+
ASSERT_FALSE(code.isFinished(&ctx));
119+
120+
for (int i = 0; i < count; i++) {
121+
ASSERT_FALSE(code.isFinished(&ctx));
122+
EXPECT_CALL(m_mock, f(&m_target));
123+
code.run(&ctx);
124+
}
125+
126+
ASSERT_TRUE(code.isFinished(&ctx));
127+
}

test/dev/llvm/llvmexecutioncontext_test.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <scratchcpp/target.h>
22
#include <dev/engine/internal/llvmexecutioncontext.h>
3-
4-
#include "../../common.h"
3+
#include <gtest/gtest.h>
54

65
using namespace libscratchcpp;
76

test/dev/llvm/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "../../common.h"

test/dev/llvm/testfunctions.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#include "testmock.h"
2+
3+
using namespace libscratchcpp;
4+
5+
extern "C" void test_function(TestMock *mock, Target *target)
6+
{
7+
if (mock)
8+
mock->f(target);
9+
}

test/dev/llvm/testfunctions.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
namespace libscratchcpp
4+
{
5+
6+
class TestMock;
7+
class Target;
8+
9+
extern "C" void test_function(TestMock *mock, Target *target);
10+
11+
} // namespace libscratchcpp

0 commit comments

Comments
 (0)