【Python】FastAPI + Pydanticでlist[Enum]の値を受け渡しする

2024-12-04

やること

  • FastAPIでREST APIを作る
  • PydanticでAPIスキーマを定義する
  • リクエストのクエリパラメータにEnumのリストを入れる
  • スキーマはリクエストを投げる側でも使用する

環境

  • Python 3.12.7
  • fastapi 0.115.5
  • pydantic 2.10.2

素直にそのまま持ち回ろうとすると謎エラー出る

まずは素直に書いてみる

from enum import Enum
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel

router = APIRouter()


class Category(Enum):
    BOOKS = "100"
    FOODS = "200"
    FASHION = "300"


class RequestSchema(BaseModel):
    categories: list[Category] = Query(...)


@router.get("")
def func(query: Annotated[RequestSchema, Depends(RequestSchema)]):
    print(f"{query=}")
    return "Passed"
params = RequestSchema(
    categories=[Category.BOOKS],
)

resuests.get(api_url, params=params.model_dump())

これでGETリクエスト投げてみると422 Unprocessable Entityで以下のエラーメッセージが返ってくる

{
    "detail": [
        {
            "type": "missing",
            "loc": [
                "body"
            ],
            "msg": "Field required",
            "input": null
        }
    ]
}

"body"て何だよ今クエリパラメータしか必要としてないぞ

試行錯誤の結果

どうもクエリパラメータでlist型を扱いたい場合は、QueryFieldで囲ってあげなきゃダメっぽい……

つまり

from fastapi import Query
from pydantic import BaseModel, Field

class RequestSchema(BaseModel):
    categories: list[Category] = Field(
        Query(...)
    )

という感じ。

Swagger UIでも以下のような表示になってくれる

ちなみにOptionalにする場合は以下のようにするとよい

class RequestSchema(BaseModel):
    categories: list[Category] = Field(
        Query(default=[]),
    )

参考: https://fastapi.tiangolo.com/ja/tutorial/query-params-str-validations/#_8

リクエスト側で使用するための設定

リクエストを投げる際には、Enum値を自前でvaluenameにしてあげる設定が必要

以下のコードだと正しい形式でリクエストを投げられない

class RequestSchema(BaseModel):
    categories: list[Category] = Field(
        Query(...)
    )

params = RequestSchema(
    categories=[Category.BOOKS],
)

resuests.get(api_url, params=params.model_dump())

これを実行すると、クエリパラメータが?categories=<Category.BOOKS: '100'>のような形式になってしまう

なので、Enumのvalueが渡されるようにひと工夫入れる必要がある

※ここで愚直にmodel_config = ConfigDict(use_enum_values=True)とか設定してしまうと以下のような罠にはまるので注意

params = RequestSchema(
    categories=[Category.BOOKS],
)

print(params.categories)

# Output: ['100']

params.categoriesの型はlist[Category]のはずなのに、list[str]型の値が取得されてしまう……

これじゃ何のためにEnumで定義したんじゃとなりますね


という訳で、別の手段でmodel_dump()した時にEnumのvalueが使われるようにします

from pydantic import field_serializer

...
@field_serializer("categories")
    def serialize_categories(self, categories: list[Category]):
        return [c.value for c in categories]

上記のような定義をパラメータごとに作ってあげれば大丈夫

パラメータごとってのがチョット面倒だけど

参考: https://docs.pydantic.dev/latest/concepts/validators/#field-validators

完成系: list[Enum]のクエリパラメータ値をvalueで受け渡し

バック側
from enum import Enum
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field, field_serializer

router = APIRouter()


class Category(Enum):
    BOOKS = "100"
    FOODS = "200"
    FASHION = "300"


class RequestSchema(BaseModel):
    categories: list[Category] = Field(
        Query(..., title="Categories"),
    )

    @field_serializer("categories")
    def serialize_categories(self, categories: list[Category]):
        return [c.value for c in categories]


@router.get("")
def func(query: Annotated[RequestSchema, Depends(RequestSchema)]):
    print(f"{query=}")
    return "Passed"
フロント側
params = RequestSchema(
    categories=[Category.BOOKS, Category.FOODS],
)
requests.get(api_url, params=params.model_dump())
実行ログ
query=RequestSchema(categories=[<Category.BOOKS: '100'>, <Category.FOODS: '200'>])
INFO:     127.0.0.1:51595 - "GET /v1/temp?categories=100&categories=200 HTTP/1.1" 200 OK

上手くEnum値を受け取れました🎉

おまけ: list[Enum]のクエリパラメータ値をnameで受け渡し

Enumのnameで受け渡ししたい場合は普通に面倒

頑張ればスキーマ上の定義をEnumのままにできるけど、一旦str型で持ち回った方が圧倒的に楽で間違いない

from enum import Enum
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field, field_serializer, field_validator

router = APIRouter()


class Color(Enum):
    VIOLET = ("#5a4498", "バイオレット")
    INDIGO = ("#043c78", "インディゴ")
    MAGENTA = ("#e4007f", "マゼンタ")


class RequestSchema(BaseModel):
    colors: list[str] = Field(
        Query(..., title="Colors"),
    )

    @field_serializer("colors")
    def serialize_colors(self, colors: list[Color]):
        return [c.name for c in colors]

    @field_validator('colors')
    @classmethod
    def validate_colors(cls, v):
        # カラー名が有効かどうかを検証
        valid_colors = [color.name for color in Color]
        for color in v:
            if color not in valid_colors:
                raise ValueError(f"Invalid color: {color}. Must be one of {valid_colors}")
        return v

    def to_colors_enum(self) -> list[Color]:
        return [Color[c_name] for c_name in self.colors]


@router.get("")
def func(query: Annotated[RequestSchema, Depends(RequestSchema)]):
    print(f"{query=}")
    colors = query.to_colors_enum()
    print(f"{colors=}")
    return "Passed"
スポンサーリンク