From 320dbf618c2c0da280ec5f286ba2fb006ea50116 Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Thu, 19 Oct 2023 13:56:52 +0100 Subject: [PATCH 1/3] feat: engine segment evaluation --- api/segments/models.py | 17 ++++++++++++++--- api/util/mappers/engine.py | 24 ++++++++++++++++++++---- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/api/segments/models.py b/api/segments/models.py index a49e73b94ebc..017458221304 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -11,6 +11,7 @@ ) from django.core.exceptions import ValidationError from django.db import models +from flag_engine.segments.evaluator import evaluate_identity_in_segment from flag_engine.utils.semver import is_semver, remove_semver_suffix from audit.constants import SEGMENT_CREATED_MESSAGE, SEGMENT_UPDATED_MESSAGE @@ -20,6 +21,11 @@ ) from features.models import Feature from projects.models import Project +from util.mappers.engine import ( + map_identity_to_engine, + map_segment_to_engine, + map_traits_to_trait_models, +) if typing.TYPE_CHECKING: from environments.identities.models import Identity @@ -109,9 +115,14 @@ def id_exists_in_rules_data(rules_data: typing.List[dict]) -> bool: def does_identity_match( self, identity: "Identity", traits: typing.List["Trait"] = None ) -> bool: - rules = self.rules.all() - return rules.count() > 0 and all( - rule.does_identity_match(identity, traits) for rule in rules + segment_model = map_segment_to_engine(self) + identity_model = map_identity_to_engine(identity) + trait_models = map_traits_to_trait_models(traits) if traits else None + + return evaluate_identity_in_segment( + identity=identity_model, + segment=segment_model, + override_traits=trait_models, ) def get_create_log_message(self, history_instance) -> typing.Optional[str]: diff --git a/api/util/mappers/engine.py b/api/util/mappers/engine.py index 8ba196127c0b..094eb03db889 100644 --- a/api/util/mappers/engine.py +++ b/api/util/mappers/engine.py @@ -50,6 +50,25 @@ ) +def map_traits_to_trait_models(traits: Iterable["Trait"]) -> list[TraitModel]: + return [ + TraitModel(trait_key=trait.trait_key, trait_value=trait.trait_value) + for trait in traits + ] + + +def map_segment_to_engine(segment: "Segment") -> SegmentModel: + segment_rules = segment.rules.all() + + return SegmentModel( + id=segment.pk, + name=segment.name, + rules=[ + map_segment_rule_to_engine(segment_rule) for segment_rule in segment_rules + ], + ) + + def map_segment_rule_to_engine( segment_rule: "SegmentRule", ) -> SegmentRuleModel: @@ -354,10 +373,7 @@ def map_identity_to_engine(identity: "Identity") -> IdentityModel: ) for feature_state in identity_feature_states ] - identity_trait_models = [ - TraitModel(trait_key=trait.trait_key, trait_value=trait.trait_value) - for trait in identity_traits - ] + identity_trait_models = map_traits_to_trait_models(identity_traits) return IdentityModel( # Attributes: From 8a9733e39b4cdfa5ea589af04eab21630f03f061 Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Mon, 20 Nov 2023 23:45:15 +0000 Subject: [PATCH 2/3] remove dead code, switch to engine constants --- api/conftest.py | 3 +- .../tests/test_dynamodb_identity_wrapper.py | 3 +- .../identities/tests/test_views.py | 18 +- .../feature_segments/tests/test_models.py | 3 +- api/segments/models.py | 249 ++---------------- api/segments/serializers.py | 3 +- api/segments/tests/test_models.py | 3 +- ...test_environments_views_sdk_environment.py | 3 +- .../test_unit_environments_views.py | 3 +- .../test_unit_import_export_export.py | 4 +- .../unit/projects/test_unit_projects_admin.py | 3 +- api/tests/unit/segments/test_conditions.py | 157 ----------- .../segments/test_unit_segments_models.py | 29 +- .../unit/segments/test_unit_segments_views.py | 3 +- api/util/mappers/engine.py | 43 ++- 15 files changed, 80 insertions(+), 447 deletions(-) delete mode 100644 api/tests/unit/segments/test_conditions.py diff --git a/api/conftest.py b/api/conftest.py index 65224405cd4c..7c9a2270e261 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -3,6 +3,7 @@ import pytest from django.contrib.contenttypes.models import ContentType from django.core.cache import cache +from flag_engine.segments.constants import EQUAL from rest_framework.authtoken.models import Token from rest_framework.test import APIClient @@ -45,7 +46,7 @@ ) from projects.permissions import VIEW_PROJECT from projects.tags.models import Tag -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule from task_processor.task_run_method import TaskRunMethod from users.models import FFAdminUser, UserPermissionGroup diff --git a/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py b/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py index 262b83f1d7c7..f1ca7411506b 100644 --- a/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py +++ b/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py @@ -5,12 +5,13 @@ from core.constants import INTEGER from django.core.exceptions import ObjectDoesNotExist from flag_engine.identities.builders import build_identity_model +from flag_engine.segments.constants import IN from rest_framework.exceptions import NotFound from environments.dynamodb import DynamoIdentityWrapper from environments.identities.models import Identity from environments.identities.traits.models import Trait -from segments.models import IN, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule from util.mappers import ( map_environment_to_environment_document, map_identity_to_identity_document, diff --git a/api/environments/identities/tests/test_views.py b/api/environments/identities/tests/test_views.py index 6086c707b002..146de7fe4e93 100644 --- a/api/environments/identities/tests/test_views.py +++ b/api/environments/identities/tests/test_views.py @@ -4,9 +4,10 @@ from unittest.case import TestCase import pytest -from core.constants import FLAGSMITH_UPDATED_AT_HEADER +from core.constants import FLAGSMITH_UPDATED_AT_HEADER, STRING from django.test import override_settings from django.urls import reverse +from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import status from rest_framework.test import APIClient, APITestCase @@ -20,7 +21,6 @@ from integrations.amplitude.models import AmplitudeConfiguration from organisations.models import Organisation, OrganisationRole from projects.models import Project -from segments import models from segments.models import Condition, Segment, SegmentRule from util.tests import Helper @@ -370,7 +370,7 @@ def test_identities_endpoint_returns_traits(self, mock_amplitude_wrapper): trait = Trait.objects.create( identity=self.identity, trait_key="trait_key", - value_type="STRING", + value_type=STRING, string_value="trait_value", ) @@ -422,7 +422,7 @@ def test_identities_endpoint_returns_value_for_segment_if_identity_in_segment( Trait.objects.create( identity=self.identity, trait_key=trait_key, - value_type="STRING", + value_type=STRING, string_value=trait_value, ) segment = Segment.objects.create(name="Test Segment", project=self.project) @@ -476,7 +476,7 @@ def test_identities_endpoint_returns_value_for_segment_if_identity_in_segment_an Trait.objects.create( identity=self.identity, trait_key=trait_key, - value_type="STRING", + value_type=STRING, string_value=trait_value, ) segment = Segment.objects.create(name="Test Segment", project=self.project) @@ -528,7 +528,7 @@ def test_identities_endpoint_returns_value_for_segment_if_rule_type_percentage_s [segment.id, self.identity.id] ) Condition.objects.create( - operator=models.PERCENTAGE_SPLIT, + operator=PERCENTAGE_SPLIT, value=(identity_percentage_value + (1 - identity_percentage_value) / 2) * 100.0, rule=segment_rule, @@ -575,7 +575,7 @@ def test_identities_endpoint_returns_default_value_if_rule_type_percentage_split [segment.id, self.identity.id] ) Condition.objects.create( - operator=models.PERCENTAGE_SPLIT, + operator=PERCENTAGE_SPLIT, value=identity_percentage_value / 2, rule=segment_rule, ) @@ -628,13 +628,13 @@ def test_post_identify_deletes_a_trait_if_trait_value_is_none(self): trait_1 = Trait.objects.create( identity=self.identity, trait_key="trait_key_1", - value_type="STRING", + value_type=STRING, string_value="trait_value", ) trait_2 = Trait.objects.create( identity=self.identity, trait_key="trait_key_2", - value_type="STRING", + value_type=STRING, string_value="trait_value", ) diff --git a/api/features/feature_segments/tests/test_models.py b/api/features/feature_segments/tests/test_models.py index 2ff952d1a663..34222c6f3f3b 100644 --- a/api/features/feature_segments/tests/test_models.py +++ b/api/features/feature_segments/tests/test_models.py @@ -1,6 +1,7 @@ import pytest from core.constants import STRING from django.test import TestCase +from flag_engine.segments.constants import EQUAL from environments.identities.models import Identity from environments.identities.traits.models import Trait @@ -8,7 +9,7 @@ from features.models import Feature, FeatureSegment from organisations.models import Organisation from projects.models import Project -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule @pytest.mark.django_db diff --git a/api/segments/models.py b/api/segments/models.py index 017458221304..3bdc7d8399a1 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -2,8 +2,6 @@ import typing from copy import deepcopy -import semver -from core.constants import BOOLEAN, FLOAT, INTEGER from core.models import ( AbstractBaseExportableModel, SoftDeleteExportableModel, @@ -11,14 +9,11 @@ ) from django.core.exceptions import ValidationError from django.db import models +from flag_engine.segments import constants from flag_engine.segments.evaluator import evaluate_identity_in_segment -from flag_engine.utils.semver import is_semver, remove_semver_suffix from audit.constants import SEGMENT_CREATED_MESSAGE, SEGMENT_UPDATED_MESSAGE from audit.related_object_type import RelatedObjectType -from environments.identities.helpers import ( - get_hashed_percentage_for_object_ids, -) from features.models import Feature from projects.models import Project from util.mappers.engine import ( @@ -34,30 +29,6 @@ logger = logging.getLogger(__name__) -try: - import re2 as re - - logger.info("Using re2 library for regex.") -except ImportError: - logger.warning("Unable to import re2. Falling back to re.") - import re - -# Condition Types -EQUAL = "EQUAL" -GREATER_THAN = "GREATER_THAN" -LESS_THAN = "LESS_THAN" -LESS_THAN_INCLUSIVE = "LESS_THAN_INCLUSIVE" -CONTAINS = "CONTAINS" -GREATER_THAN_INCLUSIVE = "GREATER_THAN_INCLUSIVE" -NOT_CONTAINS = "NOT_CONTAINS" -NOT_EQUAL = "NOT_EQUAL" -REGEX = "REGEX" -PERCENTAGE_SPLIT = "PERCENTAGE_SPLIT" -MODULO = "MODULO" -IS_SET = "IS_SET" -IS_NOT_SET = "IS_NOT_SET" -IN = "IN" - class Segment( SoftDeleteExportableModel, @@ -116,7 +87,11 @@ def does_identity_match( self, identity: "Identity", traits: typing.List["Trait"] = None ) -> bool: segment_model = map_segment_to_engine(self) - identity_model = map_identity_to_engine(identity) + identity_model = map_identity_to_engine( + identity, + with_overrides=False, + with_traits=not traits, + ) trait_models = map_traits_to_trait_models(traits) if traits else None return evaluate_identity_in_segment( @@ -166,34 +141,6 @@ def __str__(self): str(self.segment) if self.segment else str(self.rule), ) - def does_identity_match( - self, identity: "Identity", traits: typing.List["Trait"] = None - ) -> bool: - matches_conditions = False - conditions = self.conditions.all() - - if conditions.count() == 0: - matches_conditions = True - elif self.type == self.ALL_RULE: - matches_conditions = all( - condition.does_identity_match(identity, traits) - for condition in conditions - ) - elif self.type == self.ANY_RULE: - matches_conditions = any( - condition.does_identity_match(identity, traits) - for condition in conditions - ) - elif self.type == self.NONE_RULE: - matches_conditions = not any( - condition.does_identity_match(identity, traits) - for condition in conditions - ) - - return matches_conditions and all( - rule.does_identity_match(identity, traits) for rule in self.rules.all() - ) - def get_segment(self): """ rules can be a child of a parent rule instead of a segment, this method iterates back up the tree to find the @@ -214,20 +161,20 @@ class Condition( related_object_type = RelatedObjectType.SEGMENT CONDITION_TYPES = ( - (EQUAL, "Exactly Matches"), - (GREATER_THAN, "Greater than"), - (LESS_THAN, "Less than"), - (CONTAINS, "Contains"), - (GREATER_THAN_INCLUSIVE, "Greater than or equal to"), - (LESS_THAN_INCLUSIVE, "Less than or equal to"), - (NOT_CONTAINS, "Does not contain"), - (NOT_EQUAL, "Does not match"), - (REGEX, "Matches regex"), - (PERCENTAGE_SPLIT, "Percentage split"), - (MODULO, "Modulo Operation"), - (IS_SET, "Is set"), - (IS_NOT_SET, "Is not set"), - (IN, "In"), + (constants.EQUAL, "Exactly Matches"), + (constants.GREATER_THAN, "Greater than"), + (constants.LESS_THAN, "Less than"), + (constants.CONTAINS, "Contains"), + (constants.GREATER_THAN_INCLUSIVE, "Greater than or equal to"), + (constants.LESS_THAN_INCLUSIVE, "Less than or equal to"), + (constants.NOT_CONTAINS, "Does not contain"), + (constants.NOT_EQUAL, "Does not match"), + (constants.REGEX, "Matches regex"), + (constants.PERCENTAGE_SPLIT, "Percentage split"), + (constants.MODULO, "Modulo Operation"), + (constants.IS_SET, "Is set"), + (constants.IS_NOT_SET, "Is not set"), + (constants.IN, "In"), ) operator = models.CharField(choices=CONDITION_TYPES, max_length=500) @@ -252,162 +199,6 @@ def __str__(self): self.value, ) - def does_identity_match( # noqa: C901 - self, identity: "Identity", traits: typing.List["Trait"] = None - ) -> bool: - if self.operator == PERCENTAGE_SPLIT: - return self._check_percentage_split_operator(identity) - - # we allow passing in traits to handle when they aren't - # persisted for certain organisations - traits = identity.identity_traits.all() if traits is None else traits - matching_trait = next( - filter(lambda t: t.trait_key == self.property, traits), None - ) - if matching_trait is None: - return self.operator == IS_NOT_SET - - if self.operator in (IS_SET, IS_NOT_SET): - return self.operator == IS_SET - elif self.operator == MODULO: - if matching_trait.value_type in [INTEGER, FLOAT]: - return self._check_modulo_operator(matching_trait.trait_value) - elif self.operator == IN: - return str(matching_trait.trait_value) in self.value.split(",") - elif matching_trait.value_type == INTEGER: - return self.check_integer_value(matching_trait.integer_value) - elif matching_trait.value_type == FLOAT: - return self.check_float_value(matching_trait.float_value) - elif matching_trait.value_type == BOOLEAN: - return self.check_boolean_value(matching_trait.boolean_value) - elif is_semver(self.value): - return self.check_semver_value(matching_trait.string_value) - - return self.check_string_value(matching_trait.string_value) - - def _check_percentage_split_operator(self, identity): - try: - float_value = float(self.value) / 100.0 - except ValueError: - return False - - segment = self.rule.get_segment() - return ( - get_hashed_percentage_for_object_ids( - object_ids=[segment.id, identity.get_hash_key()] - ) - <= float_value - ) - - def _check_modulo_operator(self, value: typing.Union[int, float]) -> bool: - try: - divisor, remainder = self.value.split("|") - divisor = float(divisor) - remainder = float(remainder) - except ValueError: - return False - - return value % divisor == remainder - - def check_integer_value(self, value: int) -> bool: - try: - int_value = int(str(self.value)) - except ValueError: - return False - - if self.operator == EQUAL: - return value == int_value - elif self.operator == GREATER_THAN: - return value > int_value - elif self.operator == GREATER_THAN_INCLUSIVE: - return value >= int_value - elif self.operator == LESS_THAN: - return value < int_value - elif self.operator == LESS_THAN_INCLUSIVE: - return value <= int_value - elif self.operator == NOT_EQUAL: - return value != int_value - - return False - - def check_float_value(self, value: float) -> bool: - try: - float_value = float(str(self.value)) - except ValueError: - return False - - if self.operator == EQUAL: - return value == float_value - elif self.operator == GREATER_THAN: - return value > float_value - elif self.operator == GREATER_THAN_INCLUSIVE: - return value >= float_value - elif self.operator == LESS_THAN: - return value < float_value - elif self.operator == LESS_THAN_INCLUSIVE: - return value <= float_value - elif self.operator == NOT_EQUAL: - return value != float_value - - return False - - def check_boolean_value(self, value: bool) -> bool: - if self.value in ("False", "false", "0"): - bool_value = False - elif self.value in ("True", "true", "1"): - bool_value = True - else: - return False - - if self.operator == EQUAL: - return value == bool_value - elif self.operator == NOT_EQUAL: - return value != bool_value - - return False - - def check_semver_value(self, value: str) -> bool: - try: - condition_version_info = semver.VersionInfo.parse( - remove_semver_suffix(self.value) - ) - except ValueError: - return False - - if self.operator == EQUAL: - return value == condition_version_info - elif self.operator == GREATER_THAN: - return value > condition_version_info - elif self.operator == GREATER_THAN_INCLUSIVE: - return value >= condition_version_info - elif self.operator == LESS_THAN: - return value < condition_version_info - elif self.operator == LESS_THAN_INCLUSIVE: - return value <= condition_version_info - elif self.operator == NOT_EQUAL: - return value != condition_version_info - - return False - - def check_string_value(self, value: str) -> bool: - try: - str_value = str(self.value) - except ValueError: - return False - - if self.operator == EQUAL: - return value == str_value - elif self.operator == NOT_EQUAL: - return value != str_value - elif self.operator == CONTAINS: - return str_value in value - elif self.operator == NOT_CONTAINS: - return str_value not in value - elif self.operator == REGEX: - return re.compile(str(self.value)).match(value) is not None - - return False - def get_update_log_message(self, history_instance) -> typing.Optional[str]: return f"Condition updated on segment '{self._get_segment().name}'." diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 29a99c292e25..28966e65fd7c 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -1,12 +1,13 @@ import typing +from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.serializers import ListSerializer from rest_framework_recursive.fields import RecursiveField from projects.models import Project -from segments.models import PERCENTAGE_SPLIT, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule class ConditionSerializer(serializers.ModelSerializer): diff --git a/api/segments/tests/test_models.py b/api/segments/tests/test_models.py index 7c49ad1e263e..176931a0f22e 100644 --- a/api/segments/tests/test_models.py +++ b/api/segments/tests/test_models.py @@ -1,12 +1,13 @@ from unittest import TestCase import pytest +from flag_engine.segments.constants import PERCENTAGE_SPLIT from environments.identities.models import Identity from environments.models import Environment from organisations.models import Organisation from projects.models import Project -from segments.models import PERCENTAGE_SPLIT, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule @pytest.mark.django_db diff --git a/api/tests/unit/environments/test_environments_views_sdk_environment.py b/api/tests/unit/environments/test_environments_views_sdk_environment.py index 9dc3983b1466..5857fa66b895 100644 --- a/api/tests/unit/environments/test_environments_views_sdk_environment.py +++ b/api/tests/unit/environments/test_environments_views_sdk_environment.py @@ -1,11 +1,12 @@ from core.constants import FLAGSMITH_UPDATED_AT_HEADER from django.urls import reverse +from flag_engine.segments.constants import EQUAL from rest_framework import status from rest_framework.test import APIClient from environments.models import Environment, EnvironmentAPIKey from features.models import Feature -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule def test_get_environment_document( diff --git a/api/tests/unit/environments/test_unit_environments_views.py b/api/tests/unit/environments/test_unit_environments_views.py index 0cbba069fbe3..303f38f33c56 100644 --- a/api/tests/unit/environments/test_unit_environments_views.py +++ b/api/tests/unit/environments/test_unit_environments_views.py @@ -6,6 +6,7 @@ from core.constants import STRING from django.contrib.contenttypes.models import ContentType from django.urls import reverse +from flag_engine.segments.constants import EQUAL from pytest_lazyfixture import lazy_fixture from rest_framework import status from rest_framework.test import APIClient @@ -25,7 +26,7 @@ UserProjectPermission, ) from projects.permissions import CREATE_ENVIRONMENT, VIEW_PROJECT -from segments.models import EQUAL, Condition, SegmentRule +from segments.models import Condition, SegmentRule from users.models import FFAdminUser from util.tests import Helper diff --git a/api/tests/unit/import_export/test_unit_import_export_export.py b/api/tests/unit/import_export/test_unit_import_export_export.py index 595bf3e172f2..b73a80d801d3 100644 --- a/api/tests/unit/import_export/test_unit_import_export_export.py +++ b/api/tests/unit/import_export/test_unit_import_export_export.py @@ -7,7 +7,7 @@ from django.contrib.contenttypes.models import ContentType from django.core.management import call_command from django.core.serializers.json import DjangoJSONEncoder -from flag_engine.segments.constants import ALL_RULE +from flag_engine.segments.constants import ALL_RULE, EQUAL from moto import mock_s3 from environments.models import Environment, EnvironmentAPIKey, Webhook @@ -42,7 +42,7 @@ from organisations.models import Organisation, OrganisationWebhook from projects.models import Project from projects.tags.models import Tag -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule def test_export_organisation(db): diff --git a/api/tests/unit/projects/test_unit_projects_admin.py b/api/tests/unit/projects/test_unit_projects_admin.py index 0be0f847cad3..e1f735f576e1 100644 --- a/api/tests/unit/projects/test_unit_projects_admin.py +++ b/api/tests/unit/projects/test_unit_projects_admin.py @@ -3,12 +3,13 @@ import pytest from django.contrib.admin import AdminSite +from flag_engine.segments.constants import EQUAL from environments.models import Environment from features.models import Feature, FeatureSegment, FeatureState from projects.admin import ProjectAdmin from projects.models import Project -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule if typing.TYPE_CHECKING: from django.contrib.auth.models import AbstractUser diff --git a/api/tests/unit/segments/test_conditions.py b/api/tests/unit/segments/test_conditions.py deleted file mode 100644 index 48191203996d..000000000000 --- a/api/tests/unit/segments/test_conditions.py +++ /dev/null @@ -1,157 +0,0 @@ -import pytest -from core.constants import INTEGER, STRING - -from environments.identities.traits.models import Trait -from segments.models import ( - EQUAL, - GREATER_THAN, - GREATER_THAN_INCLUSIVE, - IN, - IS_NOT_SET, - IS_SET, - LESS_THAN, - LESS_THAN_INCLUSIVE, - MODULO, - NOT_EQUAL, - Condition, -) - - -@pytest.mark.parametrize( - "operator, trait_value, condition_value, result", - [ - (EQUAL, "1.0.0", "1.0.0:semver", True), - (EQUAL, "1.0.0", "1.0.1:semver", False), - (NOT_EQUAL, "1.0.0", "1.0.0:semver", False), - (NOT_EQUAL, "1.0.0", "1.0.1:semver", True), - (GREATER_THAN, "1.0.1", "1.0.0:semver", True), - (GREATER_THAN, "1.0.0", "1.0.0-beta:semver", True), - (GREATER_THAN, "1.0.1", "1.2.0:semver", False), - (GREATER_THAN, "1.0.1", "1.0.1:semver", False), - (GREATER_THAN, "1.2.4", "1.2.3-pre.2+build.4:semver", True), - (LESS_THAN, "1.0.0", "1.0.1:semver", True), - (LESS_THAN, "1.0.0", "1.0.0:semver", False), - (LESS_THAN, "1.0.1", "1.0.0:semver", False), - (LESS_THAN, "1.0.0-rc.2", "1.0.0-rc.3:semver", True), - (GREATER_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", True), - (GREATER_THAN_INCLUSIVE, "1.0.1", "1.2.0:semver", False), - (GREATER_THAN_INCLUSIVE, "1.0.1", "1.0.1:semver", True), - (LESS_THAN_INCLUSIVE, "1.0.0", "1.0.1:semver", True), - (LESS_THAN_INCLUSIVE, "1.0.0", "1.0.0:semver", True), - (LESS_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", False), - ], -) -def test_does_identity_match_for_semver_values( - identity, operator, trait_value, condition_value, result -): - # Given - condition = Condition(operator=operator, property="version", value=condition_value) - traits = [ - Trait( - trait_key="version", - string_value=trait_value, - identity=identity, - ) - ] - # Then - assert condition.does_identity_match(identity, traits) is result - - -@pytest.mark.parametrize( - "trait_value, condition_value, result", - [ - (1, "2|0", False), - (2, "2|0", True), - (3, "2|0", False), - (34.2, "4|3", False), - (35.0, "4|3", True), - ("dummy", "3|0", False), - ("1.0.0", "3|0", False), - (False, "1|3", False), - ], -) -def test_does_identity_match_for_modulo_operator( - identity, trait_value, condition_value, result -): - condition = Condition(operator=MODULO, property="user_id", value=condition_value) - - trait_value_data = Trait.generate_trait_value_data(trait_value) - traits = [Trait(trait_key="user_id", identity=identity, **trait_value_data)] - - assert condition.does_identity_match(identity, traits) is result - - -def test_does_identity_match_is_set_true(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_SET, property=trait_key) - traits = [Trait(trait_key=trait_key, identity=identity)] - - # Then - assert condition.does_identity_match(identity, traits) is True - - -def test_does_identity_match_is_set_false(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_SET, property=trait_key) - traits = [] - - # Then - assert condition.does_identity_match(identity, traits) is False - - -def test_does_identity_match_is_not_set_true(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_NOT_SET, property=trait_key) - traits = [Trait(trait_key=trait_key, identity=identity)] - - # Then - assert condition.does_identity_match(identity, traits) is False - - -def test_does_identity_match_is_not_set_false(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_NOT_SET, property=trait_key) - traits = [] - - # Then - assert condition.does_identity_match(identity, traits) is True - - -@pytest.mark.parametrize( - "condition_value, trait_value_type, trait_string_value, trait_integer_value, expected_result", - ( - ("", STRING, "foo", None, False), - ("foo,bar", STRING, "foo", None, True), - ("foo", STRING, "foo", None, True), - ("1,2,3,4", INTEGER, None, 1, True), - ("", INTEGER, None, 1, False), - ("1", INTEGER, None, 1, True), - ), -) -def test_does_identity_match_in( - identity, - condition_value, - trait_value_type, - trait_string_value, - trait_integer_value, - expected_result, -): - # Given - trait_key = "some_property" - condition = Condition(operator=IN, property=trait_key, value=condition_value) - traits = [ - Trait( - trait_key=trait_key, - identity=identity, - value_type=trait_value_type, - string_value=trait_string_value, - integer_value=trait_integer_value, - ) - ] - - # Then - assert condition.does_identity_match(identity, traits) is expected_result diff --git a/api/tests/unit/segments/test_unit_segments_models.py b/api/tests/unit/segments/test_unit_segments_models.py index 738bd0477c85..54163a12ffae 100644 --- a/api/tests/unit/segments/test_unit_segments_models.py +++ b/api/tests/unit/segments/test_unit_segments_models.py @@ -1,32 +1,7 @@ import pytest +from flag_engine.segments.constants import EQUAL -from segments.models import ( - EQUAL, - PERCENTAGE_SPLIT, - Condition, - Segment, - SegmentRule, -) - - -def test_percentage_split_calculation_divides_value_by_100_before_comparison( - mocker, segment, segment_rule, identity -): - # Given - mock_get_hashed_percentage_for_object_ids = mocker.patch( - "segments.models.get_hashed_percentage_for_object_ids" - ) - - condition = Condition.objects.create( - rule=segment_rule, operator=PERCENTAGE_SPLIT, value=10 - ) - mock_get_hashed_percentage_for_object_ids.return_value = 0.2 - - # When - result = condition.does_identity_match(identity) - - # Then - assert not result +from segments.models import Condition, Segment, SegmentRule def test_condition_get_create_log_message_for_condition_created_with_segment( diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index ebdbe7c545c3..e0126b386a1e 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -4,6 +4,7 @@ import pytest from django.contrib.auth import get_user_model from django.urls import reverse +from flag_engine.segments.constants import EQUAL from pytest_lazyfixture import lazy_fixture from rest_framework import status @@ -11,7 +12,7 @@ from audit.related_object_type import RelatedObjectType from environments.models import Environment from features.models import Feature -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule from util.mappers import map_identity_to_identity_document User = get_user_model() diff --git a/api/util/mappers/engine.py b/api/util/mappers/engine.py index 094eb03db889..6a69c83d4f46 100644 --- a/api/util/mappers/engine.py +++ b/api/util/mappers/engine.py @@ -57,9 +57,13 @@ def map_traits_to_trait_models(traits: Iterable["Trait"]) -> list[TraitModel]: ] -def map_segment_to_engine(segment: "Segment") -> SegmentModel: +def map_segment_to_engine( + segment: "Segment", +) -> SegmentModel: segment_rules = segment.rules.all() + # No reading from ORM past this point! + return SegmentModel( id=segment.pk, name=segment.name, @@ -188,7 +192,7 @@ def map_environment_to_engine( int, Iterable["SegmentRule"], ] = {segment.pk: segment.rules.all() for segment in project_segments} - project_segment_feature_states_by_segment_id = _get_project_segment_feature_states( + project_segment_feature_states_by_segment_id = _get_segment_feature_states( project_segments, environment.pk, ) @@ -351,19 +355,30 @@ def map_environment_api_key_to_engine( ) -def map_identity_to_engine(identity: "Identity") -> IdentityModel: +def map_identity_to_engine( + identity: "Identity", + *, + with_overrides: bool = True, + with_traits: bool = True, +) -> IdentityModel: environment_api_key = identity.environment.api_key # Read relationships - grab all the data needed from the ORM here. - identity_feature_states: List["FeatureState"] = _get_prioritised_feature_states( - identity.identity_features.all(), - ) - multivariate_feature_state_values_by_feature_state_id = { - feature_state.pk: feature_state.multivariate_feature_state_values.all() - for feature_state in identity_feature_states - } + if with_overrides: + identity_feature_states: List["FeatureState"] = _get_prioritised_feature_states( + identity.identity_features.all(), + ) + multivariate_feature_state_values_by_feature_state_id = { + feature_state.pk: feature_state.multivariate_feature_state_values.all() + for feature_state in identity_feature_states + } + else: + identity_feature_states = [] + multivariate_feature_state_values_by_feature_state_id = {} - identity_traits: Iterable["Trait"] = identity.identity_traits.all() + identity_traits: Iterable["Trait"] = ( + identity.identity_traits.all() if with_traits else [] + ) # Prepare relationships. identity_feature_state_models = [ @@ -406,12 +421,12 @@ def _get_prioritised_feature_states( return list(prioritised_feature_state_by_feature_id.values()) -def _get_project_segment_feature_states( - project_segments: Iterable["Segment"], +def _get_segment_feature_states( + segments: Iterable["Segment"], environment_id: int, ) -> Dict[int, List["FeatureState"]]: feature_states_by_segment_id = {} - for segment in project_segments: + for segment in segments: segment_feature_states = feature_states_by_segment_id.setdefault(segment.pk, []) for feature_segment in segment.feature_segments.all(): if feature_segment.environment_id != environment_id: From 3b11d538a5821bf27f2a24ae344ef00b3437549c Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Tue, 21 Nov 2023 10:12:54 +0000 Subject: [PATCH 3/3] replace missing import --- .../identities/tests/test_models.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/api/environments/identities/tests/test_models.py b/api/environments/identities/tests/test_models.py index 3ec3bad9f5a6..9e0691b22f12 100644 --- a/api/environments/identities/tests/test_models.py +++ b/api/environments/identities/tests/test_models.py @@ -1,6 +1,13 @@ import pytest from core.constants import FLOAT from django.utils import timezone +from flag_engine.segments.constants import ( + EQUAL, + GREATER_THAN, + GREATER_THAN_INCLUSIVE, + LESS_THAN_INCLUSIVE, + NOT_EQUAL, +) from rest_framework.test import APITestCase from environments.identities.models import Identity @@ -15,16 +22,7 @@ from features.value_types import BOOLEAN, INTEGER, STRING from organisations.models import Organisation from projects.models import Project -from segments.models import ( - EQUAL, - GREATER_THAN, - GREATER_THAN_INCLUSIVE, - LESS_THAN_INCLUSIVE, - NOT_EQUAL, - Condition, - Segment, - SegmentRule, -) +from segments.models import Condition, Segment, SegmentRule from .helpers import ( create_trait_for_identity,