hayakawaです。
今回はJavaでディープラーニングを実装できるOSSであるDeep Java Library(DJL)を使ってみました。
ディープラーニングで何かをやるとしたら、現状ではPythonで開発するケースが多いですが、
システム全体としてはJavaで開発をしたいが、ディープラーニングの処理を利用したい場合、
これまでだと良いソリューションがなく、別サービスとしてAPI呼び出しをしたり、
JavaからPythonのプロセスを呼び出したりするようなことが必要でした。
ただこれだと、パフォーマンスが求められる場合、なかなか厳しいものがあります。
このOSSはそのようなケースで、Java上でのディープラーニング処理も実行できるようにすることを目指しているようです。
aws.amazon.com
Pythonで学習・モデル作成を行い、それをJavaで推論する、ということもできるようですが、
今回は、基本的な判定処理と学習処理を、Javaのサンプルコードから読み解いてみます。
判定処理
判定処理は、物体の種類と位置を検知する物体検知タスクのサンプルを見てみました。
プリトレーニングモデルを用いて、短いコードで判定を実現しています。
サンプル(ObjectDetetion.java)
public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
BufferedImage img = BufferedImageUtils.fromFile(imageFile);
Map<String, String> criteria = new ConcurrentHashMap<>();
criteria.put("size", "512");
criteria.put("backbone", "resnet50");
criteria.put("flavor", "v1");
criteria.put("dataset", "voc");
try (ZooModel<BufferedImage, DetectedObjects> model =
MxModelZoo.SSD.loadModel(criteria, new ProgressBar())) {
try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
saveBoundingBoxImage(img, detection);
return detection;
}
}
}
(1)モデルの読込み
モデルの読込みは
MxModelZoo.SSD.loadModel(criteria, new ProgressBar()))
というコードでやっています。
最初の MxModelZoo.SSD というところで、物体検知用のアルゴリズムであるSSDを指定し、
さらにSSDのバックボーンとなる画像判定モデルを何にするかを、
上のcriteriaというMapに詰めて指定して渡しています。
今回は"VOC"(PASCAL VOCのデータセット)で学習されたResNet50 v1のモデルを指定しています。
学習済みモデルは、名前の通りMXNet上の物がサポートされているようです。
AWSさんですし、MXNetになりますよね。
MxModelZooで指定できるモデルは以下のページに表で一覧されています。
djl/mxnet/mxnet-model-zoo at master · awslabs/djl · GitHub
表の各列の意味は次の通りです。
列 |
説明 |
Application |
画像分類(Image Classification)やポーズ検知(Pose Estimation)などのタスクの種類。 |
Model Family |
SSDなどのモデルの分類名。この名前をMxModelZoo.~のところに指定します。 |
CriteriaとPossible values |
サンプルコードであったcriteriaに指定できる条件と値の組み合わせです。 |
今回やっている物体検知(Object Detection)は現在SSDのみをサポートしており、
バックボーンはVGGやMobileNetが使えるようです。
(2)判定器の生成
Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()
で、新しい判定器を生成しています。
(3)画像の判定
メソッドの先頭の
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
BufferedImage img = BufferedImageUtils.fromFile(imageFile);
で読み込んだ画像を、(2)で生成した判定器に
predictor.predict(img);
で渡し、さらに結果をsaveBoundingBoxImage()というprivateメソッドに渡して、
画像上に検知結果の名称と枠線を描画させています。
得られた画像が次のものです。
リポジトリ内の djl/examples/src/test/resources/ ディレクトリに、
試し斬りに使える画像が置いてあるので、
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
のところを書き換えて試してみましょう。
学習処理
学習の方は、伝統的なMNISTのサンプルがあったので試してみました。
サンプル(TrainMnist.java)
public static ExampleTrainingResult runExample(String[] args)
throws IOException, ParseException {
Arguments arguments = Arguments.parseArgs(args);
Block block =
new Mlp(
Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
Mnist.NUM_CLASSES,
new int[] {128, 64});
try (Model model = Model.newInstance()) {
model.setBlock(block);
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments);
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments);
DefaultTrainingConfig config = setupTrainingConfig(arguments);
config.addTrainingListeners(
TrainingListener.Defaults.logging(
TrainMnist.class.getSimpleName(),
arguments.getBatchSize(),
(int) trainingSet.getNumIterations(),
(int) validateSet.getNumIterations(),
arguments.getOutputDir()));
ExampleTrainingResult result;
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
trainer.initialize(inputShape);
TrainingUtils.fit(
trainer,
arguments.getEpoch(),
trainingSet,
validateSet,
arguments.getOutputDir(),
"mlp");
result = new ExampleTrainingResult(trainer);
}
model.save(Paths.get(arguments.getOutputDir()), "mlp");
return result;
}
}
(1) レイヤ構造を定義
ニューラルネットワークのレイヤ構造を、MLPというクラスに生成させています。
MLPクラスのコードはこちらです。
お手軽に多層パーセプトロンを作ってくれるようです。中は次のようになっていました。
public Mlp(int width, int height) {
add(Blocks.batchFlattenBlock(width * (long) height))
.add(new Linear.Builder().setOutChannels(128).build())
.add(Activation.reluBlock())
.add(new Linear.Builder().setOutChannels(64).build())
.add(Activation.reluBlock())
.add(new Linear.Builder().setOutChannels(10).build());
}
width×heightの入力を受け取り、各層で128個、64個、10個の出力をする層を重ねたネットワークを構築しているようです。
層の内容を変えたければこのクラスでやっているようにadd()メソッドで積み重ねて作れます。
(2) モデルの生成
(1)のレイヤ定義を使って
try (Model model = Model.newInstance()) {
model.setBlock(block);
で初期状態のモデルを生成しています。
(3) データの準備
getDataset()というprivateメソッドを呼び出して、MNIST用のデータを取得します。
Argument(プログラム引数)を渡しているのは、引数で指定したエポック数とバッチサイズに応じたデータ数を引っ張ってくるためのようです。
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments);
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments);
getDataset()の内部では、 Mnist.builder() という組み込みのMNISTデータロード用クラスにデータを作らせていました。
(4) 学習の設定
setupTrainingConfig()というprivateメソッドの内部で、
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.setBatchSize(arguments.getBatchSize())
.optDevices(Device.getDevices(arguments.getMaxGpus()));
という処理で学習用の設定インスタンスを生成しています。
プログラム引数のバッチサイズや利用してよいGPU数などをセットし、Softmax関数、クロスエントロピー誤差を指定しています。
(5) 学習器の初期化
(4)で生成した環境設定で、学習器を生成しています。
Trainer trainer = model.newTrainer(config))
(6) 学習開始
学習器やデータを渡して、実際に学習を開始します。
TrainingUtils.fit(
trainer,
arguments.getEpoch(),
trainingSet,
validateSet,
arguments.getOutputDir(),
"mlp");
(7) モデルの保存
model.save()を呼ぶとファイルに学習済みモデルを保存できます。
model.save(Paths.get(arguments.getOutputDir()), "mlp");