mirror of https://github.com/langgenius/dify.git
refactor: use EnumText in provider models (#33634)
This commit is contained in:
parent
3454224ff9
commit
04c0bf61fa
|
|
@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.engine import db
|
from models.engine import db
|
||||||
|
from models.enums import CredentialSourceType
|
||||||
from models.provider import (
|
from models.provider import (
|
||||||
LoadBalancingModelConfig,
|
LoadBalancingModelConfig,
|
||||||
Provider,
|
Provider,
|
||||||
|
|
@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
self._update_load_balancing_configs_with_credential(
|
self._update_load_balancing_configs_with_credential(
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
credential_record=credential_record,
|
credential_record=credential_record,
|
||||||
credential_source="provider",
|
credential_source=CredentialSourceType.PROVIDER,
|
||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||||
LoadBalancingModelConfig.credential_id == credential_id,
|
LoadBalancingModelConfig.credential_id == credential_id,
|
||||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
|
||||||
)
|
)
|
||||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||||
try:
|
try:
|
||||||
|
|
@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
self._update_load_balancing_configs_with_credential(
|
self._update_load_balancing_configs_with_credential(
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
credential_record=credential_record,
|
credential_record=credential_record,
|
||||||
credential_source="custom_model",
|
credential_source=CredentialSourceType.CUSTOM_MODEL,
|
||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||||
LoadBalancingModelConfig.credential_id == credential_id,
|
LoadBalancingModelConfig.credential_id == credential_id,
|
||||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
|
||||||
)
|
)
|
||||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||||
|
|
||||||
|
|
@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
provider_model_lb_configs = [
|
provider_model_lb_configs = [
|
||||||
config
|
config
|
||||||
for config in model_setting.load_balancing_configs
|
for config in model_setting.load_balancing_configs
|
||||||
if config.credential_source_type != "custom_model"
|
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
|
||||||
]
|
]
|
||||||
|
|
||||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||||
|
|
@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
custom_model_lb_configs = [
|
custom_model_lb_configs = [
|
||||||
config
|
config
|
||||||
for config in model_setting.load_balancing_configs
|
for config in model_setting.load_balancing_configs
|
||||||
if config.credential_source_type != "provider"
|
if config.credential_source_type != CredentialSourceType.PROVIDER
|
||||||
]
|
]
|
||||||
|
|
||||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from libs.uuid_utils import uuidv7
|
||||||
|
|
||||||
from .base import TypeBase
|
from .base import TypeBase
|
||||||
from .engine import db
|
from .engine import db
|
||||||
|
from .enums import CredentialSourceType, PaymentStatus
|
||||||
from .types import EnumText, LongText, StringUUID
|
from .types import EnumText, LongText, StringUUID
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -237,7 +238,9 @@ class ProviderOrder(TypeBase):
|
||||||
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
||||||
currency: Mapped[str | None] = mapped_column(String(40))
|
currency: Mapped[str | None] = mapped_column(String(40))
|
||||||
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
|
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
|
||||||
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
|
payment_status: Mapped[PaymentStatus] = mapped_column(
|
||||||
|
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
|
||||||
|
)
|
||||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
|
|
@ -300,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
|
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
|
||||||
|
EnumText(CredentialSourceType, length=40), nullable=True, default=None
|
||||||
|
)
|
||||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models.enums import CredentialSourceType
|
||||||
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
|
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -103,9 +104,9 @@ class ModelLoadBalancingService:
|
||||||
is_load_balancing_enabled = True
|
is_load_balancing_enabled = True
|
||||||
|
|
||||||
if config_from == "predefined-model":
|
if config_from == "predefined-model":
|
||||||
credential_source_type = "provider"
|
credential_source_type = CredentialSourceType.PROVIDER
|
||||||
else:
|
else:
|
||||||
credential_source_type = "custom_model"
|
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
||||||
|
|
||||||
# Get load balancing configurations
|
# Get load balancing configurations
|
||||||
load_balancing_configs = (
|
load_balancing_configs = (
|
||||||
|
|
@ -421,7 +422,11 @@ class ModelLoadBalancingService:
|
||||||
raise ValueError("Invalid load balancing config name")
|
raise ValueError("Invalid load balancing config name")
|
||||||
|
|
||||||
if credential_id:
|
if credential_id:
|
||||||
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
|
credential_source = (
|
||||||
|
CredentialSourceType.PROVIDER
|
||||||
|
if config_from == "predefined-model"
|
||||||
|
else CredentialSourceType.CUSTOM_MODEL
|
||||||
|
)
|
||||||
assert credential_record is not None
|
assert credential_record is not None
|
||||||
load_balancing_model_config = LoadBalancingModelConfig(
|
load_balancing_model_config = LoadBalancingModelConfig(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||||
ProviderCredentialSchema,
|
ProviderCredentialSchema,
|
||||||
ProviderEntity,
|
ProviderEntity,
|
||||||
)
|
)
|
||||||
|
from models.enums import CredentialSourceType
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
|
|
||||||
|
|
@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
|
||||||
id="lb-base",
|
id="lb-base",
|
||||||
name="LB Base",
|
name="LB Base",
|
||||||
credentials={},
|
credentials={},
|
||||||
credential_source_type="provider",
|
credential_source_type=CredentialSourceType.PROVIDER,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|
@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
|
||||||
id="lb-custom",
|
id="lb-custom",
|
||||||
name="LB Custom",
|
name="LB Custom",
|
||||||
credentials={},
|
credentials={},
|
||||||
credential_source_type="custom_model",
|
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|
@ -826,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None:
|
||||||
configuration._update_load_balancing_configs_with_credential(
|
configuration._update_load_balancing_configs_with_credential(
|
||||||
credential_id="cred-1",
|
credential_id="cred-1",
|
||||||
credential_record=credential_record,
|
credential_record=credential_record,
|
||||||
credential_source="provider",
|
credential_source=CredentialSourceType.PROVIDER,
|
||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -844,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non
|
||||||
configuration._update_load_balancing_configs_with_credential(
|
configuration._update_load_balancing_configs_with_credential(
|
||||||
credential_id="cred-1",
|
credential_id="cred-1",
|
||||||
credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"),
|
credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"),
|
||||||
credential_source="provider",
|
credential_source=CredentialSourceType.PROVIDER,
|
||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from models.enums import CredentialSourceType, PaymentStatus
|
||||||
from models.provider import (
|
from models.provider import (
|
||||||
LoadBalancingModelConfig,
|
LoadBalancingModelConfig,
|
||||||
Provider,
|
Provider,
|
||||||
|
|
@ -158,7 +159,7 @@ class TestProviderModel:
|
||||||
# Assert
|
# Assert
|
||||||
assert provider.tenant_id == tenant_id
|
assert provider.tenant_id == tenant_id
|
||||||
assert provider.provider_name == provider_name
|
assert provider.provider_name == provider_name
|
||||||
assert provider.provider_type == "custom"
|
assert provider.provider_type == ProviderType.CUSTOM
|
||||||
assert provider.is_valid is False
|
assert provider.is_valid is False
|
||||||
assert provider.quota_used == 0
|
assert provider.quota_used == 0
|
||||||
|
|
||||||
|
|
@ -172,10 +173,10 @@ class TestProviderModel:
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_name="anthropic",
|
provider_name="anthropic",
|
||||||
provider_type="system",
|
provider_type=ProviderType.SYSTEM,
|
||||||
is_valid=True,
|
is_valid=True,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
quota_type="paid",
|
quota_type=ProviderQuotaType.PAID,
|
||||||
quota_limit=10000,
|
quota_limit=10000,
|
||||||
quota_used=500,
|
quota_used=500,
|
||||||
)
|
)
|
||||||
|
|
@ -183,10 +184,10 @@ class TestProviderModel:
|
||||||
# Assert
|
# Assert
|
||||||
assert provider.tenant_id == tenant_id
|
assert provider.tenant_id == tenant_id
|
||||||
assert provider.provider_name == "anthropic"
|
assert provider.provider_name == "anthropic"
|
||||||
assert provider.provider_type == "system"
|
assert provider.provider_type == ProviderType.SYSTEM
|
||||||
assert provider.is_valid is True
|
assert provider.is_valid is True
|
||||||
assert provider.credential_id == credential_id
|
assert provider.credential_id == credential_id
|
||||||
assert provider.quota_type == "paid"
|
assert provider.quota_type == ProviderQuotaType.PAID
|
||||||
assert provider.quota_limit == 10000
|
assert provider.quota_limit == 10000
|
||||||
assert provider.quota_used == 500
|
assert provider.quota_used == 500
|
||||||
|
|
||||||
|
|
@ -199,7 +200,7 @@ class TestProviderModel:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert provider.provider_type == "custom"
|
assert provider.provider_type == ProviderType.CUSTOM
|
||||||
assert provider.is_valid is False
|
assert provider.is_valid is False
|
||||||
assert provider.quota_type == ""
|
assert provider.quota_type == ""
|
||||||
assert provider.quota_limit is None
|
assert provider.quota_limit is None
|
||||||
|
|
@ -213,7 +214,7 @@ class TestProviderModel:
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_name="openai",
|
provider_name="openai",
|
||||||
provider_type="custom",
|
provider_type=ProviderType.CUSTOM,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|
@ -253,7 +254,7 @@ class TestProviderModel:
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
tenant_id=str(uuid4()),
|
tenant_id=str(uuid4()),
|
||||||
provider_name="openai",
|
provider_name="openai",
|
||||||
provider_type=ProviderType.SYSTEM.value,
|
provider_type=ProviderType.SYSTEM,
|
||||||
is_valid=True,
|
is_valid=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -266,13 +267,13 @@ class TestProviderModel:
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
tenant_id=str(uuid4()),
|
tenant_id=str(uuid4()),
|
||||||
provider_name="openai",
|
provider_name="openai",
|
||||||
quota_type="trial",
|
quota_type=ProviderQuotaType.TRIAL,
|
||||||
quota_limit=1000,
|
quota_limit=1000,
|
||||||
quota_used=250,
|
quota_used=250,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert provider.quota_type == "trial"
|
assert provider.quota_type == ProviderQuotaType.TRIAL
|
||||||
assert provider.quota_limit == 1000
|
assert provider.quota_limit == 1000
|
||||||
assert provider.quota_used == 250
|
assert provider.quota_used == 250
|
||||||
remaining = provider.quota_limit - provider.quota_used
|
remaining = provider.quota_limit - provider.quota_used
|
||||||
|
|
@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider:
|
||||||
preferred = TenantPreferredModelProvider(
|
preferred = TenantPreferredModelProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_name="openai",
|
provider_name="openai",
|
||||||
preferred_provider_type="custom",
|
preferred_provider_type=ProviderType.CUSTOM,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert preferred.tenant_id == tenant_id
|
assert preferred.tenant_id == tenant_id
|
||||||
assert preferred.provider_name == "openai"
|
assert preferred.provider_name == "openai"
|
||||||
assert preferred.preferred_provider_type == "custom"
|
assert preferred.preferred_provider_type == ProviderType.CUSTOM
|
||||||
|
|
||||||
def test_tenant_preferred_provider_system_type(self):
|
def test_tenant_preferred_provider_system_type(self):
|
||||||
"""Test tenant preferred provider with system type."""
|
"""Test tenant preferred provider with system type."""
|
||||||
|
|
@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider:
|
||||||
preferred = TenantPreferredModelProvider(
|
preferred = TenantPreferredModelProvider(
|
||||||
tenant_id=str(uuid4()),
|
tenant_id=str(uuid4()),
|
||||||
provider_name="anthropic",
|
provider_name="anthropic",
|
||||||
preferred_provider_type="system",
|
preferred_provider_type=ProviderType.SYSTEM,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert preferred.preferred_provider_type == "system"
|
assert preferred.preferred_provider_type == ProviderType.SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class TestProviderOrder:
|
class TestProviderOrder:
|
||||||
|
|
@ -470,7 +471,7 @@ class TestProviderOrder:
|
||||||
quantity=1,
|
quantity=1,
|
||||||
currency=None,
|
currency=None,
|
||||||
total_amount=None,
|
total_amount=None,
|
||||||
payment_status="wait_pay",
|
payment_status=PaymentStatus.WAIT_PAY,
|
||||||
paid_at=None,
|
paid_at=None,
|
||||||
pay_failed_at=None,
|
pay_failed_at=None,
|
||||||
refunded_at=None,
|
refunded_at=None,
|
||||||
|
|
@ -481,7 +482,7 @@ class TestProviderOrder:
|
||||||
assert order.provider_name == "openai"
|
assert order.provider_name == "openai"
|
||||||
assert order.account_id == account_id
|
assert order.account_id == account_id
|
||||||
assert order.payment_product_id == "prod_123"
|
assert order.payment_product_id == "prod_123"
|
||||||
assert order.payment_status == "wait_pay"
|
assert order.payment_status == PaymentStatus.WAIT_PAY
|
||||||
assert order.quantity == 1
|
assert order.quantity == 1
|
||||||
|
|
||||||
def test_provider_order_with_payment_details(self):
|
def test_provider_order_with_payment_details(self):
|
||||||
|
|
@ -502,7 +503,7 @@ class TestProviderOrder:
|
||||||
quantity=5,
|
quantity=5,
|
||||||
currency="USD",
|
currency="USD",
|
||||||
total_amount=9999,
|
total_amount=9999,
|
||||||
payment_status="paid",
|
payment_status=PaymentStatus.PAID,
|
||||||
paid_at=paid_time,
|
paid_at=paid_time,
|
||||||
pay_failed_at=None,
|
pay_failed_at=None,
|
||||||
refunded_at=None,
|
refunded_at=None,
|
||||||
|
|
@ -514,7 +515,7 @@ class TestProviderOrder:
|
||||||
assert order.quantity == 5
|
assert order.quantity == 5
|
||||||
assert order.currency == "USD"
|
assert order.currency == "USD"
|
||||||
assert order.total_amount == 9999
|
assert order.total_amount == 9999
|
||||||
assert order.payment_status == "paid"
|
assert order.payment_status == PaymentStatus.PAID
|
||||||
assert order.paid_at == paid_time
|
assert order.paid_at == paid_time
|
||||||
|
|
||||||
def test_provider_order_payment_statuses(self):
|
def test_provider_order_payment_statuses(self):
|
||||||
|
|
@ -536,23 +537,23 @@ class TestProviderOrder:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Act & Assert - Wait pay status
|
# Act & Assert - Wait pay status
|
||||||
wait_order = ProviderOrder(**base_params, payment_status="wait_pay")
|
wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY)
|
||||||
assert wait_order.payment_status == "wait_pay"
|
assert wait_order.payment_status == PaymentStatus.WAIT_PAY
|
||||||
|
|
||||||
# Act & Assert - Paid status
|
# Act & Assert - Paid status
|
||||||
paid_order = ProviderOrder(**base_params, payment_status="paid")
|
paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID)
|
||||||
assert paid_order.payment_status == "paid"
|
assert paid_order.payment_status == PaymentStatus.PAID
|
||||||
|
|
||||||
# Act & Assert - Failed status
|
# Act & Assert - Failed status
|
||||||
failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)}
|
failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)}
|
||||||
failed_order = ProviderOrder(**failed_params, payment_status="failed")
|
failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED)
|
||||||
assert failed_order.payment_status == "failed"
|
assert failed_order.payment_status == PaymentStatus.FAILED
|
||||||
assert failed_order.pay_failed_at is not None
|
assert failed_order.pay_failed_at is not None
|
||||||
|
|
||||||
# Act & Assert - Refunded status
|
# Act & Assert - Refunded status
|
||||||
refunded_params = {**base_params, "refunded_at": datetime.now(UTC)}
|
refunded_params = {**base_params, "refunded_at": datetime.now(UTC)}
|
||||||
refunded_order = ProviderOrder(**refunded_params, payment_status="refunded")
|
refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED)
|
||||||
assert refunded_order.payment_status == "refunded"
|
assert refunded_order.payment_status == PaymentStatus.REFUNDED
|
||||||
assert refunded_order.refunded_at is not None
|
assert refunded_order.refunded_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig:
|
||||||
name="Secondary API Key",
|
name="Secondary API Key",
|
||||||
encrypted_config='{"api_key": "encrypted_value"}',
|
encrypted_config='{"api_key": "encrypted_value"}',
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
credential_source_type="custom",
|
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert config.encrypted_config == '{"api_key": "encrypted_value"}'
|
assert config.encrypted_config == '{"api_key": "encrypted_value"}'
|
||||||
assert config.credential_id == credential_id
|
assert config.credential_id == credential_id
|
||||||
assert config.credential_source_type == "custom"
|
assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL
|
||||||
|
|
||||||
def test_load_balancing_config_disabled(self):
|
def test_load_balancing_config_disabled(self):
|
||||||
"""Test disabled load balancing config."""
|
"""Test disabled load balancing config."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue