diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 1909ddebafd7..fccbeecf45aa 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -61,7 +61,9 @@ def validate(self, attrs): return attrs def get_project(self, validated_data: dict = None) -> Project: - return validated_data.get("project") + return validated_data.get("project") or Project.objects.get( + id=self.context["view"].kwargs["project_pk"] + ) def create(self, validated_data): project = validated_data["project"] diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index 74fa7ae22127..1b1da0035ced 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -181,6 +181,29 @@ def test_audit_log_created_when_segment_updated(project, segment, client): ) +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_can_patch_segment(project, segment, client): + # Given + segment = Segment.objects.create(name="Test segment", project=project) + url = reverse( + "api-v1:projects:project-segments-detail", + args=[project.id, segment.id], + ) + data = { + "name": "New segment name", + "rules": [{"type": "ALL", "rules": [], "conditions": []}], + } + + # When + res = client.patch(url, data=json.dumps(data), content_type="application/json") + + # Then + assert res.status_code == status.HTTP_200_OK + + @pytest.mark.parametrize( "client", [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")],