Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support transient identities and traits #4325

Merged
merged 11 commits into from
Jul 17, 2024
40 changes: 34 additions & 6 deletions api/environments/identities/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing
from itertools import chain

from django.db import models
from django.db.models import Prefetch, Q
from django.utils import timezone
from flag_engine.identities.traits.types import TraitValue
from flag_engine.segments.evaluator import evaluate_identity_in_segment

from environments.identities.managers import IdentityManager
Expand Down Expand Up @@ -225,7 +227,10 @@ def generate_traits(self, trait_data_items, persist=False):

return trait_models

def update_traits(self, trait_data_items):
def update_traits(
self,
trait_data_items: list[dict[str, TraitValue]],
) -> list[Trait]:
"""
Given a list of traits, update any that already exist and create any new ones.
Return the full list of traits for the given identity after these changes.
Expand All @@ -235,23 +240,30 @@ def update_traits(self, trait_data_items):
"""
current_traits = {t.trait_key: t for t in self.identity_traits.all()}

keys_to_delete = []
keys_to_delete = set()
new_traits = []
updated_traits = []
transient_traits = []

for trait_data_item in trait_data_items:
trait_key = trait_data_item["trait_key"]
trait_value = trait_data_item["trait_value"]
transient = trait_data_item.get("transient")

if trait_value is None:
# build a list of trait keys to delete having been nulled by the
# input data
keys_to_delete.append(trait_key)
keys_to_delete.add(trait_key)
continue

trait_value_data = Trait.generate_trait_value_data(trait_value)

if trait_key in current_traits:
if transient:
transient_traits.append(
Trait(**trait_value_data, trait_key=trait_key, identity=self)
)

elif trait_key in current_traits:
current_trait = current_traits[trait_key]
# Don't update the trait if the value hasn't changed
if current_trait.trait_value == trait_value:
Expand All @@ -260,13 +272,19 @@ def update_traits(self, trait_data_items):
for attr, value in trait_value_data.items():
setattr(current_trait, attr, value)
updated_traits.append(current_trait)

else:
new_traits.append(
Trait(**trait_value_data, trait_key=trait_key, identity=self)
)

# delete the traits that had their keys set to None
if keys_to_delete:
current_traits = {
trait_key: trait
for trait_key, trait in current_traits.items()
if trait_key not in keys_to_delete
}
self.identity_traits.filter(trait_key__in=keys_to_delete).delete()

Trait.objects.bulk_update(updated_traits, fields=Trait.BULK_UPDATE_FIELDS)
Expand All @@ -278,5 +296,15 @@ def update_traits(self, trait_data_items):
Trait.objects.bulk_create(new_traits, ignore_conflicts=True)

# return the full list of traits for this identity by refreshing from the db
# TODO: handle this in the above logic to avoid a second hit to the DB
return self.identity_traits.all()
# override persisted traits by transient traits in case of key collisions
return [
*{
trait.trait_key: trait
for trait in chain(
current_traits.values(),
updated_traits,
new_traits,
transient_traits,
)
}.values()
]
1 change: 1 addition & 0 deletions api/environments/identities/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class _TraitSerializer(serializers.Serializer):

class SDKIdentitiesQuerySerializer(serializers.Serializer):
identifier = serializers.CharField(required=True)
transient = serializers.BooleanField(default=False)


class IdentityAllFeatureStatesFeatureSerializer(serializers.Serializer):
Expand Down
3 changes: 2 additions & 1 deletion api/environments/identities/traits/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def get_trait_value(obj):

class TraitSerializerBasic(serializers.ModelSerializer):
trait_value = TraitValueField(allow_null=True)
transient = serializers.BooleanField(default=False)

class Meta:
model = Trait
fields = ("id", "trait_key", "trait_value")
fields = ("id", "trait_key", "trait_value", "transient")
read_only_fields = ("id",)


Expand Down
18 changes: 13 additions & 5 deletions api/environments/identities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from core.request_origin import RequestOrigin
from django.conf import settings
from django.db.models import Q
from django.utils import timezone
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from drf_yasg.utils import swagger_auto_schema
Expand Down Expand Up @@ -173,11 +174,18 @@ def get(self, request):
{"detail": "Missing identifier"}
) # TODO: add 400 status - will this break the clients?

identity, _ = Identity.objects.get_or_create_for_sdk(
identifier=identifier,
environment=request.environment,
integrations=IDENTITY_INTEGRATIONS,
)
if request.query_params.get("transient"):
identity = Identity(
created_date=timezone.now(),
identifier=identifier,
environment=request.environment,
)
else:
identity, _ = Identity.objects.get_or_create_for_sdk(
identifier=identifier,
environment=request.environment,
integrations=IDENTITY_INTEGRATIONS,
)
self.identity = identity

if settings.EDGE_API_URL and request.environment.project.enable_dynamo_db:
Expand Down
35 changes: 24 additions & 11 deletions api/environments/sdk/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict

from core.constants import BOOLEAN, FLOAT, INTEGER, STRING
from django.utils import timezone
from rest_framework import serializers

from environments.identities.models import Identity
Expand Down Expand Up @@ -125,6 +126,7 @@ class IdentifyWithTraitsSerializer(
HideSensitiveFieldsSerializerMixin, serializers.Serializer
):
identifier = serializers.CharField(write_only=True, required=True)
transient = serializers.BooleanField(write_only=True, default=False)
traits = TraitSerializerBasic(required=False, many=True)
flags = SDKFeatureStateSerializer(read_only=True, many=True)

Expand All @@ -136,23 +138,34 @@ def save(self, **kwargs):
(optionally store traits if flag set on org)
"""
environment = self.context["environment"]
identity, created = Identity.objects.get_or_create(
identifier=self.validated_data["identifier"], environment=environment
)

transient = self.validated_data["transient"]
trait_data_items = self.validated_data.get("traits", [])

if not created and environment.project.organisation.persist_trait_data:
# if this is an update and we're persisting traits, then we need to
# partially update any traits and return the full list
trait_models = identity.update_traits(trait_data_items)
if transient:
identity = Identity(
created_date=timezone.now(),
identifier=self.validated_data["identifier"],
environment=environment,
)
trait_models = identity.generate_traits(trait_data_items, persist=False)

else:
# generate traits for the identity and store them if configured to do so
trait_models = identity.generate_traits(
trait_data_items,
persist=environment.project.organisation.persist_trait_data,
identity, created = Identity.objects.get_or_create(
identifier=self.validated_data["identifier"], environment=environment
)

if not created and environment.project.organisation.persist_trait_data:
# if this is an update and we're persisting traits, then we need to
# partially update any traits and return the full list
trait_models = identity.update_traits(trait_data_items)
else:
# generate traits for the identity and store them if configured to do so
trait_models = identity.generate_traits(
trait_data_items,
persist=environment.project.organisation.persist_trait_data,
)

all_feature_states = identity.get_all_feature_states(
traits=trait_models,
additional_filters=self.context.get("feature_states_additional_filters"),
Expand Down
10 changes: 8 additions & 2 deletions api/tests/unit/environments/identities/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@


def generate_trait_data_item(
trait_key: str = "trait_key", trait_value: typing.Any = "trait_value"
trait_key: str = "trait_key",
trait_value: typing.Any = "trait_value",
transient: bool = False,
):
return {"trait_key": trait_key, "trait_value": trait_value}
return {
"trait_key": trait_key,
"trait_value": trait_value,
"transient": transient,
}


def create_trait_for_identity(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,19 +662,33 @@ def test_update_traits(environment: Environment) -> None:
new_trait_1_value = 5
trait_3_key = "trait_3"
trait_3_value = 3

# and one transient trait that is evaluated against but not persisted
transient_trait_key = "transient"
transient_trait_value = "not persisted"

trait_data_items = [
generate_trait_data_item(
trait_key=trait_1.trait_key, trait_value=new_trait_1_value
),
generate_trait_data_item(trait_key=trait_3_key, trait_value=trait_3_value),
generate_trait_data_item(trait_key=trait_3_key, trait_value=trait_3_value),
generate_trait_data_item(
trait_key=transient_trait_key,
trait_value=transient_trait_value,
transient=True,
),
]

# When
updated_traits = identity.update_traits(trait_data_items)

# Then
# 3 traits are returned
assert len(updated_traits) == 3
# 3 traits are persisted
assert identity.identity_traits.count() == 3

# 4 traits are returned
assert len(updated_traits) == 4

# and the first trait has it's value updated correctly
updated_trait_1 = get_trait_from_list_by_key(trait_1_key, updated_traits)
Expand All @@ -688,6 +702,10 @@ def test_update_traits(environment: Environment) -> None:
updated_trait_3 = get_trait_from_list_by_key(trait_3_key, updated_traits)
assert updated_trait_3.trait_value == trait_3_value

# and the transient trait is returned among others
transient_trait = get_trait_from_list_by_key(transient_trait_key, updated_traits)
assert transient_trait.trait_value == transient_trait_value


def test_update_traits_deletes_when_nulled_out(environment: Environment) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,23 @@ def test_get_identities_with_hide_sensitive_data(
assert response.json()["traits"] == []


def test_get_identities__transient__no_persistence(
environment: Environment,
api_client: APIClient,
) -> None:
# Given
identifier = "transient"
api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key)
url = reverse("api-v1:sdk-identities") + f"?identifier={identifier}&transient=true"

# When
response = api_client.get(url)

# Then
assert response.status_code == status.HTTP_200_OK
assert not Identity.objects.filter(identifier=identifier).count()


def test_post_identities_with_hide_sensitive_data(
environment, feature, identity, api_client
):
Expand Down Expand Up @@ -1126,6 +1143,59 @@ def test_post_identities__server_key_only_feature__server_key_auth__return_expec
assert response.json()["flags"]


def test_post_identities__transient__no_persistence(
environment: Environment,
api_client: APIClient,
) -> None:
# Given
identifier = "transient"
trait_key = "trait_key"

api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key)
url = reverse("api-v1:sdk-identities")
data = {
"identifier": identifier,
"traits": [{"trait_key": "foo", "trait_value": "bar"}],
"transient": True,
}

# When
response = api_client.post(
url, data=json.dumps(data), content_type="application/json"
)

# Then
assert response.status_code == status.HTTP_200_OK
assert not Identity.objects.filter(identifier=identifier).count()
assert not Trait.objects.filter(trait_key=trait_key).count()


def test_post_identities__transient_traits__no_persistence(
environment: Environment,
api_client: APIClient,
) -> None:
# Given
identifier = "transient"
trait_key = "trait_key"

api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key)
url = reverse("api-v1:sdk-identities")
data = {
"identifier": identifier,
"traits": [{"trait_key": "foo", "trait_value": "bar", "transient": True}],
}

# When
response = api_client.post(
url, data=json.dumps(data), content_type="application/json"
)

# Then
assert response.status_code == status.HTTP_200_OK
assert Identity.objects.filter(identifier=identifier).count()
assert not Trait.objects.filter(trait_key=trait_key).count()


def test_user_with_view_identities_permission_can_retrieve_identity(
environment,
identity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_bulk_create_update_serializer_save_many(
mocked_request = mocker.MagicMock(environment=identity.environment)

# When
with django_assert_num_queries(6):
with django_assert_num_queries(5):
serializer = SDKBulkCreateUpdateTraitSerializer(
data=data,
many=True,
Expand Down
Loading
Loading