AVX-512でFizzBuzz

はじめに

なんか最近FizzBuzzが流行ってるみたいなんで、便乗してAVX-512を使ってFizzBuzzやってみます。

方針

  • AVX-512の命令を無駄に使って16個ずつ処理する
  • Fizz,Buzz,FizzBuzzを出力するかわりに、-1,-2,-3をメモリのバッファに書き込む

Fizzの処理

AVX2までは、YMMなどのレジスタの最上位ビットを使ってマスク処理をしていましたが、AVX-512にはk0からk7まで、opmask registerというレジスタがあり、これを使ってマスク処理できます。AVX-512Fでは16bitです。

例えば3の倍数のインデックスだけマスク処理したい場合、

__mmask16 s3 = (1 << 2) + (1 << 5) + (1 << 8) + (1 << 11) + (1 << 14);

というような16bitのマスクを作って、そこだけblendすれば良いことになります。

例えば

int s[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
__m512i a = _mm512_load_epi32(s);

で、ZMMレジスタに1から16の数字が乗ります。また、

__m512i fizz = _mm512_set1_epi32(-1);

で、-1が16個乗ったZMMレジスタを作ることができます。

これを、マスクを使ってブレンドします。

__m512i r = _mm512_mask_blend_epi32(s3, a, fizz);

すると、rには「1,2,-1,4,5,-1,7,8,-1,10,11,-1,13,14,-1,16」という数字列が入るので、これをメモリに書き戻してやればよいことになります。これで16個同時にFizzの処理が完了です。

さて、次のループでは、先程のマスクレジスタの値を変えてやる必要があります。先程のレジスタのビット列は

0100100100100100 // 16〜1のマスク

でした。いま、1〜16の16個処理が終わり、次のループは17から始まるため、それに合わせてマスクレジスタを変更する必要があります。具体的には16は3で割ると1余るので1つ右にシフトします。ただし、3回に一度、最上位bitを立ててやる必要があります。こんな感じです。

0100100100100100  // 16〜1のマスク
0010010010010010  // 32〜17のマスク
1001001001001001  // 48〜33のマスク
...

なんかビット演算を使ってエレガントにできそうな気がしますが、手抜きでこんな感じにしましょう。

s3 = (s3 >> 1) | ((i % 3 == 1) << 15);

iはループカウンタです。

全く同様にしてBuzzもできます。

FizzBuzzの処理

さて、3の倍数かつ5の倍数の場合にはFizzBuzzを表示するのでした。ここでは当該インデックスの場所に-3を書き込むことにしましょう。

いま、3の倍数マスクレジスタがs3に、5の倍数マスクレジスタがs5に入ってるとすると、15の倍数レジスタはその論理積で与えられます。

__mmask16 s15 = _mm512_kand(s3, s5);

AVX-512では、opmask register同士の簡単な演算が定義されています。この場合はkandwが呼ばれます。こうしてできたs15を使ってブレンドをすればFizzBuzzができたことになります。

ソース

ソースコードはこんな感じになりました。

#include <x86intrin.h>
#include <cstdio>

const int N = 65536;
int data[N] = {};

int
main(void) {
  int s[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  __mmask16 s3 = (1 << 2) + (1 << 5) + (1 << 8) + (1 << 11) + (1 << 14);
  __mmask16 s5 = (1 << 4) + (1 << 9) + (1 << 14);
  __m512i a = _mm512_load_epi32(s);
  __m512i d = _mm512_set1_epi32(16);
  __m512i fizz = _mm512_set1_epi32(-1);
  __m512i buzz = _mm512_set1_epi32(-2);
  __m512i fizzbuzz = _mm512_set1_epi32(-3);
  for (int i = 0; i < (N >> 4); i++) {
    __m512i r = _mm512_mask_blend_epi32(s3, a, fizz);
    r = _mm512_mask_blend_epi32(s5, r, buzz);
    __mmask16 s15 = _mm512_kand(s3, s5);
    r = _mm512_mask_blend_epi32(s15, r, fizzbuzz);
    _mm512_store_epi32(data + i * 16, r);
    a = a + d;
    s3 = (s3 >> 1) | ((i % 3 == 1) << 15);
    s5 = (s5 >> 1) | ((i % 5 == 3) << 15);
  }
//結果の表示
  for (int i = 0; i < N; i++) {
    if (data[i] == -3) printf("FizzBuzz\n");
    else if (data[i] == -1) printf("Fizz\n");
    else if (data[i] == -2) printf("Buzz\n");
    else printf("%d\n", data[i]);
  }
}

「結果の表示のところで事実上FizzBuzzやってるのと同じじゃね?」とか言わないのが大人です。

アセンブリも見てみましょうか。

..B1.3:
        movl      $1431655766, %eax                             #24.28
        imull     %esi                                          #24.28
        movl      $1717986919, %eax                             #25.28
        lea       (%rdx,%rdx,2), %r9d                           #24.28
        imull     %esi                                          #25.28
        kmovw     %r8d, %k1                                     #18.17
        negl      %r9d                                          #24.28
        kmovw     %ecx, %k2                                     #19.9
        vpblendmd %zmm2, %zmm4, %zmm5{%k1}                      #18.17
        vpaddq    %zmm3, %zmm4, %zmm4                           #23.13
        kandw     %k2, %k1, %k3                                 #20.21
        vpblendmd %zmm1, %zmm5, %zmm6{%k2}                      #19.9
        sarl      $1, %edx                                      #25.28
        addl      %esi, %r9d                                    #24.28
        xorl      %r10d, %r10d                                  #24.28
        cmpl      $1, %r9d                                      #24.5
        vpblendmd %zmm0, %zmm6, %zmm7{%k3}                      #21.9
        sete      %r10b                                         #24.5
        lea       (%rdx,%rdx,4), %r11d                          #25.28
        negl      %r11d                                         #25.28
        xorl      %eax, %eax                                    #25.28
        addl      %esi, %r11d                                   #25.28
        incl      %esi                                          #17.33
        cmpl      $3, %r11d                                     #25.5
        vmovups   %zmm7, data(%rdi)                             #22.24
        sete      %al                                           #25.5
        addq      $64, %rdi                                     #17.33
        sarl      $1, %r8d                                      #24.17
        shll      $15, %r10d                                    #24.39
        sarl      $1, %ecx                                      #25.17
        orl       %r10d, %r8d                                   #24.39
        shll      $15, %eax                                     #25.39
        orl       %eax, %ecx                                    #25.39
        cmpl      $4096, %esi                                   #17.23
        jl        ..B1.3        # Prob 99%                      #17.23

vpblendmdとかkandwとか、AVX-512の命令が使われていることがわかります。kmovwはできれば消したいところですが・・・・

まとめ

AVX-512の、主にopmask registerによるマスク処理を使ってFizzBuzzを書いてみました。SIMD化した風になってますが、普通に書くより早くなってるかどうかは知りません(試してないけど、多分遅い)。

参考

参考というか、これ見てこの記事を書きました。そっちを見たほうがよほど参考になる。