Roland Paterson-Jones: 1 RFC: Jmp/jnz optimisations 6 files changed, 411 insertions(+), 5 deletions(-)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~mpu/qbe/patches/53939/mbox | git am -3Learn more about email & git
NOTE: this patch applies on top of GVN/GCM RFC7 [https://lists.sr.ht/~mpu/qbe/patches/53935] (only because there is common code - there are no hard dependencies) Optimisations: 1. Forward jmp/jnz targets through empty blocks - fwdempty() Reduces unnecessary CFG complexity and improves GCM block selection. 2. "if-convert" tiny if-then[-else] graphlets - ifconvert() If-conversion is a standard compiler technique that replaces conditional branches with conditional move (aka "select") instructions. This is generally beneficial in cases where branch prediction is poor - i.e. there is no dominant execution path. On the other hand, it increases the code path and hence can have minor detrimental performance impact in cases where branch prediction is accurate (dominant execution path). The implementation uses architecture-neutral masking at the QBE IL level in order to be portable. Architecture-specific use of conditional move instructions, for example, would lead to tighter code. Currently tuned to address only tiny (<= 2 instruction) if-else blocks. Testing - standard QBE, cproc, hare, harec, roland, coremark Benchmark: coremark is ~20% faster than GVN/GCM RFC7 [caveat - this is mainly due to significant benefit of if-conversion in coremark's heavily-used crcu8() function] --- Makefile | 2 +- all.h | 13 +- gcm.c | 2 +- gvn.c | 4 +- ifopt.c | 376 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ main.c | 19 +++ 6 files changed, 411 insertions(+), 5 deletions(-) create mode 100644 ifopt.c diff --git a/Makefile b/Makefile index b6dd6b9..83303ad 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ PREFIX = /usr/local BINDIR = $(PREFIX)/bin COMMOBJ = main.o util.o parse.o abi.o cfg.o mem.o ssa.o alias.o load.o \ - gvn.o gcm.o simpl.o live.o spill.o rega.o emit.o + gvn.o gcm.o ifopt.o simpl.o live.o spill.o rega.o emit.o AMD64OBJ = amd64/targ.o amd64/sysv.o amd64/isel.o amd64/emit.o ARM64OBJ = arm64/targ.o arm64/abi.o arm64/isel.o arm64/emit.o RV64OBJ = rv64/targ.o rv64/abi.o rv64/isel.o rv64/emit.o diff --git a/all.h b/all.h index a9a5776..39f3701 100644 --- a/all.h +++ b/all.h @@ -353,6 +353,9 @@ struct Tmp { } width; int visit; uint gcmbid; + int ifckl; + Ref ifct; + Ref ifcf; }; struct Con { @@ -561,18 +564,26 @@ int storesz(Ins *); void loadopt(Fn *); /* ssa.c */ -void adduse(Tmp *tmp, int ty, Blk *b, ...); +void adduse(Tmp *, int, Blk *, ...); void filluse(Fn *); void ssa(Fn *); void ssacheck(Fn *); /* gvn.c */ +void replaceuses(Fn *, Ref, Ref); +extern Ref con01[2]; +int iswu1(Fn *, Ref); void gvn(Fn *); /* gcm.c */ int isbad4gcm(Fn *, Blk *b, Ins *); +int istrapping(Ins *); void gcm(Fn *); +/* ifopt.c */ +void fwdempty(Fn *); +void ifconvert(Fn *); + /* simpl.c */ void simpl(Fn *); diff --git a/gcm.c b/gcm.c index c417118..f6779dc 100644 --- a/gcm.c +++ b/gcm.c @@ -63,7 +63,7 @@ isbad4gcm(Fn *fn, Blk *b, Ins *i) } /* ins can trap at runtime */ -static int +int istrapping(Ins *i) { if (i->cls == Ks || i->cls == Kd) diff --git a/gvn.c b/gvn.c index fd5b8ac..ba27be8 100644 --- a/gvn.c +++ b/gvn.c @@ -13,7 +13,7 @@ isdead(Fn *fn, Blk *b) { } /* literal constants 0, 1 */ -static Ref con01[2]; +Ref con01[2]; static int iscon(Con *c, int cls, int64_t v) @@ -452,7 +452,7 @@ replaceuse(Fn *fn, Use* u, Ref r1, Ref r2) } } -static void +void replaceuses(Fn *fn, Ref r1, Ref r2) { Tmp *t1; diff --git a/ifopt.c b/ifopt.c new file mode 100644 index 0000000..1f13295 --- /dev/null +++ b/ifopt.c @@ -0,0 +1,376 @@ +#include "all.h" + +static int +issimplesucc(Blk *b, Blk *s) +{ + if (s != b) + if (s->npred == 1) + if (s->phi == 0) + return 1; + return 0; +} + +static int +issimplejmpsucc(Blk *b, Blk *s) +{ + return issimplesucc(b, s) && s->jmp.type == Jjmp; +} + +static int +isnopins(Blk *b) +{ + Ins *i; + for (i = b->ins; i < &b->ins[b->nins]; i++) + if (i->op != Onop && i->op != Odbgloc) + return 0; + return 1; +} + +static int +isnopsucc(Blk *b, Blk *s) +{ + return issimplejmpsucc(b, s) && isnopins(s); +} + +/* remove marked-dead blks */ +static void +killblks(Fn *fn) +{ + Blk **pb; + + for (pb = &fn->start; *pb;) + if (fn->rpo[(*pb)->id]) + pb = &(*pb)->link; + else + *pb = (*pb)->link; +} + +static void +replacepred(Blk **blks, uint nblk, Blk *to, Blk *from) +{ + uint n; + for(n=0; n<nblk; n++) + if (blks[n] == from) { + blks[n] = to; + break; + } + assert(n != nblk); +} + +static void +replacepreds(Blk *s, Blk *to, Blk *from) +{ + Phi *p; + + if (!s) + return; + assert(s->npred); + replacepred(s->pred, s->npred, to, from); + for (p = s->phi; p; p = p->link) { + assert(p->narg == s->npred); + replacepred(p->blk, p->narg, to, from); + } +} + +/* collapse jmp/jnz thru empty blks */ +/* needs rpo pred, breaks cfg use */ +void +fwdempty(Fn *fn) +{ + uint bid; + Blk *b; + int s1nop, s2nop; + + if (debug['J']) + fputs("\n> Forwarding jmp/jnz through empty blocks:\n", stderr); + + for (bid = 0; bid < fn->nblk; ) { + b = fn->rpo[bid]; + if (b == 0) + goto Skip; + /* jmp thru empty blk */ + if (b->jmp.type == Jjmp) + if (b->s1->jmp.type == Jjmp) + if (isnopsucc(b, b->s1)) { + if (debug['J']) + fprintf(stderr, " forwarding @%s->@%s->@%s to @%s->@%s\n", b->name, b->s1->name, b->s1->s1->name, b->name, b->s1->s1->name); + fn->rpo[b->s1->id] = 0; /* mark dead */ + replacepreds(b->s1->s1, b, b->s1); + b->s1 = b->s1->s1; + continue; + } + if (b->jmp.type != Jjnz) + goto Skip; + /* jnz to common target */ + if (b->s1 == b->s2) { + if (debug['J']) + fprintf(stderr, " collapsing @%s -> @%s, @%s to @%s -> @%s\n", b->name, b->s1->name, b->s2->name, b->name, b->s1->name); + b->jmp.type = Jjmp; + b->jmp.arg = R; + b->s2 = 0; + goto Skip; + } + s1nop = isnopsucc(b, b->s1); + s2nop = isnopsucc(b, b->s2); + /* jnz - both sides thru empty blks, jumping to same blk */ + if (s1nop && s2nop) + if (b->s1->s1 == b->s2->s1) + if (!b->s1->s1->phi) { + if (debug['J']) + fprintf(stderr, " forwarding @%s -> @%s, @%s to @%s -> @%s\n", b->name, b->s1->name, b->s2->name, b->name, b->s1->s1->name); + fn->rpo[b->s1->id] = 0; /* mark dead */ + fn->rpo[b->s2->id] = 0; /* mark dead */ + edgedel(b->s2, &b->s2->s1); + replacepreds(b->s1->s1, b, b->s1); + b->jmp.type = Jjmp; + b->jmp.arg = R; + b->s1 = b->s1->s1; + b->s2 = 0; + goto Skip; + } + /* jnz - left side thru empty blk */ + if (s1nop) + if (!b->s1->s1->phi) { + if (debug['J']) + fprintf(stderr, " forwarding @%s -> @%s to @%s -> @%s\n", b->name, b->s1->name, b->name, b->s1->s1->name); + replacepreds(b->s1->s1, b, b->s1); + fn->rpo[b->s1->id] = 0; /* mark dead */ + b->s1 = b->s1->s1; + continue; + } + /* jnz - right side thru empty blk */ + if (s2nop) + if (!b->s2->s1->phi) { + if (debug['J']) + fprintf(stderr, " forwarding @%s -> @%s to @%s -> @%s\n", b->name, b->s2->name, b->name, b->s2->s1->name); + replacepreds(b->s2->s1, b, b->s2); + fn->rpo[b->s2->id] = 0; /* mark dead */ + b->s2 = b->s2->s1; + continue; + } + Skip:; + bid++; + } + + killblks(fn); + + if (debug['J']) { + fprintf(stderr, "\n> After forwarding jmp/jnz through empty blocks:\n\n"); + printfn(fn, stderr); + } +} + +/* (otherwise) isolated if-then[-else] graphlet */ +static int +issimpleif(Blk *b, Blk **ppredt, Blk **ppredf, Blk **pjoin) +{ + int simples1, simples2; + + if (b->jmp.type != Jjnz) + return 0; + simples1 = issimplejmpsucc(b, b->s1); + simples2 = issimplejmpsucc(b, b->s2); + + /* diamond */ + if (simples1) + if (simples2) + if (b->s1->s1 == b->s2->s1) + if (b->s1->s1->npred == 2) { + *ppredt = b->s1; + *ppredf = b->s2; + *pjoin = b->s1->s1; + return 1; + } + /* left triangle */ + if (simples1) + if (b->s1->s1 == b->s2) + if (b->s2->npred == 2) { + *ppredt = b->s1; + *ppredf = b; + *pjoin = b->s2; + return 1; + } + /* right triangle */ + if (simples2) + if (b->s1 == b->s2->s1) + if (b->s1->npred == 2) { + *ppredt = b; + *ppredf = b->s2; + *pjoin = b->s1; + return 1; + } + + return 0; +} + +#define MAX_HOIST_NINS 2 +/* small enough and no "fixed" instructions */ +static int +ishoistable(Blk *b) +{ + uint n; + Ins *i; + n = 0; + for (i = b->ins; i < &b->ins[b->nins]; i++) { + if (i->op == Onop || i->op == Odbgloc) + continue; + n++; + if (optab[i->op].ispinned || istrapping(i)) + return 0; + } + return n <= MAX_HOIST_NINS; +} + +#define MAX_IFCONV_NPHIS 2 +/* phis are all integer and not too many of them */ +static int +canifconvphis(Blk *b, uint *pnphis, int *needkl) +{ + Phi *p; + + *pnphis = 0; + *needkl = 0; + for (p = b->phi; p; p = p->link) { + if (KBASE(p->cls) != 0) + return 0; + if (p->cls == Kl) + *needkl = 1; + (*pnphis)++; + } + + return *pnphis <= MAX_IFCONV_NPHIS; +} + +static void +setdef(Fn *fn, uint bid, Ref r, Ins *i) +{ + Tmp *t; + + assert(rtype(r) == RTmp); + t = &fn->tmp[r.val]; + t->bid = bid; + t->def = i; +} + +/* If-convert small if-then[-else] using bitmasks. */ +/* needs rpo pred use; breaks cfg use */ +void +ifconvert(Fn *fn) +{ + uint bid; + Blk *b, *predt, *predf, *join; + uint n, nins, nphis; + int needkl; + Ins *ins; + Tmp *t; + Ref boolv, maskt, maskf, val0, val1; + Phi *p; + + if (debug['K']) + fputs("\n> If-conversion:\n", stderr); + + for (bid = 0; bid < fn->nblk; bid++) { + b = fn->rpo[bid]; + if (!b) + continue; /* already dead */ + if (!issimpleif(b, &predt, &predf, &join)) + continue; + if (predt == predf) + continue; + if (predt != b && !ishoistable(predt)) + continue; + if (predf != b && !ishoistable(predf)) + continue; + if (!canifconvphis(join, &nphis, &needkl)) + continue; + if (debug['K']) + fprintf(stderr, " $%s If-converting @%s - true-pred @%s false-pred @%s join @%s - %u phis\n", fn->name, b->name, predt->name, predf->name, join->name, nphis); + /* Note iswu1() needs up-to-date t->bid, t->def */ + /* TODO - can handle this case with insertion of extra cmp */ + if (!iswu1(fn, b->jmp.arg)) { + if (debug['K']) + fprintf(stderr, " but bailing cos not wu1\n"); + continue; + } + + nins = b->nins; + if (predt != b) + nins += predt->nins; + if (predf != b) + nins += predf->nins; + nins += needkl + 2 + 3*nphis; /*worst case*/ + + ins = alloc(nins * sizeof ins[0]); + n = 0; + memcpy(&ins[n], b->ins, b->nins * sizeof ins[0]); + n += b->nins; + if (predt != b) { + memcpy(&ins[n], predt->ins, predt->nins * sizeof ins[0]); + n += predt->nins; + } + if (predf != b) { + memcpy(&ins[n], predf->ins, predf->nins * sizeof ins[0]); + n += predf->nins; + } + assert(rtype(b->jmp.arg) == RTmp); + t = &fn->tmp[b->jmp.arg.val]; + if (nphis != 0 && (req(t->ifct, R) || (needkl && !t->ifckl))) { + /* extend boolean val to Kl - not always necessary */ + if (needkl) { + boolv = newtmp("ifc", Kl, fn); + ins[n++] = (Ins) {.op = Oextuw, .cls = Kl, .to = boolv, .arg = {b->jmp.arg}}; + } else + boolv = b->jmp.arg; + /* create a mask for the "true" branch */ + maskt = newtmp("ifc", (needkl ? Kl : Kw), fn); + ins[n++] = (Ins) {.op = Oneg, .cls = (needkl ? Kl : Kw), .to = maskt, .arg = {boolv}}; + setdef(fn, bid, maskt, &ins[n-1]); + /* create a mask for the "false" branch */ + maskf = newtmp("ifc", (needkl ? Kl : Kw), fn); + ins[n++] = (Ins) {.op = Osub, .cls = (needkl ? Kl : Kw), .to = maskf, .arg = {boolv, con01[1]}}; + setdef(fn, bid, maskf, &ins[n-1]); + t = &fn->tmp[b->jmp.arg.val]; /* might have moved */ + t->ifckl = needkl; + t->ifct = maskt; + t->ifcf = maskf; + } else { + maskt = t->ifct; + maskf = t->ifcf; + } + + /* if-convert the phis */ + for (p = join->phi; p; p = p->link) { + assert(p->narg == 2); + val0 = newtmp("ifc", p->cls, fn); + ins[n++] = (Ins) {.op = Oand, .cls = p->cls, .to = val0, .arg = {(p->blk[0] == predt ? maskt : maskf), p->arg[0]}}; + setdef(fn, bid, val0, &ins[n-1]); + val1 = newtmp("ifc", p->cls, fn); + ins[n++] = (Ins) {.op = Oand, .cls = p->cls, .to = val1, .arg = {(p->blk[1] == predt ? maskt : maskf), p->arg[1]}}; + setdef(fn, bid, val1, &ins[n-1]); + ins[n++] = (Ins) {.op = Oor, .cls = p->cls, .to = p->to, .arg = {val0, val1}}; + setdef(fn, bid, p->to, &ins[n-1]); + } + + /* Fix up the CFG */ + if (predt != b) + fn->rpo[predt->id] = 0; /* mark dead */ + if (predf != b) + fn->rpo[predf->id] = 0; /* mark dead */ + b->ins = ins; + b->nins = n; + b->jmp.type = Jjmp; + b->jmp.arg = R; + b->s1 = join; + b->s2 = 0; + join->npred = 1; + join->pred[0] = b; + join->phi = 0; + } + + killblks(fn); + + if (debug['K']) { + fprintf(stderr, "\n> After if-conversion:\n\n"); + printfn(fn, stderr); + } +} diff --git a/main.c b/main.c index 0e59f4d..2f23728 100644 --- a/main.c +++ b/main.c @@ -74,12 +74,31 @@ func(Fn *fn) coalesce(fn); filluse(fn); ssacheck(fn); + fwdempty(fn); + fillrpo(fn); + fillpreds(fn); + filluse(fn); + filldom(fn); + ssacheck(fn); gvn(fn); fillpreds(fn); filluse(fn); filldom(fn); gcm(fn); + fillrpo(fn); + fillpreds(fn); + filluse(fn); + fwdempty(fn); + fillrpo(fn); + fillpreds(fn); + filluse(fn); + filldom(fn); + ssacheck(fn); + ifconvert(fn); + fillrpo(fn); + fillpreds(fn); filluse(fn); + filldom(fn); ssacheck(fn); T.abi1(fn); simpl(fn); -- 2.34.1
WARNING - buggy - see https://lists.sr.ht/~mpu/qbe/%3CCAAS8gYBG=N4EqBs50Y4m_TojRRxdDaGYoVZ0kD07m7cewLFK5w@mail.gmail.com%3E