LJの力計算のSIMD化ステップ・バイ・ステップ その5

はじめに

せっかくがんばってSIMD化したのに、速度が出なかった。このまま引き下がるわけにはいかない。

コードは https://github.com/kaityo256/lj_simdstep においてある。

何が悪かった?

やっぱり、ペア毎にループを回すのはダメだ。今早いコードは「i粒子でソートしたものをSIMD化した」コードだが、スカラで最速のコードは「i粒子でソートしたものをソフトウェアパイプライニングしたもの」なのだから「i粒子でソートしたものをソフトウェアパイプライニングしたものをSIMD化」するしかない。

ただし、ループ長が短いので、前回のようなごまかしは効かない。ちゃんとループの中で先読み計算したものを、ループの外で利用した上で、残りのゴミを計算しないといけない。面倒だがやるしかない。

コード

というわけでやってみた。SIMD化の元となるコードは、「i粒子でソートしたものをソフトウェアパイプライニングしたもの」で、以下のようなコードである。

void
force_sorted_swp(void) {
  const int pn = particle_number;
  for (int i = 0; i < pn; i++) {
    const double qx_key = q[i][X];
    const double qy_key = q[i][Y];
    const double qz_key = q[i][Z];
    double pfx = 0;
    double pfy = 0;
    double pfz = 0;
    const int kp = pointer[i];
    int ja = sorted_list[kp];
    double dxa = q[ja][X] - qx_key;
    double dya = q[ja][Y] - qy_key;
    double dza = q[ja][Z] - qz_key;
    double df = 0.0;
    double dxb = 0.0, dyb = 0.0, dzb = 0.0;
    int jb = 0;

    const int np = number_of_partners[i];
    for (int k = kp; k < np + kp; k++) {
      const double dx = dxa;
      const double dy = dya;
      const double dz = dza;
      double r2 = (dx * dx + dy * dy + dz * dz);
      const int j = ja;
      ja = sorted_list[k + 1];
      dxa = q[ja][X] - qx_key;
      dya = q[ja][Y] - qy_key;
      dza = q[ja][Z] - qz_key;
      if (r2 > CL2)continue;
      pfx += df * dxb;
      pfy += df * dyb;
      pfz += df * dzb;
      p[jb][X] -= df * dxb;
      p[jb][Y] -= df * dyb;
      p[jb][Z] -= df * dzb;
      const double r6 = r2 * r2 * r2;
      df = ((24.0 * r6 - 48.0) / (r6 * r6 * r2)) * dt;
      jb = j;
      dxb = dx;
      dyb = dy;
      dzb = dz;
    }
    p[jb][X] -= df * dxb;
    p[jb][Y] -= df * dyb;
    p[jb][Z] -= df * dzb;
    p[i][X] += pfx + df * dxb;
    p[i][Y] += pfy + df * dyb;
    p[i][Z] += pfz + df * dzb;
  }
}

これもいい加減長いが、さらに4倍展開してSIMD化する。その作業はこれまでたどった道なので書かない。結果がこちら。

void
force_sorted_swp_intrin(void) {
  const int pn = particle_number;
  const v4df vzero = _mm256_set_pd(0, 0, 0, 0);
  const v4df vcl2 = _mm256_set_pd(CL2, CL2, CL2, CL2);
  const v4df vc24 = _mm256_set_pd(24 * dt, 24 * dt, 24 * dt, 24 * dt);
  const v4df vc48 = _mm256_set_pd(48 * dt, 48 * dt, 48 * dt, 48 * dt);
  for (int i = 0; i < pn; i++) {
    const v4df vqi = _mm256_load_pd((double*)(q + i));
    v4df vpf = _mm256_set_pd(0.0, 0.0, 0.0, 0.0);
    const int kp = pointer[i];
    int ja_1 = sorted_list[kp];
    int ja_2 = sorted_list[kp + 1];
    int ja_3 = sorted_list[kp + 2];
    int ja_4 = sorted_list[kp + 3];
    v4df vqj_1 = _mm256_load_pd((double*)(q + ja_1));
    v4df vdqa_1 = vqj_1 - vqi;
    v4df vqj_2 = _mm256_load_pd((double*)(q + ja_2));
    v4df vdqa_2 = vqj_2 - vqi;
    v4df vqj_3 = _mm256_load_pd((double*)(q + ja_3));
    v4df vdqa_3 = vqj_3 - vqi;
    v4df vqj_4 = _mm256_load_pd((double*)(q + ja_4));
    v4df vdqa_4 = vqj_4 - vqi;

    v4df vdf = _mm256_set_pd(0.0, 0.0, 0.0, 0.0);

    v4df vdqb_1 = _mm256_set_pd(0.0, 0.0, 0.0, 0.0);
    v4df vdqb_2 = _mm256_set_pd(0.0, 0.0, 0.0, 0.0);
    v4df vdqb_3 = _mm256_set_pd(0.0, 0.0, 0.0, 0.0);
    v4df vdqb_4 = _mm256_set_pd(0.0, 0.0, 0.0, 0.0);

    int jb_1 = 0, jb_2 = 0, jb_3 = 0, jb_4 = 0;
    const int np = number_of_partners[i];
    for (int k = 0; k < (np / 4) * 4; k += 4) {
      const int j_1 = ja_1;
      const int j_2 = ja_2;
      const int j_3 = ja_3;
      const int j_4 = ja_4;
      v4df vdq_1 = vdqa_1;
      v4df vdq_2 = vdqa_2;
      v4df vdq_3 = vdqa_3;
      v4df vdq_4 = vdqa_4;

      ja_1 = sorted_list[kp + k + 4];
      ja_2 = sorted_list[kp + k + 5];
      ja_3 = sorted_list[kp + k + 6];
      ja_4 = sorted_list[kp + k + 7];

      v4df vr2s_1 = vdq_1 * vdq_1;
      v4df vr2t_1 = _mm256_permute4x64_pd(vr2s_1, 201);
      v4df vr2u_1 = _mm256_permute4x64_pd(vr2s_1, 210);
      v4df vr2_1 = vr2s_1 + vr2t_1 + vr2u_1;

      v4df vr2s_2 = vdq_2 * vdq_2;
      v4df vr2t_2 = _mm256_permute4x64_pd(vr2s_2, 201);
      v4df vr2u_2 = _mm256_permute4x64_pd(vr2s_2, 210);
      v4df vr2_2 = vr2s_2 + vr2t_2 + vr2u_2;

      v4df vr2s_3 = vdq_3 * vdq_3;
      v4df vr2t_3 = _mm256_permute4x64_pd(vr2s_3, 201);
      v4df vr2u_3 = _mm256_permute4x64_pd(vr2s_3, 210);
      v4df vr2_3 = vr2s_3 + vr2t_3 + vr2u_3;

      v4df vr2s_4 = vdq_4 * vdq_4;
      v4df vr2t_4 = _mm256_permute4x64_pd(vr2s_4, 201);
      v4df vr2u_4 = _mm256_permute4x64_pd(vr2s_4, 210);
      v4df vr2_4 = vr2s_4 + vr2t_4 + vr2u_4;

      v4df vdf_1 = _mm256_permute4x64_pd(vdf, 0);
      v4df vdf_2 = _mm256_permute4x64_pd(vdf, 85);
      v4df vdf_3 = _mm256_permute4x64_pd(vdf, 170);
      v4df vdf_4 = _mm256_permute4x64_pd(vdf, 255);

      vqj_1 = _mm256_load_pd((double*)(q + ja_1));
      vdqa_1 = vqj_1 - vqi;
      vpf += vdf_1 * vdqb_1;

      v4df vpjb_1 = _mm256_load_pd((double*)(p + jb_1));
      vpjb_1 -= vdf_1 * vdqb_1;
      _mm256_store_pd((double*)(p + jb_1), vpjb_1);

      vqj_2 = _mm256_load_pd((double*)(q + ja_2));
      vdqa_2 = vqj_2 - vqi;
      vpf += vdf_2 * vdqb_2;

      v4df vpjb_2 = _mm256_load_pd((double*)(p + jb_2));
      vpjb_2 -= vdf_2 * vdqb_2;
      _mm256_store_pd((double*)(p + jb_2), vpjb_2);

      vqj_3 = _mm256_load_pd((double*)(q + ja_3));
      vdqa_3 = vqj_3 - vqi;
      vpf += vdf_3 * vdqb_3;

      v4df vpjb_3 = _mm256_load_pd((double*)(p + jb_3));
      vpjb_3 -= vdf_3 * vdqb_3;
      _mm256_store_pd((double*)(p + jb_3), vpjb_3);

      vqj_4 = _mm256_load_pd((double*)(q + ja_4));
      vdqa_4 = vqj_4 - vqi;
      vpf += vdf_4 * vdqb_4;

      v4df vpjb_4 = _mm256_load_pd((double*)(p + jb_4));
      vpjb_4 -= vdf_4 * vdqb_4;
      _mm256_store_pd((double*)(p + jb_4), vpjb_4);

      v4df vr2_13 = _mm256_unpacklo_pd(vr2_1, vr2_3);
      v4df vr2_24 = _mm256_unpacklo_pd(vr2_2, vr2_4);
      v4df vr2 = _mm256_shuffle_pd(vr2_13, vr2_24, 12);
      v4df vr6 = vr2 * vr2 * vr2;
      vdf = (vc24 * vr6 - vc48) / (vr6 * vr6 * vr2);
      v4df mask = vcl2 - vr2;
      vdf = _mm256_blendv_pd(vdf, vzero, mask);

      jb_1 = j_1;
      jb_2 = j_2;
      jb_3 = j_3;
      jb_4 = j_4;
      vdqb_1 = vdq_1;
      vdqb_2 = vdq_2;
      vdqb_3 = vdq_3;
      vdqb_4 = vdq_4;
    }
    v4df vdf_1 = _mm256_permute4x64_pd(vdf, 0);
    v4df vdf_2 = _mm256_permute4x64_pd(vdf, 85);
    v4df vdf_3 = _mm256_permute4x64_pd(vdf, 170);
    v4df vdf_4 = _mm256_permute4x64_pd(vdf, 255);

    v4df vpjb_1 = _mm256_load_pd((double*)(p + jb_1));
    vpjb_1 -= vdf_1 * vdqb_1;
    _mm256_store_pd((double*)(p + jb_1), vpjb_1);

    v4df vpjb_2 = _mm256_load_pd((double*)(p + jb_2));
    vpjb_2 -= vdf_2 * vdqb_2;
    _mm256_store_pd((double*)(p + jb_2), vpjb_2);

    v4df vpjb_3 = _mm256_load_pd((double*)(p + jb_3));
    vpjb_3 -= vdf_3 * vdqb_3;
    _mm256_store_pd((double*)(p + jb_3), vpjb_3);

    v4df vpjb_4 = _mm256_load_pd((double*)(p + jb_4));
    vpjb_4 -= vdf_4 * vdqb_4;
    _mm256_store_pd((double*)(p + jb_4), vpjb_4);

    v4df vpi = _mm256_load_pd((double*)(p + i));
    vpf += vdf_1 * vdqb_1;
    vpf += vdf_2 * vdqb_2;
    vpf += vdf_3 * vdqb_3;
    vpf += vdf_4 * vdqb_4;
    vpi += vpf;
    _mm256_store_pd((double*)(p + i), vpi);
    const double qix = q[i][X];
    const double qiy = q[i][Y];
    const double qiz = q[i][Z];
    double pfx = 0.0;
    double pfy = 0.0;
    double pfz = 0.0;
    for (int k = (np / 4) * 4; k < np; k++) {
      const int j = sorted_list[k + kp];
      double dx = q[j][X] - qix;
      double dy = q[j][Y] - qiy;
      double dz = q[j][Z] - qiz;
      double r2 = (dx * dx + dy * dy + dz * dz);
      double r6 = r2 * r2 * r2;
      double df = ((24.0 * r6 - 48.0) / (r6 * r6 * r2)) * dt;
      if (r2 > CL2) df = 0.0;
      pfx += df * dx;
      pfy += df * dy;
      pfz += df * dz;
      p[j][X] -= df * dx;
      p[j][Y] -= df * dy;
      p[j][Z] -= df * dz;
    }
    p[i][X] += pfx;
    p[i][Y] += pfy;
    p[i][Z] += pfz;
  }
}

うん、ちょっと気が滅入るコード量。普通の気分なら絶対やらなかったな。

結果

というわけで結果。計算しなおしたから、これまでと微妙に数値が変わるけど気にしないこと。

計算手法 実行時間[sec] 備考
pair 8.175502 ペアごとにループをまわしたもの
pair_swp 6.994964 上記をSWPしたもの
pair_swp_intrin 4.273858 上記を4倍展開してSIMD化したもの
sorted 7.072854 i粒子でソートしたもの
sorted_intrin 3.880328 i粒子でソートしたものをSIMD化したもの
sorted_swp 7.072854 i粒子でソートしたものをSWPしたもの
sorted_swp_intrin 3.278851 i粒子でソートしたものをSWPしたものをSIMD化したもの

まぁ若干早くなりましたな。上記は高密度(数密度1.0)の結果だが、低密度(数密度0.5)の結果はこんな感じ。

計算手法 実行時間[sec] 備考
pair 1.692347 ペアごとにループをまわしたもの
pair_swp 1.747897 上記をSWPしたもの
pair_swp_intrin 1.197281 上記を4倍展開してSIMD化したもの
sorted 1.523837 i粒子でソートしたもの
sorted_intrin 1.232721 i粒子でソートしたものをSIMD化したもの
sorted_swp 1.069575 i粒子でソートしたものをSWPしたもの
sorted_swp_intrin 0.965486 i粒子でソートしたものをSWPしたものをSIMD化したもの

まぁ、少しだけSWP(ソフトウェアパイプライニング)の効果はあったかな。

コードは以下に。

https://github.com/kaityo256/lj_simdstep/tree/master/step5

まとめ

というわけでSIMD化してみた。少なくとも馬鹿SIMD化は、データ構造とループ構造が決まっていれば単純作業なので難しくはない。ただし、SIMD化後の性能向上のかなりの部分は「SIMD化前のループ構造」で決まることがわかった。あ、

SIMD化の性能はSIMD化前に決まっている

ってなんか名言っぽくないですか?

おしまい・・・のはずが続いた