diff --git a/apps/codecov-api/billing/constants.py b/apps/codecov-api/billing/constants.py index bc2f5b911..7b8d4bf84 100644 --- a/apps/codecov-api/billing/constants.py +++ b/apps/codecov-api/billing/constants.py @@ -26,3 +26,7 @@ class StripeWebhookEvents: REMOVED_INVOICE_STATUSES = ["draft", "void"] + +# Task signatures for subscription schedules (shared by services.billing and management commands) +CANCELLATION_TASK_SIGNATURE = "cancel_task" +WEBHOOK_CANCELLATION_TASK_SIGNATURE = "webhook_handler_task" diff --git a/apps/codecov-api/billing/management/commands/apply_subscription_schedules.py b/apps/codecov-api/billing/management/commands/apply_subscription_schedules.py new file mode 100644 index 000000000..3eb0bd16e --- /dev/null +++ b/apps/codecov-api/billing/management/commands/apply_subscription_schedules.py @@ -0,0 +1,521 @@ +import argparse +import logging +from datetime import UTC, datetime + +import stripe +from django.conf import settings +from django.core.management.base import BaseCommand + +from billing.constants import ( + CANCELLATION_TASK_SIGNATURE, + WEBHOOK_CANCELLATION_TASK_SIGNATURE, +) +from codecov_auth.models import Owner, Plan + +log = logging.getLogger(__name__) + +if settings.STRIPE_API_KEY: + stripe.api_key = settings.STRIPE_API_KEY + stripe.api_version = "2024-12-18.acacia" + + +def valid_date(date_string: str) -> datetime: + """Validate and parse ISO format date string.""" + try: + return datetime.fromisoformat(date_string).replace(tzinfo=UTC) + except ValueError: + raise argparse.ArgumentTypeError( + f"Invalid date format: '{date_string}'. Use YYYY-MM-DD" + ) + + +def _create_end_date_schedule( + owner, + end_date, + phase1_plan_id, + phase1_quantity, + phase2_plan_id=None, + phase2_quantity=None, + subscription=None, + task_signature=CANCELLATION_TASK_SIGNATURE, +): + """ + Create a subscription schedule that ends at end_date with end_behavior="cancel". + + Phase 1 runs from current period start to current period end and phase 2 runs + from current period end to end_date. If phase2_plan_id/phase2_quantity are + omitted, phase 2 uses the same plan/quantity as phase 1. If provided, phase 2 + can use a different plan (e.g. command scheduling a transition to target plan + then cancel at end_date). + + If subscription is not provided, it is retrieved from Stripe. + """ + if subscription is None: + subscription = stripe.Subscription.retrieve(owner.stripe_subscription_id) + current_period_start = subscription["current_period_start"] + current_period_end = subscription["current_period_end"] + + if phase2_plan_id is None: + phase2_plan_id = phase1_plan_id + if phase2_quantity is None: + phase2_quantity = phase1_quantity + + end_date_ts = int(end_date.timestamp()) + metadata = { + "task_signature": task_signature, + "end_date": end_date.strftime("%Y-%m-%d"), + "script_version": "1.0", + } + + new_schedule = stripe.SubscriptionSchedule.create( + from_subscription=owner.stripe_subscription_id, + metadata=metadata, + ) + stripe.SubscriptionSchedule.modify( + new_schedule.id, + end_behavior="cancel", + phases=[ + { + "start_date": current_period_start, + "end_date": current_period_end, + "items": [ + { + "plan": phase1_plan_id, + "price": phase1_plan_id, + "quantity": phase1_quantity, + } + ], + "proration_behavior": "none", + }, + { + "start_date": current_period_end, + "end_date": end_date_ts, + "items": [ + { + "plan": phase2_plan_id, + "price": phase2_plan_id, + "quantity": phase2_quantity, + } + ], + "proration_behavior": "none", + }, + ], + ) + return new_schedule + + +class Command(BaseCommand): + help = "Apply subscription schedules to subscriptions in bulk. This command is idempotent - it will skip subscriptions that already have a schedule." + + def add_arguments(self, parser): + parser.add_argument( + "--end-date", + type=valid_date, + required=True, + help="The date to end the subscription schedule at on UTC time(format: YYYY-MM-DD), not-inclusive", + ) + parser.add_argument( + "--target-plan", + type=str, + default=None, + help="The plan name to schedule the subscription to transition to (e.g., 'users-pr-inappy'). If omitted, uses the owner's current plan.", + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Preview changes without applying them to Stripe", + ) + + # Filtering options + parser.add_argument( + "--renew-date-gte", + type=valid_date, + help="Only process owners with a renew date greater than or equal to this date (format: YYYY-MM-DD)", + ) + parser.add_argument( + "--renew-date-lte", + type=valid_date, + help="Only process owners with a renew date less than or equal to this date (format: YYYY-MM-DD)", + ) + parser.add_argument( + "--owner-ids", + type=str, + help="Comma-separated list of owner IDs to process", + ) + parser.add_argument( + "--current-plan", + type=str, + help="Only process owners currently on this plan (e.g., 'users-pr-inappy')", + ) + parser.add_argument( + "--exclude-owner-ids", + type=str, + help="Comma-separated list of owner IDs to exclude from processing", + ) + + # Pagination (cursor-based for stable batches across runs even if owners are added/removed) + parser.add_argument( + "--limit", + type=int, + default=100, + help="Maximum number of subscriptions to process (default: 100)", + ) + parser.add_argument( + "--after-ownerid", + type=int, + default=None, + metavar="OWNERID", + help="Process only owners with ownerid greater than this (for next batch, use last processed ownerid from previous run)", + ) + + def handle(self, *args, **options): + dry_run = options["dry_run"] + target_plan_name = options["target_plan"] + end_date = options["end_date"] + limit = options["limit"] + after_ownerid = options["after_ownerid"] + + if not settings.STRIPE_API_KEY: + self.stdout.write( + self.style.ERROR("STRIPE_API_KEY is not configured. Cannot proceed.") + ) + return + target_plan = None + if target_plan_name: + try: + target_plan = Plan.objects.get(name=target_plan_name) + if not target_plan.stripe_id: + self.stdout.write( + self.style.ERROR( + f"Plan '{target_plan_name}' does not have a Stripe price ID configured." + ) + ) + return + except Plan.DoesNotExist: + self.stdout.write( + self.style.ERROR(f"Plan '{target_plan_name}' does not exist.") + ) + return + + if dry_run: + self.stdout.write( + self.style.WARNING("DRY RUN MODE - No changes will be made to Stripe") + ) + + # Build queryset of owners to process + owners = Owner.objects.filter( + stripe_subscription_id__isnull=False, + stripe_customer_id__isnull=False, + ) + + # Then filter by renewal date, owner IDs, exclude owner IDs, and current plan if needed + if options["renew_date_gte"]: + owners = owners.filter(renew_date__gte=options["renew_date_gte"]) + + if options["renew_date_lte"]: + owners = owners.filter(renew_date__lte=options["renew_date_lte"]) + + if options["owner_ids"]: + raw = options["owner_ids"] + owner_ids = [int(x.strip()) for x in raw.split(",") if x.strip()] + if not owner_ids: + raise ValueError( + f"--owner-ids must be comma-separated integers, got: {raw!r}" + ) + owners = owners.filter(ownerid__in=owner_ids) + + if options["exclude_owner_ids"]: + exclude_ids = [ + int(x.strip()) for x in options["exclude_owner_ids"].split(",") + ] + owners = owners.exclude(ownerid__in=exclude_ids) + + if options["current_plan"]: + owners = owners.filter(plan=options["current_plan"]) + + owners = owners.order_by("ownerid") + if after_ownerid is not None: + owners = owners.filter(ownerid__gt=after_ownerid) + + total_matching = owners.count() + if total_matching == 0: + self.stdout.write(self.style.WARNING("No matching owners found")) + return + owners = list(owners[:limit]) + + # Output summary + self.stdout.write(f"Total matching owners: {total_matching}") + self.stdout.write( + f"Processing {len(owners)} owners (after_ownerid={after_ownerid}, limit={limit})" + ) + if target_plan: + self.stdout.write( + f"Target plan: {target_plan_name} (Stripe ID: {target_plan.stripe_id})" + ) + else: + self.stdout.write("Target plan: (use current plan per owner)") + if options["current_plan"]: + self.stdout.write(f"Filtering by current plan: {options['current_plan']}") + self.stdout.write("-" * 60) + + stats = { + "processed": 0, + "skipped_same_plan": 0, + "skipped_has_schedule_with_target_date": 0, + "skipped_no_plan": 0, + "scheduled": 0, + "errors": 0, + "errored_owners": [], + } + + for owner in owners: + try: + result = self.process_owner( + owner=owner, + target_plan=target_plan, + end_date=end_date, + dry_run=dry_run, + ) + stats[result] += 1 + stats["processed"] += 1 + except stripe.StripeError as e: + stats["errors"] += 1 + stats["errored_owners"].append(owner.ownerid) + self.stdout.write( + self.style.ERROR( + f" Stripe error for owner {owner.ownerid}: {e.user_message}" + ) + ) + log.warning( + f"Stripe error processing owner {owner.ownerid}", + extra={"error": str(e), "ownerid": owner.ownerid}, + ) + except Exception as e: + stats["errors"] += 1 + stats["errored_owners"].append(owner.ownerid) + self.stdout.write( + self.style.ERROR(f" Error processing owner {owner.ownerid}: {e}") + ) + log.exception( + f"Unexpected error processing owner {owner.ownerid}", + extra={"ownerid": owner.ownerid}, + ) + + self.stdout.write("-" * 60) + self.stdout.write(self.style.SUCCESS("Job completed!")) + self.stdout.write(f" Processed: {stats['processed']}") + self.stdout.write(f" Scheduled: {stats['scheduled']}") + self.stdout.write( + f" Skipped (already on target plan): {stats['skipped_same_plan']}" + ) + self.stdout.write( + f" Skipped (already has schedule with target date): {stats['skipped_has_schedule_with_target_date']}" + ) + self.stdout.write( + f" Skipped (no Plan for Stripe price): {stats['skipped_no_plan']}" + ) + self.stdout.write(f" Error owners count: {stats['errors']}") + self.stdout.write( + f" Errored owners: {', '.join(str(o) for o in stats['errored_owners'])}" + ) + if owners and total_matching > len(owners): + last_ownerid = owners[-1].ownerid + self.stdout.write( + self.style.SUCCESS(f" Next batch: --after-ownerid={last_ownerid}") + ) + + def process_owner( + self, + owner: Owner, + target_plan: Plan | None, + end_date: datetime, + dry_run: bool, + ) -> str: + """ + Process a single owner's subscription schedule. + + Returns a string indicating the result: + - 'scheduled': A new schedule was created + - 'skipped_has_schedule_with_target_date': Skipped because schedule already exists with target date + - 'skipped_same_plan': Skipped because already on plan with target end date + - 'skipped_no_plan': Skipped because no Plan record for subscription's Stripe price (when --target-plan omitted) + """ + self.stdout.write(f"Processing owner {owner.ownerid} ({owner.username})...") + + subscription = stripe.Subscription.retrieve(owner.stripe_subscription_id) + + current_plan_id = subscription["items"]["data"][0]["plan"]["id"] + current_quantity = subscription["items"]["data"][0]["quantity"] + current_period_end = subscription["current_period_end"] + # Stripe returns Unix timestamp; normalize to date string for comparison with end_date + current_end_date_str = datetime.fromtimestamp( + current_period_end, tz=UTC + ).strftime("%Y-%m-%d") + + self.stdout.write( + f" Current plan: {current_plan_id}, quantity: {current_quantity}" + ) + + # If no target plan passed, use current plan (look up Plan by Stripe price ID) + if target_plan is None: + target_plan = Plan.objects.filter(stripe_id=current_plan_id).first() + if not target_plan: + self.stdout.write( + self.style.WARNING( + f" SKIPPED: No Plan found for Stripe price {current_plan_id}" + ) + ) + return "skipped_no_plan" + + # Skip if already has target end date and is on the same plan + if ( + current_plan_id == target_plan.stripe_id + and current_end_date_str == end_date.strftime("%Y-%m-%d") + ): + self.stdout.write( + self.style.WARNING( + " SKIPPED: Already on target plan with same end date" + ) + ) + return "skipped_same_plan" + + # Stripe phase start_date/end_date expect Unix timestamps + end_date_ts = int(end_date.timestamp()) + + # Subscription already has a schedule, either for End or downgrade + if subscription.schedule: + existing_schedule = stripe.SubscriptionSchedule.retrieve( + subscription.schedule + ) + + # End schedule has already been added by this script or webhook handler + if existing_schedule.metadata.get("task_signature") in [ + CANCELLATION_TASK_SIGNATURE, + WEBHOOK_CANCELLATION_TASK_SIGNATURE, + ]: + if existing_schedule.metadata.get("end_date") == end_date.strftime( + "%Y-%m-%d" + ): + self.stdout.write( + self.style.WARNING( + f" SKIPPED: Subscription already has schedule {subscription.schedule} with target date" + ) + ) + return "skipped_has_schedule_with_target_date" + else: + updated_phases = existing_schedule.phases.copy() + new_phase = { + "start_date": existing_schedule.phases[-1]["start_date"], + "end_date": end_date_ts, + "items": [ + { + "plan": target_plan.stripe_id, + "price": target_plan.stripe_id, + "quantity": current_quantity, + } + ], + "proration_behavior": "none", + } + updated_phases = updated_phases[:-1] + [new_phase] + + if dry_run: + self.stdout.write( + self.style.SUCCESS( + f" [DRY RUN] Would update schedule for id: {existing_schedule.id} to end on {end_date.strftime('%Y-%m-%d')} with {target_plan.name} and {current_quantity} seats" + ) + ) + else: + stripe.SubscriptionSchedule.modify( + existing_schedule.id, + phases=updated_phases, + metadata={ + "task_signature": CANCELLATION_TASK_SIGNATURE, + "end_date": end_date.strftime("%Y-%m-%d"), + "script_version": "1.0", + }, + ) + self.stdout.write( + self.style.SUCCESS( + f" Updated schedule for id: {existing_schedule.id} to end on {end_date.strftime('%Y-%m-%d')} with {target_plan.name} and {current_quantity} seats" + ) + ) + return "scheduled" + # Add an extra End phase to the existing schedule + else: + if dry_run: + self.stdout.write( + self.style.SUCCESS( + f" [DRY RUN] Would add schedule phase to id: {existing_schedule.id} to end on {end_date.strftime('%Y-%m-%d')} with {target_plan.name} and {current_quantity} seats" + ) + ) + else: + new_phase = { + "start_date": existing_schedule.phases[-1]["end_date"], + "end_date": end_date_ts, + "items": [ + { + "plan": target_plan.stripe_id, + "price": target_plan.stripe_id, + "quantity": current_quantity, + } + ], + "proration_behavior": "none", + } + stripe.SubscriptionSchedule.modify( + existing_schedule.id, + end_behavior="cancel", + phases=existing_schedule.phases + [new_phase], + metadata={ + "task_signature": CANCELLATION_TASK_SIGNATURE, + "end_date": end_date.strftime("%Y-%m-%d"), + "script_version": "1.0", + }, + ) + self.stdout.write( + self.style.SUCCESS( + f" Added schedule phase to id:{existing_schedule.id} to end on {end_date.strftime('%Y-%m-%d')} with {target_plan.name} and {current_quantity} seats" + ) + ) + return "scheduled" + + if dry_run: + self.stdout.write( + self.style.SUCCESS( + f" [DRY RUN] Would create schedule for id: {owner.stripe_subscription_id} to end on {end_date.strftime('%Y-%m-%d')} with {target_plan.name} and {current_quantity} seats" + ) + ) + return "scheduled" + + # Create the subscription schedule + schedule = _create_end_date_schedule( + owner=owner, + end_date=end_date, + phase1_plan_id=current_plan_id, + phase1_quantity=current_quantity, + phase2_plan_id=target_plan.stripe_id, + phase2_quantity=current_quantity, + subscription=subscription, + task_signature=CANCELLATION_TASK_SIGNATURE, + ) + + self.stdout.write( + self.style.SUCCESS( + f" Created schedule {schedule.id} to end at {end_date.strftime('%Y-%m-%d')} with {target_plan.name} and {current_quantity} seats" + ) + ) + + log.info( + "Created subscription schedule", + extra={ + "ownerid": owner.ownerid, + "subscription_id": owner.stripe_subscription_id, + "schedule_id": schedule.id, + "target_plan": target_plan.name, + "target_quantity": current_quantity, + "current_plan_id": current_plan_id, + "current_quantity": current_quantity, + }, + ) + + return "scheduled" diff --git a/apps/codecov-api/billing/tests/test_apply_subscription_schedules.py b/apps/codecov-api/billing/tests/test_apply_subscription_schedules.py new file mode 100644 index 000000000..a09c1e4d1 --- /dev/null +++ b/apps/codecov-api/billing/tests/test_apply_subscription_schedules.py @@ -0,0 +1,406 @@ +from datetime import UTC, datetime +from io import StringIO +from unittest.mock import MagicMock, patch + +import pytest +from django.core.management import call_command + +from shared.django_apps.codecov_auth.tests.factories import OwnerFactory, PlanFactory + + +@pytest.fixture +def target_plan(db): + """Plan used as --target-plan.""" + return PlanFactory( + name="users-pr-inappy", + stripe_id="price_target_123", + ) + + +@pytest.fixture +def current_plan_in_stripe(db): + """Plan with stripe_id matching mock subscription default (price_current_456). Used when testing --target-plan omitted.""" + return PlanFactory( + name="users-current", + stripe_id="price_current_456", + ) + + +@pytest.fixture +def owner_with_subscription(db): + """Owner with Stripe subscription (no schedule).""" + return OwnerFactory( + stripe_subscription_id="sub_123", + stripe_customer_id="cus_123", + ) + + +class _MockSubscription: + """Mock Stripe subscription that supports both dict and attribute access (command uses subscription['items'] and subscription.schedule).""" + + def __init__(self, data): + self._data = data + self.schedule = data.get("schedule") + + def __getitem__(self, key): + return self._data[key] + + +def _make_mock_subscription( + *, + plan_id="price_current_456", + quantity=5, + current_period_start=1000000, + current_period_end=2000000, + schedule=None, +): + """Build a mock Stripe subscription (supports subscription['key'] and subscription.schedule).""" + data = { + "id": "sub_123", + "items": { + "data": [ + { + "id": "si_123", + "plan": {"id": plan_id}, + "quantity": quantity, + } + ] + }, + "current_period_start": current_period_start, + "current_period_end": current_period_end, + "schedule": schedule, + } + return _MockSubscription(data) + + +def _make_mock_schedule(phases=None, metadata=None): + """Build a mock Stripe SubscriptionSchedule (object with id, phases, metadata).""" + s = MagicMock() + s.id = "sub_sched_123" + s.phases = phases or [ + {"start_date": 1000000, "end_date": 2000000}, + ] + s.metadata = metadata or {} + return s + + +@pytest.mark.django_db +@patch("stripe.SubscriptionSchedule.modify") +@patch("stripe.SubscriptionSchedule.create") +@patch("stripe.Subscription.retrieve") +def test_apply_subscription_schedules_creates_schedule( + mock_sub_retrieve, + mock_schedule_create, + mock_schedule_modify, + owner_with_subscription, + target_plan, +): + """Creating a new schedule calls create then modify with two phases.""" + mock_sub_retrieve.return_value = _make_mock_subscription( + plan_id="price_current_456", + quantity=5, + current_period_start=1000000, + current_period_end=2000000, + schedule=None, + ) + mock_schedule_create.return_value = MagicMock(id="sub_sched_new") + + out = StringIO() + err = StringIO() + end_date = "2025-12-31" + + call_command( + "apply_subscription_schedules", + "--target-plan=users-pr-inappy", + f"--end-date={end_date}", + "--owner-ids", + str(owner_with_subscription.ownerid), + "--limit=10", + stdout=out, + stderr=err, + ) + + mock_sub_retrieve.assert_called_once_with( + owner_with_subscription.stripe_subscription_id + ) + mock_schedule_create.assert_called_once() + assert ( + mock_schedule_create.call_args.kwargs["from_subscription"] + == owner_with_subscription.stripe_subscription_id + ) + assert ( + mock_schedule_create.call_args.kwargs["metadata"]["task_signature"] + == "cancel_task" + ) + assert mock_schedule_create.call_args.kwargs["metadata"]["end_date"] == end_date + + mock_schedule_modify.assert_called_once() + modify_kw = mock_schedule_modify.call_args.kwargs + assert modify_kw["end_behavior"] == "cancel" + assert len(modify_kw["phases"]) == 2 + assert modify_kw["phases"][0]["items"][0]["plan"] == "price_current_456" + assert modify_kw["phases"][1]["items"][0]["plan"] == target_plan.stripe_id + + assert "Created schedule" in out.getvalue() + assert "scheduled" in out.getvalue().lower() or "Scheduled" in out.getvalue() + + +@pytest.mark.django_db +@pytest.mark.usefixtures("target_plan") +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.Subscription.retrieve" +) +def test_apply_subscription_schedules_dry_run_no_stripe_writes( + mock_sub_retrieve, + owner_with_subscription, +): + """Dry run only retrieves subscription; no create/modify.""" + mock_sub_retrieve.return_value = _make_mock_subscription(schedule=None) + + out = StringIO() + call_command( + "apply_subscription_schedules", + "--target-plan=users-pr-inappy", + "--end-date=2025-12-31", + "--owner-ids", + str(owner_with_subscription.ownerid), + "--dry-run", + stdout=out, + stderr=StringIO(), + ) + + mock_sub_retrieve.assert_called_once() + assert "DRY RUN" in out.getvalue() + assert "Would" in out.getvalue() or "Created" not in out.getvalue() + + +@pytest.mark.django_db +@pytest.mark.usefixtures("target_plan") +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.SubscriptionSchedule.retrieve" +) +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.Subscription.retrieve" +) +def test_apply_subscription_schedules_skips_when_has_schedule_with_same_end_date( + mock_sub_retrieve, + mock_schedule_retrieve, + owner_with_subscription, +): + """When schedule already exists with same end date, command skips.""" + mock_sub_retrieve.return_value = _make_mock_subscription(schedule="sub_sched_123") + mock_schedule_retrieve.return_value = _make_mock_schedule( + phases=[{"start_date": 1000000, "end_date": 2000000}], + metadata={ + "task_signature": "cancel_task", + "end_date": "2025-12-31", + }, + ) + + out = StringIO() + call_command( + "apply_subscription_schedules", + "--target-plan=users-pr-inappy", + "--end-date=2025-12-31", + "--owner-ids", + str(owner_with_subscription.ownerid), + stdout=out, + stderr=StringIO(), + ) + + assert "SKIPPED" in out.getvalue() + assert "already has schedule" in out.getvalue().lower() + + +@pytest.mark.django_db +@pytest.mark.usefixtures("target_plan") +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.SubscriptionSchedule.modify" +) +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.SubscriptionSchedule.retrieve" +) +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.Subscription.retrieve" +) +def test_apply_subscription_schedules_updates_existing_cancellation_schedule( + mock_sub_retrieve, + mock_schedule_retrieve, + _mock_schedule_modify, # not used but needed so owner_with_subscription is the fixture in correct order + owner_with_subscription, +): + """When schedule exists with different end date, command updates last phase.""" + mock_sub_retrieve.return_value = _make_mock_subscription(schedule="sub_sched_123") + existing = _make_mock_schedule( + phases=[ + {"start_date": 1000000, "end_date": 2000000}, + {"start_date": 2000000, "end_date": 3000000}, + ], + metadata={"task_signature": "cancel_task", "end_date": "2025-06-30"}, + ) + existing.phases = [ + {"start_date": 1000000, "end_date": 2000000}, + {"start_date": 2000000, "end_date": 3000000}, + ] + mock_schedule_retrieve.return_value = existing + + out = StringIO() + call_command( + "apply_subscription_schedules", + "--target-plan=users-pr-inappy", + "--end-date=2025-12-31", + "--owner-ids", + str(owner_with_subscription.ownerid), + stdout=out, + stderr=StringIO(), + ) + + # Command calls stripe.SubscriptionSchedule.modify(schedule_id, ...) + _mock_schedule_modify.assert_called_once() + assert _mock_schedule_modify.call_args.args[0] == existing.id + phases = _mock_schedule_modify.call_args.kwargs["phases"] + assert len(phases) == 2 + # Command passes end_date as Unix timestamp to Stripe + expected_end_ts = int( + datetime.fromisoformat("2025-12-31").replace(tzinfo=UTC).timestamp() + ) + assert phases[1]["end_date"] == expected_end_ts + assert "Updated schedule" in out.getvalue() + + +@pytest.mark.django_db +@pytest.mark.usefixtures("target_plan") +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.SubscriptionSchedule.modify" +) +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.SubscriptionSchedule.retrieve" +) +@patch( + "billing.management.commands.apply_subscription_schedules.stripe.Subscription.retrieve" +) +def test_apply_subscription_schedules_adds_end_phase_to_existing_schedule( + mock_sub_retrieve, + mock_schedule_retrieve, + _mock_schedule_modify, + owner_with_subscription, + target_plan, +): + """When schedule exists without our task signature, command appends end phase.""" + mock_sub_retrieve.return_value = _make_mock_subscription( + schedule="sub_sched_123", + quantity=3, + ) + existing_phases = [ + {"start_date": 1000000, "end_date": 2000000}, + ] + existing = _make_mock_schedule( + phases=existing_phases, + metadata={"task_signature": "other_scheduler"}, + ) + existing.phases = existing_phases + mock_schedule_retrieve.return_value = existing + + out = StringIO() + call_command( + "apply_subscription_schedules", + "--target-plan=users-pr-inappy", + "--end-date=2025-12-31", + owner_ids=str(owner_with_subscription.ownerid), + stdout=out, + stderr=StringIO(), + ) + + # Command calls stripe.SubscriptionSchedule.modify(schedule_id, ...) + _mock_schedule_modify.assert_called_once() + assert _mock_schedule_modify.call_args.args[0] == existing.id + phases = _mock_schedule_modify.call_args.kwargs["phases"] + # Original phase(s) + new end phase + assert len(phases) == 2 + assert phases[0]["start_date"] == 1000000 + assert phases[0]["end_date"] == 2000000 + # Command passes end_date as Unix timestamp to Stripe + expected_end_ts = int( + datetime.fromisoformat("2025-12-31").replace(tzinfo=UTC).timestamp() + ) + assert phases[1]["end_date"] == expected_end_ts + assert phases[1]["items"][0]["plan"] == target_plan.stripe_id + assert phases[1]["items"][0]["quantity"] == 3 + assert ( + _mock_schedule_modify.call_args.kwargs["metadata"]["task_signature"] + == "cancel_task" + ) + assert "Added schedule phase" in out.getvalue() + + +@pytest.mark.django_db +@patch("stripe.SubscriptionSchedule.modify") +@patch("stripe.SubscriptionSchedule.create") +@patch("stripe.Subscription.retrieve") +def test_apply_subscription_schedules_uses_current_plan_when_target_plan_omitted( + mock_sub_retrieve, + mock_schedule_create, + mock_schedule_modify, + owner_with_subscription, + current_plan_in_stripe, +): + """When --target-plan is omitted, command uses subscription's current plan (looked up by Stripe price ID).""" + mock_sub_retrieve.return_value = _make_mock_subscription( + plan_id="price_current_456", + quantity=5, + current_period_start=1000000, + current_period_end=2000000, + schedule=None, + ) + mock_schedule_create.return_value = MagicMock(id="sub_sched_new") + + out = StringIO() + call_command( + "apply_subscription_schedules", + "--end-date=2025-12-31", + "--owner-ids", + str(owner_with_subscription.ownerid), + "--limit=10", + stdout=out, + stderr=StringIO(), + ) + + mock_schedule_create.assert_called_once() + modify_kw = mock_schedule_modify.call_args.kwargs + assert ( + modify_kw["phases"][1]["items"][0]["plan"] == current_plan_in_stripe.stripe_id + ) + assert "Created schedule" in out.getvalue() + assert ( + "use current plan" in out.getvalue().lower() + or "Target plan: (use current" in out.getvalue() + ) + + +@pytest.mark.django_db +@patch("stripe.Subscription.retrieve") +def test_apply_subscription_schedules_skips_when_no_plan_for_stripe_price( + mock_sub_retrieve, + owner_with_subscription, +): + """When --target-plan is omitted and no Plan exists for subscription's Stripe price, command skips with skipped_no_plan.""" + # Mock subscription has a price ID that does not match any Plan in DB + mock_sub_retrieve.return_value = _make_mock_subscription( + plan_id="price_unknown_nonexistent", + schedule=None, + ) + + out = StringIO() + call_command( + "apply_subscription_schedules", + "--end-date=2025-12-31", + "--owner-ids", + str(owner_with_subscription.ownerid), + stdout=out, + stderr=StringIO(), + ) + + assert "SKIPPED" in out.getvalue() + assert "No Plan found" in out.getvalue() + assert "price_unknown_nonexistent" in out.getvalue() + assert "Skipped (no Plan for Stripe price)" in out.getvalue() diff --git a/apps/codecov-api/services/billing.py b/apps/codecov-api/services/billing.py index 662eb3d03..92dcb8780 100644 --- a/apps/codecov-api/services/billing.py +++ b/apps/codecov-api/services/billing.py @@ -7,7 +7,14 @@ from dateutil.relativedelta import relativedelta from django.conf import settings -from billing.constants import REMOVED_INVOICE_STATUSES +from billing.constants import ( + CANCELLATION_TASK_SIGNATURE, + REMOVED_INVOICE_STATUSES, + WEBHOOK_CANCELLATION_TASK_SIGNATURE, +) +from billing.management.commands.apply_subscription_schedules import ( + _create_end_date_schedule, +) from codecov_auth.models import Owner, Plan from shared.plan.constants import PlanBillingRate, TierName from shared.plan.service import PlanService @@ -346,13 +353,35 @@ def modify_subscription(self, owner: Owner, desired_plan: dict): # Divide logic bw immediate updates and scheduled updates # Immediate updates: when user upgrades seats or plan - # If the user is not in a schedule, update immediately - # If the user is in a schedule, update the existing schedule + # Update immediately + # If the user is in a schedule, release the existing schedule but recreate a new schedule + # for scheduled end date if a phase existed for it in existing schedule # Scheduled updates: when the user decreases seats or plan # If the user is not in a schedule, create a schedule # If the user is in a schedule, update the existing schedule if is_upgrading: + previous_scheduled_end_date = None if subscription_schedule_id: + existing_schedule = stripe.SubscriptionSchedule.retrieve( + subscription_schedule_id + ) + if existing_schedule.metadata.get("task_signature") in [ + CANCELLATION_TASK_SIGNATURE, + WEBHOOK_CANCELLATION_TASK_SIGNATURE, + ]: + if existing_schedule.phases: + last_phase = existing_schedule.phases[-1] + end_ts = getattr(last_phase, "end_date", None) + if end_ts is not None: + previous_scheduled_end_date = datetime.fromtimestamp( + end_ts, tz=UTC + ) + elif metadata_end_date := existing_schedule.metadata.get( + "end_date" + ): + previous_scheduled_end_date = datetime.fromisoformat( + metadata_end_date + ).replace(tzinfo=UTC) log.info( f"Releasing Stripe schedule for owner {owner.ownerid} to {desired_plan['value']} with {desired_plan['quantity']} seats by user #{self.requesting_user.ownerid}" ) @@ -361,25 +390,62 @@ def modify_subscription(self, owner: Owner, desired_plan: dict): f"Updating Stripe subscription for owner {owner.ownerid} to {desired_plan['value']} by user #{self.requesting_user.ownerid}" ) - subscription = stripe.Subscription.modify( - owner.stripe_subscription_id, - cancel_at_period_end=False, - items=[ - { - "id": subscription["items"]["data"][0]["id"], - "plan": desired_plan_info.stripe_id, - "quantity": desired_plan["quantity"], - } - ], - metadata=self._get_checkout_session_and_subscription_metadata(owner), - proration_behavior=proration_behavior, - # TODO: we need to include this arg, but it means we need to remove some of the existing args - # on the .modify() call https://docs.stripe.com/billing/subscriptions/pending-updates-reference - # payment_behavior="pending_if_incomplete", - ) + try: + subscription = stripe.Subscription.modify( + owner.stripe_subscription_id, + cancel_at_period_end=False, + items=[ + { + "id": subscription["items"]["data"][0]["id"], + "plan": desired_plan_info.stripe_id, + "quantity": desired_plan["quantity"], + } + ], + metadata=self._get_checkout_session_and_subscription_metadata( + owner + ), + proration_behavior=proration_behavior, + # TODO: we need to include this arg, but it means we need to remove some of the existing args + # on the .modify() call https://docs.stripe.com/billing/subscriptions/pending-updates-reference + # payment_behavior="pending_if_incomplete", + ) + except stripe.StripeError: + # Upgrade payment failed but we already released the schedule so add back an end-date schedule so the user doesn't lose it + if previous_scheduled_end_date is not None: + subscription = stripe.Subscription.retrieve( + owner.stripe_subscription_id + ) + current_plan_id = subscription["items"]["data"][0]["plan"]["id"] + current_quantity = subscription["items"]["data"][0]["quantity"] + _create_end_date_schedule( + owner=owner, + end_date=previous_scheduled_end_date, + phase1_plan_id=current_plan_id, + phase1_quantity=current_quantity, + subscription=subscription, + task_signature=WEBHOOK_CANCELLATION_TASK_SIGNATURE, + ) + log.info( + f"Restored end-date schedule for owner {owner.ownerid} after upgrade failure" + ) + raise + log.info( f"Stripe subscription upgrade attempted for owner {owner.ownerid} by user #{self.requesting_user.ownerid}" ) + + # If the user had a previous scheduled end date, we need to create a new schedule to replace the end date schedule + # that the upgrade is releasing + if previous_scheduled_end_date is not None: + _create_end_date_schedule( + owner=owner, + end_date=previous_scheduled_end_date, + phase1_plan_id=desired_plan_info.stripe_id, + phase1_quantity=desired_plan["quantity"], + subscription=subscription, + task_signature=WEBHOOK_CANCELLATION_TASK_SIGNATURE, + ) + indication_of_payment_failure = getattr( subscription, "pending_update", None ) diff --git a/apps/codecov-api/services/tests/test_billing.py b/apps/codecov-api/services/tests/test_billing.py index 85b71842d..9a5fe5549 100644 --- a/apps/codecov-api/services/tests/test_billing.py +++ b/apps/codecov-api/services/tests/test_billing.py @@ -9,6 +9,10 @@ from stripe import InvalidRequestError from stripe.api_resources import PaymentIntent, SetupIntent +from billing.constants import ( + CANCELLATION_TASK_SIGNATURE, + WEBHOOK_CANCELLATION_TASK_SIGNATURE, +) from billing.tests.mocks import mock_all_plans_and_tiers from codecov_auth.models import Plan, Service from services.billing import AbstractPaymentService, BillingService, StripeService @@ -1180,6 +1184,7 @@ def test_modify_subscription_with_schedule_modifies_schedule_when_user_count_dec assert owner.plan == original_plan assert owner.plan_user_count == original_user_count + @patch("services.billing.stripe.SubscriptionSchedule.retrieve") @patch("services.billing.stripe.Subscription.modify") @patch("services.billing.stripe.Subscription.retrieve") @patch("services.billing.stripe.SubscriptionSchedule.release") @@ -1188,6 +1193,7 @@ def test_modify_subscription_with_schedule_modifies_schedule_when_user_count_inc schedule_release_mock, retrieve_subscription_mock, subscription_modify_mock, + schedule_retrieve_mock, ): original_user_count = 17 original_plan = PlanName.CODECOV_PRO_MONTHLY.value @@ -1217,6 +1223,12 @@ def test_modify_subscription_with_schedule_modifies_schedule_when_user_count_inc retrieve_subscription_mock.return_value = MockSubscription(subscription_params) subscription_modify_mock.return_value = MockSubscription(subscription_params) + # Existing schedule without cancellation task signature + existing_schedule = MagicMock() + existing_schedule.metadata = {} + existing_schedule.phases = [] + schedule_retrieve_mock.return_value = existing_schedule + desired_plan_name = PlanName.CODECOV_PRO_MONTHLY.value desired_user_count = 26 desired_plan = {"value": desired_plan_name, "quantity": desired_user_count} @@ -1272,6 +1284,7 @@ def test_modify_subscription_with_schedule_modifies_schedule_when_plan_downgrade assert owner.plan == original_plan assert owner.plan_user_count == original_user_count + @patch("services.billing.stripe.SubscriptionSchedule.retrieve") @patch("services.billing.stripe.Subscription.modify") @patch("services.billing.stripe.Subscription.retrieve") @patch("services.billing.stripe.SubscriptionSchedule.release") @@ -1280,6 +1293,7 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_upgrades( schedule_release_mock, retrieve_subscription_mock, subscription_modify_mock, + schedule_retrieve_mock, ): original_user_count = 15 original_plan = PlanName.CODECOV_PRO_MONTHLY.value @@ -1305,6 +1319,12 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_upgrades( retrieve_subscription_mock.return_value = MockSubscription(subscription_params) subscription_modify_mock.return_value = MockSubscription(subscription_params) + # Existing schedule without cancellation task signature (no end date recreation needed) + existing_schedule = MagicMock() + existing_schedule.metadata = {} + existing_schedule.phases = [] + schedule_retrieve_mock.return_value = existing_schedule + desired_plan_name = PlanName.CODECOV_PRO_YEARLY.value desired_user_count = 15 desired_plan = {"value": desired_plan_name, "quantity": desired_user_count} @@ -1319,6 +1339,214 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_upgrades( assert owner.plan == desired_plan_name assert owner.plan_user_count == desired_user_count + @patch("services.billing._create_end_date_schedule") + @patch("services.billing.stripe.SubscriptionSchedule.retrieve") + @patch("services.billing.stripe.Subscription.modify") + @patch("services.billing.stripe.Subscription.retrieve") + @patch("services.billing.stripe.SubscriptionSchedule.release") + def test_modify_subscription_with_schedule_recreates_schedule_with_end_date_and_cancel_when_plan_upgrades( + self, + schedule_release_mock, + retrieve_subscription_mock, + subscription_modify_mock, + schedule_retrieve_mock, + create_end_date_schedule_mock, + ): + """Upgrading with an existing cancellation schedule releases it, upgrades, then creates a new schedule with same end_date and end_behavior cancel.""" + original_user_count = 15 + original_plan = PlanName.CODECOV_PRO_MONTHLY.value + stripe_subscription_id = "33043sdf" + owner = OwnerFactory( + plan=original_plan, + plan_user_count=original_user_count, + stripe_subscription_id=stripe_subscription_id, + ) + + schedule_id = "sub_sched_1K77Y5GlVGuVgOrkJrLjRn2e" + current_subscription_start_date = 1639628096 + current_subscription_end_date = 1644107871 + subscription_params = { + "schedule_id": schedule_id, + "start_date": current_subscription_start_date, + "end_date": current_subscription_end_date, + "quantity": original_user_count, + "name": original_plan, + "id": 111, + } + + retrieve_subscription_mock.return_value = MockSubscription(subscription_params) + subscription_modify_mock.return_value = MockSubscription(subscription_params) + + # Existing schedule had our cancellation task and an end_date (recreate after upgrade) + existing_schedule = MagicMock() + existing_schedule.metadata = { + "task_signature": CANCELLATION_TASK_SIGNATURE, + "end_date": "2025-12-31T00:00:00+00:00", + } + existing_schedule.phases = [] + schedule_retrieve_mock.return_value = existing_schedule + + desired_plan_name = PlanName.CODECOV_PRO_YEARLY.value + desired_user_count = 15 + desired_plan = {"value": desired_plan_name, "quantity": desired_user_count} + + self.stripe.modify_subscription(owner, desired_plan) + + schedule_release_mock.assert_called_once_with(schedule_id) + self._assert_subscription_modify( + subscription_modify_mock, owner, subscription_params, desired_plan + ) + + # Verify _create_end_date_schedule was called with correct params + create_end_date_schedule_mock.assert_called_once() + call_kwargs = create_end_date_schedule_mock.call_args.kwargs + assert call_kwargs["owner"] == owner + assert ( + call_kwargs["phase1_plan_id"] + == Plan.objects.get(name=desired_plan_name).stripe_id + ) + assert call_kwargs["phase1_quantity"] == desired_user_count + assert call_kwargs["task_signature"] == WEBHOOK_CANCELLATION_TASK_SIGNATURE + + owner.refresh_from_db() + assert owner.plan == desired_plan_name + assert owner.plan_user_count == desired_user_count + + @patch("services.billing._create_end_date_schedule") + @patch("services.billing.stripe.SubscriptionSchedule.retrieve") + @patch("services.billing.stripe.Subscription.modify") + @patch("services.billing.stripe.Subscription.retrieve") + @patch("services.billing.stripe.SubscriptionSchedule.release") + def test_modify_subscription_with_schedule_restores_end_date_schedule_on_upgrade_failure( + self, + schedule_release_mock, + retrieve_subscription_mock, + subscription_modify_mock, + schedule_retrieve_mock, + create_end_date_schedule_mock, + ): + """When upgrade fails after releasing a cancellation schedule, the end_date schedule should be restored.""" + original_user_count = 15 + original_plan = PlanName.CODECOV_PRO_MONTHLY.value + stripe_subscription_id = "33043sdf" + owner = OwnerFactory( + plan=original_plan, + plan_user_count=original_user_count, + stripe_subscription_id=stripe_subscription_id, + ) + + schedule_id = "sub_sched_1K77Y5GlVGuVgOrkJrLjRn2e" + current_subscription_start_date = 1639628096 + current_subscription_end_date = 1644107871 + subscription_params = { + "schedule_id": schedule_id, + "start_date": current_subscription_start_date, + "end_date": current_subscription_end_date, + "quantity": original_user_count, + "name": original_plan, + "id": 111, + } + + retrieve_subscription_mock.return_value = MockSubscription(subscription_params) + + # Simulate upgrade failure with StripeError + subscription_modify_mock.side_effect = stripe.StripeError("Payment failed") + + # Existing schedule had our cancellation task and an end_date + existing_schedule = MagicMock() + existing_schedule.metadata = { + "task_signature": CANCELLATION_TASK_SIGNATURE, + "end_date": "2025-12-31T00:00:00+00:00", + } + existing_schedule.phases = [] + schedule_retrieve_mock.return_value = existing_schedule + + desired_plan_name = PlanName.CODECOV_PRO_YEARLY.value + desired_user_count = 15 + desired_plan = {"value": desired_plan_name, "quantity": desired_user_count} + + with self.assertRaises(stripe.StripeError): + self.stripe.modify_subscription(owner, desired_plan) + + # Schedule was released before the failure + schedule_release_mock.assert_called_once_with(schedule_id) + + # End date schedule should be restored after failure + create_end_date_schedule_mock.assert_called_once() + call_kwargs = create_end_date_schedule_mock.call_args.kwargs + assert call_kwargs["owner"] == owner + assert ( + call_kwargs["phase1_plan_id"] == original_plan + ) # Original plan, not desired + assert call_kwargs["phase1_quantity"] == original_user_count + assert call_kwargs["task_signature"] == WEBHOOK_CANCELLATION_TASK_SIGNATURE + + # Owner should not be updated since upgrade failed + owner.refresh_from_db() + assert owner.plan == original_plan + assert owner.plan_user_count == original_user_count + + @patch("services.billing._create_end_date_schedule") + @patch("services.billing.stripe.SubscriptionSchedule.retrieve") + @patch("services.billing.stripe.Subscription.modify") + @patch("services.billing.stripe.Subscription.retrieve") + @patch("services.billing.stripe.SubscriptionSchedule.release") + def test_modify_subscription_with_webhook_cancellation_task_signature( + self, + schedule_release_mock, + retrieve_subscription_mock, + subscription_modify_mock, + schedule_retrieve_mock, + create_end_date_schedule_mock, + ): + """Test that WEBHOOK_CANCELLATION_TASK_SIGNATURE is also recognized for schedule recreation.""" + original_user_count = 15 + original_plan = PlanName.CODECOV_PRO_MONTHLY.value + stripe_subscription_id = "33043sdf" + owner = OwnerFactory( + plan=original_plan, + plan_user_count=original_user_count, + stripe_subscription_id=stripe_subscription_id, + ) + + schedule_id = "sub_sched_1K77Y5GlVGuVgOrkJrLjRn2e" + current_subscription_start_date = 1639628096 + current_subscription_end_date = 1644107871 + subscription_params = { + "schedule_id": schedule_id, + "start_date": current_subscription_start_date, + "end_date": current_subscription_end_date, + "quantity": original_user_count, + "name": original_plan, + "id": 111, + } + + retrieve_subscription_mock.return_value = MockSubscription(subscription_params) + subscription_modify_mock.return_value = MockSubscription(subscription_params) + + # Existing schedule had WEBHOOK_CANCELLATION_TASK_SIGNATURE + existing_schedule = MagicMock() + existing_schedule.metadata = { + "task_signature": WEBHOOK_CANCELLATION_TASK_SIGNATURE, + "end_date": "2025-12-31T00:00:00+00:00", + } + existing_schedule.phases = [] + schedule_retrieve_mock.return_value = existing_schedule + + desired_plan_name = PlanName.CODECOV_PRO_YEARLY.value + desired_user_count = 15 + desired_plan = {"value": desired_plan_name, "quantity": desired_user_count} + + self.stripe.modify_subscription(owner, desired_plan) + + schedule_release_mock.assert_called_once_with(schedule_id) + create_end_date_schedule_mock.assert_called_once() + + owner.refresh_from_db() + assert owner.plan == desired_plan_name + assert owner.plan_user_count == desired_user_count + + @patch("services.billing.stripe.SubscriptionSchedule.retrieve") @patch("services.billing.stripe.Subscription.modify") @patch("services.billing.stripe.Subscription.retrieve") @patch("services.billing.stripe.SubscriptionSchedule.release") @@ -1327,6 +1555,7 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_upgrades_ schedule_release_mock, retrieve_subscription_mock, subscription_modify_mock, + schedule_retrieve_mock, ): original_user_count = 15 original_plan = PlanName.CODECOV_PRO_MONTHLY.value @@ -1352,6 +1581,12 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_upgrades_ retrieve_subscription_mock.return_value = MockSubscription(subscription_params) subscription_modify_mock.return_value = MockSubscription(subscription_params) + # Existing schedule without cancellation task signature + existing_schedule = MagicMock() + existing_schedule.metadata = {} + existing_schedule.phases = [] + schedule_retrieve_mock.return_value = existing_schedule + desired_plan_name = PlanName.CODECOV_PRO_YEARLY.value desired_user_count = 10 desired_plan = {"value": desired_plan_name, "quantity": desired_user_count} @@ -1365,6 +1600,7 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_upgrades_ assert owner.plan == desired_plan_name assert owner.plan_user_count == desired_user_count + @patch("services.billing.stripe.SubscriptionSchedule.retrieve") @patch("services.billing.stripe.Subscription.modify") @patch("services.billing.stripe.Subscription.retrieve") @patch("services.billing.stripe.SubscriptionSchedule.release") @@ -1373,6 +1609,7 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_downgrade schedule_release_mock, retrieve_subscription_mock, subscription_modify_mock, + schedule_retrieve_mock, ): original_user_count = 15 original_plan = PlanName.CODECOV_PRO_YEARLY.value @@ -1398,6 +1635,12 @@ def test_modify_subscription_with_schedule_releases_schedule_when_plan_downgrade retrieve_subscription_mock.return_value = MockSubscription(subscription_params) subscription_modify_mock.return_value = MockSubscription(subscription_params) + # Existing schedule without cancellation task signature + existing_schedule = MagicMock() + existing_schedule.metadata = {} + existing_schedule.phases = [] + schedule_retrieve_mock.return_value = existing_schedule + desired_plan_name = PlanName.CODECOV_PRO_MONTHLY.value desired_user_count = 20 desired_plan = {"value": desired_plan_name, "quantity": desired_user_count}