mirror of https://github.com/langgenius/dify.git
fix: use RetrievalModel type for retrieval_model field in HitTestingPayload (#33750)
This commit is contained in:
parent
c93289e93c
commit
2b8823f38d
|
|
@ -24,6 +24,7 @@ from fields.hit_testing_fields import hit_testing_record_fields
|
|||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -39,14 +39,21 @@ class TestHitTestingPayload:
|
|||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all optional fields."""
|
||||
retrieval_model_data = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
"top_k": 5,
|
||||
}
|
||||
payload = HitTestingPayload(
|
||||
query="test query",
|
||||
retrieval_model={"top_k": 5},
|
||||
retrieval_model=retrieval_model_data,
|
||||
external_retrieval_model={"provider": "openai"},
|
||||
attachment_ids=["att_1", "att_2"],
|
||||
)
|
||||
assert payload.query == "test query"
|
||||
assert payload.retrieval_model == {"top_k": 5}
|
||||
assert payload.retrieval_model is not None
|
||||
assert payload.retrieval_model.top_k == 5
|
||||
assert payload.external_retrieval_model == {"provider": "openai"}
|
||||
assert payload.attachment_ids == ["att_1", "att_2"]
|
||||
|
||||
|
|
@ -134,7 +141,13 @@ class TestHitTestingApiPost:
|
|||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8}
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": True,
|
||||
"top_k": 10,
|
||||
"score_threshold": 0.8,
|
||||
}
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
|
|
@ -152,7 +165,11 @@ class TestHitTestingApiPost:
|
|||
|
||||
assert response["query"] == "complex query"
|
||||
call_kwargs = mock_hit_svc.retrieve.call_args
|
||||
assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model
|
||||
# retrieval_model is serialized via model_dump, verify key fields
|
||||
passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model")
|
||||
assert passed_retrieval_model is not None
|
||||
assert passed_retrieval_model["search_method"] == "semantic_search"
|
||||
assert passed_retrieval_model["top_k"] == 10
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
|
|
|
|||
Loading…
Reference in New Issue