/* This program is part of a faithfully-rounded, division-free l2
   vector norm with avoidance of spurious underflow and overflow.
   
   Author: Christoph Lauter,

           Université Pierre et Marie Curie Paris 6, UPMC, LIP6,
	   PEQUAN team.

   implementing the algorithm described in the paper

   Graillat, Lauter, Tang, Yamanaka and Oishi: Efficient calculations
   of faithfully rounded l2-norms of n-vectors.

   This program is

   Copyright (C) 2014 Université Pierre et Marie Curie Paris 6, UPMC,
   LIP6, PEQUAN team.

   This program is free software; you can redistribute it and/or
   modify it under the terms of the GNU Lesser General Public
   License as published by the Free Software Foundation; either
   version 2.1 of the License, or (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Lesser General Public License for more details.

   You should have received a copy of the GNU Lesser General Public
   License along with the GNU C Library; if not, see
   <http://www.gnu.org/licenses/>. 

*/

#include "faithfulnorm.h"

#include <immintrin.h>
#include <stdint.h>

/* Compile me with 

   gcc -std=c99 -O2 -msse3 -c faithfulnormnooverflow.c 

*/



/* Get a declaration of sqrt */
double sqrt(double);

/* Get a declaration of fabs */
double fabs(double);


/* A non-SSE twoAdd
   
   Computes u + v = a + b exactly, where u = RN(a + b) 
   
*/
STATIC INLINE void twoAdd(double * RESTRICT u, double * RESTRICT v, double a, double b) {
  double u1, u2, u3, u4, s, r;

  s = a + b;
  u1 = s - a;
  u2 = s - u1;
  u3 = b - u1;
  u4 = a - u2;
  r = u4 + u3;

  *u = s;
  *v = r;
}

/* A non-SSE fastTwoAdd 

   Computes u + v = a + b exactly, where u = RN(a + b)

   UNDER THE CONDITION THAT abs(a) >= abs(b)

*/
STATIC INLINE void fastTwoAdd(double * RESTRICT u, double * RESTRICT v, double a, double b) {
  double z, s, r;

  s = a + b;
  z = s - a;
  r = b - z;

  *u = s;
  *v = r;
}

/* A non-SSE double-double add 

   Computes u + v approximating ah + al + bh + bl

*/
STATIC INLINE void doubleDoubleAdd(double * RESTRICT u, double * RESTRICT v, 
                                   double ah, double al, double bh, double bl) {
  double th, tl, r;

  twoAdd(&th, &tl, ah, bh);
  r = tl + (al + bl);
  fastTwoAdd(u, v, th, r);
}

/* A non-SSE twoSquare 

   Computes u + v = a^2 exactly, where u = RN(a^2) 

*/
STATIC INLINE void twoSquare(double * RESTRICT u, double * RESTRICT v, double a) {
  STATIC CONST double c = 134217729.0;
  double ap, a1, a2, p, r;

  ap = a*c;
  a1 = (a-ap)+ap;
  a2 = a-a1;

  p = a*a;
  r = (((a1*a1-p)+(a1*a2))+(a2*a1))+(a2*a2);

  *u = p;
  *v = r;
}

STATIC INLINE double scaledSquareRoot(int m, double S) {
  STATIC CONST double sqrtGammaToM[] = { 5.26013590154837350724098988288012866555033980282317e210, /* gamma^-2/2 = (2^-350)^-2 */
                                         2.29349861599007151161082089530208694079656498916828e105, /* gamma^-1/2 = (2^-350)^-1 */
                                         1.0,                                                      /* gamma^0/2  = (2^-350)^0  */
                                         4.3601508761683463371878950054372442606045150559341e-106, /* gamma^1/2  = (2^-350)^1  */
                                         1.9010915662951598235150724058351031092648712063735e-211  /* gamma^2/2  = (2^-350)^2  */ };
  double Z, y;

  Z = sqrt(S);
  y = sqrtGammaToM[m - (-2)] * Z;

  return y;
}

STATIC INLINE double faithfulFinal(int k, double U, double u, double V, double v) {
  STATIC CONST double bound1    = 3.08469742733169167070816004443201143863168634177267e-179; /* beta_lo^2 / eps^3 = 2^-593 */
  STATIC CONST double bound2    = 1.4396524142538228424993723224595141948383030778566e163;   /* beta_hi^2 * eps^2 = 2^542  */
  STATIC CONST double gamma     = 1.9010915662951598235150724058351031092648712063735e-211;  /* gamma             = 2^-700 */
  STATIC CONST double gammaToM1 = 5.26013590154837350724098988288012866555033980282317e210;  /* gamma^-1          = 2^700  */
  int m;
  double S, s;
  
  if ((k == 2) || ((U >= bound1) || (V <= bound2))) /* First two branches of the article pseudo-code */
    return scaledSquareRoot(k,U); 
  
  if (fabs(v) <= bound2) v = 0.0;
  U *= gammaToM1;
  u *= gammaToM1;
  V *= gamma;
  v *= gamma;
  m = k + 1;
  
  doubleDoubleAdd(&S,&s,U,u,V,v);
  
  return scaledSquareRoot(m, S);
}

/* Our faithful norm with 2 bins

   Attention: the number of elements needs to be even!

*/
double faithfulNorm(CONST double * RESTRICT x, unsigned int n) {
  unsigned int i;
  dblcast dbc;
  __m256d xChunk, xAbs, xA, xSMid, xC, absMask, boundA, boundB;
  __m256d maskA, maskB, maskC, maskT, scaleA, scaleC, scaleExt, xExt, xSExt, two700, twoM700, maskCP;
  __m256d splitC, bigOccured, bigOccuredNew;
  __m256d sMid, rMid, sExt, rExt, discardExt;
  __m256d xSMidP, xSMidh, xSMidl, xSMidMP, pMidh, pMidl, sMidh, sMidl, uMid1, uMid2, uMid3, uMid4, tMid;
  __m256d xSExtP, xSExth, xSExtl, xSExtMP, pExth, pExtl, sExth, sExtl, uExt1, uExt2, uExt3, uExt4, tExt;
  __m256d sA, rA, sB, rB, sC, rC;
  double sfA, rfA, sfB, rfB, sfC, rfC;

  /* Help the compiler */
  if (n < 2) return 0.0;

  /* Load a mask that strips the sign of a double */
  dbc.i = 0x7fffffffffffffffull;
  absMask = _mm256_set1_pd(dbc.d);

  /* Load two bounds, 2^326 and 2^-374 */
  dbc.i = 0x5450000000000000ull;
  boundA = _mm256_set1_pd(dbc.d);   /* 2^326 */
  
  dbc.i = 0x2890000000000000ull;
  boundB = _mm256_set1_pd(dbc.d);   /* 2^-374 */

  /* Load two scaling factors 2^-700 and 2^700 */
  dbc.i = 0x1430000000000000ull;
  twoM700 = _mm256_set1_pd(dbc.d);   /* 2^-700 */
  
  dbc.i = 0x6bb0000000000000ull;
  two700 = _mm256_set1_pd(dbc.d);   /* 2^700 */

  /* Load the double precision splitter constant */
  dbc.i = 0x41a0000002000000ull;
  splitC = _mm256_set1_pd(dbc.d);   /* 2^27 + 1 */
  
  /* Initialize the accumulators to zero */
  sMid = _mm256_setzero_pd();
  rMid = _mm256_setzero_pd();
  sExt = _mm256_setzero_pd();
  rExt = _mm256_setzero_pd();

  /* A flag if we ever saw a large (A) bin value */
  bigOccured = _mm256_setzero_pd();
  
  /* Main loop over the vector */
  for (i=0; i<=n-4; i+=4) {
    xChunk = _mm256_loadu_pd(&(x[i]));
    xAbs = _mm256_and_pd(xChunk, absMask);

    maskA = _mm256_cmp_pd(xAbs, boundA, _CMP_GE_OS);
    maskC = _mm256_cmp_pd(xAbs, boundB, _CMP_LT_OS);
    maskT = _mm256_cmp_pd(xAbs, boundB, _CMP_GE_OS);
    maskB = _mm256_andnot_pd(maskA, maskT);
    xA = _mm256_and_pd(maskA, xAbs);
    xSMid = _mm256_and_pd(maskB, xAbs);

    bigOccuredNew = _mm256_or_pd(bigOccured, maskA);
    
    maskCP = _mm256_andnot_pd(bigOccuredNew, maskC);

    xC = _mm256_and_pd(maskCP, xAbs);
    xExt = _mm256_or_pd(xA, xC);

    scaleA = _mm256_and_pd(maskA, twoM700);
    scaleC = _mm256_and_pd(maskCP, two700);
    scaleExt = _mm256_or_pd(scaleA, scaleC);
    xSExt = _mm256_mul_pd(xExt, scaleExt);
   
    discardExt = _mm256_xor_pd(bigOccuredNew, bigOccured);

    bigOccured = bigOccuredNew;

    sExt = _mm256_andnot_pd(discardExt, sExt);
    rExt = _mm256_andnot_pd(discardExt, rExt);

    xSMidP = _mm256_mul_pd(xSMid, splitC);
    xSMidh = _mm256_add_pd(_mm256_sub_pd(xSMid, xSMidP), xSMidP);
    xSMidl = _mm256_sub_pd(xSMid, xSMidh);
    xSMidMP = _mm256_mul_pd(xSMidh, xSMidl);
    pMidh = _mm256_mul_pd(xSMid, xSMid);
    pMidl = _mm256_add_pd(_mm256_add_pd(_mm256_add_pd(_mm256_sub_pd(_mm256_mul_pd(xSMidh, xSMidh), 
							pMidh),
					     xSMidMP),
				  xSMidMP),
		       _mm256_mul_pd(xSMidl, xSMidl));
    sMidh = _mm256_add_pd(sMid, pMidh);
    uMid1 = _mm256_sub_pd(sMidh, sMid);
    uMid2 = _mm256_sub_pd(sMidh, uMid1);
    uMid3 = _mm256_sub_pd(pMidh, uMid1);
    uMid4 = _mm256_sub_pd(sMid, uMid2);
    sMidl = _mm256_add_pd(uMid3, uMid4);
    tMid = _mm256_add_pd(rMid, _mm256_add_pd(pMidl, sMidl));

    sMid = _mm256_add_pd(sMidh, tMid);
    rMid = _mm256_sub_pd(tMid, _mm256_sub_pd(sMid, sMidh));


    xSExtP = _mm256_mul_pd(xSExt, splitC);
    xSExth = _mm256_add_pd(_mm256_sub_pd(xSExt, xSExtP), xSExtP);
    xSExtl = _mm256_sub_pd(xSExt, xSExth);
    xSExtMP = _mm256_mul_pd(xSExth, xSExtl);
    pExth = _mm256_mul_pd(xSExt, xSExt);
    pExtl = _mm256_add_pd(_mm256_add_pd(_mm256_add_pd(_mm256_sub_pd(_mm256_mul_pd(xSExth, xSExth), 
							pExth),
					     xSExtMP),
				  xSExtMP),
		       _mm256_mul_pd(xSExtl, xSExtl));
    sExth = _mm256_add_pd(sExt, pExth);
    uExt1 = _mm256_sub_pd(sExth, sExt);
    uExt2 = _mm256_sub_pd(sExth, uExt1);
    uExt3 = _mm256_sub_pd(pExth, uExt1);
    uExt4 = _mm256_sub_pd(sExt, uExt2);
    sExtl = _mm256_add_pd(uExt3, uExt4);
    tExt = _mm256_add_pd(rExt, _mm256_add_pd(pExtl, sExtl));

    sExt = _mm256_add_pd(sExth, tExt);
    rExt = _mm256_sub_pd(tExt, _mm256_sub_pd(sExt, sExth));
  }

  sB = sMid;
  rB = rMid;

  sA = _mm256_and_pd(sExt, bigOccured);
  rA = _mm256_and_pd(rExt, bigOccured);

  sC = _mm256_andnot_pd(bigOccured, sExt);
  rC = _mm256_andnot_pd(bigOccured, rExt);
  
  /* After the main loop, we have to compute an approximation to

     sqrt(2^1400 * (sA1 + rA1) + (sB1 + rB1) + 2^-1400 * (sC1 + rC1) + 
          2^1400 * (sA2 + rA2) + (sB2 + rB2) + 2^-1400 * (sC2 + rC2) +
	  2^1400 * (sA3 + rA3) + (sB3 + rB3) + 2^-1400 * (sC3 + rC3) +
	  2^1400 * (sA4 + rA4) + (sB4 + rB4) + 2^-1400 * (sC4 + rC4)
         )

     where the q1 and q2 are the two entries in the SSE vector q.

     We start by extracting the different components and continue 
     with adding up (with twoAdds and fastTwoAdds) the components 
     of each of the three packages A, B and C.
  */
  
  doubleDoubleAdd(&sfA, &rfA, ((double *) &sA)[0], ((double *) &rA)[0], ((double *) &sA)[1], ((double *) &rA)[1]);
  doubleDoubleAdd(&sfA, &rfA, sfA, rfA, ((double *) &sA)[2], ((double *) &rA)[2]);
  doubleDoubleAdd(&sfA, &rfA, sfA, rfA, ((double *) &sA)[3], ((double *) &rA)[3]);

  doubleDoubleAdd(&sfB, &rfB, ((double *) &sB)[0], ((double *) &rB)[0], ((double *) &sB)[1], ((double *) &rB)[1]);
  doubleDoubleAdd(&sfB, &rfB, sfB, rfB, ((double *) &sB)[2], ((double *) &rB)[2]);
  doubleDoubleAdd(&sfB, &rfB, sfB, rfB, ((double *) &sB)[3], ((double *) &rB)[3]);

  doubleDoubleAdd(&sfC, &rfC, ((double *) &sC)[0], ((double *) &rC)[0], ((double *) &sC)[1], ((double *) &rC)[1]);
  doubleDoubleAdd(&sfC, &rfC, sfC, rfC, ((double *) &sC)[2], ((double *) &rC)[2]);
  doubleDoubleAdd(&sfC, &rfC, sfC, rfC, ((double *) &sC)[3], ((double *) &rC)[3]);

  if (sfA != 0.0) return faithfulFinal(-2,sfA,rfA,sfB,rfB);
  if (sfB != 0.0) return faithfulFinal(0,sfB,rfB,sfC,rfC);
  return faithfulFinal(2,sfC,rfC,0.0,0.0);
}

/* A naive 2-norm */
double naiveNorm(CONST double * RESTRICT x, unsigned int n) {
  double acc, t, y;
  int i;

  acc = 0.0;
  for (i=0; i<n; i++) {
    t = x[i];
    acc += t * t;
  }

  y = sqrt(acc);

  return y;
}

/* A 2-norm implemented like NETLIB does it */
double netlibNorm(CONST double * RESTRICT x, unsigned int n) {
  double scale, ssq, absxi, t;
  int i;

  scale = 0.0;
  ssq = 1.0;

  for (i=0;i<n;i++) {
    if (x[i] != 0.0) {
      absxi = fabs(x[i]);
      if (scale < absxi) {
	t = scale / absxi;
	ssq = 1.0 + ssq * (t * t);
	scale = absxi;
      } else {
	t = absxi / scale;
	ssq = ssq + t * t;
      }
    }
  }
  
  return scale * sqrt(ssq);
}

/* A dummy function call */
double dummyFunc(CONST double * RESTRICT x, unsigned int n) {
  return 0.0;
}

