feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path):
def main(): def main():
has_chinese = False has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py'] excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
for root, _, files in os.walk("."): for root, _, files in os.walk("."):
for file in files: for file in files:

View File

@@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret NOTION_INTERNAL_SECRET=you-internal-secret
# Hosted Model Credentials
HOSTED_OPENAI_ENABLED=false
HOSTED_OPENAI_API_KEY=
HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
HOSTED_AZURE_OPENAI_API_BASE=
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=

View File

@@ -16,8 +16,9 @@ from flask import Flask, request, Response, session
import flask_login import flask_login
from flask_cors import CORS from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
@@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app) register_blueprints(app)
register_commands(app) register_commands(app)
core.init_app(app) hosted.init_app(app)
return app return app
@@ -88,6 +89,7 @@ def initialize_extensions(app):
ext_login.init_app(app) ext_login.init_app(app)
ext_mail.init_app(app) ext_mail.init_app(app)
ext_sentry.init_app(app) ext_sentry.init_app(app)
ext_stripe.init_app(app)
def _create_tenant_for_account(account): def _create_tenant_for_account(account):
@@ -246,5 +248,18 @@ def threads():
} }
@app.route('/db-pool-stat')
def pool_stat():
engine = db.engine
return {
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
}
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001) app.run(host='0.0.0.0', port=5001)

View File

@@ -1,5 +1,5 @@
import datetime import datetime
import logging import math
import random import random
import string import string
import time import time
@@ -9,18 +9,18 @@ from flask import current_app
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.model_providers.providers.hosted import hosted_model_providers
from libs.password import password_pattern, valid_password, hash_password from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate from libs.helper import email as email_validate
from extensions.ext_database import db from extensions.ext_database import db
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant from models.account import InvitationCode, Tenant
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment from models.dataset import Dataset, DatasetQuery, Document
from models.model import Account from models.model import Account
import secrets import secrets
import base64 import base64
from models.provider import Provider, ProviderName from models.provider import Provider, ProviderType, ProviderQuotaType
from services.provider_service import ProviderService
@click.command('reset-password', help='Reset the account password.') @click.command('reset-password', help='Reset the account password.')
@@ -251,26 +251,37 @@ def clean_unused_dataset_indexes():
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.') @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers(): def sync_anthropic_hosted_providers():
if not hosted_model_providers.anthropic:
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
return
click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0 count = 0
page = 1 page = 1
while True: while True:
try: try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50) providers = db.session.query(Provider).filter(
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound: except NotFound:
break break
page += 1 page += 1
for tenant in tenants: for provider in providers:
try: try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id)) click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
ProviderService.create_system_provider( original_quota_limit = provider.quota_limit
tenant, new_quota_limit = hosted_model_providers.anthropic.quota_limit
ProviderName.ANTHROPIC.value, division = math.ceil(new_quota_limit / 1000)
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
) else original_quota_limit * division
provider.quota_used = division * provider.quota_used
db.session.commit()
count += 1 count += 1
except Exception as e: except Exception as e:
click.echo(click.style( click.echo(click.style(

View File

@@ -41,6 +41,7 @@ DEFAULTS = {
'SESSION_USE_SIGNER': 'True', 'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION', 'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30, 'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600,
'SQLALCHEMY_ECHO': 'False', 'SQLALCHEMY_ECHO': 'False',
'SENTRY_TRACES_SAMPLE_RATE': 1.0, 'SENTRY_TRACES_SAMPLE_RATE': 1.0,
'SENTRY_PROFILES_SAMPLE_RATE': 1.0, 'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
@@ -50,9 +51,16 @@ DEFAULTS = {
'PDF_PREVIEW': 'True', 'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai', 'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'OPENAI_HOSTED_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_ENABLED': 'False',
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000, 'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'TENANT_DOCUMENT_COUNT': 100, 'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30 'CLEAN_DAY_SETTING': 30
} }
@@ -182,7 +190,10 @@ class Config:
} }
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}" self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))} self.SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
}
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
@@ -194,20 +205,35 @@ class Config:
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
# hosted provider credentials # hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY') self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT') self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# By default it is False # By default it is False
# You could disable it for compatibility with certain OpenAPI providers # You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting # notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')

View File

@@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers # Import workspace controllers
from .workspace import workspace, members, model_providers, account, tool_providers from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
# Import explore controllers # Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
# Import universal chat controllers # Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio from .universal_chat import chat, conversation, message, parameter, audio
# Import webhook controllers
from .webhook import stripe

View File

@@ -2,16 +2,17 @@
import json import json
from datetime import datetime from datetime import datetime
import flask
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from werkzeug.exceptions import Unauthorized, Forbidden from werkzeug.exceptions import Forbidden
from constants.model_template import model_templates, demo_model_templates from constants.model_template import model_templates, demo_model_templates
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
@@ -126,9 +127,9 @@ class AppListApi(Resource):
if args['model_config'] is not None: if args['model_config'] is not None:
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=args['model_config'], config=args['model_config']
mode=args['mode']
) )
app = App( app = App(
@@ -164,6 +165,21 @@ class AppListApi(Resource):
app = App(**model_config_template['app']) app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config']) app_model_config = AppModelConfig(**model_config_template['model_config'])
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.provider_name
model_dict['name'] = default_model.model_name
app_model_config.model = json.dumps(model_dict)
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
app.name = args['name'] app.name = args['name']
app.mode = args['mode'] app.mode = args['mode']
app.icon = args['icon'] app.icon = args['icon']

View File

@@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource from flask_restful import Resource
from services.audio_service import AudioService from services.audio_service import AudioService

View File

@@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.conversation_message_task import PubHandler from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value from libs.helper import uuid_value
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
@@ -41,8 +41,11 @@ class CompletionMessageApi(Resource):
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json') parser.add_argument('query', type=str, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
account = flask_login.current_user account = flask_login.current_user
try: try:
@@ -51,7 +54,7 @@ class CompletionMessageApi(Resource):
user=account, user=account,
args=args, args=args,
from_source='console', from_source='console',
streaming=True, streaming=streaming,
is_model_config_override=True is_model_config_override=True
) )
@@ -111,8 +114,11 @@ class ChatMessageApi(Resource):
parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
account = flask_login.current_user account = flask_login.current_user
try: try:
@@ -121,7 +127,7 @@ class ChatMessageApi(Resource):
user=account, user=account,
args=args, args=args,
from_source='console', from_source='console',
streaming=True, streaming=streaming,
is_model_config_override=True is_model_config_override=True
) )

View File

@@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.generator.llm_generator import LLMGenerator from core.generator.llm_generator import LLMGenerator
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError

View File

@@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination

View File

@@ -28,9 +28,9 @@ class ModelConfigResource(Resource):
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=request.json, config=request.json
mode=app_model.mode
) )
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(

View File

@@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource):
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule']) response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
return response, 200 return response, 200

View File

@@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment, Document
@@ -97,6 +100,15 @@ class DatasetListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource):
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif args['info_list']['data_source_type'] == 'notion_import': elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form']) try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
return response, 200 return response, 200

View File

@@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
@@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource):
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
@@ -319,6 +330,15 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
@@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict) try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
return response return response
@@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict) try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif dataset.data_source_type: elif dataset.data_source_type:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(info_list, try:
data_process_rule_dict) response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
return response return response

View File

@@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField from libs.helper import TimestampField
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
@@ -102,6 +102,8 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except ValueError as e:
raise ValueError(str(e))
except Exception as e: except Exception as e:
logging.exception("Hit testing failed.") logging.exception("Hit testing failed.")
raise InternalServerError(str(e)) raise InternalServerError(str(e))

View File

@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \ NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

View File

@@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.conversation_message_task import PubHandler from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService

View File

@@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService from services.completion_service import CompletionService

View File

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.console import api from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import InstalledApp from models.model import InstalledApp
@@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, 'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }

View File

@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \ NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

View File

@@ -12,9 +12,8 @@ from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from core.constant import llm_constant
from core.conversation_message_task import PubHandler from core.conversation_message_task import PubHandler
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService
@@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('provider', type=str, required=True, location='json')
parser.add_argument('model', type=str, required=True, location='json') parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json') parser.add_argument('tools', type=list, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
@@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource):
# update app model config # update app model config
args['model_config'] = app_model_config.to_dict() args['model_config'] = app_model_config.to_dict()
args['model_config']['model']['name'] = args['model'] args['model_config']['model']['name'] = args['model']
args['model_config']['model']['provider'] = args['provider']
if not llm_constant.models[args['model']]:
raise ValueError("Model not exists.")
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
args['model_config']['agent_mode']['tools'] = args['tools'] args['model_config']['agent_mode']['tools'] = args['tools']
if not args['model_config']['agent_mode']['tools']: if not args['model_config']['agent_mode']['tools']:

View File

@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError

View File

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.console import api from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App from models.model import App
@@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = universal_app app_model = universal_app
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, 'speech_to_text': app_model_config.speech_to_text_dict,
} }

View File

@@ -0,0 +1,53 @@
import logging
import stripe
from flask import request, current_app
from flask_restful import Resource
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import only_edition_cloud
from services.provider_checkout_service import ProviderCheckoutService
class StripeWebhookApi(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
# Handle the checkout.session.completed event
if event['type'] == 'checkout.session.completed':
logging.debug(event['data']['object']['id'])
logging.debug(event['data']['object']['amount_subtotal'])
logging.debug(event['data']['object']['currency'])
logging.debug(event['data']['object']['payment_intent'])
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event)
except Exception as e:
logging.debug(str(e))
return 'success', 200
return 'success', 200
api.add_resource(StripeWebhookApi, '/webhook/stripe')

View File

@@ -1,24 +1,18 @@
# -*- coding:utf-8 -*-
import base64
import json
import logging
from flask import current_app
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.llm.provider.errors import ValidateFailedError from core.model_providers.error import LLMBadRequestError
from extensions.ext_database import db from core.model_providers.providers.base import CredentialsValidateFailedError
from libs import rsa from services.provider_checkout_service import ProviderCheckoutService
from models.provider import Provider, ProviderType, ProviderName
from services.provider_service import ProviderService from services.provider_service import ProviderService
class ProviderListApi(Resource): class ModelProviderListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -26,156 +20,36 @@ class ProviderListApi(Resource):
def get(self): def get(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
""" provider_service = ProviderService()
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, provider_list = provider_service.get_provider_list(tenant_id)
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
ProviderService.init_supported_provider(current_user.current_tenant)
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [
{
'provider_name': p.provider_name,
'provider_type': p.provider_type,
'is_valid': p.is_valid,
'last_used': p.last_used,
'is_enabled': p.is_enabled,
**({
'quota_type': p.quota_type,
'quota_limit': p.quota_limit,
'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name), only_custom=True)
if p.provider_type == ProviderType.CUSTOM.value else None
}
for p in providers
]
return provider_list return provider_list
class ProviderTokenApi(Resource): class ModelProviderValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider_name: str):
if provider not in [p.value for p in ProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
logging.log(logging.ERROR,
f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}')
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
parser.add_argument('token', type=ProviderService.get_token_type(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider)
), required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
if args['token']: provider_service = ProviderService()
try:
ProviderService.validate_provider_configs(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
token_is_valid = True
except ValidateFailedError as ex:
raise ValueError(str(ex))
base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
else:
base64_encrypted_token = None
token_is_valid = False
tenant = current_user.current_tenant
provider_model = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# Only allow updating token for CUSTOM provider type
if provider_model:
provider_model.encrypted_config = base64_encrypted_token
provider_model.is_valid = token_is_valid
else:
provider_model = Provider(tenant_id=tenant.id, provider_name=provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=base64_encrypted_token,
is_valid=token_is_valid)
db.session.add(provider_model)
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
for other_provider in other_providers:
other_provider.is_valid = False
db.session.commit()
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('token', type=ProviderService.get_token_type(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider)
), required=True, nullable=False, location='json')
args = parser.parse_args()
# todo: remove this when the provider is supported
if provider in [ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
result = True result = True
error = None error = None
try: try:
ProviderService.validate_provider_configs( provider_service.custom_provider_config_validate(
tenant=current_user.current_tenant, provider_name=provider_name,
provider_name=ProviderName(provider), config=args['config']
configs=args['token']
) )
except ValidateFailedError as e: except CredentialsValidateFailedError as ex:
result = False result = False
error = str(e) error = str(ex)
response = {'result': 'success' if result else 'error'} response = {'result': 'success' if result else 'error'}
@@ -185,91 +59,227 @@ class ProviderTokenValidateApi(Resource):
return response return response
class ProviderSystemApi(Resource): class ModelProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, provider): def post(self, provider_name: str):
if provider not in [p.value for p in ProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('is_enabled', type=bool, required=True, location='json')
args = parser.parse_args()
tenant = current_user.current_tenant_id
provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first()
if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value:
provider_model.is_valid = args['is_enabled']
db.session.commit()
elif not provider_model:
if provider == ProviderName.OPENAI.value:
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
elif provider == ProviderName.ANTHROPIC.value:
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
else:
quota_limit = 0
ProviderService.create_system_provider(
tenant,
provider,
quota_limit,
args['is_enabled']
)
else:
abort(403)
return {'result': 'success'}
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id, parser = reqparse.RequestParser()
Provider.provider_name == provider, parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
Provider.provider_type == ProviderType.SYSTEM.value).first() args = parser.parse_args()
system_model = None provider_service = ProviderService()
if provider_model:
system_model = { try:
'result': 'success', provider_service.save_custom_provider_config(
'provider': { tenant_id=current_user.current_tenant_id,
'provider_name': provider_model.provider_name, provider_name=provider_name,
'provider_type': provider_model.provider_type, config=args['config']
'is_valid': provider_model.is_valid, )
'last_used': provider_model.last_used, except CredentialsValidateFailedError as ex:
'is_enabled': provider_model.is_enabled, raise ValueError(str(ex))
'quota_type': provider_model.quota_type,
'quota_limit': provider_model.quota_limit, return {'result': 'success'}, 201
'quota_used': provider_model.quota_used
} @setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
provider_service = ProviderService()
provider_service.delete_custom_provider(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
)
return {'result': 'success'}, 204
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
result = True
error = None
try:
provider_service.custom_provider_model_config_validate(
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ModelProviderModelUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
try:
provider_service.add_or_save_custom_provider_model_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.delete_custom_provider_model(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type']
)
return {'result': 'success'}, 204
class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
choices=['system', 'custom'], location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.switch_preferred_provider(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
preferred_provider_type=args['preferred_provider_type']
)
return {'result': 'success'}
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
provider_service = ProviderService()
try:
parameter_rules = provider_service.get_model_parameter_rules(
tenant_id=current_user.current_tenant_id,
model_provider_name=provider_name,
model_name=args['model_name'],
model_type='text-generation'
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"Current Text Generation Model is invalid. Please switch to the available model.")
rules = {
k: {
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default
} }
else: for k, v in vars(parameter_rules).items()
abort(404) }
return system_model return rules
api.add_resource(ProviderTokenApi, '/providers/<provider>/token', class ModelProviderPaymentCheckoutUrlApi(Resource):
endpoint='current_providers_token') # Deprecated @setup_required
api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate', @login_required
endpoint='current_providers_token_validate') # Deprecated @account_initialization_required
def get(self, provider_name: str):
provider_service = ProviderCheckoutService()
provider_checkout = provider_service.create_checkout(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
account=current_user
)
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token', return {
endpoint='workspaces_current_providers_token') # PUT for updating provider token 'url': provider_checkout.get_checkout_url()
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate', }
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system', api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
api.add_resource(ModelProviderModelValidateApi,
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
api.add_resource(ModelProviderModelUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/models')
api.add_resource(PreferredProviderTypeUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')

View File

@@ -0,0 +1,108 @@
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from models.provider import ProviderType
from services.provider_service import ProviderService
class DefaultModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
provider_service = ProviderService()
default_model = provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type=args['model_type']
)
if not default_model:
return None
model_provider = ModelProviderFactory.get_preferred_model_provider(
tenant_id,
default_model.provider_name
)
if not model_provider:
return {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': default_model.provider_name
}
}
provider = model_provider.provider
rst = {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': provider.provider_name,
'provider_type': provider.provider_type
}
}
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
if provider.provider_type == ProviderType.SYSTEM.value:
rst['model_provider']['quota_type'] = provider.quota_type
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
rst['model_provider']['quota_limit'] = provider.quota_limit
rst['model_provider']['quota_used'] = provider.quota_used
return rst
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=args['model_type'],
provider_name=args['provider_name'],
model_name=args['model_name']
)
return {'result': 'success'}
class ValidModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, model_type):
ModelType.value_of(model_type)
provider_service = ProviderService()
valid_models = provider_service.get_valid_model_list(
tenant_id=current_user.current_tenant_id,
model_type=model_type
)
return valid_models
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

View File

@@ -0,0 +1,130 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType
from services.provider_service import ProviderService
class ProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
provider_service = ProviderService()
provider_info_list = provider_service.get_provider_list(tenant_id)
provider_list = [
{
'provider_name': p['provider_name'],
'provider_type': p['provider_type'],
'is_valid': p['is_valid'],
'last_used': p['last_used'],
'is_enabled': p['is_valid'],
**({
'quota_type': p['quota_type'],
'quota_limit': p['quota_limit'],
'quota_used': p['quota_used']
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
if p['config'] else None
}
for name, provider_info in provider_info_list.items()
for p in provider_info['providers']
]
return provider_list
class ProviderTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
provider_service = ProviderService()
try:
provider_service.save_custom_provider_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list

View File

@@ -30,7 +30,7 @@ tenant_fields = {
'created_at': TimestampField, 'created_at': TimestampField,
'role': fields.String, 'role': fields.String,
'providers': fields.List(fields.Nested(provider_fields)), 'providers': fields.List(fields.Nested(provider_fields)),
'in_trail': fields.Boolean, 'in_trial': fields.Boolean,
'trial_end_reason': fields.String, 'trial_end_reason': fields.String,
} }

View File

@@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App from models.model import App
@@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource):
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, 'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }

View File

@@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from models.model import App, AppModelConfig from models.model import App, AppModelConfig
from services.audio_service import AudioService from services.audio_service import AudioService

View File

@@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError ProviderModelCurrentlyNotSupportError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from core.conversation_message_task import PubHandler from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService

View File

@@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError DatasetNotInitedError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
from core.llm.error import ProviderTokenNotInitError from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.model import UploadFile from models.model import UploadFile

View File

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.web import api from controllers.web import api
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App from models.model import App
@@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource):
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, 'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }

View File

@@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

View File

@@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.conversation_message_task import PubHandler from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService

View File

@@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService from services.completion_service import CompletionService

View File

@@ -1,36 +0,0 @@
import os
from typing import Optional
import langchain
from flask import Flask
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.prompt.prompt_template import OneLineFormatter
class HostedOpenAICredential(BaseModel):
api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

View File

@@ -1,20 +1,17 @@
from typing import cast, List from typing import List
from langchain import OpenAI
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage from langchain.schema import BaseMessage
from core.constant import llm_constant from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
class CalcTokenMixin: class CalcTokenMixin:
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
llm = cast(ChatOpenAI, llm) return model_instance.get_num_tokens(to_prompt_messages(messages))
return llm.get_num_tokens_from_messages(messages)
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
""" """
Got the rest tokens available for the model after excluding messages tokens and completion max tokens Got the rest tokens available for the model after excluding messages tokens and completion max tokens
@@ -22,10 +19,9 @@ class CalcTokenMixin:
:param messages: :param messages:
:return: :return:
""" """
llm = cast(ChatOpenAI, llm) llm_max_tokens = model_instance.model_rules.max_tokens.max
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name] completion_max_tokens = model_instance.model_kwargs.max_tokens
completion_max_tokens = llm.max_tokens used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
return rest_tokens return rest_tokens

View File

@@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
""" """
An Multi Dataset Retrieve Agent driven by Router. An Multi Dataset Retrieve Agent driven by Router.
""" """
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str): def should_use_agent(self, query: str):
""" """

View File

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
# summarize messages if rest_tokens < 0 # summarize messages if rest_tokens < 0
try: try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e: except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e)) return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@@ -3,20 +3,28 @@ from typing import cast, List
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel
model_instance: BaseLLM
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs) rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0: if rest_tokens >= 0:
return messages return messages

View File

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
# summarize messages if rest_tokens < 0 # summarize messages if rest_tokens < 0
try: try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e: except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e)) return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@@ -0,0 +1,162 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.dataset_tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
return self.output_parser.parse(full_output)
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
dataset_tools=tools,
**kwargs,
)

View File

@@ -14,7 +14,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str): def should_use_agent(self, query: str):
""" """
@@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts: if prompts:
messages = prompts[0].to_messages() messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages) rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
if rest_tokens < 0: if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs) full_inputs = self.summarize_messages(intermediate_steps, **kwargs)

View File

@@ -3,7 +3,6 @@ import logging
from typing import Union, Optional from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool from langchain.tools import BaseTool
@@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor from langchain.agents import AgentExecutor as LCAgentExecutor
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum): class PlanningStrategy(str, enum.Enum):
ROUTER = 'router' ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react' REACT = 'react'
FUNCTION_CALL = 'function_call' FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call' MULTI_FUNCTION_CALL = 'multi_function_call'
@@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
llm: BaseLanguageModel model_instance: BaseLLM
tools: list[BaseTool] tools: list[BaseTool]
summary_llm: BaseLanguageModel summary_model_instance: BaseLLM
dataset_llm: BaseLanguageModel
memory: Optional[BaseChatMemory] = None memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
@@ -60,36 +61,49 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT: if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm, summary_llm=self.configuration.summary_model_instance.client,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm, summary_llm=self.configuration.summary_model_instance.client,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm, summary_llm=self.configuration.summary_model_instance.client,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.ROUTER: elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools( agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.dataset_llm, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
)
else: else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")

View File

@@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.llm.base import BaseLLM
class AgentLoopGatherCallbackHandler(BaseCallbackHandler): class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
self.model_name = model_name self.model_instant = model_instant
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
@@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end( self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop self._message_agent_thought, self.model_instant, self._current_loop
) )
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
@@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
) )
self.conversation_message_task.on_agent_end( self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop self._message_agent_thought, self.model_instant, self._current_loop
) )
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)

View File

@@ -3,18 +3,20 @@ import time
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel from langchain.schema import LLMResult, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
class LLMCallbackHandler(BaseCallbackHandler): class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True raise_error: bool = True
def __init__(self, llm: BaseLanguageModel, def __init__(self, model_instance: BaseLLM,
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
self.llm = llm self.model_instance = model_instance
self.llm_message = LLMMessage() self.llm_message = LLMMessage()
self.start_at = None self.start_at = None
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
@@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
}) })
self.llm_message.prompt = real_prompts self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0]) self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
"text": prompts[0] "text": prompts[0]
}] }]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter() end_at = time.perf_counter()
@@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.append_message_text(response.generations[0][0].text) self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message) self.conversation_message_task.save_message(self.llm_message)
@@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler):
if self.conversation_message_task.streaming: if self.conversation_message_task.streaming:
end_at = time.perf_counter() end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else: else:
logging.error(error) logging.error(error)

View File

@@ -5,9 +5,7 @@ from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask

View File

@@ -2,27 +2,19 @@ import logging
import re import re
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel from langchain.schema import BaseMessage
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.llm.fake import FakeLLM
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompt_template import JinjaPromptTemplate
@@ -51,12 +43,10 @@ class Completion:
inputs = conversation.inputs inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
mode=app.mode,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
app_model_config=app_model_config, model_config=app_model_config.model_dict,
query=query, streaming=streaming
inputs=inputs
) )
conversation_message_task = ConversationMessageTask( conversation_message_task = ConversationMessageTask(
@@ -68,10 +58,17 @@ class Completion:
is_override=is_override, is_override=is_override,
inputs=inputs, inputs=inputs,
query=query, query=query,
streaming=streaming streaming=streaming,
model_instance=final_model_instance
) )
chain_callback = MainChainGatherCallbackHandler(conversation_message_task) rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
model_instance=final_model_instance,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
# init orchestrator rule parser # init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser( orchestrator_rule_parser = OrchestratorRuleParser(
@@ -80,6 +77,7 @@ class Completion:
) )
# parse sensitive_word_avoidance_chain # parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain: if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query) query = sensitive_word_avoidance_chain.run(query)
@@ -102,15 +100,14 @@ class Completion:
# run the final llm # run the final llm
try: try:
cls.run_final_llm( cls.run_final_llm(
tenant_id=app.tenant_id, model_instance=final_model_instance,
mode=app.mode, mode=app.mode,
app_model_config=app_model_config, app_model_config=app_model_config,
query=query, query=query,
inputs=inputs, inputs=inputs,
agent_execute_result=agent_execute_result, agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
memory=memory, memory=memory
streaming=streaming
) )
except ConversationTaskStoppedException: except ConversationTaskStoppedException:
return return
@@ -121,31 +118,20 @@ class Completion:
return return
@classmethod @classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult], agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask, conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool): memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
# When no extra pre prompt is specified, # When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again # the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER: and agent_execute_result.strategy != PlanningStrategy.ROUTER:
final_llm = FakeLLM(response=agent_execute_result.output, fake_response = agent_execute_result.output
origin_llm=agent_execute_result.configuration.llm,
streaming=streaming)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
response = final_llm.generate([[HumanMessage(content=query)]])
return response
final_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt # get llm prompt
prompt, stop_words = cls.get_main_llm_prompt( prompt_messages, stop_words = cls.get_main_llm_prompt(
mode=mode, mode=mode,
llm=final_llm,
model=app_model_config.model_dict, model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
@@ -154,25 +140,26 @@ class Completion:
memory=memory memory=memory
) )
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=final_llm, model_instance=model_instance,
model=app_model_config.model_dict, prompt_messages=prompt_messages,
prompt=prompt,
mode=mode
) )
response = final_llm.generate([prompt], stop_words) response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
return response return response
@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict, def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict, pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult], agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion': if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template( prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags. template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
@@ -200,11 +187,7 @@ And answer according to the language of the user's question.
**prompt_inputs **prompt_inputs
) )
if isinstance(llm, BaseChatModel): return [PromptMessage(content=prompt_content)], None
# use chat llm as completion model
return [HumanMessage(content=prompt_content)], None
else:
return prompt_content, None
else: else:
messages: List[BaseMessage] = [] messages: List[BaseMessage] = []
@@ -249,12 +232,14 @@ And answer according to the language of the user's question.
inputs=human_inputs inputs=human_inputs
) )
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message]) if memory.model_instance.model_rules.max_tokens.max:
model_name = model['name'] curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens') max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \ rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
- max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0)
rest_tokens = max(rest_tokens, 0) else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens) histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else "" human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \ human_message_prompt += "Here is the chat histories between human and assistant, " \
@@ -274,17 +259,7 @@ And answer according to the language of the user's question.
for message in messages: for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content) message.content = re.sub(r'<\|.*?\|>', '', message.content)
return messages, ['\nHuman:', '</histories>'] return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod
def get_llm_callbacks(cls, llm: BaseLanguageModel,
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
return [llm_callback_handler, DifyStdOutCallbackHandler()]
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -300,15 +275,15 @@ And answer according to the language of the user's question.
conversation: Conversation, conversation: Conversation,
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory: **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
# only for calc token in memory # only for calc token in memory
memory_llm = LLMBuilder.to_llm_from_model( memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=tenant_id, tenant_id=tenant_id,
model=app_model_config.model_dict model_config=app_model_config.model_dict
) )
# use llm config from conversation # use llm config from conversation
memory = ReadOnlyConversationTokenDBBufferSharedMemory( memory = ReadOnlyConversationTokenDBBufferSharedMemory(
conversation=conversation, conversation=conversation,
llm=memory_llm, model_instance=memory_model_instance,
max_token_limit=kwargs.get("max_token_limit", 2048), max_token_limit=kwargs.get("max_token_limit", 2048),
memory_key=kwargs.get("memory_key", "chat_history"), memory_key=kwargs.get("memory_key", "chat_history"),
return_messages=kwargs.get("return_messages", True), return_messages=kwargs.get("return_messages", True),
@@ -320,21 +295,20 @@ And answer according to the language of the user's question.
return memory return memory
@classmethod @classmethod
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig, def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int: query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model( model_limited_tokens = model_instance.model_rules.max_tokens.max
tenant_id=tenant_id, max_tokens = model_instance.get_model_kwargs().max_tokens
model=app_model_config.model_dict
)
model_name = app_model_config.model_dict.get("name") if model_limited_tokens is None:
model_limited_tokens = llm_constant.max_context_token_length[model_name] return -1
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
if max_tokens is None:
max_tokens = 0
# get prompt without memory and context # get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt( prompt_messages, _ = cls.get_main_llm_prompt(
mode=mode, mode=mode,
llm=llm,
model=app_model_config.model_dict, model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
@@ -343,9 +317,7 @@ And answer according to the language of the user's question.
memory=None memory=None
) )
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \ prompt_tokens = model_instance.get_num_tokens(prompt_messages)
else llm.get_num_tokens_from_messages(prompt)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0: if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -354,36 +326,40 @@ And answer according to the language of the user's question.
return rest_tokens return rest_tokens
@classmethod @classmethod
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict, def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_name = model.get("name") model_limited_tokens = model_instance.model_rules.max_tokens.max
model_limited_tokens = llm_constant.max_context_token_length[model_name] max_tokens = model_instance.get_model_kwargs().max_tokens
max_tokens = model.get("completion_params").get('max_tokens')
if mode == 'completion' and isinstance(final_llm, BaseLLM): if model_limited_tokens is None:
prompt_tokens = final_llm.get_num_tokens(prompt) return
else:
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt) if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_limited_tokens: if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16) max_tokens = max(model_limited_tokens - prompt_tokens, 16)
final_llm.max_tokens = max_tokens
# update model instance max tokens
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
@classmethod @classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool): app_model_config: AppModelConfig, user: Account, streaming: bool):
llm = LLMBuilder.to_llm_from_model( final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
model=app_model_config.model_dict, model_config=app_model_config.model_dict,
streaming=streaming streaming=streaming
) )
# get llm prompt # get llm prompt
original_prompt, _ = cls.get_main_llm_prompt( old_prompt_messages, _ = cls.get_main_llm_prompt(
mode="completion", mode="completion",
llm=llm,
model=app_model_config.model_dict, model=app_model_config.model_dict,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
query=message.query, query=message.query,
@@ -395,10 +371,9 @@ And answer according to the language of the user's question.
original_completion = message.answer.strip() original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=original_prompt, original_completion=original_completion) prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
if isinstance(llm, BaseChatModel): prompt_messages = [PromptMessage(content=prompt)]
prompt = [HumanMessage(content=prompt)]
conversation_message_task = ConversationMessageTask( conversation_message_task = ConversationMessageTask(
task_id=task_id, task_id=task_id,
@@ -408,16 +383,16 @@ And answer according to the language of the user's question.
inputs=message.inputs, inputs=message.inputs,
query=message.query, query=message.query,
is_override=True if message.override_model_configs else False, is_override=True if message.override_model_configs else False,
streaming=streaming streaming=streaming,
model_instance=final_model_instance
) )
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=llm, model_instance=final_model_instance,
model=app_model_config.model_dict, prompt_messages=prompt_messages
prompt=prompt,
mode='completion'
) )
llm.generate([prompt]) final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)

View File

@@ -1,109 +0,0 @@
from _decimal import Decimal
models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
'gpt-3.5-turbo-16k': 'openai', # 16384 tokens
'text-davinci-003': 'openai', # 4,097 tokens
'text-davinci-002': 'openai', # 4,097 tokens
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
}
max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
'text-davinci-002': 4097,
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
'text-embedding-ada-002': 8191,
}
models_by_mode = {
'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
],
'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
'text-davinci-003', # 4,097 tokens
'text-davinci-002' # 4,097 tokens
'text-curie-001', # 2,049 tokens
'text-babbage-001', # 2,049 tokens
'text-ada-001' # 2,049 tokens
],
'embedding': [
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
]
}
model_currency = 'USD'
model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': Decimal('0.06'),
'completion': Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': Decimal('0.0015'),
'completion': Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': Decimal('0.003'),
'completion': Decimal('0.004')
},
'text-davinci-003': {
'prompt': Decimal('0.02'),
'completion': Decimal('0.02')
},
'text-curie-001': {
'prompt': Decimal('0.002'),
'completion': Decimal('0.002')
},
'text-babbage-001': {
'prompt': Decimal('0.0005'),
'completion': Decimal('0.0005')
},
'text-ada-001': {
'prompt': Decimal('0.0004'),
'completion': Decimal('0.0004')
},
'text-embedding-ada-002': {
'usage': Decimal('0.0001'),
}
}
agent_model_name = 'text-davinci-003'

View File

@@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant from core.model_providers.model_factory import ModelFactory
from core.llm.llm_builder import LLMBuilder from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.llm.provider.llm_provider_service import LLMProviderService from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompt_template import JinjaPromptTemplate
from events.message_event import message_was_created from events.message_event import message_was_created
@@ -16,12 +16,11 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
from models.provider import ProviderType, Provider
class ConversationMessageTask: class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False): conversation: Optional[Conversation] = None, is_override: bool = False):
self.task_id = task_id self.task_id = task_id
@@ -38,9 +37,12 @@ class ConversationMessageTask:
self.conversation = conversation self.conversation = conversation
self.is_new_conversation = False self.is_new_conversation = False
self.model_instance = model_instance
self.message = None self.message = None
self.model_dict = self.app_model_config.model_dict self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name') self.model_name = self.model_dict.get('name')
self.mode = app.mode self.mode = app.mode
@@ -56,9 +58,6 @@ class ConversationMessageTask:
) )
def init(self): def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None override_model_configs = None
if self.is_override: if self.is_override:
override_model_configs = { override_model_configs = {
@@ -89,15 +88,19 @@ class ConversationMessageTask:
if self.app_model_config.pre_prompt: if self.app_model_config.pre_prompt:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) model_instance = ModelFactory.get_text_generation_model(
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message]) tenant_id=self.tenant_id,
model_provider_name=self.provider_name,
model_name=self.model_name
)
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
if not self.conversation: if not self.conversation:
self.is_new_conversation = True self.is_new_conversation = True
self.conversation = Conversation( self.conversation = Conversation(
app_id=self.app_model_config.app_id, app_id=self.app_model_config.app_id,
app_model_config_id=self.app_model_config.id, app_model_config_id=self.app_model_config.id,
model_provider=self.model_dict.get('provider'), model_provider=self.provider_name,
model_id=self.model_name, model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode, mode=self.mode,
@@ -117,7 +120,7 @@ class ConversationMessageTask:
self.message = Message( self.message = Message(
app_id=self.app_model_config.app_id, app_id=self.app_model_config.app_id,
model_provider=self.model_dict.get('provider'), model_provider=self.provider_name,
model_id=self.model_name, model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id, conversation_id=self.conversation.id,
@@ -131,7 +134,7 @@ class ConversationMessageTask:
answer_unit_price=0, answer_unit_price=0,
provider_response_latency=0, provider_response_latency=0,
total_price=0, total_price=0,
currency=llm_constant.model_currency, currency=self.model_instance.get_currency(),
from_source=('console' if isinstance(self.user, Account) else 'api'), from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None), from_account_id=(self.user.id if isinstance(self.user, Account) else None),
@@ -145,12 +148,10 @@ class ConversationMessageTask:
self._pub_handler.pub_text(text) self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
model_name = self.app_model_config.model_dict.get('name')
message_tokens = llm_message.prompt_tokens message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens answer_tokens = llm_message.completion_tokens
message_unit_price = llm_constant.model_prices[model_name]['prompt'] message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = llm_constant.model_prices[model_name]['completion'] answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
@@ -163,8 +164,6 @@ class ConversationMessageTask:
self.message.provider_response_latency = llm_message.latency self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price self.message.total_price = total_price
self.update_provider_quota()
db.session.commit() db.session.commit()
message_was_created.send( message_was_created.send(
@@ -176,20 +175,6 @@ class ConversationMessageTask:
if not by_stopped: if not by_stopped:
self.end() self.end()
def update_provider_quota(self):
llm_provider_service = LLMProviderService(
tenant_id=self.app.tenant_id,
provider_name=self.message.model_provider,
)
provider = llm_provider_service.get_provider_db_record()
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
def init_chain(self, chain_result: ChainResult): def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain( message_chain = MessageChain(
message_id=self.message.id, message_id=self.message.id,
@@ -229,10 +214,10 @@ class ConversationMessageTask:
return message_agent_thought return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt'] agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion'] agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens
@@ -253,7 +238,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = llm_constant.model_currency message_agent_thought.currency = agent_model_instant.get_currency()
db.session.flush() db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence
from langchain.schema import Document from langchain.schema import Document
from sqlalchemy import func from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
@@ -13,12 +13,10 @@ class DatesetDocumentStore:
self, self,
dataset: Dataset, dataset: Dataset,
user_id: str, user_id: str,
embedding_model_name: str,
document_id: Optional[str] = None, document_id: Optional[str] = None,
): ):
self._dataset = dataset self._dataset = dataset
self._user_id = user_id self._user_id = user_id
self._embedding_model_name = embedding_model_name
self._document_id = document_id self._document_id = document_id
@classmethod @classmethod
@@ -39,10 +37,6 @@ class DatesetDocumentStore:
def user_id(self) -> Any: def user_id(self) -> Any:
return self._user_id return self._user_id
@property
def embedding_model_name(self) -> Any:
return self._embedding_model_name
@property @property
def docs(self) -> Dict[str, Document]: def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter( document_segments = db.session.query(DocumentSegment).filter(
@@ -74,6 +68,10 @@ class DatesetDocumentStore:
if max_position is None: if max_position is None:
max_position = 0 max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id
)
for doc in docs: for doc in docs:
if not isinstance(doc, Document): if not isinstance(doc, Document):
raise ValueError("doc must be a Document") raise ValueError("doc must be a Document")
@@ -88,7 +86,7 @@ class DatesetDocumentStore:
) )
# calc embedding use tokens # calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content) tokens = embedding_model.get_num_tokens(doc.page_content)
if not segment_document: if not segment_document:
max_position += 1 max_position += 1

View File

@@ -4,14 +4,14 @@ from typing import List
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions from core.model_providers.models.embedding.base import BaseEmbedding
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.dataset import Embedding from models.dataset import Embedding
class CacheEmbedding(Embeddings): class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings): def __init__(self, embeddings: BaseEmbedding):
self._embeddings = embeddings self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings):
embedding_queue_texts = [] embedding_queue_texts = []
for text in texts: for text in texts:
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first() embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding: if embedding:
text_embeddings.append(embedding.get_embedding()) text_embeddings.append(embedding.get_embedding())
else: else:
embedding_queue_texts.append(text) embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts) if embedding_queue_texts:
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try: try:
embedding = Embedding(hash=hash) embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
embedding.set_embedding(embedding_results[i]) except Exception as ex:
db.session.add(embedding) raise self._embeddings.handle_exceptions(ex)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
text_embeddings.extend(embedding_results) i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
text_embeddings.extend(embedding_results)
return text_embeddings return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Embed query text.""" """Embed query text."""
# use doc embedding cache or store if not exists # use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first() embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding: if embedding:
return embedding.get_embedding() return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text) try:
embedding_results = self._embeddings.client.embed_query(text)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
try: try:
embedding = Embedding(hash=hash) embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results) embedding.set_embedding(embedding_results)
db.session.add(embedding) db.session.add(embedding)
db.session.commit() db.session.commit()
@@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings):
logging.exception('Failed to add embedding to db') logging.exception('Failed to add embedding to db')
return embedding_results return embedding_results

View File

@@ -1,13 +1,10 @@
import logging import logging
from langchain import PromptTemplate from langchain.schema import OutputParserException
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
from core.constant import llm_constant from core.model_providers.model_factory import ModelFactory
from core.llm.llm_builder import LLMBuilder from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.llm.streamable_open_ai import StreamableOpenAI from core.model_providers.models.entity.model_params import ModelKwargs
from core.llm.token_calculator import TokenCalculator
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT GENERATOR_QA_PROMPT
# gpt-3.5-turbo works not well
generate_base_model = 'text-davinci-003'
class LLMGenerator: class LLMGenerator:
@classmethod @classmethod
@@ -28,29 +22,35 @@ class LLMGenerator:
query = query[:300] + "...[TRUNCATED]..." + query[-300:] query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query) prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_kwargs=ModelKwargs(
max_tokens=50, max_tokens=50
timeout=600 )
) )
if isinstance(llm, BaseChatModel): prompts = [PromptMessage(content=prompt)]
prompt = [HumanMessage(content=prompt)] response = model_instance.run(prompts)
answer = response.content
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip() return answer.strip()
@classmethod @classmethod
def generate_conversation_summary(cls, tenant_id: str, messages): def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200 max_tokens = 200
model = 'gpt-3.5-turbo'
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='') prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context) prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1 max_context_token_length = model_instance.model_rules.max_tokens.max
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = '' context = ''
for message in messages: for message in messages:
@@ -68,25 +68,16 @@ class LLMGenerator:
answer = message.answer answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text context += message_qa_text
if not context: if not context:
return '[message too long, no summary]' return '[message too long, no summary]'
prompt = prompt.format(context=context) prompt = prompt.format(context=context)
prompts = [PromptMessage(content=prompt)]
llm: StreamableOpenAI = LLMBuilder.to_llm( response = model_instance.run(prompts)
tenant_id=tenant_id, answer = response.content
model_name=model,
max_tokens=max_tokens
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip() return answer.strip()
@classmethod @classmethod
@@ -94,16 +85,13 @@ class LLMGenerator:
prompt = INTRODUCTION_GENERATE_PROMPT prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt) prompt = prompt.format(prompt=pre_prompt)
llm: StreamableOpenAI = LLMBuilder.to_llm( model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id, tenant_id=tenant_id
model_name=generate_base_model,
) )
if isinstance(llm, BaseChatModel): prompts = [PromptMessage(content=prompt)]
prompt = [HumanMessage(content=prompt)] response = model_instance.run(prompts)
answer = response.content
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip() return answer.strip()
@classmethod @classmethod
@@ -119,23 +107,19 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories) _input = prompt.format_prompt(histories=histories)
llm: StreamableOpenAI = LLMBuilder.to_llm( model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_kwargs=ModelKwargs(
temperature=0, max_tokens=256,
max_tokens=256 temperature=0
)
) )
if isinstance(llm, BaseChatModel): prompts = [PromptMessage(content=_input.to_string())]
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
try: try:
output = llm(query) output = model_instance.run(prompts)
if isinstance(output, BaseMessage): questions = output_parser.parse(output.content)
output = output.content
questions = output_parser.parse(output)
except Exception: except Exception:
logging.exception("Error generating suggested questions after answer") logging.exception("Error generating suggested questions after answer")
questions = [] questions = []
@@ -160,21 +144,19 @@ class LLMGenerator:
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
llm: StreamableOpenAI = LLMBuilder.to_llm( model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name=generate_base_model, model_kwargs=ModelKwargs(
temperature=0, max_tokens=512,
max_tokens=512 temperature=0
)
) )
if isinstance(llm, BaseChatModel): prompts = [PromptMessage(content=_input.to_string())]
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
try: try:
output = llm(query) output = model_instance.run(prompts)
rule_config = output_parser.parse(output) rule_config = output_parser.parse(output.content)
except OutputParserException: except OutputParserException:
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
except Exception: except Exception:
@@ -188,25 +170,21 @@ class LLMGenerator:
return rule_config return rule_config
@classmethod @classmethod
async def generate_qa_document(cls, llm: StreamableOpenAI, query): def generate_qa_document(cls, tenant_id: str, query):
prompt = GENERATOR_QA_PROMPT prompt = GENERATOR_QA_PROMPT
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=2000
)
)
if isinstance(llm, BaseChatModel): prompts = [
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] PromptMessage(content=prompt, type=MessageType.SYSTEM),
PromptMessage(content=query)
]
response = llm.generate([prompt]) response = model_instance.run(prompts)
answer = response.generations[0][0].text answer = response.content
return answer.strip()
@classmethod
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip() return answer.strip()

View File

@@ -0,0 +1,20 @@
import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def encrypt_token(tenant_id: str, token: str):
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(tenant_id: str, token: str):
return rsa.decrypt(base64.b64decode(token), tenant_id)

View File

@@ -1,10 +1,9 @@
from flask import current_app from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder from core.model_providers.model_factory import ModelFactory
from models.dataset import Dataset from models.dataset import Dataset
@@ -15,16 +14,11 @@ class IndexBuilder:
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None return None
model_credentials = LLMBuilder.get_model_credentials( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
) )
embeddings = CacheEmbedding(OpenAIEmbeddings( embeddings = CacheEmbedding(embedding_model)
max_retries=1,
**model_credentials
))
return VectorIndex( return VectorIndex(
dataset=dataset, dataset=dataset,

View File

@@ -1,4 +1,3 @@
import concurrent
import datetime import datetime
import json import json
import logging import logging
@@ -6,7 +5,6 @@ import re
import threading import threading
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, cast from typing import Optional, List, cast
from flask_login import current_user from flask_login import current_user
@@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore from core.docstore.dataset_docstore import DatesetDocumentStore
from core.generator.llm_generator import LLMGenerator from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.llm.error import ProviderTokenNotInitError from core.model_providers.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder from core.model_providers.model_factory import ModelFactory
from core.llm.streamable_open_ai import StreamableOpenAI from core.model_providers.models.entity.message import MessageType
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
@@ -35,9 +32,8 @@ from models.source import DataSourceBinding
class IndexingRunner: class IndexingRunner:
def __init__(self, embedding_model_name: str = "text-embedding-ada-002"): def __init__(self):
self.storage = storage self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, dataset_documents: List[DatasetDocument]): def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process.""" """Run the indexing process."""
@@ -227,11 +223,15 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow() dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit() db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict, def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict: doc_form: str = None) -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0 tokens = 0
preview_texts = [] preview_texts = []
total_segments = 0 total_segments = 0
@@ -253,44 +253,49 @@ class IndexingRunner:
splitter=splitter, splitter=splitter,
processing_rule=processing_rule processing_rule=processing_rule
) )
total_segments += len(documents) total_segments += len(documents)
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
self.filter_string(document.page_content))
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0: if len(preview_texts) > 0:
# qa model document # qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm( response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
return { return {
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
} }
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": embedding_model.get_currency(),
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion # load data from notion
tokens = 0 tokens = 0
preview_texts = [] preview_texts = []
@@ -336,31 +341,31 @@ class IndexingRunner:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) tokens += embedding_model.get_num_tokens(document.page_content)
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0: if len(preview_texts) > 0:
# qa model document # qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm( response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
return { return {
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
} }
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": embedding_model.get_currency(),
"preview": preview_texts "preview": preview_texts
} }
@@ -459,7 +464,6 @@ class IndexingRunner:
doc_store = DatesetDocumentStore( doc_store = DatesetDocumentStore(
dataset=dataset, dataset=dataset,
user_id=dataset_document.created_by, user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=dataset_document.id document_id=dataset_document.id
) )
@@ -513,17 +517,12 @@ class IndexingRunner:
all_documents.extend(split_documents) all_documents.extend(split_documents)
# processing qa document # processing qa document
if document_form == 'qa_model': if document_form == 'qa_model':
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
for i in range(0, len(all_documents), 10): for i in range(0, len(all_documents), 10):
threads = [] threads = []
sub_documents = all_documents[i:i + 10] sub_documents = all_documents[i:i + 10]
for doc in sub_documents: for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents}) 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
threads.append(document_format_thread) threads.append(document_format_thread)
document_format_thread.start() document_format_thread.start()
for thread in threads: for thread in threads:
@@ -531,13 +530,13 @@ class IndexingRunner:
return all_qa_documents return all_qa_documents
return all_documents return all_documents
def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents): def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
format_documents = [] format_documents = []
if document_node.page_content is None or not document_node.page_content.strip(): if document_node.page_content is None or not document_node.page_content.strip():
return return
try: try:
# qa model document # qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
qa_documents = [] qa_documents = []
for result in document_qa_list: for result in document_qa_list:
@@ -638,6 +637,10 @@ class IndexingRunner:
vector_index = IndexBuilder.get_index(dataset, 'high_quality') vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy') keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
@@ -648,7 +651,7 @@ class IndexingRunner:
chunk_documents = documents[i:i + chunk_size] chunk_documents = documents[i:i + chunk_size]
tokens += sum( tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents for document in chunk_documents
) )

View File

@@ -1,148 +0,0 @@
from typing import Union, Optional, List
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType, ProviderName
class LLMBuilder:
"""
This class handles the following logic:
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
OPENAI_API_TYPE=azure
OPENAI_API_VERSION=2022-12-01
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
OPENAI_API_KEY=<your Azure OpenAI API key>
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
}
model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
return cls.to_llm(
tenant_id=tenant_id,
model_name=model_name,
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_mode_by_model(cls, model_name: str) -> str:
if not model_name:
raise ValueError(f"empty model name is not supported.")
if model_name in llm_constant.models_by_mode['chat']:
return "chat"
elif model_name in llm_constant.models_by_mode['completion']:
return "completion"
else:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
# model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider_name = llm_constant.models[model_name]
if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = azure_openai_provider
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
provider_name = provider.provider_name
return provider_name

View File

@@ -1,15 +0,0 @@
import openai
from models.provider import ProviderName
class Moderation:
def __init__(self, provider: str, api_key: str):
self.provider = provider
self.api_key = api_key
if self.provider == ProviderName.OPENAI.value:
self.client = openai.Moderation
def moderate(self, text):
return self.client.create(input=text, api_key=self.api_key)

View File

@@ -1,138 +0,0 @@
import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'anthropic_api_key': ''
}
if obfuscated:
if not config.get('anthropic_api_key'):
config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

View File

@@ -1,145 +0,0 @@
import json
import logging
from typing import Optional, Union
import openai
import requests
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
return []
def check_embedding_model(self, credentials: Optional[dict] = None):
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
try:
result = openai.Embedding.create(input=['test'],
engine='text-embedding-ada-002',
timeout=60,
api_key=str(credentials.get('openai_api_key')),
api_base=str(credentials.get('openai_api_base')),
api_type='azure',
api_version=str(credentials.get('openai_api_version')))["data"][0][
"embedding"]
except openai.error.AuthenticationError as e:
raise AzureAuthenticationError(str(e))
except openai.error.APIConnectionError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
except openai.error.InvalidRequestError as e:
if e.http_status == 404:
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
"deployment name is exists in Azure AI")
else:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
except openai.error.OpenAIError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
if not isinstance(result, list):
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary.
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 16
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
if obfuscated:
if not config.get('openai_api_key'):
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
return config
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'openai_api_version' not in config:
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
self.check_embedding_model(credentials=config)
except ValidateFailedError as e:
raise e
except AzureAuthenticationError:
raise ValidateFailedError('Validation failed, please check your API Key.')
except AzureRequestFailedError as ex:
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex:
logging.exception('Azure OpenAI Credentials validation failed')
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
return config
class AzureAuthenticationError(Exception):
pass
class AzureRequestFailedError(Exception):
pass

View File

@@ -1,132 +0,0 @@
import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.provider import Provider, ProviderType, ProviderName
class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider = self.get_provider(only_custom)
if not provider:
raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if model_id and model_id == 'gpt-4':
raise ModelCurrentlyNotSupportError()
if quota_used >= quota_limit:
raise QuotaExceededError()
return self.get_hosted_credentials()
else:
return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
)
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
providers = query.order_by(Provider.provider_type.asc()).all()
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
return provider
return None
def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = ''
if obfuscated:
return self.obfuscated_token(config)
return config
def obfuscated_token(self, token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def get_token_type(self):
return str
def get_encrypted_token(self, config: Union[dict | str]):
return self.encrypt_token(config)
def get_decrypted_token(self, token: str):
return self.decrypt_token(token)
def encrypt_token(self, token):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token):
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
@abstractmethod
def get_provider_name(self):
raise NotImplementedError
@abstractmethod
def get_credentials(self, model_id: Optional[str] = None) -> dict:
raise NotImplementedError
@abstractmethod
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
raise NotImplementedError
@abstractmethod
def config_validate(self, config: str):
raise NotImplementedError

View File

@@ -1,2 +0,0 @@
class ValidateFailedError(Exception):
description = "Provider Validate failed"

View File

@@ -1,22 +0,0 @@
from typing import Optional
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
class HuggingfaceProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
"""
return {
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.HUGGINGFACEHUB

View File

@@ -1,53 +0,0 @@
from typing import Optional, Union
from core.llm.provider.anthropic_provider import AnthropicProvider
from core.llm.provider.azure_provider import AzureProvider
from core.llm.provider.base import BaseProvider
from core.llm.provider.huggingface_provider import HuggingfaceProvider
from core.llm.provider.openai_provider import OpenAIProvider
from models.provider import Provider
class LLMProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self.init_provider(tenant_id, provider_name)
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
if provider_name == 'openai':
return OpenAIProvider(tenant_id)
elif provider_name == 'azure_openai':
return AzureProvider(tenant_id)
elif provider_name == 'anthropic':
return AnthropicProvider(tenant_id)
elif provider_name == 'huggingface':
return HuggingfaceProvider(tenant_id)
else:
raise Exception('provider {} not found'.format(provider_name))
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return self.provider.get_models(model_id)
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
:param config:
:raises: ValidateFailedError
"""
return self.provider.config_validate(config)
def get_token_type(self):
return self.provider.get_token_type()
def get_encrypted_token(self, config: Union[dict | str]):
return self.provider.get_encrypted_token(config)

View File

@@ -1,55 +0,0 @@
import logging
from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
class OpenAIProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
response = openai.Model.list(**credentials)
return [{
'id': model['id'],
'name': model['id'],
} for model in response['data']]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the credentials for the given tenant_id and provider_name.
"""
return {
'openai_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.OPENAI
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
Moderation(self.get_provider_name().value, config).moderate('test')
except (AuthenticationError, OpenAIError) as ex:
raise ValidateFailedError(str(ex))
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key

View File

@@ -1,62 +0,0 @@
from typing import List, Optional, Any, Dict
from httpx import Timeout
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text

View File

@@ -1,41 +0,0 @@
import decimal
from typing import Optional
import tiktoken
from core.constant import llm_constant
class TokenCalculator:
@classmethod
def get_num_tokens(cls, model_name: str, text: str):
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(model_name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
@classmethod
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
if model_name in llm_constant.models_by_mode['embedding']:
unit_price = llm_constant.model_prices[model_name]['usage']
elif text_type == 'prompt':
unit_price = llm_constant.model_prices[model_name]['prompt']
elif text_type == 'completion':
unit_price = llm_constant.model_prices[model_name]['completion']
else:
raise Exception('Invalid text type')
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
@classmethod
def get_currency(cls, model_name: str):
return llm_constant.model_currency

View File

@@ -1,26 +0,0 @@
import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName
from core.llm.provider.base import BaseProvider
class Whisper:
def __init__(self, provider: BaseProvider):
self.provider = provider
if self.provider.get_provider_name() == ProviderName.OPENAI:
self.client = openai.Audio
self.credentials = provider.get_credentials()
@handle_openai_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',
file=file,
api_key=self.credentials.get('openai_api_key'),
api_base=self.credentials.get('openai_api_base'),
api_type=self.credentials.get('openai_api_type'),
api_version=self.credentials.get('openai_api_version'),
)

View File

@@ -1,27 +0,0 @@
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper

View File

@@ -1,31 +0,0 @@
import logging
from functools import wraps
import openai
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_openai_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper

View File

@@ -1,10 +1,10 @@
from typing import Any, List, Dict, Union from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel from langchain.schema import get_buffer_string, BaseMessage
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
from core.llm.streamable_open_ai import StreamableOpenAI from core.model_providers.models.llm.base import BaseLLM
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation, Message from models.model import Conversation, Message
@@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation conversation: Conversation
human_prefix: str = "Human" human_prefix: str = "Human"
ai_prefix: str = "Assistant" ai_prefix: str = "Assistant"
llm: BaseLanguageModel model_instance: BaseLLM
memory_key: str = "chat_history" memory_key: str = "chat_history"
max_token_limit: int = 2000 max_token_limit: int = 2000
message_limit: int = 10 message_limit: int = 10
@@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
messages = list(reversed(messages)) messages = list(reversed(messages))
chat_messages: List[BaseMessage] = [] chat_messages: List[PromptMessage] = []
for message in messages: for message in messages:
chat_messages.append(HumanMessage(content=message.query)) chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(AIMessage(content=message.answer)) chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages: if not chat_messages:
return chat_messages return []
# prune the chat message if it exceeds the max token limit # prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
if curr_buffer_length > self.max_token_limit: if curr_buffer_length > self.max_token_limit:
pruned_memory = [] pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages: while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0)) pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
return chat_messages return to_lc_messages(chat_messages)
@property @property
def memory_variables(self) -> List[str]: def memory_variables(self) -> List[str]:

View File

@@ -0,0 +1,293 @@
from typing import Optional
from langchain.callbacks.base import Callbacks
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
class ModelFactory:
@classmethod
def get_text_generation_model_from_model_config(cls, tenant_id: str,
model_config: dict,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
provider_name = model_config.get("provider")
model_name = model_config.get("name")
completion_params = model_config.get("completion_params", {})
return cls.get_text_generation_model(
tenant_id=tenant_id,
model_provider_name=provider_name,
model_name=model_name,
model_kwargs=ModelKwargs(
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1)
),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_text_generation_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None,
model_kwargs: Optional[ModelKwargs] = None,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
"""
get text generation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:param model_kwargs:
:param streaming:
:param callbacks:
:return:
"""
is_default_model = False
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default System Reasoning Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
is_default_model = True
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init text generation model
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
try:
model_instance = model_class(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming,
callbacks=callbacks
)
except LLMBadRequestError as e:
if is_default_model:
raise LLMBadRequestError(f"Default model {model_name} is not available. "
f"Please check your model provider credentials.")
else:
raise e
if is_default_model:
model_instance.deduct_quota = False
return model_instance
@classmethod
def get_embedding_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
"""
get embedding model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Embedding Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init embedding model
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_speech2text_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
"""
get speech to text model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Speech-to-Text Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init speech to text model
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
"""
get moderation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init moderation model
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
"""
get default model of model type.
:param tenant_id:
:param model_type:
:return:
"""
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if not default_model:
model_provider_rules = ModelProviderFactory.get_provider_rules()
for model_provider_name, model_provider_rule in model_provider_rules.items():
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
continue
model_list = model_provider.get_supported_model_list(model_type)
if model_list:
model_info = model_list[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=model_provider_name,
model_name=model_info['id']
)
db.session.add(default_model)
db.session.commit()
break
return default_model
@classmethod
def update_default_model(cls,
tenant_id: str,
model_type: ModelType,
provider_name: str,
model_name: str) -> TenantDefaultModel:
"""
update default model of model type.
:param tenant_id:
:param model_type:
:param provider_name:
:param model_name:
:return:
"""
model_provider_name = ModelProviderFactory.get_provider_names()
if provider_name not in model_provider_name:
raise ValueError(f'Invalid provider name: {provider_name}')
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
model_list = model_provider.get_supported_model_list(model_type)
model_ids = [model['id'] for model in model_list]
if model_name not in model_ids:
raise ValueError(f'Invalid model name: {model_name}')
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if default_model:
# update default model
default_model.provider_name = provider_name
default_model.model_name = model_name
db.session.commit()
else:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=provider_name,
model_name=model_name,
)
db.session.add(default_model)
db.session.commit()
return default_model

View File

@@ -0,0 +1,228 @@
from typing import Type
from sqlalchemy.exc import IntegrityError
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.rules import provider_rules
from extensions.ext_database import db
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
DEFAULT_MODELS = {
ModelType.TEXT_GENERATION.value: {
'provider_name': 'openai',
'model_name': 'gpt-3.5-turbo',
},
ModelType.EMBEDDINGS.value: {
'provider_name': 'openai',
'model_name': 'text-embedding-ada-002',
},
ModelType.SPEECH_TO_TEXT.value: {
'provider_name': 'openai',
'model_name': 'whisper-1',
}
}
class ModelProviderFactory:
@classmethod
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
if provider_name == 'openai':
from core.model_providers.providers.openai_provider import OpenAIProvider
return OpenAIProvider
elif provider_name == 'anthropic':
from core.model_providers.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider
elif provider_name == 'minimax':
from core.model_providers.providers.minimax_provider import MinimaxProvider
return MinimaxProvider
elif provider_name == 'spark':
from core.model_providers.providers.spark_provider import SparkProvider
return SparkProvider
elif provider_name == 'tongyi':
from core.model_providers.providers.tongyi_provider import TongyiProvider
return TongyiProvider
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider
elif provider_name == 'replicate':
from core.model_providers.providers.replicate_provider import ReplicateProvider
return ReplicateProvider
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
else:
raise NotImplementedError
@classmethod
def get_provider_names(cls):
"""
Returns a list of provider names.
"""
return list(provider_rules.keys())
@classmethod
def get_provider_rules(cls):
"""
Returns a list of provider rules.
:return:
"""
return provider_rules
@classmethod
def get_provider_rule(cls, provider_name: str):
"""
Returns provider rule.
"""
return provider_rules[provider_name]
@classmethod
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred model provider.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:return:
"""
# get preferred provider
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
if not preferred_provider or not preferred_provider.is_valid:
return None
# init model provider
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
return model_provider_class(provider=preferred_provider)
@classmethod
def get_preferred_type_by_preferred_model_provider(cls,
tenant_id: str,
model_provider_name: str,
preferred_model_provider: TenantPreferredModelProvider):
"""
get preferred provider type by preferred model provider.
:param model_provider_name:
:param preferred_model_provider:
:return:
"""
if not preferred_model_provider:
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
support_provider_types = model_provider_rules['support_provider_types']
if ProviderType.CUSTOM.value in support_provider_types:
custom_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.is_valid == True
).first()
if custom_provider:
return ProviderType.CUSTOM.value
model_provider = cls.get_model_provider_class(model_provider_name)
if ProviderType.SYSTEM.value in support_provider_types \
and model_provider.is_provider_type_system_supported():
return ProviderType.SYSTEM.value
elif ProviderType.CUSTOM.value in support_provider_types:
return ProviderType.CUSTOM.value
else:
return preferred_model_provider.preferred_provider_type
@classmethod
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
# get preferred provider type
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
# get providers by preferred provider type
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == preferred_provider_type
).all()
no_system_provider = False
if preferred_provider_type == ProviderType.SYSTEM.value:
quota_type_to_provider_dict = {}
for provider in providers:
quota_type_to_provider_dict[provider.quota_type] = provider
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
and quota_type in quota_type_to_provider_dict.keys():
provider = quota_type_to_provider_dict[quota_type]
if provider.is_valid and provider.quota_limit > provider.quota_used:
return provider
no_system_provider = True
if no_system_provider:
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
if providers:
return providers[0]
else:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
return provider
return None
@classmethod
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider type of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == model_provider_name
).first()
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

View File

@@ -0,0 +1,22 @@
from abc import ABC
from typing import Any
from core.model_providers.providers.base import BaseModelProvider
class BaseProviderModel(ABC):
_client: Any
_model_provider: BaseModelProvider
def __init__(self, model_provider: BaseModelProvider, client: Any):
self._model_provider = model_provider
self._client = client
@property
def client(self):
return self._client
@property
def model_provider(self):
return self._model_provider

View File

@@ -0,0 +1,78 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
deployment=name,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
chunk_size=16,
max_retries=1,
**self.credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any
import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseEmbedding(BaseProviderModel):
name: str
type: ModelType = ModelType.EMBEDDINGS
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self):
return 'USD'
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@@ -0,0 +1,35 @@
import decimal
import logging
from langchain.embeddings import MiniMaxEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class MinimaxEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = MiniMaxEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,72 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class OpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
max_retries=1,
**credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@@ -0,0 +1,36 @@
import decimal
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class ReplicateEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ReplicateEmbeddings(
model=name + ':' + credentials.get('model_version'),
replicate_api_token=credentials.get('replicate_api_token')
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,53 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from pydantic import BaseModel
class LLMRunResult(BaseModel):
content: str
prompt_tokens: int
completion_tokens: int
class MessageType(enum.Enum):
HUMAN = 'human'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
content: str = ''
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
return lc_messages
def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
return prompt_messages
def str_to_prompt_messages(texts: list[str]):
prompt_messages = []
for text in texts:
prompt_messages.append(PromptMessage(content=text))
return prompt_messages

View File

@@ -0,0 +1,59 @@
import enum
from typing import Optional, TypeVar, Generic
from langchain.load.serializable import Serializable
from pydantic import BaseModel
class ModelMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
class ModelType(enum.Enum):
TEXT_GENERATION = 'text-generation'
EMBEDDINGS = 'embeddings'
SPEECH_TO_TEXT = 'speech2text'
IMAGE = 'image'
VIDEO = 'video'
MODERATION = 'moderation'
@staticmethod
def value_of(value):
for member in ModelType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ModelKwargs(BaseModel):
max_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
presence_penalty: Optional[float]
frequency_penalty: Optional[float]
class KwargRuleType(enum.Enum):
STRING = 'string'
INTEGER = 'integer'
FLOAT = 'float'
T = TypeVar('T')
class KwargRule(Generic[T], BaseModel):
enabled: bool = True
min: Optional[T] = None
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
class ModelKwargsRules(BaseModel):
max_tokens: KwargRule = KwargRule[int](enabled=False)
temperature: KwargRule = KwargRule[float](enabled=False)
top_p: KwargRule = KwargRule[float](enabled=False)
presence_penalty: KwargRule = KwargRule[float](enabled=False)
frequency_penalty: KwargRule = KwargRule[float](enabled=False)

View File

@@ -0,0 +1,10 @@
from enum import Enum
class ProviderQuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
class ModelFeature(Enum):
AGENT_THOUGHT = 'agent_thought'

View File

@@ -0,0 +1,107 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import anthropic
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class AnthropicModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatAnthropic(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
default_request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, anthropic.APIConnectionError):
logging.warning("Failed to connect to Anthropic API.")
return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
elif isinstance(ex, anthropic.RateLimitError):
return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
elif isinstance(ex, anthropic.AuthenticationError):
return LLMAuthorizationError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.BadRequestError):
return LLMBadRequestError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.APIStatusError):
return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,177 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name == 'text-davinci-003':
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name == 'text-davinci-003':
client = EnhanceAzureOpenAI(
deployment_name=self.name,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceAzureChatOpenAI(
deployment_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003':
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,269 @@
from abc import abstractmethod
from typing import List, Optional, Any, Union
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM
class BaseLLM(BaseProviderModel):
model_mode: ModelMode = ModelMode.COMPLETION
name: str
model_kwargs: ModelKwargs
credentials: dict
streaming: bool = False
type: ModelType = ModelType.TEXT_GENERATION
deduct_quota: bool = True
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.name = name
self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
max_tokens=None,
temperature=None,
top_p=None,
presence_penalty=None,
frequency_penalty=None
)
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
self.streaming = streaming
if streaming:
default_callback = DifyStreamingStdOutCallbackHandler()
else:
default_callback = DifyStdOutCallbackHandler()
if not callbacks:
callbacks = [default_callback]
else:
callbacks.append(default_callback)
self.callbacks = callbacks
client = self._init_client()
super().__init__(model_provider, client)
@abstractmethod
def _init_client(self) -> Any:
raise NotImplementedError
def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMRunResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
if not callbacks:
callbacks = self.callbacks
else:
callbacks.extend(self.callbacks)
if 'fake_response' in kwargs and kwargs['fake_response']:
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=kwargs['fake_response'],
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
result = fake_llm.generate([prompts])
else:
try:
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
**kwargs
)
except Exception as ex:
raise self.handle_exceptions(ex)
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=completion_content,
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
fake_llm.generate([prompts])
if result.llm_output and result.llm_output['token_usage']:
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
completion_tokens = result.llm_output['token_usage']['completion_tokens']
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
)
@abstractmethod
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_token_price(self, tokens: int, message_type: MessageType):
"""
get token price.
:param tokens:
:param message_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_currency(self):
"""
get token currency.
:return:
"""
raise NotImplementedError
def get_model_kwargs(self):
return self.model_kwargs
def set_model_kwargs(self, model_kwargs: ModelKwargs):
self.model_kwargs = model_kwargs
self._set_model_kwargs(model_kwargs)
@abstractmethod
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
"""
Handle llm run exceptions.
:param ex:
:return:
"""
raise NotImplementedError
def add_callbacks(self, callbacks: Callbacks):
"""
Add callbacks to client.
:param callbacks:
:return:
"""
if not self.client.callbacks:
self.client.callbacks = callbacks
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
return False
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if len(messages) == 0:
raise ValueError("prompt must not be empty.")
if not model_mode:
model_mode = self.model_mode
if model_mode == ModelMode.COMPLETION:
return messages[0].content
else:
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""
convert model kwargs to provider model kwargs.
:param model_rules:
:param model_kwargs:
:return:
"""
model_kwargs_input = {}
for key, value in model_kwargs.dict().items():
rule = getattr(model_rules, key)
if not rule.enabled:
continue
if rule.alias:
key = rule.alias
if rule.default is not None and value is None:
value = rule.default
if rule.min is not None:
value = max(value, rule.min)
if rule.max is not None:
value = min(value, rule.max)
model_kwargs_input[key] = value
return model_kwargs_input

View File

@@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import ChatGLM
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ChatGLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatGLM(
callbacks=self.callbacks,
endpoint_url=self.credentials.get('api_base'),
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@@ -0,0 +1,82 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.llms import HuggingFaceEndpoint
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class HuggingfaceHubModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
client = HuggingFaceEndpoint(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task='text2text-generation',
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
else:
client = HuggingFaceHub(
repo_id=self.name,
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
model=self.name,
model_kwargs={
'stream': False
},
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,219 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from models.provider import ProviderType, ProviderQuotaType
COMPLETION_MODELS = [
'text-davinci-003', # 4,097 tokens
]
CHAT_MODELS = [
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
]
MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
class OpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name in COMPLETION_MODELS:
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name in COMPLETION_MODELS:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
else:
# Fine-tuning is currently only available for the following base models:
# davinci, curie, babbage, and ada.
# This means that except for the fixed `completion` model,
# all other fine-tuned models are `completion` models.
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.name == 'gpt-4' \
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True
# def is_model_valid_or_raise(self):
# """
# check is a valid model.
#
# :return:
# """
# credentials = self._model_provider.get_credentials()
#
# try:
# result = openai.Model.retrieve(
# id=self.name,
# api_key=credentials.get('openai_api_key'),
# request_timeout=60
# )
#
# if 'id' not in result or result['id'] != self.name:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
# except openai.error.OpenAIError as e:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
# except Exception as e:
# logging.exception("OpenAI Model retrieve failed.")
# raise e

View File

@@ -0,0 +1,103 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from replicate.exceptions import ReplicateError, ModelError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.error import LLMBadRequestError
from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ReplicateModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return EnhanceReplicate(
model=self.name + ':' + self.credentials.get('model_version'),
input=provider_model_kwargs,
streaming=self.streaming,
replicate_api_token=self.credentials.get('replicate_api_token'),
callbacks=self.callbacks,
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
extra_kwargs = {}
if isinstance(prompts, list):
system_messages = [message for message in messages if message.type == 'system']
if system_messages:
system_message = system_messages[0]
extra_kwargs['system_prompt'] = system_message.content
prompts = [message for message in messages if message.type != 'system']
prompts = get_buffer_string(prompts)
# The maximum length the generated tokens can have.
# Corresponds to the length of the input prompt + max_new_tokens.
if 'max_length' in self._client.input:
self._client.input['max_length'] = min(
self._client.input['max_length'] + self.get_num_tokens(messages),
self.model_rules.max_tokens.max
)
return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, list):
prompts = get_buffer_string(prompts)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,73 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.spark import ChatSpark
from core.third_party.spark.spark_llm import SparkError
class SparkModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, SparkError):
return LLMBadRequestError(f"Spark: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

Some files were not shown because too many files have changed in this diff Show More