diff --git a/api/organisations/exceptions.py b/api/organisations/exceptions.py index f447e2d739a0..3c401516d16c 100644 --- a/api/organisations/exceptions.py +++ b/api/organisations/exceptions.py @@ -1,11 +1,6 @@ from rest_framework.exceptions import APIException -class OrganisationHasNoSubscription(APIException): +class OrganisationHasNoPaidSubscription(APIException): status_code = 400 default_detail = "Organisation has no subscription" - - -class SubscriptionNotFound(APIException): - status_code = 404 - default_detail = "Subscription Not found" diff --git a/api/organisations/migrations/0048_add_default_subscription_to_orphaned_organisations.py b/api/organisations/migrations/0048_add_default_subscription_to_orphaned_organisations.py new file mode 100644 index 000000000000..de2471360221 --- /dev/null +++ b/api/organisations/migrations/0048_add_default_subscription_to_orphaned_organisations.py @@ -0,0 +1,32 @@ +# Generated by Django 3.2.23 on 2023-11-16 16:01 + +from django.db import migrations + + +def create_default_subscription(apps, schema_editor): + Organisation = apps.get_model("organisations", "Organisation") + Subscription = apps.get_model("organisations", "Subscription") + + organisations_without_subscription = Organisation.objects.filter( + subscription__isnull=True + ) + + subscriptions_to_create = [] + for organisation in organisations_without_subscription: + subscriptions_to_create.append(Subscription(organisation=organisation)) + + Subscription.objects.bulk_create(subscriptions_to_create) + + +class Migration(migrations.Migration): + + dependencies = [ + ('organisations', '0047_organisation_force_2fa'), + ] + + operations = [ + migrations.RunPython( + create_default_subscription, + reverse_code=migrations.RunPython.noop, + ) + ] diff --git a/api/organisations/models.py b/api/organisations/models.py index aaec83e2eef0..349dbdee3116 100644 --- a/api/organisations/models.py +++ b/api/organisations/models.py @@ -94,7 +94,9 @@ def get_unique_slug(self): def num_seats(self): return self.users.count() - def has_subscription(self) -> bool: + def has_paid_subscription(self) -> bool: + # Includes subscriptions that are canceled. + # See is_paid for active paid subscriptions only. return hasattr(self, "subscription") and bool(self.subscription.subscription_id) def has_subscription_information_cache(self) -> bool: @@ -104,10 +106,12 @@ def has_subscription_information_cache(self) -> bool: @property def is_paid(self): - return self.has_subscription() and self.subscription.cancellation_date is None + return ( + self.has_paid_subscription() and self.subscription.cancellation_date is None + ) def over_plan_seats_limit(self, additional_seats: int = 0): - if self.has_subscription(): + if self.has_paid_subscription(): susbcription_metadata = self.subscription.get_subscription_metadata() return self.num_seats + additional_seats > susbcription_metadata.seats @@ -127,7 +131,7 @@ def is_auto_seat_upgrade_available(self) -> bool: @hook(BEFORE_DELETE) def cancel_subscription(self): - if self.has_subscription(): + if self.has_paid_subscription(): self.subscription.cancel() @hook(AFTER_CREATE) @@ -186,6 +190,8 @@ class Meta: class Subscription(LifecycleModelMixin, SoftDeleteExportableModel): + # Even though it is not enforced at the database level, + # every organisation has a subscription. organisation = models.OneToOneField( Organisation, on_delete=models.CASCADE, related_name="subscription" ) diff --git a/api/organisations/tests/test_models.py b/api/organisations/tests/test_models.py index c525c29c4d7b..5f938fb6092f 100644 --- a/api/organisations/tests/test_models.py +++ b/api/organisations/tests/test_models.py @@ -40,7 +40,7 @@ def test_can_create_organisation_with_and_without_webhook_notification_email(sel self.assertTrue(organisation_1.name) self.assertTrue(organisation_2.name) - def test_has_subscription_true(self): + def test_has_paid_subscription_true(self): # Given organisation = Organisation.objects.create(name="Test org") Subscription.objects.filter(organisation=organisation).update( @@ -51,14 +51,14 @@ def test_has_subscription_true(self): organisation.refresh_from_db() # Then - assert organisation.has_subscription() + assert organisation.has_paid_subscription() - def test_has_subscription_missing_subscription_id(self): + def test_has_paid_subscription_missing_subscription_id(self): # Given organisation = Organisation.objects.create(name="Test org") # Then - assert not organisation.has_subscription() + assert not organisation.has_paid_subscription() @mock.patch("organisations.models.cancel_chargebee_subscription") def test_cancel_subscription_cancels_chargebee_subscription( diff --git a/api/organisations/views.py b/api/organisations/views.py index 517072fda6d9..40842fde7726 100644 --- a/api/organisations/views.py +++ b/api/organisations/views.py @@ -19,10 +19,7 @@ from rest_framework.response import Response from rest_framework.throttling import ScopedRateThrottle -from organisations.exceptions import ( - OrganisationHasNoSubscription, - SubscriptionNotFound, -) +from organisations.exceptions import OrganisationHasNoPaidSubscription from organisations.models import ( Organisation, OrganisationRole, @@ -183,19 +180,15 @@ def update_subscription(self, request, pk): ) def get_subscription_metadata(self, request, pk): organisation = self.get_object() - if not organisation.has_subscription(): - raise SubscriptionNotFound() - subscription_details = organisation.subscription.get_subscription_metadata() serializer = self.get_serializer(instance=subscription_details) - return Response(serializer.data) @action(detail=True, methods=["GET"], url_path="portal-url") def get_portal_url(self, request, pk): organisation = self.get_object() - if not organisation.has_subscription(): - raise OrganisationHasNoSubscription() + if not organisation.has_paid_subscription(): + raise OrganisationHasNoPaidSubscription() redirect_url = get_current_site(request) serializer = self.get_serializer( data={"url": organisation.subscription.get_portal_url(redirect_url)} @@ -210,8 +203,8 @@ def get_portal_url(self, request, pk): ) def get_hosted_page_url_for_subscription_upgrade(self, request, pk): organisation = self.get_object() - if not organisation.has_subscription(): - raise OrganisationHasNoSubscription() + if not organisation.has_paid_subscription(): + raise OrganisationHasNoPaidSubscription() serializer = self.get_serializer( data={ "subscription_id": organisation.subscription.subscription_id, diff --git a/api/sales_dashboard/templates/sales_dashboard/home.html b/api/sales_dashboard/templates/sales_dashboard/home.html index 24a37f5a991d..35c44627b3ce 100644 --- a/api/sales_dashboard/templates/sales_dashboard/home.html +++ b/api/sales_dashboard/templates/sales_dashboard/home.html @@ -75,7 +75,7 @@