web-dev-qa-db-fra.com

Pourquoi cette multiplication de matrice naïve est-elle plus rapide que les R de base?

Dans R, la multiplication matricielle est très optimisée, c'est-à-dire qu'il s'agit simplement d'un appel à BLAS/LAPACK. Cependant, je suis surpris que ce code C++ très naïf pour la multiplication matrice-vecteur semble fiable 30% plus rapide.

 library(Rcpp)

 # Simple C++ code for matrix multiplication
 mm_code = 
 "NumericVector my_mm(NumericMatrix m, NumericVector v){
   int nRow = m.rows();
   int nCol = m.cols();
   NumericVector ans(nRow);
   double v_j;
   for(int j = 0; j < nCol; j++){
     v_j = v[j];
     for(int i = 0; i < nRow; i++){
       ans[i] += m(i,j) * v_j;
     }
   }
   return(ans);
 }
 "
 # Compiling
 my_mm = cppFunction(code = mm_code)

 # Simulating data to use
 nRow = 10^4
 nCol = 10^4

 m = matrix(rnorm(nRow * nCol), nrow = nRow)
 v = rnorm(nCol)

 system.time(my_ans <- my_mm(m, v))
#>    user  system elapsed 
#>   0.103   0.001   0.103 
 system.time(r_ans <- m %*% v)
#>   user  system elapsed 
#>  0.154   0.001   0.154 

 # Double checking answer is correct
 max(abs(my_ans - r_ans))
 #> [1] 0

Les R de base %*% Effectuent-ils un certain type de vérification des données que je saute?

MODIFIER:

Après avoir compris ce qui se passe (merci SO!), Il convient de noter que c'est le pire des cas pour les R %*%, C'est-à-dire matrice par vecteur. Par exemple, @RalfStubner a souligné que l'utilisation d'une implémentation RcppArmadillo d'une multiplication matrice-vecteur est encore plus rapide que l'implémentation naïve que j'ai démontrée, impliquant beaucoup plus rapide que la base R, mais est pratiquement identique à la base R %*% Pour multiplication matrice-matrice (lorsque les deux matrices sont grandes et carrées):

 arma_code <- 
   "arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
 return m * m2;
 };"
 arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

 nRow = 10^3 
 nCol = 10^3

 mat1 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)
 mat2 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)

 system.time(arma_mm(mat1, mat2))
#>   user  system elapsed 
#>   0.798   0.008   0.814 
 system.time(mat1 %*% mat2)
#>   user  system elapsed 
#>   0.807   0.005   0.822  

Le courant de R (v3.5.0) %*% Est donc presque optimal pour matrice-matrice, mais pourrait être considérablement accéléré pour matrice-vecteur si vous êtes d'accord de sauter la vérification.

29
Cliff AB

Un coup d'œil rapide dans names.c ( ici en particulier ) vous indique do_matprod, la fonction C appelée par %*% et qui se trouve dans le fichier array.c. (Fait intéressant, il s'avère que crossprod et tcrossprod envoient également à la même fonction). Voici un lien vers le code de do_matprod.

En parcourant la fonction, vous pouvez voir qu'elle prend en charge un certain nombre de choses que votre implémentation naïve ne fait pas, notamment:

  1. Conserve les noms de ligne et de colonne, là où cela a du sens.
  2. Permet l'envoi vers d'autres méthodes S4 lorsque les deux objets sont opérés par un appel à %*% sont des classes pour lesquelles de telles méthodes ont été fournies. (C'est ce qui se passe dans cette partie de la fonction.)
  3. Gère les matrices réelles et complexes.
  4. Implémente une série de règles pour gérer la multiplication d'une matrice et d'une matrice, d'un vecteur et d'une matrice, d'une matrice et d'un vecteur, et d'un vecteur et d'un vecteur. (Rappelons que sous multiplication croisée dans R, un vecteur sur le LHS est traité comme un vecteur ligne, tandis que sur le RHS, il est traité comme un vecteur colonne; c'est le code qui le fait.)

Vers la fin de la fonction , il distribue à matprod ou ou cmatprod . Fait intéressant (du moins pour moi), dans le cas de matrices réelles, si l'une ou l'autre matrice peut contenir NaN ou Inf valeurs, puis matprod envoie ( ici ) à une fonction appelée simple_matprod qui est à peu près aussi simple et direct que le vôtre. Sinon, il est envoyé à l'une des deux routines BLAS Fortran qui, vraisemblablement, sont plus rapides, si des éléments matriciels uniformément "bien comportés" peuvent être garantis.

27
Josh O'Brien

La réponse de Josh explique pourquoi la multiplication matricielle de R n'est pas aussi rapide que cette approche naïve. J'étais curieux de voir combien on pouvait gagner en utilisant RcppArmadillo. Le code est assez simple:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

Référence:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10

Donc, RcppArmadillo nous donne une syntaxe plus agréable et de meilleures performances.

La curiosité a pris le dessus sur moi. Voici une solution pour utiliser directement BLAS:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")

Référence:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10

Armadillo et BLAS (OpenBLAS dans mon cas) sont presque les mêmes. Et le code BLAS est aussi ce que fait R au final. Donc 2/3 de ce que fait R est la vérification des erreurs, etc.

7
Ralf Stubner