web-dev-qa-db-fra.com

Fonction GBM R: Obtenir une importance variable séparément pour chaque classe

J'utilise la fonction gbm de R (package gbm) pour adapter les modèles de renforcement de gradient stochastique à la classification multiclass. J'essaie simplement d'obtenir l'importance de chaque prédicteur séparément pour chaque classe, comme dans cette image du livre { livre de Hastie (Les éléments de l'apprentissage statistique) } (p. 382). 

enter image description here

Cependant, la fonction summary.gbm ne renvoie que l’importance global des prédicteurs (leur importance moyenne sur toutes les classes). 

Est-ce que quelqu'un sait comment obtenir les valeurs d'importance relative? 

18
Antoine

Je pense que la réponse courte est que, à la page 379, Hastie mentionne qu’il utilise MART , qui ne semble être disponible que pour Splus.

Je conviens que le paquet gbm ne semble pas permettre de voir l’influence relative séparée. Si cela vous intéresse pour un problème de classe multicouche, vous pouvez probablement obtenir quelque chose de très similaire en construisant un gbm un pour tous pour vos classes, puis en obtenant les mesures d’importance de chacun de ces modèles.

Donc disons que vos classes sont a, b, c, & d. Vous modélisez un vs le reste et obtenez l'importance de ce modèle. Ensuite, vous modélisez b par rapport au reste et obtenez l’importance de ce modèle. Etc.

11
Tchotchke

Espérons que cette fonction vous aide. Pour l'exemple, j'ai utilisé des données du package ElemStatLearn. La fonction détermine la nature des classes d'une colonne, scinde les données en ces classes, exécute la fonction gbm () sur chaque classe et trace les diagrammes à barres de ces modèles.

# install.packages("ElemStatLearn"); install.packages("gbm")
library(ElemStatLearn)
library(gbm)

set.seed(137531)

# formula: the formula to pass to gbm()
# data: the data set to use
# column: the class column to use
classPlots <- function (formula, data, column) {

    class_column <- as.character(data[,column])
    class_values <- names(table(class_column))
    class_indexes <- sapply(class_values, function(x) which(class_column == x))
    split_data <- lapply(class_indexes, function(x) marketing[x,])
    object <- lapply(split_data, function(x) gbm(formula, data = x))
    rel.inf <- lapply(object, function(x) summary.gbm(x, plotit=FALSE))

    nobjs <- length(class_values)
    for( i in 1:nobjs ) {
        tmp <- rel.inf[[i]]
        tmp.names <- row.names(tmp)
        tmp <- tmp$rel.inf
        names(tmp) <- tmp.names

        barplot(tmp, horiz=TRUE, col='red',
                xlab="Relative importance", main=paste0("Class = ", class_values[i]))
    }
    rel.inf
}

par(mfrow=c(1,2))
classPlots(Income ~ Marital + Age, data = marketing, column = 2)

`

 output

7
nathanesau

J'ai fait quelques recherches sur la manière dont le paquetage gbm calcule l'importance et il est basé sur l'erreur ErrorReduction qui est contenue dans l'élément tree du résultat et est accessible via pretty.gbm.trees(). L'influence relative est obtenue en prenant la somme de cette réduction d'erreur sur tous les arbres pour chaque variable. Pour un problème multiclass, il existe en réalité des arborescences n.trees*num.classes dans le modèle. Donc, s'il y a 3 classes, vous pouvez calculer la somme de ErrorReduction pour chaque variable sur chaque troisième arbre afin d'obtenir l'importance pour une classe. J'ai écrit les fonctions suivantes pour implémenter ceci puis tracer les résultats:

Obtenez une importance variable par classe

RelInf_ByClass <- function(object, n.trees, n.classes, Scale = TRUE){
  library(dplyr)
  library(purrr)
  library(gbm)
  Ext_ErrRed<- function(ptree){
    ErrRed <- ptree %>% filter(SplitVar != -1) %>% group_by(SplitVar) %>% 
      summarise(Sum_ErrRed = sum(ErrorReduction))
  }
  trees_ErrRed <- map(1:n.trees, ~pretty.gbm.tree(object, .)) %>% 
    map(Ext_ErrRed)

  trees_by_class <- split(trees_ErrRed, rep(1:n.classes, n.trees/n.classes)) %>% 
    map(~bind_rows(.) %>% group_by(SplitVar) %>% 
          summarise(rel_inf = sum(Sum_ErrRed)))
  varnames <- data.frame(Num = 0:(length(object$var.names)-1),
                         Name = object$var.names)
  classnames <- data.frame(Num = 1:object$num.classes, 
                           Name = object$classes)
  out <- trees_by_class %>% bind_rows(.id = "Class") %>%  
    mutate(Class = classnames$Name[match(Class,classnames$Num)],
    SplitVar = varnames$Name[match(SplitVar,varnames$Num)]) %>%
    group_by(Class) 
  if(Scale == FALSE){
    return(out)
    } else {
    out <- out %>% mutate(Scaled_inf = rel_inf/max(rel_inf)*100)
    }
}

Importance variable de la parcelle par classe

Dans mon utilisation réelle de ce logiciel, j'ai plus de 40 fonctionnalités, je vous donne donc la possibilité de spécifier le nombre de fonctionnalités à tracer. Je ne pouvais pas non plus utiliser de facettage si je voulais que les parcelles soient triées séparément pour chaque classe, c'est pourquoi j'ai utilisé gridExtra

plot_imp_byclass <- function(df, n) {
  library(ggplot2)
  library(gridExtra)
  plot_imp_class <- function(df){
    df %>% arrange(rel_inf) %>% 
      mutate(SplitVar = factor(SplitVar, levels = .$SplitVar)) %>% 
      ggplot(aes(SplitVar, rel_inf))+
      geom_segment(aes(x = SplitVar, 
                       xend = SplitVar, 
                       y = 0, 
                       yend = rel_inf))+
      geom_point(size=3, col = "cyan") + 
      coord_flip()+
      labs(title = df$Class[[1]], x = "Variable", y = "Importance")+
      theme_classic()+
      theme(plot.title = element_text(hjust = 0.5))
  }

  df %>% top_n(n, rel_inf) %>% split(.$Class) %>% 
    map(plot_imp_class) %>% map(ggplotGrob) %>% 
    {grid.arrange(grobs = .)}
}

L'essayer

gbm_iris <- gbm(Species~., data = iris)
imp_byclass <- RelInf_ByClass(gbm_iris, length(gbm_iris$trees), 
                              gbm_iris$num.classes, Scale = F)
plot_imp_byclass(imp_byclass, 4)

Semble donner les mêmes résultats que la fonction relative.influence intégrée si vous additionnez les résultats sur toutes les classes. 

relative.influence(gbm_iris)
# n.trees not given. Using 100 trees.
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
# 0.00000     51.88684   2226.88017    868.71085 

imp_byclass %>% group_by(SplitVar) %>% summarise(Overall_rel_inf = sum(rel_inf))
# A tibble: 3 x 2
# SplitVar     Overall_rel_inf
# <fct>                  <dbl>
#   1 Petal.Length          2227. 
# 2 Petal.Width            869. 
# 3 Sepal.Width             51.9
2
see24