Автор: Эндрю Эйрес, Цин Лан, Навин Свами, Пиюш Гхай, Ижи Лю

Apache MXNet - это среда глубокого обучения с открытым исходным кодом, используемая для обучения и развертывания глубоких нейронных сетей. Он масштабируемый, что позволяет быстро обучать модели, и поддерживает гибкую модель программирования и несколько языков (C ++, Python, Julia, Clojure, JavaScript, R, Scala).

Сегодня сообщество Apache MXNet рада объявить о предварительной версии Java APIs для Inference. API MXNet для Java упрощают использование моделей глубокого обучения в системах и приложениях, построенных на основе популярного языка Java и среды выполнения. Эксперт по машинному обучению может обучить модели с помощью Python, точно настроить и сохранить модель, а можно загрузить модель и использовать ее в производстве для вывода с помощью API Java.

В этом посте мы опишем доступные Java API и проведем вас через быструю практическую настройку, чтобы начать работу с ними.

MXNet-Java API: что нового?

Простота использования была в центре нашего мыслительного процесса, когда мы разрабатывали API. С помощью всего нескольких строк кода можно загрузить современные модели, обученные с помощью Apache MXNet, и начать делать прогнозы. API-интерфейсы Java также имеют автоматизированную систему управления ресурсами, что упрощает управление объемом внутренней памяти без снижения производительности.

Вот доступные API-интерфейсы вывода на Java в MXNet:

  • Predictor API
  • Детектор объектов (одиночный и пакетный)

Predictor API предоставляет методы для выполнения логических выводов по предварительно обученной модели.

ObjectDetector - это оболочка поверх Predictor API. Он предоставляет методы для обнаружения отдельных объектов на изображении, а также их местоположения на изображении.

Быстрый старт с API Java

API Java в настоящее время поддерживают Java 8 и выше. Предварительная версия пакета MXNet Java доступна на Maven Central. Официальный выпуск будет вместе с предстоящим выпуском MXNet 1.4.

Шаг 0: Настройка среды

Следуйте пошаговому руководству здесь, чтобы настроить среду разработки для использования API Java.

Шаг 1. Загрузка предварительно обученной модели с помощью Predictor API

MXNet имеет обширную коллекцию предварительно обученных современных моделей глубокого обучения, доступных в MXNet Model Zoo. Существуют также предварительно обученные современные модели, реализованные с использованием Gluon (обязательный API для MXNet), доступные в Gluon Model Zoo. Позже в этом посте мы запустим пример вывода для модели Resnet50 Single Shot Detection, взятой из Model Zoo.

Чтобы загрузить предварительно обученную модель MXNet с использованием API Java, требуется путь к загруженной модели и описание входных данных модели. Чтобы указать ввод, ожидаемый моделью, мы создаем объект типа DataDesc. Входные данные обычно определяются на этапе обучения модели. DataDesc принимает имя входного слоя, форму входных данных, тип данных и порядок данных.

Shape inputShape = new Shape(new int[] {1,3,224,224});
DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW"); 
List<DataDesc> inputDescList = new ArrayList<DataDesc>();
inputDescList.add(inputDescriptor);

Легко указать, хотите ли вы запускать логический вывод на процессорах или графических процессорах (если у вас есть машина с поддержкой графического процессора), указав это в объекте Context.

List<Context> context = new ArrayList<>();
context.add(Context.cpu()); 
// For GPU, context.add(Context.gpu());

Примеры в этом руководстве были запущены на MacBook Pro с использованием контекста ЦП. Чтобы использовать MXNet с графическим процессором NVIDIA на Mac, следуйте инструкциям здесь.

String modelPathPrefix = "path-to-model";
Predictor predictor = new Predictor(modelPathPrefix, inputDescList, context);

modelPath - это папка, которая содержит файл символов (представляющий слои модели), файл параметров (обученные веса модели) и любые другие вспомогательные файлы, необходимые для модели.

Шаг 2: вывод с использованием предиктора

Класс Predictor имеет три функции прогнозирования, которые принимают входные данные и производят прогнозы в качестве выходных данных. Входные данные для функции прогнозирования могут быть либо NDArray, либо одномерным списком Java, либо одномерным массивом Java.

В MXNet NDArray - это основная структура данных для всех математических вычислений, связанных с моделями глубокого обучения.

Вот три примера вызова API Predictor:

List<NDArray> result = predictor.predictWithNDArray(inputNDArray);

or

List<List<Float>> result = predictor.predict(inputFloatList);

or

float[][] result = predictor.predict(inputFloatArray);

Шаг 3. Пришло время шоу!

Теперь, когда мы вкратце поговорили о том, как выглядит Predictor API, давайте взглянем на Пример обнаружения объектов, написанный на Java из репозитория MXNet. Мы будем использовать API детектора объектов для идентификации объектов и их местоположения на изображении.

Следуйте инструкциям здесь, чтобы загрузить файлы модели. Для работы с этим руководством нам потребуются три файла модели: файл символов, файл параметров и файл synset.txt. Файл символа содержит описание архитектуры модели, то есть слоев, используемых в модели. Файл Params содержит веса обученной модели, назначенные слоям. Файл synset содержит метки классов, используемые при обучении.

Давайте определим формы ввода, дескрипторы ввода и контекст, которые будут использоваться позже.

Shape inputShape = new Shape(new int[] {1,3,512,512});
DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW"); 
List<DataDesc> inputDescList = new ArrayList<DataDesc>();
inputDescList.add(inputDescriptor);
List<Context> context = new ArrayList<Context>();
context.add(Context.cpu());

Давайте рассмотрим описанные нами переменные.

  1. inputShape, определенный выше, представляет пакетный ввод с размером пакета 1, имеющий 3 канала (RGB) в изображении, а высота каждого изображения в пакете 512, а ширина каждого изображения - 512.
  2. Первый параметр объекта inputDescriptor - это имя входного слоя в файле символов модели, за которым следует inputShape, тип входных данных: Float32 и «NCHW», где N обозначает размер пакета, C обозначает каналы, H обозначает высоту, W обозначает ширину изображения.

Теперь давайте определим экземпляр ObjectDetector, API высокого уровня, который предоставляет метод для обнаружения отдельных объектов на изображении, их предполагаемых меток и местоположений на изображении. Он также содержит служебные методы для предварительной обработки входных изображений. Для создания экземпляра ObjectDetector нам понадобятся объекты modelPath, inputDescriptor и context.

String modelPathPrefix = "path-to-model/resnet50_ssd";
ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);

Нам нужно указать имя модели в качестве префикса после пути к модели, чтобы правильно загрузить ее с помощью Predictor API. например: если название модели - resnet50_ssd, а файлы модели загружены в папку /tmp/model/, тогда modelPathPrefix будет /tmp/model/resnet50_ssd.

Теперь давайте возьмем входное изображение и загрузим его, чтобы скормить его нашей модели. В качестве входных данных мы будем использовать изображение, взятое отсюда.

String inputImagePath = "path-to-downloaded-image";
BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);

Теперь мы можем выполнить логический вывод, вызвав:

int numberOfObjectsToDetect = 3; // returns top 3 objects in image
List<List<ObjectDetectorOutput>> output = objDet.imageObjectDetect(img, numberOfObjectsToDetect);

ProTip: imageObjectDetect метод возвращает вложенный список, который можно упростить, используя фрагмент кода.

ObjectDetectorOutput - служебный класс, который содержит предсказанную метку объекта, вероятность предсказанной метки, Xmax, Xmin, Ymax и Ymin. X и Y представляют значения пикселей для местоположения обнаруженного объекта на исходном изображении и могут использоваться для формирования ограничивающей рамки вокруг обнаруженного объекта.

Вот результат, полученный после преобразования с помощью ProTip:

Class: car
Probabilties: 0.98847263
Coord:312.21335,72.0291,456.01443,150.66176
Class: bicycle
Probabilties: 0.94833825
Coord:155.95807,149.96362,383.8369,418.94513
Class: dog
Probabilties: 0.8281818
Coord:83.82353,179.13998,206.63783,476.7875

Вот визуализация прогнозируемых результатов:

Шаг 4. Что дальше?

Вы можете попробовать запустить больше примеров в репозитории MXNet в Папке примеров Java.

Заключение

API-интерфейсы MXNet Java Inference позволяют разработчикам использовать предварительно обученные модели глубокого обучения с помощью Apache MXNet для начала работы с глубоким обучением.

Apache MXNe - проект с открытым исходным кодом. Если вам это нравится и вы хотите внести свой вклад, присоединяйтесь к проекту здесь.