mirror of https://github.com/langgenius/dify.git
Add model provider factory coverage tests
This commit is contained in:
parent
1d0beda427
commit
f8538ae36a
|
|
@ -111,6 +111,52 @@ def test_model_provider_factory_requires_runtime() -> None:
|
|||
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
|
||||
providers = [
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
|
||||
result = factory.get_providers()
|
||||
|
||||
assert list(result) == providers
|
||||
assert result is not providers
|
||||
|
||||
|
||||
def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup() -> None:
|
||||
provider = _build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
|
||||
|
||||
result = factory.get_provider_schema("openai")
|
||||
|
||||
assert result is provider
|
||||
|
||||
|
||||
def test_model_provider_factory_raises_for_unknown_provider() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider: anthropic"):
|
||||
factory.get_model_provider("anthropic")
|
||||
|
||||
|
||||
def test_model_provider_factory_get_models_filters_provider_and_model_type() -> None:
|
||||
providers = [
|
||||
_build_provider(
|
||||
|
|
@ -135,6 +181,47 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
|
|||
assert [model.model for model in results[0].models] == ["gpt-4o-mini"]
|
||||
|
||||
|
||||
def test_model_provider_factory_get_models_skips_providers_without_requested_model_type() -> None:
|
||||
providers = [
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
models=[_build_model("gpt-4o-mini", ModelType.LLM)],
|
||||
),
|
||||
_build_provider(
|
||||
provider="langgenius/elevenlabs/elevenlabs",
|
||||
provider_name="elevenlabs",
|
||||
supported_model_types=[ModelType.TTS],
|
||||
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(model_type=ModelType.TTS)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].provider == "langgenius/elevenlabs/elevenlabs"
|
||||
assert [model.model for model in results[0].models] == ["eleven_multilingual_v2"]
|
||||
|
||||
|
||||
def test_model_provider_factory_get_models_without_model_type_keeps_all_provider_models() -> None:
|
||||
providers = [
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM, ModelType.TTS],
|
||||
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(provider="openai")
|
||||
|
||||
assert len(results) == 1
|
||||
assert [model.model for model in results[0].models] == ["gpt-4o-mini", "tts-1"]
|
||||
|
||||
|
||||
def test_model_provider_factory_validates_provider_credentials() -> None:
|
||||
runtime = _FakeModelRuntime(
|
||||
[
|
||||
|
|
@ -169,6 +256,23 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Provider openai does not have provider_credential_schema"):
|
||||
factory.provider_credentials_validate(provider="openai", credentials={"api_key": "secret"})
|
||||
|
||||
|
||||
def test_model_provider_factory_validates_model_credentials() -> None:
|
||||
runtime = _FakeModelRuntime(
|
||||
[
|
||||
|
|
@ -208,6 +312,28 @@ def test_model_provider_factory_validates_model_credentials() -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Provider openai does not have model_credential_schema"):
|
||||
factory.model_credentials_validate(
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
|
||||
def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider() -> None:
|
||||
runtime = _FakeModelRuntime(
|
||||
[
|
||||
|
|
@ -275,3 +401,20 @@ def test_model_provider_factory_builds_model_type_instances(
|
|||
instance = factory.get_model_type_instance("openai", model_type)
|
||||
|
||||
assert isinstance(instance, expected_type)
|
||||
|
||||
|
||||
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
|
||||
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
|
||||
|
|
|
|||
Loading…
Reference in New Issue