-
Notifications
You must be signed in to change notification settings - Fork 429
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Capability for Pydantic-based OpenAPI response schemas (#3795)
- Loading branch information
Showing
9 changed files
with
206 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,94 @@ | ||
from copy import deepcopy | ||
from functools import lru_cache | ||
|
||
import jsonref | ||
from drf_yasg.openapi import Schema | ||
from flag_engine.environments.models import EnvironmentModel | ||
|
||
SKIP_PROPERTIES = [ | ||
"amplitude_config", | ||
"dynatrace_config", | ||
"heap_config", | ||
"mixpanel_config", | ||
"rudderstack_config", | ||
"segment_config", | ||
"webhook_config", | ||
] | ||
SKIP_DEFINITIONS = ["IntegrationModel", "WebhookModel"] | ||
|
||
|
||
@lru_cache() | ||
def get_environment_document_response() -> Schema: | ||
model_json_schema = EnvironmentModel.model_json_schema(mode="serialization") | ||
|
||
# Restrict segment rule recursion to two levels. | ||
segment_rule_schema = deepcopy(model_json_schema["$defs"]["SegmentRuleModel"]) | ||
del segment_rule_schema["properties"]["rules"] | ||
model_json_schema["$defs"]["SegmentRuleInnerModel"] = segment_rule_schema | ||
model_json_schema["$defs"]["SegmentRuleModel"]["properties"]["rules"]["items"][ | ||
"$ref" | ||
] = "#/$defs/SegmentRuleInnerModel" | ||
|
||
# Remove integrations. | ||
for prop in SKIP_PROPERTIES: | ||
del model_json_schema["properties"][prop] | ||
for definition in SKIP_DEFINITIONS: | ||
del model_json_schema["$defs"][definition] | ||
|
||
return Schema(**jsonref.replace_refs(model_json_schema)) | ||
import inspect | ||
from typing import Any | ||
|
||
from drf_yasg.inspectors import SwaggerAutoSchema | ||
from drf_yasg.openapi import SCHEMA_DEFINITIONS, Response, Schema | ||
from pydantic import BaseModel | ||
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue | ||
from pydantic_core import core_schema | ||
|
||
|
||
class _GenerateJsonSchema(GenerateJsonSchema): | ||
def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue: | ||
"""Generates an OpenAPI 2.0-compatible JSON schema that matches a schema that allows null values. | ||
(The catch is OpenAPI 2.0 does not allow them, but some clients are capable | ||
to consume the `x-nullable` annotation.) | ||
Args: | ||
schema: The core schema. | ||
Returns: | ||
The generated JSON schema. | ||
""" | ||
anyof_schema_value = super().nullable_schema(schema) | ||
elem = next( | ||
any_of | ||
for any_of in anyof_schema_value["anyOf"] | ||
if any_of.get("type") != "null" | ||
) | ||
if type := elem.get("type"): | ||
return {"type": type, "x-nullable": True} | ||
# Assuming a reference here (which we can not annotate) | ||
return elem | ||
|
||
|
||
class PydanticResponseCapableSwaggerAutoSchema(SwaggerAutoSchema): | ||
""" | ||
A `SwaggerAutoSchema` subclass that allows to generate view response Swagger docs | ||
from a Pydantic model. | ||
Example usage: | ||
``` | ||
@drf_yasg.utils.swagger_auto_schema( | ||
responses={200: YourPydanticSchema}, | ||
auto_schema=PydanticResponseCapableSwaggerAutoSchema, | ||
) | ||
def your_view(): ... | ||
``` | ||
To adapt Pydantic-generated schemas, the following is taken care of: | ||
1. Pydantic-generated definitions are unwrapped and added to drf-yasg's global definitions. | ||
2. Rather than using `anyOf`, nullable fields are annotated with `x-nullable`. | ||
3. As there's no way to annotate a reference, all nested models are assumed to be `x-nullable`. | ||
""" | ||
|
||
def get_response_schemas( | ||
self, | ||
response_serializers: dict[str | int, Any], | ||
) -> dict[str, Response]: | ||
result = {} | ||
|
||
definitions = self.components.with_scope(SCHEMA_DEFINITIONS) | ||
|
||
for status_code in list(response_serializers): | ||
if inspect.isclass(response_serializers[status_code]) and issubclass( | ||
model_cls := response_serializers[status_code], BaseModel | ||
): | ||
model_json_schema = model_cls.model_json_schema( | ||
mode="serialization", | ||
schema_generator=_GenerateJsonSchema, | ||
ref_template=f"#/{SCHEMA_DEFINITIONS}/{{model}}", | ||
) | ||
|
||
for ref_name, schema_kwargs in model_json_schema.pop("$defs").items(): | ||
definitions.setdefault( | ||
ref_name, | ||
maker=lambda: Schema( | ||
**schema_kwargs, | ||
# We can not annotate references with `x-nullable`, | ||
# So just assume all nested models as nullable for now. | ||
x_nullable=True, | ||
), | ||
) | ||
|
||
result[str(status_code)] = Response( | ||
description=model_cls.__name__, | ||
schema=Schema(**model_json_schema), | ||
) | ||
|
||
del response_serializers[status_code] | ||
|
||
return {**super().get_response_schemas(response_serializers), **result} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from flag_engine.environments.models import EnvironmentModel | ||
|
||
from util.pydantic import exclude_model_fields | ||
|
||
SDKEnvironmentDocumentModel = exclude_model_fields( | ||
EnvironmentModel, | ||
"amplitude_config", | ||
"dynatrace_config", | ||
"heap_config", | ||
"mixpanel_config", | ||
"rudderstack_config", | ||
"segment_config", | ||
"webhook_config", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import pydantic | ||
from drf_yasg.openapi import ( | ||
SCHEMA_DEFINITIONS, | ||
ReferenceResolver, | ||
Response, | ||
Schema, | ||
) | ||
from pytest_mock import MockerFixture | ||
|
||
from api.openapi import PydanticResponseCapableSwaggerAutoSchema | ||
|
||
|
||
def test_pydantic_response_capable_auto_schema__renders_expected( | ||
mocker: MockerFixture, | ||
) -> None: | ||
# Given | ||
class Nested(pydantic.BaseModel): | ||
usual_str: str | ||
optional_int: int | None = None | ||
|
||
class ResponseModel(pydantic.BaseModel): | ||
nested_once: Nested | ||
nested_list: list[Nested] | ||
|
||
auto_schema = PydanticResponseCapableSwaggerAutoSchema( | ||
view=mocker.MagicMock(), | ||
path=mocker.MagicMock(), | ||
method=mocker.MagicMock(), | ||
components=ReferenceResolver("definitions", force_init=True), | ||
request=mocker.MagicMock(), | ||
overrides=mocker.MagicMock(), | ||
) | ||
|
||
# When | ||
response_schemas = auto_schema.get_response_schemas({200: ResponseModel}) | ||
|
||
# Then | ||
assert response_schemas == { | ||
"200": Response( | ||
description="ResponseModel", | ||
schema=Schema( | ||
title="ResponseModel", | ||
required=["nested_once", "nested_list"], | ||
type="object", | ||
properties={ | ||
"nested_list": { | ||
"items": {"$ref": "#/definitions/Nested"}, | ||
"title": "Nested List", | ||
"type": "array", | ||
}, | ||
"nested_once": {"$ref": "#/definitions/Nested"}, | ||
}, | ||
), | ||
), | ||
} | ||
nested_schema = auto_schema.components.with_scope(SCHEMA_DEFINITIONS).get("Nested") | ||
assert nested_schema == Schema( | ||
title="Nested", | ||
required=["usual_str"], | ||
type="object", | ||
properties={ | ||
"optional_int": { | ||
"default": None, | ||
"title": "Optional Int", | ||
"type": "integer", | ||
"x-nullable": True, | ||
}, | ||
"usual_str": {"title": "Usual Str", "type": "string"}, | ||
}, | ||
x_nullable=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from pydantic import BaseModel, create_model | ||
|
||
|
||
def exclude_model_fields( | ||
model_cls: type[BaseModel], | ||
*exclude_fields: str, | ||
) -> type[BaseModel]: | ||
""" | ||
Create a copy of a model class without the fields | ||
specified in `exclude_fields`. | ||
""" | ||
fields = { | ||
field_name: (field.annotation, field) | ||
for field_name, field in model_cls.model_fields.items() | ||
if field_name not in exclude_fields | ||
} | ||
return create_model( | ||
model_cls.__name__, | ||
__config__=model_cls.model_config, | ||
**fields, | ||
) |