diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 4caf9ec98b..a2a7f689a5 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -39,7 +39,7 @@ class QuotaCharge: actual_amount: Actual amount consumed. Defaults to the reserved amount. If less than reserved, the difference is refunded automatically. """ - if self._committed or not self.charge_id: + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: return try: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 2bddd3ab11..6fe61a1a52 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -59,7 +59,7 @@ class BillingService: Returns: {"reservation_id": "uuid", "available": int, "reserved": int} """ - payload = { + payload: dict = { "tenant_id": tenant_id, "feature_key": feature_key, "request_id": request_id, @@ -78,7 +78,7 @@ class BillingService: Returns: {"available": int, "reserved": int, "refunded": int} """ - payload = { + payload: dict = { "tenant_id": tenant_id, "feature_key": feature_key, "reservation_id": reservation_id, diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5b1a4790f5..eea9673710 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -39,9 +39,15 @@ class TestAppGenerateService: patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service - mock_billing_service.update_tenant_feature_plan_usage.return_value = { - "result": "success", - "history_id": "test_history_id", + mock_billing_service.quota_reserve.return_value = { + "reservation_id": "test-reservation-id", + "available": 100, + "reserved": 1, + } + mock_billing_service.quota_commit.return_value = { + "available": 99, + "reserved": 0, + "refunded": 0, } # Setup default mock returns for workflow service @@ -478,8 +484,10 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - # Verify billing service was called to consume quota - mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() + # Verify billing two-phase quota (reserve + commit) + billing = mock_external_service_dependencies["billing_service"] + billing.quota_reserve.assert_called_once() + billing.quota_commit.assert_called_once() def test_generate_with_invalid_app_mode( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index c2b430c551..68ee6ae9d6 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -447,8 +447,8 @@ class TestGenerateBilling: def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() - consume_mock = mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + reserve_mock = mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.reserve", return_value=quota_charge, ) mocker.patch( @@ -467,7 +467,8 @@ class TestGenerateBilling: invoke_from=InvokeFrom.SERVICE_API, streaming=False, ) - consume_mock.assert_called_once_with("tenant-id") + reserve_mock.assert_called_once_with("tenant-id") + quota_charge.commit.assert_called_once() def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): from services.errors.app import QuotaExceededError @@ -475,7 +476,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaType.WORKFLOW.reserve", side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), ) @@ -492,7 +493,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaType.WORKFLOW.reserve", return_value=quota_charge, ) mocker.patch( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index 639e091041..ab83c8020f 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -146,6 +146,9 @@ class TestAsyncWorkflowService: mocks["team_task"].delay.return_value = task_result mocks["sandbox_task"].delay.return_value = task_result + quota_charge_mock = MagicMock() + mocks["quota_workflow"].reserve.return_value = quota_charge_mock + class DummyAccount: def __init__(self, user_id: str): self.id = user_id @@ -163,7 +166,8 @@ class TestAsyncWorkflowService: assert result.status == "queued" assert result.queue == queue_name - mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") + mocks["quota_workflow"].reserve.assert_called_once_with("tenant-123") + quota_charge_mock.commit.assert_called_once() assert session.commit.call_count == 2 created_log = mocks["repo"].create.call_args[0][0] @@ -250,7 +254,7 @@ class TestAsyncWorkflowService: mocks = async_workflow_trigger_mocks mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM mocks["get_workflow"].return_value = workflow - mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + mocks["quota_workflow"].reserve.side_effect = QuotaExceededError( feature="workflow", tenant_id="tenant-123", required=1, diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index eecb3c7672..135c2e9962 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -426,7 +426,7 @@ class TestBillingServiceUsageCalculation: # Assert assert result == expected_response - mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id}) + mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id}) def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request): """Test updating tenant feature usage with positive delta (adding credits)."""