refactor: use EnumText in provider models (#33634)

This commit is contained in:
tmimmanuel 2026-03-18 04:27:40 +00:00 committed by GitHub
parent 3454224ff9
commit 04c0bf61fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 43 deletions

View File

@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.engine import db from models.engine import db
from models.enums import CredentialSourceType
from models.provider import ( from models.provider import (
LoadBalancingModelConfig, LoadBalancingModelConfig,
Provider, Provider,
@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential( self._update_load_balancing_configs_with_credential(
credential_id=credential_id, credential_id=credential_id,
credential_record=credential_record, credential_record=credential_record,
credential_source="provider", credential_source=CredentialSourceType.PROVIDER,
session=session, session=session,
) )
except Exception: except Exception:
@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider", LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
) )
lb_configs_using_credential = session.execute(lb_stmt).scalars().all() lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try: try:
@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential( self._update_load_balancing_configs_with_credential(
credential_id=credential_id, credential_id=credential_id,
credential_record=credential_record, credential_record=credential_record,
credential_source="custom_model", credential_source=CredentialSourceType.CUSTOM_MODEL,
session=session, session=session,
) )
except Exception: except Exception:
@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model", LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
) )
lb_configs_using_credential = session.execute(lb_stmt).scalars().all() lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
provider_model_lb_configs = [ provider_model_lb_configs = [
config config
for config in model_setting.load_balancing_configs for config in model_setting.load_balancing_configs
if config.credential_source_type != "custom_model" if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
] ]
load_balancing_enabled = model_setting.load_balancing_enabled load_balancing_enabled = model_setting.load_balancing_enabled
@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
custom_model_lb_configs = [ custom_model_lb_configs = [
config config
for config in model_setting.load_balancing_configs for config in model_setting.load_balancing_configs
if config.credential_source_type != "provider" if config.credential_source_type != CredentialSourceType.PROVIDER
] ]
load_balancing_enabled = model_setting.load_balancing_enabled load_balancing_enabled = model_setting.load_balancing_enabled

View File

@ -13,6 +13,7 @@ from libs.uuid_utils import uuidv7
from .base import TypeBase from .base import TypeBase
from .engine import db from .engine import db
from .enums import CredentialSourceType, PaymentStatus
from .types import EnumText, LongText, StringUUID from .types import EnumText, LongText, StringUUID
@ -237,7 +238,9 @@ class ProviderOrder(TypeBase):
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
currency: Mapped[str | None] = mapped_column(String(40)) currency: Mapped[str | None] = mapped_column(String(40))
total_amount: Mapped[int | None] = mapped_column(sa.Integer) total_amount: Mapped[int | None] = mapped_column(sa.Integer)
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'")) payment_status: Mapped[PaymentStatus] = mapped_column(
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
)
paid_at: Mapped[datetime | None] = mapped_column(DateTime) paid_at: Mapped[datetime | None] = mapped_column(DateTime)
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
refunded_at: Mapped[datetime | None] = mapped_column(DateTime) refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -300,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None) credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
EnumText(CredentialSourceType, length=40), nullable=True, default=None
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False DateTime, nullable=False, server_default=func.current_timestamp(), init=False

View File

@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.enums import CredentialSourceType
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -103,9 +104,9 @@ class ModelLoadBalancingService:
is_load_balancing_enabled = True is_load_balancing_enabled = True
if config_from == "predefined-model": if config_from == "predefined-model":
credential_source_type = "provider" credential_source_type = CredentialSourceType.PROVIDER
else: else:
credential_source_type = "custom_model" credential_source_type = CredentialSourceType.CUSTOM_MODEL
# Get load balancing configurations # Get load balancing configurations
load_balancing_configs = ( load_balancing_configs = (
@ -421,7 +422,11 @@ class ModelLoadBalancingService:
raise ValueError("Invalid load balancing config name") raise ValueError("Invalid load balancing config name")
if credential_id: if credential_id:
credential_source = "provider" if config_from == "predefined-model" else "custom_model" credential_source = (
CredentialSourceType.PROVIDER
if config_from == "predefined-model"
else CredentialSourceType.CUSTOM_MODEL
)
assert credential_record is not None assert credential_record is not None
load_balancing_model_config = LoadBalancingModelConfig( load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id, tenant_id=tenant_id,

View File

@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
ProviderCredentialSchema, ProviderCredentialSchema,
ProviderEntity, ProviderEntity,
) )
from models.enums import CredentialSourceType
from models.provider import ProviderType from models.provider import ProviderType
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
id="lb-base", id="lb-base",
name="LB Base", name="LB Base",
credentials={}, credentials={},
credential_source_type="provider", credential_source_type=CredentialSourceType.PROVIDER,
) )
], ],
), ),
@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
id="lb-custom", id="lb-custom",
name="LB Custom", name="LB Custom",
credentials={}, credentials={},
credential_source_type="custom_model", credential_source_type=CredentialSourceType.CUSTOM_MODEL,
) )
], ],
), ),
@ -826,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None:
configuration._update_load_balancing_configs_with_credential( configuration._update_load_balancing_configs_with_credential(
credential_id="cred-1", credential_id="cred-1",
credential_record=credential_record, credential_record=credential_record,
credential_source="provider", credential_source=CredentialSourceType.PROVIDER,
session=session, session=session,
) )
@ -844,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non
configuration._update_load_balancing_configs_with_credential( configuration._update_load_balancing_configs_with_credential(
credential_id="cred-1", credential_id="cred-1",
credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"),
credential_source="provider", credential_source=CredentialSourceType.PROVIDER,
session=session, session=session,
) )

View File

@ -19,6 +19,7 @@ from uuid import uuid4
import pytest import pytest
from models.enums import CredentialSourceType, PaymentStatus
from models.provider import ( from models.provider import (
LoadBalancingModelConfig, LoadBalancingModelConfig,
Provider, Provider,
@ -158,7 +159,7 @@ class TestProviderModel:
# Assert # Assert
assert provider.tenant_id == tenant_id assert provider.tenant_id == tenant_id
assert provider.provider_name == provider_name assert provider.provider_name == provider_name
assert provider.provider_type == "custom" assert provider.provider_type == ProviderType.CUSTOM
assert provider.is_valid is False assert provider.is_valid is False
assert provider.quota_used == 0 assert provider.quota_used == 0
@ -172,10 +173,10 @@ class TestProviderModel:
provider = Provider( provider = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_name="anthropic", provider_name="anthropic",
provider_type="system", provider_type=ProviderType.SYSTEM,
is_valid=True, is_valid=True,
credential_id=credential_id, credential_id=credential_id,
quota_type="paid", quota_type=ProviderQuotaType.PAID,
quota_limit=10000, quota_limit=10000,
quota_used=500, quota_used=500,
) )
@ -183,10 +184,10 @@ class TestProviderModel:
# Assert # Assert
assert provider.tenant_id == tenant_id assert provider.tenant_id == tenant_id
assert provider.provider_name == "anthropic" assert provider.provider_name == "anthropic"
assert provider.provider_type == "system" assert provider.provider_type == ProviderType.SYSTEM
assert provider.is_valid is True assert provider.is_valid is True
assert provider.credential_id == credential_id assert provider.credential_id == credential_id
assert provider.quota_type == "paid" assert provider.quota_type == ProviderQuotaType.PAID
assert provider.quota_limit == 10000 assert provider.quota_limit == 10000
assert provider.quota_used == 500 assert provider.quota_used == 500
@ -199,7 +200,7 @@ class TestProviderModel:
) )
# Assert # Assert
assert provider.provider_type == "custom" assert provider.provider_type == ProviderType.CUSTOM
assert provider.is_valid is False assert provider.is_valid is False
assert provider.quota_type == "" assert provider.quota_type == ""
assert provider.quota_limit is None assert provider.quota_limit is None
@ -213,7 +214,7 @@ class TestProviderModel:
provider = Provider( provider = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_name="openai", provider_name="openai",
provider_type="custom", provider_type=ProviderType.CUSTOM,
) )
# Act # Act
@ -253,7 +254,7 @@ class TestProviderModel:
provider = Provider( provider = Provider(
tenant_id=str(uuid4()), tenant_id=str(uuid4()),
provider_name="openai", provider_name="openai",
provider_type=ProviderType.SYSTEM.value, provider_type=ProviderType.SYSTEM,
is_valid=True, is_valid=True,
) )
@ -266,13 +267,13 @@ class TestProviderModel:
provider = Provider( provider = Provider(
tenant_id=str(uuid4()), tenant_id=str(uuid4()),
provider_name="openai", provider_name="openai",
quota_type="trial", quota_type=ProviderQuotaType.TRIAL,
quota_limit=1000, quota_limit=1000,
quota_used=250, quota_used=250,
) )
# Assert # Assert
assert provider.quota_type == "trial" assert provider.quota_type == ProviderQuotaType.TRIAL
assert provider.quota_limit == 1000 assert provider.quota_limit == 1000
assert provider.quota_used == 250 assert provider.quota_used == 250
remaining = provider.quota_limit - provider.quota_used remaining = provider.quota_limit - provider.quota_used
@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider:
preferred = TenantPreferredModelProvider( preferred = TenantPreferredModelProvider(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_name="openai", provider_name="openai",
preferred_provider_type="custom", preferred_provider_type=ProviderType.CUSTOM,
) )
# Assert # Assert
assert preferred.tenant_id == tenant_id assert preferred.tenant_id == tenant_id
assert preferred.provider_name == "openai" assert preferred.provider_name == "openai"
assert preferred.preferred_provider_type == "custom" assert preferred.preferred_provider_type == ProviderType.CUSTOM
def test_tenant_preferred_provider_system_type(self): def test_tenant_preferred_provider_system_type(self):
"""Test tenant preferred provider with system type.""" """Test tenant preferred provider with system type."""
@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider:
preferred = TenantPreferredModelProvider( preferred = TenantPreferredModelProvider(
tenant_id=str(uuid4()), tenant_id=str(uuid4()),
provider_name="anthropic", provider_name="anthropic",
preferred_provider_type="system", preferred_provider_type=ProviderType.SYSTEM,
) )
# Assert # Assert
assert preferred.preferred_provider_type == "system" assert preferred.preferred_provider_type == ProviderType.SYSTEM
class TestProviderOrder: class TestProviderOrder:
@ -470,7 +471,7 @@ class TestProviderOrder:
quantity=1, quantity=1,
currency=None, currency=None,
total_amount=None, total_amount=None,
payment_status="wait_pay", payment_status=PaymentStatus.WAIT_PAY,
paid_at=None, paid_at=None,
pay_failed_at=None, pay_failed_at=None,
refunded_at=None, refunded_at=None,
@ -481,7 +482,7 @@ class TestProviderOrder:
assert order.provider_name == "openai" assert order.provider_name == "openai"
assert order.account_id == account_id assert order.account_id == account_id
assert order.payment_product_id == "prod_123" assert order.payment_product_id == "prod_123"
assert order.payment_status == "wait_pay" assert order.payment_status == PaymentStatus.WAIT_PAY
assert order.quantity == 1 assert order.quantity == 1
def test_provider_order_with_payment_details(self): def test_provider_order_with_payment_details(self):
@ -502,7 +503,7 @@ class TestProviderOrder:
quantity=5, quantity=5,
currency="USD", currency="USD",
total_amount=9999, total_amount=9999,
payment_status="paid", payment_status=PaymentStatus.PAID,
paid_at=paid_time, paid_at=paid_time,
pay_failed_at=None, pay_failed_at=None,
refunded_at=None, refunded_at=None,
@ -514,7 +515,7 @@ class TestProviderOrder:
assert order.quantity == 5 assert order.quantity == 5
assert order.currency == "USD" assert order.currency == "USD"
assert order.total_amount == 9999 assert order.total_amount == 9999
assert order.payment_status == "paid" assert order.payment_status == PaymentStatus.PAID
assert order.paid_at == paid_time assert order.paid_at == paid_time
def test_provider_order_payment_statuses(self): def test_provider_order_payment_statuses(self):
@ -536,23 +537,23 @@ class TestProviderOrder:
} }
# Act & Assert - Wait pay status # Act & Assert - Wait pay status
wait_order = ProviderOrder(**base_params, payment_status="wait_pay") wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY)
assert wait_order.payment_status == "wait_pay" assert wait_order.payment_status == PaymentStatus.WAIT_PAY
# Act & Assert - Paid status # Act & Assert - Paid status
paid_order = ProviderOrder(**base_params, payment_status="paid") paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID)
assert paid_order.payment_status == "paid" assert paid_order.payment_status == PaymentStatus.PAID
# Act & Assert - Failed status # Act & Assert - Failed status
failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)} failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)}
failed_order = ProviderOrder(**failed_params, payment_status="failed") failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED)
assert failed_order.payment_status == "failed" assert failed_order.payment_status == PaymentStatus.FAILED
assert failed_order.pay_failed_at is not None assert failed_order.pay_failed_at is not None
# Act & Assert - Refunded status # Act & Assert - Refunded status
refunded_params = {**base_params, "refunded_at": datetime.now(UTC)} refunded_params = {**base_params, "refunded_at": datetime.now(UTC)}
refunded_order = ProviderOrder(**refunded_params, payment_status="refunded") refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED)
assert refunded_order.payment_status == "refunded" assert refunded_order.payment_status == PaymentStatus.REFUNDED
assert refunded_order.refunded_at is not None assert refunded_order.refunded_at is not None
@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig:
name="Secondary API Key", name="Secondary API Key",
encrypted_config='{"api_key": "encrypted_value"}', encrypted_config='{"api_key": "encrypted_value"}',
credential_id=credential_id, credential_id=credential_id,
credential_source_type="custom", credential_source_type=CredentialSourceType.CUSTOM_MODEL,
) )
# Assert # Assert
assert config.encrypted_config == '{"api_key": "encrypted_value"}' assert config.encrypted_config == '{"api_key": "encrypted_value"}'
assert config.credential_id == credential_id assert config.credential_id == credential_id
assert config.credential_source_type == "custom" assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL
def test_load_balancing_config_disabled(self): def test_load_balancing_config_disabled(self):
"""Test disabled load balancing config.""" """Test disabled load balancing config."""