一般の逆数m乗根近似補正公式とAVX512における逆数平方根近似補正

はじめに

AVX-512(より正確にはAVX-512ER)には、vrcp28pdという逆数近似命令の他に、vrsqrt28pdという逆数平方根近似命令がある。やっぱり28bitなので、一度の補正でフル精度出るはずなのでそれを試してみる。ついでに(意味なく)一般のm乗根の逆数の近似値が与えられた時の補正公式を導出してみる。

公式の導出

ある数$a$が与えられた時、$a^{-1/m}$に収束するような数列を作りたい。具体的にはニュートン法

\[x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}\]

において、$x_n$が$a^{-1/m}$に収束するようにしたい。ただし、$a$の逆数$m$乗根近似の補正をするという目的から、右辺に$a$に関する除算が出現してはならないという条件がある。もちろん$a$の$m$乗根も出現してはならない。

まず、ニュートン法の条件として、目的の値になったら$f(x)=0$となる必要がある。従って、$f(a^{-1/m})=0$である。これを満たす関数として、こんな関数形を考える。

\[f(x) = (1-a x^m)g(x)\]

これで条件$f(a^{-1/m})=0$は満たされる。次に、これでニュートン法を構築した時に、$a$に関する除算や$m$乗根が現れないという条件を考える。ニュートン法の公式を見るに、$f’(x)$に$a$が現れなければその条件を満たす。微分を計算してみると、

\[f' = -m a x^{m-1}g + (1- a x^m) g'\]

$a$に関して整理すると、

\[f' = -m a x^{m-1}(mg + xg') + g'\]

これが$a$依存性を持たなければ良いのだから、

\[mg + xg' = 0\]

が満たされれば良い。これは簡単に求積できて

\[g(x) = c x^{-m}\]

である。cは積分定数だが、ニュートン法に影響を与えないので$c=1$にしてしまおう。最終的に

\[f(x) = \frac{1-ax^m}{x^m}\]

となり、ニュートン法のスキームは

\[x_{n+1} = \frac{x_n(m+1 - a x^m)}{m}\]

と、加減乗算のみで構成できた。無論、$m=1$を代入すれば逆数近似の補正公式に一致する。

逆数平方根近似

というわけで逆数平方根近似を試す。まずは無補正。コードはほとんど逆数近似の記事と同じだが、__m512dの型でやると計算をいちいち組み込みを使う必要があって面倒なので

typedef double v8df __attribute__((vector_size(64)));

として、v8dfで計算することにしよう。こんな感じ。

double out[8];

int
main(){
  std::mt19937 mt(1);
  std::uniform_real_distribution<double> ud(0.0,1.0);
  for(int i=0;i<10;i++){
    double a = ud(mt);
    v8df z = _mm512_set1_pd(a);
    v8df zrsqrt = _mm512_rsqrt28_pd(z);
    double arsqrt = sqrt(1.0/a);
    _mm512_storeu_pd(out, zrsqrt);
    bitdump(arsqrt);
    bitdump(out[0]);
    printf("%d\n",bitcomp(arsqrt,out[0]));
    printf("\n");
  }
}

実行結果はこんな感じ。

1.0014105748636146
0 01111111111 0000000001011100011100011000010011101000011001001001
1.0014105745515209
0 01111111111 0000000001011100011100011000001110010001001111100000
29

(snip)

1.0870137260928034
0 01111111111 0001011001000110100010000001001111011111001001110100
1.0870137258219525
0 01111111111 0001011001000110100010000001001010110101010110011000
31

フル精度で計算した結果が「1.0014105748636146」で、逆数平方根近似した値が「1.0014105745515209」で、仮数部が29ビット合っている、という意味。もちろん、前の記事に書いたとおり、単純に一致しているビットを数えているだけなので、本来の精度はもっと高い。とりあえず結果を眺めてみると名前の通り最低28ビット出てるらしいことはわかる。

逆数平方根近似+補正

補正公式に$m=2$を代入すれば逆数平方根近似の補正公式

\(x_{n+1} = \frac{x_n(3 - a x_n^2)}{2}\) を得るので、それで補正するだけ。こんな感じ。

int
main(){
  std::mt19937 mt(1);
  std::uniform_real_distribution<double> ud(0.0,1.0);
  v8df v3 = _mm512_set1_pd(3.0);
  v8df vh = _mm512_set1_pd(0.5);
  for(int i=0;i<10;i++){
    double a = ud(mt);
    v8df z = _mm512_set1_pd(a);
    v8df zrsqrt = _mm512_rsqrt28_pd(z);
    v8df zrsqrt2 = zrsqrt * (v3 - z * zrsqrt * zrsqrt)*vh;
    double arsqrt = sqrt(1.0/a);
    _mm512_storeu_pd(out, zrsqrt2);
    bitdump(arsqrt);
    bitdump(out[0]);
    printf("%d\n",bitcomp(arsqrt,out[0]));
    printf("\n");
  }

実行結果。

1.0014105748636146
0 01111111111 0000000001011100011100011000010011101000011001001001
1.0014105748636146
0 01111111111 0000000001011100011100011000010011101000011001001001
52

(snip)

1.0870137260928034
0 01111111111 0001011001000110100010000001001111011111001001110100
1.0870137260928034
0 01111111111 0001011001000110100010000001001111011111001001110100
52

ちゃんとフル精度出てるみたいですね。

まとめ

無意味に一般の$m$乗根補正公式を導出し1、それを使ってAVX-512ERで追加されたvrsqrt28pd命令と、その精度補正を試してみた。少なくともインテルコンパイラで試した範囲では、積極的に逆数近似、逆数平方根近似を使っているみたいですね。やっぱり一度の補正でフル精度出る、というのはコンパイラにとっても使いやすいんじゃないでしょうか。

  1. 本当はどこかで逆数平方根近似補正公式を先に見つけて、それと逆数近似公式からあたりを付けて「導出」したフリをしただけだけど。