/* How Does JIT Compilation Work? * * This blog post is also a valid C89 program which you can compile and run. * Compile it with: cc -x c -o jit post.txt * Then run it with: ./jit * * Warning: it will only work on AMD64 systems; see the comment in * translate_direct() for an explanation of why. * * JIT compilation, or Just-In-Time compilation, is a technique for more * efficiently interpreting programs by deferring compiling them into machine * code until they are just about to be executed, or even by deferring compiling * individual parts of the program until they are ready to be executed. It * sounds clever, but how does it actually work, and why do that instead of * running an optimizing compiler ahead of time? To give a full answer, a * motivating example might be in order. * * Let's envision a stack-based language somewhat like dc(1), whose programs are * a sequence of these instructions: * + - * / pop two numbers, operate, push results (the usual) * lit push the next thing from the instruction stream * if pop a number; if it's not zero, go to the address which is * next in the instruction stream * swap swap the two elements atop the stack * dup duplicate hte top two elements of the stack * done stop running and return the top of the stack * and let's assume that we've already handled all the parsing and other * frontend unpleasantness, so we have a program in our language represented * as follows. */ #define _GNU_SOURCE /* MAP_ANON */ #include #include #include #include #include typedef enum { OPC_ADD = 0, OPC_SUB = 1, OPC_MUL = 2, OPC_DIV = 3, OPC_LIT = 4, OPC_IF = 5, OPC_SWAP = 6, OPC_DUP = 7, OPC_DONE = 8, } Opcode; /* A program is therefore an array of integers, which are opcodes and their * optional arguments. Here's an example program, which expects two numbers * on the stack when it starts and which adds them by repeatedly incrementing * the first & decrementing the second - extremely inefficient, but suitable as * an example. */ const int example[] = { OPC_LIT, 1, OPC_SUB, OPC_SWAP, OPC_LIT, 1, OPC_ADD, OPC_SWAP, OPC_DUP, OPC_IF, 0, OPC_SWAP, OPC_ADD, OPC_DONE, }; /* Interpretation * * The most straightforward way to execute these programs is to simply interpret * them. We can do that by writing a function that traverses each opcode in * turn, simulating what the stack should look like, and does the requested * operations. We'll need a way to represent our stack of integers... */ typedef struct { int vals[256]; int top; } Stack; /* Implementations of these are at the bottom - how they work is not important * right now. */ void stack_init(Stack *s); void stack_pushargs(Stack *s, const int *args, int narg); void stack_push(Stack *s, int v); int stack_pop(Stack *s); void stack_dump(Stack *s); int arith(int op, Stack *s); /* And now here's our interpretation function itself. Arithmetic operations, * which share a lot of code, are broken out into a separate function. It takes * an array of arguments, which are put on the stack at the start of program * execution. */ int interpret(const int *program, const int *args, int nargs) { Stack s; int pc = 0; stack_init(&s); stack_pushargs(&s, args, nargs); while (1) { switch (program[pc]) { case OPC_ADD: case OPC_SUB: case OPC_MUL: case OPC_DIV: stack_push(&s, arith(program[pc], &s)); pc++; break; case OPC_LIT: stack_push(&s, program[pc + 1]); pc += 2; break; case OPC_IF: if (stack_pop(&s) != 0) pc = program[pc + 1]; else pc += 2; break; case OPC_SWAP: { int a = stack_pop(&s); int b = stack_pop(&s); stack_push(&s, a); stack_push(&s, b); pc++; break; } case OPC_DUP: { int a = stack_pop(&s); stack_push(&s, a); stack_push(&s, a); pc++; break; } case OPC_DONE: return stack_pop(&s); } } } int arith(int op, Stack *s) { int a = stack_pop(s); int b = stack_pop(s); switch (op) { case OPC_ADD: return b + a; case OPC_SUB: return b - a; case OPC_MUL: return b * a; case OPC_DIV: return b / a; default: assert(0 && "unknown arith"); } } /* Alright! Now that that's done, what actually *happens* when we run * interpret()? Let's focus in on the while loop, since that's where the real * work happens. On x86-64, that loop turns into machine code looking basically * like this: * cmp $0x8, %rcx * ja ouch * mov (%rdi, %rcx, 8), %rax * jmp *%rax * i.e., the body of the loop turns into a computed jump to an address looked up * in a table, which is how switch statements usually compile[2]. The first two * instructions are necessary to ensure that we don't index off the end of that * jump table, then we load the address out of the table and jump to it. * * However... this is actually just a matter of luck and me having chosen a * design that compiles to compact code. A real instruction encoding would * almost certainly not have the opcodes fit into a neat linear jump table like * that, and we'd instead end up with code looking more like this: * cmp $0x42, %rcx * je do_add * cmp $0x11, %rcx * je do_sub * ... and so on. This could rapidly get pretty slow, and even if we do have a * compact set of instructions like this example does, there's an extra * conditional branch on every pass through the loop that it'd be great to avoid * if we can - but how? * * Enter: Compilation! * * Here's the central idea: when we're interpreting bytecode, we have to * repeatedly do the "translation" from bytecode into local machine code, over * and over again. What if we could instead do that translation once, and then * just run the translated code over and over? That idea is usually called * "compilation". There are a lot of different ways to do it, of different * complexities, and we'll go through a couple of them here. * * One way of doing the translation step once is to produce what's called * "threaded code". To do this, we translate the existing bytecode to a sequence * of machine code addresses, using pre-written blocks of code that implement * each of the bytecode instructions. Our "interpreter" then is a function that * repeatedly does jumps through this sequence of addresses until the program * ends. This is called "threaded" because control returns to the interpreter * function after each bytecode instruction is interpreted, so the control flow * is conceptually still "threaded" through the interpreter. * * Here's how that looks. Since we need a separate machine code address to use * for each bytecode instruction, we'll need a bunch of little helper functions. * These have implementations below as an appendix. */ typedef int (*threaded_op)(const intptr_t *prog, int pc, Stack *s); int threaded_op_add(const intptr_t *prog, int pc, Stack *s); int threaded_op_sub(const intptr_t *prog, int pc, Stack *s); int threaded_op_mul(const intptr_t *prog, int pc, Stack *s); int threaded_op_div(const intptr_t *prog, int pc, Stack *s); int threaded_op_swap(const intptr_t *prog, int pc, Stack *s); int threaded_op_dup(const intptr_t *prog, int pc, Stack *s); int threaded_op_done(const intptr_t *prog, int pc, Stack *s); int threaded_op_lit(const intptr_t *prog, int pc, Stack *s); int threaded_op_if(const intptr_t *prog, int pc, Stack *s); /* This translates a single bytecode instruction into a single machine code * function address, and returns the number of arguments the bytecode * instruction needs to take *from the bytecode* in the nargs out parameter. */ threaded_op translate_one_threaded(int op, int *nargs) { *nargs = 0; switch (op) { case OPC_ADD: return threaded_op_add; case OPC_SUB: return threaded_op_sub; case OPC_MUL: return threaded_op_mul; case OPC_DIV: return threaded_op_div; case OPC_SWAP: return threaded_op_swap; case OPC_DUP: return threaded_op_dup; case OPC_DONE: return threaded_op_done; default: break; } *nargs = 1; switch (op) { case OPC_LIT: return threaded_op_lit; case OPC_IF: return threaded_op_if; default: assert(0 && "puzzling opcode"); } } /* Translates program, which is bytecode, into progbuf, which is a sequence of * function pointers, or addresses of machine code implementations of those * bytecode functions. This function simply assumes that progbuf is big enough * to hold the translated program. * * For example, if we were given this bytecode: * OPC_LIT, 2, OPC_LIT, 3, OPC_ADD, OPC_DONE * The resulting progbuf would look like: * &threaded_op_lit, 2, * &threaded_op_lit, 3, * &threaded_op_add, * &threaded_op_done * which is why progbuf is an intptr_t - so that its elements will be wide * enough to hold function pointers, but can also be treated as integers when we * need to. */ void translate_threaded(const int *program, intptr_t *progbuf) { int i = 0, j; threaded_op opfunc; do { int nargs; opfunc = translate_one_threaded(program[i], &nargs); progbuf[i] = (intptr_t)opfunc; for (j = 0; j < nargs; j++) progbuf[i + j + 1] = program[i + j + 1]; i += nargs + 1; } while (opfunc != threaded_op_done); } int threaded(const int *program, const int *args, int nargs) { intptr_t progbuf[256]; Stack s; int pc = 0; translate_threaded(program, progbuf); stack_init(&s); stack_pushargs(&s, args, nargs); /* All we have to do now is repeatedly call function pointers out of * progbuf until the pc goes negative, which indicates that the * program is done. */ while (pc >= 0) { threaded_op op = (threaded_op)progbuf[pc]; pc = op(progbuf, pc, &s); } return stack_pop(&s); } /* When we start executing the threaded code, our control flow then looks like * this: * * threaded(): jmp progbuf[pc] * threaded_op_add(): ... do some adding ... * threaded_op_add(): ret * threaded(): jmp progbuf[pc] * threaded_op_lit(): ... do some pushing ... * ... and so on. * * This is cool, and already shows a pretty nice gain in code compactness * (and speed) compared to the interpreted solution, but there is a problem: * pipelining. In fact, this approach can actually be slower than normal * interpretation, depending on the instruction format. Why is that? * * Since each of these jumps is "indirect" (to an address stored in memory) and * dependent on the value of pc, it's difficult for the processor to execute it * efficiently. The processor wants to speculatively execute instructions ahead * of where the current machine program counter is, but when it reaches the `jmp * progbuf[pc]` it has to stop, because it has no way to predict what the value * of progbuf[pc] will be[fn1]. * * How can we deal with that problem, and produce more efficient code? The * solution is to generate not simply a table of function addresses, but * actually a chunk of native *code*. In fact, by doing that, we can replace all * the previous indirect jumps with direct jumps, which produces what is called * "direct threaded" code. * * How does that look? Well, let's use the example program from before again: * OPC_LIT, 2, OPC_LIT, 3, OPC_ADD, OPC_DONE * and remember that the indirect threaded version looked like this, as an array * of intptr_t: * &threaded_op_lit, 2, * &threaded_op_lit, 3, * &threaded_op_add, * &threaded_op_done, * what we would like to generate instead of something like *this*, as an array * of *machine instructions*: * call threaded_op_lit(2) * call threaded_op_lit(3) * call threaded_op_add * call threaded_op_done * Since those calls are all "direct" calls to definite addresses, the processor * will be able to speculatively execute past them effectively and we will * experience pipelined glory. * * How do we do that? Well, we generate the machine code when we're compiling, * obviously, and then we run the resulting machine code. Call instructions are * easy enough to emit, although we're going to need a bit of assembly magic to * pass function parameters in. We can start merrily implementing that... until * we try to implement OPC_IF, which is quite difficult to phrase this way. The * problem is that in this model, threaded_op_if is a function which can return * to one of two different places. It's not too hard to achieve this * technically, but we have to write threaded_op_if in assembly - the C compiler * doesn't support doing this at all. If we're writing the op functions in * assembly already, though, why don't we just emit the op function bodies * directly and skip all of the calls? * * This is more or less how modern bytecode machines actually work: they * translate the bytecode directly into host machine code that they can execute, * using the host's program counter for control flow and so on. For that same * example program from above, which was: * OPC_LIT, 2, OPC_LIT, 3, OPC_ADD, OPC_DONE * that would look like these machine instructions, modulo some x86 nastiness: * push $2 * push $3 * pop %r8 * pop %r9 * add %r8, %r9 * push %r9 * ret * Sounds fun. How do we actually do that? Given a program written in the * bytecode we defined above, we need to produce a buffer full of machine * instructions which we can then execute. */ /* First, we need a set of helper functions that we use to emit stuff into the * JIT buffer - either individual bytes, or reusable instructions. We'll use * both the machine program counter and its stack. */ uint8_t *emit1(uint8_t *pb, uint8_t b0) { *(pb++) = b0; return pb; } uint8_t *emit2(uint8_t *pb, uint8_t b0, uint8_t b1) { *(pb++) = b0; *(pb++) = b1; return pb; } uint8_t *emit3(uint8_t *pb, uint8_t b0, uint8_t b1, uint8_t b2) { *(pb++) = b0; *(pb++) = b1; *(pb++) = b2; return pb; } uint8_t *emit4(uint8_t *pb, uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3) { pb = emit2(pb, b0, b1); pb = emit2(pb, b2, b3); return pb; } uint8_t *emitw4(uint8_t *pb, intptr_t v) { pb = emit2(pb, v & 0xff, (v >> 8) & 0xff); pb = emit2(pb, (v >> 16) & 0xff, (v >> 24) & 0xff); return pb; } uint8_t *emitw8(uint8_t *pb, intptr_t v) { pb = emit2(pb, v & 0xff, (v >> 8) & 0xff); pb = emit2(pb, (v >> 16) & 0xff, (v >> 24) & 0xff); pb = emit2(pb, (v >> 32) & 0xff, (v >> 40) & 0xff); pb = emit2(pb, (v >> 48) & 0xff, (v >> 56) & 0xff); return pb; } uint8_t *emit_pop2(uint8_t *pb) { /* pop %r8, pop %r9 */ pb = emit2(pb, 0x41, 0x58); pb = emit2(pb, 0x41, 0x59); return pb; } uint8_t *emit_push1(uint8_t *pb) { /* push %r9 */ return emit2(pb, 0x41, 0x51); return pb; } /* Now, here's the guts of the JIT. This translates a single instruction from * program[*pc] to some bytes starting at progbuf[*bc], and also fills in a slot * in the jump table. The jump table is used to emit target addresses for * OPC_IF; since not all our instructions are the same length, either in the * program or in the generated progbuf, we need some way to map OPC_IF's operand * (and index into the program) into an offset from the machine if instruction. * * Important note: as implemented here, this technique doesn't support OPC_IF * with a forward jump target, because the jtab entry for the target won't yet * have been filled in. We could deal with that by JITing in two passes instead * of one, with the second pass responsible for filling in jump targets that * weren't resolved on the first pass, but it significantly complicates things * no pedagogical purpose. In a real JIT, we'd need to do that, of course. * * A couple of conventions for reading the generated code: it uses the host * stack for data, and uses %r8 and %r9 as scratch registers to implement the * various stack operations, so when you see pop2 and push1, those operate on * the host stack, %r8, and %r9. * * This function returns whether translation is done (that is, we encountered * OPC_DONE). */ int translate_one_jit(const int *program, int *pc, uint8_t *progbuf, int *bc, uint8_t **jtab) { int np = 1; uint8_t *pb = progbuf + *bc; int done = 0; jtab[*pc] = progbuf + *bc; switch (program[*pc]) { case OPC_ADD: pb = emit_pop2(pb); pb = emit3(pb, 0x4d, 0x01, 0xc1); pb = emit_push1(pb); break; case OPC_SUB: pb = emit_pop2(pb); pb = emit3(pb, 0x4d, 0x29, 0xc1); pb = emit_push1(pb); break; case OPC_MUL: pb = emit_pop2(pb); pb = emit4(pb, 0x4d, 0x0f, 0xaf, 0xc8); pb = emit_push1(pb); break; case OPC_DIV: /* This is grotesque because of x86, see [fn2]. */ pb = emit2(pb, 0x41, 0x58); /* pop %r8 */ pb = emit1(pb, 0x58); /* pop %rax */ pb = emit3(pb, 0x49, 0xf7, 0xf8); /* idiv %r8 */ pb = emit1(pb, 0x50); /* push %rax */ break; case OPC_SWAP: pb = emit_pop2(pb); pb = emit2(pb, 0x41, 0x50); /* push %r8 */ pb = emit2(pb, 0x41, 0x51); /* push %r9 */ break; case OPC_DUP: pb = emit2(pb, 0x41, 0x59); /* pop %r9 */ pb = emit_push1(pb); /* push %r9 */ pb = emit_push1(pb); /* push %r9 */ break; case OPC_DONE: pb = emit1(pb, 0x58); /* pop %rax */ pb = emit1(pb, 0xc3); /* ret */ done = 1; break; case OPC_LIT: np = 2; /* x86 has no 'push immediate' instruction for large * immediates, so we need a scratch register. */ pb = emit2(pb, 0x48, 0xb8); /* movabs %rax */ pb = emitw8(pb, program[*pc + 1]); pb = emit1(pb, 0x50); /* push %rax */ break; case OPC_IF: np = 2; pb = emit2(pb, 0x41, 0x59); /* pop %r9 */ pb = emit3(pb, 0x4d, 0x85, 0xc9); /* test %r9, %r9 */ pb = emit2(pb, 0x0f, 0x85); /* jne */ /* the Jcc instructions take a relative address, which * is interpreted relative to the jump instruction * itself. */ pb = emitw4(pb, jtab[program[*pc + 1]] - pb - 4); break; } *bc = pb - progbuf; *pc += np; return !done; } /* This function emits a preamble on the JITed code, which is responsible for * pushing all of the supplied arguments onto the machine stack for use. */ void jit_setup(uint8_t *progbuf, int *bc, const int *args, int nargs) { uint8_t *pb = progbuf + *bc; while (--nargs >= 0) { pb = emit2(pb, 0x48, 0xb8); /* movabs %rax */ pb = emitw8(pb, args[nargs]); pb = emit1(pb, 0x50); /* push %rax */ } *bc = pb - progbuf; } void translate_jit(const int *program, uint8_t *progbuf, const int *args, int nargs) { int pc = 0, bc = 0; uint8_t *jtab[512]; jit_setup(progbuf, &bc, args, nargs); while (translate_one_jit(program, &pc, progbuf, &bc, jtab)) ; } uint8_t *progbuf_new(); void progbuf_make_ready(uint8_t *progbuf); int jit(const int *program, const int *args, int nargs) { uint8_t *progbuf = progbuf_new(); translate_jit(program, progbuf, args, nargs); progbuf_make_ready(progbuf); int (*translated)(void) = (int (*)(void))progbuf; return translated(); } /* Whew! JIT victory. Now let's do a little bit of microbenchmarking: */ uint64_t delta_ns(struct timespec ta, struct timespec tb) { return (tb.tv_sec - ta.tv_sec) * 1000000000L + tb.tv_nsec - ta.tv_nsec; } int main() { const int args[] = { 400000, 900000 }; const int nargs = sizeof(args) / sizeof(args[0]); struct timespec st, tt, jt, et; int ir, tr, jr; clock_gettime(CLOCK_MONOTONIC, &st); ir = interpret(example, args, nargs); clock_gettime(CLOCK_MONOTONIC, &tt); tr = threaded(example, args, nargs); clock_gettime(CLOCK_MONOTONIC, &jt); jr = jit(example, args, nargs); clock_gettime(CLOCK_MONOTONIC, &et); printf("int: %d (%ld ns)\n", ir, delta_ns(st, tt)); printf("thr: %d (%ld ns)\n", tr, delta_ns(tt, jt)); printf("jit: %d (%ld ns)\n", jr, delta_ns(jt, et)); return 0; } /* ... and there we are! A working JIT for a small bytecode language. * * We could, of course, do the translation part ahead of time, rather than right * as we're executing the bytecode - in that case, it'd be ahead-of-time rather * than just-in-time compilation. That has a lot of advantages, like not needing * the translator present at runtime, but it does mean the translator has a lot * less information to work with. For example, a translator that is executing at * runtime can actually "know", by instrumenting the translated machine code, * which sections are hot versus cold and make intelligent optimization * decisions - effectively live profile-guided optimization. Modern Javascript * engines rely heavily on this technique to make dynamic tradeoffs between * compile and runtime, but that's material for another post. * * That's it for now I think folks - thanks for reading, as always :) */ /* Stack manipulation functions, for the interpreter and the threaded code * version - nothing super spicy here. */ void stack_init(Stack *s) { s->top = 0; } void stack_pushargs(Stack *s, const int *args, int nargs) { /* Note that the arguments are pushed in reverse order, a pattern that * we'll reuse later - the natural argument order for a stack machine * has the function's 0th argument at the top of the stack when it is * called. */ while (--nargs >= 0) stack_push(s, args[nargs]); } void stack_push(Stack *s, int v) { assert(s->top < sizeof(s->vals) / sizeof(s->vals[0])); s->vals[s->top++] = v; } int stack_pop(Stack *s) { assert(s->top > 0); return s->vals[--s->top]; } void stack_dump(Stack *s) { int i; for (i = 0; i < s->top; i++) printf("%d ", s->vals[i]); printf("\n"); } /* Appendix: operations for the threaded version. These all have to have the * same type signature so they can be called via function pointers, but are * otherwise pretty boring. They all return the new value of the program * counter. */ int threaded_op_add(const intptr_t *prog, int pc, Stack *s) { int a = stack_pop(s); int b = stack_pop(s); stack_push(s, b + a); return pc + 1; } int threaded_op_sub(const intptr_t *prog, int pc, Stack *s) { int a = stack_pop(s); int b = stack_pop(s); stack_push(s, b - a); return pc + 1; } int threaded_op_mul(const intptr_t *prog, int pc, Stack *s) { int a = stack_pop(s); int b = stack_pop(s); stack_push(s, b * a); return pc + 1; } int threaded_op_div(const intptr_t *prog, int pc, Stack *s) { int a = stack_pop(s); int b = stack_pop(s); stack_push(s, b / a); return pc + 1; } int threaded_op_swap(const intptr_t *prog, int pc, Stack *s) { int a = stack_pop(s); int b = stack_pop(s); stack_push(s, a); stack_push(s, b); return pc + 1; } int threaded_op_dup(const intptr_t *prog, int pc, Stack *s) { int a = stack_pop(s); stack_push(s, a); stack_push(s, a); return pc + 1; } int threaded_op_done(const intptr_t *prog, int pc, Stack *s) { return -1; } int threaded_op_lit(const intptr_t *prog, int pc, Stack *s) { stack_push(s, prog[pc + 1]); return pc + 2; } int threaded_op_if(const intptr_t *prog, int pc, Stack *s) { if (stack_pop(s) != 0) return prog[pc + 1]; else return pc + 2; } /* Appendix: two wrapper functions to abstract away a nasty detail of doing JIT. * 30 years ago, we could have simply allocated the program buffer on the stack * and jumped directly into it. However, modern systems enforce an invariant * called "W^X", meaning that any given chunk of memory can be either writeable * or executable but not both. W^X is imposed at page granularity, so for our * JIT, we allocate a dedicated page of memory to use. That page starts off * writeable, but progbuf_make_ready() flips it to be executable instead. */ uint8_t *progbuf_new() { return mmap(NULL, 4096, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); } void progbuf_make_ready(uint8_t *progbuf) { mprotect(progbuf, 4096, PROT_READ | PROT_EXEC); } /* Footnotes! * * [fn1]: this actually isn't *as* bad as it sounds, since the branch predictors * in modern CPUs are very clever, but it's still pretty darn bad. * * [fn2]: the IDIV instruction on x86 and AMD64 always operates on %rax, which * is neither very orthogonal nor very cash money, and requires this special * case here. */