web-dev-qa-db-fra.com

Comparer deux valeurs du formulaire (a + sqrt (b)) aussi rapidement que possible?

Dans le cadre d'un programme que j'écris, je dois comparer deux valeurs sous la forme a + sqrt(b)a et b sont des entiers non signés. Comme cela fait partie d'une boucle serrée, j'aimerais que cette comparaison s'exécute le plus rapidement possible. (Si cela importe, j'exécute le code sur des machines x86-64 et les entiers non signés ne dépassent pas 10 ^ 6. De plus, je sais pertinemment que a1<a2.)

En tant que fonction autonome, c'est ce que j'essaie d'optimiser. Mes nombres sont des entiers suffisamment petits pour que double (ou même float) puissent les représenter exactement, mais une erreur d'arrondi dans les résultats sqrt ne doit pas changer le résultat.

// known pre-condition: a1 < a2  in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);  // computed mathematically exactly
}

Cas de test: is_smaller(900000, 1000000, 900001, 998002) devrait retourner true, mais comme indiqué dans les commentaires par @wim le calculer avec sqrtf() retournerait false. Il en serait de même de (int)sqrt() Pour être tronqué en entier.

a1+sqrt(b1) = 90100 et a2+sqrt(b2) = 901000.00050050037512481206. Le flotteur le plus proche est exactement 90100.


Comme la fonction sqrt() est généralement assez chère même sur les x86-64 modernes lorsqu'elle est entièrement insérée en tant qu'instruction sqrtsd, j'essaie d'éviter d'appeler sqrt() dans la mesure où possible.

La suppression de sqrt par quadrature peut également éviter tout risque d'erreurs d'arrondi en rendant tous les calculs exacts.

Si à la place la fonction était quelque chose comme ça ...

bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
    return a1+sqrt(b1) < x;
}

... alors je pourrais simplement faire return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;

Mais maintenant qu'il y a deux termes sqrt(...), je ne peux pas faire la même manipulation algébrique.

Je pourrais quadriller les valeurs deux fois , en utilisant cette formule:

      a1 + sqrt(b1) = a2 + sqrt(b2)
<==>  a1 - a2 = sqrt(b2) - sqrt(b1)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==>  (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2

La division non signée par 4 est bon marché parce que c'est juste un décalage de bits, mais comme je mets les nombres au carré deux fois, je devrai utiliser des entiers de 128 bits et je devrai introduire quelques vérifications >=0 (Parce que je compare l'inégalité au lieu de l'égalité).

Il semble qu'il puisse y avoir un moyen de le faire plus rapidement, en appliquant une meilleure algèbre à ce problème. Existe-t-il un moyen de le faire plus rapidement?

45
Bernard

Voici une version sans sqrt, mais je ne suis pas sûr qu'elle soit plus rapide qu'une version qui n'en a qu'un sqrt (cela peut dépendre de la distribution des valeurs).

Voici le calcul (comment supprimer les deux sqrts):

ad = a2-a1
bd = b2-b1

a1+sqrt(b1) < a2+sqrt(b2)              // subtract a1
   sqrt(b1) < ad+sqrt(b2)              // square it
        b1  < ad^2+2*ad*sqrt(b2)+b2    // arrange
   ad^2+bd  > -2*ad*sqrt(b2)

Ici, le côté droit est toujours négatif. Si le côté gauche est positif, alors nous devons retourner vrai.

Si le côté gauche est négatif, alors nous pouvons quadriller l'inégalité:

ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2

L'élément clé à noter ici est que si a2>=a1+1000, Alors is_smaller Renvoie toujours true (car la valeur maximale de sqrt(b1) est 1000). Si a2<=a1+1000, Alors ad est un petit nombre, donc ad^4 Tiendra toujours en 64 bits (il n'est pas nécessaire d'avoir une arithmétique de 128 bits). Voici le code:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    if (ad>1000) {
        return true;
    }

    int bd = b2 - b1;
    if (ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;

    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

EDIT: Comme Peter Cordes l'a remarqué, le premier if n'est pas nécessaire, comme le second if le gère, donc le code devient plus petit et plus rapide:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    int bd = b2 - b1;
    if ((long long int)ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;
    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}
19
geza

Je suis fatigué et j'ai probablement fait une erreur; mais je suis sûr que si je l'ai fait, quelqu'un le fera remarquer ..

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a1-a2;   // May be negative

    if(a_diff < 0) {
        if(b1 < b2) {
            return true;
        }
        temp = a_diff+sqrt(b1);
        if(temp < 0) {
            return true;
        }
        return temp*temp < b2;
    } else {
        if(b1 >= b2) {
            return false;
        }
    }
//  return a_diff+sqrt(b1) < sqrt(b2);

    temp = a_diff+sqrt(b1);
    return temp*temp < b2;
}

Si tu sais a1 < a2 alors cela pourrait devenir:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a2-a1;    // Will be positive

    if(b1 > b2) {
        return false;
    }
    if(b1 >= a_diff*a_diff) {
        return false;
    }
    temp = a_diff+sqrt(b2);
    return b1 < temp*temp;
}
4
Brendan

Il existe également une méthode newton pour calculer les sqrts entiers comme décrit ici Une autre approche serait de ne pas calculer la racine carrée, mais de rechercher le plancher (sqrt (n)) via la recherche binaire ... il n'y a "que" 1000 nombres carrés complets inférieurs à 10 ^ 6. Cela a probablement de mauvaises performances, mais serait une approche intéressante. Je n'ai mesuré aucun de ceux-ci, mais voici des exemples:

#include <iostream>
#include <array>
#include <algorithm>        // std::lower_bound
#include <cassert>          


bool is_smaller_sqrt(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt(b1) < a2 + sqrt(b2);
}

static std::array<int, 1001> squares;

template <typename C>
void squares_init(C& c)
{
    for (int i = 0; i < c.size(); ++i)
        c[i] = i*i;
}

inline bool greater(const int& l, const int& r)
{
    return r < l;
}

inline bool is_smaller_bsearch(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    // return a1 + sqrt(b1) < a2 + sqrt(b2)

    // find floor(sqrt(b1)) - binary search withing 1000 elems
    auto it_b1 = std::lower_bound(crbegin(squares), crend(squares), b1, greater).base();

    // find floor(sqrt(b2)) - binary search withing 1000 elems
    auto it_b2 = std::lower_bound(crbegin(squares), crend(squares), b2, greater).base();

    return (a2 - a1) > (it_b1 - it_b2);
}

unsigned int sqrt32(unsigned long n)
{
    unsigned int c = 0x8000;
    unsigned int g = 0x8000;

    for (;;) {
        if (g*g > n) {
            g ^= c;
        }

        c >>= 1;

        if (c == 0) {
            return g;
        }

        g |= c;
    }
}

bool is_smaller_sqrt32(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt32(b1) < a2 + sqrt32(b2);
}

int main()
{
    squares_init(squares);

    // now can use is_smaller
    assert(is_smaller_sqrt(1, 4, 3, 1) == is_smaller_sqrt32(1, 4, 3, 1));
    assert(is_smaller_sqrt(1, 2, 3, 3) == is_smaller_sqrt32(1, 2, 3, 3));
    assert(is_smaller_sqrt(1000, 4, 1001, 1) == is_smaller_sqrt32(1000, 4, 1001, 1));
    assert(is_smaller_sqrt(1, 300, 3, 200) == is_smaller_sqrt32(1, 300, 3, 200));
}
2
StPiere

Je ne sais pas si les manipulations algébriques, en combinaison avec l'arithmétique entière, conduisent nécessairement à la solution la plus rapide. Dans ce cas, vous aurez besoin de nombreuses multiplications scalaires (ce qui n'est pas très rapide) et/ou la prédiction de branche peut échouer, ce qui peut dégrader les performances. Évidemment, vous devrez comparer pour voir quelle solution est la plus rapide dans votre cas particulier.

Une méthode pour rendre le sqrt un peu plus rapide consiste à ajouter l'option -fno-math-errno À gcc ou clang. Dans ce cas, le compilateur n'a pas à vérifier les entrées négatives. Avec icc ce paramètre par défaut.

Une meilleure amélioration des performances est possible en utilisant l'instruction vectorisée sqrtsqrtpd, au lieu de l'instruction scalaire sqrt instruction sqrtsd. Peter Cordes a montré que clang est capable de vectoriser automatiquement ce code, de sorte qu'il génère ce sqrtpd.

Cependant, le succès de la vectorisation automatique dépend fortement des bons paramètres du compilateur et du compilateur utilisé (clang, gcc, icc, etc.). Avec -march=nehalem, Ou plus, clang ne vectorise pas.

Des résultats de vectorisation plus fiables sont possibles avec le code intrinsèque suivant, voir ci-dessous. Pour la portabilité, nous supposons uniquement la prise en charge de SSE2, qui est la ligne de base x86-64.

/* gcc -m64 -O3 -fno-math-errno smaller.c                      */
/* Adding e.g. -march=nehalem or -march=skylake might further  */
/* improve the generated code                                  */
/* Note that SSE2 in guaranteed to exist with x86-64           */
#include<immintrin.h>
#include<math.h>
#include<stdio.h>
#include<stdint.h>

int is_smaller_v5(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    uint64_t a64    =  (((uint64_t)a2)<<32) | ((uint64_t)a1); /* Avoid too much port 5 pressure by combining 2 32 bit integers in one 64 bit integer */
    uint64_t b64    =  (((uint64_t)b2)<<32) | ((uint64_t)b1); 
    __m128i ax      = _mm_cvtsi64_si128(a64);         /* Move integer from gpr to xmm register                  */
    __m128i bx      = _mm_cvtsi64_si128(b64);         
    __m128d a       = _mm_cvtepi32_pd(ax);            /* Convert 2 integers to double                           */
    __m128d b       = _mm_cvtepi32_pd(bx);            /* We don't need _mm_cvtepu32_pd since a,b < 1e6          */
    __m128d sqrt_b  = _mm_sqrt_pd(b);                 /* Vectorized sqrt: compute 2 sqrt-s with 1 instruction   */
    __m128d sum     = _mm_add_pd(a, sqrt_b);
    __m128d sum_lo  = sum;                            /* a1 + sqrt(b1) in the lower 64 bits                     */
    __m128d sum_hi  =  _mm_unpackhi_pd(sum, sum);     /* a2 + sqrt(b2) in the lower 64 bits                     */
    return _mm_comilt_sd(sum_lo, sum_hi);
}


int is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);
}


int main(){
    unsigned a1; unsigned b1; unsigned a2; unsigned b2;
    a1 = 11; b1 = 10; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 11; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 11; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 10; b2 = 11;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));

    return 0;
}


Voir ce lien Godbolt pour l'assembly généré.

Dans un test de débit simple sur Intel Skylake, avec les options du compilateur gcc -m64 -O3 -fno-math-errno -march=nehalem, J'ai trouvé un débit de is_smaller_v5() qui était 2,6 fois meilleur que l'original is_smaller(): 6,8 cycles de processeur vs 18 cycles de CPU, avec overhead de boucle inclus. Cependant, dans un test de latence (trop?) Simple, où les entrées a1, a2, b1, b2 Dépendaient du résultat de la précédente is_smaller(_v5), je n'ai vu aucune amélioration. (39,7 cycles vs 39 cycles).

2
wim

Peut-être pas mieux que les autres réponses, mais utilise une idée différente (et une masse de pré-analyse).

// Compute approximate integer square root of input in the range [0,10^6].
// Uses a piecewise linear approximation to sqrt() with bounded error in each piece:
//   0 <= x <= 784 : x/28
//   784 < x <= 7056 : 21 + x/112
//   7056 < x <= 28224 : 56 + x/252
//   28224 < x <= 78400 : 105 + x/448
//   78400 < x <= 176400 : 168 + x/700
//   176400 < x <= 345744 : 245 + x/1008
//   345744 < x <= 614656 : 336 + x/1372
//   614656 < x <= 1000000 : (784000+x)/1784
// It is the case that sqrt(x) - 7.9992711366390365897... <= pseudosqrt(x) <= sqrt(x).
unsigned pseudosqrt(unsigned x) {
    return 
        x <= 78400 ? 
            x <= 7056 ?
                x <= 764 ? x/28 : 21 + x/112
              : x <= 28224 ? 56 + x/252 : 105 + x/448
          : x <= 345744 ?
                x <= 176400 ? 168 + x/700 : 245 + x/1008
              : x <= 614656 ? 336 + x/1372 : (x+784000)/1784 ;
}

// known pre-conditions: a1 < a2, 
//                  0 <= b1 <= 1000000
//                  0 <= b2 <= 1000000
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
// Try three refinements:
// 1: a1 + sqrt(b1) <= a1 + 1000, 
//    so is a1 + 1000 < a2 ?  
//    Convert to a2 - a1 > 1000 .
// 2: a1 + sqrt(b1) <= a1 + pseudosqrt(b1) + 8 and
//    a2 + pseudosqrt(b2) <= a2 + sqrt(b2), 
//    so is  a1 + pseudosqrt(b1) + 8 < a2 + pseudosqrt(b2) ?
//    Convert to a2 - a1 > pseudosqrt(b1) - pseudosqrt(b2) + 8 .
// 3: Actually do the work.
//    Convert to a2 - a1 > sqrt(b1) - sqrt(b2)
// Use short circuit evaluation to stop when resolved.
    unsigned ad = a2 - a1;
    return (ad > 1000)
           || (ad > pseudosqrt(b1) - pseudosqrt(b2) + 8)
           || ((int) ad > (int)(sqrt(b1) - sqrt(b2)));
}

(Je n'ai pas de compilateur à portée de main, donc il contient probablement une faute de frappe ou deux.)

1
Eric Towers