diff --git a/api/environments/dynamodb/__init__.py b/api/environments/dynamodb/__init__.py index 0dbfce3d2107..f36c2a219915 100644 --- a/api/environments/dynamodb/__init__.py +++ b/api/environments/dynamodb/__init__.py @@ -1,5 +1,13 @@ -from .dynamodb_wrapper import ( # noqa +from .dynamodb_wrapper import ( DynamoEnvironmentAPIKeyWrapper, + DynamoEnvironmentV2Wrapper, DynamoEnvironmentWrapper, DynamoIdentityWrapper, ) + +__all__ = ( + "DynamoEnvironmentAPIKeyWrapper", + "DynamoEnvironmentV2Wrapper", + "DynamoEnvironmentWrapper", + "DynamoIdentityWrapper", +) diff --git a/api/environments/dynamodb/services.py b/api/environments/dynamodb/services.py index ffa3abb1b11d..10af48ce855c 100644 --- a/api/environments/dynamodb/services.py +++ b/api/environments/dynamodb/services.py @@ -15,12 +15,12 @@ logger = logging.getLogger(__name__) -def migrate_environments_to_v2(project_id: int) -> None: +def migrate_environments_to_v2(project_id: int) -> IdentityOverridesV2Changeset | None: dynamo_wrapper_v2 = DynamoEnvironmentV2Wrapper() identity_wrapper = DynamoIdentityWrapper() if not (dynamo_wrapper_v2.is_enabled and identity_wrapper.is_enabled): - return + return None logger.info("Migrating environments to v2 for project %d", project_id) @@ -43,6 +43,7 @@ def migrate_environments_to_v2(project_id: int) -> None: dynamo_wrapper_v2.update_identity_overrides(changeset) logger.info("Finished migrating environments to v2 for project %d", project_id) + return changeset def _iter_paginated_overrides( diff --git a/api/environments/models.py b/api/environments/models.py index 05b4974d11f4..657810aaa9cd 100644 --- a/api/environments/models.py +++ b/api/environments/models.py @@ -37,6 +37,7 @@ ) from environments.dynamodb import ( DynamoEnvironmentAPIKeyWrapper, + DynamoEnvironmentV2Wrapper, DynamoEnvironmentWrapper, ) from environments.exceptions import EnvironmentHeaderNotPresentError @@ -44,6 +45,7 @@ from features.models import Feature, FeatureSegment, FeatureState from features.versioning.exceptions import FeatureVersioningError from metadata.models import Metadata +from projects.models import IdentityOverridesV2MigrationStatus, Project from segments.models import Segment from util.mappers import map_environment_to_environment_document from webhooks.models import AbstractBaseExportableWebhookModel @@ -57,6 +59,7 @@ # Intialize the dynamo environment wrapper(s) globaly environment_wrapper = DynamoEnvironmentWrapper() +environment_v2_wrapper = DynamoEnvironmentV2Wrapper() environment_api_key_wrapper = DynamoEnvironmentAPIKeyWrapper() @@ -234,7 +237,7 @@ def write_environments_to_dynamodb( # grab the first project and verify that each environment is for the same # project (which should always be the case). Since we're working with fairly # small querysets here, this shouldn't have a noticeable impact on performance. - project = getattr(environments[0], "project", None) + project: Project | None = getattr(environments[0], "project", None) for environment in environments[1:]: if not environment.project == project: raise RuntimeError("Environments must all belong to the same project.") @@ -244,6 +247,13 @@ def write_environments_to_dynamodb( environment_wrapper.write_environments(environments) + if ( + project.identity_overrides_v2_migration_status + == IdentityOverridesV2MigrationStatus.COMPLETE + and environment_v2_wrapper.is_enabled + ): + environment_v2_wrapper.write_environments(environments) + def get_feature_state( self, feature_id: int, filter_kwargs: dict = None ) -> typing.Optional[FeatureState]: diff --git a/api/projects/tasks.py b/api/projects/tasks.py index 743d741c1016..57abc6ada4c9 100644 --- a/api/projects/tasks.py +++ b/api/projects/tasks.py @@ -4,21 +4,21 @@ @register_task_handler() -def write_environments_to_dynamodb(project_id: int): +def write_environments_to_dynamodb(project_id: int) -> None: from environments.models import Environment Environment.write_environments_to_dynamodb(project_id=project_id) @register_task_handler() -def migrate_project_environments_to_v2(project_id: int): +def migrate_project_environments_to_v2(project_id: int) -> None: from environments.dynamodb.services import migrate_environments_to_v2 from projects.models import IdentityOverridesV2MigrationStatus, Project with transaction.atomic(): project = Project.objects.select_for_update().get(id=project_id) - migrate_environments_to_v2(project_id=project_id) - project.identity_overrides_v2_migration_status = ( - IdentityOverridesV2MigrationStatus.COMPLETE - ) - project.save() + if migrate_environments_to_v2(project_id=project_id): + project.identity_overrides_v2_migration_status = ( + IdentityOverridesV2MigrationStatus.COMPLETE + ) + project.save() diff --git a/api/tests/unit/environments/conftest.py b/api/tests/unit/environments/conftest.py index a4d7d20b9638..d9383b131a49 100644 --- a/api/tests/unit/environments/conftest.py +++ b/api/tests/unit/environments/conftest.py @@ -1,6 +1,14 @@ +from unittest.mock import Mock + import pytest +from pytest_mock import MockerFixture @pytest.fixture() -def mock_dynamo_env_wrapper(mocker): +def mock_dynamo_env_wrapper(mocker: MockerFixture) -> Mock: return mocker.patch("environments.models.environment_wrapper") + + +@pytest.fixture() +def mock_dynamo_env_v2_wrapper(mocker: MockerFixture) -> Mock: + return mocker.patch("environments.models.environment_v2_wrapper") diff --git a/api/tests/unit/environments/test_unit_environments_models.py b/api/tests/unit/environments/test_unit_environments_models.py index 6a00da4c6896..6ebf6077e2d5 100644 --- a/api/tests/unit/environments/test_unit_environments_models.py +++ b/api/tests/unit/environments/test_unit_environments_models.py @@ -2,7 +2,7 @@ from copy import copy from datetime import timedelta from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest from core.constants import STRING @@ -25,7 +25,7 @@ from features.multivariate.models import MultivariateFeatureOption from features.versioning.models import EnvironmentFeatureVersion from organisations.models import Organisation, OrganisationRole -from projects.models import Project +from projects.models import IdentityOverridesV2MigrationStatus, Project from segments.models import Segment from util.mappers import map_environment_to_environment_document @@ -484,6 +484,82 @@ def test_write_environments_to_dynamodb_with_environment_and_project( ) +def test_write_environments_to_dynamodb__project_environments_v2_migrated__call_expected( + dynamo_enabled_project: Project, + dynamo_enabled_project_environment_one: Environment, + dynamo_enabled_project_environment_two: Environment, + mock_dynamo_env_wrapper: Mock, + mock_dynamo_env_v2_wrapper: Mock, +) -> None: + # Given + dynamo_enabled_project.identity_overrides_v2_migration_status = ( + IdentityOverridesV2MigrationStatus.COMPLETE + ) + dynamo_enabled_project.save() + mock_dynamo_env_v2_wrapper.is_enabled = True + + # When + Environment.write_environments_to_dynamodb(project_id=dynamo_enabled_project.id) + + # Then + args, kwargs = mock_dynamo_env_v2_wrapper.write_environments.call_args + assert kwargs == {} + assert len(args) == 1 + assert_queryset_equal( + args[0], Environment.objects.filter(project=dynamo_enabled_project) + ) + + +def test_write_environments_to_dynamodb__project_environments_v2_migrated__wrapper_disabled__wrapper_not_called( + dynamo_enabled_project: Project, + dynamo_enabled_project_environment_one: Environment, + dynamo_enabled_project_environment_two: Environment, + mock_dynamo_env_wrapper: Mock, + mock_dynamo_env_v2_wrapper: Mock, +) -> None: + # Given + mock_dynamo_env_v2_wrapper.is_enabled = False + dynamo_enabled_project.identity_overrides_v2_migration_status = ( + IdentityOverridesV2MigrationStatus.COMPLETE + ) + dynamo_enabled_project.save() + + # When + Environment.write_environments_to_dynamodb(project_id=dynamo_enabled_project.id) + + # Then + mock_dynamo_env_v2_wrapper.write_environments.assert_not_called() + + +@pytest.mark.parametrize( + "identity_overrides_v2_migration_status", + ( + IdentityOverridesV2MigrationStatus.NOT_STARTED, + IdentityOverridesV2MigrationStatus.IN_PROGRESS, + ), +) +def test_write_environments_to_dynamodb__project_environments_v2_not_migrated__wrapper_not_called( + dynamo_enabled_project: Project, + dynamo_enabled_project_environment_one: Environment, + dynamo_enabled_project_environment_two: Environment, + mock_dynamo_env_wrapper: Mock, + mock_dynamo_env_v2_wrapper: Mock, + identity_overrides_v2_migration_status: str, +) -> None: + # Given + dynamo_enabled_project.identity_overrides_v2_migration_status = ( + identity_overrides_v2_migration_status + ) + dynamo_enabled_project.save() + mock_dynamo_env_v2_wrapper.is_enabled = True + + # When + Environment.write_environments_to_dynamodb(project_id=dynamo_enabled_project.id) + + # Then + mock_dynamo_env_v2_wrapper.write_environments.assert_not_called() + + @pytest.mark.parametrize( "value, identity_id, identifier", ( diff --git a/api/tests/unit/projects/test_tasks.py b/api/tests/unit/projects/test_tasks.py index 1812f5b54c2d..8c099a5cb52a 100644 --- a/api/tests/unit/projects/test_tasks.py +++ b/api/tests/unit/projects/test_tasks.py @@ -1,6 +1,7 @@ import pytest from pytest_mock import MockerFixture +from environments.dynamodb.types import IdentityOverridesV2Changeset from projects.models import IdentityOverridesV2MigrationStatus, Project from projects.tasks import migrate_project_environments_to_v2 @@ -16,9 +17,24 @@ def project_v2_migration_in_progress( return project +@pytest.mark.parametrize( + "migrate_environments_to_v2_return_value, expected_status", + ( + ( + IdentityOverridesV2Changeset(to_put=[], to_delete=[]), + IdentityOverridesV2MigrationStatus.COMPLETE, + ), + ( + None, + IdentityOverridesV2MigrationStatus.IN_PROGRESS, + ), + ), +) def test_migrate_project_environments_to_v2__calls_expected( mocker: MockerFixture, project_v2_migration_in_progress: Project, + migrate_environments_to_v2_return_value: IdentityOverridesV2Changeset | None, + expected_status: str, ): # Given mocked_migrate_environments_to_v2 = mocker.patch( @@ -26,6 +42,9 @@ def test_migrate_project_environments_to_v2__calls_expected( autospec=True, return_value=None, ) + mocked_migrate_environments_to_v2.return_value = ( + migrate_environments_to_v2_return_value + ) # When migrate_project_environments_to_v2(project_id=project_v2_migration_in_progress.id) @@ -35,8 +54,9 @@ def test_migrate_project_environments_to_v2__calls_expected( mocked_migrate_environments_to_v2.assert_called_once_with( project_id=project_v2_migration_in_progress.id, ) - assert project_v2_migration_in_progress.identity_overrides_v2_migration_status == ( - IdentityOverridesV2MigrationStatus.COMPLETE + assert ( + project_v2_migration_in_progress.identity_overrides_v2_migration_status + == expected_status )