AVX-512で非ゼロ要素をパック

はじめに

配列に入ったデータのうち、非ゼロ要素だけ前に詰めたい。

つまり、こんなデータが与えられた時に、

0 0 1 2 0 0 0 3 0 0 0 0 0 0 0 4 0 5 6 7 0 0 8 9 0 0 10 0 0 11 12 13 

以下を出力したい。

1 2 3 4 5 6 7 8 9 10 11 12 13 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 

データがdataSIZE個入ってるとして、resultにパックしたデータを作る関数のシリアル版はこんな感じになるだろう。

void
pack_serial(void) {
  int pos = 0;
  for (int i = 0; i < SIZE; i++) {
    if (data[i] != 0) {
      result[pos] = data[i];
      pos++;
    }
  }
}

これをAVX-512を使って実装してみる。高速化というか、AVX-512の命令を使ってみたいだけです。

方針

せっかく512bit使えるので、16個同時に持ってきて、それを一度にパックしたい。その為には_mm512_permutevar_epi32を使えば良い。で、16個のデータを詰めて並び替える方法は2 ** 16 = 65536通りあるので、事前にテーブルを作ってしまおう。そのテーブルをoffset16とする。そのオフセットのインデックスを作るのには、非ゼロの要素のみビットが立ったマスクを作れば良い。これは_mm512_test_epi32_maskでできる。

まとめると、

  1. _mm512_loadu_si512 でデータを16個取ってくる
  2. _mm512_test_epi32_maskで非ゼロ要素のマスクを作成
  3. _mm512_loadu_si512でオフセットテーブルからオフセットを持ってくる
  4. _mm512_permutevar_epi32でオフセットを利用して並び替える
  5. _mm512_store_si512でデータをstore
  6. _mm_popcnt_u32でマスクのビットが立っている数だけ、ストア位置をずらす

という感じ。

例えば

0 0 1 2 0 0 0 3 0 0 0 0 0 0 0 4

という16個のデータを持ってきたら、ここから

0011000100000001

というマスクを作って、そこから

2 3 7 15 0 0 0 0 0 0 0 0 0 0 0 0 

というオフセットをテーブルから引いてきて、このオフセットを使って並び替えると、

1 2 3 4 0 0 0 0 0 0 0 0 0 0 0 0 

となる。これを繰り返せば良い。

これをそのまま実装するとこんな感じ。

void
pack_512(void) {
  int pos = 0;
  for (int i = 0; i < SIZE / 16; i++) {
    __m512i vdata = _mm512_loadu_si512(data + i * 16);
    __mmask16 vmask = _mm512_test_epi32_mask(vdata, vdata);
    __m512i voffset = _mm512_load_si512((__m512i const *)(offset16 + vmask * 16));
    __m512i vout = _mm512_permutevar_epi32(voffset, vdata);
    _mm512_store_si512((__m512i *)(result2 + pos), vout);
    pos += _mm_popcnt_u32(vmask);
  }
  for (int i = pos; i < pos + 16 && i < SIZE; i++) {
    result2[i] = 0;
  }
}

最後にゴミ掃除をつけてあるが、素直な実装なので難しいところは無いと思う。

3/28 追記: これ、データをアラインしてあっても、_mm512_store_si51の引数のresult2 + posがアラインされてるとは限らないから、状況によってSIGSEGVが出ますね。代わりに_mm512_storeu_si512を使わないとダメです。256bit版も同様。

256ビット版

上記だと、65536要素のテーブルを引いてるのがちょっとアレな気がするので、16個データを取ってはくるんだけど、それを8個ずつにわけてパックすることを考える。これだとテーブルは256要素で良くなる。

zmmにロードしたデータの下位256bitを_mm512_castsi512_si256で、上位256bitを_mm512_extracti64x4_epi64で取ったら、やることは512bit版とほぼ同じ。マスクの上位8bitと下位8bitの分離も必要。実装はこんな感じになる。

void
pack_256(void) {
  int pos = 0;
  for (int i = 0; i < SIZE / 16; i++) {
    __m512i vdata = _mm512_loadu_si512(data + i * 16);
    __mmask16 vmask = _mm512_test_epi32_mask(vdata, vdata);
    __m256i vlow = _mm512_castsi512_si256(vdata);
    __m256i vhigh = _mm512_extracti64x4_epi64(vdata, 1);
    int mask_low = vmask & 255;
    int mask_high = vmask >> 8;
    __m256i voffset = _mm256_load_si256((__m256i const *)(offset8 + mask_low * 8));
    __m256i vout_low = _mm256_permutevar8x32_epi32(vlow, voffset);
    voffset = _mm256_load_si256((__m256i const *)(offset8 + mask_high * 8));
    __m256i vout_high = _mm256_permutevar8x32_epi32(vhigh, voffset);
    _mm256_store_si256((__m256i *)(result3 + pos), vout_low);
    pos += _mm_popcnt_u32(mask_low);
    _mm256_store_si256((__m256i *)(result3 + pos), vout_high);
    pos += _mm_popcnt_u32(mask_high);
  }
  for (int i = pos; i < pos + 16 && i < SIZE; i++) {
    result3[i] = 0;
  }
}

512 bit + vpcompressd 版(2/10 追記)

コメント欄にて、vpcompressdという、そのものずばりの命令を教えていただいた。マスクレジスタの立ってるところだけ詰めた配列を作ってくれる。対応する組み込み関数は_mm512_mask_compressstoreu_epi32で、これを使うとpackルーチンはこう書ける。

void
pack_512c(void) {
  int pos = 0;
  for (int i = 0; i < SIZE / 16; i++) {
    __m512i vdata = _mm512_loadu_si512(data + i * 16);
    __mmask16 vmask = _mm512_test_epi32_mask(vdata, vdata);
    _mm512_mask_compressstoreu_epi32(result4 + pos, vmask, vdata);
    pos += _mm_popcnt_u32(vmask);
  }
}

テーブル引き不要。出力先をゼロクリアしておけばゴミ処理も不要。簡単すぎる・・・

結果 (2/10 追記)

131072個の配列のうち、半分だけ非ゼロ要素を作って、それを1000回パックするのにかかった時間を計測した。

  • Intel(R) Xeon Phi(TM) CPU 7250 @ 1.40GHz
  • icpc (ICC) 17.0.1 20161005
  • コンパイルオプション -std=c++11 -O3
方法 時間 [ms]
シリアル 1107
512 bit 291
256 bit 124
512 bit + vpcompressd 86

65536要素テーブル引き一回より、256要素テーブル引き二回やった方が倍近く早いけれど、vpcompressd使った方がさらに早い。

まとめ

適当にやったのでどこがボトルネックなのかも調べていないが、まぁシリアル版よりは10倍近く早くなったからよかった。SIMD化云々というよりは、単にメモリアクセスが減ったのが高速化の要因な気もする。あと、マスクレジスタ便利な気がする。

手抜きだけど、一応ソースは以下においておきます(vpcompressd版も追加)。

https://gist.github.com/kaityo256/c5e7a02eef60e98fe8b5b08638476825