Skip to content

Commit 5705ff7

Browse files
committed
Implement a few more ops in the JIT
1 parent 865ad00 commit 5705ff7

File tree

2 files changed

+180
-58
lines changed

2 files changed

+180
-58
lines changed

jit/src/instructions.rs

Lines changed: 75 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ use std::collections::HashMap;
88

99
use super::{JitCompileError, JitSig, JitType};
1010

11+
#[repr(u16)]
12+
enum CustomTrapCode {
13+
/// Raised when shifting by a negative number
14+
NegativeShiftCount = 0,
15+
}
16+
1117
#[derive(Clone)]
1218
struct Local {
1319
var: Variable,
@@ -340,64 +346,75 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
340346
// the rhs is popped off first
341347
let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
342348
let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
343-
match (a.ty, b.ty) {
344-
(JitType::Int, JitType::Int) => match op {
345-
BinaryOperator::Add => {
346-
let (out, carry) = self.builder.ins().iadd_ifcout(a.val, b.val);
347-
self.builder.ins().trapif(
348-
IntCC::Overflow,
349-
carry,
350-
TrapCode::IntegerOverflow,
351-
);
352-
self.stack.push(JitValue {
353-
val: out,
354-
ty: JitType::Int,
355-
});
356-
Ok(())
357-
}
358-
BinaryOperator::Subtract => {
359-
let out = self.compile_sub(a.val, b.val);
360-
self.stack.push(JitValue {
361-
val: out,
362-
ty: JitType::Int,
363-
});
364-
Ok(())
365-
}
366-
_ => Err(JitCompileError::NotSupported),
367-
},
368-
(JitType::Float, JitType::Float) => match op {
369-
BinaryOperator::Add => {
370-
self.stack.push(JitValue {
371-
val: self.builder.ins().fadd(a.val, b.val),
372-
ty: JitType::Float,
373-
});
374-
Ok(())
375-
}
376-
BinaryOperator::Subtract => {
377-
self.stack.push(JitValue {
378-
val: self.builder.ins().fsub(a.val, b.val),
379-
ty: JitType::Float,
380-
});
381-
Ok(())
382-
}
383-
BinaryOperator::Multiply => {
384-
self.stack.push(JitValue {
385-
val: self.builder.ins().fmul(a.val, b.val),
386-
ty: JitType::Float,
387-
});
388-
Ok(())
389-
}
390-
BinaryOperator::Divide => {
391-
self.stack.push(JitValue {
392-
val: self.builder.ins().fdiv(a.val, b.val),
393-
ty: JitType::Float,
394-
});
395-
Ok(())
396-
}
397-
_ => Err(JitCompileError::NotSupported),
398-
},
399-
_ => Err(JitCompileError::NotSupported),
400-
}
349+
let (val, ty) = match (op, a.ty, b.ty) {
350+
(BinaryOperator::Add, JitType::Int, JitType::Int) => {
351+
let (out, carry) = self.builder.ins().iadd_ifcout(a.val, b.val);
352+
self.builder.ins().trapif(
353+
IntCC::Overflow,
354+
carry,
355+
TrapCode::IntegerOverflow,
356+
);
357+
(out, JitType::Int)
358+
}
359+
(BinaryOperator::Subtract, JitType::Int, JitType::Int) => {
360+
(self.compile_sub(a.val, b.val), JitType::Int)
361+
}
362+
(BinaryOperator::FloorDivide, JitType::Int, JitType::Int) => {
363+
(self.builder.ins().sdiv(a.val, b.val), JitType::Int)
364+
}
365+
(BinaryOperator::Modulo, JitType::Int, JitType::Int) => {
366+
(self.builder.ins().srem(a.val, b.val), JitType::Int)
367+
}
368+
(
369+
BinaryOperator::Lshift | BinaryOperator::Rshift,
370+
JitType::Int,
371+
JitType::Int,
372+
) => {
373+
// Shifts throw an exception if we have a negative shift count
374+
// Remove all bits except the sign bit, and trap if its 1 (i.e. negative).
375+
let sign = self.builder.ins().ushr_imm(b.val, 63);
376+
self.builder.ins().trapnz(
377+
sign,
378+
TrapCode::User(CustomTrapCode::NegativeShiftCount as u16),
379+
);
380+
381+
let out = if *op == BinaryOperator::Lshift {
382+
self.builder.ins().ishl(a.val, b.val)
383+
} else {
384+
self.builder.ins().sshr(a.val, b.val)
385+
};
386+
387+
(out, JitType::Int)
388+
}
389+
(BinaryOperator::And, JitType::Int, JitType::Int) => {
390+
(self.builder.ins().band(a.val, b.val), JitType::Int)
391+
}
392+
(BinaryOperator::Or, JitType::Int, JitType::Int) => {
393+
(self.builder.ins().bor(a.val, b.val), JitType::Int)
394+
}
395+
(BinaryOperator::Xor, JitType::Int, JitType::Int) => {
396+
(self.builder.ins().bxor(a.val, b.val), JitType::Int)
397+
}
398+
399+
// Floats
400+
(BinaryOperator::Add, JitType::Float, JitType::Float) => {
401+
(self.builder.ins().fadd(a.val, b.val), JitType::Float)
402+
}
403+
(BinaryOperator::Subtract, JitType::Float, JitType::Float) => {
404+
(self.builder.ins().fsub(a.val, b.val), JitType::Float)
405+
}
406+
(BinaryOperator::Multiply, JitType::Float, JitType::Float) => {
407+
(self.builder.ins().fmul(a.val, b.val), JitType::Float)
408+
}
409+
(BinaryOperator::Divide, JitType::Float, JitType::Float) => {
410+
(self.builder.ins().fdiv(a.val, b.val), JitType::Float)
411+
}
412+
_ => return Err(JitCompileError::NotSupported),
413+
};
414+
415+
self.stack.push(JitValue { val, ty });
416+
417+
Ok(())
401418
}
402419
Instruction::SetupLoop { .. } | Instruction::PopBlock => {
403420
// TODO: block support

jit/tests/int_tests.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,111 @@ fn test_sub() {
2323
assert_eq!(sub(-3, -10), Ok(7));
2424
}
2525

26+
#[test]
27+
fn test_floor_div() {
28+
let floor_div = jit_function! { floor_div(a:i64, b:i64) -> i64 => r##"
29+
def floor_div(a: int, b: int):
30+
return a // b
31+
"## };
32+
33+
assert_eq!(floor_div(5, 10), Ok(0));
34+
assert_eq!(floor_div(5, 2), Ok(2));
35+
assert_eq!(floor_div(12, 10), Ok(1));
36+
assert_eq!(floor_div(7, 10), Ok(0));
37+
assert_eq!(floor_div(-3, -1), Ok(3));
38+
assert_eq!(floor_div(-3, 1), Ok(-3));
39+
}
40+
41+
#[test]
42+
fn test_mod() {
43+
let modulo = jit_function! { modulo(a:i64, b:i64) -> i64 => r##"
44+
def modulo(a: int, b: int):
45+
return a % b
46+
"## };
47+
48+
assert_eq!(modulo(5, 10), Ok(5));
49+
assert_eq!(modulo(5, 2), Ok(1));
50+
assert_eq!(modulo(12, 10), Ok(2));
51+
assert_eq!(modulo(7, 10), Ok(7));
52+
assert_eq!(modulo(-3, 1), Ok(0));
53+
assert_eq!(modulo(-5, 10), Ok(-5));
54+
}
55+
56+
#[test]
57+
fn test_lshift() {
58+
let lshift = jit_function! { lshift(a:i64, b:i64) -> i64 => r##"
59+
def lshift(a: int, b: int):
60+
return a << b
61+
"## };
62+
63+
assert_eq!(lshift(5, 10), Ok(5120));
64+
assert_eq!(lshift(5, 2), Ok(20));
65+
assert_eq!(lshift(12, 10), Ok(12288));
66+
assert_eq!(lshift(7, 10), Ok(7168));
67+
assert_eq!(lshift(-3, 1), Ok(-6));
68+
assert_eq!(lshift(-10, 2), Ok(-40));
69+
}
70+
71+
#[test]
72+
fn test_rshift() {
73+
let rshift = jit_function! { rshift(a:i64, b:i64) -> i64 => r##"
74+
def rshift(a: int, b: int):
75+
return a >> b
76+
"## };
77+
78+
assert_eq!(rshift(5120, 10), Ok(5));
79+
assert_eq!(rshift(20, 2), Ok(5));
80+
assert_eq!(rshift(12288, 10), Ok(12));
81+
assert_eq!(rshift(7168, 10), Ok(7));
82+
assert_eq!(rshift(-3, 1), Ok(-2));
83+
assert_eq!(rshift(-10, 2), Ok(-3));
84+
}
85+
86+
#[test]
87+
fn test_and() {
88+
let bitand = jit_function! { bitand(a:i64, b:i64) -> i64 => r##"
89+
def bitand(a: int, b: int):
90+
return a & b
91+
"## };
92+
93+
assert_eq!(bitand(5120, 10), Ok(0));
94+
assert_eq!(bitand(20, 16), Ok(16));
95+
assert_eq!(bitand(12488, 4249), Ok(4232));
96+
assert_eq!(bitand(7168, 2), Ok(0));
97+
assert_eq!(bitand(-3, 1), Ok(1));
98+
assert_eq!(bitand(-10, 2), Ok(2));
99+
}
100+
101+
#[test]
102+
fn test_or() {
103+
let bitor = jit_function! { bitor(a:i64, b:i64) -> i64 => r##"
104+
def bitor(a: int, b: int):
105+
return a | b
106+
"## };
107+
108+
assert_eq!(bitor(5120, 10), Ok(5130));
109+
assert_eq!(bitor(20, 16), Ok(20));
110+
assert_eq!(bitor(12488, 4249), Ok(12505));
111+
assert_eq!(bitor(7168, 2), Ok(7170));
112+
assert_eq!(bitor(-3, 1), Ok(-3));
113+
assert_eq!(bitor(-10, 2), Ok(-10));
114+
}
115+
116+
#[test]
117+
fn test_xor() {
118+
let bitxor = jit_function! { bitxor(a:i64, b:i64) -> i64 => r##"
119+
def bitxor(a: int, b: int):
120+
return a ^ b
121+
"## };
122+
123+
assert_eq!(bitxor(5120, 10), Ok(5130));
124+
assert_eq!(bitxor(20, 16), Ok(4));
125+
assert_eq!(bitxor(12488, 4249), Ok(8273));
126+
assert_eq!(bitxor(7168, 2), Ok(7170));
127+
assert_eq!(bitxor(-3, 1), Ok(-4));
128+
assert_eq!(bitxor(-10, 2), Ok(-12));
129+
}
130+
26131
#[test]
27132
fn test_eq() {
28133
let eq = jit_function! { eq(a:i64, b:i64) -> i64 => r##"

0 commit comments

Comments
 (0)