diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 205a27ffdb..766280dbc7 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -111,6 +111,52 @@ def test_model_provider_factory_requires_runtime() -> None: ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] +def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + result = factory.get_providers() + + assert list(result) == providers + assert result is not providers + + +def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup() -> None: + provider = _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + result = factory.get_provider_schema("openai") + + assert result is provider + + +def test_model_provider_factory_raises_for_unknown_provider() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Invalid provider: anthropic"): + factory.get_model_provider("anthropic") + + def test_model_provider_factory_get_models_filters_provider_and_model_type() -> None: providers = [ _build_provider( @@ -135,6 +181,47 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() -> assert [model.model for model in results[0].models] == ["gpt-4o-mini"] +def test_model_provider_factory_get_models_skips_providers_without_requested_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + models=[_build_model("gpt-4o-mini", ModelType.LLM)], + ), + _build_provider( + provider="langgenius/elevenlabs/elevenlabs", + provider_name="elevenlabs", + supported_model_types=[ModelType.TTS], + models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(model_type=ModelType.TTS) + + assert len(results) == 1 + assert results[0].provider == "langgenius/elevenlabs/elevenlabs" + assert [model.model for model in results[0].models] == ["eleven_multilingual_v2"] + + +def test_model_provider_factory_get_models_without_model_type_keeps_all_provider_models() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai") + + assert len(results) == 1 + assert [model.model for model in results[0].models] == ["gpt-4o-mini", "tts-1"] + + def test_model_provider_factory_validates_provider_credentials() -> None: runtime = _FakeModelRuntime( [ @@ -169,6 +256,23 @@ def test_model_provider_factory_validates_provider_credentials() -> None: ) +def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"api_key": "secret"}) + + def test_model_provider_factory_validates_model_credentials() -> None: runtime = _FakeModelRuntime( [ @@ -208,6 +312,28 @@ def test_model_provider_factory_validates_model_credentials() -> None: ) +def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider() -> None: runtime = _FakeModelRuntime( [ @@ -275,3 +401,20 @@ def test_model_provider_factory_builds_model_type_instances( instance = factory.get_model_type_instance("openai", model_type) assert isinstance(instance, expected_type) + + +def test_model_provider_factory_rejects_unsupported_model_type() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Unsupported model type: unsupported"): + factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]