Skip to content

Commit

Permalink
feat: Support transient identities and traits (#4325)
Browse files Browse the repository at this point in the history
  • Loading branch information
khvn26 authored Jul 17, 2024
1 parent 56e6390 commit 27f6539
Show file tree
Hide file tree
Showing 12 changed files with 339 additions and 48 deletions.
83 changes: 60 additions & 23 deletions api/environments/identities/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing
from itertools import chain

from django.db import models
from django.db.models import Prefetch, Q
from django.utils import timezone
from flag_engine.identities.traits.types import TraitValue
from flag_engine.segments.evaluator import evaluate_identity_in_segment

from environments.identities.managers import IdentityManager
Expand Down Expand Up @@ -204,28 +206,32 @@ def generate_traits(self, trait_data_items, persist=False):
:return: list of TraitModels
"""
trait_models = []
trait_models_to_persist = []

# Remove traits having Null(None) values
trait_data_items = filter(
lambda trait: trait["trait_value"] is not None, trait_data_items
)
for trait_data_item in trait_data_items:
# exclude traits with null values
if (trait_value := trait_data_item["trait_value"]) is None:
continue

trait_key = trait_data_item["trait_key"]
trait_value = trait_data_item["trait_value"]
trait_models.append(
Trait(
trait_key=trait_key,
identity=self,
**Trait.generate_trait_value_data(trait_value),
)
trait = Trait(
trait_key=trait_key,
identity=self,
**Trait.generate_trait_value_data(trait_value),
)
trait_models.append(trait)
if not trait_data_item.get("transient"):
trait_models_to_persist.append(trait)

if persist:
Trait.objects.bulk_create(trait_models)
Trait.objects.bulk_create(trait_models_to_persist)

return trait_models

def update_traits(self, trait_data_items):
def update_traits(
self,
trait_data_items: list[dict[str, TraitValue]],
) -> list[Trait]:
"""
Given a list of traits, update any that already exist and create any new ones.
Return the full list of traits for the given identity after these changes.
Expand All @@ -235,38 +241,59 @@ def update_traits(self, trait_data_items):
"""
current_traits = {t.trait_key: t for t in self.identity_traits.all()}

keys_to_delete = []
keys_to_delete = set()
new_traits = []
updated_traits = []
transient_traits = []

for trait_data_item in trait_data_items:
trait_key = trait_data_item["trait_key"]
trait_value = trait_data_item["trait_value"]
transient = trait_data_item.get("transient")

if transient:
transient_traits.append(
Trait(
**Trait.generate_trait_value_data(trait_value),
trait_key=trait_key,
identity=self,
)
)
continue

if trait_value is None:
# build a list of trait keys to delete having been nulled by the
# input data
keys_to_delete.append(trait_key)
keys_to_delete.add(trait_key)
continue

trait_value_data = Trait.generate_trait_value_data(trait_value)

if trait_key in current_traits:
current_trait = current_traits[trait_key]
# Don't update the trait if the value hasn't changed
if current_trait.trait_value == trait_value:
continue

for attr, value in trait_value_data.items():
for attr, value in Trait.generate_trait_value_data(trait_value).items():
setattr(current_trait, attr, value)
updated_traits.append(current_trait)
else:
new_traits.append(
Trait(**trait_value_data, trait_key=trait_key, identity=self)
continue

new_traits.append(
Trait(
**Trait.generate_trait_value_data(trait_value),
trait_key=trait_key,
identity=self,
)
)

# delete the traits that had their keys set to None
# (except the transient ones)
if keys_to_delete:
current_traits = {
trait_key: trait
for trait_key, trait in current_traits.items()
if trait_key not in keys_to_delete
}
self.identity_traits.filter(trait_key__in=keys_to_delete).delete()

Trait.objects.bulk_update(updated_traits, fields=Trait.BULK_UPDATE_FIELDS)
Expand All @@ -278,5 +305,15 @@ def update_traits(self, trait_data_items):
Trait.objects.bulk_create(new_traits, ignore_conflicts=True)

# return the full list of traits for this identity by refreshing from the db
# TODO: handle this in the above logic to avoid a second hit to the DB
return self.identity_traits.all()
# override persisted traits by transient traits in case of key collisions
return [
*{
trait.trait_key: trait
for trait in chain(
current_traits.values(),
updated_traits,
new_traits,
transient_traits,
)
}.values()
]
1 change: 1 addition & 0 deletions api/environments/identities/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class _TraitSerializer(serializers.Serializer):

class SDKIdentitiesQuerySerializer(serializers.Serializer):
identifier = serializers.CharField(required=True)
transient = serializers.BooleanField(default=False)


class IdentityAllFeatureStatesFeatureSerializer(serializers.Serializer):
Expand Down
3 changes: 2 additions & 1 deletion api/environments/identities/traits/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def get_trait_value(obj):

class TraitSerializerBasic(serializers.ModelSerializer):
trait_value = TraitValueField(allow_null=True)
transient = serializers.BooleanField(default=False, write_only=True)

class Meta:
model = Trait
fields = ("id", "trait_key", "trait_value")
fields = ("id", "trait_key", "trait_value", "transient")
read_only_fields = ("id",)


Expand Down
18 changes: 13 additions & 5 deletions api/environments/identities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from core.request_origin import RequestOrigin
from django.conf import settings
from django.db.models import Q
from django.utils import timezone
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from drf_yasg.utils import swagger_auto_schema
Expand Down Expand Up @@ -173,11 +174,18 @@ def get(self, request):
{"detail": "Missing identifier"}
) # TODO: add 400 status - will this break the clients?

identity, _ = Identity.objects.get_or_create_for_sdk(
identifier=identifier,
environment=request.environment,
integrations=IDENTITY_INTEGRATIONS,
)
if request.query_params.get("transient"):
identity = Identity(
created_date=timezone.now(),
identifier=identifier,
environment=request.environment,
)
else:
identity, _ = Identity.objects.get_or_create_for_sdk(
identifier=identifier,
environment=request.environment,
integrations=IDENTITY_INTEGRATIONS,
)
self.identity = identity

if settings.EDGE_API_URL and request.environment.project.enable_dynamo_db:
Expand Down
35 changes: 24 additions & 11 deletions api/environments/sdk/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict

from core.constants import BOOLEAN, FLOAT, INTEGER, STRING
from django.utils import timezone
from rest_framework import serializers

from environments.identities.models import Identity
Expand Down Expand Up @@ -125,6 +126,7 @@ class IdentifyWithTraitsSerializer(
HideSensitiveFieldsSerializerMixin, serializers.Serializer
):
identifier = serializers.CharField(write_only=True, required=True)
transient = serializers.BooleanField(write_only=True, default=False)
traits = TraitSerializerBasic(required=False, many=True)
flags = SDKFeatureStateSerializer(read_only=True, many=True)

Expand All @@ -136,23 +138,34 @@ def save(self, **kwargs):
(optionally store traits if flag set on org)
"""
environment = self.context["environment"]
identity, created = Identity.objects.get_or_create(
identifier=self.validated_data["identifier"], environment=environment
)

transient = self.validated_data["transient"]
trait_data_items = self.validated_data.get("traits", [])

if not created and environment.project.organisation.persist_trait_data:
# if this is an update and we're persisting traits, then we need to
# partially update any traits and return the full list
trait_models = identity.update_traits(trait_data_items)
if transient:
identity = Identity(
created_date=timezone.now(),
identifier=self.validated_data["identifier"],
environment=environment,
)
trait_models = identity.generate_traits(trait_data_items, persist=False)

else:
# generate traits for the identity and store them if configured to do so
trait_models = identity.generate_traits(
trait_data_items,
persist=environment.project.organisation.persist_trait_data,
identity, created = Identity.objects.get_or_create(
identifier=self.validated_data["identifier"], environment=environment
)

if not created and environment.project.organisation.persist_trait_data:
# if this is an update and we're persisting traits, then we need to
# partially update any traits and return the full list
trait_models = identity.update_traits(trait_data_items)
else:
# generate traits for the identity and store them if configured to do so
trait_models = identity.generate_traits(
trait_data_items,
persist=environment.project.organisation.persist_trait_data,
)

all_feature_states = identity.get_all_feature_states(
traits=trait_models,
additional_filters=self.context.get("feature_states_additional_filters"),
Expand Down
2 changes: 2 additions & 0 deletions api/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def segment_featurestate(
feature_segment: int,
) -> int:
data = {
"enabled": True,
"feature_state_value": {"type": "unicode", "string_value": "segment override"},
"feature": feature,
"environment": environment,
"feature_segment": feature_segment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient

from features.feature_types import MULTIVARIATE
from tests.integration.helpers import (
Expand Down Expand Up @@ -221,3 +222,103 @@ def test_get_feature_states_for_identity_only_makes_one_query_to_get_mv_feature_

second_identity_response_json = second_identity_response.json()
assert len(second_identity_response_json["flags"]) == 3


def test_get_feature_states_for_identity__transient_identity__segment_match_expected(
sdk_client: APIClient,
feature: int,
segment: int,
segment_condition_property: str,
segment_condition_value: str,
segment_featurestate: int,
) -> None:
# Given
url = reverse("api-v1:sdk-identities")

# When
# flags are requested for a new transient identity
# that matches the segment
response = sdk_client.post(
url,
data=json.dumps(
{
"identifier": "unseen",
"traits": [
{
"trait_key": segment_condition_property,
"trait_value": segment_condition_value,
}
],
"transient": True,
}
),
content_type="application/json",
)

# Then
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert (
flag_data := next(
(
flag
for flag in response_json["flags"]
if flag["feature"]["id"] == feature
),
None,
)
)
assert flag_data["enabled"] is True
assert flag_data["feature_state_value"] == "segment override"


def test_get_feature_states_for_identity__transient_trait__segment_match_expected(
sdk_client: APIClient,
feature: int,
segment: int,
segment_condition_property: str,
segment_condition_value: str,
segment_featurestate: int,
) -> None:
# Given
url = reverse("api-v1:sdk-identities")

# When
# flags are requested for a new transient identity
# that matches the segment
response = sdk_client.post(
url,
data=json.dumps(
{
"identifier": "unseen",
"traits": [
{
"trait_key": segment_condition_property,
"trait_value": segment_condition_value,
"transient": True,
},
{
"trait_key": "persistent",
"trait_value": "trait value",
},
],
}
),
content_type="application/json",
)

# Then
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert (
flag_data := next(
(
flag
for flag in response_json["flags"]
if flag["feature"]["id"] == feature
),
None,
)
)
assert flag_data["enabled"] is True
assert flag_data["feature_state_value"] == "segment override"
10 changes: 8 additions & 2 deletions api/tests/unit/environments/identities/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@


def generate_trait_data_item(
trait_key: str = "trait_key", trait_value: typing.Any = "trait_value"
trait_key: str = "trait_key",
trait_value: typing.Any = "trait_value",
transient: bool = False,
):
return {"trait_key": trait_key, "trait_value": trait_value}
return {
"trait_key": trait_key,
"trait_value": trait_value,
"transient": transient,
}


def create_trait_for_identity(
Expand Down
Loading

0 comments on commit 27f6539

Please sign in to comment.