Автор: Эндрю Эйрес, Цин Лан, Навин Свами, Пиюш Гхай, Ижи Лю
Apache MXNet - это среда глубокого обучения с открытым исходным кодом, используемая для обучения и развертывания глубоких нейронных сетей. Он масштабируемый, что позволяет быстро обучать модели, и поддерживает гибкую модель программирования и несколько языков (C ++, Python, Julia, Clojure, JavaScript, R, Scala).
Сегодня сообщество Apache MXNet рада объявить о предварительной версии Java APIs для Inference. API MXNet для Java упрощают использование моделей глубокого обучения в системах и приложениях, построенных на основе популярного языка Java и среды выполнения. Эксперт по машинному обучению может обучить модели с помощью Python, точно настроить и сохранить модель, а можно загрузить модель и использовать ее в производстве для вывода с помощью API Java.
В этом посте мы опишем доступные Java API и проведем вас через быструю практическую настройку, чтобы начать работу с ними.
MXNet-Java API: что нового?
Простота использования была в центре нашего мыслительного процесса, когда мы разрабатывали API. С помощью всего нескольких строк кода можно загрузить современные модели, обученные с помощью Apache MXNet, и начать делать прогнозы. API-интерфейсы Java также имеют автоматизированную систему управления ресурсами, что упрощает управление объемом внутренней памяти без снижения производительности.
Вот доступные API-интерфейсы вывода на Java в MXNet:
- Predictor API
- Детектор объектов (одиночный и пакетный)
Predictor API предоставляет методы для выполнения логических выводов по предварительно обученной модели.
ObjectDetector - это оболочка поверх Predictor API. Он предоставляет методы для обнаружения отдельных объектов на изображении, а также их местоположения на изображении.
Быстрый старт с API Java
API Java в настоящее время поддерживают Java 8 и выше. Предварительная версия пакета MXNet Java доступна на Maven Central. Официальный выпуск будет вместе с предстоящим выпуском MXNet 1.4.
Шаг 0: Настройка среды
Следуйте пошаговому руководству здесь, чтобы настроить среду разработки для использования API Java.
Шаг 1. Загрузка предварительно обученной модели с помощью Predictor API
MXNet имеет обширную коллекцию предварительно обученных современных моделей глубокого обучения, доступных в MXNet Model Zoo. Существуют также предварительно обученные современные модели, реализованные с использованием Gluon (обязательный API для MXNet), доступные в Gluon Model Zoo. Позже в этом посте мы запустим пример вывода для модели Resnet50 Single Shot Detection
, взятой из Model Zoo.
Чтобы загрузить предварительно обученную модель MXNet с использованием API Java, требуется путь к загруженной модели и описание входных данных модели. Чтобы указать ввод, ожидаемый моделью, мы создаем объект типа DataDesc
. Входные данные обычно определяются на этапе обучения модели. DataDesc принимает имя входного слоя, форму входных данных, тип данных и порядок данных.
Shape inputShape = new Shape(new int[] {1,3,224,224}); DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW"); List<DataDesc> inputDescList = new ArrayList<DataDesc>(); inputDescList.add(inputDescriptor);
Легко указать, хотите ли вы запускать логический вывод на процессорах или графических процессорах (если у вас есть машина с поддержкой графического процессора), указав это в объекте Context
.
List<Context> context = new ArrayList<>(); context.add(Context.cpu()); // For GPU, context.add(Context.gpu());
Примеры в этом руководстве были запущены на MacBook Pro с использованием контекста ЦП. Чтобы использовать MXNet с графическим процессором NVIDIA на Mac, следуйте инструкциям здесь.
String modelPathPrefix = "path-to-model"; Predictor predictor = new Predictor(modelPathPrefix, inputDescList, context);
modelPath
- это папка, которая содержит файл символов (представляющий слои модели), файл параметров (обученные веса модели) и любые другие вспомогательные файлы, необходимые для модели.
Шаг 2: вывод с использованием предиктора
Класс Predictor
имеет три функции прогнозирования, которые принимают входные данные и производят прогнозы в качестве выходных данных. Входные данные для функции прогнозирования могут быть либо NDArray, либо одномерным списком Java, либо одномерным массивом Java.
В MXNet NDArray - это основная структура данных для всех математических вычислений, связанных с моделями глубокого обучения.
Вот три примера вызова API Predictor:
List<NDArray> result = predictor.predictWithNDArray(inputNDArray);
or
List<List<Float>> result = predictor.predict(inputFloatList);
or
float[][] result = predictor.predict(inputFloatArray);
Шаг 3. Пришло время шоу!
Теперь, когда мы вкратце поговорили о том, как выглядит Predictor API, давайте взглянем на Пример обнаружения объектов, написанный на Java из репозитория MXNet. Мы будем использовать API детектора объектов для идентификации объектов и их местоположения на изображении.
Следуйте инструкциям здесь, чтобы загрузить файлы модели. Для работы с этим руководством нам потребуются три файла модели: файл символов, файл параметров и файл synset.txt. Файл символа содержит описание архитектуры модели, то есть слоев, используемых в модели. Файл Params содержит веса обученной модели, назначенные слоям. Файл synset содержит метки классов, используемые при обучении.
Давайте определим формы ввода, дескрипторы ввода и контекст, которые будут использоваться позже.
Shape inputShape = new Shape(new int[] {1,3,512,512}); DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW"); List<DataDesc> inputDescList = new ArrayList<DataDesc>(); inputDescList.add(inputDescriptor); List<Context> context = new ArrayList<Context>(); context.add(Context.cpu());
Давайте рассмотрим описанные нами переменные.
inputShape
, определенный выше, представляет пакетный ввод с размером пакета 1, имеющий 3 канала (RGB) в изображении, а высота каждого изображения в пакете 512, а ширина каждого изображения - 512.- Первый параметр объекта
inputDescriptor
- это имя входного слоя в файле символов модели, за которым следуетinputShape
, тип входных данных: Float32 и «NCHW», где N обозначает размер пакета, C обозначает каналы, H обозначает высоту, W обозначает ширину изображения.
Теперь давайте определим экземпляр ObjectDetector
, API высокого уровня, который предоставляет метод для обнаружения отдельных объектов на изображении, их предполагаемых меток и местоположений на изображении. Он также содержит служебные методы для предварительной обработки входных изображений. Для создания экземпляра ObjectDetector нам понадобятся объекты modelPath
, inputDescriptor
и context
.
String modelPathPrefix = "path-to-model/resnet50_ssd"; ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
Нам нужно указать имя модели в качестве префикса после пути к модели, чтобы правильно загрузить ее с помощью Predictor API. например: если название модели -
resnet50_ssd
, а файлы модели загружены в папку/tmp/model/
, тогдаmodelPathPrefix
будет/tmp/model/resnet50_ssd
.
Теперь давайте возьмем входное изображение и загрузим его, чтобы скормить его нашей модели. В качестве входных данных мы будем использовать изображение, взятое отсюда.
String inputImagePath = "path-to-downloaded-image";
BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
Теперь мы можем выполнить логический вывод, вызвав:
int numberOfObjectsToDetect = 3; // returns top 3 objects in image List<List<ObjectDetectorOutput>> output = objDet.imageObjectDetect(img, numberOfObjectsToDetect);
ProTip:
imageObjectDetect
метод возвращает вложенный список, который можно упростить, используя фрагмент кода.
ObjectDetectorOutput
- служебный класс, который содержит предсказанную метку объекта, вероятность предсказанной метки, Xmax
, Xmin
, Ymax
и Ymin
. X и Y представляют значения пикселей для местоположения обнаруженного объекта на исходном изображении и могут использоваться для формирования ограничивающей рамки вокруг обнаруженного объекта.
Вот результат, полученный после преобразования с помощью ProTip:
Class: car Probabilties: 0.98847263 Coord:312.21335,72.0291,456.01443,150.66176 Class: bicycle Probabilties: 0.94833825 Coord:155.95807,149.96362,383.8369,418.94513 Class: dog Probabilties: 0.8281818 Coord:83.82353,179.13998,206.63783,476.7875
Вот визуализация прогнозируемых результатов:
Шаг 4. Что дальше?
Вы можете попробовать запустить больше примеров в репозитории MXNet в Папке примеров Java.
Заключение
API-интерфейсы MXNet Java Inference позволяют разработчикам использовать предварительно обученные модели глубокого обучения с помощью Apache MXNet для начала работы с глубоким обучением.
Apache MXNe - проект с открытым исходным кодом. Если вам это нравится и вы хотите внести свой вклад, присоединяйтесь к проекту здесь.