Skip to content

Commit

Permalink
feat: Capability for Pydantic-based OpenAPI response schemas (#3795)
Browse files Browse the repository at this point in the history
  • Loading branch information
khvn26 authored Apr 25, 2024
1 parent fe9d05a commit 609deaa
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 55 deletions.
132 changes: 94 additions & 38 deletions api/api/openapi.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,94 @@
from copy import deepcopy
from functools import lru_cache

import jsonref
from drf_yasg.openapi import Schema
from flag_engine.environments.models import EnvironmentModel

SKIP_PROPERTIES = [
"amplitude_config",
"dynatrace_config",
"heap_config",
"mixpanel_config",
"rudderstack_config",
"segment_config",
"webhook_config",
]
SKIP_DEFINITIONS = ["IntegrationModel", "WebhookModel"]


@lru_cache()
def get_environment_document_response() -> Schema:
model_json_schema = EnvironmentModel.model_json_schema(mode="serialization")

# Restrict segment rule recursion to two levels.
segment_rule_schema = deepcopy(model_json_schema["$defs"]["SegmentRuleModel"])
del segment_rule_schema["properties"]["rules"]
model_json_schema["$defs"]["SegmentRuleInnerModel"] = segment_rule_schema
model_json_schema["$defs"]["SegmentRuleModel"]["properties"]["rules"]["items"][
"$ref"
] = "#/$defs/SegmentRuleInnerModel"

# Remove integrations.
for prop in SKIP_PROPERTIES:
del model_json_schema["properties"][prop]
for definition in SKIP_DEFINITIONS:
del model_json_schema["$defs"][definition]

return Schema(**jsonref.replace_refs(model_json_schema))
import inspect
from typing import Any

from drf_yasg.inspectors import SwaggerAutoSchema
from drf_yasg.openapi import SCHEMA_DEFINITIONS, Response, Schema
from pydantic import BaseModel
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import core_schema


class _GenerateJsonSchema(GenerateJsonSchema):
def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue:
"""Generates an OpenAPI 2.0-compatible JSON schema that matches a schema that allows null values.
(The catch is OpenAPI 2.0 does not allow them, but some clients are capable
to consume the `x-nullable` annotation.)
Args:
schema: The core schema.
Returns:
The generated JSON schema.
"""
anyof_schema_value = super().nullable_schema(schema)
elem = next(
any_of
for any_of in anyof_schema_value["anyOf"]
if any_of.get("type") != "null"
)
if type := elem.get("type"):
return {"type": type, "x-nullable": True}
# Assuming a reference here (which we can not annotate)
return elem


class PydanticResponseCapableSwaggerAutoSchema(SwaggerAutoSchema):
"""
A `SwaggerAutoSchema` subclass that allows to generate view response Swagger docs
from a Pydantic model.
Example usage:
```
@drf_yasg.utils.swagger_auto_schema(
responses={200: YourPydanticSchema},
auto_schema=PydanticResponseCapableSwaggerAutoSchema,
)
def your_view(): ...
```
To adapt Pydantic-generated schemas, the following is taken care of:
1. Pydantic-generated definitions are unwrapped and added to drf-yasg's global definitions.
2. Rather than using `anyOf`, nullable fields are annotated with `x-nullable`.
3. As there's no way to annotate a reference, all nested models are assumed to be `x-nullable`.
"""

def get_response_schemas(
self,
response_serializers: dict[str | int, Any],
) -> dict[str, Response]:
result = {}

definitions = self.components.with_scope(SCHEMA_DEFINITIONS)

for status_code in list(response_serializers):
if inspect.isclass(response_serializers[status_code]) and issubclass(
model_cls := response_serializers[status_code], BaseModel
):
model_json_schema = model_cls.model_json_schema(
mode="serialization",
schema_generator=_GenerateJsonSchema,
ref_template=f"#/{SCHEMA_DEFINITIONS}/{{model}}",
)

for ref_name, schema_kwargs in model_json_schema.pop("$defs").items():
definitions.setdefault(
ref_name,
maker=lambda: Schema(
**schema_kwargs,
# We can not annotate references with `x-nullable`,
# So just assume all nested models as nullable for now.
x_nullable=True,
),
)

result[str(status_code)] = Response(
description=model_cls.__name__,
schema=Schema(**model_json_schema),
)

del response_serializers[status_code]

return {**super().get_response_schemas(response_serializers), **result}
1 change: 1 addition & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@
EMAIL_USE_TLS = env.bool("EMAIL_USE_TLS", default=True)

SWAGGER_SETTINGS = {
"DEFAULT_AUTO_SCHEMA_CLASS": "api.openapi.PydanticResponseCapableSwaggerAutoSchema",
"SHOW_REQUEST_HEADERS": True,
"SECURITY_DEFINITIONS": {
"Private": {
Expand Down
14 changes: 14 additions & 0 deletions api/environments/sdk/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from flag_engine.environments.models import EnvironmentModel

from util.pydantic import exclude_model_fields

SDKEnvironmentDocumentModel = exclude_model_fields(
EnvironmentModel,
"amplitude_config",
"dynatrace_config",
"heap_config",
"mixpanel_config",
"rudderstack_config",
"segment_config",
"webhook_config",
)
4 changes: 2 additions & 2 deletions api/environments/sdk/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from rest_framework.response import Response
from rest_framework.views import APIView

from api.openapi import get_environment_document_response
from environments.authentication import EnvironmentKeyAuthentication
from environments.models import Environment
from environments.permissions.permissions import EnvironmentKeyPermissions
from environments.sdk.schemas import SDKEnvironmentDocumentModel


class SDKEnvironmentAPIView(APIView):
Expand All @@ -17,7 +17,7 @@ class SDKEnvironmentAPIView(APIView):
def get_authenticators(self):
return [EnvironmentKeyAuthentication(required_key_prefix="ser.")]

@swagger_auto_schema(responses={200: get_environment_document_response()})
@swagger_auto_schema(responses={200: SDKEnvironmentDocumentModel})
def get(self, request: HttpRequest) -> Response:
environment_document = Environment.get_environment_document(
request.environment.api_key
Expand Down
4 changes: 2 additions & 2 deletions api/environments/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from rest_framework.request import Request
from rest_framework.response import Response

from api.openapi import get_environment_document_response
from environments.permissions.permissions import (
EnvironmentAdminPermission,
EnvironmentPermissions,
NestedEnvironmentPermissions,
)
from environments.sdk.schemas import SDKEnvironmentDocumentModel
from features.versioning.tasks import enable_v2_versioning
from permissions.permissions_calculator import get_environment_permission_data
from permissions.serializers import (
Expand Down Expand Up @@ -199,7 +199,7 @@ def user_permissions(self, request, *args, **kwargs):
serializer = UserObjectPermissionsSerializer(instance=permission_data)
return Response(serializer.data)

@swagger_auto_schema(responses={200: get_environment_document_response()})
@swagger_auto_schema(responses={200: SDKEnvironmentDocumentModel})
@action(detail=True, methods=["GET"], url_path="document")
def get_document(self, request, api_key: str):
return Response(Environment.get_environment_document(api_key))
Expand Down
13 changes: 1 addition & 12 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ python-gnupg = "^0.5.1"
django-redis = "^5.4.0"
pygithub = "2.1.1"
hubspot-api-client = "^8.2.1"
jsonref = "^1.1.0"

[tool.poetry.group.auth-controller]
optional = true
Expand Down
71 changes: 71 additions & 0 deletions api/tests/unit/api/test_unit_openapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pydantic
from drf_yasg.openapi import (
SCHEMA_DEFINITIONS,
ReferenceResolver,
Response,
Schema,
)
from pytest_mock import MockerFixture

from api.openapi import PydanticResponseCapableSwaggerAutoSchema


def test_pydantic_response_capable_auto_schema__renders_expected(
mocker: MockerFixture,
) -> None:
# Given
class Nested(pydantic.BaseModel):
usual_str: str
optional_int: int | None = None

class ResponseModel(pydantic.BaseModel):
nested_once: Nested
nested_list: list[Nested]

auto_schema = PydanticResponseCapableSwaggerAutoSchema(
view=mocker.MagicMock(),
path=mocker.MagicMock(),
method=mocker.MagicMock(),
components=ReferenceResolver("definitions", force_init=True),
request=mocker.MagicMock(),
overrides=mocker.MagicMock(),
)

# When
response_schemas = auto_schema.get_response_schemas({200: ResponseModel})

# Then
assert response_schemas == {
"200": Response(
description="ResponseModel",
schema=Schema(
title="ResponseModel",
required=["nested_once", "nested_list"],
type="object",
properties={
"nested_list": {
"items": {"$ref": "#/definitions/Nested"},
"title": "Nested List",
"type": "array",
},
"nested_once": {"$ref": "#/definitions/Nested"},
},
),
),
}
nested_schema = auto_schema.components.with_scope(SCHEMA_DEFINITIONS).get("Nested")
assert nested_schema == Schema(
title="Nested",
required=["usual_str"],
type="object",
properties={
"optional_int": {
"default": None,
"title": "Optional Int",
"type": "integer",
"x-nullable": True,
},
"usual_str": {"title": "Usual Str", "type": "string"},
},
x_nullable=True,
)
21 changes: 21 additions & 0 deletions api/util/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pydantic import BaseModel, create_model


def exclude_model_fields(
model_cls: type[BaseModel],
*exclude_fields: str,
) -> type[BaseModel]:
"""
Create a copy of a model class without the fields
specified in `exclude_fields`.
"""
fields = {
field_name: (field.annotation, field)
for field_name, field in model_cls.model_fields.items()
if field_name not in exclude_fields
}
return create_model(
model_cls.__name__,
__config__=model_cls.model_config,
**fields,
)

0 comments on commit 609deaa

Please sign in to comment.