Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support blank identifiers, assume transient #4449

Merged
merged 10 commits into from
Aug 9, 2024
10 changes: 7 additions & 3 deletions api/environments/identities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from django.db import models
from django.db.models import Prefetch, Q
from django.utils import timezone
from flag_engine.identities.traits.types import TraitValue
from flag_engine.segments.evaluator import evaluate_identity_in_segment

from environments.identities.managers import IdentityManager
from environments.identities.traits.models import Trait
from environments.models import Environment
from environments.sdk.types import SDKTraitData
from features.models import FeatureState
from features.multivariate.models import MultivariateFeatureStateValue
from segments.models import Segment
Expand Down Expand Up @@ -196,7 +196,11 @@ def get_all_user_traits(self):
def __str__(self):
return "Account %s" % self.identifier

def generate_traits(self, trait_data_items, persist=False):
def generate_traits(
self,
trait_data_items: list[SDKTraitData],
persist=False,
) -> list[Trait]:
"""
Given a list of trait data items, validated by TraitSerializerFull, generate
a list of TraitModel objects for the given identity.
Expand Down Expand Up @@ -232,7 +236,7 @@ def generate_traits(self, trait_data_items, persist=False):

def update_traits(
self,
trait_data_items: list[dict[str, TraitValue]],
trait_data_items: list[SDKTraitData],
) -> list[Trait]:
"""
Given a list of traits, update any that already exist and create any new ones.
Expand Down
1 change: 1 addition & 0 deletions api/environments/identities/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class _TraitSerializer(serializers.Serializer):
help_text="Can be of type string, boolean, float or integer."
)

identifier = serializers.CharField()
flags = serializers.ListField(child=SDKFeatureStateSerializer())
traits = serializers.ListSerializer(child=_TraitSerializer())

Expand Down
1 change: 1 addition & 0 deletions api/environments/identities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def _get_all_feature_states_for_user_response(
serializer = serializer_class(
{
"flags": all_feature_states,
"identifier": identity.identifier,
"traits": identity.identity_traits.all(),
},
context=self.get_serializer_context(),
Expand Down
61 changes: 36 additions & 25 deletions api/environments/sdk/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict

from core.constants import BOOLEAN, FLOAT, INTEGER, STRING
from django.utils import timezone
from rest_framework import serializers

from environments.identities.models import Identity
Expand All @@ -12,6 +11,12 @@
from environments.identities.traits.fields import TraitValueField
from environments.identities.traits.models import Trait
from environments.identities.traits.serializers import TraitSerializerBasic
from environments.sdk.services import (
get_identified_transient_identity_and_traits,
get_persisted_identity_and_traits,
get_transient_identity_and_traits,
)
from environments.sdk.types import SDKTraitData
from features.serializers import (
FeatureStateSerializerFull,
SDKFeatureStateSerializer,
Expand Down Expand Up @@ -125,7 +130,11 @@ def create(self, validated_data):
class IdentifyWithTraitsSerializer(
HideSensitiveFieldsSerializerMixin, serializers.Serializer
):
identifier = serializers.CharField(write_only=True, required=True)
identifier = serializers.CharField(
required=False,
allow_blank=True,
allow_null=True,
)
transient = serializers.BooleanField(write_only=True, default=False)
traits = TraitSerializerBasic(required=False, many=True)
flags = SDKFeatureStateSerializer(read_only=True, many=True)
Expand All @@ -137,44 +146,46 @@ def save(self, **kwargs):
Create the identity with the associated traits
(optionally store traits if flag set on org)
"""
identifier = self.validated_data.get("identifier")
environment = self.context["environment"]

transient = self.validated_data["transient"]
trait_data_items = self.validated_data.get("traits", [])
sdk_trait_data: list[SDKTraitData] = self.validated_data.get("traits", [])

if transient:
identity = Identity(
created_date=timezone.now(),
identifier=self.validated_data["identifier"],
if not identifier:
# We have a fully transient identity that should never be persisted.
identity, traits = get_transient_identity_and_traits(
environment=environment,
sdk_trait_data=sdk_trait_data,
)
trait_models = identity.generate_traits(trait_data_items, persist=False)

else:
identity, created = Identity.objects.get_or_create(
identifier=self.validated_data["identifier"], environment=environment
elif transient:
# Get presently stored traits and identity overrides
# but don't persist incoming data.
identity, traits = get_identified_transient_identity_and_traits(
environment=environment,
identifier=identifier,
sdk_trait_data=sdk_trait_data,
)

if not created and environment.project.organisation.persist_trait_data:
# if this is an update and we're persisting traits, then we need to
# partially update any traits and return the full list
trait_models = identity.update_traits(trait_data_items)
else:
# generate traits for the identity and store them if configured to do so
trait_models = identity.generate_traits(
trait_data_items,
persist=environment.project.organisation.persist_trait_data,
)
else:
# Persist the identity in accordance with non-local settings
# and individual trait transiency.
identity, traits = get_persisted_identity_and_traits(
environment=environment,
identifier=identifier,
sdk_trait_data=sdk_trait_data,
)

all_feature_states = identity.get_all_feature_states(
traits=trait_models,
traits=traits,
additional_filters=self.context.get("feature_states_additional_filters"),
)
identify_integrations(identity, all_feature_states, trait_models)
identify_integrations(identity, all_feature_states, traits)

return {
"identity": identity,
"traits": trait_models,
"identifier": identity.identifier,
"traits": traits,
"flags": all_feature_states,
}

Expand Down
86 changes: 86 additions & 0 deletions api/environments/sdk/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import uuid
from itertools import chain
from typing import TypeAlias

from django.utils import timezone

from environments.identities.models import Identity
from environments.identities.traits.models import Trait
from environments.models import Environment
from environments.sdk.types import SDKTraitData

IdentityAndTraits: TypeAlias = tuple[Identity, list[Trait]]


def _get_transient_identity(
environment: Environment,
identifier: str,
) -> Identity:
return Identity(
created_date=timezone.now(),
environment=environment,
identifier=identifier,
)


def get_transient_identity_and_traits(
environment: Environment,
sdk_trait_data: list[SDKTraitData],
) -> IdentityAndTraits:
return (
(
identity := _get_transient_identity(
environment=environment,
identifier=str(uuid.uuid4()),
)
),
identity.generate_traits(sdk_trait_data, persist=False),
)


def get_identified_transient_identity_and_traits(
environment: Environment,
identifier: str,
sdk_trait_data: list[SDKTraitData],
) -> IdentityAndTraits:
if identity := Identity.objects.filter(
environment=environment,
identifier=identifier,
).first():
for sdk_trait_data_item in sdk_trait_data:
sdk_trait_data_item["transient"] = True
return identity, identity.update_traits(sdk_trait_data)
return (
identity := _get_transient_identity(
environment=environment,
identifier=identifier,
)
), identity.generate_traits(sdk_trait_data, persist=False)


def get_persisted_identity_and_traits(
environment: Environment,
identifier: str,
sdk_trait_data: list[SDKTraitData],
) -> IdentityAndTraits:
identity, created = Identity.objects.get_or_create(
environment=environment,
identifier=identifier,
)
persist_trait_data = environment.project.organisation.persist_trait_data
if created:
return identity, identity.generate_traits(
sdk_trait_data,
persist=persist_trait_data,
)
if persist_trait_data:
return identity, identity.update_traits(sdk_trait_data)
return identity, list(
{
trait.trait_key: trait
for trait in chain(
identity.identity_traits.all(),
identity.generate_traits(sdk_trait_data, persist=False),
)
}.values()
)
9 changes: 9 additions & 0 deletions api/environments/sdk/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import typing

from typing_extensions import NotRequired


class SDKTraitData(typing.TypedDict):
trait_key: str
trait_value: typing.Any
transient: NotRequired[bool]
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
from typing import Any, Generator
from unittest import mock

import pytest
from django.urls import reverse
from pytest_lazyfixture import lazy_fixture
from pytest_mock import MockerFixture
from rest_framework import status
from rest_framework.test import APIClient

Expand Down Expand Up @@ -224,13 +227,65 @@ def test_get_feature_states_for_identity_only_makes_one_query_to_get_mv_feature_
assert len(second_identity_response_json["flags"]) == 3


def test_get_feature_states_for_identity__transient_identity__segment_match_expected(
@pytest.fixture
def existing_identity_identifier_data(
identity_identifier: str,
identity: int,
) -> dict[str, Any]:
return {"identifier": identity_identifier}


@pytest.fixture
def transient_random_identifier(
mocker: MockerFixture,
) -> Generator[str, None, None]:
uuid_mock = mocker.patch("environments.sdk.services.uuid", autospec=True)
uuid_mock.uuid4.return_value = identifier = "1199c22c-4dcb-4505-9857-5db5f258469c"
yield identifier


@pytest.mark.parametrize(
"transient_data",
[
pytest.param({"transient": True}, id="with-transient-true"),
pytest.param({"transient": False}, id="with-transient-false"),
pytest.param({}, id="missing-transient"),
],
)
@pytest.mark.parametrize(
"identifier_data,expected_identifier",
[
pytest.param(
lazy_fixture("existing_identity_identifier_data"),
lazy_fixture("identity_identifier"),
id="existing-identifier",
),
pytest.param({"identifier": "unseen"}, "unseen", id="new-identifier"),
pytest.param(
{"identifier": ""},
lazy_fixture("transient_random_identifier"),
id="blank-identifier",
),
pytest.param(
{"identifier": None},
lazy_fixture("transient_random_identifier"),
id="null-identifier",
),
pytest.param(
{}, lazy_fixture("transient_random_identifier"), id="missing-identifier"
),
],
)
def test_get_feature_states_for_identity__segment_match_expected(
sdk_client: APIClient,
feature: int,
segment: int,
segment_condition_property: str,
segment_condition_value: str,
segment_featurestate: int,
identifier_data: dict[str, Any],
transient_data: dict[str, Any],
expected_identifier: str,
) -> None:
# Given
url = reverse("api-v1:sdk-identities")
Expand All @@ -242,14 +297,14 @@ def test_get_feature_states_for_identity__transient_identity__segment_match_expe
url,
data=json.dumps(
{
"identifier": "unseen",
**identifier_data,
**transient_data,
"traits": [
{
"trait_key": segment_condition_property,
"trait_value": segment_condition_value,
}
],
"transient": True,
}
),
content_type="application/json",
Expand All @@ -258,6 +313,7 @@ def test_get_feature_states_for_identity__transient_identity__segment_match_expe
# Then
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["identifier"] == expected_identifier
assert (
flag_data := next(
(
Expand Down
Loading
Loading