Taste of Tech Topics

Taste of Tech Topics

Acroquest Technology株式会社のエンジニアが書く技術ブログ

FastAPI+StrawberryでGraphQLのAPIを実現する

はじめに

最近アクアリウムを始めました、菅野です。
プログラムと異なり、生体を扱う都合上の想定外を楽しみながら試行錯誤しております。

さて、皆さんはAPIサーバを構築する際に、どのAPI形式を用いていますか?
まだまだREST形式で実装することが多いかとは思いますが、
GraphQLを用いることも増えてきているのではないでしょうか?

今回は、そんなGraphQLをFastAPIと各種ライブラリを用いて簡単に実装する方法を紹介していこうと思います。

GraphQLとは

GraphQLは、Meta社(旧Facebook社)によって開発・公開されたAPI仕様です。クエリ形式で、処理やパラメータの内容を指定します。
RESTとの比較としては、

  • クライアント側で取得したい情報をクエリとして渡すことができるため、 利用しないデータを無駄に受け取らなくて済む。
  • 一つのエンドポイントに対し複数リソースのクエリを一度にリクエストできるので、APIコール数を削減できる

等の利点があります。

GraphQL | A query language for your API

構成

PythonAPI作成フレームワーク、FastAPIと、GraphQLライブラリStrawberryを組み合わせることで、簡単にGraphqlAPIを実現することができます。 今回は、上記ライブラリに加え、ORMライブラリSQLAlchemyを用いてSQLite3に接続するTODO管理アプリのバックエンドサーバを構築します。

利用ライブラリ

利用するPythonおよび主なライブラリとそのバージョンは以下になります。

ライブラリ バージョン
Python 3.9.6
FastAPI 0.85.0
Strawberry 0.131.1
SQLAlchemy 1.4.41

実装

ディレクトリ構成

ディレクトリ構成は以下のようになっています。
クリーンアーキテクチャの構成に従いディレクトリを分けております。

/
│  poetry.lock
│  pyproject.toml      
├─db
└─src
    │  app.py
    │  contexts.py
    │  database.py
    │  router.py
    │  resolvers.py
    │  __init__.py
    │  
    ├─domain
    │  ├─model
    │  │      task.py
    │  │          
    │  └─service
    │          task_service.py
    │              
    ├─infra
    │  └─repository
    │          models.py
    │          task_repository.py
    │              
    └─web
        └─task
                inputs.py
                types.py

クラス図

それぞれのコードの関係性は以下になっています。

各種コード説明

infra/repositoryディレクト

データベースに接続するためのデータモデル定義と、各種DB操作の実装をこちらに記載します。

models.py

DBテーブルで利用するカラム定義を記載します。
一般的に、SQLAlchemyで用いるDBデータモデルクラスはDeclarative Extensionsを用いて、
テーブルとドメインモデルクラスのマッピングを行います。
しかし、上記を利用すると、ドメインモデルクラスがSQLAlchemyに依存する形になってしまうため、
今回は利用せず、ドメインモデルクラスをSQLAlchemyに依存しないように分けて実装します。

Declarative Extensions — SQLAlchemy 1.4 Documentation

from domain.model.task import Task ,Status
from sqlalchemy import (Column, DateTime, Enum,
                        String, Table)
from sqlalchemy.orm import registry
from sqlalchemy.sql import func
from sqlalchemy_utils import UUIDType, ChoiceType
import uuid
import enum

mapper_registry = registry()

# Taskテーブルの定義
task = Table(
    'task',
    mapper_registry.metadata,
    Column('id', UUIDType(binary=False), primary_key=True, default=uuid.uuid4),
    Column('description', String(200)),
    Column('title', String(200)),
    Column('status', Enum(Status)),
    Column('updated_at', DateTime, default=func.now())
)
mapper_registry.map_imperatively(Task, task)

task_repository.py

DBセッション情報を受け取り、CRUD操作を行う処理を記載します。

from sqlalchemy.orm import Session

from domain.model.task import Task


class TaskRepository:
    def __init__(self, db: Session):
        self.__db = db

    def find_by_id(self, id_: str) -> Task:
        task = self.__db.query(Task).get(id_)

        return task

    def find_all(self) -> list[Task]:
        # tasks = list(cls._store.values())
        tasks = self.__db.query(Task).all()

        return tasks

    def save(self, task: Task) -> None:
        self.__db.add(task)
        self.__db.commit()

    def delete(self, task: Task) -> None:
        self.__db.delete(task)
        self.__db.commit()

domainディレクト

内部で行うビジネスロジック(今回ではAPI実行時の各種処理)とデータ型の定義を行います。

service/task_service.py

データアクセスをリポジトリクラスに委任することで接続先に依存しない実装になっています。
例えば、DB接続ではなくオンメモリでデータを保持するようになった場合はリポジトリクラスのみの実装を変更すればよくなります。

from datetime import datetime
from typing import Optional

from domain.model.task import Status, Task
from infra.repository.task_repository import TaskRepository


class TaskService:
    def __init__(self, repo: TaskRepository) -> None:
        self.__repo = repo

    @property
    def repo(self) -> type[TaskRepository]:
        return self.__repo

    def find(self, id: str) -> Task:
        task = self.repo.find_by_id(id)

        return task

    def find_all(self) -> list[Task]:
        tasks = self.repo.find_all()

        return tasks

    def create(self, *, title: str, description: Optional[str] = None) -> Task:
        task = Task(title=title, description=description)
        self.repo.save(task)

        return task

    def update(self, *, id: str, status: Status) -> Task:
        task = self.repo.find_by_id(id)
        task.status = status
        task.updated_at = datetime.utcnow()
        self.repo.save(task)

        return task

    def delete(self, id: str) -> Task:
        task = self.repo.find_by_id(id)
        self.repo.delete(task)

        return task

model/task.py

DB接続モデルクラスでも説明したように、SQLAlchemyに依存しないドメインモデルクラスとして、
利用するデータのモデルクラスを別途実装します。

import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional


class Status(Enum):
    TODO = 'todo'
    DOING = 'doing'
    DONE = 'done'


@dataclass
class Task:
    title: str
    id: uuid.UUID = field(default_factory=uuid.uuid4)
    description: Optional[str] = None
    status: Status = Status.TODO
    updated_at: datetime = field(default_factory=datetime.utcnow)

web/taskディレクト

DB接続、ドメインロジックの実装が終わったので、 いよいよ本投稿の要旨であるStrawberryを用いたGraphql実装部分の肝である、webクラスを実装していきます。

types.py

GraphQLでは、クエリを構造を定義する必要があります。 Strawberryではstrawberry.typeアノテーションを用いてコードベースでクエリスキーマを定義できます。 Python型とGraphQLスキーマ型との対応は以下のリンクを参照してください。

Schema basics | 🍓 Strawberry GraphQL

from datetime import datetime
from typing import Optional

import strawberry

from domain.model.task import Status, Task

StatusType = strawberry.enum(Status, name='Status')


@strawberry.type(name='Task')
class TaskType:
    id: strawberry.ID
    title: str
    description: Optional[str]
    status: StatusType
    updated_at: datetime

input.py

GraphQLリクエストの入力も同様にコードベースで定義します。 strawberry.inputアノテーションを付けたクラスが入力スキーマとして利用できるようになります。

from typing import Optional

import strawberry


@strawberry.input
class AddTaskInput:
    title: str
    description: Optional[str] = None


@strawberry.input
class UpdateTaskInput:
    id: strawberry.ID
    status: str

srcディレクトリ直下

FastAPI起動用のコードを記載します。

database.py DB接続情報と初回起動時のテーブル初期化処理を記載します。

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from infra.repository.models import mapper_registry


SQLALCHEMY_DATABASE_URI = 'sqlite:///../db/tasks.db'


class DatabaseContext:
    def initialize(self):

        engine = create_engine(
            SQLALCHEMY_DATABASE_URI, connect_args={'check_same_thread': False}, echo=True
        )

        self.SessionLocal = sessionmaker(
            autocommit=False, autoflush=False, bind=engine)
        Base = mapper_registry.generate_base()
        Base.metadata.create_all(bind=engine)


database_context = DatabaseContext()


def get_db():
    """
    Get database

    Yields:
        SessionLocal: Local session for database connection
    """
    db = database_context.SessionLocal()
    try:
        yield db
    finally:
        db.close()

context.py

コンテキスト情報として、各種サービスとリポジトリクラスの関係性を保持するように定義します。

from fastapi import Depends
from strawberry.fastapi import BaseContext

from database import get_db
from domain.service.task_service import TaskService
from infra.repository.task_repository import TaskRepository


def init_task_repository(db=Depends(get_db)):
    return TaskRepository(db)


def init_task_service(task_repository: TaskRepository = Depends(init_task_repository)):
    return TaskService(
        task_repository
    )


class TaskContext(BaseContext):
    def __init__(self, task: TaskService):
        self.__task: TaskService = task

    def get_task(self):
        return self.__task


class TaskServicesContext(BaseContext):
    def __init__(self, task: TaskService):
        self.__task: TaskService = task

    def get_task(self):
        return self.__task


async def get_context(
        task_service: TaskService = Depends(init_task_service)
) -> TaskContext:
    return TaskContext(
        task=task_service
    )

resolver.py

コンテキストからサービス情報を取得し、対応する処理を呼び出す実装を記載します。

import strawberry

from web.task.inputs import AddTaskInput, UpdateTaskInput
from web.task.types import TaskType
from strawberry.types import Info


def get_task(id: strawberry.ID, info: Info) -> TaskType:
    service = info.context.get_task()
    task = service.find(id)

    return task


def get_tasks(info: Info) -> list[TaskType]:

    service = info.context.get_task()
    tasks = service.find_all()

    return tasks


def add_task(task_input: AddTaskInput, info: Info) -> TaskType:
    service = info.context.get_task()
    task = service.create(**task_input.__dict__)

    return task


def update_task(task_input: UpdateTaskInput, info: Info) -> TaskType:
    service = info.context.get_task()
    task = service.update(**task_input.__dict__)

    return task


def delete_task(id: strawberry.ID, info: Info) -> TaskType:
    service = info.context.get_task()
    task = service.delete(id)

    return task

router.py

GraphQLのQuery, Mutationの形式を定義し、実行する処理をリゾルバとして渡すように記載します。

import strawberry
from strawberry.fastapi import GraphQLRouter
from resolvers import add_task, delete_task, get_task, get_tasks, update_task

from web.task.types import TaskType
from contexts import get_context


@strawberry.type
class Query:
    task: TaskType = strawberry.field(resolver=get_task)
    tasks: list[TaskType] = strawberry.field(resolver=get_tasks)


@strawberry.type
class Mutation:
    task_add: TaskType = strawberry.field(resolver=add_task)
    task_update: TaskType = strawberry.field(resolver=update_task)
    task_delete: TaskType = strawberry.field(resolver=delete_task)


schema = strawberry.Schema(query=Query, mutation=Mutation)
task_app = GraphQLRouter(schema, context_getter=get_context)

app.py

from fastapi import FastAPI
from router import task_app
import uvicorn

from database import database_context

api = FastAPI()


def register_controller():
    api.include_router(task_app, prefix='/task')


if __name__ == '__main__':
    database_context.initialize()
    register_controller()

    uvicorn.run(app=api, host='0.0.0.0', port=8000)

実装したrouter.pyの内容を登録し、FastAPIを起動する処理を記載します。

起動、API呼び出し

起動コマンドは以下のようになっています。

cd ./src
python -m app

起動に成功するとlocalhost:8000にAPIが立ち上がります。
/taskにGraphQLルータを追加したので、
localhost:8000/taskにブラウザでアクセスするとStrawberryのGraphQL UIページにアクセスできます。

中央のパネルにGraphQLクエリを入力して実行することができます。
左側のメニューからドキュメントを確認したり、クエリの探索的作成もできる便利なUIになっています。

データの投入クエリは以下のようになっています。

mutation addDataSample {
  taskAdd(taskInput:{ title:"test", description: "testTask"}){
    title
  }
}

クエリの実行結果は右側のパネルに表示されます。

投入したデータを一覧取得するクエリを実行してみましょう。

query listtasks {
  tasks {
    id
    title
    status
    description
    updatedAt
  }
}

無事、投入したデータを確認することができました。

さいごに

FastAPIとGraphQLライブラリStrawberryを用いて簡単にGraphQLAPIを実装する方法を紹介しました。
ライブラリに則った定義を記載するだけで簡単にGraphQLのAPIを実現できます。
手軽に実現できるため、今後APIを構築する際はGraphQLも選択肢の一つに入るのではないかと思います。
REST API、GraphQL双方の利点、欠点を踏まえながら最適な形式を選択していきたいものですね。

それでは!

Acroquest Technologyでは、キャリア採用を行っています。
  • ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
  • Elasticsearch等を使ったデータ収集/分析/可視化
  • マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
  • 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長
  少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。 www.wantedly.com

QuickSight のダッシュボードを 1-Click で埋め込んでみた

こんにちは、機械学習エンジニアの駿です。
データサイエンスチームYAMALEXの一員としても活動しています。

最近、ビジネスシーンや日常会話でのちょっとしたフレーズを英語では何と言う?という内容で、会社でYouTuberを始めました。
Short 動画を上げているので是非ご視聴ください。

最近のおすすめはこちら。 「明日も頑張ろう!」を英語で!

さて、今回は QuickSight の機能、 1-Click 埋め込みを試してみました。

自分のサイトで QuickSight のダッシュボードを共有したいときに、ワンクリックで埋め込み用 HTML コードのコピー&ペーストが可能です。
SDKAPI を使ったプログラム実装が不要で気楽に共有できるようになります。

Web サイトに埋め込むだけでなく、Cognito認証ありの場合の 1-Click 埋め込みの URL の挙動なども検証してみました。

Amazon QuickSight

Amazon QuickSightAWS が提供する BI ツールで、各種 AWS サービス、サードパーティクラウド、オンプレのデータに接続し可視化することができます。

以前 QuickSight の異常検出を試した記事も上げているので、興味がある方はこちらもご覧ください。

acro-engineer.hatenablog.com

QuickSight の外部サイト埋め込み

QuickSight には作成したダッシュボードを Web サイトや Wiki に埋め込むための URL を API を利用して生成する機能があります。

API を利用した埋め込みではユーザが Web サイトにアクセスした際に、アプリケーションサーバが QuickSight の API を呼び出し、埋め込み用 URL を取得します。
Web サイトは取得した URL を埋め込むことで、ユーザに QuickSight ダッシュボードが埋め込まれたページを表示することができます。

QuickSight の 1-Click 埋め込み

1-Click 埋め込みは2022年5月に GA となった機能です。
従来の外部サイト埋め込みと違い、コーディングや開発をすることなく、 Web サイトや Wikiダッシュボードを埋め込むことができます。
1-Click 埋め込みを活用することで、より素早くダッシュボードを共有し、ユーザにインサイトを提供することが可能になります。

それでは実際に Web サイトに 1-Click 埋め込みを使って取得した HTML コードを使ってダッシュボードを表示できるのか試してみます。

事前準備

1. QuickSight アカウント作成

QuickSight アカウントを作成します。

このとき、エンタープライズ版を選んでください。

スタンダードと比べて、料金は高くなりますが、スタンダード版には含まれない、 組み込みダッシュボードなどの機能を使うことができるようになります。
初めて利用する場合は無料枠もあります。

作成時に選びそびれてしまった/既にスタンダード版で作成したアカウントを持っている場合は、設定からエンタープライズ版に変更してください。

2. ダッシュボード作成

まず、 Web サイトに埋め込むためのダッシュボードを作成します。
今回はサンプルで用意されている分析 Web and Social Media Analytics analysis からダッシュボードを作成します。

分析を開いたら、右上の共有メニューから「ダッシュボードを公開」を選択することで、ダッシュボードを作ることができます。
「1-Click埋め込みテスト」という名前でダッシュボードを作成しました。


ダッシュボードの公開

新しいダッシュボードとして公開

3. CloudFront + S3 で静的ホスティング

次にダッシュボードを埋め込むための Web サイトを作成します。 S3 に静的ファイルを配置し、 CloudFront で配信することで実現します。

(1). 静的ファイルを配置

S3 バケットを作成し、次の HTML コードを index.html としてアップロードします。

<!DOCTYPE html>
  <html lang="en">
  <head>
    <meta charset="UTF-8">
    <title>1-Click埋め込み</title>
  </head>
<body>
  <h1>↓QuickSight 1-Click埋め込み↓</h1>

  <!-- ここにコピーしたHTMLコードをペーストします  -->
</body>
</html>

(2). CloudFront で配信

CloudFront のコンソールから「ディストリビューションを作成」を選択します。

オリジンドメインの入力を求められますが、クリックするとドロップダウンで候補が表示されるため、その中から先ほど作成したS3バケットドメインを選択します。

そのほか、S3 バケットアクセスを Origin access control settings (recommended) に設定、ビューワープロトコルポリシーを Redirect HTTP to HTTPS に設定しました。


使用するS3バケットドメインを選択

ブラウザで {ディストリビューションドメイン名} + "/index.html" にアクセスして下のような画面が表示されたら成功です。
まだ QuickSight のダッシュボードを埋め込んでいないため、タイトルのみが表示されています。


Webサイトにアクセスできた

4. QuickSight で CloudFront のドメインを許可

Web サイト上で QuickSight の埋め込みを表示するためには、 QuickSight 側で Web サイトのドメインを許可する必要があります。
QuickSight の埋め込みはアクセスした分課金されるため、他の人のサイトに勝手に張り付けられるなどして意図せずアクセスが増え、課金されてしまう、などの事故を防ぐことができます。

「QuickSight の管理」画面で「ドメインと埋め込み」を選択してください。
ドメイン」に先ほど作成した CloudFront のディストリビューションドメイン名を入力し、「追加」ボタンを押します。
下のリストに CloudFront のドメインが表示されれば成功です。


許可するドメインを追加できた

ダッシュボードの 1-Click 埋め込み

以上で事前準備が完了したので、ここからは実際に 1-Click 埋め込み機能を使って Web サイトに埋め込むための HTML コードを作成し、 Web サイトに埋め込みます。

1. URL 取得

右上の共有メニューから「ダッシュボードの共有」を選択します。

共有画面の上部に「埋め込みコードをコピー」というリンクがあります。
これをクリックすると自動でクリップボードに、埋め込むための HTML コードがコピーされます。


埋め込みコードをコピー

2. HTML に張り付け

先ほど作成した index.html<!-- ここにコピーしたHTMLコードをペーストします --> 部分にペーストしてください。

<iframe
  width="960"
  height="720"
  src="https://us-west-2.quicksight.aws.amazon.com/sn/embed/share/...">
</iframe>

貼り付けられたら保存して、 index.html を S3 にアップロードしましょう。

3. 表示

あとは CloudFront の URL にアクセスするだけです。

初回表示時は QuickSight へのログインが必要です。
ポップアップが表示されるので、QuickSight アカウントでログインしてください。
(ブラウザの設定によってはポップアップがブロックされている可能性があります。)

次回以降はログインされた状態が保持されます。


ログインを求められる

作成したダッシュボードを簡単に Web サイトに埋め込むことができました。


埋め込み完了

動作を確認してみる

1. 公開のダッシュボードを共有してみる

公開ダッシュボードを共有することで、ログイン不要でインターネット上の全ユーザがダッシュボードを埋め込んだ Web サイト上で分析を閲覧することができます。
なお、公開ダッシュボード機能は $250/月(2022/09/03 現在)かかるセッションキャパシティーを使用するため、試す際は費用にお気を付けください。

まず、非公開のダッシュボードではログインが必要なことを再度確認するため、プライベートブラウザで CloudFront のページにアクセスします。
(普通のブラウザでは先ほどすでにログインしているため、そのまま開けてしまうため、プライベートブラウザを使います。)
上と同じ、初回表示時のログインを促されました。いったん、ログインはせずにそのまま置いておきます。

それでは、公式のドキュメント に従って、インターネット上の全員にダッシュボードへのアクセスを許可します。


ダッシュボードを公開する

公開ダッシュボードに設定しても 1-Click 埋め込みの URL は変わりません。
先ほどログインを求められていたプライベートブラウザをリロードしてみると、今度はログインを求められずに表示することができました。
(キャプチャは先ほどログインして表示したダッシュボードと変わらないため、割愛します。)

2. ドメインを許可していないサイトで表示してみる

ドメインを許可していないサイトで表示しようとしても、表示できないことを確認します。

事前準備4. で許可したドメインの右にあるごみ箱マークを押して、ドメインの許可を削除します。


ドメインの削除

ドメインを削除した状態で Web サイトを開くと このページは開けません とエラーページが表示されました。
確かにドメインを許可した Web サイトでないとダッシュボードを表示できないことを確認できました。


ダッシュボードは表示されない

3. Cognito でログインした状態で表示してみる

最後に Cognito ユーザとしてログインしたサイト上に埋め込まれたダッシュボードは、追加でログインする必要があるのかどうかを検証します。
検証用に作成した Amplify でログイン機能を付けたサイトに、今まで通り 1-Click 埋め込みの URL を貼り付けて、公開しました。

Cognito ユーザでログインすると QuickSight の認証はどうなるのか確認してみます。


Cognito ユーザでログインする

ログインのポップアップが表示される

Cognito ユーザとしてログインしていても、 QuickSight ユーザとしてログインを求められることが分かりました。
1-Click 埋め込みで取得した URL には認証情報が入っていないため、当然ともいえるかもしれません。

もちろん、 QuickSight ユーザでログインすればダッシュボードを表示することができます。


ログインすれば表示できる

1-Click 埋め込みと API を利用した埋め込みの違い

  1. 有効期限

    API を利用した埋め込みでは URL に5分間の有効期限があります。
    そのため Web サイトにアクセスするたびに API を呼び出してURLを取得し、 Web サイトにその URL をプログラム的に埋め込む必要があります。

    その反面、 1-Click 埋め込みの URL は分析者が無効にするまで利用することができます。
    一度取得した URL を静的に HTML に埋め込むだけでいいので、 Web サイトの開発コストが少ないのが利点です。

  2. 認証方法

    1-Click 埋め込みの URL はユーザ固有のものではなく、同一の URL にアクセスすると、 QuickSight による認証にリダイレクトされる仕組みになっています。
    Web サイトのログインと QuickSight のログインが別で必要になってしまいますが、 QuickSight のログイン情報は次回以降も保持されるため、一度ログインしてしまえば再度ログインする必要はありません。

    API を利用した埋め込みでは特定のユーザ用の URL を生成する方法とユーザ認証のない URL を生成する方法があります。
    ユーザ固有の URL を生成する場合も、 Web サイトへのログイン情報をもとに URL を生成できるため、 QuickSight による認証は必要ありません。

まとめると下のようになります。
それぞれに向き不向きがあるため、ユースケースに合ったやり方を選択する必要があります。

方式 有効期限 認証方法
1-Click 埋め込み なし QuickSight による認証
API を利用した埋め込み 5分間 特定ユーザの認証情報を含んだ URL /ユーザ認証のない URL

まとめ

今回は開発不要で Web サイトに QuickSight のダッシュボードを簡単に埋め込むことができる、 1-Click 埋め込みを使ってみました。

QuickSight の API を学習する必要がなく、画面をクリックしていくだけで実現できました。
あまりに簡単だったので、ほんとにこれでいいの?と思ってしまいました。

それでは、
Let's keep up the good work tomorrow!

Acroquest Technologyでは、キャリア採用を行っています。
  • ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
  • Elasticsearch等を使ったデータ収集/分析/可視化
  • マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
  • 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長
  少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。 www.wantedly.com

第49回Elasticsearch勉強会で、ElasticsearchによるNLP(質問応答)の発表をしてきました

こんにちは、@shin0higuchiです😊
先日、第49回Elasticsearch勉強会を開催しました。

私からは、Elasticsearch 8.3 で実装された、PyTorchモデルによる質問応答機能を紹介しました。
発表のスライドはこちらです。

www.slideshare.net
以下、発表の内容について簡単に説明します。

概要

「質問応答」とは?

今回の発表のテーマである「質問応答」とは、機械学習タスクのひとつで、
一般に、利用者の質問に対して適切な回答を自動で返すことを指します。

活用先の例としては、チャットボットで製品に関する質問に回答させることなどが考えられます。
この場合、質問に対する回答は製品マニュアルに書いてあるはずですので、マニュアル内の適切な箇所を抜き出して回答するのが良いと言えます。
※チャットボットの口調などはまた別の話になるので、ここでは扱いません。

たとえばユーザが「ネットに繋がらない」という質問をしたときに、
「インターネットに繋がらない場合は機器背面のボタンを長押しして再起動をしてください」という回答を返せば適切と言えるでしょう。


Elasticsearchは、 7.6で教師あり機械学習を利用できるようになって以降、
PyTorchモデルのインポートが可能になるなど自然言語処理周りの機能追加が続いており、8.3で質問応答がサポートされた形です。


Elasticsearchで質問応答を実現する仕組みは下図のようなイメージです。
製品マニュアルのような大量のデータから、質問文に対する回答となる箇所を適切に抽出する内容となります。

学習済みモデルの取り込み

Elasticsearchに機械学習モデルをインポートするためには eland というPythonライブラリを利用します。
github.com

eland は次のコマンドで簡単にインストールすることができます。

python -m pip install eland


elandをインストールすると、eland_import_hub_model コマンドが使えるようになります。
HuggingFace社が、様々な学習済みモデルをダウンロード可能なプラットフォームを提供しており、
elandでは公開済みモデルのIDを指定するだけで簡単にElasticsearchに取り込むことが可能です。
huggingface.co

  • deepset/tinyroberta-squad2 というモデルを取り込む場合のコマンド例
eland_import_hub_model --cloud-id xxxx -u ml_import_user -p xxxx --hub-model-id deepset/tinyroberta-squad2 --task-type question_answering 

※ ml_import_user というユーザーで、インポートしています。適宜書き換えてください。

モデルの利用

モデルは、KibanaのUIから、もしくはElasticserchのAPIを通じて実行可能です。

たとえば、Inference API は、以下のような形で利用することができます。

curl -XPOST "http://<host>:9200/_ml/trained_models/question-answering-demo/_infer?timeout=60s&pretty" -H "Content-Type: application/json" -H "Authorization: ApiKey xxxxxxxx" -d'
{
  "inference_config": {
    "question_answering": {
      "question": "ネットに繋がらない"
    }
  },
  "docs": [
    {
      "text_field": "「スリープ設定」で設定した時間内に操作しないと液晶モニターが消灯します。いずれかのボタンを押すと、復帰します。\n        インターネットに繋がらない場合は機器背面のボタンを長押しして再起動をしてください。それでも繋がらない場合はサポートカウンターまでお問い合わせください。\n        暗い場所では、液晶モニターの明るさを維持するためにノイズが出ることがあります。印刷に影響はありません。\n        2枚以上の連続プリントまたは周囲温度が高いところでのプリントは時間がかかることがあります。\n        印刷時にプチプチという音がすることがありますが、インク・紙の走行によるものであり異常ではありません。\n        パソコンで作成したフォルダ名に特殊文字が入っている場合、そのフォルダ内の画像は表示できません。フォルダ名を変更してください。 "
    }
  ]
}'

レスポンスは以下のような形式で返ってきます。

{
  "inference_results" : [
    {
      "predicted_value" : " インターネットに繋がらない場合は機器背面のボタンを長押しして再起動をしてください",
      "start_offset" : 72,
      "end_offset" : 113,
      "prediction_probability" : 0.702419152359398672
    }
  ]
}


以上のように、質問と情報源を渡すことで、回答を得ることができます。
通常の全文検索とはまた異なる方向の情報検索が実現でき、活用の可能性を感じますね。

今回は学習済みのモデルをそのまま取り込んで利用しましたが、
日本語での精度を改善したい場合や、ドメイン特有の文脈を上手く扱いたい場合などは、
転移学習と呼ばれる手法で、既存モデルを自身のユースケースにフィットさせてください。

転移学習そのものの方法は、ここでは割愛しますが、
TorchScriptと呼ばれる形式に変換したモデルを、下記のように取り込むことが可能です。

import eland as ed
from elasticsearch import Elasticsearch
from eland.ml.pytorch import PyTorchModel

es = Elasticsearch("http://ml_admin:xxxxxxxx@<host>:9200")
ptm = PyTorchModel(es, 'question-answering-demo')
ptm.import_model(model_path='/path/to/model', config_path='/path/to/config', vocab_path='/path/to/vocabfile', config=config)

詳細は下記のリンクをご確認ください。
GitHub - elastic/eland: Python Client and Toolkit for DataFrames, Big Data, Machine Learning and ETL in Elasticsearch


今後の期待

今回紹介した機能は、実装されたばかりの機能なので今後の改善に期待したい箇所もあります。
たとえば、現状はモデル実行に多少時間がかかる場合があります。こちらは 8.4 で推論時のキャッシュ利用や、推論時のスレッド数を並列化する改善が入っているようなので、その効果に期待したいところです。
また、現状は情報源となるドキュメントをリクエストに含める必要がありますが、インデックス内のドキュメントをまとめて指定できるようになると実用性が飛躍的に上がりそうです。

今後 Elasticsearch における検索の幅をより一層広げてくれるであろう NLP機能、今後も目が離せませんね。


ということで今回の記事は以上になります。最後までお読みいただきありがとうございました。


Acroquest Technologyでは、キャリア採用を行っています。

  • ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
  • Elasticsearch等を使ったデータ収集/分析/可視化
  • マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
  • 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長

 
少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。
Kaggle Grandmasterと一緒に働きたエンジニアWanted! - Acroquest Technology株式会社のデータサイエンティストの採用 - Wantedlywww.wantedly.com

Amazon Aurora Serverless v2 で、PostGISを利用した位置情報検索の性能を試す

久しぶりにピアノの基礎練を始めたphonypianistです。 ハノンは指のトレーニング・リハビリには最適です。単純な音階なので、弾いてて楽しくはないですが😓 指を動かすと脳も活性化する話もありますが、その目的ならピアノでなくてもPCのキーボードをひたすら打っても良いのかも?🤔

さて、少し前に、Aurora Serverless v2が一般提供されて、LambdaからAuroraを使うのが、かなり実用的になりました。 aws.amazon.com

v1に比べると、おおよそ以下の点が改善されています

v1に引き続き、PostgreSQLにも対応しているため、PostGISも利用可能になっています。

今回、Aurora Serverless v2でPostGISを使った際に、どれくらいの性能が出るかを計測してみました。

概要

あらかじめ、ランダムに生成した位置情報のレコードをデータベースに入れておきます。 そして、ある地点から指定した半径以内にある点を取得する処理を行います。

検索イメージ

この取得処理にかかる処理時間を計測します。

計測条件

今回は、Lambdaからクエリを発行します。そのため、RDS Proxyを使用します。

Aurora Serverless v2検証構成

Aurora Serverless v2の最小ACUは2、最大ACUは16で設定しました。

データ量や発行するクエリは以下の通りとします。

  • 母体データ件数を100,000件~1,000,000件で変化させる。
  • クエリ実行でヒットする件数を100件~1000件で変化させる。
  • use_spheroid=true(回転楕円体を使った計算)でST_DWithin関数を用いて位置検索を行う(指定した点から半径xxxメートル以内のレコードを抽出)。

Lambdaから実際に発行するクエリは以下の通りです。数値部分は上記条件に合うように適宜変更します。

SELECT
    address,
    ST_AsGeoJson(geometry)
FROM
    points
WHERE
    ST_DWithin(geometry, ST_GeomFromText('POINT(139.6147861 35.5080426)', 4326), 1000, true);

計測結果

母体データ件数とヒットする件数を変化させて処理時間を計測した結果は以下となりました。(単位は秒)

↓母体件数\ヒット件数→ 100件 500件 1000件
100,000 1.69 1.74 1.70
500,000 2.04 2.00 2.11
1,000,000 2.52 2.63 2.78
5,000,000 7.12 7.12 7.12
10,000,000 12.20 11.72 11.68

グラフにすると以下のようになります。

位置情報検索の性能傾向

母体データ件数とほぼ比例して、処理時間も長くなっています。 検索でヒットする件数の影響はあまりなさそうです。

1000万件で10秒台なので、これくらいのデータ量なら、API Gateway+Lambda経由で実行しても同期処理が可能ですね。

まとめ

Aurora Serverless v2でも、(当然ではありますが)PostGISを使用できました。 Aurora Serverless v2でRDS Proxyに対応したこともあり、v1より便利にLambdaから使えるようになっています。 ぜひお試しください。

それでは!

Acroquest Technologyでは、キャリア採用を行っています。
  • ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
  • Elasticsearch等を使ったデータ収集/分析/可視化
  • マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
  • 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長
  少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。 www.wantedly.com

Athenaでデータの格納形式ごとのクエリ実行性能を比較してみた

こんにちは、唄うエンジニア、miyajima です。

仕事の中でAmazon Athenaを利用する機会があったため、今回はそのAmazon Athenaを使った性能比較を試してみました。


Amazon Athena とは - Amazon Athena

Amazon Athena はS3に保存されているファイルのデータをSQL形式のクエリで直接検索することができるサービスです。
大量のデータを高速に分析、検索することができ、また一度テーブル定義すれば、あとはS3にファイルを置くだけで検索できるようになるのでデータの追加変更も容易です。またサーバレスですから費用は検索に使用した分だけとなり、運用コストも抑えられます。

Athenaサポートしているデータ形式は、以下のように多数用意されています。

  • 一般的なデータファイル形式: CSV, TSV, JSON
  • Hadoopの分散処理に適用した形式: Apache Avro, Apache Parquet, ORC
  • その他、 Apache WebServer、CloudTrail、Logstashのログ形式など

サポートされる SerDes とデータ形式 - Amazon Athena

またGZIPなどでの圧縮したデータも扱えます。

そのため、それぞれのデータ形式によって性能差がどのくらいあるのかは把握しておきたいところです。
また、AthenaはS3上の指定したパス上の全ファイルを1つのテーブルとして扱うため、ファイル数やファイルサイズがどのようにパフォーマンスに影響するのかも気になりますね。
ということで、いくつかの観点に沿って実行時間を調べてみました。

比較内容

今回は以下の観点にそって性能比較を行いました。

観点 内容
ファイルサイズ/ファイル数 データファイルの分割数を 1 / 10 / 100 / 1,000 / 10,000 にした5パターンのデータセットを用意し、それぞれの検索速度を比較
ファイル形式 同一の情報量のデータをCSVJSON、Parquet形式で用意し、検索速度を比較する。
圧縮/非圧縮 同一のデータをgzip圧縮した場合としない場合で比較。この時ファイルサイズによる影響も比較する。
検索方法 各ファイル形式に対して全検索、LIKE検索、列指定検索を比較する。

また、今回はデータ形式を主眼に比較を行っており、パーティションは指定していません。できれば別の機会に、パーティション指定の有無による性能比較も試したいと思います。

今回の調査にあたり使わせていただいたデータは、以下のページで公開されているWebサーバのログファイルです。

AIT Log Data Set V1.1 | Zenodo

この中から、Webサーバのアクセスログ(14万行)を複数のフォーマットに変換してS3に登録し、検索してみました。
mail.cup.com-access.log(抜粋)

192.168.10.190 - - [29/Feb/2020:00:00:02 +0000] "GET /login.php HTTP/1.1" 200 2532 "-" "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:73.0) Gecko/20100101 Firefox/73.0"
192.168.10.4 - - [29/Feb/2020:00:00:09 +0000] "POST /services/ajax.php/kronolith/listTopTags HTTP/1.1" 200 402 "http://mail.cup.com/kronolith/" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Ubuntu Chromium/77.0.3865.90 HeadlessChrome/77.0.3865.90 Safari/537.36"
192.168.10.190 - - [29/Feb/2020:00:00:12 +0000] "POST /login.php HTTP/1.1" 302 601 "http://mail.cup.com/login.php" "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:73.0) Gecko/20100101 Firefox/73.0"
192.168.10.190 - - [29/Feb/2020:00:00:13 +0000] "GET /services/portal/ HTTP/1.1" 200 7696 "http://mail.cup.com/login.php" "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:73.0) Gecko/20100101 Firefox/73.0"
192.168.10.190 - - [29/Feb/2020:00:00:14 +0000] "GET /themes/default/graphics/head-bg.png HTTP/1.1" 200 380 "http://mail.cup.com/themes/default/screen.css" "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:73.0) Gecko/20100101 Firefox/73.0"
192.168.10.190 - - [29/Feb/2020:00:00:14 +0000] "GET /themes/default/graphics/logo.png HTTP/1.1" 200 2607 "http://mail.cup.com/themes/default/screen.css" "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:73.0) Gecko/20100101 Firefox/73.0"
::1 - - [29/Feb/2020:00:00:18 +0000] "OPTIONS * HTTP/1.0" 200 110 "-" "Apache/2.4.25 (Debian) OpenSSL/1.0.2u (internal dummy connection)"
192.168.10.190 - - [29/Feb/2020:00:00:19 +0000] "GET /mnemo/ HTTP/1.1" 200 5681 "http://mail.cup.com/services/portal/" "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:73.0) Gecko/20100101 Firefox/73.0"
192.168.10.190 - - [29/Feb/2020:00:00:22 +0000] "GET /services/portal/ HTTP/1.1" 200 7053 "http://mail.cup.com/nag/list.php" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Ubuntu Chromium/77.0.3865.90 HeadlessChrome/77.0.3865.90 Safari/537.36"
192.168.10.190 - - [29/Feb/2020:00:00:26 +0000] "GET /mnemo/ HTTP/1.1" 200 5179 "http://mail.cup.com/services/portal/" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Ubuntu Chromium/77.0.3865.90 HeadlessChrome/77.0.3865.90 Safari/537.36"
...

なお今回の性能比較の中で、全く同じ条件(データ形式、ファイルサイズ、ファイル数)でクエリ実行した場合でも、実行時間が平均値に対して20~40%程度のバラつきがありました。それもコールドスタートで発生するような「初回のみ遅い」という訳ではなく、連続で実行した時も途中の実行回だけ他よりも時間がかかる、というようなこともありました。
これは、AWS内部リソース割り当ての調整などで実行時間に差が発生しているのかもしれません。

これらのばらつきの影響を軽減するため、今回の測定では、5回測定した処理時間のうち最大値・最小値を除いた中央3値の平均を測定値とすることとしました(オリンピックでの芸術点の付け方と同じです)。

また、今回使用したデータは生の状態でも数百MB程度のものですが、GB、TB単位の大量のデータを扱う場合は、今回の誤差の影響は相対的に小さくなるのではないか、と思います。

比較結果

ファイルサイズ/ファイル数による比較

まず検証方法ですが、情報としては全く同じデータをJSONCSV、Parquetの形式で用意し、更にそのデータを1 / 10 / 100 / 1,000 / 10,000 のファイル数で分割したデータセットを作成します。
つまり、ここまでで以下の15パターンのデータセットが用意されました。

形式 ファイル数 1ファイルのデータ件数 1ファイルのサイズ(平均)
JSON
1
148,534
49,019.3 kB
JSON
10
14,854
4,901.9 kB
JSON
100
1,486
490.2 kB
JSON
1,000
149
49.0 kB
JSON
10,000
15
4.9 kB
CSV
1
148,534
37,124.2 kB
CSV
10
14,854
3,712.4 kB
CSV
100
1,486
371.2 kB
CSV
1,000
149
37.1 kB
CSV
10,000
15
3.7 kB
Parquet
1
148,534
-
Parquet
10
14,854
-
Parquet
100
1,486
-
Parquet
1,000
149
-
Parquet
10,000
15
-

※Parquetのデータはファイルを直接読み込んで得る形式ではなく、CSVデータをを元に変換しています。

そしてこれらのパターンそれぞれをAthenaの別々のテーブルに格納し、クエリを実行しました。

クエリ実行時間は、以下のようになりました。

ファイル分割数 1ファイルのデータ数 JSON CSV Parquet
1 148,534 2.17秒 2.19秒 1.90秒
10 14,854 2.15秒 2.10秒 1.83秒
100 1,486 2.11秒 2.34秒 1.89秒
1,000 149 2.03秒 2.65秒 2.26秒
10,000 15 3.09秒 3.04秒 2.49秒

ファイルサイズ・ファイル数での比較

今回は、そこまで影響は大きくはなかったですが、どの形式もサイズが小さいファイルが多くなると、処理時間が長くなる傾向が見えました。

圧縮/非圧縮の比較

今回はgzip圧縮した場合と非圧縮の場合、またファイル形式ははJSONCSVで試しました。
さらに、圧縮したデータの操作はファイルサイズや分割数にも影響する可能性があるため、ファイル分割数を上と同様にいくつかのパターンで実施しました。

ファイル分割数 1ファイル 100ファイル 10000ファイル
JSON:非圧縮 2.17秒 2.11秒 3.09秒
JSON:gzip 3.03秒 2.62秒 3.77秒
CSV:非圧縮 2.19秒 2.34秒 3.04秒
CSV:gzip 2.59秒 2.47秒 3.06秒

圧縮・非圧縮での比較

今回のデータではJSONCSVともに、検索にかかる時間はGZIP圧縮した方がわずかに長かったです。圧縮したデータを扱う方が効率は良くなるように感じますが、検証に用意出来たデータのサイズがそれほど大きくないため、あまりその効果が反映されなかったのかもしれません。

またデータ保存の観点では、データを圧縮した方がコスト効率がよいため、その点を踏まえて実際のデータで試した方がいでしょう。

ファイル形式と検索方法による比較

次はクエリによる比較です。 今回検証に使用したクエリは以下の3パターンです。

  • 全検索("SELECT * FROM <テーブル名>")
  • LIKE検索("SELECT * FROM <テーブル名> WHERE path LIKE '%.php%' ")
  • 列指定検索("SELECT host, path, status FROM <テーブル名>")

この3つのクエリで、JSON / CSV / Parquet の3種類の形式を、更にファイル分割数も替えて検索しました。

全検索

データ形式 1ファイル 100ファイル 10,000ファイル
JSON 2.17秒 2.11秒 3.09秒
CSV 2.19秒 2.34秒 3.04秒
Parquet 1.90秒 1.89秒 2.49秒

ファイル形式での比較(全検索)

LIKE検索

データ形式 1ファイル 100ファイル 10,000ファイル
JSON 1.82秒 1.53秒 2.69秒
CSV 1.94秒 1.64秒 2.71秒
Parquet 2.10秒 1.92秒 1.49秒

ファイル形式での比較(LIKE検索)

列指定検索

データ形式 1ファイル 100ファイル 10,000ファイル
JSON 1.80秒 1.54秒 2.60秒
CSV 1.46秒 1.35秒 2.59秒
Parquet 1.45秒 1.50秒 1.63秒

ファイル形式での比較(列指定検索)

CSVJSONは、どの検索方式でもファイル数が増えるにしたがって処理時間も増えているのが分かりますが、ParquetはLIKE検索、列検索の際、ファイル数が増えても処理時間が増えていません(というかLIKE検索は処理時間が減ってすらいます)。
Parquetは列指向データ形式といい、列単位でデータを取得、検索する際に最適化されたフォーマットです。この形式はデータの格納や検索を最適化するためのエンコーディングが複数用意されています。
エンコーディング方法の例:

エンコーディング方式 概要 効果の高いデータ列
Dictionary Encoding 出現する値を辞書に格納し、テーブルにはそのキーを格納する 出現する値の種類が限られているようなデータ
Run Length Encoding (RLE) 値と、その値を繰り返す回数のペアを格納する 同じ値が連続して出現するようなデータ
Delta Encoding 一定件数ごとに区切った中でのデータの前後の差分を抽出してその差分を保持する 一定のペースで変化する数値データ

参照: Encodings | Apache Parquet

そのため、これらの方式に適するデータを格納するようなスキーマを設計し、またそのカラムを検索、抽出の対象にできれば、よりParquetの特性を生かし、処理を高速化させる事が出来るのではないかと思います。

また今回はデータ量が数十MB程度しかありませんでしたので、それを強く実感できるほどではないのかもしれません。機会があれば、もっと多量のデータで検証してみたいと思います。

まとめ

今回の検証のまとめです。

  • 大量データを扱う場合は特に、処理時間、コストの両方の観点から、Parquet形式を積極的に使う。
  • ファイル数が細かく、多くなると、処理時間が長くなる傾向がある。
  • 今回のデータでは圧縮ファイルを使った場合の性能改善は見られなかった。更に大きなデータでの検証が必要。

それぞれで性能に差分は確認できましたが、より多くのデータ、複雑な条件になった時には、異なる傾向が見えるかもしれません。 機会があれば、別のデータセットで試してみたいと思います。

また、Athenaの性能改善を行う上で、AWS公式のパフォーマンスチューニングの情報は把握しておくとよいでしょう。

Amazon Athena のパフォーマンスチューニング Tips トップ 10 | Amazon Web Services ブログ

Acroquest Technologyでは、キャリア採用を行っています。
  • ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
  • Elasticsearch等を使ったデータ収集/分析/可視化
  • マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
  • 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長
  少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。 www.wantedly.com

エッジでLookout for Visionを実行すると爆速だった話

こんにちは、機械学習エンジニアの駿です。

先日庭の花壇にヒマワリの種をまいたのですが、早速芽が出てきました。
夏に黄色い花が咲くのが今から楽しみです。

今回はAWSの外観検査サービスである Amazon Lookout for Vision が、 AWS IoT Greengrass を使ってエッジでの推論が可能になったため、試してみました。

今までは Lookout for Visionクラウド側での推論(判定)しかできず、画像をネットワーク越しに送るため、オーバーヘッドが発生していました。
そのため、例えば工場の生産ラインで外観検査をしていたとすると、製品の撮影・検査をした後で、異常なモノを仕分けをしたりしますが、それに時間がかかってしまう、といった状況が発生することになってしまいます。

しかし、 Lookout for Vision のモデルを Greengrass でパッケージ化しエッジデバイスに配置して検査を行うことで、画像を送信するオーバーヘッドがなくなり、リアルタイムで処理ができるようになります。

Amazon Lookout for Vision

Lookout for Visionは画像が正常か異常を判定する、異常検知のサービスです。
教師画像は最低30枚と少ない枚数で始められますが、非常に高い精度で異常を判定することができます。

詳細については以下の記事があるので、そちらをご覧ください。

acro-engineer.hatenablog.com

AWS IoT Greengrass

Greengrass は Raspberry Piなどのエッジデバイス上にIoTアプリケーションを構築、デプロイ、管理するためのクラウドサービスです。
Lambda関数、Dockerコンテナなどをコンポーネントとしてパッケージ化し、エッジデバイスにデプロイ、実行することができます。

そんな Greengrass が2021年末に Lookout for Vision モデルのコンポーネントに対応し、
GPUを備えたエッジデバイスにデプロイすることで、クラウドに画像を送ることなく、オンプレミスかつリアルタイムの外観検査ができるようになりました。

構成

今回はエッジデバイスにJetsonNanoを使用し、 Greengrass サービスを通じて、 Lookout for Vision モデルコンポーネントと、 それを利用するための EdgeAgent コンポーネントをデプロイします。

今回の構成図(公式ドキュメントより)

手順

大まかな流れは下記のようになります。

  1. JetsonNano 上で Greengrass のサービスを起動する
  2. Lookout for Vision モデルを Greengrass 用にコンポーネント化する
  3. モデルコンポーネントAWSが提供している EdgeAgent を JetsonNano にデプロイする
  4. Python スクリプトからデプロイされたモデルを起動し、検査を実行する

早速、詳細の説明に入ります。

1. JetsonNano 上で Greengrass のサービスを起動する

公式ドキュメントでは動作環境としてJetson Xavierを推奨していますが、すぐに用意できなかったので、今回は JetsonNano を使ってみました。
ただし、性能を考慮した場合、実際の運用などでは Jetson Xavier などを利用するほうが安全だと思われます。

(1) JetsonNano に JetPack をインストール

エッジデバイス上で Lookout for Vision のモデルを使用するためには、GPUとCUDA、TensorRT などのライブラリが必要です。
JetPack を使うと、これらのライブラリをひとつひとつインストールする必要がなく、 既に整った環境を作ることができます。

JetPack4.5.1 のSDカードイメージをダウンロードし、公式ドキュメント に従ってセットアップを行います。

なお、Greengrass でデプロイする Lookout for Vision モデルが対応しているのが、バージョン4.4と4.5系のみのため、間違えて最新バージョンを入れないように注意が必要です。

(2) Greengrass サービスおよび Lookout for Vision モデルの起動に必要なライブラリのインストール

Greengrass サービスの起動に Java が必要になります。
今回はAWSが提供する Corretto をインストールしました。

また、Lookout for Vision のモデルの起動に Python3.8 もしくは 3.9 が必要です。
今回は 3.8 をインストールしました。

(3) クライアントアプリに必要なライブラリのインストール

公式ドキュメント に従って、 grpc をインストールし、サービス定義ファイルからクライアントインターフェイスを生成します。

また、画像読み込みのための Pillow もインストールします。

(4) Greengrass サービスのダウンロードおよび起動

Greengrass のコンソールから 「1つのCoreデバイスをセットアップ」を選択します。

「1つのCoreデバイスをセットアップ」

コアデバイス名など必要な情報を入力すると、インストーラのダウンロードコマンドおよび実行コマンドが生成されるので、エッジデバイス上でコマンドを実行します。

インストールが終わると、自動で Greengrass サービスの実行が始まります。

また、インストール中に Greengrass ユーザ ggc_user が作成されています。
そのままでは Greengrass でデプロイされたコンポーネントGPU にアクセスできないため、 ggc_user を video グループに追加します。

sudo usermod -a -G video ggc_user

Greengrass コンソールから追加したエッジデバイスが確認でき、ステータスが「正常」となっていたら Greengrass の準備は完了です。

JetsonNanoがCoreデバイスとして設定できた

(5) DLR のインストール

Lookout for Vision が libdlr.so を必要とするのですが、 pip でインストールできるDLRには .so ファイルが含まれていないようです。

ggc_user として whl からインストールすることで、 モデルコンポーネントが読み込めるようになります。
下記コマンドを ggc_user として実行してください。

curl -O https://neo-ai-dlr-release.s3-us-west-2.amazonaws.com/v1.10.0/jetpack4.5/dlr-1.10.0-py3-none-any.whl
python3.8 -m pip install dlr-1.10.0-py3-none-any.whl

2. Lookout for Vision モデルを Greengrass 用にコンポーネント化する

既に学習済みのモデルがあるものとします。

プロジェクトのページに遷移し左のメニューで「モデルのパッケージ」を選択します。
「モデルパッケージングジョブを作成」ボタンを押し、「モデルの選択」など必要な項目を埋めていきます。

「モデルパッケージングジョブを作成」

  • ターゲットハードウェア設定

    2022/05/07現在、プリセットの設定は Jetson Xavier 用しかないため、JetsonNano を使用する際は「ターゲットプラットフォーム」を選択して、設定を行います。

    項目
    オペレーティングシステム LINUX
    アーキテクチャ ARM64
    アクセラレーター NVIDIA
    コンパイラオプション {"gpu-code": "sm_53", "trt-ver": "7.1.3", "cuda-ver": "10.2"}

    コンパイラオプションは使用するデバイス、インストールしたライブラリのバージョンによって異なります。
    GPUコード、TensorRT バージョン、CUDA バージョンをそれぞれ指定します。
    上記値は JetsonNano+JetPack4.5.1の場合の値です。

    ターゲットハードウェアの設定

最後に「モデルパッケージングジョブを作成」ボタンを押して、パッケージングを開始します。
コンソール上で、「成功」ステータスになったら完了です。

Greengrass コンソールのコンポーネントページからも作成したモデルコンポーネントを確認することができます。

3. モデルコンポーネントAWSが提供している EdgeAgent を JetsonNano にデプロイする

Greengrass コンソールのデプロイページからデプロイを作成します。

今回はデプロイターゲットにコアデバイスを選択し、上で登録した JetsonNano にのみデプロイします。

コンポーネントの選択」画面で、パッケージした Lookout for Vision モデルコンポーネントを選択します。
このコンポーネントAWSが提供する aws.iot.lookoutvision.EdgeAgent に依存しているため、 そちらも自動でデプロイされます。

コンポーネントの選択」

そのほかの設定を変更する必要は必要ありません。
「デプロイ」を選択して、 JetsonNano にデプロイします。

デプロイのステータスが「完了」になったら成功です。

4. Pythonスクリプトからデプロイされたモデルを起動し、検査を実行する

エッジデバイス上で ggc_user としてログインします。

Pythonインタプリタを起動し、モデルの起動と検査を試してみます。

(1) モデルの起動

import grpc
from edge_agent_pb2_grpc import EdgeAgentStub
import edge_agent_pb2 as pb2
channel = grpc.insecure_channel("unix:///tmp/aws.iot.lookoutvision.EdgeAgent.sock")
stub = EdgeAgentStub(channel)
model_component_name = "lfv_component_aarm"

# まずはモデルが止まっていることを確認します。
model_description_response = stub.DescribeModel(pb2.DescribeModelRequest(model_component=model_component_name))
model_description_response.model_description.status == pb2.STOPPED
# -> True

# モデルを起動します
stub.StartModel(pb2.StartModelRequest(model_component=model_component_name))
# -> status: STARTING
# 起動するのを待ってから
model_description_response = stub.DescribeModel(ob2.DescribeModelRequest(model_component=model_component_name))
model_description_response.model_description.status == pb2.RUNNING
# -> True

(2) 検査実行

モデルが起動したら画像に対して検査を実行できます。

エッジデバイスに画像を用意して、Pillow で読み込んだものをモデルに送ります。

from PIL import Image

image = Image.open(image_path)
image = image.convert("RGB")
detect_anomalies_response = stub.DetectAnomalies(
    pb2.DetectAnomaliesRequest(
        model_component=model_component_name,
        bitmap=pb2.Bitmap(
            width=image.size[0],
            height=image.size[1],
            byte_data=bytes(image.tobytes())
        )
    )
)

is_anomalous = detect_anomalies_response.detect_anomaly_result.is_anomalous
confidence = detect_anomalies_response.detect_anomaly_result.confidence
print(f"Image is anomalous - {is_anomalous}")
print(f"confidence - {confidence:.2}")
# -> Image is anomalous - True
# -> confidence - 0.97

(3) モデルの停止

stub.StopModel(StopModelRequest(model_component=model_component_name))
model_description_response = stub.DescribeModel(ob2.DescribeModelRequest(model_component=model_component_name))
model_description_response.model_description.status == pb2.STOPPED
# -> True

channel.close()

結果

精度

前述の投稿で使用している Metal Nut 画像から正常を5枚、異常を5枚使って検査を実行したところ、すべて正しく判定することができました。

クラウド側で判定した場合と比較して、エッジデバイス上で判定しても同等の精度が出ることがわかりました。

実行時間

下記の関数を作り、画像1枚を推論するのにかかる時間を計測してみました。

def timeit():
    start_time = time.time()
    stub.DetectAnomalies(...)
    print(f"took {time.time() - start_time} seconds")

上記関数を10回実行した結果は下記のようになり、平均で0.25秒/枚で検査ができることになります。

# 所要時間(ms)
1 399.8
2 374.2
3 212.9
4 216.0
5 217.8
6 212.5
7 215.5
8 212.8
9 218.3
10 216.5
平均 249.6

awscli でクラウド側の Lookout for Vision モデルを使用した場合はネットワークの往復も含めて約3秒でした。

awscli の場合は画像変換の時間も含まれるため単純な比較はできませんが、エッジで実行することで10倍以上早くなっています。
Jetson Xavier などより計算力のあるデバイスを用いることで、さらにリアルタイム性のある外観検査ができるようになりますね。

オレゴンリージョンの Lookout for Vision を使用しているため、東京リージョンのものを用いるよりもさらに伝送時間がかかっていると思われます。)

料金

Greengrass を使ってエッジデバイスにデプロイした Lookout for Vision モデルを使用する場合、エッジ推論ユニットに基づいて月額料金がかかります。
1デバイス上で120検査/分までの検査は1エッジ推論ユニットとして扱われ、1エッジ推論ユニットは月100USD かかります。

個人で実施するにはちょっとお高めになっているので、工場など大規模で検査を行う必要があるユースケースを想定しているのがわかります。

まとめ

今回はGreengrassを使ってLookout for VisionのモデルをJetson Nanoにデプロイし、外観検査を行いました。
普段エッジデバイスを使うことがあまりないこともあり、最初の環境構築でバージョン不整合などでつまづいてしまいました。

実際に動かしてみて、エッジでの検査は画像をネットワーク越しに送信する必要がない分、10倍も早く実行できることがわかりました。
それでいてクラウドと同じモデルを使っているため、精度は同等です。

また、今回は Pythonインタプリタを使って検査を実行しましたが、クライアントアプリを作成して Greengrass コンポーネントとしてデプロイすることもできます。

今度は実際にカメラを繋いで、どれくらいのFPSが出せるのか、なども試してみたいです。

Acroquest Technologyでは、キャリア採用を行っています。
  • ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
  • Elasticsearch等を使ったデータ収集/分析/可視化
  • マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
  • 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長
  少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。 www.wantedly.com

モデル最適化ソフトウェアOpenVINOを用いた性能高速化とモデル比較の実験

皆さんこんにちは。
@tereka114です。

モデル最適化の選択肢の一つであるOpenVINOを試してみました。
モデル最適化とは、モデルの精度を殆ど落とさず、高速化する技術で、以下のような恩恵が得られることが知られています。

①特にGPU等を利用できない、RasberryPiのようなエッジデバイス上で機械学習モデルを動かすケースで性能を上げられる
②モデル最適化のエンジンが動く環境であれば、一度構築したモデルを複数の環境で実行させることができる

今回、その技術の一つであるOpenVINOとそれを用いた有名なモデルのベンチマークを紹介します。

OpenVINO

Intel社が開発したディープラーニングを高速に実行するためのソフトウェアです。
OpenVINOが学習済のモデルをハードウェアに合わせて最適化し、CPU、GPUなどのアクセラレータで高速で推論できるようにします。

www.intel.com

本記事では、インストールに関して詳細を扱いません。
インストール方法は次のリンク先を参考にしてください。

docs.openvino.ai

PyTorchからOpenVINOを動かす

PyTorchは実装のしやすさとPyTorch Image Models(timmライブラリ)を利用した学習モデルの多様性から私も含め、多くのデータサイエンティストが利用しています。
そのため、今回は、PyTorchで作られたモデルからOpenVINOを動かしてみます。

github.com

OpenVINOを利用する手順

PyTorchのモデルをOpenVINO形式に変換するには次のステップが必要になります。

1. PyTorchのモデルからONNX形式に変換する。
2. ONNXからOpenVINO形式に変換する。
3. OpenVINOモデルで推論する。

モデル変換・推論の流れ
PyTorchのモデルをONNXに変換する。

OpenVINOはPyTorchのモデルを直接変換できないため、まずはONNXに変換します。
PyTorchにONNX変換を行う関数が用意されているため、その関数を利用します。
ここで、モデルはevalを呼び出して推論モードにしておくことが必要です。
なぜならば、推論時の最適化を行う必要があるためDropoutやBatch Normalizationなど学習、推論で挙動が変わるものでは、期待する計算ができなくなります。

また、SwinTransformerには、ONNXに備わっていない演算があるため、その関数を外部からONNXに登録(roll関数)しています。

import torch
import torch.onnx as torch_onnx
import timm
import argparse
import torch
from torch.onnx.symbolic_helper import parse_args, _slice_helper
from sys import maxsize as maxsize


@parse_args('v', 'is', 'is')
def roll(g, input, shifts, dims):
    # Swin Transformerの計算に必要なOperatorを定義
    assert len(shifts) == len(dims)
    result = input
    for i in range(len(shifts)):
        shapes = []
        shape = _slice_helper(g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[maxsize])
        shapes.append(shape)
        shape = _slice_helper(g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]])
        shapes.append(shape)
        result = g.op("Concat", *shapes, axis_i=dims[i])
    return result

parser = argparse.ArgumentParser()
parser.add_argument("--model")
parser.add_argument("--output")
parser.add_argument("--size", type=int)

args = parser.parse_args()

# モデルの読み込み
torch.onnx.symbolic_registry.register_op('roll', roll, '', version=9)
net = timm.create_model(args.model, pretrained=True)
net.eval()
# モデル出力のための設定
model_onnx_path = args.output # 出力するモデルのファイル名
input_names = ["input"] # データを入力する際の名称
output_names = ["output"] # 出力データを取り出す際の名称

# ダミーインプットの作成
input_shape = (3, args.size, args.size) # 入力データの形式
batch_size = 1 # 入力データのバッチサイズ
dummy_input = torch.randn(batch_size, *input_shape) # ダミーインプット生成

# 変換実行
if "swin" in args.model:
    # Swin Transformer用に、ONNXのOpsetを固定
    output = torch_onnx.export(
        net, dummy_input, model_onnx_path,export_params=True, 
        verbose=False, input_names=input_names, output_names=output_names, opset_version=11)
else:
    output = torch_onnx.export(
        net, dummy_input, model_onnx_path,export_params=True, 
        verbose=False, input_names=input_names, output_names=output_names)

この実装を次のコマンドで動かします。

python pytorch_to_onnx.py --model resnet50 --output resnet50.onnx --size 224
ONNXからOpenVINO形式に変換する。

ONNXからOpenVINOへの変換はOpenVINOのモデル最適化コマンドを実行するのみです。
前段のResNet50のモデルを利用して、変換する場合は以下のコマンドです。

python /opt/intel/openvino_2021/deployment_tools/model_optimizer/mo.py --input_model resnet50.onnx
OpenVINOモデルで推論する。

最後にOpenVINOを動作させます。
事前に以下のコマンドでOpenVINOのPythonモジュールをインストールします。

pip install openvino

以下、先程までコンパイルしたモデルの推論の実装です。
PyTorch、ONNX、OpenVINOの推論速度を比較する実装も含まれています

import numpy as np
import time as tm
import timm

import torch
import onnxruntime
from openvino.inference_engine import IECore

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model")
parser.add_argument("--output")
parser.add_argument("--size", type=int)

args = parser.parse_args()
SIZE = int(args.size)

# Pytorchの準備
net = timm.create_model(args.model, pretrained=True)
net.eval()
# ONNXの準備
session = onnxruntime.InferenceSession(f"{args.output}.onnx")

# OpenVINOの準備
ie = IECore()
model_path = f'{args.output}.xml'
weight_path = f'{args.output}.bin'
net_openvino = ie.read_network(model=model_path, weights=weight_path)
exec_net = ie.load_network(network=net_openvino, device_name='CPU', num_requests=1)

# 時間計測用
time_onnx = 0
time_openvino = 0
time_pytorch = 0

# 予測結果比較用
out_onnx = []
out_pytorch = []
out_openvino = []

TIMES = 300
for i in range(TIMES):
    image = torch.rand(1, 3, SIZE, SIZE)
    with torch.no_grad():
        start_time = tm.time()
        out = net(image)
        out_pytorch.append(np.argmax(out[0]))
        time_pytorch += tm.time() - start_time

    start_time = tm.time()
    preds = session.run(["output"], {"input": image.cpu().numpy()})
    out_onnx.append(np.argmax(preds[0]))
    time_onnx += tm.time() - start_time

    start_time = tm.time()
    outputs = exec_net.infer(inputs={'input': image.cpu().numpy()})['output']
    out_openvino.append(np.argmax(outputs[0]))
    time_openvino += tm.time() - start_time

# 推論結果の整合性確認のため
print(np.sum(np.array(out_pytorch) == np.array(out_openvino)))
print(np.sum(np.array(out_pytorch) == np.array(out_onnx)))
print(np.sum(np.array(out_onnx) == np.array(out_openvino)))

# 計算結果
print('PyTorch: ', time_pytorch / TIMES)
print('ONNX: ', time_onnx / TIMES)
print('Open VINO: ', time_openvino / TIMES)

実行は次のコマンドです。

python infer.py --size 224 --model resnet50--output resnet50

性能実験

OpenVINOを利用すればどの程度高速化されるのか
画像認識の有名なモデルと先程の実装を用いて、性能を比較しました。

計測環境

現在一般的に利用されるモデルを中心に計測しました。
モデルの精度・性能の目安はPyTorchのモデル実装の宝庫であるtimmライブラリのリンクをご確認ください。

github.com

結果は次のとおりです。PyTorch(s),ONNX(s),OpenVINO(s)は1枚あたりの推論速度を示しています。
PyTorchと比較して、30-70%ほどの高速化を達成し、また、ONNXよりもほとんどの場合で高速化を達成できました。
また、本方式では、最終的な推論結果は変わりませんでした。

Model Image Size PyTorch(s) ONNX(s) OpenVINO(s) 高速化率(Pytorch) 高速化率(ONNX)
resnet50 224 0.271 0.137 0.112 58.67% 18.25%
resnet152 224 0.795 0.409 0.332 58.24% 18.83%
convnext_tiny 224 0.324 0.201 0.182 43.83% 9.45%
swin_tiny_patch4_window7_224 224 0.317 0.135 0.184 41.96% -36.30%
mobilenetv2_120d 224 0.065 0.031 0.028 56.92% 9.68%
mobilenetv3_large_100_miil 224 0.03 0.014 0.019 36.67% -35.71%
vit_tiny_patch16_224 224 0.088 0.053 0.057 35.23% -7.55%
vgg16 224 1.09 0.31 0.275 74.77% 11.29%
vgg19 224 1.309 0.388 0.339 74.10% 12.63%
tf_efficientnet_b0_ns 224 0.054 0.024 0.024 55.56% 0.00%
tf_efficientnet_b7_ns 224 0.448 0.216 0.18 59.82% 16.67%
モデル最適化

棒グラフはPyTorch(s)、ONNX(s)、OpenVINO(s)の値を表示しており、低い値であればよりよい性能であることを示しています。

最後に

CPUで処理を行った場合、OpenVINOの結果が最も早い場合が多かったです。
IoTデバイス上で動作させた場合に少し性能に満足できない場合にOpenVINOを適用すると良いかもしれません。
推論性能に困った場合の選択肢の一つに入れると良いと思います。

Acroquest Technologyでは、キャリア採用を行っています。


  • ディープラーニング等を使った自然言語/画像/音声/動画解析の研究開発
  • Elasticsearch等を使ったデータ収集/分析/可視化
  • マイクロサービス、DevOps、最新のOSSを利用する開発プロジェクト
  • 書籍・雑誌等の執筆や、社内外での技術の発信・共有によるエンジニアとしての成長

 
少しでも上記に興味を持たれた方は、是非以下のページをご覧ください。
Kaggle Grandmasterと話したいエンジニアWanted! - Acroquest Technology株式会社のデータサイエンティストの採用 - Wantedlywww.wantedly.com