Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support --exclusive-start-key option for ensure_identity_traits_blanks #4941

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,27 +1,42 @@
import json
from argparse import ArgumentParser
from typing import Any

import structlog
from django.core.management import BaseCommand
from structlog import get_logger
from structlog.stdlib import BoundLogger

from environments.dynamodb import DynamoIdentityWrapper

identity_wrapper = DynamoIdentityWrapper()

logger: structlog.BoundLogger = structlog.get_logger()

LOG_COUNT_EVERY = 100_000


class Command(BaseCommand):
def handle(self, *args: Any, **options: Any) -> None:
def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument(
"--exclusive-start-key",
dest="exclusive_start_key",
type=str,
default="",
help="Exclusive start key in valid JSON",
)

def handle(self, *args: Any, exclusive_start_key: str, **options: Any) -> None:
total_count = identity_wrapper.table.item_count
scanned_count = 0
fixed_count = 0
scanned_count = scanned_percentage = fixed_count = 0

log: structlog.BoundLogger = logger.bind(total_count=total_count)

kwargs = {}
if exclusive_start_key:
kwargs["ExclusiveStartKey"] = json.loads(exclusive_start_key)

log: BoundLogger = get_logger(total_count=total_count)
log.info("started")

for identity_document in identity_wrapper.query_get_all_items():
for identity_document in identity_wrapper.scan_iter_all_items(**kwargs):
should_write_identity_document = False

if identity_traits_data := identity_document.get("identity_traits"):
Expand Down
41 changes: 36 additions & 5 deletions api/environments/dynamodb/wrappers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@
import boto3
import boto3.dynamodb.types
from botocore.config import Config
from sentry_sdk import set_context # TODO @kgustyr: Replace with OTel

if typing.TYPE_CHECKING:
from mypy_boto3_dynamodb.service_resource import Table
from mypy_boto3_dynamodb.type_defs import (
QueryOutputTableTypeDef,
ScanOutputTableTypeDef,
TableAttributeValueTypeDef,
)

DynamoDBOutput = QueryOutputTableTypeDef | ScanOutputTableTypeDef

P = typing.ParamSpec("P")

# Avoid `decimal.Rounded` when reading large numbers
# See https://github.com/boto/boto3/issues/2500
Expand Down Expand Up @@ -40,14 +49,20 @@ def get_table(self) -> typing.Optional["Table"]:
def is_enabled(self) -> bool:
return self.table is not None

def query_get_all_items(self, **kwargs: dict) -> typing.Generator[dict, None, None]:
if kwargs:
response_getter = partial(self.table.query, **kwargs)
else:
response_getter = partial(self.table.scan)
def _iter_all_items(
self,
response_getter_method: "typing.Callable[[P], DynamoDBOutput]",
**kwargs: "P.kwargs",
) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]:
response_getter = partial(response_getter_method, **kwargs)
set_context(
"dynamodb",
{"table_name": self.table_name, **kwargs},
)

while True:
query_response = response_getter()

for item in query_response["Items"]:
yield item

Expand All @@ -56,3 +71,19 @@ def query_get_all_items(self, **kwargs: dict) -> typing.Generator[dict, None, No
break

response_getter.keywords["ExclusiveStartKey"] = last_evaluated_key
set_context(
"dynamodb",
{"table_name": self.table_name, **response_getter.keywords},
)

def scan_iter_all_items(
self,
**kwargs: typing.Any,
) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]:
return self._iter_all_items(self.table.scan, **kwargs)

def query_iter_all_items(
self,
**kwargs: typing.Any,
) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]:
return self._iter_all_items(self.table.query, **kwargs)
4 changes: 2 additions & 2 deletions api/environments/dynamodb/wrappers/environment_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_identity_overrides_by_environment_id(
) -> typing.List[dict[str, Any]]:
try:
return list(
self.query_get_all_items(
self.query_iter_all_items(
KeyConditionExpression=Key(ENVIRONMENTS_V2_PARTITION_KEY).eq(
str(environment_id),
)
Expand Down Expand Up @@ -122,7 +122,7 @@ def delete_environment(self, environment_id: int):
"ProjectionExpression": "document_key",
}
with self.table.batch_writer() as writer:
for item in self.query_get_all_items(**query_kwargs):
for item in self.query_iter_all_items(**query_kwargs):
writer.delete_item(
Key={
ENVIRONMENTS_V2_PARTITION_KEY: environment_id,
Expand Down
2 changes: 1 addition & 1 deletion api/tests/integration/edge_api/identities/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def identity_overrides_v2(
edge_identity.save(admin_user)
return [
item["document_key"]
for item in dynamodb_wrapper_v2.query_get_all_items(
for item in dynamodb_wrapper_v2.query_iter_all_items(
KeyConditionExpression=Key("environment_id").eq(
str(dynamo_enabled_environment)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_delete_identity(
KeyConditionExpression=Key("identity_uuid").eq(identity_uuid),
)["Count"]
assert not list(
dynamodb_wrapper_v2.query_get_all_items(
dynamodb_wrapper_v2.query_iter_all_items(
KeyConditionExpression=Key("environment_id").eq(
str(dynamo_enabled_environment)
)
Expand Down
22 changes: 22 additions & 0 deletions api/tests/unit/edge_api/test_unit_edge_api_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,25 @@ def test_ensure_identity_traits_blanks__logs_expected(
"total_count": 11,
},
]


def test_ensure_identity_traits_blanks__exclusive_start_key__calls_expected(
flagsmith_identities_table: "Table",
mocker: "MockerFixture",
) -> None:
# Given
exclusive_start_key = '{"composite_key":"test_hello"}'
expected_kwargs = {"ExclusiveStartKey": {"composite_key": "test_hello"}}

identity_wrapper_mock = mocker.patch(
"edge_api.management.commands.ensure_identity_traits_blanks.identity_wrapper"
)

# When
call_command(
"ensure_identity_traits_blanks",
exclusive_start_key=exclusive_start_key,
)

# Then
identity_wrapper_mock.scan_get_all_items.assert_called_once_with(**expected_kwargs)
Loading