diff --git a/api/environments/identities/models.py b/api/environments/identities/models.py index a95f4a178f61..80611c521860 100644 --- a/api/environments/identities/models.py +++ b/api/environments/identities/models.py @@ -1,8 +1,10 @@ import typing +from itertools import chain from django.db import models from django.db.models import Prefetch, Q from django.utils import timezone +from flag_engine.identities.traits.types import TraitValue from flag_engine.segments.evaluator import evaluate_identity_in_segment from environments.identities.managers import IdentityManager @@ -204,28 +206,32 @@ def generate_traits(self, trait_data_items, persist=False): :return: list of TraitModels """ trait_models = [] + trait_models_to_persist = [] - # Remove traits having Null(None) values - trait_data_items = filter( - lambda trait: trait["trait_value"] is not None, trait_data_items - ) for trait_data_item in trait_data_items: + # exclude traits with null values + if (trait_value := trait_data_item["trait_value"]) is None: + continue + trait_key = trait_data_item["trait_key"] - trait_value = trait_data_item["trait_value"] - trait_models.append( - Trait( - trait_key=trait_key, - identity=self, - **Trait.generate_trait_value_data(trait_value), - ) + trait = Trait( + trait_key=trait_key, + identity=self, + **Trait.generate_trait_value_data(trait_value), ) + trait_models.append(trait) + if not trait_data_item.get("transient"): + trait_models_to_persist.append(trait) if persist: - Trait.objects.bulk_create(trait_models) + Trait.objects.bulk_create(trait_models_to_persist) return trait_models - def update_traits(self, trait_data_items): + def update_traits( + self, + trait_data_items: list[dict[str, TraitValue]], + ) -> list[Trait]: """ Given a list of traits, update any that already exist and create any new ones. Return the full list of traits for the given identity after these changes. @@ -235,38 +241,59 @@ def update_traits(self, trait_data_items): """ current_traits = {t.trait_key: t for t in self.identity_traits.all()} - keys_to_delete = [] + keys_to_delete = set() new_traits = [] updated_traits = [] + transient_traits = [] for trait_data_item in trait_data_items: trait_key = trait_data_item["trait_key"] trait_value = trait_data_item["trait_value"] + transient = trait_data_item.get("transient") + + if transient: + transient_traits.append( + Trait( + **Trait.generate_trait_value_data(trait_value), + trait_key=trait_key, + identity=self, + ) + ) + continue if trait_value is None: # build a list of trait keys to delete having been nulled by the # input data - keys_to_delete.append(trait_key) + keys_to_delete.add(trait_key) continue - trait_value_data = Trait.generate_trait_value_data(trait_value) - if trait_key in current_traits: current_trait = current_traits[trait_key] # Don't update the trait if the value hasn't changed if current_trait.trait_value == trait_value: continue - for attr, value in trait_value_data.items(): + for attr, value in Trait.generate_trait_value_data(trait_value).items(): setattr(current_trait, attr, value) updated_traits.append(current_trait) - else: - new_traits.append( - Trait(**trait_value_data, trait_key=trait_key, identity=self) + continue + + new_traits.append( + Trait( + **Trait.generate_trait_value_data(trait_value), + trait_key=trait_key, + identity=self, ) + ) # delete the traits that had their keys set to None + # (except the transient ones) if keys_to_delete: + current_traits = { + trait_key: trait + for trait_key, trait in current_traits.items() + if trait_key not in keys_to_delete + } self.identity_traits.filter(trait_key__in=keys_to_delete).delete() Trait.objects.bulk_update(updated_traits, fields=Trait.BULK_UPDATE_FIELDS) @@ -278,5 +305,15 @@ def update_traits(self, trait_data_items): Trait.objects.bulk_create(new_traits, ignore_conflicts=True) # return the full list of traits for this identity by refreshing from the db - # TODO: handle this in the above logic to avoid a second hit to the DB - return self.identity_traits.all() + # override persisted traits by transient traits in case of key collisions + return [ + *{ + trait.trait_key: trait + for trait in chain( + current_traits.values(), + updated_traits, + new_traits, + transient_traits, + ) + }.values() + ] diff --git a/api/environments/identities/serializers.py b/api/environments/identities/serializers.py index b180cbdc7b84..bd4b1ffaed1f 100644 --- a/api/environments/identities/serializers.py +++ b/api/environments/identities/serializers.py @@ -65,6 +65,7 @@ class _TraitSerializer(serializers.Serializer): class SDKIdentitiesQuerySerializer(serializers.Serializer): identifier = serializers.CharField(required=True) + transient = serializers.BooleanField(default=False) class IdentityAllFeatureStatesFeatureSerializer(serializers.Serializer): diff --git a/api/environments/identities/traits/serializers.py b/api/environments/identities/traits/serializers.py index 8999beae7efd..47b938506a8e 100644 --- a/api/environments/identities/traits/serializers.py +++ b/api/environments/identities/traits/serializers.py @@ -22,10 +22,11 @@ def get_trait_value(obj): class TraitSerializerBasic(serializers.ModelSerializer): trait_value = TraitValueField(allow_null=True) + transient = serializers.BooleanField(default=False, write_only=True) class Meta: model = Trait - fields = ("id", "trait_key", "trait_value") + fields = ("id", "trait_key", "trait_value", "transient") read_only_fields = ("id",) diff --git a/api/environments/identities/views.py b/api/environments/identities/views.py index 491031247d20..f403b8001dc1 100644 --- a/api/environments/identities/views.py +++ b/api/environments/identities/views.py @@ -5,6 +5,7 @@ from core.request_origin import RequestOrigin from django.conf import settings from django.db.models import Q +from django.utils import timezone from django.utils.decorators import method_decorator from django.views.decorators.cache import cache_page from drf_yasg.utils import swagger_auto_schema @@ -173,11 +174,18 @@ def get(self, request): {"detail": "Missing identifier"} ) # TODO: add 400 status - will this break the clients? - identity, _ = Identity.objects.get_or_create_for_sdk( - identifier=identifier, - environment=request.environment, - integrations=IDENTITY_INTEGRATIONS, - ) + if request.query_params.get("transient"): + identity = Identity( + created_date=timezone.now(), + identifier=identifier, + environment=request.environment, + ) + else: + identity, _ = Identity.objects.get_or_create_for_sdk( + identifier=identifier, + environment=request.environment, + integrations=IDENTITY_INTEGRATIONS, + ) self.identity = identity if settings.EDGE_API_URL and request.environment.project.enable_dynamo_db: diff --git a/api/environments/sdk/serializers.py b/api/environments/sdk/serializers.py index 09eac1877bfe..8ee5d254b94b 100644 --- a/api/environments/sdk/serializers.py +++ b/api/environments/sdk/serializers.py @@ -2,6 +2,7 @@ from collections import defaultdict from core.constants import BOOLEAN, FLOAT, INTEGER, STRING +from django.utils import timezone from rest_framework import serializers from environments.identities.models import Identity @@ -125,6 +126,7 @@ class IdentifyWithTraitsSerializer( HideSensitiveFieldsSerializerMixin, serializers.Serializer ): identifier = serializers.CharField(write_only=True, required=True) + transient = serializers.BooleanField(write_only=True, default=False) traits = TraitSerializerBasic(required=False, many=True) flags = SDKFeatureStateSerializer(read_only=True, many=True) @@ -136,23 +138,34 @@ def save(self, **kwargs): (optionally store traits if flag set on org) """ environment = self.context["environment"] - identity, created = Identity.objects.get_or_create( - identifier=self.validated_data["identifier"], environment=environment - ) + transient = self.validated_data["transient"] trait_data_items = self.validated_data.get("traits", []) - if not created and environment.project.organisation.persist_trait_data: - # if this is an update and we're persisting traits, then we need to - # partially update any traits and return the full list - trait_models = identity.update_traits(trait_data_items) + if transient: + identity = Identity( + created_date=timezone.now(), + identifier=self.validated_data["identifier"], + environment=environment, + ) + trait_models = identity.generate_traits(trait_data_items, persist=False) + else: - # generate traits for the identity and store them if configured to do so - trait_models = identity.generate_traits( - trait_data_items, - persist=environment.project.organisation.persist_trait_data, + identity, created = Identity.objects.get_or_create( + identifier=self.validated_data["identifier"], environment=environment ) + if not created and environment.project.organisation.persist_trait_data: + # if this is an update and we're persisting traits, then we need to + # partially update any traits and return the full list + trait_models = identity.update_traits(trait_data_items) + else: + # generate traits for the identity and store them if configured to do so + trait_models = identity.generate_traits( + trait_data_items, + persist=environment.project.organisation.persist_trait_data, + ) + all_feature_states = identity.get_all_feature_states( traits=trait_models, additional_filters=self.context.get("feature_states_additional_filters"), diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 3348ac264c12..5173596ab746 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -342,6 +342,8 @@ def segment_featurestate( feature_segment: int, ) -> int: data = { + "enabled": True, + "feature_state_value": {"type": "unicode", "string_value": "segment override"}, "feature": feature, "environment": environment, "feature_segment": feature_segment, diff --git a/api/tests/integration/environments/identities/test_integration_identities.py b/api/tests/integration/environments/identities/test_integration_identities.py index 6e5bdee3194d..1e6a3fa81f82 100644 --- a/api/tests/integration/environments/identities/test_integration_identities.py +++ b/api/tests/integration/environments/identities/test_integration_identities.py @@ -4,6 +4,7 @@ import pytest from django.urls import reverse from rest_framework import status +from rest_framework.test import APIClient from features.feature_types import MULTIVARIATE from tests.integration.helpers import ( @@ -221,3 +222,103 @@ def test_get_feature_states_for_identity_only_makes_one_query_to_get_mv_feature_ second_identity_response_json = second_identity_response.json() assert len(second_identity_response_json["flags"]) == 3 + + +def test_get_feature_states_for_identity__transient_identity__segment_match_expected( + sdk_client: APIClient, + feature: int, + segment: int, + segment_condition_property: str, + segment_condition_value: str, + segment_featurestate: int, +) -> None: + # Given + url = reverse("api-v1:sdk-identities") + + # When + # flags are requested for a new transient identity + # that matches the segment + response = sdk_client.post( + url, + data=json.dumps( + { + "identifier": "unseen", + "traits": [ + { + "trait_key": segment_condition_property, + "trait_value": segment_condition_value, + } + ], + "transient": True, + } + ), + content_type="application/json", + ) + + # Then + assert response.status_code == status.HTTP_200_OK + response_json = response.json() + assert ( + flag_data := next( + ( + flag + for flag in response_json["flags"] + if flag["feature"]["id"] == feature + ), + None, + ) + ) + assert flag_data["enabled"] is True + assert flag_data["feature_state_value"] == "segment override" + + +def test_get_feature_states_for_identity__transient_trait__segment_match_expected( + sdk_client: APIClient, + feature: int, + segment: int, + segment_condition_property: str, + segment_condition_value: str, + segment_featurestate: int, +) -> None: + # Given + url = reverse("api-v1:sdk-identities") + + # When + # flags are requested for a new transient identity + # that matches the segment + response = sdk_client.post( + url, + data=json.dumps( + { + "identifier": "unseen", + "traits": [ + { + "trait_key": segment_condition_property, + "trait_value": segment_condition_value, + "transient": True, + }, + { + "trait_key": "persistent", + "trait_value": "trait value", + }, + ], + } + ), + content_type="application/json", + ) + + # Then + assert response.status_code == status.HTTP_200_OK + response_json = response.json() + assert ( + flag_data := next( + ( + flag + for flag in response_json["flags"] + if flag["feature"]["id"] == feature + ), + None, + ) + ) + assert flag_data["enabled"] is True + assert flag_data["feature_state_value"] == "segment override" diff --git a/api/tests/unit/environments/identities/helpers.py b/api/tests/unit/environments/identities/helpers.py index 09b97e189acf..ac0fdd01b0e3 100644 --- a/api/tests/unit/environments/identities/helpers.py +++ b/api/tests/unit/environments/identities/helpers.py @@ -5,9 +5,15 @@ def generate_trait_data_item( - trait_key: str = "trait_key", trait_value: typing.Any = "trait_value" + trait_key: str = "trait_key", + trait_value: typing.Any = "trait_value", + transient: bool = False, ): - return {"trait_key": trait_key, "trait_value": trait_value} + return { + "trait_key": trait_key, + "trait_value": trait_value, + "transient": transient, + } def create_trait_for_identity( diff --git a/api/tests/unit/environments/identities/test_unit_identities_models.py b/api/tests/unit/environments/identities/test_unit_identities_models.py index 33b0de616ce8..8a4c00aa736b 100644 --- a/api/tests/unit/environments/identities/test_unit_identities_models.py +++ b/api/tests/unit/environments/identities/test_unit_identities_models.py @@ -604,17 +604,21 @@ def test_generate_traits_with_persistence(environment: Environment) -> None: generate_trait_data_item("string_trait", "string_value"), generate_trait_data_item("integer_trait", 1), generate_trait_data_item("boolean_value", True), + generate_trait_data_item("transient_trait", "string_value", transient=True), ] # When trait_models = identity.generate_traits(trait_data_items, persist=True) # Then - # the response from the method has 3 traits - assert len(trait_models) == 3 + # the response from the method has 4 traits + assert len(trait_models) == 4 - # and the database matches it + # and 3 were persisted assert Trait.objects.filter(identity=identity).count() == 3 + assert not Trait.objects.filter( + identity=identity, trait_key="transient_trait" + ).exists() def test_generate_traits_without_persistence(environment: Environment) -> None: @@ -662,19 +666,32 @@ def test_update_traits(environment: Environment) -> None: new_trait_1_value = 5 trait_3_key = "trait_3" trait_3_value = 3 + + # and one transient trait that is evaluated against but not persisted + transient_trait_key = "transient" + transient_trait_value = "not persisted" + trait_data_items = [ generate_trait_data_item( trait_key=trait_1.trait_key, trait_value=new_trait_1_value ), generate_trait_data_item(trait_key=trait_3_key, trait_value=trait_3_value), + generate_trait_data_item( + trait_key=transient_trait_key, + trait_value=transient_trait_value, + transient=True, + ), ] # When updated_traits = identity.update_traits(trait_data_items) # Then - # 3 traits are returned - assert len(updated_traits) == 3 + # 3 traits are persisted + assert identity.identity_traits.count() == 3 + + # 4 traits are returned + assert len(updated_traits) == 4 # and the first trait has it's value updated correctly updated_trait_1 = get_trait_from_list_by_key(trait_1_key, updated_traits) @@ -688,6 +705,10 @@ def test_update_traits(environment: Environment) -> None: updated_trait_3 = get_trait_from_list_by_key(trait_3_key, updated_traits) assert updated_trait_3.trait_value == trait_3_value + # and the transient trait is returned among others + transient_trait = get_trait_from_list_by_key(transient_trait_key, updated_traits) + assert transient_trait.trait_value == transient_trait_value + def test_update_traits_deletes_when_nulled_out(environment: Environment) -> None: """ diff --git a/api/tests/unit/environments/identities/test_unit_identities_views.py b/api/tests/unit/environments/identities/test_unit_identities_views.py index d378a581146f..55052511be85 100644 --- a/api/tests/unit/environments/identities/test_unit_identities_views.py +++ b/api/tests/unit/environments/identities/test_unit_identities_views.py @@ -1033,6 +1033,23 @@ def test_get_identities_with_hide_sensitive_data( assert response.json()["traits"] == [] +def test_get_identities__transient__no_persistence( + environment: Environment, + api_client: APIClient, +) -> None: + # Given + identifier = "transient" + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + url = reverse("api-v1:sdk-identities") + f"?identifier={identifier}&transient=true" + + # When + response = api_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK + assert not Identity.objects.filter(identifier=identifier).count() + + def test_post_identities_with_hide_sensitive_data( environment, feature, identity, api_client ): @@ -1126,6 +1143,59 @@ def test_post_identities__server_key_only_feature__server_key_auth__return_expec assert response.json()["flags"] +def test_post_identities__transient__no_persistence( + environment: Environment, + api_client: APIClient, +) -> None: + # Given + identifier = "transient" + trait_key = "trait_key" + + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + url = reverse("api-v1:sdk-identities") + data = { + "identifier": identifier, + "traits": [{"trait_key": trait_key, "trait_value": "bar"}], + "transient": True, + } + + # When + response = api_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + assert not Identity.objects.filter(identifier=identifier).exists() + assert not Trait.objects.filter(trait_key=trait_key).exists() + + +def test_post_identities__transient_traits__no_persistence( + environment: Environment, + api_client: APIClient, +) -> None: + # Given + identifier = "transient" + trait_key = "trait_key" + + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + url = reverse("api-v1:sdk-identities") + data = { + "identifier": identifier, + "traits": [{"trait_key": trait_key, "trait_value": "bar", "transient": True}], + } + + # When + response = api_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + assert Identity.objects.filter(identifier=identifier).exists() + assert not Trait.objects.filter(trait_key=trait_key).exists() + + def test_user_with_view_identities_permission_can_retrieve_identity( environment, identity, diff --git a/api/tests/unit/environments/identities/traits/test_unit_traits_serializers.py b/api/tests/unit/environments/identities/traits/test_unit_traits_serializers.py index c6773d23c44b..7b5437b3462e 100644 --- a/api/tests/unit/environments/identities/traits/test_unit_traits_serializers.py +++ b/api/tests/unit/environments/identities/traits/test_unit_traits_serializers.py @@ -48,7 +48,7 @@ def test_bulk_create_update_serializer_save_many( mocked_request = mocker.MagicMock(environment=identity.environment) # When - with django_assert_num_queries(6): + with django_assert_num_queries(5): serializer = SDKBulkCreateUpdateTraitSerializer( data=data, many=True, diff --git a/api/tests/unit/environments/sdk/test_unit_sdk_serializers.py b/api/tests/unit/environments/sdk/test_unit_sdk_serializers.py index 32d1825a52f6..ac687533623c 100644 --- a/api/tests/unit/environments/sdk/test_unit_sdk_serializers.py +++ b/api/tests/unit/environments/sdk/test_unit_sdk_serializers.py @@ -4,6 +4,7 @@ from pytest_mock import MockerFixture from environments.identities.models import Identity +from environments.identities.traits.models import Trait from environments.models import Environment from environments.sdk.serializers import IdentifyWithTraitsSerializer from features.models import Feature @@ -84,3 +85,33 @@ def test_identify_with_traits_serializer__additional_filters_in_context__filters # Then assert "flags" not in serializer.data + + +def test_identify_with_traits_serializer__transient__identity_and_traits_not_persisted( + mocker: MockerFixture, + environment: Environment, +) -> None: + # Given + identity_identifier = "completely_new_identity" + data = { + "identifier": identity_identifier, + "traits": [{"trait_key": "trait_key", "trait_value": "trait_value"}], + "transient": True, + } + request_mock = mocker.MagicMock() + + serializer = IdentifyWithTraitsSerializer( + data=data, + context={ + "environment": environment, + "request": request_mock, + }, + ) + + # When + assert serializer.is_valid() + serializer.save() + + # Then + assert not Identity.objects.filter(identifier=identity_identifier).exists() + assert not Trait.objects.filter(identity__identifier=identity_identifier).exists()