From e85d20031e8c62be70ea390e7042596d35b338a9 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Wed, 11 Mar 2026 18:29:53 +0800 Subject: [PATCH] feat: notification (#32192) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/__init__.py | 2 + api/controllers/console/admin.py | 170 +++++++++++++++++++++++- api/controllers/console/notification.py | 90 +++++++++++++ api/services/billing_service.py | 75 +++++++++++ 4 files changed, 336 insertions(+), 1 deletion(-) create mode 100644 api/controllers/console/notification.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 902d67174b..d624b10b22 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -39,6 +39,7 @@ from . import ( feature, human_input_form, init_validate, + notification, ping, setup, spec, @@ -184,6 +185,7 @@ __all__ = [ "model_config", "model_providers", "models", + "notification", "oauth", "oauth_server", "ops_trace", diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 03b602f6e8..6c3a6a8c1f 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,3 +1,5 @@ +import csv +import io from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar @@ -6,7 +8,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select -from werkzeug.exceptions import NotFound, Unauthorized +from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from configs import dify_config from constants.languages import supported_language @@ -16,6 +18,7 @@ from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp +from services.billing_service import BillingService P = ParamSpec("P") R = TypeVar("R") @@ -277,3 +280,168 @@ class DeleteExploreBannerApi(Resource): db.session.commit() return {"result": "success"}, 204 + + +class LangContentPayload(BaseModel): + lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'") + title: str = Field(...) + subtitle: str | None = Field(default=None) + body: str = Field(...) + title_pic_url: str | None = Field(default=None) + + +class UpsertNotificationPayload(BaseModel): + notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update") + contents: list[LangContentPayload] = Field(..., min_length=1) + start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z") + end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z") + frequency: str = Field(default="once", description="'once' | 'every_page_load'") + status: str = Field(default="active", description="'active' | 'inactive'") + + +class BatchAddNotificationAccountsPayload(BaseModel): + notification_id: str = Field(...) + user_email: list[str] = Field(..., description="List of account email addresses") + + +console_ns.schema_model( + UpsertNotificationPayload.__name__, + UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + BatchAddNotificationAccountsPayload.__name__, + BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +@console_ns.route("/admin/upsert_notification") +class UpsertNotificationApi(Resource): + @console_ns.doc("upsert_notification") + @console_ns.doc( + description=( + "Create or update an in-product notification. " + "Supply notification_id to update an existing one; omit it to create a new one. " + "Pass at least one language variant in contents (zh / en / jp)." + ) + ) + @console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__]) + @console_ns.response(200, "Notification upserted successfully") + @only_edition_cloud + @admin_required + def post(self): + payload = UpsertNotificationPayload.model_validate(console_ns.payload) + result = BillingService.upsert_notification( + contents=[c.model_dump() for c in payload.contents], + frequency=payload.frequency, + status=payload.status, + notification_id=payload.notification_id, + start_time=payload.start_time, + end_time=payload.end_time, + ) + return {"result": "success", "notification_id": result.get("notificationId")}, 200 + + +@console_ns.route("/admin/batch_add_notification_accounts") +class BatchAddNotificationAccountsApi(Resource): + @console_ns.doc("batch_add_notification_accounts") + @console_ns.doc( + description=( + "Register target accounts for a notification by email address. " + 'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. ' + "File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) " + "plus a 'notification_id' field. " + "Emails that do not match any account are silently skipped." + ) + ) + @console_ns.response(200, "Accounts added successfully") + @only_edition_cloud + @admin_required + def post(self): + from models.account import Account + + if "file" in request.files: + notification_id = request.form.get("notification_id", "").strip() + if not notification_id: + raise BadRequest("notification_id is required.") + emails = self._parse_emails_from_file() + else: + payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload) + notification_id = payload.notification_id + emails = payload.user_email + + if not emails: + raise BadRequest("No valid email addresses provided.") + + # Resolve emails → account IDs in chunks to avoid large IN-clause + account_ids: list[str] = [] + chunk_size = 500 + for i in range(0, len(emails), chunk_size): + chunk = emails[i : i + chunk_size] + rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all() + account_ids.extend(str(row.id) for row in rows) + + if not account_ids: + raise BadRequest("None of the provided emails matched an existing account.") + + # Send to dify-saas in batches of 1000 + total_count = 0 + batch_size = 1000 + for i in range(0, len(account_ids), batch_size): + batch = account_ids[i : i + batch_size] + result = BillingService.batch_add_notification_accounts( + notification_id=notification_id, + account_ids=batch, + ) + total_count += result.get("count", 0) + + return { + "result": "success", + "emails_provided": len(emails), + "accounts_matched": len(account_ids), + "count": total_count, + }, 200 + + @staticmethod + def _parse_emails_from_file() -> list[str]: + """Parse email addresses from an uploaded CSV or TXT file.""" + file = request.files["file"] + if not file.filename: + raise BadRequest("Uploaded file has no filename.") + + filename_lower = file.filename.lower() + if not filename_lower.endswith((".csv", ".txt")): + raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.") + + try: + content = file.read().decode("utf-8") + except UnicodeDecodeError: + try: + file.seek(0) + content = file.read().decode("gbk") + except UnicodeDecodeError: + raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.") + + emails: list[str] = [] + if filename_lower.endswith(".csv"): + reader = csv.reader(io.StringIO(content)) + for row in reader: + for cell in row: + cell = cell.strip() + if cell: + emails.append(cell) + else: + for line in content.splitlines(): + line = line.strip() + if line: + emails.append(line) + + # Deduplicate while preserving order + seen: set[str] = set() + unique_emails: list[str] = [] + for email in emails: + if email.lower() not in seen: + seen.add(email.lower()) + unique_emails.append(email) + + return unique_emails diff --git a/api/controllers/console/notification.py b/api/controllers/console/notification.py new file mode 100644 index 0000000000..53e4aa3d86 --- /dev/null +++ b/api/controllers/console/notification.py @@ -0,0 +1,90 @@ +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field + +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required +from libs.login import current_account_with_tenant, login_required +from services.billing_service import BillingService + +# Notification content is stored under three lang tags. +_FALLBACK_LANG = "en-US" + + +def _pick_lang_content(contents: dict, lang: str) -> dict: + """Return the single LangContent for *lang*, falling back to English.""" + return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) + + +class DismissNotificationPayload(BaseModel): + notification_id: str = Field(...) + + +@console_ns.route("/notification") +class NotificationApi(Resource): + @console_ns.doc("get_notification") + @console_ns.doc( + description=( + "Return the active in-product notification for the current user " + "in their interface language (falls back to English if unavailable). " + "The notification is NOT marked as seen here; call POST /notification/dismiss " + "when the user explicitly closes the modal." + ), + responses={ + 200: "Success — inspect should_show to decide whether to render the modal", + 401: "Unauthorized", + }, + ) + @setup_required + @login_required + @account_initialization_required + @only_edition_cloud + def get(self): + current_user, _ = current_account_with_tenant() + + result = BillingService.get_account_notification(str(current_user.id)) + + # Proto JSON uses camelCase field names (Kratos default marshaling). + if not result.get("shouldShow"): + return {"should_show": False, "notifications": []}, 200 + + lang = current_user.interface_language or _FALLBACK_LANG + + notifications = [] + for notification in result.get("notifications") or []: + contents: dict = notification.get("contents") or {} + lang_content = _pick_lang_content(contents, lang) + notifications.append( + { + "notification_id": notification.get("notificationId"), + "frequency": notification.get("frequency"), + "lang": lang_content.get("lang", lang), + "title": lang_content.get("title", ""), + "subtitle": lang_content.get("subtitle", ""), + "body": lang_content.get("body", ""), + "title_pic_url": lang_content.get("titlePicUrl", ""), + } + ) + + return {"should_show": bool(notifications), "notifications": notifications}, 200 + + +@console_ns.route("/notification/dismiss") +class NotificationDismissApi(Resource): + @console_ns.doc("dismiss_notification") + @console_ns.doc( + description="Mark a notification as dismissed for the current user.", + responses={200: "Success", 401: "Unauthorized"}, + ) + @setup_required + @login_required + @account_initialization_required + @only_edition_cloud + def post(self): + current_user, _ = current_account_with_tenant() + payload = DismissNotificationPayload.model_validate(request.get_json()) + BillingService.dismiss_notification( + notification_id=payload.notification_id, + account_id=str(current_user.id), + ) + return {"result": "success"}, 200 diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 946b8cdfdb..5ab47c799a 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -393,3 +393,78 @@ class BillingService: for item in data: tenant_whitelist.append(item["tenant_id"]) return tenant_whitelist + + @classmethod + def get_account_notification(cls, account_id: str) -> dict: + """Return the active in-product notification for account_id, if any. + + Calling this endpoint also marks the notification as seen; subsequent + calls will return should_show=false when frequency='once'. + + Response shape (mirrors GetAccountNotificationReply): + { + "should_show": bool, + "notification": { # present only when should_show=true + "notification_id": str, + "contents": { # lang -> LangContent + "en": {"lang": "en", "title": ..., "subtitle": ..., "body": ..., "title_pic_url": ...}, + ... + }, + "frequency": "once" | "every_page_load" + } + } + """ + return cls._send_request("GET", "/notifications/active", params={"account_id": account_id}) + + @classmethod + def upsert_notification( + cls, + contents: list[dict], + frequency: str = "once", + status: str = "active", + notification_id: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + ) -> dict: + """Create or update a notification. + + contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str} + start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional. + Returns {"notification_id": str}. + """ + payload: dict = { + "contents": contents, + "frequency": frequency, + "status": status, + } + if notification_id: + payload["notification_id"] = notification_id + if start_time: + payload["start_time"] = start_time + if end_time: + payload["end_time"] = end_time + return cls._send_request("POST", "/notifications", json=payload) + + @classmethod + def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict: + """Register target account IDs for a notification (max 1000 per call). + + Returns {"count": int}. + """ + return cls._send_request( + "POST", + f"/notifications/{notification_id}/accounts", + json={"account_ids": account_ids}, + ) + + @classmethod + def dismiss_notification(cls, notification_id: str, account_id: str) -> dict: + """Mark a notification as dismissed for an account. + + Returns {"success": bool}. + """ + return cls._send_request( + "POST", + f"/notifications/{notification_id}/dismiss", + json={"account_id": account_id}, + )