diff --git a/api/api/openapi.py b/api/api/openapi.py index b52b7c39d8b3..0a1a56e9a64f 100644 --- a/api/api/openapi.py +++ b/api/api/openapi.py @@ -1,38 +1,94 @@ -from copy import deepcopy -from functools import lru_cache - -import jsonref -from drf_yasg.openapi import Schema -from flag_engine.environments.models import EnvironmentModel - -SKIP_PROPERTIES = [ - "amplitude_config", - "dynatrace_config", - "heap_config", - "mixpanel_config", - "rudderstack_config", - "segment_config", - "webhook_config", -] -SKIP_DEFINITIONS = ["IntegrationModel", "WebhookModel"] - - -@lru_cache() -def get_environment_document_response() -> Schema: - model_json_schema = EnvironmentModel.model_json_schema(mode="serialization") - - # Restrict segment rule recursion to two levels. - segment_rule_schema = deepcopy(model_json_schema["$defs"]["SegmentRuleModel"]) - del segment_rule_schema["properties"]["rules"] - model_json_schema["$defs"]["SegmentRuleInnerModel"] = segment_rule_schema - model_json_schema["$defs"]["SegmentRuleModel"]["properties"]["rules"]["items"][ - "$ref" - ] = "#/$defs/SegmentRuleInnerModel" - - # Remove integrations. - for prop in SKIP_PROPERTIES: - del model_json_schema["properties"][prop] - for definition in SKIP_DEFINITIONS: - del model_json_schema["$defs"][definition] - - return Schema(**jsonref.replace_refs(model_json_schema)) +import inspect +from typing import Any + +from drf_yasg.inspectors import SwaggerAutoSchema +from drf_yasg.openapi import SCHEMA_DEFINITIONS, Response, Schema +from pydantic import BaseModel +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue +from pydantic_core import core_schema + + +class _GenerateJsonSchema(GenerateJsonSchema): + def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue: + """Generates an OpenAPI 2.0-compatible JSON schema that matches a schema that allows null values. + + (The catch is OpenAPI 2.0 does not allow them, but some clients are capable + to consume the `x-nullable` annotation.) + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + anyof_schema_value = super().nullable_schema(schema) + elem = next( + any_of + for any_of in anyof_schema_value["anyOf"] + if any_of.get("type") != "null" + ) + if type := elem.get("type"): + return {"type": type, "x-nullable": True} + # Assuming a reference here (which we can not annotate) + return elem + + +class PydanticResponseCapableSwaggerAutoSchema(SwaggerAutoSchema): + """ + A `SwaggerAutoSchema` subclass that allows to generate view response Swagger docs + from a Pydantic model. + + Example usage: + + ``` + @drf_yasg.utils.swagger_auto_schema( + responses={200: YourPydanticSchema}, + auto_schema=PydanticResponseCapableSwaggerAutoSchema, + ) + def your_view(): ... + ``` + + To adapt Pydantic-generated schemas, the following is taken care of: + + 1. Pydantic-generated definitions are unwrapped and added to drf-yasg's global definitions. + 2. Rather than using `anyOf`, nullable fields are annotated with `x-nullable`. + 3. As there's no way to annotate a reference, all nested models are assumed to be `x-nullable`. + """ + + def get_response_schemas( + self, + response_serializers: dict[str | int, Any], + ) -> dict[str, Response]: + result = {} + + definitions = self.components.with_scope(SCHEMA_DEFINITIONS) + + for status_code in list(response_serializers): + if inspect.isclass(response_serializers[status_code]) and issubclass( + model_cls := response_serializers[status_code], BaseModel + ): + model_json_schema = model_cls.model_json_schema( + mode="serialization", + schema_generator=_GenerateJsonSchema, + ref_template=f"#/{SCHEMA_DEFINITIONS}/{{model}}", + ) + + for ref_name, schema_kwargs in model_json_schema.pop("$defs").items(): + definitions.setdefault( + ref_name, + maker=lambda: Schema( + **schema_kwargs, + # We can not annotate references with `x-nullable`, + # So just assume all nested models as nullable for now. + x_nullable=True, + ), + ) + + result[str(status_code)] = Response( + description=model_cls.__name__, + schema=Schema(**model_json_schema), + ) + + del response_serializers[status_code] + + return {**super().get_response_schemas(response_serializers), **result} diff --git a/api/app/settings/common.py b/api/app/settings/common.py index 1955bdaf7103..67adec31d27f 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -485,6 +485,7 @@ EMAIL_USE_TLS = env.bool("EMAIL_USE_TLS", default=True) SWAGGER_SETTINGS = { + "DEFAULT_AUTO_SCHEMA_CLASS": "api.openapi.PydanticResponseCapableSwaggerAutoSchema", "SHOW_REQUEST_HEADERS": True, "SECURITY_DEFINITIONS": { "Private": { diff --git a/api/environments/sdk/schemas.py b/api/environments/sdk/schemas.py new file mode 100644 index 000000000000..9741f4130df2 --- /dev/null +++ b/api/environments/sdk/schemas.py @@ -0,0 +1,14 @@ +from flag_engine.environments.models import EnvironmentModel + +from util.pydantic import exclude_model_fields + +SDKEnvironmentDocumentModel = exclude_model_fields( + EnvironmentModel, + "amplitude_config", + "dynatrace_config", + "heap_config", + "mixpanel_config", + "rudderstack_config", + "segment_config", + "webhook_config", +) diff --git a/api/environments/sdk/views.py b/api/environments/sdk/views.py index 052bf2cb5840..a798709490c1 100644 --- a/api/environments/sdk/views.py +++ b/api/environments/sdk/views.py @@ -4,10 +4,10 @@ from rest_framework.response import Response from rest_framework.views import APIView -from api.openapi import get_environment_document_response from environments.authentication import EnvironmentKeyAuthentication from environments.models import Environment from environments.permissions.permissions import EnvironmentKeyPermissions +from environments.sdk.schemas import SDKEnvironmentDocumentModel class SDKEnvironmentAPIView(APIView): @@ -17,7 +17,7 @@ class SDKEnvironmentAPIView(APIView): def get_authenticators(self): return [EnvironmentKeyAuthentication(required_key_prefix="ser.")] - @swagger_auto_schema(responses={200: get_environment_document_response()}) + @swagger_auto_schema(responses={200: SDKEnvironmentDocumentModel}) def get(self, request: HttpRequest) -> Response: environment_document = Environment.get_environment_document( request.environment.api_key diff --git a/api/environments/views.py b/api/environments/views.py index 739413703736..14815b13fdc5 100644 --- a/api/environments/views.py +++ b/api/environments/views.py @@ -11,12 +11,12 @@ from rest_framework.request import Request from rest_framework.response import Response -from api.openapi import get_environment_document_response from environments.permissions.permissions import ( EnvironmentAdminPermission, EnvironmentPermissions, NestedEnvironmentPermissions, ) +from environments.sdk.schemas import SDKEnvironmentDocumentModel from features.versioning.tasks import enable_v2_versioning from permissions.permissions_calculator import get_environment_permission_data from permissions.serializers import ( @@ -199,7 +199,7 @@ def user_permissions(self, request, *args, **kwargs): serializer = UserObjectPermissionsSerializer(instance=permission_data) return Response(serializer.data) - @swagger_auto_schema(responses={200: get_environment_document_response()}) + @swagger_auto_schema(responses={200: SDKEnvironmentDocumentModel}) @action(detail=True, methods=["GET"], url_path="document") def get_document(self, request, api_key: str): return Response(Environment.get_environment_document(api_key)) diff --git a/api/poetry.lock b/api/poetry.lock index 0c5068709953..55f1a9b0e7ed 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2263,17 +2263,6 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] -[[package]] -name = "jsonref" -version = "1.1.0" -description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python." -optional = false -python-versions = ">=3.7" -files = [ - {file = "jsonref-1.1.0-py3-none-any.whl", hash = "sha256:590dc7773df6c21cbf948b5dac07a72a251db28b0238ceecce0a2abfa8ec30a9"}, - {file = "jsonref-1.1.0.tar.gz", hash = "sha256:32fe8e1d85af0fdefbebce950af85590b22b60f9e95443176adbde4e1ecea552"}, -] - [[package]] name = "jsonschema" version = "4.17.3" @@ -4846,4 +4835,4 @@ requests = ">=2.7,<3.0" [metadata] lock-version = "2.0" python-versions = "~3.12" -content-hash = "40ddfbb4a4248c7aed384e0d0cb256b8153d3758b19137b3848c459186183692" +content-hash = "6bb4ffb389ab5fa15347895d977dfca1703ef1b31efc93f386daf410f1a8aeb3" diff --git a/api/pyproject.toml b/api/pyproject.toml index 9ff0e1c8288f..24165221f342 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -106,7 +106,6 @@ python-gnupg = "^0.5.1" django-redis = "^5.4.0" pygithub = "2.1.1" hubspot-api-client = "^8.2.1" -jsonref = "^1.1.0" [tool.poetry.group.auth-controller] optional = true diff --git a/api/tests/unit/api/test_unit_openapi.py b/api/tests/unit/api/test_unit_openapi.py new file mode 100644 index 000000000000..17e1d3a5b788 --- /dev/null +++ b/api/tests/unit/api/test_unit_openapi.py @@ -0,0 +1,71 @@ +import pydantic +from drf_yasg.openapi import ( + SCHEMA_DEFINITIONS, + ReferenceResolver, + Response, + Schema, +) +from pytest_mock import MockerFixture + +from api.openapi import PydanticResponseCapableSwaggerAutoSchema + + +def test_pydantic_response_capable_auto_schema__renders_expected( + mocker: MockerFixture, +) -> None: + # Given + class Nested(pydantic.BaseModel): + usual_str: str + optional_int: int | None = None + + class ResponseModel(pydantic.BaseModel): + nested_once: Nested + nested_list: list[Nested] + + auto_schema = PydanticResponseCapableSwaggerAutoSchema( + view=mocker.MagicMock(), + path=mocker.MagicMock(), + method=mocker.MagicMock(), + components=ReferenceResolver("definitions", force_init=True), + request=mocker.MagicMock(), + overrides=mocker.MagicMock(), + ) + + # When + response_schemas = auto_schema.get_response_schemas({200: ResponseModel}) + + # Then + assert response_schemas == { + "200": Response( + description="ResponseModel", + schema=Schema( + title="ResponseModel", + required=["nested_once", "nested_list"], + type="object", + properties={ + "nested_list": { + "items": {"$ref": "#/definitions/Nested"}, + "title": "Nested List", + "type": "array", + }, + "nested_once": {"$ref": "#/definitions/Nested"}, + }, + ), + ), + } + nested_schema = auto_schema.components.with_scope(SCHEMA_DEFINITIONS).get("Nested") + assert nested_schema == Schema( + title="Nested", + required=["usual_str"], + type="object", + properties={ + "optional_int": { + "default": None, + "title": "Optional Int", + "type": "integer", + "x-nullable": True, + }, + "usual_str": {"title": "Usual Str", "type": "string"}, + }, + x_nullable=True, + ) diff --git a/api/util/pydantic.py b/api/util/pydantic.py new file mode 100644 index 000000000000..d526c76ee95b --- /dev/null +++ b/api/util/pydantic.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel, create_model + + +def exclude_model_fields( + model_cls: type[BaseModel], + *exclude_fields: str, +) -> type[BaseModel]: + """ + Create a copy of a model class without the fields + specified in `exclude_fields`. + """ + fields = { + field_name: (field.annotation, field) + for field_name, field in model_cls.model_fields.items() + if field_name not in exclude_fields + } + return create_model( + model_cls.__name__, + __config__=model_cls.model_config, + **fields, + )