web-dev-qa-db-fra.com

Le modèle pré-entraîné Tensorflow-Lite ne fonctionne pas dans la démo Android

La démo Android Tensorflow-Lite fonctionne avec le modèle original fourni: mobilenet_quant_v1_224.tflite. Voir: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite

Ils proposent également d’autres modèles allégés pré-entraînés ici: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md

Cependant, j'ai téléchargé certains modèles plus petits à partir du lien ci-dessus, par exemple mobilenet_v1_0.25_224.tflite, et j'ai remplacé le modèle d'origine par ce modèle dans l'application de démonstration en modifiant simplement le MODEL_PATH = "mobilenet_v1_0.25_224.tflite"; dans le ImageClassifier.Java. L'application se bloque avec:

12-11 12: 52: 34.222 17713-17729 /? E/AndroidRuntime: EXCEPTION FATALE: Processus CameraBackground: Android.example.com.tflitecamerademo, PID: 17713 Exception Java.lang.IllegalArgument: Impossible d'obtenir les dimensions en entrée. 0-ème entrée devrait avoir 602112 octets, mais trouvé 150528 octets. à org.tensorflow.lite.NativeInterpreterWrapper.getInputDims (Méthode native) à org.tensorflow.lite.NativeInterpreterWrapper.run (NativeInterpreterWrapper.Java:82) à l’agrément .tensorflow.lite.Interpreter.run (Interpreter.Java:93) à com.example.Android.tflitecam.imageClassify.classifyFrame (ImageClassifier.Java:108) à com.example.Android.tflitecamerademo.Camera2Rid.prendre : 663) à l'adresse com.example.Android.tflitecamerademo.Camera2BasicFragment.access 900 $ (Camera2BasicFragment.Java:69) à com.example.Android.tflitecamerademo.Camera2BasicFragment $ 5.run (Camera2BasicFragment.Java58). handleCallback (Handler.Java:751) sur Android.os.Handler.dispatchMessage (Handler.Java:95) sur Android.os.Looper.loop (Looper.Java:154) sur Android.os.HandlerThread.run (HandlerThread.Java : 61)

Cela semble être dû au fait que la dimension en entrée requise par le modèle est quatre fois plus grande que la taille de l'image. J'ai donc modifié DIM_BATCH_SIZE = 1 en DIM_BATCH_SIZE = 4. Maintenant l'erreur est:

EXCEPTION FATALE: Processus arrière-plan de la caméra: Android.example.com.tflitecamerademo, PID: 18241 Java.lang.IllegalArgumentException: impossible de convertir un tenseur TensorFlowLite de type FLOAT32 en objet Java de type [[B (compatible avec le type UINT de TensorFlowLite)] à org.tensorflow.lite.Tensor.copyTo (Tensor.Java:36) à org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs (Interpreter.Java:122) à org.tensorflow.lite.Interpreter.Jun:93 ) dans le répertoire accès 900 $ (Camera2BasicFragment.Java:69) à l'adresse com.example.Android.tflitecamerademo.Camera2BasicFragment $ 5.run (Camera2BasicFragment.Java:558) à Android.os.Handler.handleCallback (Handler.Java:75) à Android.os.Handler .dispatchMessage (Handler.Java:95) à Android.os.Loop er.loop (Looper.Java:154) sur Android.os.HandlerThread.run (HandlerThread.Java:61)

Ma question est de savoir comment faire fonctionner un modèle tflite à MobileNet réduit avec la démo Android TF-lite.

(J'ai en fait essayé d'autres choses, comme convertir un graphique figé TF en modèle TF-lite à l'aide de l'outil fourni, même en utilisant exactement le même exemple de code que dans https://github.com/tensorflow/tensorflow/blob/master/tensorflow /contrib/lite/toco/g3doc/cmdline_examples.md , mais le modèle converti tflite ne peut toujours pas fonctionner dans la démo Android.)

8
Seedling

Le fichier ImageClassifier.Java inclus dans la démo Android de Tensorflow-Lite s’attend à un modèle quantized. À l'heure actuelle, un seul des modèles Mobilenets est fourni sous forme quantifiée: Mobilenet 1.0 224 Quant.

Pour utiliser les autres modèles float, permutez ImageClassifier.Java à partir de la source de démonstration Tensorflow for Poets TF-Lite. Ceci est écrit pour les modèles float . https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/Android/tflite/app/src/main/Java /com/example/Android/tflitecamerademo/ImageClassifier.Java

Faites un diff et vous verrez qu'il existe plusieurs différences importantes dans la mise en œuvre.

Une autre option à envisager consiste à convertir les modèles float en modèles quantifiés à l'aide de TOCO: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md

4
Ash Eldritch

J'avais aussi les mêmes erreurs que Seedling… .. J'ai créé un nouveau wrapper de classificateur d'images pour le modèle Mobilenet Float… .. Cela fonctionne bien maintenant. Vous pouvez directement ajouter cette classe dans la démo du classificateur d'image et l'utiliser pour créer un classificateur dans Camera2BasicFragment.

classifier = new ImageClassifierFloatMobileNet(getActivity());

ci-dessous, l'emballage de la classe de classificateur d'Image pour le modèle Mobilenet Float

    /**
 * This classifier works with the Float MobileNet model.
 */
public class ImageClassifierFloatMobileNet extends ImageClassifier {

  /**
   * An array to hold inference results, to be feed into Tensorflow Lite as outputs.
   * This isn't part of the super class, because we need a primitive array here.
   */
  private float[][] labelProbArray = null;

  private static final int IMAGE_MEAN = 128;
  private static final float IMAGE_STD = 128.0f;

  /**
   * Initializes an {@code ImageClassifier}.
   *
   * @param activity
   */
  public ImageClassifierFloatMobileNet(Activity activity) throws IOException {
    super(activity);
    labelProbArray = new float[1][getNumLabels()];
  }

  @Override
  protected String getModelPath() {
    // you can download this file from
    // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_Android_quant_2017_11_08.Zip
//    return "mobilenet_quant_v1_224.tflite";
    return "retrained.tflite";
  }

  @Override
  protected String getLabelPath() {
//    return "labels_mobilenet_quant_v1_224.txt";
    return "retrained_labels.txt";
  }

  @Override
  public int getImageSizeX() {
    return 224;
  }

  @Override
  public int getImageSizeY() {
    return 224;
  }

  @Override
  protected int getNumBytesPerChannel() {
    // the Float model uses a 4 bytes
    return 4;
  }

  @Override
  protected void addPixelValue(int val) {
    imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
    imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
    imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
  }

  @Override
  protected float getProbability(int labelIndex) {
    return labelProbArray[0][labelIndex];
  }

  @Override
  protected void setProbability(int labelIndex, Number value) {
    labelProbArray[0][labelIndex] = value.byteValue();
  }

  @Override
  protected float getNormalizedProbability(int labelIndex) {
    return labelProbArray[0][labelIndex];
  }

  @Override
  protected void runInference() {
    tflite.run(imgData, labelProbArray);
  }
}
1
vikoo