Merge branch 'main' into 3-18-no-global-loading

This commit is contained in:
Stephen Zhou 2026-03-20 10:57:18 +08:00 committed by GitHub
commit 088a481432
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
297 changed files with 10091 additions and 2130 deletions

View File

@ -6,8 +6,8 @@ runs:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
with:
node-version-file: "./web/.nvmrc"
node-version-file: web/.nvmrc
cache: true
cache-dependency-path: web/pnpm-lock.yaml
run-install: |
- cwd: ./web
args: ['--frozen-lockfile']
cwd: ./web

View File

@ -2,6 +2,12 @@ name: Run Pytest
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}
@ -11,6 +17,8 @@ jobs:
test:
name: API Tests
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
@ -24,6 +32,7 @@ jobs:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
@ -79,21 +88,12 @@ jobs:
api/tests/test_containers_integration_tests \
api/tests/unit_tests
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }}
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
with:
files: ./coverage.xml
disable_search: true
flags: api
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -56,6 +56,7 @@ jobs:
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
secrets: inherit
web-tests:
name: Web Tests

View File

@ -1,9 +1,11 @@
import json
import logging
from typing import Any
from typing import Any, cast
import click
from pydantic import TypeAdapter
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from configs import dify_config
from core.helper import encrypter
@ -48,14 +50,15 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(ToolOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
deleted_count = cast(
CursorResult,
db.session.execute(
delete(ToolOAuthSystemClient).where(
ToolOAuthSystemClient.provider == provider_name,
ToolOAuthSystemClient.plugin_id == plugin_id,
)
),
).rowcount
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@ -97,14 +100,15 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
deleted_count = cast(
CursorResult,
db.session.execute(
delete(TriggerOAuthSystemClient).where(
TriggerOAuthSystemClient.provider == provider_name,
TriggerOAuthSystemClient.plugin_id == plugin_id,
)
),
).rowcount
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@ -139,14 +143,15 @@ def setup_datasource_oauth_client(provider, client_params):
return
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
deleted_count = (
db.session.query(DatasourceOauthParamConfig)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
deleted_count = cast(
CursorResult,
db.session.execute(
delete(DatasourceOauthParamConfig).where(
DatasourceOauthParamConfig.provider == provider_name,
DatasourceOauthParamConfig.plugin_id == plugin_id,
)
),
).rowcount
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@ -192,7 +197,9 @@ def transform_datasource_credentials(environment: str):
# deal notion credentials
deal_notion_count = 0
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
notion_credentials = db.session.scalars(
select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion")
).all()
if notion_credentials:
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for notion_credential in notion_credentials:
@ -201,7 +208,7 @@ def transform_datasource_credentials(environment: str):
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
try:
@ -250,7 +257,9 @@ def transform_datasource_credentials(environment: str):
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
firecrawl_credentials = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl")
).all()
if firecrawl_credentials:
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for firecrawl_credential in firecrawl_credentials:
@ -259,7 +268,7 @@ def transform_datasource_credentials(environment: str):
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
try:
@ -312,7 +321,9 @@ def transform_datasource_credentials(environment: str):
db.session.commit()
# deal jina credentials
deal_jina_count = 0
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
jina_credentials = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader")
).all()
if jina_credentials:
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for jina_credential in jina_credentials:
@ -321,7 +332,7 @@ def transform_datasource_credentials(environment: str):
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
try:

View File

@ -1,7 +1,10 @@
import json
from typing import cast
import click
import sqlalchemy as sa
from sqlalchemy import update
from sqlalchemy.engine import CursorResult
from configs import dify_config
from extensions.ext_database import db
@ -740,14 +743,17 @@ def migrate_oss(
else:
try:
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
updated = (
db.session.query(UploadFile)
.where(
UploadFile.storage_type == source_storage_type,
UploadFile.key.in_(copied_upload_file_keys),
)
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
)
updated = cast(
CursorResult,
db.session.execute(
update(UploadFile)
.where(
UploadFile.storage_type == source_storage_type,
UploadFile.key.in_(copied_upload_file_keys),
)
.values(storage_type=dify_config.STORAGE_TYPE)
),
).rowcount
db.session.commit()
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
except Exception as e:

View File

@ -2,6 +2,7 @@ import logging
import click
import sqlalchemy as sa
from sqlalchemy import delete, select, update
from sqlalchemy.orm import sessionmaker
from configs import dify_config
@ -41,7 +42,7 @@ def reset_encrypt_key_pair():
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
tenants = session.scalars(select(Tenant)).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
@ -49,8 +50,8 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id)
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
click.echo(
click.style(
@ -93,7 +94,7 @@ def convert_to_agent_apps():
app_id = str(i.id)
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
app = db.session.query(App).where(App.id == app_id).first()
app = db.session.scalar(select(App).where(App.id == app_id))
if app is not None:
apps.append(app)
@ -108,8 +109,8 @@ def convert_to_agent_apps():
db.session.commit()
# update conversation mode to agent
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT}
db.session.execute(
update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT)
)
db.session.commit()
@ -177,7 +178,7 @@ where sites.id is null limit 1000"""
continue
try:
app = db.session.query(App).where(App.id == app_id).first()
app = db.session.scalar(select(App).where(App.id == app_id))
if not app:
logger.info("App %s not found", app_id)
continue

View File

@ -41,14 +41,13 @@ def migrate_annotation_vector_database():
# get apps info
per_page = 50
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = (
session.query(App)
apps = session.scalars(
select(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
).all()
if not apps:
break
except SQLAlchemyError:
@ -63,8 +62,8 @@ def migrate_annotation_vector_database():
try:
click.echo(f"Creating app annotation index: {app.id}")
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1)
)
if not app_annotation_setting:
@ -72,10 +71,10 @@ def migrate_annotation_vector_database():
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
dataset_collection_binding = session.scalar(
select(DatasetCollectionBinding).where(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
)
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
@ -205,11 +204,11 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
dataset_collection_binding = db.session.execute(
select(DatasetCollectionBinding).where(
DatasetCollectionBinding.id == dataset.collection_binding_id
)
).scalar_one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
@ -334,7 +333,7 @@ def add_qdrant_index(field: str):
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
bindings = db.session.scalars(select(DatasetCollectionBinding)).all()
if not bindings:
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
@ -421,10 +420,10 @@ def old_metadata_migration():
if field.value == key:
break
else:
dataset_metadata = (
db.session.query(DatasetMetadata)
dataset_metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first()
.limit(1)
)
if not dataset_metadata:
dataset_metadata = DatasetMetadata(
@ -436,7 +435,7 @@ def old_metadata_migration():
)
db.session.add(dataset_metadata)
db.session.flush()
dataset_metadata_binding = DatasetMetadataBinding(
dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding(
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
metadata_id=dataset_metadata.id,
@ -445,14 +444,14 @@ def old_metadata_migration():
)
db.session.add(dataset_metadata_binding)
else:
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
dataset_metadata_binding = db.session.scalar(
select(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
)
.first()
.limit(1)
)
if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding(

View File

@ -30,6 +30,7 @@ from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@ -335,7 +336,7 @@ class MessageFeedbackApi(Resource):
if not args.rating and feedback:
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.rating = FeedbackRating(args.rating)
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
@ -347,9 +348,9 @@ class MessageFeedbackApi(Resource):
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
rating=FeedbackRating(rating_value),
content=args.content,
from_source="admin",
from_source=FeedbackFromSource.ADMIN,
from_account_id=current_user.id,
)
db.session.add(feedback)

View File

@ -298,6 +298,7 @@ class DatasetDocumentListApi(Resource):
if sort == "hit_count":
sub_query = (
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.where(DocumentSegment.dataset_id == str(dataset_id))
.group_by(DocumentSegment.document_id)
.subquery()
)

View File

@ -24,6 +24,7 @@ from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
retrieval_model: RetrievalModel | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None

View File

@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource):
app_model=app_model,
message_id=message_id,
user=current_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_model = None
if is_anonymous:
user_model = (
session.query(EndUser)
user_model = session.scalar(
select(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
.limit(1)
)
else:
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
user_model = session.get(EndUser, user_id)
if not user_model:
user_model = EndUser(
@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]):
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
tenant_model = db.session.get(Tenant, tenant_id)
if not tenant_model:
raise ValueError("tenant not found")

View File

@ -2,6 +2,7 @@ import json
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from controllers.common.schema import register_schema_models
from controllers.console.wraps import setup_required
@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource):
def post(self):
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
account = db.session.query(Account).filter_by(email=args.owner_email).first()
account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1))
if account is None:
return {"message": "owner account not found."}, 404

View File

@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
if signature_base64 != token:
return view(*args, **kwargs)
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
kwargs["user"] = db.session.get(EndUser, user_id)
return view(*args, **kwargs)

View File

@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource):
app_model=app_model,
message_id=message_id,
user=end_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@ -8,6 +8,7 @@ from datetime import datetime
from flask import Response, request
from flask_restx import Resource, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -147,11 +148,11 @@ class HumanInputFormApi(Resource):
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
"""Resolve App/Site for the form's app and validate tenant status."""
app_model = db.session.query(App).where(App.id == form.app_id).first()
app_model = db.session.get(App, form.app_id)
if app_model is None or app_model.tenant_id != form.tenant_id:
raise NotFoundError("Form not found")
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if site is None:
raise Forbidden()

View File

@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import uuid_value
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
app_model=app_model,
message_id=message_id,
user=end_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@ -1,6 +1,7 @@
from typing import cast
from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user):
"""Retrieve app site info."""
# get site
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise Forbidden()

View File

@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole, MessageStatus
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
upload_file_id=file["related_id"],
created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}

View File

@ -74,11 +74,22 @@ class AppGenerateResponseConverter(ABC):
for resource in metadata["retriever_resources"]:
updated_resources.append(
{
"dataset_id": resource.get("dataset_id"),
"dataset_name": resource.get("dataset_name"),
"document_id": resource.get("document_id"),
"segment_id": resource.get("segment_id", ""),
"position": resource["position"],
"data_source_type": resource.get("data_source_type"),
"document_name": resource["document_name"],
"score": resource["score"],
"hit_count": resource.get("hit_count"),
"word_count": resource.get("word_count"),
"segment_position": resource.get("segment_position"),
"index_node_hash": resource.get("index_node_hash"),
"content": resource["content"],
"page": resource.get("page"),
"title": resource.get("title"),
"files": resource.get("files"),
"summary": resource.get("summary"),
}
)

View File

@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import (
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
@ -419,7 +419,7 @@ class AppRunner:
message_id=message_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
url=f"/files/tools/{tool_file.id}",
upload_file_id=tool_file.id,
created_by_role=(

View File

@ -517,7 +517,7 @@ class WorkflowResponseConverter:
snapshot = self._pop_snapshot(event.node_execution_id)
start_at = snapshot.start_at if snapshot else event.start_at
finished_at = naive_utc_now()
finished_at = event.finished_at or naive_utc_now()
elapsed_time = (finished_at - start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)

View File

@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),

View File

@ -456,6 +456,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@ -471,6 +472,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
@ -487,6 +489,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,

View File

@ -335,6 +335,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
@ -390,6 +391,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
@ -414,6 +416,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)

View File

@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.enums import MessageFileBelongsTo
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
@ -233,7 +234,7 @@ class MessageCycleManager:
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or "user",
belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER,
url=url,
)

View File

@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event.finished_at,
)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
event.node_run_result,
WorkflowNodeExecutionStatus.FAILED,
error=event.error,
finished_at=event.finished_at,
)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
event.node_run_result,
WorkflowNodeExecutionStatus.EXCEPTION,
error=event.error,
finished_at=event.finished_at,
)
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
*,
error: str | None = None,
update_outputs: bool = True,
finished_at: datetime | None = None,
) -> None:
finished_at = naive_utc_now()
actual_finished_at = finished_at or naive_utc_now()
snapshot = self._node_snapshots.get(domain_execution.id)
start_at = snapshot.created_at if snapshot else domain_execution.created_at
domain_execution.status = status
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
domain_execution.finished_at = actual_finished_at
domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0)
if error:
domain_execution.error = error

View File

@ -15,6 +15,7 @@ from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import MessageFile, UploadFile
from models.tools import ToolFile
@ -81,7 +82,7 @@ class DatasourceFileManager:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=filepath,
name=present_filename,
size=len(file_binary),

View File

@ -1,9 +1,10 @@
import re
from typing import Any
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
def clean(cls, text: str, process_rule: dict[str, Any] | None) -> str:
# default clean
# remove invalid symbol
text = re.sub(r"<\|", "<", text)

View File

@ -4,6 +4,7 @@ from typing import Any
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@ -15,6 +16,11 @@ from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class PreSegmentData(TypedDict):
segment: DocumentSegment
keywords: list[str]
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
@ -128,7 +134,7 @@ class Jieba(BaseKeyword):
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
def _save_dataset_keyword_table(self, keyword_table: dict[str, set[str]] | None):
keyword_table_dict = {
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
@ -144,7 +150,7 @@ class Jieba(BaseKeyword):
storage.delete(file_key)
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> dict | None:
def _get_dataset_keyword_table(self) -> dict[str, set[str]] | None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
@ -169,14 +175,16 @@ class Jieba(BaseKeyword):
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]):
def _add_text_to_keyword_table(
self, keyword_table: dict[str, set[str]], id: str, keywords: list[str]
) -> dict[str, set[str]]:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]):
def _delete_ids_from_keyword_table(self, keyword_table: dict[str, set[str]], ids: list[str]) -> dict[str, set[str]]:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
@ -193,7 +201,7 @@ class Jieba(BaseKeyword):
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
def _retrieve_ids_by_query(self, keyword_table: dict[str, set[str]], query: str, k: int = 4) -> list[str]:
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
@ -228,7 +236,7 @@ class Jieba(BaseKeyword):
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
def multi_create_segment_keywords(self, pre_segment_data_list: list[PreSegmentData]):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:

View File

@ -103,7 +103,7 @@ class RetrievalService:
reranking_mode: str = "reranking_model",
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
attachment_ids: list[str] | None = None,
):
if not query and not attachment_ids:
return []
@ -250,8 +250,8 @@ class RetrievalService:
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
all_documents: list[Document],
exceptions: list[str],
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
@ -279,9 +279,9 @@ class RetrievalService:
top_k: int,
score_threshold: float | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
all_documents: list[Document],
retrieval_method: RetrievalMethod,
exceptions: list,
exceptions: list[str],
document_ids_filter: list[str] | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
):
@ -373,9 +373,9 @@ class RetrievalService:
top_k: int,
score_threshold: float | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
all_documents: list[Document],
retrieval_method: str,
exceptions: list,
exceptions: list[str],
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():

View File

@ -15,6 +15,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
@ -150,7 +151,7 @@ class PdfExtractor(BaseExtractor):
# save file to db
upload_file = UploadFile(
tenant_id=self._tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=file_key,
size=len(img_bytes),

View File

@ -21,6 +21,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
@ -112,7 +113,7 @@ class WordExtractor(BaseExtractor):
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=file_key,
size=0,
@ -140,7 +141,7 @@ class WordExtractor(BaseExtractor):
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=file_key,
size=0,
@ -365,7 +366,7 @@ class WordExtractor(BaseExtractor):
paragraph_content = []
# State for legacy HYPERLINK fields
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
hyperlink_field_text_parts: list[str] = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:

View File

@ -591,7 +591,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
query: str,
available_datasets: list,
available_datasets: list[Dataset],
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@ -633,15 +633,15 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if dataset:
selected_dataset = db.session.scalar(dataset_stmt)
if selected_dataset:
results = []
if dataset.provider == "external":
if selected_dataset.provider == "external":
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
tenant_id=selected_dataset.tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
external_retrieval_parameters=selected_dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
@ -654,28 +654,28 @@ class DatasetRetrieval:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
document.metadata["dataset_name"] = selected_dataset.name
results.append(document)
else:
if metadata_condition and not metadata_filter_document_ids:
return []
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
document_ids = metadata_filter_document_ids.get(selected_dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
return []
retrieval_model_config: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
if dataset.retrieval_model
cast(DefaultRetrievalModelDict, selected_dataset.retrieval_model)
if selected_dataset.retrieval_model
else default_retrieval_model
)
# get top k
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
if selected_dataset.indexing_technique == "economy":
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else:
retrieval_method = retrieval_model_config["search_method"]
@ -694,7 +694,7 @@ class DatasetRetrieval:
with measure_time() as timer:
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
dataset_id=selected_dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
@ -726,7 +726,7 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
available_datasets: list[Dataset],
query: str | None,
top_k: int,
score_threshold: float,
@ -1028,7 +1028,7 @@ class DatasetRetrieval:
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
all_documents: list[Document],
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
@ -1298,7 +1298,7 @@ class DatasetRetrieval:
def get_metadata_filter_condition(
self,
dataset_ids: list,
dataset_ids: list[str],
query: str,
tenant_id: str,
user_id: str,
@ -1400,7 +1400,7 @@ class DatasetRetrieval:
return output
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
self, dataset_ids: list[str], query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> list[dict[str, Any]] | None:
# get all metadata field
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
@ -1598,7 +1598,7 @@ class DatasetRetrieval:
)
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list[str], query: str
):
model_mode = ModelMode(mode)
input_text = query
@ -1690,7 +1690,7 @@ class DatasetRetrieval:
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
available_datasets: list[Dataset],
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],

View File

@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
from dify_graph.file import FileType
from dify_graph.file.models import FileTransferMethod
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import Message, MessageFile
logger = logging.getLogger(__name__)
@ -352,7 +352,7 @@ class ToolEngine:
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
url=message.url,
upload_file_id=tool_file_id,
created_by_role=(

View File

@ -3,7 +3,6 @@ from typing import Final
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
{

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, cast
from typing import Any
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
# Get trigger data passed when workflow was triggered
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self.node_data.provider_id,
"event_name": self.node_data.event_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,

View File

@ -245,6 +245,9 @@ _END_STATE = frozenset(
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
Values in this enum are persisted as execution metadata and must stay in sync
with every node that writes `NodeRunResult.metadata`.
"""
TOTAL_TOKENS = "total_tokens"
@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
TRIGGER_INFO = "trigger_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node

View File

@ -159,6 +159,7 @@ class ErrorHandler:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
@ -198,6 +199,7 @@ class ErrorHandler:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,

View File

@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from dify_graph.context import IExecutionContext
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from libs.datetime_utils import naive_utc_now
from .ready_queue import ReadyQueue
@ -65,6 +68,7 @@ class Worker(threading.Thread):
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
self._current_node_started_at: datetime | None = None
def stop(self) -> None:
"""Signal the worker to stop processing."""
@ -104,18 +108,15 @@ class Worker(threading.Thread):
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._current_node_started_at = None
self._execute_node(node)
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),
self._event_queue.put(
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
)
self._event_queue.put(error_event)
finally:
self._current_node_started_at = None
def _execute_node(self, node: Node) -> None:
"""
@ -136,6 +137,8 @@ class Worker(threading.Thread):
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
@ -149,6 +152,8 @@ class Worker(threading.Thread):
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
@ -177,3 +182,24 @@ class Worker(threading.Thread):
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _build_fallback_failure_event(
self, node: Node, error: Exception, *, started_at: datetime | None = None
) -> NodeRunFailedEvent:
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
failure_time = naive_utc_now()
error_message = str(error)
return NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=error_message,
start_at=started_at or failure_time,
finished_at=failure_time,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
error_type=type(error).__name__,
),
)

View File

@ -36,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
class NodeRunSucceededEvent(GraphNodeEventBase):
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunFailedEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunExceptionEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunRetryEvent(NodeRunStartedEvent):

View File

@ -406,11 +406,13 @@ class Node(Generic[NodeDataT]):
error=str(e),
error_type="WorkflowNodeError",
)
finished_at = naive_utc_now()
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
error=str(e),
)
@ -568,6 +570,7 @@ class Node(Generic[NodeDataT]):
return self._node_data
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
finished_at = naive_utc_now()
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
@ -575,6 +578,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
error=result.error,
)
@ -584,6 +588,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
)
case _:
@ -606,6 +611,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
finished_at = naive_utc_now()
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
@ -613,6 +619,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result,
)
case WorkflowNodeExecutionStatus.FAILED:
@ -621,6 +628,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)

View File

@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
future_to_index: dict[
Future[
tuple[
datetime,
float,
list[GraphNodeEventBase],
object | None,
dict[str, Variable],
@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
try:
result = future.result()
(
iter_start_at,
iteration_duration,
events,
output_value,
conversation_snapshot,
@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Yield all events from this iteration
yield from events
# Update tokens and timing
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
# The worker computes duration before we replay buffered events here,
# so slow downstream consumers don't inflate per-iteration timing.
iter_run_map[str(index)] = iteration_duration
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
index: int,
item: object,
execution_context: "IExecutionContext",
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
conversation_snapshot = self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool
)
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
return (
iter_start_at,
iteration_duration,
events,
output_value,
conversation_snapshot,

View File

@ -3,6 +3,7 @@ import logging
import time
import click
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
@ -24,13 +25,11 @@ def handle(sender, **kwargs):
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document)
.where(
document = db.session.scalar(
select(Document).where(
Document.id == document_id,
Document.dataset_id == dataset_id,
)
.first()
)
if not document:

View File

@ -1,6 +1,6 @@
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy import delete, select
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
db.session.execute(
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
)
if added_dataset_ids:
for dataset_id in added_dataset_ids:

View File

@ -1,6 +1,6 @@
from typing import cast
from sqlalchemy import select
from sqlalchemy import delete, select
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from dify_graph.nodes import BuiltinNodeTypes
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
db.session.execute(
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
)
if added_dataset_ids:
for dataset_id in added_dataset_ids:

View File

@ -3,6 +3,7 @@ import json
import flask_login
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@ -34,16 +35,15 @@ def load_user_from_request(request_from_flask_login):
if admin_api_key and admin_api_key == auth_token:
workspace_id = request.headers.get("X-WORKSPACE-ID")
if workspace_id:
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == workspace_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role == "owner")
.one_or_none()
)
).one_or_none()
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter_by(id=ta.account_id).first()
account = db.session.scalar(select(Account).where(Account.id == ta.account_id))
if account:
account.current_tenant = tenant
return account
@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound("End user not found.")
return end_user
@ -80,7 +80,7 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound("End user not found.")
return end_user
@ -90,11 +90,11 @@ def load_user_from_request(request_from_flask_login):
server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code:
raise Unauthorized("Invalid Authorization token.")
app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
app_mcp_server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
if not app_mcp_server:
raise NotFound("App MCP server not found.")
end_user = (
db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").limit(1)
)
if not end_user:
raise NotFound("End user not found.")

View File

@ -32,7 +32,7 @@ class OpenDALStorage(BaseStorage):
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
if scheme == "fs":
root = kwargs.get("root", "storage")
root = kwargs.setdefault("root", "storage")
Path(root).mkdir(parents=True, exist_ok=True)
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)

View File

@ -424,13 +424,11 @@ def _build_from_datasource_file(
datasource_file_id = mapping.get("datasource_file_id")
if not datasource_file_id:
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
datasource_file = (
db.session.query(UploadFile)
.where(
datasource_file = db.session.scalar(
select(UploadFile).where(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
.first()
)
if datasource_file is None:

View File

@ -158,6 +158,13 @@ class FeedbackFromSource(StrEnum):
ADMIN = "admin"
class FeedbackRating(StrEnum):
"""MessageFeedback rating"""
LIKE = "like"
DISLIKE = "dislike"
class InvokeFrom(StrEnum):
"""How a conversation/message was invoked"""

View File

@ -23,6 +23,7 @@ from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
from libs.uuid_utils import uuidv7
@ -35,7 +36,10 @@ from .enums import (
BannerStatus,
ConversationStatus,
CreatorUserRole,
FeedbackFromSource,
FeedbackRating,
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
)
from .provider_ids import GenericProviderID
@ -1164,7 +1168,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
or 0
@ -1175,7 +1179,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
or 0
@ -1190,7 +1194,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
or 0
@ -1201,7 +1205,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
or 0
@ -1724,8 +1728,8 @@ class MessageFeedback(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
@ -1778,7 +1782,9 @@ class MessageFile(TypeBase):
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
)
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
@ -2108,7 +2114,7 @@ class UploadFile(Base):
# The `server_default` serves as a fallback mechanism.
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@ -2152,7 +2158,7 @@ class UploadFile(Base):
self,
*,
tenant_id: str,
storage_type: str,
storage_type: StorageType,
key: str,
name: str,
size: int,

View File

@ -22,14 +22,14 @@ from sqlalchemy import (
from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from dify_graph.constants import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
from dify_graph.file.constants import maybe_file_object
from dify_graph.file.models import File
from dify_graph.variables import utils as variable_utils
@ -936,8 +936,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
datasource_info = execution_metadata["datasource_info"]
extras["icon"] = datasource_info.get("icon")
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
elif (
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
):
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
provider_id = trigger_info.get("provider_id")
if provider_id:
extras["icon"] = TriggerManager.get_trigger_plugin_icon(

View File

@ -1,6 +1,6 @@
[pytest]
pythonpath = .
addopts = --cov=./api --cov-report=json --import-mode=importlib
addopts = --cov=./api --cov-report=json --import-mode=importlib --cov-branch --cov-report=xml
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com

View File

@ -3,6 +3,7 @@ import math
import time
import click
from sqlalchemy import select
import app
from core.helper.marketplace import fetch_global_plugin_manifest
@ -28,17 +29,15 @@ def check_upgradable_plugin_task():
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green"))
strategies = (
db.session.query(TenantPluginAutoUpgradeStrategy)
.where(
strategies = db.session.scalars(
select(TenantPluginAutoUpgradeStrategy).where(
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
TenantPluginAutoUpgradeStrategy.strategy_setting
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
)
.all()
)
).all()
total_strategies = len(strategies)
click.echo(click.style(f"Total strategies: {total_strategies}", fg="green"))

View File

@ -2,7 +2,7 @@ import datetime
import time
import click
from sqlalchemy import text
from sqlalchemy import select, text
from sqlalchemy.exc import SQLAlchemyError
import app
@ -19,14 +19,12 @@ def clean_embedding_cache_task():
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
while True:
try:
embedding_ids = (
db.session.query(Embedding.id)
embedding_ids = db.session.scalars(
select(Embedding.id)
.where(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc())
.limit(100)
.all()
)
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
).all()
except SQLAlchemyError:
raise
if embedding_ids:

View File

@ -3,7 +3,7 @@ import time
from typing import TypedDict
import click
from sqlalchemy import func, select
from sqlalchemy import func, select, update
from sqlalchemy.exc import SQLAlchemyError
import app
@ -51,7 +51,7 @@ def clean_unused_datasets_task():
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
select(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
@ -64,7 +64,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
select(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
@ -142,8 +142,8 @@ def clean_unused_datasets_task():
index_processor.clean(dataset, None)
# Update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
{Document.enabled: False}
db.session.execute(
update(Document).where(Document.dataset_id == dataset.id).values(enabled=False)
)
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))

View File

@ -1,6 +1,7 @@
import time
import click
from sqlalchemy import func, select
import app
from configs import dify_config
@ -20,7 +21,7 @@ def create_tidb_serverless_task():
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
db.session.scalar(select(func.count(TidbAuthBinding.id)).where(TidbAuthBinding.active == False)) or 0
)
if idle_tidb_serverless_number >= tidb_serverless_number:
break

View File

@ -49,16 +49,18 @@ def mail_clean_document_notify_task():
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
# check current owner
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if not current_owner_join:
continue
account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
if not account:
continue
@ -71,7 +73,7 @@ def mail_clean_document_notify_task():
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

View File

@ -7,6 +7,7 @@ from flask import Response
from sqlalchemy import or_
from extensions.ext_database import db
from models.enums import FeedbackRating
from models.model import Account, App, Conversation, Message, MessageFeedback
@ -100,7 +101,7 @@ class FeedbackService:
"ai_response": message.answer[:500] + "..."
if len(message.answer) > 500
else message.answer, # Truncate long responses
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
"feedback_rating": "👍" if feedback.rating == FeedbackRating.LIKE else "👎",
"feedback_rating_raw": feedback.rating,
"feedback_comment": feedback.content or "",
"feedback_source": feedback.from_source,

View File

@ -23,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from dify_graph.file import helpers as file_helpers
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from models import Account
@ -93,7 +94,7 @@ class FileService:
# save file to db
upload_file = UploadFile(
tenant_id=current_tenant_id or "",
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=filename,
size=file_size,
@ -152,7 +153,7 @@ class FileService:
# save file to db
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=text_name,
size=len(text),

View File

@ -16,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
from repositories.sqlalchemy_execution_extra_content_repository import (
@ -172,7 +173,7 @@ class MessageService:
app_model: App,
message_id: str,
user: Union[Account, EndUser] | None,
rating: str | None,
rating: FeedbackRating | None,
content: str | None,
):
if not user:
@ -197,7 +198,7 @@ class MessageService:
message_id=message.id,
rating=rating,
content=content,
from_source=("user" if isinstance(user, EndUser) else "admin"),
from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),
)

View File

@ -14,6 +14,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, MessageFeedback
from services.feedback_service import FeedbackService
@ -77,8 +78,8 @@ class TestFeedbackExportApi:
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content=None,
from_end_user_id=str(uuid.uuid4()),
from_account_id=None,
@ -90,8 +91,8 @@ class TestFeedbackExportApi:
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="dislike",
from_source="admin",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.ADMIN,
content="The response was not helpful",
from_end_user_id=None,
from_account_id=str(uuid.uuid4()),
@ -277,8 +278,8 @@ class TestFeedbackExportApi:
# Verify service was called with correct parameters
mock_export_feedbacks.assert_called_once_with(
app_id=mock_app_model.id,
from_source="user",
rating="dislike",
from_source=FeedbackFromSource.USER,
rating=FeedbackRating.DISLIKE,
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",

View File

@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=storage_key,
name="test_file.txt",
size=1024,
@ -288,7 +289,7 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="other_tenant_key",
name="other_file.txt",
size=1024,

View File

@ -13,6 +13,7 @@ from dify_graph.variables.types import SegmentType
from dify_graph.variables.variables import StringVariable
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from factories.variable_factory import build_segment
from libs import datetime_utils
from models.enums import CreatorUserRole
@ -347,7 +348,7 @@ class TestDraftVariableLoader(unittest.TestCase):
# Create an upload file record
upload_file = UploadFile(
tenant_id=self._test_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_offload_{uuid.uuid4()}.json",
name="test_offload.json",
size=len(content_bytes),
@ -450,7 +451,7 @@ class TestDraftVariableLoader(unittest.TestCase):
# Create upload file record
upload_file = UploadFile(
tenant_id=self._test_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_integration_{uuid.uuid4()}.txt",
name="test_integration.txt",
size=len(content_bytes),

View File

@ -6,6 +6,7 @@ from sqlalchemy import delete
from core.db.session_factory import session_factory
from dify_graph.variables.segments import StringSegment
from extensions.storage.storage_type import StorageType
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
@ -197,7 +198,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file1.json",
name="file1.json",
size=1024,
@ -210,7 +211,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file2.json",
name="file2.json",
size=2048,
@ -430,7 +431,7 @@ class TestDeleteDraftVariablesSessionCommit:
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file1.json",
name="file1.json",
size=1024,
@ -443,7 +444,7 @@ class TestDeleteDraftVariablesSessionCommit:
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file2.json",
name="file2.json",
size=2048,

View File

@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=storage_key,
name="test_file.txt",
size=1024,
@ -289,7 +290,7 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="other_tenant_key",
name="other_file.txt",
size=1024,

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
@ -198,7 +199,7 @@ class DocumentStatusTestDataFactory:
"""
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"uploads/{uuid4()}",
name=name,
size=128,

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models.enums import MessageFileBelongsTo
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
@ -852,7 +853,7 @@ class TestAgentService:
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file1.jpg",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
@ -861,7 +862,7 @@ class TestAgentService:
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file2.png",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pytest
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
@ -83,7 +84,7 @@ def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, n
"""Persist an upload file row referenced by document.data_source_info."""
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"uploads/{uuid4()}",
name=name,
size=128,

View File

@ -8,6 +8,7 @@ from unittest import mock
import pytest
from extensions.ext_database import db
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
@ -47,8 +48,8 @@ class TestFeedbackService:
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content="Great answer!",
from_end_user_id="user-123",
from_account_id=None,
@ -61,8 +62,8 @@ class TestFeedbackService:
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="admin",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.ADMIN,
content="Could be more detailed",
from_end_user_id=None,
from_account_id="admin-456",
@ -179,8 +180,8 @@ class TestFeedbackService:
# Test with filters
result = FeedbackService.export_feedbacks(
app_id=sample_data["app"].id,
from_source="admin",
rating="dislike",
from_source=FeedbackFromSource.ADMIN,
rating=FeedbackRating.DISLIKE,
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",
@ -293,8 +294,8 @@ class TestFeedbackService:
app_id=sample_data["app"].id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="user",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.USER,
content="回答不够详细,需要更多信息",
from_end_user_id="user-123",
from_account_id=None,

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from extensions.storage.storage_type import StorageType
from models import Account, Tenant
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@ -140,7 +141,7 @@ class TestFileService:
upload_file = UploadFile(
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"upload_files/test/{fake.uuid4()}.txt",
name="test_file.txt",
size=1024,

View File

@ -7,6 +7,7 @@ import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import (
App,
AppAnnotationHitHistory,
@ -172,8 +173,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content="first",
from_end_user_id=conversation.from_end_user_id,
)
@ -181,8 +182,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.USER,
content="second",
from_end_user_id=conversation.from_end_user_id,
)
@ -190,8 +191,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.ADMIN,
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import FeedbackRating
from models.model import MessageFeedback
from services.app_service import AppService
from services.errors.message import (
@ -405,7 +406,7 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create feedback
rating = "like"
rating = FeedbackRating.LIKE
content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=rating, content=content
@ -435,7 +436,11 @@ class TestMessageService:
# Test creating feedback with no user
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.create_feedback(
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
app_model=app,
message_id=message.id,
user=None,
rating=FeedbackRating.LIKE,
content=fake.text(max_nb_chars=100),
)
def test_create_feedback_update_existing(
@ -452,14 +457,14 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial feedback
initial_rating = "like"
initial_rating = FeedbackRating.LIKE
initial_content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
)
# Update feedback
updated_rating = "dislike"
updated_rating = FeedbackRating.DISLIKE
updated_content = fake.text(max_nb_chars=100)
updated_feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
@ -487,7 +492,11 @@ class TestMessageService:
# Create initial feedback
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
app_model=app,
message_id=message.id,
user=account,
rating=FeedbackRating.LIKE,
content=fake.text(max_nb_chars=100),
)
# Delete feedback by setting rating to None
@ -538,7 +547,7 @@ class TestMessageService:
app_model=app,
message_id=message.id,
user=account,
rating="like" if i % 2 == 0 else "dislike",
rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE,
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
)
feedbacks.append(feedback)
@ -568,7 +577,11 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
app_model=app,
message_id=message.id,
user=account,
rating=FeedbackRating.LIKE,
content=f"Feedback {i}",
)
# Get feedbacks with pagination

View File

@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType, MessageChainType
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
from models.model import (
App,
AppAnnotationHitHistory,
@ -166,7 +166,7 @@ class TestMessagesCleanServiceIntegration:
name="Test conversation",
inputs={},
status="normal",
from_source="api",
from_source=FeedbackFromSource.USER,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(conversation)
@ -196,7 +196,7 @@ class TestMessagesCleanServiceIntegration:
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
from_source="api",
from_source=FeedbackFromSource.USER,
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)
@ -216,8 +216,8 @@ class TestMessagesCleanServiceIntegration:
app_id=message.app_id,
conversation_id=message.conversation_id,
message_id=message.id,
rating="like",
from_source="api",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(feedback)
@ -249,7 +249,7 @@ class TestMessagesCleanServiceIntegration:
type="image",
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)

View File

@ -48,41 +48,42 @@ class TestToolTransformService:
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
provider_type="api",
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
elif provider_type == "builtin":
provider = BuiltinToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon="🔧",
icon_dark="🔧",
tenant_id="test_tenant_id",
user_id="test_user_id",
provider="test_provider",
credential_type="api_key",
credentials={"api_key": "test_key"},
encrypted_credentials='{"api_key": "test_key"}',
)
elif provider_type == "workflow":
provider = WorkflowToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
workflow_id="test_workflow_id",
app_id="test_workflow_id",
label="Test Workflow",
version="1.0.0",
parameter_configuration="[]",
)
elif provider_type == "mcp":
provider = MCPToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
server_url="https://mcp.example.com",
server_url_hash="test_server_url_hash",
server_identifier="test_server",
tools='[{"name": "test_tool", "description": "Test tool"}]',
authed=True,

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -209,7 +210,7 @@ class TestBatchCleanDocumentTask:
upload_file = UploadFile(
tenant_id=account.current_tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{fake.file_name()}",
name=fake.file_name(),
size=1024,

View File

@ -19,6 +19,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
@ -203,7 +204,7 @@ class TestBatchCreateSegmentToIndexTask:
upload_file = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{fake.file_name()}",
name=fake.file_name(),
size=1024,

View File

@ -18,6 +18,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -254,7 +255,7 @@ class TestCleanDatasetTask:
upload_file = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{fake.file_name()}",
name=fake.file_name(),
size=1024,
@ -925,7 +926,7 @@ class TestCleanDatasetTask:
special_filename = f"test_file_{special_content}.txt"
upload_file = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{special_filename}",
name=special_filename,
size=1024,

View File

@ -6,6 +6,7 @@ import pytest
from core.db.session_factory import session_factory
from dify_graph.variables.segments import StringSegment
from dify_graph.variables.types import SegmentType
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models import Tenant
from models.enums import CreatorUserRole
@ -78,7 +79,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id:
for i in range(count):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test/file-{uuid.uuid4()}-{i}.json",
name=f"file-{i}.json",
size=1024 + i,

View File

@ -0,0 +1,56 @@
from pathlib import Path
from extensions.storage.opendal_storage import OpenDALStorage
class TestOpenDALFsDefaultRoot:
"""Test that OpenDALStorage with scheme='fs' works correctly when no root is provided."""
def test_fs_without_root_uses_default(self, tmp_path, monkeypatch):
"""When no root is specified, the default 'storage' should be used and passed to the Operator."""
# Change to tmp_path so the default "storage" dir is created there
monkeypatch.chdir(tmp_path)
# Ensure no OPENDAL_FS_ROOT env var is set
monkeypatch.delenv("OPENDAL_FS_ROOT", raising=False)
storage = OpenDALStorage(scheme="fs")
# The default directory should have been created
assert (tmp_path / "storage").is_dir()
# The storage should be functional
storage.save("test_default_root.txt", b"hello")
assert storage.exists("test_default_root.txt")
assert storage.load_once("test_default_root.txt") == b"hello"
# Cleanup
storage.delete("test_default_root.txt")
def test_fs_with_explicit_root(self, tmp_path):
"""When root is explicitly provided, it should be used."""
custom_root = str(tmp_path / "custom_storage")
storage = OpenDALStorage(scheme="fs", root=custom_root)
assert Path(custom_root).is_dir()
storage.save("test_explicit_root.txt", b"world")
assert storage.exists("test_explicit_root.txt")
assert storage.load_once("test_explicit_root.txt") == b"world"
# Cleanup
storage.delete("test_explicit_root.txt")
def test_fs_with_env_var_root(self, tmp_path, monkeypatch):
"""When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs."""
env_root = str(tmp_path / "env_storage")
monkeypatch.setenv("OPENDAL_FS_ROOT", env_root)
# Ensure .env file doesn't interfere
monkeypatch.chdir(tmp_path)
storage = OpenDALStorage(scheme="fs")
assert Path(env_root).is_dir()
storage.save("test_env_root.txt", b"env_data")
assert storage.exists("test_env_root.txt")
assert storage.load_once("test_env_root.txt") == b"env_data"
# Cleanup
storage.delete("test_env_root.txt")

View File

@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import (
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.provider_manager import ProviderManager
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService
@ -1121,7 +1122,7 @@ class TestDatasetIndexingEstimateApi:
def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="key",
name="name.txt",
size=1,

View File

@ -50,7 +50,7 @@ class TestGetUser:
mock_user.id = "user123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
mock_session.get.return_value = mock_user
# Act
with app.app_context():
@ -58,7 +58,7 @@ class TestGetUser:
# Assert
assert result == mock_user
mock_session.query.assert_called_once()
mock_session.get.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@ -72,7 +72,8 @@ class TestGetUser:
mock_user.session_id = "anonymous_session"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# non-anonymous path uses session.get(); anonymous uses session.scalar()
mock_session.get.return_value = mock_user
# Act
with app.app_context():
@ -89,7 +90,7 @@ class TestGetUser:
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = None
mock_session.get.return_value = None
mock_new_user = MagicMock()
mock_enduser_class.return_value = mock_new_user
@ -103,18 +104,20 @@ class TestGetUser:
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_use_default_session_id_when_user_id_none(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask
):
"""Test using default session ID when user_id is None"""
# Arrange
mock_user = MagicMock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# When user_id is None, is_anonymous=True, so session.scalar() is used
mock_session.scalar.return_value = mock_user
# Act
with app.app_context():
@ -133,7 +136,7 @@ class TestGetUser:
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.side_effect = Exception("Database error")
mock_session.get.side_effect = Exception("Database error")
# Act & Assert
with app.app_context():
@ -161,9 +164,9 @@ class TestGetUserTenant:
# Act
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()
@ -194,8 +197,8 @@ class TestGetUserTenant:
# Act & Assert
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
mock_get.return_value = None
with pytest.raises(ValueError, match="tenant not found"):
protected_view()
@ -215,9 +218,9 @@ class TestGetUserTenant:
# Act - use empty string for user_id to trigger default logic
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()

View File

@ -249,8 +249,8 @@ class TestEnterpriseInnerApiUserAuth:
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
):
with patch.object(dify_config, "INNER_API", True):
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = mock_user
with patch("controllers.inner_api.wraps.db.session.get") as mock_get:
mock_get.return_value = mock_user
result = protected_view()
# Assert

View File

@ -91,7 +91,7 @@ class TestEnterpriseWorkspace:
# Arrange
mock_account = MagicMock()
mock_account.email = "owner@example.com"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
mock_db.session.scalar.return_value = mock_account
now = datetime(2025, 1, 1, 12, 0, 0)
mock_tenant = MagicMock()
@ -122,7 +122,7 @@ class TestEnterpriseWorkspace:
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
"""Test that post() returns 404 when the owner account does not exist"""
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
# Act
unwrapped_post = inspect.unwrap(api_instance.post)

View File

@ -31,6 +31,7 @@ from controllers.service_api.app.message import (
MessageListQuery,
MessageSuggestedApi,
)
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import (
@ -310,7 +311,7 @@ class TestMessageService:
app_model=Mock(spec=App),
message_id=str(uuid.uuid4()),
user=Mock(spec=EndUser),
rating="like",
rating=FeedbackRating.LIKE,
content="Great response!",
)
@ -326,7 +327,7 @@ class TestMessageService:
app_model=Mock(spec=App),
message_id="invalid_message_id",
user=Mock(spec=EndUser),
rating="like",
rating=FeedbackRating.LIKE,
content=None,
)

View File

@ -39,14 +39,21 @@ class TestHitTestingPayload:
def test_payload_with_all_fields(self):
"""Test payload with all optional fields."""
retrieval_model_data = {
"search_method": "semantic_search",
"reranking_enable": False,
"score_threshold_enabled": False,
"top_k": 5,
}
payload = HitTestingPayload(
query="test query",
retrieval_model={"top_k": 5},
retrieval_model=retrieval_model_data,
external_retrieval_model={"provider": "openai"},
attachment_ids=["att_1", "att_2"],
)
assert payload.query == "test query"
assert payload.retrieval_model == {"top_k": 5}
assert payload.retrieval_model is not None
assert payload.retrieval_model.top_k == 5
assert payload.external_retrieval_model == {"provider": "openai"}
assert payload.attachment_ids == ["att_1", "att_2"]
@ -134,7 +141,13 @@ class TestHitTestingApiPost:
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8}
retrieval_model = {
"search_method": "semantic_search",
"reranking_enable": False,
"score_threshold_enabled": True,
"top_k": 10,
"score_threshold": 0.8,
}
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
@ -152,7 +165,11 @@ class TestHitTestingApiPost:
assert response["query"] == "complex query"
call_kwargs = mock_hit_svc.retrieve.call_args
assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model
# retrieval_model is serialized via model_dump, verify key fields
passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model")
assert passed_retrieval_model is not None
assert passed_retrieval_model["search_method"] == "semantic_search"
assert passed_retrieval_model["top_k"] == 10
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")

View File

@ -49,6 +49,17 @@ class _FakeSession:
assert self._model_name is not None
return self._mapping.get(self._model_name)
def get(self, model, ident):
return self._mapping.get(model.__name__)
def scalar(self, stmt):
# Extract the model name from the select statement's column_descriptions
try:
name = stmt.column_descriptions[0]["entity"].__name__
except (AttributeError, IndexError, KeyError):
return None
return self._mapping.get(name)
class _FakeDB:
"""Minimal db stub exposing engine and session."""

View File

@ -50,7 +50,7 @@ class TestAppSiteApi:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
site_obj = _site()
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
mock_db.session.scalar.return_value = site_obj
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
end_user = SimpleNamespace(id="eu-1")
@ -66,9 +66,9 @@ class TestAppSiteApi:
@patch("controllers.web.site.db")
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
@ -80,7 +80,7 @@ class TestAppSiteApi:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
from models.account import TenantStatus
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
mock_db.session.scalar.return_value = _site()
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.ARCHIVE,

View File

@ -44,11 +44,22 @@ class TestAgentChatAppGenerateResponseConverterBlocking:
metadata={
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s1",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "content",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
}
],
"annotation_reply": {"id": "a"},
@ -107,11 +118,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
metadata={
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s1",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "content",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
"summary": "summary",
"extra": "ignored",
}
@ -151,11 +173,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
assert "usage" not in metadata
assert metadata["retriever_resources"] == [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s1",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "content",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
"summary": "summary",
}
]

View File

@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
import uuid
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from unittest.mock import Mock
@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_finish_response_prefers_event_finished_at(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Finished timestamps should come from the event, not delayed queue processing time."""
converter = self.create_workflow_response_converter()
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr(
"core.app.apps.common.workflow_response_converter.naive_utc_now",
lambda: delayed_processing_time,
)
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
event = QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=BuiltinNodeTypes.CODE,
node_execution_id="node-exec-1",
start_at=start_at,
finished_at=finished_at,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
assert response is not None
assert response.data.elapsed_time == 2.0
assert response.data.finished_at == int(finished_at.timestamp())
def test_workflow_node_retry_response_uses_truncated_process_data(self):
"""Test that node retry response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()

View File

@ -38,11 +38,22 @@ class TestCompletionAppGenerateResponseConverter:
metadata = {
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "c",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
"summary": "sum",
"extra": "x",
}
@ -66,7 +77,12 @@ class TestCompletionAppGenerateResponseConverter:
assert "annotation_reply" not in result["metadata"]
assert "usage" not in result["metadata"]
assert result["metadata"]["retriever_resources"][0]["dataset_id"] == "dataset-1"
assert result["metadata"]["retriever_resources"][0]["document_id"] == "document-1"
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
assert result["metadata"]["retriever_resources"][0]["data_source_type"] == "file"
assert result["metadata"]["retriever_resources"][0]["segment_position"] == 3
assert result["metadata"]["retriever_resources"][0]["index_node_hash"] == "abc1234"
assert "extra" not in result["metadata"]["retriever_resources"][0]
def test_convert_blocking_simple_response_metadata_not_dict(self):

View File

@ -0,0 +1,60 @@
from datetime import UTC, datetime
from unittest.mock import Mock
import pytest
from core.app.workflow.layers.persistence import (
PersistenceWorkflowInfo,
WorkflowPersistenceLayer,
_NodeRuntimeSnapshot,
)
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
from dify_graph.node_events import NodeRunResult
def _build_layer() -> WorkflowPersistenceLayer:
application_generate_entity = Mock()
application_generate_entity.inputs = {}
return WorkflowPersistenceLayer(
application_generate_entity=application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
version="1",
graph_data={},
),
workflow_execution_repository=Mock(),
workflow_node_execution_repository=Mock(),
)
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
layer = _build_layer()
node_execution = Mock()
node_execution.id = "node-exec-1"
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
node_execution.update_from_mapping = Mock()
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
node_id="node-id",
title="LLM",
predecessor_node_id=None,
iteration_id="iter-1",
loop_id=None,
created_at=node_execution.created_at,
)
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
layer._update_node_execution(
node_execution,
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event_finished_at,
)
assert node_execution.finished_at == event_finished_at
assert node_execution.elapsed_time == 2.0

View File

@ -166,6 +166,7 @@ class TestDatasourceFileManager:
# Setup
mock_guess_ext.return_value = None # Cannot guess
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_config.STORAGE_TYPE = "local"
# Execute
upload_file = DatasourceFileManager.create_file_by_raw(

View File

@ -0,0 +1,145 @@
import queue
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
from dify_graph.graph_engine.worker import Worker
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
worker = Worker(
ready_queue=InMemoryReadyQueue(),
event_queue=queue.Queue(),
graph=MagicMock(),
layers=[],
)
node = SimpleNamespace(
execution_id="exec-1",
id="node-1",
node_type=BuiltinNodeTypes.LLM,
)
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
assert event.start_at == fixed_time
assert event.finished_at == fixed_time
assert event.error == "boom"
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
assert event.node_run_result.error == "boom"
assert event.node_run_result.error_type == "RuntimeError"
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
failure_time = start_at + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeNode:
execution_id = "exec-1"
id = "node-1"
node_type = BuiltinNodeTypes.LLM
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="LLM",
start_at=start_at,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"node-1": FakeNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["node-1"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 1:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == start_at
assert fallback_event.finished_at == failure_time
assert fallback_event.error == "queue boom"
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None:
parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
child_start = parent_start + timedelta(seconds=3)
failure_time = parent_start + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeIterationNode:
execution_id = "iteration-exec"
id = "iteration-node"
node_type = BuiltinNodeTypes.ITERATION
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="Iteration",
start_at=parent_start,
)
yield NodeRunStartedEvent(
id="child-exec",
node_id="child-node",
node_type=BuiltinNodeTypes.LLM,
node_title="LLM",
start_at=child_start,
in_iteration_id=self.id,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["iteration-node"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 2:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == parent_start
assert fallback_event.finished_at == failure_time

View File

@ -0,0 +1,63 @@
import time
from contextlib import nullcontext
from datetime import UTC, datetime
import pytest
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.graph_events import NodeRunSucceededEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from dify_graph.nodes.iteration.iteration_node import IterationNode
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Parallel Iteration",
iterator_selector=["start", "items"],
output_selector=["iteration", "output"],
is_parallel=True,
parallel_nums=2,
error_handle_mode=ErrorHandleMode.TERMINATED,
)
node._capture_execution_context = lambda: nullcontext()
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
return (
0.1 + (index * 0.1),
[
NodeRunSucceededEvent(
id=f"exec-{index}",
node_id=f"llm-{index}",
node_type=BuiltinNodeTypes.LLM,
start_at=datetime.now(UTC).replace(tzinfo=None),
),
],
f"output-{item}",
{},
LLMUsage.empty_usage(),
)
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
outputs: list[object] = []
iter_run_map: dict[str, float] = {}
usage_accumulator = [LLMUsage.empty_usage()]
generator = node._execute_parallel_iterations(
iterator_list_value=["a", "b"],
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
for _ in generator:
# Simulate a slow consumer replaying buffered events.
time.sleep(0.02)
assert outputs == ["output-a", "output-b"]
assert iter_run_map["0"] == pytest.approx(0.1)
assert iter_run_map["1"] == pytest.approx(0.2)

View File

@ -0,0 +1,63 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from tests.workflow_test_utils import build_test_graph_init_params
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
init_params = build_test_graph_init_params(
graph_config=graph_config,
user_from="account",
invoke_from="debugger",
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable(user_id="user", files=[]),
user_inputs={"payload": "value"},
),
start_at=0.0,
)
return init_params, runtime_state
def _build_node_config() -> NodeConfigDict:
return NodeConfigDictAdapter.validate_python(
{
"id": "node-1",
"data": {
"type": TRIGGER_PLUGIN_NODE_TYPE,
"title": "Trigger Event",
"plugin_id": "plugin-id",
"provider_id": "provider-id",
"event_name": "event-name",
"subscription_id": "subscription-id",
"plugin_unique_identifier": "plugin-unique-identifier",
"event_parameters": {},
},
}
)
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
id="node-1",
config=_build_node_config(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
"provider_id": "provider-id",
"event_name": "event-name",
"plugin_unique_identifier": "plugin-unique-identifier",
}

View File

@ -0,0 +1,19 @@
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.node_events.base import NodeRunResult
def test_node_run_result_accepts_trigger_info_metadata() -> None:
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": "provider-id",
"event_name": "event-name",
}
},
)
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
"provider_id": "provider-id",
"event_name": "event-name",
}

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, EndUser, Message
from services.errors.message import (
FirstMessageNotExistsError,
@ -820,14 +821,14 @@ class TestMessageServiceFeedback:
app_model=app,
message_id="msg-123",
user=user,
rating="like",
rating=FeedbackRating.LIKE,
content="Good answer",
)
# Assert
assert result.rating == "like"
assert result.rating == FeedbackRating.LIKE
assert result.content == "Good answer"
assert result.from_source == "user"
assert result.from_source == FeedbackFromSource.USER
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
@ -852,13 +853,13 @@ class TestMessageServiceFeedback:
app_model=app,
message_id="msg-123",
user=user,
rating="dislike",
rating=FeedbackRating.DISLIKE,
content="Bad answer",
)
# Assert
assert result == feedback
assert feedback.rating == "dislike"
assert feedback.rating == FeedbackRating.DISLIKE
assert feedback.content == "Bad answer"
mock_db.session.commit.assert_called_once()

16
codecov.yml Normal file
View File

@ -0,0 +1,16 @@
coverage:
status:
project:
default:
target: auto
flags:
web:
paths:
- "web/"
carryforward: true
api:
paths:
- "api/"
carryforward: true

View File

@ -7,17 +7,21 @@
*/
import type { InstalledApp } from '@/models/explore'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import Toast from '@/app/components/base/toast'
import SideBar from '@/app/components/explore/sidebar'
import { MediaType } from '@/hooks/use-breakpoints'
import { AppModeEnum } from '@/types/app'
const { mockToastAdd } = vi.hoisted(() => ({
mockToastAdd: vi.fn(),
}))
let mockMediaType: string = MediaType.pc
const mockSegments = ['apps']
const mockPush = vi.fn()
const mockUninstall = vi.fn()
const mockUpdatePinStatus = vi.fn()
let mockInstalledApps: InstalledApp[] = []
let mockIsUninstallPending = false
vi.mock('@/next/navigation', () => ({
useSelectedLayoutSegments: () => mockSegments,
@ -42,12 +46,22 @@ vi.mock('@/service/use-explore', () => ({
}),
useUninstallApp: () => ({
mutateAsync: mockUninstall,
isPending: mockIsUninstallPending,
}),
useUpdateAppPinStatus: () => ({
mutateAsync: mockUpdatePinStatus,
}),
}))
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
add: mockToastAdd,
close: vi.fn(),
update: vi.fn(),
promise: vi.fn(),
},
}))
const createInstalledApp = (overrides: Partial<InstalledApp> = {}): InstalledApp => ({
id: overrides.id ?? 'app-1',
uninstallable: overrides.uninstallable ?? false,
@ -74,7 +88,7 @@ describe('Sidebar Lifecycle Flow', () => {
vi.clearAllMocks()
mockMediaType = MediaType.pc
mockInstalledApps = []
vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
mockIsUninstallPending = false
})
describe('Pin / Unpin / Delete Flow', () => {
@ -91,7 +105,7 @@ describe('Sidebar Lifecycle Flow', () => {
await waitFor(() => {
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: true })
expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
}))
})
@ -110,7 +124,7 @@ describe('Sidebar Lifecycle Flow', () => {
await waitFor(() => {
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: false })
expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
}))
})
@ -136,9 +150,9 @@ describe('Sidebar Lifecycle Flow', () => {
// Step 4: Uninstall API called and success toast shown
await waitFor(() => {
expect(mockUninstall).toHaveBeenCalledWith('app-1')
expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
message: 'common.api.remove',
title: 'common.api.remove',
}))
})
})

Some files were not shown because too many files have changed in this diff Show More