diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 0279725ff2..c6a270e470 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.engine import db +from models.enums import CredentialSourceType from models.provider import ( LoadBalancingModelConfig, Provider, @@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) except Exception: @@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "provider", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: @@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="custom_model", + credential_source=CredentialSourceType.CUSTOM_MODEL, session=session, ) except Exception: @@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "custom_model", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() @@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel): provider_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "custom_model" + if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL ] load_balancing_enabled = model_setting.load_balancing_enabled @@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel): custom_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "provider" + if config.credential_source_type != CredentialSourceType.PROVIDER ] load_balancing_enabled = model_setting.load_balancing_enabled diff --git a/api/models/provider.py b/api/models/provider.py index 7cefdbaba5..4e114bb034 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -13,6 +13,7 @@ from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db +from .enums import CredentialSourceType, PaymentStatus from .types import EnumText, LongText, StringUUID @@ -237,7 +238,9 @@ class ProviderOrder(TypeBase): quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) currency: Mapped[str | None] = mapped_column(String(40)) total_amount: Mapped[int | None] = mapped_column(sa.Integer) - payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'")) + payment_status: Mapped[PaymentStatus] = mapped_column( + EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'") + ) paid_at: Mapped[datetime | None] = mapped_column(DateTime) pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) refunded_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -300,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) - credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None) + credential_source_type: Mapped[CredentialSourceType | None] = mapped_column( + EnumText(CredentialSourceType, length=40), nullable=True, default=None + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 2133dc5b3a..bf3b6db3ed 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -103,9 +104,9 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True if config_from == "predefined-model": - credential_source_type = "provider" + credential_source_type = CredentialSourceType.PROVIDER else: - credential_source_type = "custom_model" + credential_source_type = CredentialSourceType.CUSTOM_MODEL # Get load balancing configurations load_balancing_configs = ( @@ -421,7 +422,11 @@ class ModelLoadBalancingService: raise ValueError("Invalid load balancing config name") if credential_id: - credential_source = "provider" if config_from == "predefined-model" else "custom_model" + credential_source = ( + CredentialSourceType.PROVIDER + if config_from == "predefined-model" + else CredentialSourceType.CUSTOM_MODEL + ) assert credential_record is not None load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 5ebefcd8d2..75473fc89a 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) +from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-base", name="LB Base", credentials={}, - credential_source_type="provider", + credential_source_type=CredentialSourceType.PROVIDER, ) ], ), @@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-custom", name="LB Custom", credentials={}, - credential_source_type="custom_model", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) ], ), @@ -826,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None: configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) @@ -844,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) diff --git a/api/tests/unit_tests/models/test_provider_models.py b/api/tests/unit_tests/models/test_provider_models.py index ec84a61c8e..f628e54a4d 100644 --- a/api/tests/unit_tests/models/test_provider_models.py +++ b/api/tests/unit_tests/models/test_provider_models.py @@ -19,6 +19,7 @@ from uuid import uuid4 import pytest +from models.enums import CredentialSourceType, PaymentStatus from models.provider import ( LoadBalancingModelConfig, Provider, @@ -158,7 +159,7 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == provider_name - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_used == 0 @@ -172,10 +173,10 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="anthropic", - provider_type="system", + provider_type=ProviderType.SYSTEM, is_valid=True, credential_id=credential_id, - quota_type="paid", + quota_type=ProviderQuotaType.PAID, quota_limit=10000, quota_used=500, ) @@ -183,10 +184,10 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == "anthropic" - assert provider.provider_type == "system" + assert provider.provider_type == ProviderType.SYSTEM assert provider.is_valid is True assert provider.credential_id == credential_id - assert provider.quota_type == "paid" + assert provider.quota_type == ProviderQuotaType.PAID assert provider.quota_limit == 10000 assert provider.quota_used == 500 @@ -199,7 +200,7 @@ class TestProviderModel: ) # Assert - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_type == "" assert provider.quota_limit is None @@ -213,7 +214,7 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="openai", - provider_type="custom", + provider_type=ProviderType.CUSTOM, ) # Act @@ -253,7 +254,7 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, is_valid=True, ) @@ -266,13 +267,13 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - quota_type="trial", + quota_type=ProviderQuotaType.TRIAL, quota_limit=1000, quota_used=250, ) # Assert - assert provider.quota_type == "trial" + assert provider.quota_type == ProviderQuotaType.TRIAL assert provider.quota_limit == 1000 assert provider.quota_used == 250 remaining = provider.quota_limit - provider.quota_used @@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=tenant_id, provider_name="openai", - preferred_provider_type="custom", + preferred_provider_type=ProviderType.CUSTOM, ) # Assert assert preferred.tenant_id == tenant_id assert preferred.provider_name == "openai" - assert preferred.preferred_provider_type == "custom" + assert preferred.preferred_provider_type == ProviderType.CUSTOM def test_tenant_preferred_provider_system_type(self): """Test tenant preferred provider with system type.""" @@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=str(uuid4()), provider_name="anthropic", - preferred_provider_type="system", + preferred_provider_type=ProviderType.SYSTEM, ) # Assert - assert preferred.preferred_provider_type == "system" + assert preferred.preferred_provider_type == ProviderType.SYSTEM class TestProviderOrder: @@ -470,7 +471,7 @@ class TestProviderOrder: quantity=1, currency=None, total_amount=None, - payment_status="wait_pay", + payment_status=PaymentStatus.WAIT_PAY, paid_at=None, pay_failed_at=None, refunded_at=None, @@ -481,7 +482,7 @@ class TestProviderOrder: assert order.provider_name == "openai" assert order.account_id == account_id assert order.payment_product_id == "prod_123" - assert order.payment_status == "wait_pay" + assert order.payment_status == PaymentStatus.WAIT_PAY assert order.quantity == 1 def test_provider_order_with_payment_details(self): @@ -502,7 +503,7 @@ class TestProviderOrder: quantity=5, currency="USD", total_amount=9999, - payment_status="paid", + payment_status=PaymentStatus.PAID, paid_at=paid_time, pay_failed_at=None, refunded_at=None, @@ -514,7 +515,7 @@ class TestProviderOrder: assert order.quantity == 5 assert order.currency == "USD" assert order.total_amount == 9999 - assert order.payment_status == "paid" + assert order.payment_status == PaymentStatus.PAID assert order.paid_at == paid_time def test_provider_order_payment_statuses(self): @@ -536,23 +537,23 @@ class TestProviderOrder: } # Act & Assert - Wait pay status - wait_order = ProviderOrder(**base_params, payment_status="wait_pay") - assert wait_order.payment_status == "wait_pay" + wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY) + assert wait_order.payment_status == PaymentStatus.WAIT_PAY # Act & Assert - Paid status - paid_order = ProviderOrder(**base_params, payment_status="paid") - assert paid_order.payment_status == "paid" + paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID) + assert paid_order.payment_status == PaymentStatus.PAID # Act & Assert - Failed status failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)} - failed_order = ProviderOrder(**failed_params, payment_status="failed") - assert failed_order.payment_status == "failed" + failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED) + assert failed_order.payment_status == PaymentStatus.FAILED assert failed_order.pay_failed_at is not None # Act & Assert - Refunded status refunded_params = {**base_params, "refunded_at": datetime.now(UTC)} - refunded_order = ProviderOrder(**refunded_params, payment_status="refunded") - assert refunded_order.payment_status == "refunded" + refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED) + assert refunded_order.payment_status == PaymentStatus.REFUNDED assert refunded_order.refunded_at is not None @@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig: name="Secondary API Key", encrypted_config='{"api_key": "encrypted_value"}', credential_id=credential_id, - credential_source_type="custom", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) # Assert assert config.encrypted_config == '{"api_key": "encrypted_value"}' assert config.credential_id == credential_id - assert config.credential_source_type == "custom" + assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL def test_load_balancing_config_disabled(self): """Test disabled load balancing config."""