From a6e8e43883c5ff25c53cd6968c9e098e4f3f2a74 Mon Sep 17 00:00:00 2001 From: Saumya Talwani <68903741+saumyatalwani@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:51:56 +0530 Subject: [PATCH] test: add tests for some files in services module (#32583) --- .../services/test_api_token_service.py | 466 ++++++++++++++++ .../services/test_app_model_config_service.py | 88 +++ .../services/test_async_workflow_service.py | 507 ++++++++++++++++++ .../services/test_attachment_service.py | 73 +++ .../test_code_based_extension_service.py | 89 +++ .../test_conversation_variable_updater.py | 75 +++ .../services/test_credit_pool_service.py | 157 ++++++ 7 files changed, 1455 insertions(+) create mode 100644 api/tests/unit_tests/services/test_api_token_service.py create mode 100644 api/tests/unit_tests/services/test_app_model_config_service.py create mode 100644 api/tests/unit_tests/services/test_async_workflow_service.py create mode 100644 api/tests/unit_tests/services/test_attachment_service.py create mode 100644 api/tests/unit_tests/services/test_code_based_extension_service.py create mode 100644 api/tests/unit_tests/services/test_conversation_variable_updater.py create mode 100644 api/tests/unit_tests/services/test_credit_pool_service.py diff --git a/api/tests/unit_tests/services/test_api_token_service.py b/api/tests/unit_tests/services/test_api_token_service.py new file mode 100644 index 0000000000..ad4de93b25 --- /dev/null +++ b/api/tests/unit_tests/services/test_api_token_service.py @@ -0,0 +1,466 @@ +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Unauthorized + +import services.api_token_service as api_token_service_module +from services.api_token_service import ApiTokenCache, CachedApiToken + + +@pytest.fixture +def mock_db_session(): + """Fixture providing common DB session mocking for query_token_from_db tests.""" + fake_engine = MagicMock() + + session = MagicMock() + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + with ( + patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)), + patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class, + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + yield { + "session": session, + "mock_session_class": mock_session_class, + "mock_cache_set": mock_cache_set, + "mock_record_usage": mock_record_usage, + "fake_engine": fake_engine, + } + + +class TestQueryTokenFromDb: + def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session): + """Test DB lookup success path caches token and records usage.""" + # Arrange + auth_token = "token-123" + scope = "app" + api_token = MagicMock() + + mock_db_session["session"].scalar.return_value = api_token + + # Act + result = api_token_service_module.query_token_from_db(auth_token, scope) + + # Assert + assert result == api_token + mock_db_session["mock_session_class"].assert_called_once_with( + mock_db_session["fake_engine"], expire_on_commit=False + ) + mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token) + mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope) + + def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session): + """Test DB lookup miss path caches null marker and raises Unauthorized.""" + # Arrange + auth_token = "missing-token" + scope = "app" + + mock_db_session["session"].scalar.return_value = None + + # Act / Assert + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.query_token_from_db(auth_token, scope) + + mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None) + mock_db_session["mock_record_usage"].assert_not_called() + + +class TestRecordTokenUsage: + def test_should_write_active_key_with_iso_timestamp_and_ttl(self): + """Test record_token_usage writes usage timestamp with one-hour TTL.""" + # Arrange + auth_token = "token-123" + scope = "dataset" + fixed_time = datetime(2026, 2, 24, 12, 0, 0) + expected_key = ApiTokenCache.make_active_key(auth_token, scope) + + with ( + patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time), + patch.object(api_token_service_module, "redis_client") as mock_redis, + ): + # Act + api_token_service_module.record_token_usage(auth_token, scope) + + # Assert + mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600) + + def test_should_not_raise_when_redis_write_fails(self): + """Test record_token_usage swallows Redis errors.""" + # Arrange + with patch.object(api_token_service_module, "redis_client") as mock_redis: + mock_redis.set.side_effect = Exception("redis unavailable") + + # Act / Assert + api_token_service_module.record_token_usage("token-123", "app") + + +class TestFetchTokenWithSingleFlight: + def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self): + """Test single-flight returns cache when another request already populated it.""" + # Arrange + auth_token = "token-123" + scope = "app" + cached_token = CachedApiToken( + id="id-1", + app_id="app-1", + tenant_id="tenant-1", + type="app", + token=auth_token, + last_used_at=None, + created_at=None, + ) + + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get, + patch.object(api_token_service_module, "query_token_from_db") as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == cached_token + mock_redis.lock.assert_called_once_with( + f"api_token_query_lock:{scope}:{auth_token}", + timeout=10, + blocking_timeout=5, + ) + lock.acquire.assert_called_once_with(blocking=True) + lock.release.assert_called_once() + mock_cache_get.assert_called_once_with(auth_token, scope) + mock_query_db.assert_not_called() + + def test_should_query_db_when_lock_acquired_and_cache_missed(self): + """Test single-flight queries DB when cache remains empty after lock acquisition.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None), + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_query_db.assert_called_once_with(auth_token, scope) + lock.release.assert_called_once() + + def test_should_query_db_directly_when_lock_not_acquired(self): + """Test lock timeout branch falls back to direct DB query.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.return_value = False + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get") as mock_cache_get, + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_cache_get.assert_not_called() + mock_query_db.assert_called_once_with(auth_token, scope) + lock.release.assert_not_called() + + def test_should_reraise_unauthorized_from_db_query(self): + """Test Unauthorized from DB query is propagated unchanged.""" + # Arrange + auth_token = "token-123" + scope = "app" + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None), + patch.object( + api_token_service_module, + "query_token_from_db", + side_effect=Unauthorized("Access token is invalid"), + ), + ): + mock_redis.lock.return_value = lock + + # Act / Assert + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + lock.release.assert_called_once() + + def test_should_fallback_to_db_query_when_lock_raises_exception(self): + """Test Redis lock errors fall back to direct DB query.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.side_effect = RuntimeError("redis lock error") + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_query_db.assert_called_once_with(auth_token, scope) + + +class TestApiTokenCacheTenantBranches: + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis): + """Test scoped delete removes cache key and tenant index membership.""" + # Arrange + token = "token-123" + scope = "app" + cache_key = ApiTokenCache._make_cache_key(token, scope) + cached_token = CachedApiToken( + id="id-1", + app_id="app-1", + tenant_id="tenant-1", + type="app", + token=token, + last_used_at=None, + created_at=None, + ) + mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8") + + with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index: + # Act + result = ApiTokenCache.delete(token, scope) + + # Assert + assert result is True + mock_redis.delete.assert_called_once_with(cache_key) + mock_remove_index.assert_called_once_with("tenant-1", cache_key) + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis): + """Test tenant invalidation deletes indexed cache entries and index key.""" + # Arrange + tenant_id = "tenant-1" + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + mock_redis.smembers.return_value = { + b"api_token:app:token-1", + b"api_token:any:token-2", + } + + # Act + result = ApiTokenCache.invalidate_by_tenant(tenant_id) + + # Assert + assert result is True + mock_redis.smembers.assert_called_once_with(index_key) + mock_redis.delete.assert_any_call("api_token:app:token-1") + mock_redis.delete.assert_any_call("api_token:any:token-2") + mock_redis.delete.assert_any_call(index_key) + + +class TestApiTokenCacheCoreBranches: + def test_cached_api_token_repr_should_include_id_and_type(self): + """Test CachedApiToken __repr__ includes key identity fields.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + + assert repr(token) == "" + + def test_serialize_token_should_handle_cached_api_token_instances(self): + """Test serialization path when input is already a CachedApiToken.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + + serialized = ApiTokenCache._serialize_token(token) + + assert isinstance(serialized, bytes) + assert b'"id":"id-123"' in serialized + assert b'"token":"token-123"' in serialized + + def test_deserialize_token_should_return_none_for_null_markers(self): + """Test null cache marker deserializes to None.""" + assert ApiTokenCache._deserialize_token("null") is None + assert ApiTokenCache._deserialize_token(b"null") is None + + def test_deserialize_token_should_return_none_for_invalid_payload(self): + """Test invalid serialized payload returns None.""" + assert ApiTokenCache._deserialize_token("not-json") is None + + @patch("services.api_token_service.redis_client") + def test_get_should_return_none_on_cache_miss(self, mock_redis): + """Test cache miss branch in ApiTokenCache.get.""" + mock_redis.get.return_value = None + + result = ApiTokenCache.get("token-123", "app") + + assert result is None + mock_redis.get.assert_called_once_with("api_token:app:token-123") + + @patch("services.api_token_service.redis_client") + def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis): + """Test cache hit branch in ApiTokenCache.get.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + mock_redis.get.return_value = token.model_dump_json().encode("utf-8") + + result = ApiTokenCache.get("token-123", "app") + + assert isinstance(result, CachedApiToken) + assert result.id == "id-123" + + @patch("services.api_token_service.redis_client") + def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): + """Test tenant index update exits early for missing tenant id.""" + ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123") + + mock_redis.sadd.assert_not_called() + mock_redis.expire.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis): + """Test tenant index update handles Redis write errors gracefully.""" + mock_redis.sadd.side_effect = Exception("redis down") + + ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123") + + mock_redis.sadd.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): + """Test tenant index removal exits early for missing tenant id.""" + ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123") + + mock_redis.srem.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis): + """Test tenant index removal handles Redis errors gracefully.""" + mock_redis.srem.side_effect = Exception("redis down") + + ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123") + + mock_redis.srem.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis): + """Test set returns False when Redis setex fails.""" + mock_redis.setex.side_effect = Exception("redis write failed") + api_token = MagicMock() + api_token.id = "id-123" + api_token.app_id = "app-123" + api_token.tenant_id = "tenant-123" + api_token.type = "app" + api_token.token = "token-123" + api_token.last_used_at = None + api_token.created_at = None + + result = ApiTokenCache.set("token-123", "app", api_token) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis): + """Test delete(scope=None) returns False when scan_iter raises.""" + mock_redis.scan_iter.side_effect = Exception("scan failed") + + result = ApiTokenCache.delete("token-123", None) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis): + """Test scoped delete still succeeds when tenant lookup from cache fails.""" + token = "token-123" + scope = "app" + cache_key = ApiTokenCache._make_cache_key(token, scope) + mock_redis.get.side_effect = Exception("get failed") + + result = ApiTokenCache.delete(token, scope) + + assert result is True + mock_redis.delete.assert_called_once_with(cache_key) + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis): + """Test scoped delete returns False when delete operation fails.""" + token = "token-123" + scope = "app" + mock_redis.get.return_value = None + mock_redis.delete.side_effect = Exception("delete failed") + + result = ApiTokenCache.delete(token, scope) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis): + """Test tenant invalidation returns True when tenant index is empty.""" + mock_redis.smembers.return_value = set() + + result = ApiTokenCache.invalidate_by_tenant("tenant-123") + + assert result is True + mock_redis.delete.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis): + """Test tenant invalidation returns False when Redis operation fails.""" + mock_redis.smembers.side_effect = Exception("redis failed") + + result = ApiTokenCache.invalidate_by_tenant("tenant-123") + + assert result is False diff --git a/api/tests/unit_tests/services/test_app_model_config_service.py b/api/tests/unit_tests/services/test_app_model_config_service.py new file mode 100644 index 0000000000..d4b4bf14a3 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_model_config_service.py @@ -0,0 +1,88 @@ +from unittest.mock import patch + +import pytest + +from models.model import AppMode +from services.app_model_config_service import AppModelConfigService + + +@pytest.fixture +def mock_config_managers(): + """Fixture that patches all app config manager validate methods. + + Returns a dictionary containing the mocked config_validate methods for each manager. + """ + with ( + patch("services.app_model_config_service.ChatAppConfigManager.config_validate") as mock_chat_validate, + patch("services.app_model_config_service.AgentChatAppConfigManager.config_validate") as mock_agent_validate, + patch( + "services.app_model_config_service.CompletionAppConfigManager.config_validate" + ) as mock_completion_validate, + ): + mock_chat_validate.return_value = {"manager": "chat"} + mock_agent_validate.return_value = {"manager": "agent"} + mock_completion_validate.return_value = {"manager": "completion"} + + yield { + "chat": mock_chat_validate, + "agent": mock_agent_validate, + "completion": mock_completion_validate, + } + + +class TestAppModelConfigService: + @pytest.mark.parametrize( + ("app_mode", "selected_manager"), + [ + (AppMode.CHAT, "chat"), + (AppMode.AGENT_CHAT, "agent"), + (AppMode.COMPLETION, "completion"), + ], + ) + def test_should_route_validation_to_correct_manager_based_on_app_mode( + self, app_mode, selected_manager, mock_config_managers + ): + """Test configuration validation is delegated to the expected manager for each supported app mode.""" + tenant_id = "tenant-123" + config = {"temperature": 0.5} + + mock_chat_validate = mock_config_managers["chat"] + mock_agent_validate = mock_config_managers["agent"] + mock_completion_validate = mock_config_managers["completion"] + + result = AppModelConfigService.validate_configuration(tenant_id=tenant_id, config=config, app_mode=app_mode) + + assert result == {"manager": selected_manager} + + if selected_manager == "chat": + mock_chat_validate.assert_called_once_with(tenant_id, config) + mock_agent_validate.assert_not_called() + mock_completion_validate.assert_not_called() + elif selected_manager == "agent": + mock_agent_validate.assert_called_once_with(tenant_id, config) + mock_chat_validate.assert_not_called() + mock_completion_validate.assert_not_called() + else: + mock_completion_validate.assert_called_once_with(tenant_id, config) + mock_chat_validate.assert_not_called() + mock_agent_validate.assert_not_called() + + def test_should_raise_value_error_when_app_mode_is_not_supported(self, mock_config_managers): + """Test unsupported app modes raise ValueError with the invalid mode in the message.""" + tenant_id = "tenant-123" + config = {"temperature": 0.5} + + mock_chat_validate = mock_config_managers["chat"] + mock_agent_validate = mock_config_managers["agent"] + mock_completion_validate = mock_config_managers["completion"] + + with pytest.raises(ValueError, match=f"Invalid app mode: {AppMode.WORKFLOW}"): + AppModelConfigService.validate_configuration( + tenant_id=tenant_id, + config=config, + app_mode=AppMode.WORKFLOW, + ) + + mock_chat_validate.assert_not_called() + mock_agent_validate.assert_not_called() + mock_completion_validate.assert_not_called() diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py new file mode 100644 index 0000000000..639e091041 --- /dev/null +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -0,0 +1,507 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import services.async_workflow_service as async_workflow_service_module +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus +from services.async_workflow_service import AsyncWorkflowService +from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.workflow.entities import AsyncTriggerResponse, TriggerData +from services.workflow.queue_dispatcher import QueuePriority + + +class AsyncWorkflowServiceTestDataFactory: + """Factory helpers for async workflow service unit tests.""" + + @staticmethod + def create_trigger_data( + app_id: str = "app-123", + tenant_id: str = "tenant-123", + workflow_id: str | None = "workflow-123", + root_node_id: str = "root-node-123", + ) -> TriggerData: + """Create valid trigger data for async workflow execution tests.""" + return TriggerData( + app_id=app_id, + tenant_id=tenant_id, + workflow_id=workflow_id, + root_node_id=root_node_id, + inputs={"name": "dify"}, + files=[], + trigger_type=AppTriggerType.UNKNOWN, + trigger_from=WorkflowRunTriggeredFrom.APP_RUN, + trigger_metadata=None, + ) + + @staticmethod + def create_trigger_log_with_data(trigger_data: TriggerData, retry_count: int = 0) -> MagicMock: + """Create a mock trigger log with serialized trigger data.""" + trigger_log = MagicMock() + trigger_log.id = "trigger-log-123" + trigger_log.trigger_data = trigger_data.model_dump_json() + trigger_log.retry_count = retry_count + trigger_log.error = "previous-error" + trigger_log.status = WorkflowTriggerStatus.FAILED + trigger_log.to_dict.return_value = {"id": trigger_log.id} + return trigger_log + + +class TestAsyncWorkflowService: + @pytest.fixture + def async_workflow_trigger_mocks(self): + """Shared fixture for async workflow trigger tests. + + Yields mocks for: + - repo: SQLAlchemyWorkflowTriggerLogRepository + - dispatcher_manager_class: QueueDispatcherManager class + - dispatcher: dispatcher instance + - quota_workflow: QuotaType.WORKFLOW + - get_workflow: AsyncWorkflowService._get_workflow method + - professional_task: execute_workflow_professional + - team_task: execute_workflow_team + - sandbox_task: execute_workflow_sandbox + """ + mock_repo = MagicMock() + + def _create_side_effect(new_log): + new_log.id = "trigger-log-123" + return new_log + + mock_repo.create.side_effect = _create_side_effect + + mock_dispatcher = MagicMock() + quota_workflow = MagicMock() + mock_get_workflow = MagicMock() + + mock_professional_task = MagicMock() + mock_team_task = MagicMock() + mock_sandbox_task = MagicMock() + + with ( + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + patch.object(async_workflow_service_module, "QueueDispatcherManager") as mock_dispatcher_manager_class, + patch.object(async_workflow_service_module, "WorkflowService"), + patch.object( + async_workflow_service_module.AsyncWorkflowService, + "_get_workflow", + ) as mock_get_workflow, + patch.object( + async_workflow_service_module, + "QuotaType", + new=SimpleNamespace(WORKFLOW=quota_workflow), + ), + patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, + patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, + patch.object(async_workflow_service_module, "execute_workflow_sandbox") as mock_sandbox_task, + ): + # Configure dispatcher_manager to return our mock_dispatcher + mock_dispatcher_manager_class.return_value.get_dispatcher.return_value = mock_dispatcher + + yield { + "repo": mock_repo, + "dispatcher_manager_class": mock_dispatcher_manager_class, + "dispatcher": mock_dispatcher, + "quota_workflow": quota_workflow, + "get_workflow": mock_get_workflow, + "professional_task": mock_professional_task, + "team_task": mock_team_task, + "sandbox_task": mock_sandbox_task, + } + + @pytest.mark.parametrize( + ("queue_name", "selected_task_attr"), + [ + (QueuePriority.PROFESSIONAL, "execute_workflow_professional"), + (QueuePriority.TEAM, "execute_workflow_team"), + (QueuePriority.SANDBOX, "execute_workflow_sandbox"), + ], + ) + def test_should_dispatch_to_matching_celery_task_when_triggering_workflow( + self, queue_name, selected_task_attr, async_workflow_trigger_mocks + ): + """Test queue-based task routing and successful async trigger response.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = queue_name + mocks["get_workflow"].return_value = workflow + + task_result = MagicMock() + task_result.id = "task-123" + mocks["professional_task"].delay.return_value = task_result + mocks["team_task"].delay.return_value = task_result + mocks["sandbox_task"].delay.return_value = task_result + + class DummyAccount: + def __init__(self, user_id: str): + self.id = user_id + + with patch.object(async_workflow_service_module, "Account", DummyAccount): + user = DummyAccount("account-123") + + # Act + result = AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data) + + # Assert + assert isinstance(result, AsyncTriggerResponse) + assert result.workflow_trigger_log_id == "trigger-log-123" + assert result.task_id == "task-123" + assert result.status == "queued" + assert result.queue == queue_name + + mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") + assert session.commit.call_count == 2 + + created_log = mocks["repo"].create.call_args[0][0] + assert created_log.status == WorkflowTriggerStatus.QUEUED + assert created_log.queue_name == queue_name + assert created_log.created_by_role == CreatorUserRole.ACCOUNT + assert created_log.created_by == "account-123" + assert created_log.trigger_data == trigger_data.model_dump_json() + assert created_log.inputs == json.dumps(dict(trigger_data.inputs)) + assert created_log.celery_task_id == "task-123" + + task_mocks = { + "execute_workflow_professional": mocks["professional_task"], + "execute_workflow_team": mocks["team_task"], + "execute_workflow_sandbox": mocks["sandbox_task"], + } + for task_attr, task_mock in task_mocks.items(): + if task_attr == selected_task_attr: + task_mock.delay.assert_called_once_with({"workflow_trigger_log_id": "trigger-log-123"}) + else: + task_mock.delay.assert_not_called() + + def test_should_set_end_user_role_when_triggered_by_end_user(self, async_workflow_trigger_mocks): + """Test that non-account users are tracked as END_USER in trigger logs.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = QueuePriority.SANDBOX + mocks["get_workflow"].return_value = workflow + + task_result = MagicMock(id="task-123") + mocks["sandbox_task"].delay.return_value = task_result + + user = SimpleNamespace(id="end-user-123") + + # Act + AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data) + + # Assert + created_log = mocks["repo"].create.call_args[0][0] + assert created_log.created_by_role == CreatorUserRole.END_USER + assert created_log.created_by == "end-user-123" + + def test_should_raise_workflow_not_found_when_app_does_not_exist(self): + """Test trigger failure when app lookup returns no result.""" + # Arrange + session = MagicMock() + session.scalar.return_value = None + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data(app_id="missing-app") + + with ( + patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository"), + patch.object(async_workflow_service_module, "QueueDispatcherManager"), + patch.object(async_workflow_service_module, "WorkflowService"), + ): + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="App not found: missing-app"): + AsyncWorkflowService.trigger_workflow_async( + session=session, + user=SimpleNamespace(id="user-123"), + trigger_data=trigger_data, + ) + + def test_should_mark_log_rate_limited_and_raise_when_quota_exceeded(self, async_workflow_trigger_mocks): + """Test quota-exceeded path updates trigger log and raises WorkflowQuotaLimitError.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM + mocks["get_workflow"].return_value = workflow + mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + feature="workflow", + tenant_id="tenant-123", + required=1, + ) + + # Act / Assert + with pytest.raises( + WorkflowQuotaLimitError, + match="Workflow execution quota limit reached for tenant tenant-123", + ): + AsyncWorkflowService.trigger_workflow_async( + session=session, + user=SimpleNamespace(id="user-123"), + trigger_data=trigger_data, + ) + + assert session.commit.call_count == 2 + updated_log = mocks["repo"].update.call_args[0][0] + assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED + assert "Quota limit reached" in updated_log.error + mocks["professional_task"].delay.assert_not_called() + mocks["team_task"].delay.assert_not_called() + mocks["sandbox_task"].delay.assert_not_called() + + def test_should_raise_when_reinvoke_target_log_does_not_exist(self): + """Test reinvoke_trigger error path when original trigger log is missing.""" + # Arrange + session = MagicMock() + repo = MagicMock() + repo.get_by_id.return_value = None + + with patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo): + # Act / Assert + with pytest.raises(ValueError, match="Trigger log not found: missing-log"): + AsyncWorkflowService.reinvoke_trigger( + session=session, + user=SimpleNamespace(id="user-123"), + workflow_trigger_log_id="missing-log", + ) + + def test_should_update_original_log_and_requeue_when_reinvoking(self): + """Test reinvoke flow updates original log state and triggers a new async run.""" + # Arrange + session = MagicMock() + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + trigger_log = AsyncWorkflowServiceTestDataFactory.create_trigger_log_with_data(trigger_data, retry_count=1) + repo = MagicMock() + repo.get_by_id.return_value = trigger_log + + expected_response = AsyncTriggerResponse( + workflow_trigger_log_id="new-trigger-log-456", + task_id="task-456", + status="queued", + queue=QueuePriority.TEAM, + ) + + with ( + patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo), + patch.object( + async_workflow_service_module.AsyncWorkflowService, + "trigger_workflow_async", + return_value=expected_response, + ) as mock_trigger_workflow_async, + ): + user = SimpleNamespace(id="user-123") + + # Act + response = AsyncWorkflowService.reinvoke_trigger( + session=session, + user=user, + workflow_trigger_log_id="trigger-log-123", + ) + + # Assert + assert response == expected_response + assert trigger_log.status == WorkflowTriggerStatus.RETRYING + assert trigger_log.retry_count == 2 + assert trigger_log.error is None + assert trigger_log.triggered_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + called_trigger_data = mock_trigger_workflow_async.call_args[0][2] + assert isinstance(called_trigger_data, TriggerData) + assert called_trigger_data.app_id == "app-123" + + @pytest.mark.parametrize( + ("repo_result", "expected"), + [ + (None, None), + (MagicMock(), {"id": "trigger-log-123"}), + ], + ) + def test_should_return_trigger_log_dict_or_none(self, repo_result, expected): + """Test get_trigger_log returns serialized log data or None.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + fake_engine = MagicMock() + mock_repo.get_by_id.return_value = repo_result + if repo_result: + repo_result.to_dict.return_value = expected + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)), + patch.object( + async_workflow_service_module, "Session", return_value=mock_session_context + ) as mock_session_class, + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_trigger_log("trigger-log-123", tenant_id="tenant-123") + + # Assert + assert result == expected + mock_session_class.assert_called_once_with(fake_engine) + mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123") + + def test_should_return_recent_logs_as_dict_list(self): + """Test get_recent_logs converts repository models into dictionaries.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + log1 = MagicMock() + log1.to_dict.return_value = {"id": "log-1"} + log2 = MagicMock() + log2.to_dict.return_value = {"id": "log-2"} + mock_repo.get_recent_logs.return_value = [log1, log2] + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), + patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_recent_logs( + tenant_id="tenant-123", + app_id="app-123", + hours=12, + limit=50, + offset=10, + ) + + # Assert + assert result == [{"id": "log-1"}, {"id": "log-2"}] + mock_repo.get_recent_logs.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + hours=12, + limit=50, + offset=10, + ) + + def test_should_return_failed_logs_for_retry_as_dict_list(self): + """Test get_failed_logs_for_retry serializes repository logs into dicts.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + log = MagicMock() + log.to_dict.return_value = {"id": "failed-log-1"} + mock_repo.get_failed_for_retry.return_value = [log] + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), + patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_failed_logs_for_retry(tenant_id="tenant-123", max_retry_count=4, limit=20) + + # Assert + assert result == [{"id": "failed-log-1"}] + mock_repo.get_failed_for_retry.assert_called_once_with(tenant_id="tenant-123", max_retry_count=4, limit=20) + + +class TestAsyncWorkflowServiceGetWorkflow: + def test_should_return_specific_workflow_when_workflow_id_exists(self): + """Test _get_workflow returns published workflow by id when provided.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + workflow = MagicMock() + workflow_service.get_published_workflow_by_id.return_value = workflow + + # Act + result = AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-123") + + # Assert + assert result == workflow + workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123") + workflow_service.get_published_workflow.assert_not_called() + + def test_should_raise_when_specific_workflow_id_not_found(self): + """Test _get_workflow raises WorkflowNotFoundError for unknown workflow id.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + workflow_service.get_published_workflow_by_id.return_value = None + + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="Published workflow not found: workflow-404"): + AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-404") + + def test_should_return_default_published_workflow_when_workflow_id_not_provided(self): + """Test _get_workflow returns default published workflow when no id is provided.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + workflow = MagicMock() + workflow_service.get_published_workflow.return_value = workflow + + # Act + result = AsyncWorkflowService._get_workflow(workflow_service, app_model) + + # Assert + assert result == workflow + workflow_service.get_published_workflow.assert_called_once_with(app_model) + workflow_service.get_published_workflow_by_id.assert_not_called() + + def test_should_raise_when_default_published_workflow_not_found(self): + """Test _get_workflow raises WorkflowNotFoundError when app has no published workflow.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + workflow_service.get_published_workflow.return_value = None + + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="No published workflow found for app: app-123"): + AsyncWorkflowService._get_workflow(workflow_service, app_model) diff --git a/api/tests/unit_tests/services/test_attachment_service.py b/api/tests/unit_tests/services/test_attachment_service.py new file mode 100644 index 0000000000..88be20bc41 --- /dev/null +++ b/api/tests/unit_tests/services/test_attachment_service.py @@ -0,0 +1,73 @@ +import base64 +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +import services.attachment_service as attachment_service_module +from models.model import UploadFile +from services.attachment_service import AttachmentService + + +class TestAttachmentService: + def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self): + """Test that AttachmentService keeps the provided sessionmaker instance.""" + session_factory = sessionmaker() + + service = AttachmentService(session_factory=session_factory) + + assert service._session_maker is session_factory + + def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self): + """Test that AttachmentService builds a sessionmaker bound to the provided engine.""" + engine = create_engine("sqlite:///:memory:") + + service = AttachmentService(session_factory=engine) + session = service._session_maker() + try: + assert session.bind == engine + finally: + session.close() + engine.dispose() + + @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) + def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory): + """Test that invalid session_factory types are rejected.""" + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + AttachmentService(session_factory=invalid_session_factory) + + def test_should_return_base64_encoded_blob_when_file_exists(self): + """Test that existing files are loaded from storage and returned as base64.""" + service = AttachmentService(session_factory=sessionmaker()) + upload_file = MagicMock(spec=UploadFile) + upload_file.key = "upload-file-key" + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = upload_file + service._session_maker = MagicMock(return_value=session) + + with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: + result = service.get_file_base64("file-123") + + assert result == base64.b64encode(b"binary-content").decode() + service._session_maker.assert_called_once_with(expire_on_commit=False) + session.query.assert_called_once_with(UploadFile) + mock_load.assert_called_once_with("upload-file-key") + + def test_should_raise_not_found_when_file_does_not_exist(self): + """Test that missing files raise NotFound and never call storage.""" + service = AttachmentService(session_factory=sessionmaker()) + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + service._session_maker = MagicMock(return_value=session) + + with patch.object(attachment_service_module.storage, "load_once") as mock_load: + with pytest.raises(NotFound, match="File not found"): + service.get_file_base64("missing-file") + + service._session_maker.assert_called_once_with(expire_on_commit=False) + session.query.assert_called_once_with(UploadFile) + mock_load.assert_not_called() diff --git a/api/tests/unit_tests/services/test_code_based_extension_service.py b/api/tests/unit_tests/services/test_code_based_extension_service.py new file mode 100644 index 0000000000..f6538a140a --- /dev/null +++ b/api/tests/unit_tests/services/test_code_based_extension_service.py @@ -0,0 +1,89 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from services.code_based_extension_service import CodeBasedExtensionService + + +class TestCodeBasedExtensionService: + def test_should_return_only_non_builtin_extensions_with_public_fields(self, monkeypatch: pytest.MonkeyPatch): + """Test service returns only non-builtin extensions with name/label/form_schema fields.""" + moderation_extension = SimpleNamespace( + name="custom-moderation", + label={"en-US": "Custom Moderation"}, + form_schema=[{"variable": "api_key"}], + builtin=False, + extension_class=object, + position=20, + ) + builtin_extension = SimpleNamespace( + name="builtin-moderation", + label={"en-US": "Builtin Moderation"}, + form_schema=[{"variable": "token"}], + builtin=True, + extension_class=object, + position=1, + ) + retrieval_extension = SimpleNamespace( + name="custom-retrieval", + label={"en-US": "Custom Retrieval"}, + form_schema=None, + builtin=False, + extension_class=object, + position=30, + ) + module_extensions_mock = MagicMock(return_value=[moderation_extension, builtin_extension, retrieval_extension]) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + result = CodeBasedExtensionService.get_code_based_extension("external_data_tool") + + assert result == [ + { + "name": "custom-moderation", + "label": {"en-US": "Custom Moderation"}, + "form_schema": [{"variable": "api_key"}], + }, + { + "name": "custom-retrieval", + "label": {"en-US": "Custom Retrieval"}, + "form_schema": None, + }, + ] + assert set(result[0].keys()) == {"name", "label", "form_schema"} + module_extensions_mock.assert_called_once_with("external_data_tool") + + def test_should_return_empty_list_when_all_extensions_are_builtin(self, monkeypatch: pytest.MonkeyPatch): + """Test builtin extensions are filtered out completely.""" + builtin_extension = SimpleNamespace( + name="builtin-moderation", + label={"en-US": "Builtin Moderation"}, + form_schema=[{"variable": "token"}], + builtin=True, + ) + module_extensions_mock = MagicMock(return_value=[builtin_extension]) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + result = CodeBasedExtensionService.get_code_based_extension("moderation") + + assert result == [] + module_extensions_mock.assert_called_once_with("moderation") + + def test_should_propagate_error_when_module_extensions_lookup_fails(self, monkeypatch: pytest.MonkeyPatch): + """Test ValueError from extension lookup bubbles up unchanged.""" + module_extensions_mock = MagicMock(side_effect=ValueError("Extension Module invalid-module not found")) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + with pytest.raises(ValueError, match="Extension Module invalid-module not found"): + CodeBasedExtensionService.get_code_based_extension("invalid-module") + + module_extensions_mock.assert_called_once_with("invalid-module") diff --git a/api/tests/unit_tests/services/test_conversation_variable_updater.py b/api/tests/unit_tests/services/test_conversation_variable_updater.py new file mode 100644 index 0000000000..157424a2a7 --- /dev/null +++ b/api/tests/unit_tests/services/test_conversation_variable_updater.py @@ -0,0 +1,75 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.variables.variables import StringVariable +from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater + + +class TestConversationVariableUpdater: + def test_should_update_conversation_variable_data_and_commit(self): + """Test update persists serialized variable data when the row exists.""" + conversation_id = "conv-123" + variable = StringVariable( + id="var-123", + name="topic", + value="new value", + ) + expected_json = variable.model_dump_json() + + row = SimpleNamespace(data="old value") + session = MagicMock() + session.scalar.return_value = row + + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + session_maker = MagicMock(return_value=session_context) + updater = ConversationVariableUpdater(session_maker) + + updater.update(conversation_id=conversation_id, variable=variable) + + session_maker.assert_called_once_with() + session.scalar.assert_called_once() + stmt = session.scalar.call_args.args[0] + compiled_params = stmt.compile().params + assert variable.id in compiled_params.values() + assert conversation_id in compiled_params.values() + assert row.data == expected_json + session.commit.assert_called_once() + + def test_should_raise_not_found_error_when_conversation_variable_missing(self): + """Test update raises ConversationVariableNotFoundError when no matching row exists.""" + conversation_id = "conv-404" + variable = StringVariable( + id="var-404", + name="topic", + value="value", + ) + + session = MagicMock() + session.scalar.return_value = None + + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + session_maker = MagicMock(return_value=session_context) + updater = ConversationVariableUpdater(session_maker) + + with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): + updater.update(conversation_id=conversation_id, variable=variable) + + session.commit.assert_not_called() + + def test_should_do_nothing_when_flush_is_called(self): + """Test flush currently behaves as a no-op and returns None.""" + session_maker = MagicMock() + updater = ConversationVariableUpdater(session_maker) + + result = updater.flush() + + assert result is None + session_maker.assert_not_called() diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py new file mode 100644 index 0000000000..9ef314cb9e --- /dev/null +++ b/api/tests/unit_tests/services/test_credit_pool_service.py @@ -0,0 +1,157 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import services.credit_pool_service as credit_pool_service_module +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from services.credit_pool_service import CreditPoolService + + +@pytest.fixture +def mock_credit_deduction_setup(): + """Fixture providing common setup for credit deduction tests.""" + pool = SimpleNamespace(remaining_credits=50) + fake_engine = MagicMock() + session = MagicMock() + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool) + mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine)) + mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context) + + return { + "pool": pool, + "fake_engine": fake_engine, + "session": session, + "session_context": session_context, + "patches": (mock_get_pool, mock_db, mock_session), + } + + +class TestCreditPoolService: + def test_should_create_default_pool_with_trial_type_and_configured_quota(self): + """Test create_default_pool persists a trial pool using configured hosted credits.""" + tenant_id = "tenant-123" + hosted_pool_credits = 5000 + + with ( + patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits), + patch.object(credit_pool_service_module, "db") as mock_db, + ): + pool = CreditPoolService.create_default_pool(tenant_id) + + assert isinstance(pool, TenantCreditPool) + assert pool.tenant_id == tenant_id + assert pool.pool_type == "trial" + assert pool.quota_limit == hosted_pool_credits + assert pool.quota_used == 0 + mock_db.session.add.assert_called_once_with(pool) + mock_db.session.commit.assert_called_once() + + def test_should_return_first_pool_from_query_when_get_pool_called(self): + """Test get_pool queries by tenant and pool_type and returns first result.""" + tenant_id = "tenant-123" + pool_type = "enterprise" + expected_pool = MagicMock(spec=TenantCreditPool) + + with patch.object(credit_pool_service_module, "db") as mock_db: + query = mock_db.session.query.return_value + filtered_query = query.filter_by.return_value + filtered_query.first.return_value = expected_pool + + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type) + + assert result == expected_pool + mock_db.session.query.assert_called_once_with(TenantCreditPool) + query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type) + filtered_query.first.assert_called_once() + + def test_should_return_false_when_pool_not_found_in_check_credits_available(self): + """Test check_credits_available returns False when tenant has no pool.""" + with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool: + result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10) + + assert result is False + mock_get_pool.assert_called_once_with("tenant-123", "trial") + + def test_should_return_true_when_remaining_credits_cover_required_amount(self): + """Test check_credits_available returns True when remaining credits are sufficient.""" + pool = SimpleNamespace(remaining_credits=100) + + with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool: + result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) + + assert result is True + mock_get_pool.assert_called_once_with("tenant-123", "trial") + + def test_should_return_false_when_remaining_credits_are_insufficient(self): + """Test check_credits_available returns False when required credits exceed remaining credits.""" + pool = SimpleNamespace(remaining_credits=30) + + with patch.object(CreditPoolService, "get_pool", return_value=pool): + result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) + + assert result is False + + def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self): + """Test check_and_deduct_credits raises when tenant credit pool does not exist.""" + with patch.object(CreditPoolService, "get_pool", return_value=None): + with pytest.raises(QuotaExceededError, match="Credit pool not found"): + CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) + + def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self): + """Test check_and_deduct_credits raises when remaining credits are zero or negative.""" + pool = SimpleNamespace(remaining_credits=0) + + with patch.object(CreditPoolService, "get_pool", return_value=pool): + with pytest.raises(QuotaExceededError, match="No credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) + + def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup): + """Test check_and_deduct_credits updates quota_used by the actual deducted amount.""" + tenant_id = "tenant-123" + pool_type = "trial" + credits_required = 200 + remaining_credits = 120 + expected_deducted_credits = 120 + + mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits + patches = mock_credit_deduction_setup["patches"] + session = mock_credit_deduction_setup["session"] + + with patches[0], patches[1], patches[2]: + result = CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=credits_required, + pool_type=pool_type, + ) + + assert result == expected_deducted_credits + session.execute.assert_called_once() + session.commit.assert_called_once() + + stmt = session.execute.call_args.args[0] + compiled_params = stmt.compile().params + assert tenant_id in compiled_params.values() + assert pool_type in compiled_params.values() + assert expected_deducted_credits in compiled_params.values() + + def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup): + """Test check_and_deduct_credits translates DB update failures to QuotaExceededError.""" + mock_credit_deduction_setup["pool"].remaining_credits = 50 + mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure") + session = mock_credit_deduction_setup["session"] + + patches = mock_credit_deduction_setup["patches"] + mock_logger = patch.object(credit_pool_service_module, "logger") + + with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj: + with pytest.raises(QuotaExceededError, match="Failed to deduct credits"): + CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) + + session.commit.assert_not_called() + mock_logger_obj.exception.assert_called_once()