hayakawaです。
今回はJavaでディープラーニングを実装できるOSSであるDeep Java Library(DJL)を使ってみました。
ディープラーニングで何かをやるとしたら、現状ではPythonで開発するケースが多いですが、
システム全体としてはJavaで開発をしたいが、ディープラーニングの処理を利用したい場合、
これまでだと良いソリューションがなく、別サービスとしてAPI呼び出しをしたり、
JavaからPythonのプロセスを呼び出したりするようなことが必要でした。
ただこれだと、パフォーマンスが求められる場合、なかなか厳しいものがあります。
このOSSはそのようなケースで、Java上でのディープラーニング処理も実行できるようにすることを目指しているようです。
Pythonで学習・モデル作成を行い、それをJavaで推論する、ということもできるようですが、
今回は、基本的な判定処理と学習処理を、Javaのサンプルコードから読み解いてみます。
判定処理
判定処理は、物体の種類と位置を検知する物体検知タスクのサンプルを見てみました。
プリトレーニングモデルを用いて、短いコードで判定を実現しています。
public static DetectedObjects predict() throws IOException, ModelException, TranslateException { Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg"); BufferedImage img = BufferedImageUtils.fromFile(imageFile); Map<String, String> criteria = new ConcurrentHashMap<>(); criteria.put("size", "512"); criteria.put("backbone", "resnet50"); criteria.put("flavor", "v1"); criteria.put("dataset", "voc"); try (ZooModel<BufferedImage, DetectedObjects> model = MxModelZoo.SSD.loadModel(criteria, new ProgressBar())) { //(1) try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()) { //(2) DetectedObjects detection = predictor.predict(img); //(3) saveBoundingBoxImage(img, detection); return detection; } } }
(1)モデルの読込み
モデルの読込みは
MxModelZoo.SSD.loadModel(criteria, new ProgressBar()))
というコードでやっています。
最初の MxModelZoo.SSD というところで、物体検知用のアルゴリズムであるSSDを指定し、
さらにSSDのバックボーンとなる画像判定モデルを何にするかを、
上のcriteriaというMapに詰めて指定して渡しています。
今回は"VOC"(PASCAL VOCのデータセット)で学習されたResNet50 v1のモデルを指定しています。
学習済みモデルは、名前の通りMXNet上の物がサポートされているようです。
AWSさんですし、MXNetになりますよね。
MxModelZooで指定できるモデルは以下のページに表で一覧されています。
djl/mxnet/mxnet-model-zoo at master · awslabs/djl · GitHub
表の各列の意味は次の通りです。
列 | 説明 |
Application | 画像分類(Image Classification)やポーズ検知(Pose Estimation)などのタスクの種類。 |
Model Family | SSDなどのモデルの分類名。この名前をMxModelZoo.~のところに指定します。 |
CriteriaとPossible values | サンプルコードであったcriteriaに指定できる条件と値の組み合わせです。 |
今回やっている物体検知(Object Detection)は現在SSDのみをサポートしており、
バックボーンはVGGやMobileNetが使えるようです。
(2)判定器の生成
Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()
で、新しい判定器を生成しています。
(3)画像の判定
メソッドの先頭の
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
BufferedImage img = BufferedImageUtils.fromFile(imageFile);
で読み込んだ画像を、(2)で生成した判定器に
predictor.predict(img);
で渡し、さらに結果をsaveBoundingBoxImage()というprivateメソッドに渡して、
画像上に検知結果の名称と枠線を描画させています。
得られた画像が次のものです。
リポジトリ内の djl/examples/src/test/resources/ ディレクトリに、
試し斬りに使える画像が置いてあるので、
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
のところを書き換えて試してみましょう。
学習処理
学習の方は、伝統的なMNISTのサンプルがあったので試してみました。
public static ExampleTrainingResult runExample(String[] args) throws IOException, ParseException { Arguments arguments = Arguments.parseArgs(args); // Construct neural network Block block = new Mlp( Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES, new int[] {128, 64}); //(1) try (Model model = Model.newInstance()) { model.setBlock(block); //(2) // get training and validation dataset RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments); RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments); //(3) // setup training configuration DefaultTrainingConfig config = setupTrainingConfig(arguments); //(4) config.addTrainingListeners( TrainingListener.Defaults.logging( TrainMnist.class.getSimpleName(), arguments.getBatchSize(), (int) trainingSet.getNumIterations(), (int) validateSet.getNumIterations(), arguments.getOutputDir())); ExampleTrainingResult result; try (Trainer trainer = model.newTrainer(config)) { //(5) trainer.setMetrics(new Metrics()); /* * MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray. * 1st axis is batch axis, we can use 1 for initialization. */ Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH); // initialize trainer with proper input shape trainer.initialize(inputShape); TrainingUtils.fit( //(6) trainer, arguments.getEpoch(), trainingSet, validateSet, arguments.getOutputDir(), "mlp"); result = new ExampleTrainingResult(trainer); } model.save(Paths.get(arguments.getOutputDir()), "mlp"); //(7) return result; } }
(1) レイヤ構造を定義
ニューラルネットワークのレイヤ構造を、MLPというクラスに生成させています。
MLPクラスのコードはこちらです。
お手軽に多層パーセプトロンを作ってくれるようです。中は次のようになっていました。
public Mlp(int width, int height) { add(Blocks.batchFlattenBlock(width * (long) height)) .add(new Linear.Builder().setOutChannels(128).build()) .add(Activation.reluBlock()) .add(new Linear.Builder().setOutChannels(64).build()) .add(Activation.reluBlock()) .add(new Linear.Builder().setOutChannels(10).build()); }
width×heightの入力を受け取り、各層で128個、64個、10個の出力をする層を重ねたネットワークを構築しているようです。
層の内容を変えたければこのクラスでやっているようにadd()メソッドで積み重ねて作れます。
(2) モデルの生成
(1)のレイヤ定義を使って
try (Model model = Model.newInstance()) {
model.setBlock(block);
で初期状態のモデルを生成しています。
(3) データの準備
getDataset()というprivateメソッドを呼び出して、MNIST用のデータを取得します。
Argument(プログラム引数)を渡しているのは、引数で指定したエポック数とバッチサイズに応じたデータ数を引っ張ってくるためのようです。
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments); RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments);
getDataset()の内部では、 Mnist.builder() という組み込みのMNISTデータロード用クラスにデータを作らせていました。
(4) 学習の設定
setupTrainingConfig()というprivateメソッドの内部で、
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) .setBatchSize(arguments.getBatchSize()) .optDevices(Device.getDevices(arguments.getMaxGpus()));
という処理で学習用の設定インスタンスを生成しています。
プログラム引数のバッチサイズや利用してよいGPU数などをセットし、Softmax関数、クロスエントロピー誤差を指定しています。
(5) 学習器の初期化
(4)で生成した環境設定で、学習器を生成しています。
Trainer trainer = model.newTrainer(config))
(6) 学習開始
学習器やデータを渡して、実際に学習を開始します。
TrainingUtils.fit(
trainer,
arguments.getEpoch(),
trainingSet,
validateSet,
arguments.getOutputDir(),
"mlp");
(7) モデルの保存
model.save()を呼ぶとファイルに学習済みモデルを保存できます。
model.save(Paths.get(arguments.getOutputDir()), "mlp");
感想
基本的な判定と学習を見てみましたが、APIの構成内容がオーソドックスで名前にクセも無いため、
他のディープラーニング用フレームワークを知っていれば習得は速そうです。
またGPUは勝手に見つけて勝手に使ってくれるらしく楽です。
OpenCVなどを用いる画像処理ライブラリは、DJL側でラッパーを用意しており、
基本的な操作であれば独自にOpenCVを触らないでよさそうです。
上で紹介した以外のサンプルも同じディレクトリに配置されており、
基本的にはサンプルを真似して使えば一通りのことはできそうだと感じました。
Java上のディープラーニングライブラリとしては、良い候補になりそうです。
Acroquest Technologyでは、キャリア採用を行っています。
- ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
- Elasticsearch等を使ったデータ収集/分析/可視化
- マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
- 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長
少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。
Kaggle Masterと働きたい尖ったエンジニアWanted! - Acroquest Technology株式会社のデータサイエンティストの求人 - Wantedlywww.wantedly.com