Loop to Constant Computation

May 3, 2020
int optimize_it(int c, int n) {
  while (c < n) {
    c = c + 3;
  }
  return c;
}

In this article we will see how LLVM constant-folds the loop above, why the optimization works and we will take a sneak peak at the passes that achieve that. The ASM ouputs are x86_64 but the optimizations we discuss are not (in fact, they're target-independent).



What Did Other Compilers Do?

I think it is useful to see how different compilers optimize the same input. Especially when the code is small and the ASM output is mostly comprehensible. The results below are with -O1 (note that -O1 is not the same in all compilers) plus some discussion for higher levels of optimization.

GCC 9.3 Godbolt snippet

test:
        mov     eax, edi
        cmp     edi, esi
        jge     .L2
.L3:
        add     eax, 3
        cmp     esi, eax
        jg      .L3
.L2:
        ret

GCC at -O1 doesn't do anything fascinating. The only thing that happened is that the loop was effectively converted to a do-while loop (i.e. it was rotated, observe that the check happens at the "bottom" of the loop). Same at -O2.

At -O3 (Godbolt snippet) it went crazy. I didn't try to decode all the output, but the idea seems to be that first it does a bunch of (runtime) checks to see if it can unroll the loop (i.e. if there is an upper bound) and if so it branches there:

.L3:
        lea     eax, [rdi+3]
        cmp     esi, eax
        jle     .L1
        lea     eax, [rdi+6]
        cmp     esi, eax
        jle     .L1
        lea     eax, [rdi+9]
        cmp     esi, eax
        jle     .L1
        lea     eax, [rdi+12]
        cmp     esi, eax
        jle     .L1
        lea     eax, [rdi+15]
        cmp     esi, eax
        jle     .L1
        lea     eax, [rdi+18]
        cmp     esi, eax
        jle     .L1
        lea     eax, [rdi+21]
        add     edi, 24
        cmp     esi, eax
        cmovg   eax, edi
        ret

And if not, it uses a vectorized version of the loop:

.L4:
    movdqa  xmm0, xmm1
    add     eax, 1
    paddd   xmm1, xmm3
    paddd   xmm0, xmm2
    cmp     ecx, eax
    jne     .L4

Still, the computation is not constant.

MSVC v19.0 Godbolt snippet

test:
        cmp     ecx, edx
        jge     $LN2@test
        sub     edx, ecx
        mov     eax, -1431655765              ; aaaaaaabH
        dec     edx
        mul     edx
        shr     edx, 1
        lea     eax, DWORD PTR [rcx+rdx]
        lea     ecx, DWORD PTR [rdx*2+3]
        add     ecx, eax
$LN2@test:
        mov     eax, ecx
        ret     0

MSVC has the same output at -O1, -O2 and -Os. And it's a good one. It has done something similar to what we'll later see LLVM doing and it effectively has converted the loop to a constant computation.

ICC 19.0.1 Godbolt snippet

test:
        jmp       ..B1.9        # Prob 100%                     
..B1.3:                         # Preds ..B1.9
        add       edi, 3                                        
..B1.9:                         # Preds ..B1.1 ..B1.3
        cmp       edi, esi                                      
        jl        ..B1.3        # Prob 82%                      
        mov       eax, edi                                      
        ret                                                     

ICC's -O1 is similar to GCC and it's pretty basic. It has converted the loop to a do-while loop too but in a different way. Here there's no "guard". That is, in GCC's output, there's a guard (imagine an if that wraps the do-while loop) that verifies whether the loop will be entered at least once.

ICC has converted it to a do-while loop but jumps straight to the comparison. In C, it's like this:

int test(int c, int n) {
  goto cond;
  do {
    c = c + 3;
cond:
  } while (c < n);
  return c;
}

At -O2 (Godbolt snippet) the output is better. It uses the same ideas as MSVC above and what we'll see in LLVM to convert the loop to a constant computation.

test:
        cmp       edi, esi                                      
        jge       ..B1.3        # Prob 50%                      
        mov       eax, 1431655766                               
        lea       ecx, DWORD PTR [1+rdi]                        
        sub       esi, ecx                                      
        add       esi, 3                                        
        imul      esi                                           
        sar       esi, 31                                       
        sub       edx, esi                                      
        lea       esi, DWORD PTR [3+rdx+rdx*2]                  
        lea       edi, DWORD PTR [-3+rdi+rsi]                   
..B1.3:                         # Preds ..B1.2 ..B1.1
        mov       eax, edi                                      
        ret                                                     

Don't get fooled with these DWORD PTR [], they're not dereferences. It's just the syntax of lea (which does not impose them in general, it depends on the assembler). Before moving on, I should mention that the -O3 output is the same.

What Does LLVM Do?

Clang 10 at -O1 outputs the following (Godbolt snippet) which is the same for all the other levels. It's a constant computation but it seems to be smaller than both ICC's and MSVC's output.

test:                                   # @test
      cmp     esi, edi
      cmovl   esi, edi
      sub     esi, edi
      add     esi, 2
      mov     eax, 2863311531
      imul    rax, rsi
      shr     rax, 33
      lea     eax, [rax + 2*rax]
      add     eax, edi
      ret

Interestingly, it outputs this from version 3.4.1 (Godbolt snippet).

Can We Convert this Loop to a Constant Computation?

Let's first consider a simpler version of this loop:

int test(int c, int n) {
  int c = 0;
  while (c < n) {
    c = c + 1;
  }
  return c;
}

For that it's pretty obvious that we can compute in constant time. Note that we care for the value of c at the end of the loop (i.e. its exit value) and that is obviously n if n > 0 or 0 otherwise.

One other way to express that is that the exit value of c is
exit_value_of_c = max(n, 0);

Now, let's do a small modification. Let's change the step from 1 to 2:

int test(int c, int n) {
  int c = 0;
  while (c < n) {
    c = c + 2;
  }
  return c;
}

It seems that we can still compute it in constant time, but now there are more cases to consider since we may end up on n or we may not. To put it differently, consider that in the previous version, starting from 0 and going upwards by only 1 (and assuming n > 0), it was sure that we would end up on n which made it easy to figure out the exit value of c.

However now, if n is a multiple of 2, we'll end up on it, otherwise not. And what happens in either case?

To simplify the problem, consider that the most important thing for computing the exit value of c is the number of iterations. If one gives us the number of iterations, it's always easy to compute the exit value of c if we know the step, no matter what the step is (and considering it is constant). That is because every time, we add the step to the previous value of c and thus the final value of it should be: initial_value_of_c + number_of_iterations * step.

So, we have to find a way to generally compute the number of iterations for a given initial value of c, a step and an n, in constant time. If we can do that, then we plug the number in the above formula and we're done.

Back to the code with a step of 2 and let's focus on the number of iterations from now on. If n is a multiple of 2, then the number of iterations is n / 2. For example, for n == 4, we'll go 0 -> enter the loop, 2 -> enter the loop, 4 -> don't enter the loop.

In general, it makes sense to do exactly half n iterations.

Now, what happens when n is not a multiple of 2 ?

The idea is that in this case, n == 2k + 1 or in simple words, it is some multiple of 2 plus 1. This is important because remember that previously, when we considered that n is a multiple of 2, we knew that because we would end up on it, we wouldn't get in that loop iteration (e.g. 0, 2, 4 -- for 4 we don't get into the loop).

In the same manner here, starting from 0 and incrementing by 2, we will end up on 2k. But when we do, we'll do one more iteration because of this + 1 that will let us in for one (and only one) iteration. So, in this case, we'll do exactly n/2 + 1 iterations.

So, to sum up, if n is multiple of 2, we do n/2 iterations, otherwise, we do
n/2 + 1. We can express that with one function, ceil(). ceil(x) gives us the ceiling of x e.g. for x = 0.5, ceil(x) = 1. It basically accounts for that +1 iteration in the case where n is not a multiple.

Ok, we found a way to compute the number of iterations when the step is 2. How about other steps? For other steps the idea is the same. If n is a multiple of the step, then we'll do n / step iterations. Otherwise, n = step*k + x. That is, n is a multiple of the step plus something. But, that something is at most step - 1, otherwise we would go to a next multiple.

So, this "plus something" is enough to give us one more iteration but only one. Which means that we generalized our formula: number_of_iterations = ceil(n/step) + 1.

There's one last thing to consider: This works only if initial_value_of_c == 0, but it's quite easy to generalize it. If the initial value is not 0, and is let's say v, it means we're looping from v to n. That is the same as looping from 0 to n - v.

For example, looping from 4 to 7 is like looping from 0 to 3. It's like moving the "window" of iterations (which we can do since we only care for the number of iterations). With this last consideration, generally
number_of_iterations = ceil((n - initial_value_of_c) / step)

That's all great and constant but ceil() in general is a floating-point function and we would like to only use integer computations for that (note that the ASM that LLVM outputs has only integer computations).

ceil(a / b) with Integer Computations

Let's say that we want to compute ceil(a / b). Let's assume for one second that a is not a multiple of b. ceil(a / b) = a / b + 1 in this case but we'll think differently.

Integer division, as it is specified in C (which is another whole story but we'll simplify things for now), truncates the result. For example, 3 / 2 = 1.5 with FP division but 1 with integer division.

The most important thing to realize here is that in a / b, if a is not a multiple of b, the result of the division is like "cutting" a to the previous multiple of b and then doing the division.

For example, 5/2 = 4/2 and 8/3 = 6/3.

Graphically, it looks like this:



The vertical bars are multiples of some b. All numbers in a same-colored area will have the same result when divided by b. This is important because, note that when a is not a multiple, it means it is in some colored area and it is not a vertical bar. As we said, ceil(a / b) = a / b + 1 i.e. I do the division and then add 1. But this is the same as taking a and "moving it" to the next colored area, since any number in the next colored area will be cut in the next vertical bar which effectivelly will give me +1 in the division. For example, for any a in the blue area (except the vertical bar), I can take the ceil by moving it anywhere in the red area (including the vertical red) and then doing the division.

"Movement" in this case intuitively means addition. I want to find a number that will move any number in one area to anywhere in the next. I would like this number with which I'm adding to be the same for all numbers because that will make my formula uncoditional and this is good in computers.

Let's think in multiples of 5: |----|----|----| ...
The vertical bars are multiples of 5 and the dashes are the numbers till the next multiple. We want a number t that will move any dash to the next dashed area and including the next vertical bar.

|----|----|----| ...

For example, I want to move the red dash to anywhere in the green area (and note here that I don't want to move it any further away otherwise I will get a / b + 2 not + 1). I can achieve that by adding to it either 4 (which will get me to the vertical bar) or 5 or 6 or 7 or 8.

This number has to work for all the dashes. All these numbers work for the next of the red except 8 which moves it to the vertical bar after the green area and we don't want that. So, t must be between 4 and 7 inclusive.

If we do the same reasoning for all the dashes, we'll find that the only t that works for all the dashes is 4. So, for multiples of 5, t = 4.

If you follow the same reasoning for multiples of any number, you'll find that in general t = b - 1. And we found this "magic" number that by adding it to a non-multiple a, it gives us the ceil(a / b).

Finally, what happens when a is a multiple of b. In that case, t can't move it to the next area. Multiples of b are the only numbers which t can't move to the next area and this is exactly what we want.

All in all, in general ceil(a / b) = (a + b - 1) / b.

To sum up

number_of_iterations = ceil(n - initial_value_of_c / step) =
(n - initial_value_of_c + step - 1) / step


exit_value_of_c = initial_value_of_c + number_of_iterations*step

A Look at LLVM IR

In general, if you want to look at how a compiler optimizes a piece of code, it's good to get familiar with its Intermediate Representations. The reason for that is that what ends up on assembly is a result of multiple passes, some of which are from the middle-end and some from the back-end. More importantly, the back-end a lot of times obfuscates the code and makes the understanding of the optimization and reasoning of the compiler harder.

In this example we'll see that the middle-end is the one that optimized the loop to a constant computation and the output constant computations happen to have a division. The division is transformed to the standard trick of multiplication and shift by the back end. But, no matter how "standard" this trick is, one might not know it and they'll see some weird multiplications and magic numbers in the output ASM which we'll seem to come out of nowhere. That in turn we'll make it harder for them to understand the actual optimization that in essence had nothing to do with the division.

Don't get me wrong, I love assembly and we should always look at it because this is what is actually executed. But it is not always educational.

Here we'll focus on LLVM and I'll assume familiarity with its basic environment, LLVM IR etc. One standard path I follow to figure out what pass did a particular transformation is the following:

  • First of all, compile the C/C++ source code with Clang and tell it output LLVM IR, i.e. usually: -g0 -emit-llvm -S (note: I use Clang compiled from source which by default outputs IR with better naming. If you have a release build, you can get somewhat better naming by passing it -fno-discard-value-names).
  • Remove attributes (especially optnone) and other irrelevant stuff from the output and pass it to opt with: -sroa. This will convert the code from memory control-flow (i.e. loads / stores) to SSA control-flow (i.e. PHI nodes etc.). It is way more readable that way, plus most optimizations can't work without it.
  • Pass the output again from opt, now with the optimization level e.g. -O1 but also with the argument -print-after-all. This will print the output IR after every pass and you can identify which pass did what.

In this case, we'll see that most of the job was done by -indvars i.e. IndVarSimplify (Godbolt snippet). There are some other passes, namely -simplifycfg -instcombine that simplify the code more but the most important changes are by induction variable simplification.


define i32 @test(i32 %c, i32 %n) {
  %0 = icmp sgt i32 %n, %c
  %smax = select i1 %0, i32 %n, i32 %c
  %1 = add i32 %smax, 2
  %2 = sub i32 %1, %c
  %3 = udiv i32 %2, 3
  %4 = mul nuw i32 %3, 3
  br label %while.cond

while.cond:                                       ; preds = %while.body, %entry
  br i1 false, label %while.body, label %while.end

while.body:                                       ; preds = %while.cond
  br label %while.cond

while.end:                                        ; preds = %while.cond
  %5 = add i32 %c, %4
  ret i32 %5
}

Let's actually run -simplifycfg(Godbolt snippet) to remove these dead blocks that have been left from the the while loop.


define i32 @test(i32 %c, i32 %n) {
  %0 = icmp sgt i32 %n, %c
  %smax = select i1 %0, i32 %n, i32 %c
  %1 = add i32 %smax, 2
  %2 = sub i32 %1, %c
  %3 = udiv i32 %2, 3
  %4 = mul nuw i32 %3, 3
  br label %while.cond

while.cond:                                       ; preds = %while.body, %entry
  br i1 false, label %while.body, label %while.end

while.body:                                       ; preds = %while.cond
  br label %while.cond

while.end:                                        ; preds = %while.cond
  %5 = add i32 %c, %4
  ret i32 %5
}

If we assume for a second that %smax is %n, then LLVM has generated exactly what we computed above.

%smax is only used to account for the case that we never enter the loop and in this case, the initial value of c (i.e. %c in LLVM IR) has to be returned. This is a smart trick from LLVM in that, if it picks the max of n, c. If c is bigger and you plug it in the formula we computed above, it will be zeroed with the - c and thus number_of_iterations = (step - 1) / step which is always 0 and thus what we return is the initial value.

A Sneak Peak Into the LLVM Source Code

At the time of writing this, the whole job is done by rewriteExitValues() that is called in IndVarSimplify::run(). This is a function that uses LCSSA form to find the exit values of the loop (only those that are ever used outside of the loop of course) and then uses SCEV to analyze and then rewrite their values.

It's fascinating to see that SCEV has computed the whole expression we came up with into a SCEV expression. The rest of the code basically turns this expression into code and writes it to the preheader (here entry).