AI w aplikacji, czyli TensorFlow na Androidzie
Obecnie prawie każdy producent elektroniki twierdzi, że w jego sprzęcie mieszka potężny dżin zwany sztuczną inteligencją. Po smartfonach przychodzi pomału kolej na sprzęty AGD, takie jak pralka czy mikrofalówka. Choć oczywiście “AI” to marketingowy buzzword, bo póki co nie pogadamy z naszą lodówką zamawiając na wieczór schłodzone Martini (wstrząśnięte, nie zmieszane), to myślę, że warto przyjrzeć się bliżej technologiom pozwalającym na użycie sieci neuronowych na urządzeniach mobilnych.
Krzysztof Joachimiak. Machine Learning Engineer i Software Developer w StethoMe®. Zaczynał jako Android Developer, po czym przeskoczył do swojej ulubionej działki IT, czyli uczenia maszynowego. Stara się trzymać rękę na pulsie jeśli chodzi o rozwój tak zwanej – nieraz trochę na wyrost — sztucznej inteligencji. Ostatnio związany z poznańskim startupem z branży medycznej, StethoMe®.
Spis treści
Zdalnie czy lokalnie?
Jeśli piszemy aplikację na system Android, która ma korzystać z dobrodziejstw sztucznej inteligencji, stajemy przed wyborem — czy nasz model ma być uruchamiany na telefonie/tablecie, czy też ma funkcjonować jako web service odpytywany przez nasze urządzenie mobilne. Głównymi kryteriami przy podejmowaniu decyzji jest wielkość samego modelu, czas na wykonanie potrzebnych obliczeń, oraz to, czy model ma działać offline. W większości przypadków zdecydujemy się zapewne na zdalne API. Co jednak zrobić w przypadku, jeśli chcielibyśmy skorzystać z sieci neuronowej, która ma być uruchamiana bezpośrednio na urządzeniu z Androidem? Proponuję użyć dobrze znaną bibliotekę TensorFlow od Google.
Zanim zaczniemy
- Przygotuj odpowiednie środowisko. Osobiście pracowałem na Linuksie (Ubuntu 16.04) korzystając z Pythona 2.7.
- Zainstaluj TensorFlow.
- Zainstaluj najnowszą wersję narzędzia Bazel.
- Sklonuj repozytorium TensorFlow.
- Zbuduj narzędzia służące do transformowania i optymalizacji sieci
- bazel build sciezka_do_tensorflow/tools/graph_transforms:transform_graph
- bazel build sciezka_do_tensorflow/tools/graph_transforms:summarize_graph
- bazel build sciezka_do_tensorflow/contrib/lite/toco:toco
Eksportujemy naszą sieć
W tym tutorialu chodzi mi o pokazanie samego procesu eksportowania sieci, nie będę więc powielać treści mnóstwa tutoriali do TensorFlow i celowo pominę proces treningu sieci. Załóżmy, że będzie to banalnie prosta architektura, reprezentująca regresję logistyczną.
Podczas tworzenia modelu musimy pamiętać, żeby nazwać węzeł wejściowy oraz wyjściowy grafu; w przeciwnym wypadku będziemy zmuszeni do jego ręcznego przeglądania i znalezienia nazw nadanych automatycznie przez TensorFlow.
`` python export_graph.py import tensorflow as tf import os OUTPUT_DIR = "net" OUTPUT_NAME = "model" OUTPUT_SIZE = 1 # wielkość wyjścia if __name__ == '__main__': # Definicja sieci X = tf.placeholder(tf.float32, (None, 10), name="input") network = tf.layers.dense(X, OUTPUT_SIZE) network = tf.nn.sigmoid(network) # Użycie funkcji liniowej to dobry sposób na nadanie nazwy ostatniemu węzłowi, jeśli # np. wczytujemy skądś sieć. Tworzymy wówczas dodatkowy węzeł network = tf.identity(network, name= "output") init = tf.global_variables_initializer() with tf.Session() as sess: # Inicjalizacja sieci losowymi wagami sess.run(init) # Zapisujemy graf tf.train.write_graph(sess.graph_def, '.',os.path.join(OUTPUT_DIR, "{}.pbtxt".format(OUTPUT_NAME))) # Zapisujemy wagi modelu saver = tf.train.Saver() saver.save(sess, os.path.join(OUTPUT_DIR, OUTPUT_NAME))
Podczas tworzenia modelu w TensorFlow warto zwrócić uwagę na dwie rzeczy:
- Jeśli wykorzystujemy warstwy konwolucyjne, bezpieczniej jest użyć domyślnej kolejności wymiarów z kanałami na końcu. Tylko ta wersja jest obecnie wspierana przez implementację TensoFlow działającą na CPU.
- Normalizacja batcha wymaga ustawiania odpowiedniej flagi na etapie predykcji (flaga training).
Po uruchomieniu tego skryptu, otrzymamy folder zawierający pięć plików, w tym <OUTPUT_NAME>.pbtx
, zawierający strukturę grafu obliczeniowego, oraz cztery pliki wyprodukowane przez Savera, zawierające wagi modelu.
Następnym krokiem jest połączenie struktury z wagami w jednym pliku. W tym celu możemy użyć narzędzia freeze_graph, które należy zbudować podobnie jak pozostałe narzędzia wymienione przeze mnie w punkcie 5., lub użyć poniższego skryptu:
``python freeze.py import sys from tensorflow.python.tools import freeze_graph import os INPUT_DIR = sys.argv[1] OUTPUT_DIR = sys.argv[2] MODEL_NAME = sys.argv[3] OUTPUT_NODE = sys.argv[4] freeze_graph.freeze_graph(input_graph=os.path.join(INPUT_DIR, "{}.pbtxt".format(MODEL_NAME)), input_saver="", input_binary=False, input_checkpoint=os.path.join(INPUT_DIR, MODEL_NAME), output_node_names=OUTPUT_NODE, restore_op_name="save/restore_all", filename_tensor_name="save/Const:0", output_graph=os.path.join(OUTPUT_DIR,"{}.pb".format(MODEL_NAME)), clear_devices=True, initializer_nodes="") ``
Jeśli zdecydujemy się na uruchomienie użycia skryptu, powinniśmy zastosować następującą komendę:
python freeze.py model model net output
W tym miejscu będziemy już mieć plik model.pb, który zawiera pełen model sieci — strukturę grafu obliczeniowego wraz z wagami. To jednak nie jest jeszcze ostatni krok. Powinniśmy dodatkowo zoptymalizować nasz graf, aby obliczenia były mniej zasobożerne (co zresztą jest szczególnie ważne na urządzeniu mobilnym!).
Optymalizacja sieci
W tym celu używamy uprzednio zbudowanego narzędzia graph_transform. Jako argumenty musimy przekazać nazwę pliku wejściowego (model.pb) oraz wyjściowego — dla odróżnienia nazwijmy go model_opt.pb. graph_transform oferuje szereg różnego rodzaju transformacji; ja polecam użyć przede wszystkim fold_constants, remove_device (bez tego graf będzie żądał odpalenia na konkretnym rodzaju urządzenia, np. na karcie graficznej, czego nie oferuje biblioteka TesnorFlow na Andoridzie). Pełną listę dostępnych transformacji można znaleźć tutaj. Wśród nich na pewno warto zwrócić uwagę na fold_batch_norm oraz fold_old_batch_norm używaną do optymalizacji obliczeń w warstwie normalizacji batcha.
```bash input_file=model.pb output_file=model_opt.pb /home/some/path/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=$input_file --out_graph=$output_file --inputs='input' # nazwy węzłów wejściowych --outputs='output' # nazwy węzłów wyjściowych --transforms=' remove_device fold_constants(ignore_errors=true) ```
Jeden skrypt
Wszystkie powyższe operacje możemy szybko powtórzyć, korzystając z poniższego skryptu:
temp=output # !!!! Ustaw ścieżkę do lokalnego repozytorium TF !!!! tensorflow_path=/home/user/Desktop/tensorflow # Input path=$1 model_name=$2 input_node=$3 output_node=$4 mkdir $temp python freeze.py $path $temp $model_name $output_node # https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/#add_default_attributes input_file=$temp/$model_name.pb output_file=$temp/$model_name"_opt.pb" echo $input_file $output_file $tensorflow_path/bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=$input_file --out_graph=$output_file --inputs=$input_node --outputs=$output_node --transforms=' remove_device fold_constants(ignore_errors=true) fold_batch_norms fold_old_batch_norms strip_unused_nodes remove_attribute(attribute_name=_class)' $tensorflow_path/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=$output_file
Na koniec działania, skrypt wyświetli nam podsumowanie dla danej sieci, w tym m.in. liczbę jej parametrów.
TF Mobile czy TF Lite
Przez pewien czas istniały równocześnie dwa standardy: starszy TF Mobile oraz nowszy TF Lite. Świadomie trzymałem się starszego, gdyż nowszy przez długi czas nie posiadał zaimplementowanych wielu operacji, np.: funkcji ELU. Obecnie TF Mobile jest porzucany na rzecz TF Lite i na początku 2019 ten pierwszy standard stanie się przestarzały.
TF Mobile
Otwieramy nasz projekt z aplikacją androidową i umieszczamy nasz plik model_opt.pb w folderze assets.
W pliku build.gradle dodajemy najnowszą wersję biblioteki TensorFlow
implementation 'org.tensorflow:tensorflow-android:+'
Następnie stwórzmy sobie klasę pomocniczą:
```java package com.stethome.androidtf; import android.content.Context; import org.tensorflow.contrib.android.TensorFlowInferenceInterface; public class TensorflowModel { // Model private static final String MODEL_FILE = "file:///android_asset/model_opt.pb"; private static final String INPUT_NODE = "input"; private static final String OUTPUT_NODE = "output"; private static final long[] INPUT_SIZE = {1, 10}; // Tensorflow interface private TensorFlowInferenceInterface inferenceInterface; public TensorflowModel(Context context) { this.inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE); } public float[] predict(float[] input) { this.inferenceInterface.feed(INPUT_NODE, input, INPUT_SIZE); this.inferenceInterface.run(new String[]{OUTPUT_NODE}); // Result float[] result = new float[1]; // wielkość wyjścia inferenceInterface.fetch(OUTPUT_NODE, result); return result; }
Jej użycie w aplikacji będzie wyglądać następująco:
```java package com.stethome.androidtf; import android.support.v7.app.AppCompatActivity; import android.os.Bundle; import android.util.Log; import android.util.TimingLogger; import java.util.Arrays; import java.util.concurrent.TimeUnit; public class MainActivity extends AppCompatActivity { private TensorflowModel tfModel; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); tfModel = new TensorflowModel(this); // Dummy data float[] input = new float[10]; // Predykcje float[] output = tfModel.predict(input); Log.d("OUTPUT", Arrays.toString(output)); Log.d("OUTPUT SIZE", Integer.toString(output.length)); } } ```
TF Lite
W celu skonwertowania naszego modelu do nowego standardu, użyjmy kolejnego narzędzia z repozytorium TF:
```bash input_file=model_opt.pb output_file=model.lite input_node_name=input output_node_name=output /home/some/path/tensorflow/bazel-bin/tensorflow/contrib/lite/toco/toco --input_file=$input_file --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=$output_file --inference_type=FLOAT --input_type=FLOAT --input_arrays=$input_node_name --output_arrays=$output_node_name --input_shapes=1,10 # ```
Dodajemy plik model.lite do folderu assets.
W pliku build.gradle dodajemy najnowszą wersję biblioteki TensorFlow
implementation ‘org.tensorflow:tensorflow-lite:+’
Ładujemy plik z modelem:
private MappedByteBuffer loadModel(Activity activity,String MODEL_FILE) throws IOException { AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_FILE); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long length = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, length); }
Uruchamiamy interpreter TF Lite:
import org.tensorflow.lite.Interpreter; String MODEL="model.lite"; Interpreter tflite; try { tflite=new Interpreter(loadModelFile(MainActivity.this,modelFile)); } catch (IOException e) { e.printStackTrace(); }
A następnie wykonujemy predykcję:
float[][] input=new float[1][10]; float[][] output=new float[][]{{0}}; tflite.run(inp,out);
Podsumowanie
Jak widać, wykorzystanie sieci neuronowej do uruchamiania bezpośrednio na urządzeniu z Androidem nie jest skomplikowane. Myślę, że TensorFlow jest najwygodniejszym do tego narzędziem.