mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into 3-18-no-global-loading
This commit is contained in:
commit
088a481432
|
|
@ -6,8 +6,8 @@ runs:
|
|||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
|
||||
with:
|
||||
node-version-file: "./web/.nvmrc"
|
||||
node-version-file: web/.nvmrc
|
||||
cache: true
|
||||
cache-dependency-path: web/pnpm-lock.yaml
|
||||
run-install: |
|
||||
- cwd: ./web
|
||||
args: ['--frozen-lockfile']
|
||||
cwd: ./web
|
||||
|
|
|
|||
|
|
@ -2,6 +2,12 @@ name: Run Pytest
|
|||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
|
|
@ -11,6 +17,8 @@ jobs:
|
|||
test:
|
||||
name: API Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -24,6 +32,7 @@ jobs:
|
|||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
|
|
@ -79,21 +88,12 @@ jobs:
|
|||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
{
|
||||
echo ""
|
||||
echo "<details><summary>File-level coverage (click to expand)</summary>"
|
||||
echo ""
|
||||
echo '```'
|
||||
uv run --project api coverage report -m
|
||||
echo '```'
|
||||
echo "</details>"
|
||||
} >> $GITHUB_STEP_SUMMARY
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
flags: api
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ jobs:
|
|||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
|
|
@ -48,14 +50,15 @@ def setup_system_tool_oauth_client(provider, client_params):
|
|||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(ToolOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(ToolOAuthSystemClient).where(
|
||||
ToolOAuthSystemClient.provider == provider_name,
|
||||
ToolOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
|
|
@ -97,14 +100,15 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
|||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(TriggerOAuthSystemClient).where(
|
||||
TriggerOAuthSystemClient.provider == provider_name,
|
||||
TriggerOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
|
|
@ -139,14 +143,15 @@ def setup_datasource_oauth_client(provider, client_params):
|
|||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(DatasourceOauthParamConfig).where(
|
||||
DatasourceOauthParamConfig.provider == provider_name,
|
||||
DatasourceOauthParamConfig.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
|
|
@ -192,7 +197,9 @@ def transform_datasource_credentials(environment: str):
|
|||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
notion_credentials = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion")
|
||||
).all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for notion_credential in notion_credentials:
|
||||
|
|
@ -201,7 +208,7 @@ def transform_datasource_credentials(environment: str):
|
|||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
|
|
@ -250,7 +257,9 @@ def transform_datasource_credentials(environment: str):
|
|||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
firecrawl_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl")
|
||||
).all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
|
|
@ -259,7 +268,7 @@ def transform_datasource_credentials(environment: str):
|
|||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
|
|
@ -312,7 +321,9 @@ def transform_datasource_credentials(environment: str):
|
|||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
jina_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader")
|
||||
).all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for jina_credential in jina_credentials:
|
||||
|
|
@ -321,7 +332,7 @@ def transform_datasource_credentials(environment: str):
|
|||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import json
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -740,14 +743,17 @@ def migrate_oss(
|
|||
else:
|
||||
try:
|
||||
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
|
||||
updated = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
|
||||
)
|
||||
updated = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
update(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.values(storage_type=dify_config.STORAGE_TYPE)
|
||||
),
|
||||
).rowcount
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -41,7 +42,7 @@ def reset_encrypt_key_pair():
|
|||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
tenants = session.scalars(select(Tenant)).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
|
|
@ -49,8 +50,8 @@ def reset_encrypt_key_pair():
|
|||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
|
||||
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
|
|
@ -93,7 +94,7 @@ def convert_to_agent_apps():
|
|||
app_id = str(i.id)
|
||||
if app_id not in proceeded_app_ids:
|
||||
proceeded_app_ids.append(app_id)
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
if app is not None:
|
||||
apps.append(app)
|
||||
|
||||
|
|
@ -108,8 +109,8 @@ def convert_to_agent_apps():
|
|||
db.session.commit()
|
||||
|
||||
# update conversation mode to agent
|
||||
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
|
||||
{Conversation.mode: AppMode.AGENT_CHAT}
|
||||
db.session.execute(
|
||||
update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
|
@ -177,7 +178,7 @@ where sites.id is null limit 1000"""
|
|||
continue
|
||||
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
if not app:
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -41,14 +41,13 @@ def migrate_annotation_vector_database():
|
|||
# get apps info
|
||||
per_page = 50
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
apps = session.scalars(
|
||||
select(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
|
|
@ -63,8 +62,8 @@ def migrate_annotation_vector_database():
|
|||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1)
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
|
|
@ -72,10 +71,10 @@ def migrate_annotation_vector_database():
|
|||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
dataset_collection_binding = session.scalar(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
|
|
@ -205,11 +204,11 @@ def migrate_knowledge_vector_database():
|
|||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
dataset_collection_binding = db.session.execute(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == dataset.collection_binding_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
|
|
@ -334,7 +333,7 @@ def add_qdrant_index(field: str):
|
|||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
bindings = db.session.scalars(select(DatasetCollectionBinding)).all()
|
||||
if not bindings:
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
|
|
@ -421,10 +420,10 @@ def old_metadata_migration():
|
|||
if field.value == key:
|
||||
break
|
||||
else:
|
||||
dataset_metadata = (
|
||||
db.session.query(DatasetMetadata)
|
||||
dataset_metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_metadata:
|
||||
dataset_metadata = DatasetMetadata(
|
||||
|
|
@ -436,7 +435,7 @@ def old_metadata_migration():
|
|||
)
|
||||
db.session.add(dataset_metadata)
|
||||
db.session.flush()
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
metadata_id=dataset_metadata.id,
|
||||
|
|
@ -445,14 +444,14 @@ def old_metadata_migration():
|
|||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
else:
|
||||
dataset_metadata_binding = (
|
||||
db.session.query(DatasetMetadataBinding) # type: ignore
|
||||
dataset_metadata_binding = db.session.scalar(
|
||||
select(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||
DatasetMetadataBinding.document_id == document.id,
|
||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_metadata_binding:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from fields.raws import FilesContainedField
|
|||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
|
@ -335,7 +336,7 @@ class MessageFeedbackApi(Resource):
|
|||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
feedback.rating = FeedbackRating(args.rating)
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
|
|
@ -347,9 +348,9 @@ class MessageFeedbackApi(Resource):
|
|||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
rating=FeedbackRating(rating_value),
|
||||
content=args.content,
|
||||
from_source="admin",
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
|
|
|||
|
|
@ -298,6 +298,7 @@ class DatasetDocumentListApi(Resource):
|
|||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from fields.hit_testing_fields import hit_testing_record_fields
|
|||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt
|
|||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
|
|
@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar
|
|||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_model = session.get(EndUser, user_id)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
|
|
@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
|||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
tenant_model = db.session.get(Tenant, tenant_id)
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import setup_required
|
||||
|
|
@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource):
|
|||
def post(self):
|
||||
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
account = db.session.query(Account).filter_by(email=args.owner_email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1))
|
||||
if account is None:
|
||||
return {"message": "owner account not found."}, 404
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
|||
if signature_base64 != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
|
||||
kwargs["user"] = db.session.get(EndUser, user_id)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource):
|
|||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from datetime import datetime
|
|||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -147,11 +148,11 @@ class HumanInputFormApi(Resource):
|
|||
|
||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
||||
app_model = db.session.get(App, form.app_id)
|
||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse
|
|||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
|
|
@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
|
|||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource):
|
|||
def get(self, app_model, end_user):
|
||||
"""Retrieve app site info."""
|
||||
# get site
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable
|
|||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole, MessageStatus
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
|
@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
type=file["type"],
|
||||
transfer_method=file["transfer_method"],
|
||||
url=file["remote_url"],
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
upload_file_id=file["related_id"],
|
||||
created_by_role=CreatorUserRole.ACCOUNT
|
||||
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
|
|
|
|||
|
|
@ -74,11 +74,22 @@ class AppGenerateResponseConverter(ABC):
|
|||
for resource in metadata["retriever_resources"]:
|
||||
updated_resources.append(
|
||||
{
|
||||
"dataset_id": resource.get("dataset_id"),
|
||||
"dataset_name": resource.get("dataset_name"),
|
||||
"document_id": resource.get("document_id"),
|
||||
"segment_id": resource.get("segment_id", ""),
|
||||
"position": resource["position"],
|
||||
"data_source_type": resource.get("data_source_type"),
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
"hit_count": resource.get("hit_count"),
|
||||
"word_count": resource.get("word_count"),
|
||||
"segment_position": resource.get("segment_position"),
|
||||
"index_node_hash": resource.get("index_node_hash"),
|
||||
"content": resource["content"],
|
||||
"page": resource.get("page"),
|
||||
"title": resource.get("title"),
|
||||
"files": resource.get("files"),
|
||||
"summary": resource.get("summary"),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import (
|
|||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -419,7 +419,7 @@ class AppRunner:
|
|||
message_id=message_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
url=f"/files/tools/{tool_file.id}",
|
||||
upload_file_id=tool_file.id,
|
||||
created_by_role=(
|
||||
|
|
|
|||
|
|
@ -517,7 +517,7 @@ class WorkflowResponseConverter:
|
|||
snapshot = self._pop_snapshot(event.node_execution_id)
|
||||
|
||||
start_at = snapshot.start_at if snapshot else event.start_at
|
||||
finished_at = naive_utc_now()
|
||||
finished_at = event.finished_at or naive_utc_now()
|
||||
elapsed_time = (finished_at - start_at).total_seconds()
|
||||
|
||||
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
|
|||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
|
@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
|
|
|
|||
|
|
@ -456,6 +456,7 @@ class WorkflowBasedAppRunner:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
|
|
@ -471,6 +472,7 @@ class WorkflowBasedAppRunner:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
|
|
@ -487,6 +489,7 @@ class WorkflowBasedAppRunner:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
|
|
|
|||
|
|
@ -335,6 +335,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
|
@ -390,6 +391,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
|
@ -414,6 +416,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator
|
|||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import MessageFileBelongsTo
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
|
@ -233,7 +234,7 @@ class MessageCycleManager:
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER,
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
|
|
@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
|
|
@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
|
|
@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||
*,
|
||||
error: str | None = None,
|
||||
update_outputs: bool = True,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
finished_at = naive_utc_now()
|
||||
actual_finished_at = finished_at or naive_utc_now()
|
||||
snapshot = self._node_snapshots.get(domain_execution.id)
|
||||
start_at = snapshot.created_at if snapshot else domain_execution.created_at
|
||||
domain_execution.status = status
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
|
||||
domain_execution.finished_at = actual_finished_at
|
||||
domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0)
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from configs import dify_config
|
|||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
|
@ -81,7 +82,7 @@ class DatasourceFileManager:
|
|||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=filepath,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CleanProcessor:
|
||||
@classmethod
|
||||
def clean(cls, text: str, process_rule: dict) -> str:
|
||||
def clean(cls, text: str, process_rule: dict[str, Any] | None) -> str:
|
||||
# default clean
|
||||
# remove invalid symbol
|
||||
text = re.sub(r"<\|", "<", text)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any
|
|||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
|
|
@ -15,6 +16,11 @@ from extensions.ext_storage import storage
|
|||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
||||
|
||||
|
||||
class PreSegmentData(TypedDict):
|
||||
segment: DocumentSegment
|
||||
keywords: list[str]
|
||||
|
||||
|
||||
class KeywordTableConfig(BaseModel):
|
||||
max_keywords_per_chunk: int = 10
|
||||
|
||||
|
|
@ -128,7 +134,7 @@ class Jieba(BaseKeyword):
|
|||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
storage.delete(file_key)
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
def _save_dataset_keyword_table(self, keyword_table: dict[str, set[str]] | None):
|
||||
keyword_table_dict = {
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
|
|
@ -144,7 +150,7 @@ class Jieba(BaseKeyword):
|
|||
storage.delete(file_key)
|
||||
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
|
||||
|
||||
def _get_dataset_keyword_table(self) -> dict | None:
|
||||
def _get_dataset_keyword_table(self) -> dict[str, set[str]] | None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
|
|
@ -169,14 +175,16 @@ class Jieba(BaseKeyword):
|
|||
|
||||
return {}
|
||||
|
||||
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]):
|
||||
def _add_text_to_keyword_table(
|
||||
self, keyword_table: dict[str, set[str]], id: str, keywords: list[str]
|
||||
) -> dict[str, set[str]]:
|
||||
for keyword in keywords:
|
||||
if keyword not in keyword_table:
|
||||
keyword_table[keyword] = set()
|
||||
keyword_table[keyword].add(id)
|
||||
return keyword_table
|
||||
|
||||
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]):
|
||||
def _delete_ids_from_keyword_table(self, keyword_table: dict[str, set[str]], ids: list[str]) -> dict[str, set[str]]:
|
||||
# get set of ids that correspond to node
|
||||
node_idxs_to_delete = set(ids)
|
||||
|
||||
|
|
@ -193,7 +201,7 @@ class Jieba(BaseKeyword):
|
|||
|
||||
return keyword_table
|
||||
|
||||
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
|
||||
def _retrieve_ids_by_query(self, keyword_table: dict[str, set[str]], query: str, k: int = 4) -> list[str]:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keywords = keyword_table_handler.extract_keywords(query)
|
||||
|
||||
|
|
@ -228,7 +236,7 @@ class Jieba(BaseKeyword):
|
|||
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list[PreSegmentData]):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class RetrievalService:
|
|||
reranking_mode: str = "reranking_model",
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
if not query and not attachment_ids:
|
||||
return []
|
||||
|
|
@ -250,8 +250,8 @@ class RetrievalService:
|
|||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
exceptions: list,
|
||||
all_documents: list[Document],
|
||||
exceptions: list[str],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
|
|
@ -279,9 +279,9 @@ class RetrievalService:
|
|||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
all_documents: list[Document],
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
exceptions: list[str],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
):
|
||||
|
|
@ -373,9 +373,9 @@ class RetrievalService:
|
|||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
all_documents: list[Document],
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
exceptions: list[str],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
|||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
|
@ -150,7 +151,7 @@ class PdfExtractor(BaseExtractor):
|
|||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=len(img_bytes),
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
|||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
|
@ -112,7 +113,7 @@ class WordExtractor(BaseExtractor):
|
|||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
|
|
@ -140,7 +141,7 @@ class WordExtractor(BaseExtractor):
|
|||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
|
|
@ -365,7 +366,7 @@ class WordExtractor(BaseExtractor):
|
|||
paragraph_content = []
|
||||
# State for legacy HYPERLINK fields
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts: list = []
|
||||
hyperlink_field_text_parts: list[str] = []
|
||||
is_collecting_field_text = False
|
||||
# Iterate through paragraph elements in document order
|
||||
for child in paragraph._element:
|
||||
|
|
|
|||
|
|
@ -591,7 +591,7 @@ class DatasetRetrieval:
|
|||
user_id: str,
|
||||
user_from: str,
|
||||
query: str,
|
||||
available_datasets: list,
|
||||
available_datasets: list[Dataset],
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
|
|
@ -633,15 +633,15 @@ class DatasetRetrieval:
|
|||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
if dataset:
|
||||
selected_dataset = db.session.scalar(dataset_stmt)
|
||||
if selected_dataset:
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
if selected_dataset.provider == "external":
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
tenant_id=selected_dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
external_retrieval_parameters=selected_dataset.retrieval_model,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
|
|
@ -654,28 +654,28 @@ class DatasetRetrieval:
|
|||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
document.metadata["dataset_name"] = selected_dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
return []
|
||||
document_ids_filter = None
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
document_ids = metadata_filter_document_ids.get(selected_dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
return []
|
||||
retrieval_model_config: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
cast(DefaultRetrievalModelDict, selected_dataset.retrieval_model)
|
||||
if selected_dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
if selected_dataset.indexing_technique == "economy":
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
|
|
@ -694,7 +694,7 @@ class DatasetRetrieval:
|
|||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
dataset_id=selected_dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
|
|
@ -726,7 +726,7 @@ class DatasetRetrieval:
|
|||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
available_datasets: list[Dataset],
|
||||
query: str | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
|
|
@ -1028,7 +1028,7 @@ class DatasetRetrieval:
|
|||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
all_documents: list[Document],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
|
|
@ -1298,7 +1298,7 @@ class DatasetRetrieval:
|
|||
|
||||
def get_metadata_filter_condition(
|
||||
self,
|
||||
dataset_ids: list,
|
||||
dataset_ids: list[str],
|
||||
query: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
|
|
@ -1400,7 +1400,7 @@ class DatasetRetrieval:
|
|||
return output
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
self, dataset_ids: list[str], query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> list[dict[str, Any]] | None:
|
||||
# get all metadata field
|
||||
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
|
|
@ -1598,7 +1598,7 @@ class DatasetRetrieval:
|
|||
)
|
||||
|
||||
def _get_prompt_template(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
|
||||
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list[str], query: str
|
||||
):
|
||||
model_mode = ModelMode(mode)
|
||||
input_text = query
|
||||
|
|
@ -1690,7 +1690,7 @@ class DatasetRetrieval:
|
|||
def _multiple_retrieve_thread(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
available_datasets: list,
|
||||
available_datasets: list[Dataset],
|
||||
metadata_condition: MetadataCondition | None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||
all_documents: list[Document],
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
|
|||
from dify_graph.file import FileType
|
||||
from dify_graph.file.models import FileTransferMethod
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -352,7 +352,7 @@ class ToolEngine:
|
|||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Final
|
|||
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
|
||||
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
|
||||
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
|
||||
|
||||
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
|
||||
|
|
@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
|||
|
||||
# Get trigger data passed when workflow was triggered
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"event_name": self.node_data.event_name,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
|
|
|
|||
|
|
@ -245,6 +245,9 @@ _END_STATE = frozenset(
|
|||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
|
||||
Values in this enum are persisted as execution metadata and must stay in sync
|
||||
with every node that writes `NodeRunResult.metadata`.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
|
|
@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -159,6 +159,7 @@ class ErrorHandler:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
|
|
@ -198,6 +199,7 @@ class ErrorHandler:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
|
|
|
|||
|
|
@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
|
|||
from typing_extensions import override
|
||||
|
||||
from dify_graph.context import IExecutionContext
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
|
@ -65,6 +68,7 @@ class Worker(threading.Thread):
|
|||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
self._current_node_started_at: datetime | None = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
|
|
@ -104,18 +108,15 @@ class Worker(threading.Thread):
|
|||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._current_node_started_at = None
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
self._event_queue.put(
|
||||
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
|
||||
)
|
||||
self._event_queue.put(error_event)
|
||||
finally:
|
||||
self._current_node_started_at = None
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
|
|
@ -136,6 +137,8 @@ class Worker(threading.Thread):
|
|||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
|
|
@ -149,6 +152,8 @@ class Worker(threading.Thread):
|
|||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
|
|
@ -177,3 +182,24 @@ class Worker(threading.Thread):
|
|||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _build_fallback_failure_event(
|
||||
self, node: Node, error: Exception, *, started_at: datetime | None = None
|
||||
) -> NodeRunFailedEvent:
|
||||
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
|
||||
failure_time = naive_utc_now()
|
||||
error_message = str(error)
|
||||
return NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=error_message,
|
||||
start_at=started_at or failure_time,
|
||||
finished_at=failure_time,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
error_type=type(error).__name__,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
|||
|
||||
class NodeRunSucceededEvent(GraphNodeEventBase):
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
|
|
|
|||
|
|
@ -406,11 +406,13 @@ class Node(Generic[NodeDataT]):
|
|||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
finished_at = naive_utc_now()
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=str(e),
|
||||
)
|
||||
|
|
@ -568,6 +570,7 @@ class Node(Generic[NodeDataT]):
|
|||
return self._node_data
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
finished_at = naive_utc_now()
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
|
|
@ -575,6 +578,7 @@ class Node(Generic[NodeDataT]):
|
|||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=result.error,
|
||||
)
|
||||
|
|
@ -584,6 +588,7 @@ class Node(Generic[NodeDataT]):
|
|||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
case _:
|
||||
|
|
@ -606,6 +611,7 @@ class Node(Generic[NodeDataT]):
|
|||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
finished_at = naive_utc_now()
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
|
|
@ -613,6 +619,7 @@ class Node(Generic[NodeDataT]):
|
|||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
|
|
@ -621,6 +628,7 @@ class Node(Generic[NodeDataT]):
|
|||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
future_to_index: dict[
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
float,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, Variable],
|
||||
|
|
@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
try:
|
||||
result = future.result()
|
||||
(
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
|
|
@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
# The worker computes duration before we replay buffered events here,
|
||||
# so slow downstream consumers don't inflate per-iteration timing.
|
||||
iter_run_map[str(index)] = iteration_duration
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
|
|
@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
index: int,
|
||||
item: object,
|
||||
execution_context: "IExecutionContext",
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with execution_context:
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
|
@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
return (
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import logging
|
|||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
|
|
@ -24,13 +25,11 @@ def handle(sender, **kwargs):
|
|||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
document = db.session.scalar(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not document:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
|
|||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
|
||||
)
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
|
|
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
|
|||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
|
||||
)
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import json
|
|||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -34,16 +35,15 @@ def load_user_from_request(request_from_flask_login):
|
|||
if admin_api_key and admin_api_key == auth_token:
|
||||
workspace_id = request.headers.get("X-WORKSPACE-ID")
|
||||
if workspace_id:
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == workspace_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role == "owner")
|
||||
.one_or_none()
|
||||
)
|
||||
).one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).filter_by(id=ta.account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == ta.account_id))
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
|
@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
|
|||
end_user_id = decoded.get("end_user_id")
|
||||
if not end_user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
|
|
@ -80,7 +80,7 @@ def load_user_from_request(request_from_flask_login):
|
|||
decoded = PassportService().verify(auth_token)
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
|
|
@ -90,11 +90,11 @@ def load_user_from_request(request_from_flask_login):
|
|||
server_code = request.view_args.get("server_code") if request.view_args else None
|
||||
if not server_code:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
app_mcp_server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
|
||||
if not app_mcp_server:
|
||||
raise NotFound("App MCP server not found.")
|
||||
end_user = (
|
||||
db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first()
|
||||
end_user = db.session.scalar(
|
||||
select(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").limit(1)
|
||||
)
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class OpenDALStorage(BaseStorage):
|
|||
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
|
||||
|
||||
if scheme == "fs":
|
||||
root = kwargs.get("root", "storage")
|
||||
root = kwargs.setdefault("root", "storage")
|
||||
Path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
|
||||
|
|
|
|||
|
|
@ -424,13 +424,11 @@ def _build_from_datasource_file(
|
|||
datasource_file_id = mapping.get("datasource_file_id")
|
||||
if not datasource_file_id:
|
||||
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
|
||||
datasource_file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
datasource_file = db.session.scalar(
|
||||
select(UploadFile).where(
|
||||
UploadFile.id == datasource_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if datasource_file is None:
|
||||
|
|
|
|||
|
|
@ -158,6 +158,13 @@ class FeedbackFromSource(StrEnum):
|
|||
ADMIN = "admin"
|
||||
|
||||
|
||||
class FeedbackRating(StrEnum):
|
||||
"""MessageFeedback rating"""
|
||||
|
||||
LIKE = "like"
|
||||
DISLIKE = "dislike"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""How a conversation/message was invoked"""
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from core.tools.signature import sign_tool_file
|
|||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
|
|
@ -35,7 +36,10 @@ from .enums import (
|
|||
BannerStatus,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
|
|
@ -1164,7 +1168,7 @@ class Conversation(Base):
|
|||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
|
|
@ -1175,7 +1179,7 @@ class Conversation(Base):
|
|||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
|
|
@ -1190,7 +1194,7 @@ class Conversation(Base):
|
|||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
|
|
@ -1201,7 +1205,7 @@ class Conversation(Base):
|
|||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
|
|
@ -1724,8 +1728,8 @@ class MessageFeedback(TypeBase):
|
|||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
|
||||
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
|
||||
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
|
|
@ -1778,7 +1782,9 @@ class MessageFile(TypeBase):
|
|||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
|
||||
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
|
||||
)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
@ -2108,7 +2114,7 @@ class UploadFile(Base):
|
|||
# The `server_default` serves as a fallback mechanism.
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
|
|
@ -2152,7 +2158,7 @@ class UploadFile(Base):
|
|||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
storage_type: str,
|
||||
storage_type: StorageType,
|
||||
key: str,
|
||||
name: str,
|
||||
size: int,
|
||||
|
|
|
|||
|
|
@ -22,14 +22,14 @@ from sqlalchemy import (
|
|||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
|
|
@ -936,8 +936,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
|
||||
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
|
||||
elif (
|
||||
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
|
||||
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
|
||||
):
|
||||
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[pytest]
|
||||
pythonpath = .
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib --cov-branch --cov-report=xml
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import math
|
|||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
import app
|
||||
from core.helper.marketplace import fetch_global_plugin_manifest
|
||||
|
|
@ -28,17 +29,15 @@ def check_upgradable_plugin_task():
|
|||
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
|
||||
click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green"))
|
||||
|
||||
strategies = (
|
||||
db.session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(
|
||||
strategies = db.session.scalars(
|
||||
select(TenantPluginAutoUpgradeStrategy).where(
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
|
||||
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
|
||||
TenantPluginAutoUpgradeStrategy.strategy_setting
|
||||
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
total_strategies = len(strategies)
|
||||
click.echo(click.style(f"Total strategies: {total_strategies}", fg="green"))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
|
|
@ -19,14 +19,12 @@ def clean_embedding_cache_task():
|
|||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
while True:
|
||||
try:
|
||||
embedding_ids = (
|
||||
db.session.query(Embedding.id)
|
||||
embedding_ids = db.session.scalars(
|
||||
select(Embedding.id)
|
||||
.where(Embedding.created_at < thirty_days_ago)
|
||||
.order_by(Embedding.created_at.desc())
|
||||
.limit(100)
|
||||
.all()
|
||||
)
|
||||
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
|
||||
).all()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if embedding_ids:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
from typing import TypedDict
|
||||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
|
|
@ -51,7 +51,7 @@ def clean_unused_datasets_task():
|
|||
try:
|
||||
# Subquery for counting new documents
|
||||
document_subquery_new = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
select(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
|
|
@ -64,7 +64,7 @@ def clean_unused_datasets_task():
|
|||
|
||||
# Subquery for counting old documents
|
||||
document_subquery_old = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
select(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
|
|
@ -142,8 +142,8 @@ def clean_unused_datasets_task():
|
|||
index_processor.clean(dataset, None)
|
||||
|
||||
# Update document
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
|
||||
{Document.enabled: False}
|
||||
db.session.execute(
|
||||
update(Document).where(Document.dataset_id == dataset.id).values(enabled=False)
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
|
|
@ -20,7 +21,7 @@ def create_tidb_serverless_task():
|
|||
try:
|
||||
# check the number of idle tidb serverless
|
||||
idle_tidb_serverless_number = (
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
|
||||
db.session.scalar(select(func.count(TidbAuthBinding.id)).where(TidbAuthBinding.active == False)) or 0
|
||||
)
|
||||
if idle_tidb_serverless_number >= tidb_serverless_number:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -49,16 +49,18 @@ def mail_clean_document_notify_task():
|
|||
if plan != CloudPlan.SANDBOX:
|
||||
knowledge_details = []
|
||||
# check tenant
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
# check current owner
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
)
|
||||
if not current_owner_join:
|
||||
continue
|
||||
account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
|
||||
if not account:
|
||||
continue
|
||||
|
||||
|
|
@ -71,7 +73,7 @@ def mail_clean_document_notify_task():
|
|||
)
|
||||
|
||||
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
|
||||
if dataset:
|
||||
document_count = len(document_ids)
|
||||
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from flask import Response
|
|||
from sqlalchemy import or_
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import Account, App, Conversation, Message, MessageFeedback
|
||||
|
||||
|
||||
|
|
@ -100,7 +101,7 @@ class FeedbackService:
|
|||
"ai_response": message.answer[:500] + "..."
|
||||
if len(message.answer) > 500
|
||||
else message.answer, # Truncate long responses
|
||||
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
|
||||
"feedback_rating": "👍" if feedback.rating == FeedbackRating.LIKE else "👎",
|
||||
"feedback_rating_raw": feedback.rating,
|
||||
"feedback_comment": feedback.content or "",
|
||||
"feedback_source": feedback.from_source,
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account
|
||||
|
|
@ -93,7 +94,7 @@ class FileService:
|
|||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_tenant_id or "",
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=filename,
|
||||
size=file_size,
|
||||
|
|
@ -152,7 +153,7 @@ class FileService:
|
|||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=text_name,
|
||||
size=len(text),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
|||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import (
|
||||
|
|
@ -172,7 +173,7 @@ class MessageService:
|
|||
app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser] | None,
|
||||
rating: str | None,
|
||||
rating: FeedbackRating | None,
|
||||
content: str | None,
|
||||
):
|
||||
if not user:
|
||||
|
|
@ -197,7 +198,7 @@ class MessageService:
|
|||
message_id=message.id,
|
||||
rating=rating,
|
||||
content=content,
|
||||
from_source=("user" if isinstance(user, EndUser) else "admin"),
|
||||
from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN),
|
||||
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
|
||||
from_account_id=(user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from controllers.console.app import wraps
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, MessageFeedback
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
|
@ -77,8 +78,8 @@ class TestFeedbackExportApi:
|
|||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content=None,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
from_account_id=None,
|
||||
|
|
@ -90,8 +91,8 @@ class TestFeedbackExportApi:
|
|||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
content="The response was not helpful",
|
||||
from_end_user_id=None,
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
|
|
@ -277,8 +278,8 @@ class TestFeedbackExportApi:
|
|||
# Verify service was called with correct parameters
|
||||
mock_export_feedbacks.assert_called_once_with(
|
||||
app_id=mock_app_model.id,
|
||||
from_source="user",
|
||||
rating="dislike",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
|||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
|
|
@ -288,7 +289,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
|||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dify_graph.variables.types import SegmentType
|
|||
from dify_graph.variables.variables import StringVariable
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -347,7 +348,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
|||
# Create an upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_offload_{uuid.uuid4()}.json",
|
||||
name="test_offload.json",
|
||||
size=len(content_bytes),
|
||||
|
|
@ -450,7 +451,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
|||
# Create upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_integration_{uuid.uuid4()}.txt",
|
||||
name="test_integration.txt",
|
||||
size=len(content_bytes),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy import delete
|
|||
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
|
|
@ -197,7 +198,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
|
|
@ -210,7 +211,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
|
|
@ -430,7 +431,7 @@ class TestDeleteDraftVariablesSessionCommit:
|
|||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
|
|
@ -443,7 +444,7 @@ class TestDeleteDraftVariablesSessionCommit:
|
|||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
|||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
|
|
@ -289,7 +290,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
|||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
|
|
@ -198,7 +199,7 @@ class DocumentStatusTestDataFactory:
|
|||
"""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
from models.enums import MessageFileBelongsTo
|
||||
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.agent_service import AgentService
|
||||
|
|
@ -852,7 +853,7 @@ class TestAgentService:
|
|||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="http://example.com/file1.jpg",
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
|
|
@ -861,7 +862,7 @@ class TestAgentService:
|
|||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="http://example.com/file2.png",
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
|
||||
|
|
@ -83,7 +84,7 @@ def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, n
|
|||
"""Persist an upload file row referenced by document.data_source_info."""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from unittest import mock
|
|||
import pytest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
|
@ -47,8 +48,8 @@ class TestFeedbackService:
|
|||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="like",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="Great answer!",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
|
|
@ -61,8 +62,8 @@ class TestFeedbackService:
|
|||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
content="Could be more detailed",
|
||||
from_end_user_id=None,
|
||||
from_account_id="admin-456",
|
||||
|
|
@ -179,8 +180,8 @@ class TestFeedbackService:
|
|||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
app_id=sample_data["app"].id,
|
||||
from_source="admin",
|
||||
rating="dislike",
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
|
|
@ -293,8 +294,8 @@ class TestFeedbackService:
|
|||
app_id=sample_data["app"].id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="回答不够详细,需要更多信息",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
|
|
@ -140,7 +141,7 @@ class TestFileService:
|
|||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"upload_files/test/{fake.uuid4()}.txt",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import pytest
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
|
|
@ -172,8 +173,8 @@ class TestAppMessageExportServiceIntegration:
|
|||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
|
|
@ -181,8 +182,8 @@ class TestAppMessageExportServiceIntegration:
|
|||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
|
|
@ -190,8 +191,8 @@ class TestAppMessageExportServiceIntegration:
|
|||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
from services.errors.message import (
|
||||
|
|
@ -405,7 +406,7 @@ class TestMessageService:
|
|||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create feedback
|
||||
rating = "like"
|
||||
rating = FeedbackRating.LIKE
|
||||
content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=rating, content=content
|
||||
|
|
@ -435,7 +436,11 @@ class TestMessageService:
|
|||
# Test creating feedback with no user
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=None,
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=fake.text(max_nb_chars=100),
|
||||
)
|
||||
|
||||
def test_create_feedback_update_existing(
|
||||
|
|
@ -452,14 +457,14 @@ class TestMessageService:
|
|||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
initial_rating = "like"
|
||||
initial_rating = FeedbackRating.LIKE
|
||||
initial_content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
|
||||
)
|
||||
|
||||
# Update feedback
|
||||
updated_rating = "dislike"
|
||||
updated_rating = FeedbackRating.DISLIKE
|
||||
updated_content = fake.text(max_nb_chars=100)
|
||||
updated_feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
|
||||
|
|
@ -487,7 +492,11 @@ class TestMessageService:
|
|||
|
||||
# Create initial feedback
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=fake.text(max_nb_chars=100),
|
||||
)
|
||||
|
||||
# Delete feedback by setting rating to None
|
||||
|
|
@ -538,7 +547,7 @@ class TestMessageService:
|
|||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating="like" if i % 2 == 0 else "dislike",
|
||||
rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE,
|
||||
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
|
||||
)
|
||||
feedbacks.append(feedback)
|
||||
|
|
@ -568,7 +577,11 @@ class TestMessageService:
|
|||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=f"Feedback {i}",
|
||||
)
|
||||
|
||||
# Get feedbacks with pagination
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
|
|||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import DataSourceType, MessageChainType
|
||||
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
|
|
@ -166,7 +166,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
name="Test conversation",
|
||||
inputs={},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
|
|
@ -196,7 +196,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_account_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
|
@ -216,8 +216,8 @@ class TestMessagesCleanServiceIntegration:
|
|||
app_id=message.app_id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating="like",
|
||||
from_source="api",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(feedback)
|
||||
|
|
@ -249,7 +249,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
type="image",
|
||||
transfer_method="local_file",
|
||||
url="http://example.com/test.jpg",
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
created_by_role="end_user",
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,41 +48,42 @@ class TestToolTransformService:
|
|||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
|
||||
provider_type="api",
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
elif provider_type == "builtin":
|
||||
provider = BuiltinToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon="🔧",
|
||||
icon_dark="🔧",
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
provider="test_provider",
|
||||
credential_type="api_key",
|
||||
credentials={"api_key": "test_key"},
|
||||
encrypted_credentials='{"api_key": "test_key"}',
|
||||
)
|
||||
elif provider_type == "workflow":
|
||||
provider = WorkflowToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
workflow_id="test_workflow_id",
|
||||
app_id="test_workflow_id",
|
||||
label="Test Workflow",
|
||||
version="1.0.0",
|
||||
parameter_configuration="[]",
|
||||
)
|
||||
elif provider_type == "mcp":
|
||||
provider = MCPToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
server_url="https://mcp.example.com",
|
||||
server_url_hash="test_server_url_hash",
|
||||
server_identifier="test_server",
|
||||
tools='[{"name": "test_tool", "description": "Test tool"}]',
|
||||
authed=True,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
@ -209,7 +210,7 @@ class TestBatchCleanDocumentTask:
|
|||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=account.current_tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
|
|
@ -203,7 +204,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
|
|
@ -254,7 +255,7 @@ class TestCleanDatasetTask:
|
|||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
|
|
@ -925,7 +926,7 @@ class TestCleanDatasetTask:
|
|||
special_filename = f"test_file_{special_content}.txt"
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{special_filename}",
|
||||
name=special_filename,
|
||||
size=1024,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from core.db.session_factory import session_factory
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -78,7 +79,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id:
|
|||
for i in range(count):
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test/file-{uuid.uuid4()}-{i}.json",
|
||||
name=f"file-{i}.json",
|
||||
size=1024 + i,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,56 @@
|
|||
from pathlib import Path
|
||||
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
|
||||
|
||||
class TestOpenDALFsDefaultRoot:
|
||||
"""Test that OpenDALStorage with scheme='fs' works correctly when no root is provided."""
|
||||
|
||||
def test_fs_without_root_uses_default(self, tmp_path, monkeypatch):
|
||||
"""When no root is specified, the default 'storage' should be used and passed to the Operator."""
|
||||
# Change to tmp_path so the default "storage" dir is created there
|
||||
monkeypatch.chdir(tmp_path)
|
||||
# Ensure no OPENDAL_FS_ROOT env var is set
|
||||
monkeypatch.delenv("OPENDAL_FS_ROOT", raising=False)
|
||||
|
||||
storage = OpenDALStorage(scheme="fs")
|
||||
|
||||
# The default directory should have been created
|
||||
assert (tmp_path / "storage").is_dir()
|
||||
# The storage should be functional
|
||||
storage.save("test_default_root.txt", b"hello")
|
||||
assert storage.exists("test_default_root.txt")
|
||||
assert storage.load_once("test_default_root.txt") == b"hello"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_default_root.txt")
|
||||
|
||||
def test_fs_with_explicit_root(self, tmp_path):
|
||||
"""When root is explicitly provided, it should be used."""
|
||||
custom_root = str(tmp_path / "custom_storage")
|
||||
storage = OpenDALStorage(scheme="fs", root=custom_root)
|
||||
|
||||
assert Path(custom_root).is_dir()
|
||||
storage.save("test_explicit_root.txt", b"world")
|
||||
assert storage.exists("test_explicit_root.txt")
|
||||
assert storage.load_once("test_explicit_root.txt") == b"world"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_explicit_root.txt")
|
||||
|
||||
def test_fs_with_env_var_root(self, tmp_path, monkeypatch):
|
||||
"""When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs."""
|
||||
env_root = str(tmp_path / "env_storage")
|
||||
monkeypatch.setenv("OPENDAL_FS_ROOT", env_root)
|
||||
# Ensure .env file doesn't interfere
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
storage = OpenDALStorage(scheme="fs")
|
||||
|
||||
assert Path(env_root).is_dir()
|
||||
storage.save("test_env_root.txt", b"env_data")
|
||||
assert storage.exists("test_env_root.txt")
|
||||
assert storage.load_once("test_env_root.txt") == b"env_data"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_env_root.txt")
|
||||
|
|
@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import (
|
|||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
|
|
@ -1121,7 +1122,7 @@ class TestDatasetIndexingEstimateApi:
|
|||
def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="key",
|
||||
name="name.txt",
|
||||
size=1,
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class TestGetUser:
|
|||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
mock_session.get.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
|
|
@ -58,7 +58,7 @@ class TestGetUser:
|
|||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.query.assert_called_once()
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
|
|
@ -72,7 +72,8 @@ class TestGetUser:
|
|||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
# non-anonymous path uses session.get(); anonymous uses session.scalar()
|
||||
mock_session.get.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
|
|
@ -89,7 +90,7 @@ class TestGetUser:
|
|||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_session.get.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
|
|
@ -103,18 +104,20 @@ class TestGetUser:
|
|||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_use_default_session_id_when_user_id_none(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask
|
||||
):
|
||||
"""Test using default session ID when user_id is None"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
# When user_id is None, is_anonymous=True, so session.scalar() is used
|
||||
mock_session.scalar.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
|
|
@ -133,7 +136,7 @@ class TestGetUser:
|
|||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.side_effect = Exception("Database error")
|
||||
mock_session.get.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
|
|
@ -161,9 +164,9 @@ class TestGetUserTenant:
|
|||
# Act
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
|
|
@ -194,8 +197,8 @@ class TestGetUserTenant:
|
|||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
|
||||
mock_get.return_value = None
|
||||
with pytest.raises(ValueError, match="tenant not found"):
|
||||
protected_view()
|
||||
|
||||
|
|
@ -215,9 +218,9 @@ class TestGetUserTenant:
|
|||
# Act - use empty string for user_id to trigger default logic
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
|
|
|
|||
|
|
@ -249,8 +249,8 @@ class TestEnterpriseInnerApiUserAuth:
|
|||
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_user
|
||||
with patch("controllers.inner_api.wraps.db.session.get") as mock_get:
|
||||
mock_get.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class TestEnterpriseWorkspace:
|
|||
# Arrange
|
||||
mock_account = MagicMock()
|
||||
mock_account.email = "owner@example.com"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
|
|
@ -122,7 +122,7 @@ class TestEnterpriseWorkspace:
|
|||
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
|
||||
"""Test that post() returns 404 when the owner account does not exist"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from controllers.service_api.app.message import (
|
|||
MessageListQuery,
|
||||
MessageSuggestedApi,
|
||||
)
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
|
|
@ -310,7 +311,7 @@ class TestMessageService:
|
|||
app_model=Mock(spec=App),
|
||||
message_id=str(uuid.uuid4()),
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
rating=FeedbackRating.LIKE,
|
||||
content="Great response!",
|
||||
)
|
||||
|
||||
|
|
@ -326,7 +327,7 @@ class TestMessageService:
|
|||
app_model=Mock(spec=App),
|
||||
message_id="invalid_message_id",
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,14 +39,21 @@ class TestHitTestingPayload:
|
|||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all optional fields."""
|
||||
retrieval_model_data = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
"top_k": 5,
|
||||
}
|
||||
payload = HitTestingPayload(
|
||||
query="test query",
|
||||
retrieval_model={"top_k": 5},
|
||||
retrieval_model=retrieval_model_data,
|
||||
external_retrieval_model={"provider": "openai"},
|
||||
attachment_ids=["att_1", "att_2"],
|
||||
)
|
||||
assert payload.query == "test query"
|
||||
assert payload.retrieval_model == {"top_k": 5}
|
||||
assert payload.retrieval_model is not None
|
||||
assert payload.retrieval_model.top_k == 5
|
||||
assert payload.external_retrieval_model == {"provider": "openai"}
|
||||
assert payload.attachment_ids == ["att_1", "att_2"]
|
||||
|
||||
|
|
@ -134,7 +141,13 @@ class TestHitTestingApiPost:
|
|||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8}
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": True,
|
||||
"top_k": 10,
|
||||
"score_threshold": 0.8,
|
||||
}
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
|
|
@ -152,7 +165,11 @@ class TestHitTestingApiPost:
|
|||
|
||||
assert response["query"] == "complex query"
|
||||
call_kwargs = mock_hit_svc.retrieve.call_args
|
||||
assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model
|
||||
# retrieval_model is serialized via model_dump, verify key fields
|
||||
passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model")
|
||||
assert passed_retrieval_model is not None
|
||||
assert passed_retrieval_model["search_method"] == "semantic_search"
|
||||
assert passed_retrieval_model["top_k"] == 10
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
|
|
|
|||
|
|
@ -49,6 +49,17 @@ class _FakeSession:
|
|||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
def get(self, model, ident):
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
def scalar(self, stmt):
|
||||
# Extract the model name from the select statement's column_descriptions
|
||||
try:
|
||||
name = stmt.column_descriptions[0]["entity"].__name__
|
||||
except (AttributeError, IndexError, KeyError):
|
||||
return None
|
||||
return self._mapping.get(name)
|
||||
|
||||
|
||||
class _FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class TestAppSiteApi:
|
|||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
site_obj = _site()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
|
||||
mock_db.session.scalar.return_value = site_obj
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
|
@ -66,9 +66,9 @@ class TestAppSiteApi:
|
|||
@patch("controllers.web.site.db")
|
||||
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
|
|
@ -80,7 +80,7 @@ class TestAppSiteApi:
|
|||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
from models.account import TenantStatus
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
|
||||
mock_db.session.scalar.return_value = _site()
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.ARCHIVE,
|
||||
|
|
|
|||
|
|
@ -44,11 +44,22 @@ class TestAgentChatAppGenerateResponseConverterBlocking:
|
|||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
|
|
@ -107,11 +118,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
|
|||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "summary",
|
||||
"extra": "ignored",
|
||||
}
|
||||
|
|
@ -151,11 +173,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
|
|||
assert "usage" not in metadata
|
||||
assert metadata["retriever_resources"] == [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "summary",
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
|
|||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
|
|
@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
|
|||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_prefers_event_finished_at(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Finished timestamps should come from the event, not delayed queue processing time."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.common.workflow_response_converter.naive_utc_now",
|
||||
lambda: delayed_processing_time,
|
||||
)
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.CODE,
|
||||
node_execution_id="node-exec-1",
|
||||
start_at=start_at,
|
||||
finished_at=finished_at,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.elapsed_time == 2.0
|
||||
assert response.data.finished_at == int(finished_at.timestamp())
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
|
|
|||
|
|
@ -38,11 +38,22 @@ class TestCompletionAppGenerateResponseConverter:
|
|||
metadata = {
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "c",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "sum",
|
||||
"extra": "x",
|
||||
}
|
||||
|
|
@ -66,7 +77,12 @@ class TestCompletionAppGenerateResponseConverter:
|
|||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
assert result["metadata"]["retriever_resources"][0]["dataset_id"] == "dataset-1"
|
||||
assert result["metadata"]["retriever_resources"][0]["document_id"] == "document-1"
|
||||
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
|
||||
assert result["metadata"]["retriever_resources"][0]["data_source_type"] == "file"
|
||||
assert result["metadata"]["retriever_resources"][0]["segment_position"] == 3
|
||||
assert result["metadata"]["retriever_resources"][0]["index_node_hash"] == "abc1234"
|
||||
assert "extra" not in result["metadata"]["retriever_resources"][0]
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_not_dict(self):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.workflow.layers.persistence import (
|
||||
PersistenceWorkflowInfo,
|
||||
WorkflowPersistenceLayer,
|
||||
_NodeRuntimeSnapshot,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_layer() -> WorkflowPersistenceLayer:
|
||||
application_generate_entity = Mock()
|
||||
application_generate_entity.inputs = {}
|
||||
|
||||
return WorkflowPersistenceLayer(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version="1",
|
||||
graph_data={},
|
||||
),
|
||||
workflow_execution_repository=Mock(),
|
||||
workflow_node_execution_repository=Mock(),
|
||||
)
|
||||
|
||||
|
||||
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
layer = _build_layer()
|
||||
node_execution = Mock()
|
||||
node_execution.id = "node-exec-1"
|
||||
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
node_execution.update_from_mapping = Mock()
|
||||
|
||||
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
|
||||
node_id="node-id",
|
||||
title="LLM",
|
||||
predecessor_node_id=None,
|
||||
iteration_id="iter-1",
|
||||
loop_id=None,
|
||||
created_at=node_execution.created_at,
|
||||
)
|
||||
|
||||
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
|
||||
|
||||
layer._update_node_execution(
|
||||
node_execution,
|
||||
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event_finished_at,
|
||||
)
|
||||
|
||||
assert node_execution.finished_at == event_finished_at
|
||||
assert node_execution.elapsed_time == 2.0
|
||||
|
|
@ -166,6 +166,7 @@ class TestDatasourceFileManager:
|
|||
# Setup
|
||||
mock_guess_ext.return_value = None # Cannot guess
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
import queue
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from dify_graph.graph_engine.worker import Worker
|
||||
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
|
||||
|
||||
|
||||
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
|
||||
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=InMemoryReadyQueue(),
|
||||
event_queue=queue.Queue(),
|
||||
graph=MagicMock(),
|
||||
layers=[],
|
||||
)
|
||||
node = SimpleNamespace(
|
||||
execution_id="exec-1",
|
||||
id="node-1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
)
|
||||
|
||||
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
|
||||
|
||||
assert event.start_at == fixed_time
|
||||
assert event.finished_at == fixed_time
|
||||
assert event.error == "boom"
|
||||
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert event.node_run_result.error == "boom"
|
||||
assert event.node_run_result.error_type == "RuntimeError"
|
||||
|
||||
|
||||
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
|
||||
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
failure_time = start_at + timedelta(seconds=5)
|
||||
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
|
||||
|
||||
class FakeNode:
|
||||
execution_id = "exec-1"
|
||||
id = "node-1"
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
|
||||
yield NodeRunStartedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
node_title="LLM",
|
||||
start_at=start_at,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=MagicMock(),
|
||||
event_queue=MagicMock(),
|
||||
graph=MagicMock(nodes={"node-1": FakeNode()}),
|
||||
layers=[],
|
||||
)
|
||||
|
||||
worker._ready_queue.get.side_effect = ["node-1"]
|
||||
|
||||
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
|
||||
captured_events.append(event)
|
||||
if len(captured_events) == 1:
|
||||
raise RuntimeError("queue boom")
|
||||
worker.stop()
|
||||
|
||||
worker._event_queue.put.side_effect = put_side_effect
|
||||
|
||||
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
|
||||
worker.run()
|
||||
|
||||
fallback_event = captured_events[-1]
|
||||
|
||||
assert isinstance(fallback_event, NodeRunFailedEvent)
|
||||
assert fallback_event.start_at == start_at
|
||||
assert fallback_event.finished_at == failure_time
|
||||
assert fallback_event.error == "queue boom"
|
||||
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
|
||||
|
||||
def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None:
|
||||
parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
child_start = parent_start + timedelta(seconds=3)
|
||||
failure_time = parent_start + timedelta(seconds=5)
|
||||
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
|
||||
|
||||
class FakeIterationNode:
|
||||
execution_id = "iteration-exec"
|
||||
id = "iteration-node"
|
||||
node_type = BuiltinNodeTypes.ITERATION
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
|
||||
yield NodeRunStartedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
node_title="Iteration",
|
||||
start_at=parent_start,
|
||||
)
|
||||
yield NodeRunStartedEvent(
|
||||
id="child-exec",
|
||||
node_id="child-node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
node_title="LLM",
|
||||
start_at=child_start,
|
||||
in_iteration_id=self.id,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=MagicMock(),
|
||||
event_queue=MagicMock(),
|
||||
graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}),
|
||||
layers=[],
|
||||
)
|
||||
|
||||
worker._ready_queue.get.side_effect = ["iteration-node"]
|
||||
|
||||
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
|
||||
captured_events.append(event)
|
||||
if len(captured_events) == 2:
|
||||
raise RuntimeError("queue boom")
|
||||
worker.stop()
|
||||
|
||||
worker._event_queue.put.side_effect = put_side_effect
|
||||
|
||||
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
|
||||
worker.run()
|
||||
|
||||
fallback_event = captured_events[-1]
|
||||
|
||||
assert isinstance(fallback_event, NodeRunFailedEvent)
|
||||
assert fallback_event.start_at == parent_start
|
||||
assert fallback_event.finished_at == failure_time
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.graph_events import NodeRunSucceededEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from dify_graph.nodes.iteration.iteration_node import IterationNode
|
||||
|
||||
|
||||
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Parallel Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["iteration", "output"],
|
||||
is_parallel=True,
|
||||
parallel_nums=2,
|
||||
error_handle_mode=ErrorHandleMode.TERMINATED,
|
||||
)
|
||||
node._capture_execution_context = lambda: nullcontext()
|
||||
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
|
||||
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
|
||||
|
||||
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
|
||||
return (
|
||||
0.1 + (index * 0.1),
|
||||
[
|
||||
NodeRunSucceededEvent(
|
||||
id=f"exec-{index}",
|
||||
node_id=f"llm-{index}",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
start_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
),
|
||||
],
|
||||
f"output-{item}",
|
||||
{},
|
||||
LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
|
||||
|
||||
outputs: list[object] = []
|
||||
iter_run_map: dict[str, float] = {}
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
generator = node._execute_parallel_iterations(
|
||||
iterator_list_value=["a", "b"],
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
for _ in generator:
|
||||
# Simulate a slow consumer replaying buffered events.
|
||||
time.sleep(0.02)
|
||||
|
||||
assert outputs == ["output-a", "output-b"]
|
||||
assert iter_run_map["0"] == pytest.approx(0.1)
|
||||
assert iter_run_map["1"] == pytest.approx(0.2)
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
from collections.abc import Mapping
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = build_test_graph_init_params(
|
||||
graph_config=graph_config,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", files=[]),
|
||||
user_inputs={"payload": "value"},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def _build_node_config() -> NodeConfigDict:
|
||||
return NodeConfigDictAdapter.validate_python(
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": TRIGGER_PLUGIN_NODE_TYPE,
|
||||
"title": "Trigger Event",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"subscription_id": "subscription-id",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
"event_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
id="node-1",
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events.base import NodeRunResult
|
||||
|
||||
|
||||
def test_node_run_result_accepts_trigger_info_metadata() -> None:
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, EndUser, Message
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -820,14 +821,14 @@ class TestMessageServiceFeedback:
|
|||
app_model=app,
|
||||
message_id="msg-123",
|
||||
user=user,
|
||||
rating="like",
|
||||
rating=FeedbackRating.LIKE,
|
||||
content="Good answer",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.rating == "like"
|
||||
assert result.rating == FeedbackRating.LIKE
|
||||
assert result.content == "Good answer"
|
||||
assert result.from_source == "user"
|
||||
assert result.from_source == FeedbackFromSource.USER
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
|
@ -852,13 +853,13 @@ class TestMessageServiceFeedback:
|
|||
app_model=app,
|
||||
message_id="msg-123",
|
||||
user=user,
|
||||
rating="dislike",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
content="Bad answer",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == feedback
|
||||
assert feedback.rating == "dislike"
|
||||
assert feedback.rating == FeedbackRating.DISLIKE
|
||||
assert feedback.content == "Bad answer"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
coverage:
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
target: auto
|
||||
|
||||
flags:
|
||||
web:
|
||||
paths:
|
||||
- "web/"
|
||||
carryforward: true
|
||||
|
||||
api:
|
||||
paths:
|
||||
- "api/"
|
||||
carryforward: true
|
||||
|
|
@ -7,17 +7,21 @@
|
|||
*/
|
||||
import type { InstalledApp } from '@/models/explore'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import SideBar from '@/app/components/explore/sidebar'
|
||||
import { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const { mockToastAdd } = vi.hoisted(() => ({
|
||||
mockToastAdd: vi.fn(),
|
||||
}))
|
||||
|
||||
let mockMediaType: string = MediaType.pc
|
||||
const mockSegments = ['apps']
|
||||
const mockPush = vi.fn()
|
||||
const mockUninstall = vi.fn()
|
||||
const mockUpdatePinStatus = vi.fn()
|
||||
let mockInstalledApps: InstalledApp[] = []
|
||||
let mockIsUninstallPending = false
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useSelectedLayoutSegments: () => mockSegments,
|
||||
|
|
@ -42,12 +46,22 @@ vi.mock('@/service/use-explore', () => ({
|
|||
}),
|
||||
useUninstallApp: () => ({
|
||||
mutateAsync: mockUninstall,
|
||||
isPending: mockIsUninstallPending,
|
||||
}),
|
||||
useUpdateAppPinStatus: () => ({
|
||||
mutateAsync: mockUpdatePinStatus,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockToastAdd,
|
||||
close: vi.fn(),
|
||||
update: vi.fn(),
|
||||
promise: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
const createInstalledApp = (overrides: Partial<InstalledApp> = {}): InstalledApp => ({
|
||||
id: overrides.id ?? 'app-1',
|
||||
uninstallable: overrides.uninstallable ?? false,
|
||||
|
|
@ -74,7 +88,7 @@ describe('Sidebar Lifecycle Flow', () => {
|
|||
vi.clearAllMocks()
|
||||
mockMediaType = MediaType.pc
|
||||
mockInstalledApps = []
|
||||
vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
|
||||
mockIsUninstallPending = false
|
||||
})
|
||||
|
||||
describe('Pin / Unpin / Delete Flow', () => {
|
||||
|
|
@ -91,7 +105,7 @@ describe('Sidebar Lifecycle Flow', () => {
|
|||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: true })
|
||||
expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'success',
|
||||
}))
|
||||
})
|
||||
|
|
@ -110,7 +124,7 @@ describe('Sidebar Lifecycle Flow', () => {
|
|||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: false })
|
||||
expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'success',
|
||||
}))
|
||||
})
|
||||
|
|
@ -136,9 +150,9 @@ describe('Sidebar Lifecycle Flow', () => {
|
|||
// Step 4: Uninstall API called and success toast shown
|
||||
await waitFor(() => {
|
||||
expect(mockUninstall).toHaveBeenCalledWith('app-1')
|
||||
expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'success',
|
||||
message: 'common.api.remove',
|
||||
title: 'common.api.remove',
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue