mirror of https://github.com/langgenius/dify.git
test: unit test case for controllers.common module (#32056)
This commit is contained in:
parent
46098b2be6
commit
2cc0de9c1b
|
|
@ -0,0 +1,70 @@
|
|||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
RemoteFileUploadError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
|
||||
|
||||
class TestFilenameNotExistsError:
|
||||
def test_defaults(self):
|
||||
error = FilenameNotExistsError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.description == "The specified filename does not exist."
|
||||
|
||||
|
||||
class TestRemoteFileUploadError:
|
||||
def test_defaults(self):
|
||||
error = RemoteFileUploadError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.description == "Error uploading remote file."
|
||||
|
||||
|
||||
class TestFileTooLargeError:
|
||||
def test_defaults(self):
|
||||
error = FileTooLargeError()
|
||||
|
||||
assert error.code == 413
|
||||
assert error.error_code == "file_too_large"
|
||||
assert error.description == "File size exceeded. {message}"
|
||||
|
||||
|
||||
class TestUnsupportedFileTypeError:
|
||||
def test_defaults(self):
|
||||
error = UnsupportedFileTypeError()
|
||||
|
||||
assert error.code == 415
|
||||
assert error.error_code == "unsupported_file_type"
|
||||
assert error.description == "File type not allowed."
|
||||
|
||||
|
||||
class TestBlockedFileExtensionError:
|
||||
def test_defaults(self):
|
||||
error = BlockedFileExtensionError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "file_extension_blocked"
|
||||
assert error.description == "The file extension is blocked for security reasons."
|
||||
|
||||
|
||||
class TestTooManyFilesError:
|
||||
def test_defaults(self):
|
||||
error = TooManyFilesError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "too_many_files"
|
||||
assert error.description == "Only one file is allowed."
|
||||
|
||||
|
||||
class TestNoFileUploadedError:
|
||||
def test_defaults(self):
|
||||
error = NoFileUploadedError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "no_file_uploaded"
|
||||
assert error.description == "Please upload your file."
|
||||
|
|
@ -1,22 +1,95 @@
|
|||
from flask import Response
|
||||
|
||||
from controllers.common.file_response import enforce_download_for_html, is_html_content
|
||||
from controllers.common.file_response import (
|
||||
_normalize_mime_type,
|
||||
enforce_download_for_html,
|
||||
is_html_content,
|
||||
)
|
||||
|
||||
|
||||
class TestFileResponseHelpers:
|
||||
def test_is_html_content_detects_mime_type(self):
|
||||
class TestNormalizeMimeType:
|
||||
def test_returns_empty_string_for_none(self):
|
||||
assert _normalize_mime_type(None) == ""
|
||||
|
||||
def test_returns_empty_string_for_empty_string(self):
|
||||
assert _normalize_mime_type("") == ""
|
||||
|
||||
def test_normalizes_mime_type(self):
|
||||
assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html"
|
||||
|
||||
|
||||
class TestIsHtmlContent:
|
||||
def test_detects_html_via_mime_type(self):
|
||||
mime_type = "text/html; charset=UTF-8"
|
||||
|
||||
result = is_html_content(mime_type, filename="file.txt", extension="txt")
|
||||
result = is_html_content(
|
||||
mime_type=mime_type,
|
||||
filename="file.txt",
|
||||
extension="txt",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_html_content_detects_extension(self):
|
||||
result = is_html_content("text/plain", filename="report.html", extension=None)
|
||||
def test_detects_html_via_extension_argument(self):
|
||||
result = is_html_content(
|
||||
mime_type="text/plain",
|
||||
filename=None,
|
||||
extension="html",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_enforce_download_for_html_sets_headers(self):
|
||||
def test_detects_html_via_filename_extension(self):
|
||||
result = is_html_content(
|
||||
mime_type="text/plain",
|
||||
filename="report.html",
|
||||
extension=None,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_no_html_detected_anywhere(self):
|
||||
"""
|
||||
Missing negative test:
|
||||
- MIME type is not HTML
|
||||
- filename has no HTML extension
|
||||
- extension argument is not HTML
|
||||
"""
|
||||
result = is_html_content(
|
||||
mime_type="application/json",
|
||||
filename="data.json",
|
||||
extension="json",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_when_all_inputs_are_none(self):
|
||||
result = is_html_content(
|
||||
mime_type=None,
|
||||
filename=None,
|
||||
extension=None,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestEnforceDownloadForHtml:
|
||||
def test_sets_attachment_when_filename_missing(self):
|
||||
response = Response("payload", mimetype="text/html")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
response,
|
||||
mime_type="text/html",
|
||||
filename=None,
|
||||
extension="html",
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
assert response.headers["Content-Disposition"] == "attachment"
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_sets_headers_when_filename_present(self):
|
||||
response = Response("payload", mimetype="text/html")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
|
|
@ -27,11 +100,12 @@ class TestFileResponseHelpers:
|
|||
)
|
||||
|
||||
assert updated is True
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Disposition"].startswith("attachment")
|
||||
assert "unsafe.html" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_enforce_download_for_html_no_change_for_non_html(self):
|
||||
def test_does_not_modify_response_for_non_html_content(self):
|
||||
response = Response("payload", mimetype="text/plain")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,188 @@
|
|||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from controllers.common import helpers
|
||||
from controllers.common.helpers import FileInfo, guess_file_info_from_response
|
||||
|
||||
|
||||
def make_response(
|
||||
url="https://example.com/file.txt",
|
||||
headers=None,
|
||||
content=None,
|
||||
):
|
||||
return httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", url),
|
||||
headers=headers or {},
|
||||
content=content or b"",
|
||||
)
|
||||
|
||||
|
||||
class TestGuessFileInfoFromResponse:
|
||||
def test_filename_from_url(self):
|
||||
response = make_response(
|
||||
url="https://example.com/test.pdf",
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.filename == "test.pdf"
|
||||
assert info.extension == ".pdf"
|
||||
assert info.mimetype == "application/pdf"
|
||||
|
||||
def test_filename_from_content_disposition(self):
|
||||
headers = {
|
||||
"Content-Disposition": "attachment; filename=myfile.csv",
|
||||
"Content-Type": "text/csv",
|
||||
}
|
||||
response = make_response(
|
||||
url="https://example.com/",
|
||||
headers=headers,
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.filename == "myfile.csv"
|
||||
assert info.extension == ".csv"
|
||||
assert info.mimetype == "text/csv"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("magic_available", "expected_ext"),
|
||||
[
|
||||
(True, "txt"),
|
||||
(False, "bin"),
|
||||
],
|
||||
)
|
||||
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
|
||||
if magic_available:
|
||||
if helpers.magic is None:
|
||||
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
|
||||
else:
|
||||
monkeypatch.setattr(helpers, "magic", None)
|
||||
|
||||
response = make_response(
|
||||
url="https://example.com/",
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
name, ext = info.filename.split(".")
|
||||
UUID(name)
|
||||
assert ext == expected_ext
|
||||
|
||||
def test_mimetype_from_header_when_unknown(self):
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = make_response(
|
||||
url="https://example.com/file.unknown",
|
||||
headers=headers,
|
||||
content=b'{"a": 1}',
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.mimetype == "application/json"
|
||||
|
||||
def test_extension_added_when_missing(self):
|
||||
headers = {"Content-Type": "image/png"}
|
||||
response = make_response(
|
||||
url="https://example.com/image",
|
||||
headers=headers,
|
||||
content=b"fakepngdata",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.extension == ".png"
|
||||
assert info.filename.endswith(".png")
|
||||
|
||||
def test_content_length_used_as_size(self):
|
||||
headers = {
|
||||
"Content-Length": "1234",
|
||||
"Content-Type": "text/plain",
|
||||
}
|
||||
response = make_response(
|
||||
url="https://example.com/a.txt",
|
||||
headers=headers,
|
||||
content=b"a" * 1234,
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.size == 1234
|
||||
|
||||
def test_size_minus_one_when_header_missing(self):
|
||||
response = make_response(url="https://example.com/a.txt")
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.size == -1
|
||||
|
||||
def test_fallback_to_bin_extension(self):
|
||||
headers = {"Content-Type": "application/octet-stream"}
|
||||
response = make_response(
|
||||
url="https://example.com/download",
|
||||
headers=headers,
|
||||
content=b"\x00\x01\x02\x03",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.extension == ".bin"
|
||||
assert info.filename.endswith(".bin")
|
||||
|
||||
def test_return_type(self):
|
||||
response = make_response()
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert isinstance(info, FileInfo)
|
||||
|
||||
|
||||
class TestMagicImportWarnings:
|
||||
@pytest.mark.parametrize(
|
||||
("platform_name", "expected_message"),
|
||||
[
|
||||
("Windows", "pip install python-magic-bin"),
|
||||
("Darwin", "brew install libmagic"),
|
||||
("Linux", "sudo apt-get install libmagic1"),
|
||||
("Other", "install `libmagic`"),
|
||||
],
|
||||
)
|
||||
def test_magic_import_warning_per_platform(
|
||||
self,
|
||||
monkeypatch,
|
||||
platform_name,
|
||||
expected_message,
|
||||
):
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
# Force ImportError when "magic" is imported
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "magic":
|
||||
raise ImportError("No module named magic")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
monkeypatch.setattr("platform.system", lambda: platform_name)
|
||||
|
||||
# Remove helpers so it imports fresh
|
||||
import sys
|
||||
|
||||
original_helpers = sys.modules.get(helpers.__name__)
|
||||
sys.modules.pop(helpers.__name__, None)
|
||||
|
||||
try:
|
||||
with pytest.warns(UserWarning, match="To use python-magic") as warning:
|
||||
imported_helpers = importlib.import_module(helpers.__name__)
|
||||
assert expected_message in str(warning[0].message)
|
||||
finally:
|
||||
if original_helpers is not None:
|
||||
sys.modules[helpers.__name__] = original_helpers
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
import sys
|
||||
from enum import StrEnum
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class ProductModel(BaseModel):
|
||||
id: int
|
||||
price: float
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_console_ns():
|
||||
"""Mock the console_ns to avoid circular imports during test collection."""
|
||||
mock_ns = MagicMock(spec=Namespace)
|
||||
mock_ns.models = {}
|
||||
|
||||
# Inject mock before importing schema module
|
||||
with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}):
|
||||
yield mock_ns
|
||||
|
||||
|
||||
def test_default_ref_template_value():
|
||||
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0
|
||||
|
||||
assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}"
|
||||
|
||||
|
||||
def test_register_schema_model_calls_namespace_schema_model():
|
||||
from controllers.common.schema import register_schema_model
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_model(namespace, UserModel)
|
||||
|
||||
namespace.schema_model.assert_called_once()
|
||||
|
||||
model_name, schema = namespace.schema_model.call_args.args
|
||||
|
||||
assert model_name == "UserModel"
|
||||
assert isinstance(schema, dict)
|
||||
assert "properties" in schema
|
||||
|
||||
|
||||
def test_register_schema_model_passes_schema_from_pydantic():
|
||||
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_model(namespace, UserModel)
|
||||
|
||||
schema = namespace.schema_model.call_args.args[1]
|
||||
|
||||
expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
|
||||
assert schema == expected_schema
|
||||
|
||||
|
||||
def test_register_schema_models_registers_multiple_models():
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_models(namespace, UserModel, ProductModel)
|
||||
|
||||
assert namespace.schema_model.call_count == 2
|
||||
|
||||
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
|
||||
assert called_names == ["UserModel", "ProductModel"]
|
||||
|
||||
|
||||
def test_register_schema_models_calls_register_schema_model(monkeypatch):
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_register(ns, model):
|
||||
calls.append((ns, model))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.common.schema.register_schema_model",
|
||||
fake_register,
|
||||
)
|
||||
|
||||
register_schema_models(namespace, UserModel, ProductModel)
|
||||
|
||||
assert calls == [
|
||||
(namespace, UserModel),
|
||||
(namespace, ProductModel),
|
||||
]
|
||||
|
||||
|
||||
class StatusEnum(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class PriorityEnum(StrEnum):
|
||||
HIGH = "high"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
def test_get_or_create_model_returns_existing_model(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
existing_model = MagicMock()
|
||||
mock_console_ns.models = {"TestModel": existing_model}
|
||||
|
||||
result = get_or_create_model("TestModel", {"key": "value"})
|
||||
|
||||
assert result == existing_model
|
||||
mock_console_ns.model.assert_not_called()
|
||||
|
||||
|
||||
def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
mock_console_ns.models = {}
|
||||
new_model = MagicMock()
|
||||
mock_console_ns.model.return_value = new_model
|
||||
field_def = {"name": {"type": "string"}}
|
||||
|
||||
result = get_or_create_model("NewModel", field_def)
|
||||
|
||||
assert result == new_model
|
||||
mock_console_ns.model.assert_called_once_with("NewModel", field_def)
|
||||
|
||||
|
||||
def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
existing_model = MagicMock()
|
||||
mock_console_ns.models = {"ExistingModel": existing_model}
|
||||
|
||||
result = get_or_create_model("ExistingModel", {"key": "value"})
|
||||
|
||||
assert result == existing_model
|
||||
mock_console_ns.model.assert_not_called()
|
||||
|
||||
|
||||
def test_register_enum_models_registers_single_enum():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum)
|
||||
|
||||
namespace.schema_model.assert_called_once()
|
||||
|
||||
model_name, schema = namespace.schema_model.call_args.args
|
||||
|
||||
assert model_name == "StatusEnum"
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
|
||||
def test_register_enum_models_registers_multiple_enums():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum, PriorityEnum)
|
||||
|
||||
assert namespace.schema_model.call_count == 2
|
||||
|
||||
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
|
||||
assert called_names == ["StatusEnum", "PriorityEnum"]
|
||||
|
||||
|
||||
def test_register_enum_models_uses_correct_ref_template():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum)
|
||||
|
||||
schema = namespace.schema_model.call_args.args[1]
|
||||
|
||||
# Verify the schema contains enum values
|
||||
assert "enum" in schema or "anyOf" in schema
|
||||
Loading…
Reference in New Issue