mirror of https://github.com/langgenius/dify.git
Merge d1e477e372 into 27c4faad4f
This commit is contained in:
commit
b28f91945d
|
|
@ -1,7 +1,10 @@
|
|||
"""Testcontainers integration tests for OAuth controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import (
|
||||
OAuthCallback,
|
||||
|
|
@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError
|
|||
|
||||
class TestGetOAuthProviders:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("github_config", "google_config", "expected_github", "expected_google"),
|
||||
|
|
@ -64,10 +65,8 @@ class TestOAuthLogin:
|
|||
return OAuthLogin()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider(self):
|
||||
|
|
@ -131,10 +130,8 @@ class TestOAuthCallback:
|
|||
return OAuthCallback()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_setup(self):
|
||||
|
|
@ -190,15 +187,8 @@ class TestOAuthCallback:
|
|||
(KeyError("Missing key"), "OAuth process failed"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_should_handle_oauth_exceptions(
|
||||
self, mock_get_providers, mock_db, resource, app, exception, expected_error
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
|
||||
# Import the real requests module to create a proper exception
|
||||
import httpx
|
||||
|
||||
|
|
@ -258,7 +248,6 @@ class TestOAuthCallback:
|
|||
)
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
|
|
@ -269,7 +258,6 @@ class TestOAuthCallback:
|
|||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
resource,
|
||||
|
|
@ -278,10 +266,6 @@ class TestOAuthCallback:
|
|||
account_status,
|
||||
expected_redirect,
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
mock_db.session.commit = MagicMock()
|
||||
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
|
@ -306,14 +290,12 @@ class TestOAuthCallback:
|
|||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
def test_should_activate_pending_account(
|
||||
self,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
|
|
@ -338,12 +320,10 @@ class TestOAuthCallback:
|
|||
|
||||
assert mock_account.status == AccountStatus.ACTIVE
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
|
|
@ -352,7 +332,6 @@ class TestOAuthCallback:
|
|||
mock_redirect,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
|
|
@ -425,15 +404,10 @@ class TestAccountGeneration:
|
|||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.oauth.Session")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
|
||||
self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for Session creation
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Test OpenID found
|
||||
mock_account_model.get_by_openid.return_value = mock_account
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
|
|
@ -443,27 +417,39 @@ class TestAccountGeneration:
|
|||
|
||||
# Test fallback to email lookup
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
|
||||
mock_get_account.assert_called_once()
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
|
||||
mock_session = MagicMock()
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
def test_get_account_by_email_with_case_fallback_uses_real_db(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers
|
||||
):
|
||||
"""Test case-insensitive email lookup against real PostgreSQL."""
|
||||
from uuid import uuid4
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
from models.account import Account
|
||||
|
||||
assert result == expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
test_email = f"Case-{uuid4()}@Test.com"
|
||||
account = Account(
|
||||
email=test_email,
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback(test_email, session=db_session_with_containers)
|
||||
assert result is not None
|
||||
|
||||
result_lower = AccountService.get_account_by_email_with_case_fallback(
|
||||
test_email.lower(), session=db_session_with_containers
|
||||
)
|
||||
assert result_lower is not None
|
||||
assert result_lower.id == account.id
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_register", "existing_account", "should_create"),
|
||||
|
|
@ -478,10 +464,8 @@ class TestAccountGeneration:
|
|||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_handle_account_generation_scenarios(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
|
|
@ -519,10 +503,8 @@ class TestAccountGeneration:
|
|||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
Loading…
Reference in New Issue