mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-14 20:41:21 +08:00
feat: 合并dify1.1.3版本
# Conflicts: # README.md # api/.env.example # api/controllers/console/__init__.py # api/controllers/console/apikey.py # api/controllers/console/explore/completion.py # api/controllers/console/explore/workflow.py # api/controllers/service_api/app/workflow.py # api/controllers/service_api/wraps.py # api/controllers/web/workflow.py # api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py # api/core/model_runtime/model_providers/bedrock/llm/llm.py # api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml # api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py # api/models/model.py # api/poetry.lock # api/pyproject.toml # web/.env.example # web/Dockerfile # web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx # web/app/components/app/overview/appCard.tsx # web/app/components/base/chat/chat-with-history/chat-wrapper.tsx # web/app/components/base/chat/embedded-chatbot/index.tsx # web/app/components/base/mermaid/index.tsx # web/app/components/develop/index.tsx # web/app/components/develop/secret-key/secret-key-modal.tsx # web/app/components/explore/app-list/index.tsx # web/app/components/explore/item-operation/index.tsx # web/app/components/explore/sidebar/app-nav-item/index.tsx # web/app/components/explore/sidebar/index.tsx # web/app/components/header/account-setting/index.tsx # web/app/components/header/index.tsx # web/app/components/share/text-generation/index.tsx # web/app/components/tools/provider/detail.tsx # web/app/layout.tsx # web/package.json # web/service/base.ts # web/yarn.lock
This commit is contained in:
@@ -10,6 +10,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@@ -27,7 +28,6 @@ from models.account import (
|
||||
AccountStatus,
|
||||
Tenant,
|
||||
TenantAccountJoin,
|
||||
TenantAccountJoinRole,
|
||||
TenantAccountRole,
|
||||
TenantStatus,
|
||||
)
|
||||
@@ -77,6 +77,7 @@ class AccountService:
|
||||
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
|
||||
)
|
||||
LOGIN_MAX_ERROR_LIMITS = 5
|
||||
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
|
||||
|
||||
@staticmethod
|
||||
def _get_refresh_token_key(refresh_token: str) -> str:
|
||||
@@ -100,7 +101,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = Account.query.filter_by(id=user_id).first()
|
||||
account = db.session.query(Account).filter_by(id=user_id).first()
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@@ -145,7 +146,7 @@ class AccountService:
|
||||
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
@@ -503,6 +504,32 @@ class AccountService:
|
||||
key = f"login_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
def add_forgot_password_error_rate_limit(email: str) -> None:
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
count = 0
|
||||
count = int(count) + 1
|
||||
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
|
||||
|
||||
@staticmethod
|
||||
def is_forgot_password_error_rate_limit(email: str) -> bool:
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
return False
|
||||
|
||||
count = int(count)
|
||||
if count > AccountService.FORGOT_PASSWORD_MAX_ERROR_LIMITS:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def reset_forgot_password_error_rate_limit(email: str):
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
def is_email_send_ip_limit(ip_address: str):
|
||||
minute_key = f"email_send_ip_limit_minute:{ip_address}"
|
||||
@@ -597,8 +624,8 @@ class TenantService:
|
||||
@staticmethod
|
||||
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountJoinRole.OWNER.value:
|
||||
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
|
||||
if role == TenantAccountRole.OWNER.value:
|
||||
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
|
||||
logging.error(f"Tenant {tenant.id} has already an owner.")
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
@@ -706,10 +733,10 @@ class TenantService:
|
||||
return updated_accounts
|
||||
|
||||
@staticmethod
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool:
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountRole]) -> bool:
|
||||
"""Check if user has any of the given roles for a tenant"""
|
||||
if not all(isinstance(role, TenantAccountJoinRole) for role in roles):
|
||||
raise ValueError("all roles must be TenantAccountJoinRole")
|
||||
if not all(isinstance(role, TenantAccountRole) for role in roles):
|
||||
raise ValueError("all roles must be TenantAccountRole")
|
||||
|
||||
return (
|
||||
db.session.query(TenantAccountJoin)
|
||||
@@ -721,7 +748,7 @@ class TenantService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
|
||||
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
@@ -758,9 +785,11 @@ class TenantService:
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
"""Remove member from tenant"""
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
|
||||
if operator.id == account.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if not ta:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
@@ -897,11 +926,14 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
||||
) -> str:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
assert inviter is not None, "Inviter must be provided."
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter_by(email=email).first()
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user # type: ignore
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@@ -16,7 +20,10 @@ class AgentService:
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
conversation: Optional[Conversation] = (
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
conversation: Conversation | None = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
@@ -59,6 +66,10 @@ class AgentService:
|
||||
|
||||
timezone = pytz.timezone(current_user.timezone)
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("App model config not found")
|
||||
|
||||
result = {
|
||||
"meta": {
|
||||
"status": "success",
|
||||
@@ -66,16 +77,16 @@ class AgentService:
|
||||
"start_time": message.created_at.astimezone(timezone).isoformat(),
|
||||
"elapsed_time": message.provider_response_latency,
|
||||
"total_tokens": message.answer_tokens + message.message_tokens,
|
||||
"agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"),
|
||||
"agent_mode": app_model_config.agent_mode_dict.get("strategy", "react"),
|
||||
"iterations": len(agent_thoughts),
|
||||
},
|
||||
"iterations": [],
|
||||
"files": message.message_files,
|
||||
}
|
||||
|
||||
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
|
||||
agent_config = AgentConfigManager.convert(app_model_config.to_dict())
|
||||
if not agent_config:
|
||||
return result
|
||||
raise ValueError("Agent config not found")
|
||||
|
||||
agent_tools = agent_config.tools or []
|
||||
|
||||
@@ -89,7 +100,7 @@ class AgentService:
|
||||
tool_labels = agent_thought.tool_labels
|
||||
tool_meta = agent_thought.tool_meta
|
||||
tool_inputs = agent_thought.tool_inputs_dict
|
||||
tool_outputs = agent_thought.tool_outputs_dict
|
||||
tool_outputs = agent_thought.tool_outputs_dict or {}
|
||||
tool_calls = []
|
||||
for tool in tools:
|
||||
tool_name = tool
|
||||
@@ -144,3 +155,22 @@ class AgentService:
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def list_agent_providers(cls, user_id: str, tenant_id: str):
|
||||
"""
|
||||
List agent providers
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
return manager.fetch_agent_strategy_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
Get agent provider
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
try:
|
||||
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -7,22 +8,34 @@ from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from events.app_event import app_model_config_was_updated, app_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.1.5"
|
||||
|
||||
|
||||
@@ -42,11 +55,16 @@ class Import(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
app_id: Optional[str] = None
|
||||
app_mode: Optional[str] = None
|
||||
current_dsl_version: str = CURRENT_DSL_VERSION
|
||||
imported_dsl_version: str = ""
|
||||
error: str = ""
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
@@ -76,6 +94,11 @@ class PendingData(BaseModel):
|
||||
app_id: str | None
|
||||
|
||||
|
||||
class CheckDependenciesPendingData(BaseModel):
|
||||
dependencies: list[PluginDependency]
|
||||
app_id: str | None
|
||||
|
||||
|
||||
class AppDslService:
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
@@ -113,7 +136,6 @@ class AppDslService:
|
||||
error="yaml_url is required when import_mode is yaml-url",
|
||||
)
|
||||
try:
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
parsed_url = urlparse(yaml_url)
|
||||
if (
|
||||
parsed_url.scheme == "https"
|
||||
@@ -126,7 +148,7 @@ class AppDslService:
|
||||
response.raise_for_status()
|
||||
content = response.content.decode()
|
||||
|
||||
if len(content) > max_size:
|
||||
if len(content) > DSL_MAX_SIZE:
|
||||
return Import(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
@@ -199,7 +221,7 @@ class AppDslService:
|
||||
error="App not found",
|
||||
)
|
||||
|
||||
if app.mode not in [AppMode.WORKFLOW.value, AppMode.ADVANCED_CHAT.value]:
|
||||
if app.mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
|
||||
return Import(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
@@ -208,7 +230,7 @@ class AppDslService:
|
||||
|
||||
# If major version mismatch, store import info in Redis
|
||||
if status == ImportStatus.PENDING:
|
||||
panding_data = PendingData(
|
||||
pending_data = PendingData(
|
||||
import_mode=import_mode,
|
||||
yaml_content=content,
|
||||
name=name,
|
||||
@@ -221,7 +243,7 @@ class AppDslService:
|
||||
redis_client.setex(
|
||||
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
panding_data.model_dump_json(),
|
||||
pending_data.model_dump_json(),
|
||||
)
|
||||
|
||||
return Import(
|
||||
@@ -231,6 +253,22 @@ class AppDslService:
|
||||
imported_dsl_version=imported_version,
|
||||
)
|
||||
|
||||
# Extract dependencies
|
||||
dependencies = data.get("dependencies", [])
|
||||
check_dependencies_pending_data = None
|
||||
if dependencies:
|
||||
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
||||
elif imported_version <= "0.1.5":
|
||||
if "workflow" in data:
|
||||
graph = data.get("workflow", {}).get("graph", {})
|
||||
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)
|
||||
else:
|
||||
dependencies_list = self._extract_dependencies_from_model_config(data.get("model_config", {}))
|
||||
|
||||
check_dependencies_pending_data = DependenciesAnalysisService.generate_latest_dependencies(
|
||||
dependencies_list
|
||||
)
|
||||
|
||||
# Create or update app
|
||||
app = self._create_or_update_app(
|
||||
app=app,
|
||||
@@ -241,12 +279,14 @@ class AppDslService:
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
dependencies=check_dependencies_pending_data,
|
||||
)
|
||||
|
||||
return Import(
|
||||
id=import_id,
|
||||
status=status,
|
||||
app_id=app.id,
|
||||
app_mode=app.mode,
|
||||
imported_dsl_version=imported_version,
|
||||
)
|
||||
|
||||
@@ -313,6 +353,7 @@ class AppDslService:
|
||||
id=import_id,
|
||||
status=ImportStatus.COMPLETED,
|
||||
app_id=app.id,
|
||||
app_mode=app.mode,
|
||||
current_dsl_version=CURRENT_DSL_VERSION,
|
||||
imported_dsl_version=data.get("version", "0.1.0"),
|
||||
)
|
||||
@@ -325,6 +366,29 @@ class AppDslService:
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def check_dependencies(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
) -> CheckDependenciesResult:
|
||||
"""Check dependencies"""
|
||||
# Get dependencies from Redis
|
||||
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
|
||||
dependencies = redis_client.get(redis_key)
|
||||
if not dependencies:
|
||||
return CheckDependenciesResult()
|
||||
|
||||
# Extract dependencies
|
||||
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
|
||||
|
||||
# Get leaked dependencies
|
||||
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
|
||||
tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
|
||||
)
|
||||
return CheckDependenciesResult(
|
||||
leaked_dependencies=leaked_dependencies,
|
||||
)
|
||||
|
||||
def _create_or_update_app(
|
||||
self,
|
||||
*,
|
||||
@@ -336,6 +400,7 @@ class AppDslService:
|
||||
icon_type: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_background: Optional[str] = None,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
) -> App:
|
||||
"""Create a new app or update an existing one."""
|
||||
app_data = data.get("app", {})
|
||||
@@ -384,6 +449,14 @@ class AppDslService:
|
||||
self._session.commit()
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
# save dependencies
|
||||
if dependencies:
|
||||
redis_client.setex(
|
||||
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
|
||||
# Initialize app based on mode
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_data = data.get("workflow")
|
||||
@@ -479,6 +552,13 @@ class AppDslService:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=app_model.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
@@ -492,3 +572,154 @@ class AppDslService:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=app_model.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow
|
||||
:param workflow: Workflow instance
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow graph
|
||||
:param graph: Workflow graph
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
dependencies = []
|
||||
for node in graph.get("nodes", []):
|
||||
try:
|
||||
typ = node.get("data", {}).get("type")
|
||||
match typ:
|
||||
case NodeType.TOOL.value:
|
||||
tool_entity = ToolNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
|
||||
)
|
||||
case NodeType.LLM.value:
|
||||
llm_entity = LLMNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
|
||||
)
|
||||
case NodeType.QUESTION_CLASSIFIER.value:
|
||||
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
question_classifier_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.PARAMETER_EXTRACTOR.value:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
parameter_extractor_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
|
||||
if knowledge_retrieval_entity.retrieval_mode == "multiple":
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config:
|
||||
if (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
|
||||
== "reranking_model"
|
||||
):
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
|
||||
),
|
||||
)
|
||||
elif (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
|
||||
== "weighted_score"
|
||||
):
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
|
||||
vector_setting = (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
|
||||
)
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
vector_setting.embedding_provider_name
|
||||
),
|
||||
)
|
||||
elif knowledge_retrieval_entity.retrieval_mode == "single":
|
||||
model_config = knowledge_retrieval_entity.single_retrieval_config
|
||||
if model_config:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
model_config.model.provider
|
||||
),
|
||||
)
|
||||
case _:
|
||||
# TODO: Handle default case or unknown node types
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting node dependency", exc_info=e)
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from model config
|
||||
:param model_config: model config dict
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
try:
|
||||
# completion model
|
||||
model_dict = model_config.get("model", {})
|
||||
if model_dict:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
|
||||
)
|
||||
|
||||
# reranking model
|
||||
dataset_configs = model_config.get("dataset_configs", {})
|
||||
if dataset_configs:
|
||||
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
|
||||
if dataset_config.get("reranking_model"):
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
dataset_config.get("reranking_model", {})
|
||||
.get("reranking_provider_name", {})
|
||||
.get("provider")
|
||||
)
|
||||
)
|
||||
|
||||
# tools
|
||||
agent_configs = model_config.get("agent_mode", {})
|
||||
if agent_configs:
|
||||
for agent_config in agent_configs.get("tools", []):
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting model config dependency", exc_info=e)
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
return []
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
|
||||
@@ -11,13 +11,17 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting import RateLimit
|
||||
from libs.helper import RateLimiter
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
system_rate_limiter = RateLimiter("app_daily_rate_limiter", dify_config.APP_DAILY_RATE_LIMIT, 86400)
|
||||
|
||||
@classmethod
|
||||
def generate(
|
||||
cls,
|
||||
@@ -36,6 +40,19 @@ class AppGenerateService:
|
||||
:param streaming: streaming
|
||||
:return:
|
||||
"""
|
||||
# system level rate limiter
|
||||
if dify_config.BILLING_ENABLED:
|
||||
# check if it's free plan
|
||||
limit_info = BillingService.get_info(app_model.tenant_id)
|
||||
if limit_info["subscription"]["plan"] == "sandbox":
|
||||
if cls.system_rate_limiter.is_rate_limited(app_model.tenant_id):
|
||||
raise InvokeRateLimitError(
|
||||
"Rate limit exceeded, please upgrade your plan "
|
||||
f"or your RPD was {dify_config.APP_DAILY_RATE_LIMIT} requests/day"
|
||||
)
|
||||
cls.system_rate_limiter.increment_rate_limit(app_model.tenant_id)
|
||||
|
||||
# app level rate limiter
|
||||
max_active_request = AppGenerateService._get_max_active_requests(app_model)
|
||||
rate_limit = RateLimit(app_model.id, max_active_request)
|
||||
request_id = RateLimit.gen_request_key()
|
||||
@@ -43,66 +60,62 @@ class AppGenerateService:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
return rate_limit.generate(
|
||||
generator=CompletionAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
generator = AgentChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
generator=generator,
|
||||
request_id=request_id,
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
return rate_limit.generate(
|
||||
generator=ChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return rate_limit.generate(
|
||||
generator=AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
generator = WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
workflow_thread_pool_id=None,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
generator=generator,
|
||||
request_id=request_id,
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
workflow_thread_pool_id=None,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
@@ -126,18 +139,36 @@ class AppGenerateService:
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.sql import text # Extend: App Center - Recommended list sorted
|
||||
from configs import dify_config
|
||||
from constants.model_template import default_app_templates
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.features.rate_limiting import RateLimit
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
@@ -53,9 +52,13 @@ class AppService:
|
||||
# stop Extend: App Center - Recommended list sorted by usage frequency
|
||||
|
||||
if args["mode"] == "workflow":
|
||||
filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value]))
|
||||
filters.append(App.mode == AppMode.WORKFLOW.value)
|
||||
elif args["mode"] == "completion":
|
||||
filters.append(App.mode == AppMode.COMPLETION.value)
|
||||
elif args["mode"] == "chat":
|
||||
filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value]))
|
||||
filters.append(App.mode == AppMode.CHAT.value)
|
||||
elif args["mode"] == "advanced-chat":
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
elif args["mode"] == "channel":
|
||||
@@ -254,7 +257,6 @@ class AppService:
|
||||
"""
|
||||
app.name = args.get("name")
|
||||
app.description = args.get("description", "")
|
||||
app.max_active_requests = args.get("max_active_requests")
|
||||
app.icon_type = args.get("icon_type", "emoji")
|
||||
app.icon = args.get("icon")
|
||||
app.icon_background = args.get("icon_background")
|
||||
@@ -269,9 +271,6 @@ class AppService:
|
||||
# ======= stop: Extend: App Center - Recommended list sorted =======
|
||||
db.session.commit()
|
||||
|
||||
if app.max_active_requests is not None:
|
||||
rate_limit = RateLimit(app.id, app.max_active_requests)
|
||||
rate_limit.flush_cache(use_local_value=True)
|
||||
return app
|
||||
|
||||
def update_app_name(self, app: App, name: str) -> App:
|
||||
|
||||
@@ -5,6 +5,7 @@ import httpx
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
@@ -12,6 +13,8 @@ class BillingService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
|
||||
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
@@ -19,6 +22,17 @@ class BillingService:
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
|
||||
|
||||
return {
|
||||
"limit": knowledge_rate_limit.get("limit", 10),
|
||||
"subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
|
||||
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
|
||||
@@ -91,3 +105,29 @@ class BillingService:
|
||||
"""Update account deletion feedback."""
|
||||
json = {"email": email, "feedback": feedback}
|
||||
return cls._send_request("POST", "/account/delete-feedback", json=json)
|
||||
|
||||
@classmethod
|
||||
def get_compliance_download_link(
|
||||
cls,
|
||||
doc_name: str,
|
||||
account_id: str,
|
||||
tenant_id: str,
|
||||
ip: str,
|
||||
device_info: str,
|
||||
):
|
||||
limiter_key = f"{account_id}:{tenant_id}"
|
||||
if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
|
||||
from controllers.console.error import CompilanceRateLimitError
|
||||
|
||||
raise CompilanceRateLimitError()
|
||||
|
||||
json = {
|
||||
"doc_name": doc_name,
|
||||
"account_id": account_id,
|
||||
"tenant_id": tenant_id,
|
||||
"ip_address": ip,
|
||||
"device_info": device_info,
|
||||
}
|
||||
res = cls._send_request("POST", "/compliance/download", json=json)
|
||||
cls.compliance_download_rate_limiter.increment_rate_limit(limiter_key)
|
||||
return res
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import click
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.model import App, Conversation, Message
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
from services.billing_service import BillingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClearFreePlanTenantExpiredLogs:
|
||||
@classmethod
|
||||
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
|
||||
with flask_app.app_context():
|
||||
apps = db.session.query(App).filter(App.tenant_id == tenant_id).all()
|
||||
app_ids = [app.id for app in apps]
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
messages = (
|
||||
session.query(Message)
|
||||
.filter(
|
||||
Message.app_id.in_(app_ids),
|
||||
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
if len(messages) == 0:
|
||||
break
|
||||
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/messages/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
[message.to_dict() for message in messages],
|
||||
),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
message_ids = [message.id for message in messages]
|
||||
|
||||
# delete messages
|
||||
session.query(Message).filter(
|
||||
Message.id.in_(message_ids),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(message_ids)} messages for tenant {tenant_id} "
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
conversations = (
|
||||
session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.app_id.in_(app_ids),
|
||||
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(conversations) == 0:
|
||||
break
|
||||
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/conversations/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
[conversation.to_dict() for conversation in conversations],
|
||||
),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
conversation_ids = [conversation.id for conversation in conversations]
|
||||
session.query(Conversation).filter(
|
||||
Conversation.id.in_(conversation_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(conversation_ids)}"
|
||||
f" conversations for tenant {tenant_id}"
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_node_executions = (
|
||||
session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == tenant_id,
|
||||
WorkflowNodeExecution.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(workflow_node_executions) == 0:
|
||||
break
|
||||
|
||||
# save workflow node executions
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(workflow_node_executions),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
workflow_node_execution_ids = [
|
||||
workflow_node_execution.id for workflow_node_execution in workflow_node_executions
|
||||
]
|
||||
|
||||
# delete workflow node executions
|
||||
session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id.in_(workflow_node_execution_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
|
||||
f" workflow node executions for tenant {tenant_id}"
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_runs = (
|
||||
session.query(WorkflowRun)
|
||||
.filter(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(workflow_runs) == 0:
|
||||
break
|
||||
|
||||
# save workflow runs
|
||||
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
[workflow_run.to_dict() for workflow_run in workflow_runs],
|
||||
),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
|
||||
|
||||
# delete workflow runs
|
||||
session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id.in_(workflow_run_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def process(cls, days: int, batch: int, tenant_ids: list[str]):
|
||||
"""
|
||||
Clear free plan tenant expired logs.
|
||||
"""
|
||||
|
||||
click.echo(click.style("Clearing free plan tenant expired logs", fg="white"))
|
||||
ended_at = datetime.datetime.now()
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
handled_tenant_count = 0
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
try:
|
||||
if (
|
||||
not dify_config.BILLING_ENABLED
|
||||
or BillingService.get_info(tenant_id)["subscription"]["plan"] == "sandbox"
|
||||
):
|
||||
# only process sandbox tenant
|
||||
cls.process_tenant(flask_app, tenant_id, days, batch)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tenant {tenant_id}")
|
||||
finally:
|
||||
nonlocal handled_tenant_count
|
||||
handled_tenant_count += 1
|
||||
if handled_tenant_count % 100 == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] "
|
||||
f"Processed {handled_tenant_count} tenants "
|
||||
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
|
||||
f"{handled_tenant_count}/{total_tenant_count}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
futures = []
|
||||
|
||||
if tenant_ids:
|
||||
for tenant_id in tenant_ids:
|
||||
futures.append(
|
||||
thread_pool.submit(
|
||||
process_tenant,
|
||||
current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
tenant_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
while current_time < ended_at:
|
||||
click.echo(
|
||||
click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white")
|
||||
)
|
||||
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
||||
interval = datetime.timedelta(days=1)
|
||||
# Process tenants in this batch
|
||||
with Session(db.engine) as session:
|
||||
# Calculate tenant count in next batch with current interval
|
||||
# Try different intervals until we find one with a reasonable tenant count
|
||||
test_intervals = [
|
||||
datetime.timedelta(days=1),
|
||||
datetime.timedelta(hours=12),
|
||||
datetime.timedelta(hours=6),
|
||||
datetime.timedelta(hours=3),
|
||||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.filter(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
break
|
||||
else:
|
||||
# If all intervals have too many tenants, use minimum interval
|
||||
interval = datetime.timedelta(hours=1)
|
||||
|
||||
# Adjust interval to target ~100 tenants per batch
|
||||
if tenant_count > 0:
|
||||
# Scale interval based on ratio to target count
|
||||
interval = min(
|
||||
datetime.timedelta(days=1), # Max 1 day
|
||||
max(
|
||||
datetime.timedelta(hours=1), # Min 1 hour
|
||||
interval * (100 / tenant_count), # Scale to target 100
|
||||
),
|
||||
)
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
.filter(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
|
||||
tenants = []
|
||||
for row in rs:
|
||||
tenant_id = str(row.id)
|
||||
try:
|
||||
tenants.append(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
thread_pool.submit(
|
||||
process_tenant,
|
||||
current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
current_time = batch_end
|
||||
|
||||
# wait for all threads to finish
|
||||
for future in futures:
|
||||
future.result()
|
||||
+150
-45
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -9,12 +10,15 @@ from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
@@ -42,7 +46,6 @@ from models.source import DataSourceOauthBinding
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
ChildChunkUpdateArgs,
|
||||
KnowledgeConfig,
|
||||
MetaDataConfig,
|
||||
RerankingModel,
|
||||
RetrievalModel,
|
||||
SegmentUpdateArgs,
|
||||
@@ -243,7 +246,7 @@ class DatasetService:
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
@@ -268,7 +271,15 @@ class DatasetService:
|
||||
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
||||
if not external_knowledge_api_id:
|
||||
raise ValueError("External knowledge api id is required.")
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(dataset_id=dataset_id).first()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
)
|
||||
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("External knowledge binding not found.")
|
||||
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
@@ -316,25 +327,76 @@ class DatasetService:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if (
|
||||
data["embedding_model_provider"] != dataset.embedding_model_provider
|
||||
or data["embedding_model"] != dataset.embedding_model
|
||||
"embedding_model_provider" not in data
|
||||
or "embedding_model" not in data
|
||||
or not data.get("embedding_model_provider")
|
||||
or not data.get("embedding_model")
|
||||
):
|
||||
action = "update"
|
||||
# If the dataset already has embedding model settings, use those
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
# Keep existing values
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
# If collection_binding_id exists, keep it too
|
||||
if dataset.collection_binding_id:
|
||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||
# Otherwise, don't try to update embedding model settings at all
|
||||
# Remove these fields from filtered_data if they exist but are None/empty
|
||||
if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
|
||||
del filtered_data["embedding_model_provider"]
|
||||
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
|
||||
del filtered_data["embedding_model"]
|
||||
else:
|
||||
skip_embedding_update = False
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data["embedding_model"],
|
||||
)
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
# Handle existing model provider
|
||||
plugin_model_provider = dataset.embedding_model_provider
|
||||
plugin_model_provider_str = None
|
||||
if plugin_model_provider:
|
||||
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
|
||||
|
||||
# Handle new model provider from request
|
||||
new_plugin_model_provider = data["embedding_model_provider"]
|
||||
new_plugin_model_provider_str = None
|
||||
if new_plugin_model_provider:
|
||||
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
|
||||
|
||||
# Only update embedding model if both values are provided and different from current
|
||||
if (
|
||||
plugin_model_provider_str != new_plugin_model_provider_str
|
||||
or data["embedding_model"] != dataset.embedding_model
|
||||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data["embedding_model"],
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, skip updating it
|
||||
# and keep the existing settings if available
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
if dataset.collection_binding_id:
|
||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||
# Skip the rest of the embedding model update
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -582,9 +644,45 @@ class DocumentService:
|
||||
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_ids),
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all()
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@@ -667,8 +765,13 @@ class DocumentService:
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("No permission.")
|
||||
|
||||
document.name = name
|
||||
if dataset.built_in_field_enabled:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata[BuiltInField.document_name.value] = name
|
||||
document.doc_metadata = doc_metadata
|
||||
|
||||
document.name = name
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@@ -888,16 +991,13 @@ class DocumentService:
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
if knowledge_config.metadata:
|
||||
document.doc_type = knowledge_config.metadata.doc_type
|
||||
document.metadata = knowledge_config.metadata.doc_metadata
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
@@ -914,7 +1014,6 @@ class DocumentService:
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
knowledge_config.metadata,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
@@ -958,6 +1057,8 @@ class DocumentService:
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||
"type": page.type,
|
||||
}
|
||||
# Truncate page name to 255 characters to prevent DB field length errors
|
||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
@@ -968,9 +1069,8 @@ class DocumentService:
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
page.page_name,
|
||||
truncated_page_name,
|
||||
batch,
|
||||
knowledge_config.metadata,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
@@ -1011,7 +1111,6 @@ class DocumentService:
|
||||
account,
|
||||
document_name,
|
||||
batch,
|
||||
knowledge_config.metadata,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
@@ -1049,7 +1148,6 @@ class DocumentService:
|
||||
account: Account,
|
||||
name: str,
|
||||
batch: str,
|
||||
metadata: Optional[MetaDataConfig] = None,
|
||||
):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
@@ -1065,9 +1163,17 @@ class DocumentService:
|
||||
doc_form=document_form,
|
||||
doc_language=document_language,
|
||||
)
|
||||
if metadata is not None:
|
||||
document.doc_metadata = metadata.doc_metadata
|
||||
document.doc_type = metadata.doc_type
|
||||
doc_metadata = {}
|
||||
if dataset.built_in_field_enabled:
|
||||
doc_metadata = {
|
||||
BuiltInField.document_name: name,
|
||||
BuiltInField.uploader: account.name,
|
||||
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.source: data_source_type,
|
||||
}
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
@@ -1180,10 +1286,6 @@ class DocumentService:
|
||||
# update document name
|
||||
if document_data.name:
|
||||
document.name = document_data.name
|
||||
# update doc_type and doc_metadata if provided
|
||||
if document_data.metadata is not None:
|
||||
document.doc_metadata = document_data.metadata.doc_type
|
||||
document.doc_type = document_data.metadata.doc_type
|
||||
# update document to be waiting
|
||||
document.indexing_status = "waiting"
|
||||
document.completed_at = None
|
||||
@@ -1468,7 +1570,7 @@ class SegmentService:
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
lock_name = "add_segment_lock_document_id_{}".format(document.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
max_position = (
|
||||
@@ -1545,9 +1647,12 @@ class SegmentService:
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[content + segment_item["answer"]]
|
||||
)[0]
|
||||
else:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
@@ -1695,9 +1800,9 @@ class SegmentService:
|
||||
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
|
||||
else:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
@@ -1850,7 +1955,7 @@ class SegmentService:
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segmment_ids.append(segment.id)
|
||||
@@ -1942,7 +2047,7 @@ class SegmentService:
|
||||
child_chunk.content = child_chunk_update_args.content
|
||||
child_chunk.word_count = len(child_chunk.content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
child_chunk.type = "customized"
|
||||
update_child_chunks.append(child_chunk)
|
||||
else:
|
||||
@@ -1999,7 +2104,7 @@ class SegmentService:
|
||||
child_chunk.content = content
|
||||
child_chunk.word_count = len(content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
child_chunk.type = "customized"
|
||||
db.session.add(child_chunk)
|
||||
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
||||
|
||||
@@ -84,13 +84,31 @@ class RerankingModel(BaseModel):
|
||||
reranking_model_name: Optional[str] = None
|
||||
|
||||
|
||||
class WeightVectorSetting(BaseModel):
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class WeightKeywordSetting(BaseModel):
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightModel(BaseModel):
|
||||
weight_type: str
|
||||
vector_setting: Optional[WeightVectorSetting] = None
|
||||
keyword_setting: Optional[WeightKeywordSetting] = None
|
||||
|
||||
|
||||
class RetrievalModel(BaseModel):
|
||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
|
||||
reranking_enable: bool
|
||||
reranking_model: Optional[RerankingModel] = None
|
||||
reranking_mode: Optional[str] = None
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
score_threshold: Optional[float] = None
|
||||
weights: Optional[WeightModel] = None
|
||||
|
||||
|
||||
class MetaDataConfig(BaseModel):
|
||||
@@ -110,7 +128,6 @@ class KnowledgeConfig(BaseModel):
|
||||
embedding_model: Optional[str] = None
|
||||
embedding_model_provider: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
metadata: Optional[MetaDataConfig] = None
|
||||
|
||||
|
||||
class SegmentUpdateArgs(BaseModel):
|
||||
@@ -124,3 +141,36 @@ class SegmentUpdateArgs(BaseModel):
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
id: Optional[str] = None
|
||||
content: str
|
||||
|
||||
|
||||
class MetadataArgs(BaseModel):
|
||||
type: Literal["string", "number", "time"]
|
||||
name: str
|
||||
|
||||
|
||||
class MetadataUpdateArgs(BaseModel):
|
||||
name: str
|
||||
value: Optional[str | int | float] = None
|
||||
|
||||
|
||||
class MetadataValueUpdateArgs(BaseModel):
|
||||
fields: list[MetadataUpdateArgs]
|
||||
|
||||
|
||||
class MetadataDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
value: Optional[str | int | float] = None
|
||||
|
||||
|
||||
class DocumentMetadataOperation(BaseModel):
|
||||
document_id: str
|
||||
metadata_list: list[MetadataDetail]
|
||||
|
||||
|
||||
class MetadataOperationData(BaseModel):
|
||||
"""
|
||||
Metadata operation data
|
||||
"""
|
||||
|
||||
operation_data: list[DocumentMetadataOperation]
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.entities.model_entities import (
|
||||
ModelWithProviderEntity,
|
||||
ProviderModelWithStatusEntity,
|
||||
)
|
||||
from core.entities.provider_entities import QuotaConfiguration
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
@@ -18,7 +18,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
ProviderHelpEntity,
|
||||
SimpleProviderEntity,
|
||||
)
|
||||
from models.provider import ProviderQuotaType, ProviderType
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
@@ -53,6 +53,7 @@ class ProviderResponse(BaseModel):
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
@@ -74,7 +75,9 @@ class ProviderResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
@@ -91,6 +94,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
Model class for provider with models response.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
@@ -101,7 +105,9 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
@@ -118,10 +124,14 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
Simple provider entity response.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
@@ -146,13 +156,14 @@ class DefaultModelResponse(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
# FIXME type error ignore here
|
||||
provider: SimpleProviderEntityResponse # type: ignore
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
super().__init__(**model.model_dump())
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
|
||||
dump_model = model.model_dump()
|
||||
dump_model["provider"]["tenant_id"] = tenant_id
|
||||
super().__init__(**dump_model)
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
class WorkflowInUseError(ValueError):
|
||||
"""Raised when attempting to delete a workflow that's in use by an app"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DraftWorkflowDeletionError(ValueError):
|
||||
"""Raised when attempting to delete a draft workflow"""
|
||||
|
||||
pass
|
||||
@@ -8,6 +8,7 @@ import validators
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import (
|
||||
Dataset,
|
||||
@@ -245,7 +246,11 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def fetch_external_knowledge_retrieval(
|
||||
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
) -> list:
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
@@ -272,6 +277,7 @@ class ExternalDatasetService:
|
||||
},
|
||||
"query": query,
|
||||
"knowledge_id": external_knowledge_binding.external_knowledge_id,
|
||||
"metadata_condition": metadata_condition.model_dump() if metadata_condition else None,
|
||||
}
|
||||
|
||||
response = ExternalDatasetService.process_external_api(
|
||||
|
||||
@@ -43,6 +43,7 @@ class FeatureModel(BaseModel):
|
||||
members: LimitationModel = LimitationModel(size=0, limit=1)
|
||||
apps: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
|
||||
knowledge_rate_limit: int = 10
|
||||
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
|
||||
docs_processing: str = "standard"
|
||||
@@ -54,12 +55,20 @@ class FeatureModel(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class KnowledgeRateLimitModel(BaseModel):
|
||||
enabled: bool = False
|
||||
limit: int = 10
|
||||
subscription_plan: str = ""
|
||||
|
||||
|
||||
class SystemFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
sso_enforced_for_web: bool = False
|
||||
sso_enforced_for_web_protocol: str = ""
|
||||
enable_web_sso_switch_component: bool = False
|
||||
enable_marketplace: bool = False
|
||||
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
|
||||
enable_email_code_login: bool = False
|
||||
enable_email_password_login: bool = True
|
||||
enable_social_oauth_login: bool = False
|
||||
@@ -85,6 +94,16 @@ class FeatureService:
|
||||
|
||||
return features
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
knowledge_rate_limit = KnowledgeRateLimitModel()
|
||||
if dify_config.BILLING_ENABLED and tenant_id:
|
||||
knowledge_rate_limit.enabled = True
|
||||
limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
|
||||
knowledge_rate_limit.limit = limit_info.get("limit", 10)
|
||||
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
|
||||
return knowledge_rate_limit
|
||||
|
||||
@classmethod
|
||||
def get_system_features(cls) -> SystemFeatureModel:
|
||||
system_features = SystemFeatureModel()
|
||||
@@ -96,6 +115,9 @@ class FeatureService:
|
||||
|
||||
cls._fulfill_params_from_enterprise(system_features)
|
||||
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
system_features.enable_marketplace = True
|
||||
|
||||
return system_features
|
||||
|
||||
@classmethod
|
||||
@@ -160,6 +182,9 @@ class FeatureService:
|
||||
if "model_load_balancing_enabled" in billing_info:
|
||||
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
|
||||
|
||||
if "knowledge_rate_limit" in billing_info:
|
||||
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
@@ -47,7 +47,7 @@ class HitTestingService:
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 2),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
|
||||
@@ -15,7 +15,6 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
LastMessageNotExistsError,
|
||||
@@ -46,6 +45,8 @@ class MessageService:
|
||||
app_model=app_model, user=user, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
fetch_limit = limit + 1
|
||||
|
||||
if first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
@@ -64,7 +65,7 @@ class MessageService:
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.limit(fetch_limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
@@ -72,25 +73,14 @@ class MessageService:
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.limit(fetch_limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
if len(history_messages) > limit:
|
||||
has_more = True
|
||||
history_messages = history_messages[:-1]
|
||||
|
||||
if order == "asc":
|
||||
history_messages = list(reversed(history_messages))
|
||||
@@ -112,6 +102,8 @@ class MessageService:
|
||||
|
||||
base_query = db.session.query(Message)
|
||||
|
||||
fetch_limit = limit + 1
|
||||
|
||||
if conversation_id is not None:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, user=user, conversation_id=conversation_id
|
||||
@@ -131,21 +123,16 @@ class MessageService:
|
||||
history_messages = (
|
||||
base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.limit(fetch_limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
history_messages = base_query.order_by(Message.created_at.desc()).limit(fetch_limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = base_query.filter(
|
||||
Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
if len(history_messages) > limit:
|
||||
has_more = True
|
||||
history_messages = history_messages[:-1]
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
@@ -222,12 +209,6 @@ class MessageService:
|
||||
app_model=app_model, conversation_id=message.conversation_id, user=user
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != "normal":
|
||||
raise ConversationCompletedError()
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
|
||||
from services.dataset_service import DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
|
||||
|
||||
class MetadataService:
|
||||
@staticmethod
|
||||
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
||||
# check if metadata name already exists
|
||||
if DatasetMetadata.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
|
||||
).first():
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
if field.value == metadata_args.name:
|
||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||
metadata = DatasetMetadata(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
type=metadata_args.type,
|
||||
name=metadata_args.name,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(metadata)
|
||||
db.session.commit()
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
# check if metadata name already exists
|
||||
if DatasetMetadata.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
|
||||
).first():
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
if field.value == name:
|
||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
old_name = metadata.name
|
||||
metadata.name = name
|
||||
metadata.updated_by = current_user.id
|
||||
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
# update related documents
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
value = doc_metadata.pop(old_name, None)
|
||||
doc_metadata[name] = value
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return metadata # type: ignore
|
||||
except Exception:
|
||||
logging.exception("Update metadata name failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def delete_metadata(dataset_id: str, metadata_id: str):
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
db.session.delete(metadata)
|
||||
|
||||
# deal related documents
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(metadata.name, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return metadata
|
||||
except Exception:
|
||||
logging.exception("Delete metadata failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def get_built_in_fields():
|
||||
return [
|
||||
{"name": BuiltInField.document_name.value, "type": "string"},
|
||||
{"name": BuiltInField.uploader.value, "type": "string"},
|
||||
{"name": BuiltInField.upload_date.value, "type": "time"},
|
||||
{"name": BuiltInField.last_update_date.value, "type": "time"},
|
||||
{"name": BuiltInField.source.value, "type": "string"},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def enable_built_in_field(dataset: Dataset):
|
||||
if dataset.built_in_field_enabled:
|
||||
return
|
||||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_field_enabled = True
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
if documents:
|
||||
for document in documents:
|
||||
if not document.doc_metadata:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Enable built-in field failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def disable_built_in_field(dataset: Dataset):
|
||||
if not dataset.built_in_field_enabled:
|
||||
return
|
||||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_field_enabled = False
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
document_ids = []
|
||||
if documents:
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(BuiltInField.document_name.value, None)
|
||||
doc_metadata.pop(BuiltInField.uploader.value, None)
|
||||
doc_metadata.pop(BuiltInField.upload_date.value, None)
|
||||
doc_metadata.pop(BuiltInField.last_update_date.value, None)
|
||||
doc_metadata.pop(BuiltInField.source.value, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
document_ids.append(document.id)
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Disable built-in field failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
|
||||
for operation in metadata_args.operation_data:
|
||||
lock_key = f"document_metadata_lock_{operation.document_id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id)
|
||||
document = DocumentService.get_document(dataset.id, operation.document_id)
|
||||
if document is None:
|
||||
raise ValueError("Document not found.")
|
||||
doc_metadata = {}
|
||||
for metadata_value in operation.metadata_list:
|
||||
doc_metadata[metadata_value.name] = metadata_value.value
|
||||
if dataset.built_in_field_enabled:
|
||||
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# deal metadata binding
|
||||
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
|
||||
for metadata_value in operation.metadata_list:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=operation.document_id,
|
||||
metadata_id=metadata_value.id,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Update documents metadata failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
|
||||
if dataset_id:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
if redis_client.get(lock_key):
|
||||
raise ValueError("Another knowledge base metadata operation is running, please wait a moment.")
|
||||
redis_client.set(lock_key, 1, ex=3600)
|
||||
if document_id:
|
||||
lock_key = f"document_metadata_lock_{document_id}"
|
||||
if redis_client.get(lock_key):
|
||||
raise ValueError("Another document metadata operation is running, please wait a moment.")
|
||||
redis_client.set(lock_key, 1, ex=3600)
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_metadatas(dataset: Dataset):
|
||||
return {
|
||||
"doc_metadata": [
|
||||
{
|
||||
"id": item.get("id"),
|
||||
"name": item.get("name"),
|
||||
"type": item.get("type"),
|
||||
"count": DatasetMetadataBinding.query.filter_by(
|
||||
metadata_id=item.get("id"), dataset_id=dataset.id
|
||||
).count(),
|
||||
}
|
||||
for item in dataset.doc_metadata or []
|
||||
if item.get("id") != "built-in"
|
||||
],
|
||||
"built_in_field_enabled": dataset.built_in_field_enabled,
|
||||
}
|
||||
@@ -14,7 +14,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
ModelCredentialSchema,
|
||||
ProviderCredentialSchema,
|
||||
)
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_database import db
|
||||
from models.provider import LoadBalancingModelConfig
|
||||
@@ -527,6 +527,7 @@ class ModelLoadBalancingService:
|
||||
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
|
||||
|
||||
if validate:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
if isinstance(credential_schemas, ModelCredentialSchema):
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=provider_configuration.provider.provider,
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from typing import Optional
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import ProviderType
|
||||
from services.entities.model_provider_entities import (
|
||||
@@ -54,6 +47,7 @@ class ModelProviderService:
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_configuration.provider.provider,
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
@@ -97,10 +91,11 @@ class ModelProviderService:
|
||||
|
||||
# Get provider available models
|
||||
return [
|
||||
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
|
||||
ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
|
||||
for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str):
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
|
||||
"""
|
||||
get provider credentials.
|
||||
"""
|
||||
@@ -168,7 +163,7 @@ class ModelProviderService:
|
||||
# Remove custom provider credentials.
|
||||
provider_configuration.delete_custom_credentials()
|
||||
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]:
|
||||
"""
|
||||
get model credentials.
|
||||
|
||||
@@ -302,6 +297,7 @@ class ModelProviderService:
|
||||
|
||||
providers_with_models.append(
|
||||
ProviderWithModelsResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
@@ -343,18 +339,17 @@ class ModelProviderService:
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = provider_configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# fetch credentials
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
|
||||
|
||||
if not credentials:
|
||||
return []
|
||||
|
||||
# Call get_parameter_rules method of model instance to get model parameter rules
|
||||
return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials))
|
||||
model_schema = provider_configuration.get_model_schema(
|
||||
model_type=ModelType.LLM, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
return model_schema.parameter_rules if model_schema else []
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
"""
|
||||
@@ -365,13 +360,15 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
|
||||
|
||||
try:
|
||||
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
|
||||
return (
|
||||
DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
@@ -383,7 +380,7 @@ class ModelProviderService:
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
logger.debug(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
@@ -402,55 +399,21 @@ class ModelProviderService:
|
||||
)
|
||||
|
||||
def get_model_provider_icon(
|
||||
self, provider: str, icon_type: str, lang: str
|
||||
self, tenant_id: str, provider: str, icon_type: str, lang: str
|
||||
) -> tuple[Optional[bytes], Optional[str]]:
|
||||
"""
|
||||
get model provider icon.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_large)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return:
|
||||
"""
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
file_name: str | None = None
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
else:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
if not file_name:
|
||||
return None, None
|
||||
|
||||
root_path = current_app.root_path
|
||||
provider_instance_path = os.path.dirname(
|
||||
os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/"))
|
||||
)
|
||||
file_path = os.path.join(provider_instance_path, "_assets")
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return None, None
|
||||
|
||||
mimetype, _ = mimetypes.guess_type(file_path)
|
||||
mimetype = mimetype or "application/octet-stream"
|
||||
|
||||
# read binary from file
|
||||
byte_data = Path(file_path).read_bytes()
|
||||
return byte_data, mimetype
|
||||
return byte_data, mime_type
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
|
||||
"""
|
||||
@@ -516,48 +479,3 @@ class ModelProviderService:
|
||||
|
||||
# Enable model
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/apply"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider})
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
if response.json()["code"] != "success":
|
||||
raise ValueError(f"error: {response.json()['message']}")
|
||||
|
||||
rst = response.json()
|
||||
|
||||
if rst["type"] == "redirect":
|
||||
return {"type": rst["type"], "redirect_url": rst["redirect_url"]}
|
||||
else:
|
||||
return {"type": rst["type"], "result": "success"}
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/qualification-verify"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
json_data = {"workspace_id": tenant_id, "provider_name": provider}
|
||||
if token:
|
||||
json_data["token"] = token
|
||||
response = requests.post(api_url, headers=headers, json=json_data)
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
rst = response.json()
|
||||
if rst["code"] != "success":
|
||||
raise ValueError(f"error: {rst['message']}")
|
||||
|
||||
data = rst["data"]
|
||||
if data["qualified"] is True:
|
||||
return {"result": "success", "provider_name": provider, "flag": True}
|
||||
else:
|
||||
return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]}
|
||||
|
||||
@@ -24,10 +24,10 @@ class OpsService:
|
||||
return None
|
||||
|
||||
# decrypt_token and obfuscated_token
|
||||
tenant = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not tenant:
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
return None
|
||||
tenant_id = tenant.tenant_id
|
||||
tenant_id = app.tenant_id
|
||||
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
|
||||
tenant_id, tracing_provider, trace_config_data.tracing_config
|
||||
)
|
||||
@@ -117,10 +117,10 @@ class OpsService:
|
||||
return None
|
||||
|
||||
# get tenant id
|
||||
tenant = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not tenant:
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
return None
|
||||
tenant_id = tenant.tenant_id
|
||||
tenant_id = app.tenant_id
|
||||
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
|
||||
if project_url:
|
||||
tracing_config["project_url"] = project_url
|
||||
@@ -157,10 +157,10 @@ class OpsService:
|
||||
return None
|
||||
|
||||
# get tenant id
|
||||
tenant = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not tenant:
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
return None
|
||||
tenant_id = tenant.tenant_id
|
||||
tenant_id = app.tenant_id
|
||||
tracing_config = OpsTraceManager.encrypt_tracing_config(
|
||||
tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config
|
||||
)
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from models.engine import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
cls.migrate_db_records("providers", "provider_name") # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name")
|
||||
cls.migrate_db_records("provider_orders", "provider_name")
|
||||
cls.migrate_db_records("tenant_default_models", "provider_name")
|
||||
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
|
||||
cls.migrate_db_records("provider_model_settings", "provider_name")
|
||||
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
|
||||
cls.migrate_datasets()
|
||||
cls.migrate_db_records("embeddings", "provider_name") # large table
|
||||
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider")
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
while True:
|
||||
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
|
||||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
retrieval_model = i.retrieval_model
|
||||
print(type(retrieval_model))
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
retrieval_model_changed = False
|
||||
if retrieval_model:
|
||||
if (
|
||||
"reranking_model" in retrieval_model
|
||||
and "reranking_provider_name" in retrieval_model["reranking_model"]
|
||||
and retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
):
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating {table_name} {record_id} "
|
||||
f"(reranking_provider_name: "
|
||||
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
|
||||
)
|
||||
retrieval_model_changed = True
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||
params = {"record_id": record_id}
|
||||
update_retrieval_model_sql = ""
|
||||
if retrieval_model and retrieval_model_changed:
|
||||
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
||||
params["retrieval_model"] = json.dumps(retrieval_model)
|
||||
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||
{update_retrieval_model_sql}
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), params)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||
)
|
||||
continue
|
||||
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
while True:
|
||||
sql = f"""select id, {provider_column_name} as provider_name from {table_name}
|
||||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), {"record_id": record_id})
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||
)
|
||||
continue
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
||||
@@ -0,0 +1,121 @@
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
|
||||
|
||||
class DependenciesAnalysisService:
|
||||
@classmethod
|
||||
def analyze_tool_dependency(cls, tool_id: str) -> str:
|
||||
"""
|
||||
Analyze the dependency of a tool.
|
||||
|
||||
Convert the tool id to the plugin_id
|
||||
"""
|
||||
try:
|
||||
return ToolProviderID(tool_id).plugin_id
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def analyze_model_provider_dependency(cls, model_provider_id: str) -> str:
|
||||
"""
|
||||
Analyze the dependency of a model provider.
|
||||
|
||||
Convert the model provider id to the plugin_id
|
||||
"""
|
||||
try:
|
||||
return ModelProviderID(model_provider_id).plugin_id
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dependencies: list[PluginDependency]) -> list[PluginDependency]:
|
||||
"""
|
||||
Check dependencies, returns the leaked dependencies in current workspace
|
||||
"""
|
||||
required_plugin_unique_identifiers = []
|
||||
for dependency in dependencies:
|
||||
required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier)
|
||||
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
# get leaked dependencies
|
||||
missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers)
|
||||
missing_plugin_unique_identifiers = {plugin.plugin_unique_identifier: plugin for plugin in missing_plugins}
|
||||
|
||||
leaked_dependencies = []
|
||||
for dependency in dependencies:
|
||||
unique_identifier = dependency.value.plugin_unique_identifier
|
||||
if unique_identifier in missing_plugin_unique_identifiers:
|
||||
leaked_dependencies.append(
|
||||
PluginDependency(
|
||||
type=dependency.type,
|
||||
value=dependency.value,
|
||||
current_identifier=missing_plugin_unique_identifiers[unique_identifier].current_identifier,
|
||||
)
|
||||
)
|
||||
|
||||
return leaked_dependencies
|
||||
|
||||
@classmethod
|
||||
def generate_dependencies(cls, tenant_id: str, dependencies: list[str]) -> list[PluginDependency]:
|
||||
"""
|
||||
Generate dependencies through the list of plugin ids
|
||||
"""
|
||||
dependencies = list(set(dependencies))
|
||||
manager = PluginInstallationManager()
|
||||
plugins = manager.fetch_plugin_installation_by_ids(tenant_id, dependencies)
|
||||
result = []
|
||||
for plugin in plugins:
|
||||
if plugin.source == PluginInstallationSource.Github:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Github,
|
||||
value=PluginDependency.Github(
|
||||
repo=plugin.meta["repo"],
|
||||
version=plugin.meta["version"],
|
||||
package=plugin.meta["package"],
|
||||
github_plugin_unique_identifier=plugin.plugin_unique_identifier,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Marketplace:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
value=PluginDependency.Marketplace(
|
||||
marketplace_plugin_unique_identifier=plugin.plugin_unique_identifier
|
||||
),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Package:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Package,
|
||||
value=PluginDependency.Package(plugin_unique_identifier=plugin.plugin_unique_identifier),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Remote:
|
||||
raise ValueError(
|
||||
f"You used a remote plugin: {plugin.plugin_unique_identifier} in the app, please remove it first"
|
||||
" if you want to export the DSL."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin source: {plugin.source}")
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def generate_latest_dependencies(cls, dependencies: list[str]) -> list[PluginDependency]:
|
||||
"""
|
||||
Generate the latest version of dependencies
|
||||
"""
|
||||
dependencies = list(set(dependencies))
|
||||
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
|
||||
return [
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
value=PluginDependency.Marketplace(marketplace_plugin_unique_identifier=dep.latest_package_identifier),
|
||||
)
|
||||
for dep in deps
|
||||
]
|
||||
@@ -0,0 +1,66 @@
|
||||
from core.plugin.manager.endpoint import PluginEndpointManager
|
||||
|
||||
|
||||
class EndpointService:
|
||||
@classmethod
|
||||
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
|
||||
return PluginEndpointManager().create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_endpoints(cls, tenant_id: str, user_id: str, page: int, page_size: int):
|
||||
return PluginEndpointManager().list_endpoints(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_endpoints_for_single_plugin(cls, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
|
||||
return PluginEndpointManager().list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
return PluginEndpointManager().update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointManager().delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointManager().enable_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def disable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointManager().disable_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
@@ -0,0 +1,501 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import BuiltinToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int) -> None:
|
||||
"""
|
||||
Migrate plugin.
|
||||
"""
|
||||
from threading import Lock
|
||||
|
||||
click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
|
||||
ended_at = datetime.datetime.now()
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
handled_tenant_count = 0
|
||||
file_lock = Lock()
|
||||
counter_lock = Lock()
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
with flask_app.app_context():
|
||||
nonlocal handled_tenant_count
|
||||
try:
|
||||
plugins = cls.extract_installed_plugin_ids(tenant_id)
|
||||
# Use lock when writing to file
|
||||
with file_lock:
|
||||
with open(filepath, "a") as f:
|
||||
f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
|
||||
|
||||
# Use lock when updating counter
|
||||
with counter_lock:
|
||||
nonlocal handled_tenant_count
|
||||
handled_tenant_count += 1
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] "
|
||||
f"Processed {handled_tenant_count} tenants "
|
||||
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
|
||||
f"{handled_tenant_count}/{total_tenant_count}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tenant {tenant_id}")
|
||||
|
||||
futures = []
|
||||
|
||||
while current_time < ended_at:
|
||||
click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
|
||||
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
||||
interval = datetime.timedelta(days=1)
|
||||
# Process tenants in this batch
|
||||
with Session(db.engine) as session:
|
||||
# Calculate tenant count in next batch with current interval
|
||||
# Try different intervals until we find one with a reasonable tenant count
|
||||
test_intervals = [
|
||||
datetime.timedelta(days=1),
|
||||
datetime.timedelta(hours=12),
|
||||
datetime.timedelta(hours=6),
|
||||
datetime.timedelta(hours=3),
|
||||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.filter(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
break
|
||||
else:
|
||||
# If all intervals have too many tenants, use minimum interval
|
||||
interval = datetime.timedelta(hours=1)
|
||||
|
||||
# Adjust interval to target ~100 tenants per batch
|
||||
if tenant_count > 0:
|
||||
# Scale interval based on ratio to target count
|
||||
interval = min(
|
||||
datetime.timedelta(days=1), # Max 1 day
|
||||
max(
|
||||
datetime.timedelta(hours=1), # Min 1 hour
|
||||
interval * (100 / tenant_count), # Scale to target 100
|
||||
),
|
||||
)
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
.filter(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
|
||||
tenants = []
|
||||
for row in rs:
|
||||
tenant_id = str(row.id)
|
||||
try:
|
||||
tenants.append(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
thread_pool.submit(
|
||||
process_tenant,
|
||||
current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
current_time = batch_end
|
||||
|
||||
# wait for all threads to finish
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
@classmethod
|
||||
def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract installed plugin ids.
|
||||
"""
|
||||
tools = cls.extract_tool_tables(tenant_id)
|
||||
models = cls.extract_model_tables(tenant_id)
|
||||
workflows = cls.extract_workflow_tables(tenant_id)
|
||||
apps = cls.extract_app_tables(tenant_id)
|
||||
|
||||
return list({*tools, *models, *workflows, *apps})
|
||||
|
||||
@classmethod
|
||||
def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract model tables.
|
||||
|
||||
"""
|
||||
models: list[str] = []
|
||||
table_pairs = [
|
||||
("providers", "provider_name"),
|
||||
("provider_models", "provider_name"),
|
||||
("provider_orders", "provider_name"),
|
||||
("tenant_default_models", "provider_name"),
|
||||
("tenant_preferred_model_providers", "provider_name"),
|
||||
("provider_model_settings", "provider_name"),
|
||||
("load_balancing_model_configs", "provider_name"),
|
||||
]
|
||||
|
||||
for table, column in table_pairs:
|
||||
models.extend(cls.extract_model_table(tenant_id, table, column))
|
||||
|
||||
# duplicate models
|
||||
models = list(set(models))
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract model table.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.execute(
|
||||
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
|
||||
)
|
||||
result = []
|
||||
for row in rs:
|
||||
provider_name = str(row[0])
|
||||
result.append(ModelProviderID(provider_name).plugin_id)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract tool tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
result.append(ToolProviderID(row.provider).plugin_id)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract workflow tables, only ToolNode is required.
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
graph = row.graph_dict
|
||||
# get nodes
|
||||
nodes = graph.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") == "tool":
|
||||
provider_name = data.get("provider_name")
|
||||
provider_type = data.get("provider_type")
|
||||
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
|
||||
result.append(ToolProviderID(provider_name).plugin_id)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract app tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
apps = session.query(App).filter(App.tenant_id == tenant_id).all()
|
||||
if not apps:
|
||||
return []
|
||||
|
||||
agent_app_model_config_ids = [
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
agent_config = row.agent_mode_dict
|
||||
if "tools" in agent_config and isinstance(agent_config["tools"], list):
|
||||
for tool in agent_config["tools"]:
|
||||
if isinstance(tool, dict):
|
||||
try:
|
||||
tool_entity = AgentToolEntity(**tool)
|
||||
if (
|
||||
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
|
||||
and tool_entity.provider_id not in excluded_providers
|
||||
):
|
||||
result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tool {tool}")
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
|
||||
"""
|
||||
Fetch plugin unique identifier using plugin id.
|
||||
"""
|
||||
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
|
||||
if not plugin_manifest:
|
||||
return None
|
||||
|
||||
return plugin_manifest[0].latest_package_identifier
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
|
||||
plugins: dict[str, str] = {}
|
||||
plugin_ids = []
|
||||
plugin_not_exist = []
|
||||
logger.info(f"Extracting unique plugins from {extracted_plugins}")
|
||||
with open(extracted_plugins) as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
new_plugin_ids = data.get("plugins", [])
|
||||
for plugin_id in new_plugin_ids:
|
||||
if plugin_id not in plugin_ids:
|
||||
plugin_ids.append(plugin_id)
|
||||
|
||||
def fetch_plugin(plugin_id):
|
||||
try:
|
||||
unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
|
||||
if unique_identifier:
|
||||
plugins[plugin_id] = unique_identifier
|
||||
else:
|
||||
plugin_not_exist.append(plugin_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to fetch plugin unique identifier for {plugin_id}")
|
||||
plugin_not_exist.append(plugin_id)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
|
||||
|
||||
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
|
||||
|
||||
@classmethod
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
plugins = cls.extract_unique_plugins(extracted_plugins)
|
||||
not_installed = []
|
||||
plugin_install_failed = []
|
||||
|
||||
# use a fake tenant id to install all the plugins
|
||||
fake_tenant_id = uuid4().hex
|
||||
logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(tenant_id: str, plugin_ids: list[str]) -> None:
|
||||
logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
# at most 64 plugins one batch
|
||||
for i in range(0, len(plugin_ids), 64):
|
||||
batch_plugin_ids = plugin_ids[i : i + 64]
|
||||
batch_plugin_identifiers = [
|
||||
plugins["plugins"][plugin_id]
|
||||
for plugin_id in batch_plugin_ids
|
||||
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
|
||||
]
|
||||
manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
batch_plugin_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
|
||||
with open(extracted_plugins) as f:
|
||||
"""
|
||||
Read line by line, and install plugins for each tenant.
|
||||
"""
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
tenant_id = data.get("tenant_id")
|
||||
plugin_ids = data.get("plugins", [])
|
||||
current_not_installed = {
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_not_exist": [],
|
||||
}
|
||||
# get plugin unique identifier
|
||||
for plugin_id in plugin_ids:
|
||||
unique_identifier = plugins.get(plugin_id)
|
||||
if unique_identifier:
|
||||
current_not_installed["plugin_not_exist"].append(plugin_id)
|
||||
|
||||
if current_not_installed["plugin_not_exist"]:
|
||||
not_installed.append(current_not_installed)
|
||||
|
||||
thread_pool.submit(install, tenant_id, plugin_ids)
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
|
||||
logger.info("Uninstall plugins")
|
||||
|
||||
# get installation
|
||||
try:
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
while installation:
|
||||
for plugin in installation:
|
||||
manager.uninstall(fake_tenant_id, plugin.installation_id)
|
||||
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
|
||||
|
||||
Path(output_file).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"not_installed": not_installed,
|
||||
"plugin_install_failed": plugin_install_failed,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def handle_plugin_instance_install(
|
||||
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Install plugins for a tenant.
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
# download all the plugins and upload
|
||||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
futures = []
|
||||
|
||||
for plugin_id, plugin_identifier in plugin_identifiers_map.items():
|
||||
|
||||
def download_and_upload(tenant_id, plugin_id, plugin_identifier):
|
||||
plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
|
||||
if not plugin_package:
|
||||
raise Exception(f"Failed to download plugin {plugin_identifier}")
|
||||
|
||||
# upload
|
||||
manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
|
||||
|
||||
futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
|
||||
|
||||
# Wait for all downloads to complete
|
||||
for future in futures:
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
success = []
|
||||
failed = []
|
||||
|
||||
reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
|
||||
|
||||
# at most 8 plugins one batch
|
||||
for i in range(0, len(plugin_identifiers_map), 8):
|
||||
batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
|
||||
batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
|
||||
|
||||
try:
|
||||
response = manager.install_from_identifiers(
|
||||
tenant_id=tenant_id,
|
||||
identifiers=batch_plugin_identifiers,
|
||||
source=PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
# add to failed
|
||||
failed.extend(batch_plugin_identifiers)
|
||||
continue
|
||||
|
||||
if response.all_installed:
|
||||
success.extend(batch_plugin_identifiers)
|
||||
continue
|
||||
|
||||
task_id = response.task_id
|
||||
done = False
|
||||
while not done:
|
||||
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
|
||||
for plugin in status.plugins:
|
||||
if plugin.status == PluginInstallTaskStatus.Success:
|
||||
success.append(reverse_map[plugin.plugin_unique_identifier])
|
||||
else:
|
||||
failed.append(reverse_map[plugin.plugin_unique_identifier])
|
||||
logger.error(
|
||||
f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
|
||||
)
|
||||
|
||||
done = True
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
return {"success": success, "failed": failed}
|
||||
@@ -0,0 +1,34 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
|
||||
@staticmethod
|
||||
def change_permission(
|
||||
tenant_id: str,
|
||||
install_permission: TenantPluginPermission.InstallPermission,
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
)
|
||||
if not permission:
|
||||
permission = TenantPluginPermission(
|
||||
tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission
|
||||
)
|
||||
|
||||
session.add(permission)
|
||||
else:
|
||||
permission.install_permission = install_permission
|
||||
permission.debug_permission = debug_permission
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
@@ -0,0 +1,364 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_type
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import marketplace
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.helper.marketplace import download_plugin_pkg
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
|
||||
from core.plugin.manager.asset import PluginAssetManager
|
||||
from core.plugin.manager.debugging import PluginDebuggingManager
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginService:
|
||||
class LatestPluginCache(BaseModel):
|
||||
plugin_id: str
|
||||
version: str
|
||||
unique_identifier: str
|
||||
|
||||
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
|
||||
REDIS_TTL = 60 * 5 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
|
||||
"""
|
||||
Fetch the latest plugin version
|
||||
"""
|
||||
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
|
||||
|
||||
try:
|
||||
cache_not_exists = []
|
||||
|
||||
# Try to get from Redis first
|
||||
for plugin_id in plugin_ids:
|
||||
cached_data = redis_client.get(f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}")
|
||||
if cached_data:
|
||||
result[plugin_id] = PluginService.LatestPluginCache.model_validate_json(cached_data)
|
||||
else:
|
||||
cache_not_exists.append(plugin_id)
|
||||
|
||||
if cache_not_exists:
|
||||
manifests = {
|
||||
manifest.plugin_id: manifest
|
||||
for manifest in marketplace.batch_fetch_plugin_manifests(cache_not_exists)
|
||||
}
|
||||
|
||||
for plugin_id, manifest in manifests.items():
|
||||
latest_plugin = PluginService.LatestPluginCache(
|
||||
plugin_id=plugin_id,
|
||||
version=manifest.latest_version,
|
||||
unique_identifier=manifest.latest_package_identifier,
|
||||
)
|
||||
|
||||
# Store in Redis
|
||||
redis_client.setex(
|
||||
f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}",
|
||||
PluginService.REDIS_TTL,
|
||||
latest_plugin.model_dump_json(),
|
||||
)
|
||||
|
||||
result[plugin_id] = latest_plugin
|
||||
|
||||
# pop plugin_id from cache_not_exists
|
||||
cache_not_exists.remove(plugin_id)
|
||||
|
||||
for plugin_id in cache_not_exists:
|
||||
result[plugin_id] = None
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("failed to fetch latest plugin version")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_debugging_key(tenant_id: str) -> str:
|
||||
"""
|
||||
get the debugging key of the tenant
|
||||
"""
|
||||
manager = PluginDebuggingManager()
|
||||
return manager.get_debugging_key(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def list(tenant_id: str) -> list[PluginEntity]:
|
||||
"""
|
||||
list all plugins of the tenant
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [plugin.plugin_id for plugin in plugins if plugin.source == PluginInstallationSource.Marketplace]
|
||||
try:
|
||||
manifests = PluginService.fetch_latest_plugin_version(plugin_ids)
|
||||
except Exception:
|
||||
manifests = {}
|
||||
logger.exception("failed to fetch plugin manifests")
|
||||
|
||||
for plugin in plugins:
|
||||
if plugin.source == PluginInstallationSource.Marketplace:
|
||||
if plugin.plugin_id in manifests:
|
||||
latest_plugin_cache = manifests[plugin.plugin_id]
|
||||
if latest_plugin_cache:
|
||||
# set latest_version
|
||||
plugin.latest_version = latest_plugin_cache.version
|
||||
plugin.latest_unique_identifier = latest_plugin_cache.unique_identifier
|
||||
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
|
||||
"""
|
||||
List plugin installations from ids
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
|
||||
|
||||
@staticmethod
|
||||
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
get the asset file of the plugin
|
||||
"""
|
||||
manager = PluginAssetManager()
|
||||
# guess mime type
|
||||
mime_type, _ = guess_type(asset_file)
|
||||
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
|
||||
"""
|
||||
check if the plugin unique identifier is already installed by other tenant
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_by_identifier(tenant_id, plugin_unique_identifier)
|
||||
|
||||
@staticmethod
|
||||
def fetch_plugin_manifest(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||
"""
|
||||
Fetch plugin manifest
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||
|
||||
@staticmethod
|
||||
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
|
||||
"""
|
||||
Fetch plugin installation tasks
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
|
||||
|
||||
@staticmethod
|
||||
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_install_task(tenant_id: str, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.delete_plugin_installation_task(tenant_id, task_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_all_install_task_items(
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete all plugin installation task items
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.delete_all_plugin_installation_task_items(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_install_task_item(tenant_id: str, task_id: str, identifier: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task item
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.delete_plugin_installation_task_item(tenant_id, task_id, identifier)
|
||||
|
||||
@staticmethod
|
||||
def upgrade_plugin_with_marketplace(
|
||||
tenant_id: str, original_plugin_unique_identifier: str, new_plugin_unique_identifier: str
|
||||
):
|
||||
"""
|
||||
Upgrade plugin with marketplace
|
||||
"""
|
||||
if original_plugin_unique_identifier == new_plugin_unique_identifier:
|
||||
raise ValueError("you should not upgrade plugin with the same plugin")
|
||||
|
||||
# check if plugin pkg is already downloaded
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
try:
|
||||
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
|
||||
# already downloaded, skip, and record install event
|
||||
marketplace.record_install_plugin_event(new_plugin_unique_identifier)
|
||||
except Exception:
|
||||
# plugin not installed, download and upload pkg
|
||||
pkg = download_plugin_pkg(new_plugin_unique_identifier)
|
||||
manager.upload_pkg(tenant_id, pkg, verify_signature=False)
|
||||
|
||||
return manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_plugin_unique_identifier,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upgrade_plugin_with_github(
|
||||
tenant_id: str,
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
repo: str,
|
||||
version: str,
|
||||
package: str,
|
||||
):
|
||||
"""
|
||||
Upgrade plugin with github
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
PluginInstallationSource.Github,
|
||||
{
|
||||
"repo": repo,
|
||||
"version": version,
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginUploadResponse:
|
||||
"""
|
||||
Upload plugin package files
|
||||
|
||||
returns: plugin_unique_identifier
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upload_pkg(tenant_id, pkg, verify_signature)
|
||||
|
||||
@staticmethod
|
||||
def upload_pkg_from_github(
|
||||
tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False
|
||||
) -> PluginUploadResponse:
|
||||
"""
|
||||
Install plugin from github release package files,
|
||||
returns plugin_unique_identifier
|
||||
"""
|
||||
pkg = download_with_size_limit(
|
||||
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
|
||||
)
|
||||
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upload_pkg(
|
||||
tenant_id,
|
||||
pkg,
|
||||
verify_signature,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upload_bundle(
|
||||
tenant_id: str, bundle: bytes, verify_signature: bool = False
|
||||
) -> Sequence[PluginBundleDependency]:
|
||||
"""
|
||||
Upload a plugin bundle and return the dependencies.
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upload_bundle(tenant_id, bundle, verify_signature)
|
||||
|
||||
@staticmethod
|
||||
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
|
||||
manager = PluginInstallationManager()
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
plugin_unique_identifiers,
|
||||
PluginInstallationSource.Package,
|
||||
[{}],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
|
||||
"""
|
||||
Install plugin from github release package files,
|
||||
returns plugin_unique_identifier
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[plugin_unique_identifier],
|
||||
PluginInstallationSource.Github,
|
||||
[
|
||||
{
|
||||
"repo": repo,
|
||||
"version": version,
|
||||
"package": package,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def install_from_marketplace_pkg(
|
||||
tenant_id: str, plugin_unique_identifiers: Sequence[str], verify_signature: bool = False
|
||||
):
|
||||
"""
|
||||
Install plugin from marketplace package files,
|
||||
returns installation task id
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
# check if already downloaded
|
||||
for plugin_unique_identifier in plugin_unique_identifiers:
|
||||
try:
|
||||
manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||
# already downloaded, skip
|
||||
except Exception:
|
||||
# plugin not installed, download and upload pkg
|
||||
pkg = download_plugin_pkg(plugin_unique_identifier)
|
||||
manager.upload_pkg(tenant_id, pkg, verify_signature)
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
plugin_unique_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
[
|
||||
{
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
}
|
||||
for plugin_unique_identifier in plugin_unique_identifiers
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
manager = PluginInstallationManager()
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
|
||||
@staticmethod
|
||||
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
"""
|
||||
Check if the tools exist
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.check_tools_existence(tenant_id, provider_ids)
|
||||
@@ -20,7 +20,7 @@ class TagService:
|
||||
)
|
||||
if keyword:
|
||||
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.group_by(Tag.id)
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
ToolCredentialsOption,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
@@ -41,28 +41,28 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
credentials_schema = [
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
required=True,
|
||||
default="none",
|
||||
options=[
|
||||
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
],
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||
default="api_key",
|
||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||
default="",
|
||||
@@ -84,17 +84,14 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(
|
||||
schema: str, extra_info: Optional[dict] = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
return tool_bundles
|
||||
return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@@ -167,8 +164,14 @@ class ApiToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
||||
|
||||
db.session.add(db_provider)
|
||||
@@ -198,18 +201,18 @@ class ApiToolManageService:
|
||||
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("parse api schema error")
|
||||
raise ValueError("invalid schema, please check the url you provided")
|
||||
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider = (
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@@ -225,8 +228,9 @@ class ApiToolManageService:
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool_bundle,
|
||||
tenant_id=tenant_id,
|
||||
labels=labels,
|
||||
)
|
||||
for tool_bundle in provider.tools
|
||||
@@ -293,16 +297,21 @@ class ApiToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
|
||||
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
|
||||
db.session.add(provider)
|
||||
@@ -363,7 +372,7 @@ class ApiToolManageService:
|
||||
|
||||
try:
|
||||
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError("invalid schema")
|
||||
|
||||
# get tool bundle
|
||||
@@ -407,8 +416,13 @@ class ApiToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
@@ -419,20 +433,20 @@ class ApiToolManageService:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
runtime_tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
result = runtime_tool.validate_credentials(credentials, parameters)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
@@ -441,7 +455,7 @@ class ApiToolManageService:
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider in db_providers:
|
||||
# convert provider controller to user provider
|
||||
@@ -453,13 +467,13 @@ class ApiToolManageService:
|
||||
user_provider.labels = labels
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
tools = provider_controller.get_tools(tenant_id=tenant_id)
|
||||
|
||||
for tool in tools or []:
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,19 +2,19 @@ import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@@ -24,36 +24,38 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BuiltinToolManageService:
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
|
||||
:param user_id: the id of the user
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider: the name of the provider
|
||||
|
||||
:return: the list of tools
|
||||
"""
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ToolConfigurationManager(
|
||||
tenant_id=tenant_id, provider_controller=provider_controller
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
|
||||
credentials = tool_provider_configurations.decrypt(credentials)
|
||||
|
||||
result: list[UserTool] = []
|
||||
result: list[ToolApiEntity] = []
|
||||
for tool in tools or []:
|
||||
result.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
@@ -64,14 +66,47 @@ class BuiltinToolManageService:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name):
|
||||
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get builtin tool provider info
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt(credentials)
|
||||
|
||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=builtin_provider,
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
entity.original_credentials = {}
|
||||
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:param provider_name: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder(provider.get_credentials_schema())
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
@@ -81,31 +116,38 @@ class BuiltinToolManageService:
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
provider = session.scalar(stmt)
|
||||
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
||||
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# get original credentials if exists
|
||||
if provider is not None:
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
# validate credentials
|
||||
provider_controller.validate_credentials(credentials)
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
# encrypt credentials
|
||||
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
except (
|
||||
PluginDaemonClientSideError,
|
||||
ToolProviderNotFoundError,
|
||||
ToolNotFoundError,
|
||||
ToolProviderCredentialValidationError,
|
||||
) as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
if provider is None:
|
||||
@@ -117,14 +159,14 @@ class BuiltinToolManageService:
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
session.add(provider)
|
||||
|
||||
db.session.add(provider)
|
||||
else:
|
||||
provider.encrypted_credentials = json.dumps(credentials)
|
||||
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@@ -132,21 +174,19 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
||||
|
||||
if provider is None:
|
||||
if provider_obj is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
credentials = tool_configuration.decrypt(provider_obj.credentials)
|
||||
credentials = tool_configuration.mask_tool_credentials(credentials)
|
||||
return credentials
|
||||
|
||||
@@ -155,24 +195,22 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
||||
|
||||
if provider is None:
|
||||
if provider_obj is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.delete(provider_obj)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return {"result": "success"}
|
||||
@@ -182,67 +220,111 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
||||
icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
|
||||
icon_bytes = Path(icon_path).read_bytes()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
with db.session.no_autoflush:
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
# find provider
|
||||
find_provider = lambda provider: next(
|
||||
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
|
||||
)
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
if provider_controller.identity is None:
|
||||
continue
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = GenericProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider_obj = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == full_provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider_obj = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider_obj is None:
|
||||
return None
|
||||
|
||||
provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
|
||||
return provider_obj
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderTypeLiteral
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
@@ -9,17 +9,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
providers = ToolManager.user_list_providers(user_id, tenant_id, typ)
|
||||
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
|
||||
|
||||
# add icon
|
||||
for provider in providers:
|
||||
ToolTransformService.repack_provider(provider)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
|
||||
@@ -2,47 +2,61 @@ import json
|
||||
import logging
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
@staticmethod
|
||||
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
@classmethod
|
||||
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
url_prefix = (
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
|
||||
)
|
||||
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
||||
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
|
||||
url_prefix = (
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
|
||||
)
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
return url_prefix + "builtin/" + provider_name + "/icon"
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||
try:
|
||||
return cast(dict, json.loads(icon))
|
||||
except:
|
||||
if isinstance(icon, str):
|
||||
return cast(dict, json.loads(icon))
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
@@ -52,55 +66,52 @@ class ToolTransformService:
|
||||
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = cast(
|
||||
str,
|
||||
ToolTransformService.get_tool_provider_icon_url(
|
||||
elif isinstance(provider, ToolProviderApiEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.icon, str):
|
||||
provider.icon = ToolTransformService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.icon
|
||||
)
|
||||
else:
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
decrypt_credentials: bool = True,
|
||||
) -> UserToolProvider:
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
result = UserToolProvider(
|
||||
id=provider_controller.identity.name,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
pt_BR=provider_controller.identity.description.pt_BR,
|
||||
ja_JP=provider_controller.identity.description.ja_JP,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
pt_BR=provider_controller.identity.label.pt_BR,
|
||||
ja_JP=provider_controller.identity.label.ja_JP,
|
||||
),
|
||||
result = ToolProviderApiEntity(
|
||||
id=provider_controller.entity.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
plugin_id=None,
|
||||
tools=[],
|
||||
labels=provider_controller.tool_labels,
|
||||
)
|
||||
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
result.plugin_id = provider_controller.plugin_id
|
||||
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
||||
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
|
||||
|
||||
for name, value in schema.items():
|
||||
assert result.masked_credentials is not None, "masked credentials is None"
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type))
|
||||
if result.masked_credentials:
|
||||
result.masked_credentials[name] = ""
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
@@ -113,12 +124,15 @@ class ToolTransformService:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
@@ -151,41 +165,35 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController, labels: Optional[list[str]] = None
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
return UserToolProvider(
|
||||
return ToolProviderApiEntity(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
),
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id=None,
|
||||
plugin_unique_identifier=None,
|
||||
tools=[],
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def api_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> UserToolProvider:
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
@@ -193,12 +201,16 @@ class ToolTransformService:
|
||||
if db_provider.user is None:
|
||||
raise ValueError(f"user is None for api provider {db_provider.id}")
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
user = db_provider.user
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
username = user.name
|
||||
except Exception:
|
||||
logger.exception(f"failed to get user name for api provider {db_provider.id}")
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
result = ToolProviderApiEntity(
|
||||
id=db_provider.id,
|
||||
author=username,
|
||||
name=db_provider.name,
|
||||
@@ -212,6 +224,8 @@ class ToolTransformService:
|
||||
zh_Hans=db_provider.name,
|
||||
),
|
||||
type=ToolProviderType.API,
|
||||
plugin_id=None,
|
||||
plugin_unique_identifier=None,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
@@ -220,39 +234,42 @@ class ToolTransformService:
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
def convert_tool_entity_to_api_entity(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: Optional[dict] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> UserTool:
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolApiEntity:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials or {},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
parameters = tool.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters()
|
||||
# override parameters
|
||||
@@ -268,23 +285,21 @@ class ToolTransformService:
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is None")
|
||||
|
||||
return UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human if tool.description else "", # type: ignore
|
||||
return ToolApiEntity(
|
||||
author=tool.entity.identity.author,
|
||||
name=tool.entity.identity.name,
|
||||
label=tool.entity.identity.label,
|
||||
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
|
||||
output_schema=tool.entity.output_schema,
|
||||
parameters=current_parameters,
|
||||
labels=labels,
|
||||
labels=labels or [],
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""),
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
@@ -36,7 +36,7 @@ class WorkflowToolManageService:
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
@@ -54,11 +54,12 @@ class WorkflowToolManageService:
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
|
||||
workflow = app.workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
@@ -101,7 +102,7 @@ class WorkflowToolManageService:
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update a workflow tool.
|
||||
@@ -133,7 +134,7 @@ class WorkflowToolManageService:
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: Optional[WorkflowToolProvider] = (
|
||||
workflow_tool_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -142,14 +143,14 @@ class WorkflowToolManageService:
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: Optional[App] = (
|
||||
app: App | None = (
|
||||
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Optional[Workflow] = app.workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
@@ -178,7 +179,7 @@ class WorkflowToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
List workflow tools.
|
||||
:param user_id: the user id
|
||||
@@ -187,11 +188,11 @@ class WorkflowToolManageService:
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
tools = []
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
|
||||
except:
|
||||
except Exception:
|
||||
# skip deleted tools
|
||||
pass
|
||||
|
||||
@@ -203,12 +204,13 @@ class WorkflowToolManageService:
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
continue
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, []))
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(tenant_id)[0],
|
||||
labels=labels.get(tool.provider_id, []),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
|
||||
@@ -239,42 +241,12 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
@@ -285,26 +257,38 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
:return: the tool
|
||||
"""
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
workflow_app: App | None = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
workflow = workflow_app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {db_tool.id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
@@ -314,15 +298,17 @@ class WorkflowToolManageService:
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
"tool": ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"synced": workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
@@ -330,7 +316,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -340,8 +326,14 @@ class WorkflowToolManageService:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))]
|
||||
return [
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,30 +1,46 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import App, EndUser, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatedByRole
|
||||
from models.workflow import WorkflowRunStatus
|
||||
|
||||
|
||||
class WorkflowAppService:
|
||||
def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination:
|
||||
def get_paginate_workflow_app_logs(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
keyword: str | None = None,
|
||||
status: WorkflowRunStatus | None = None,
|
||||
created_at_before: datetime | None = None,
|
||||
created_at_after: datetime | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
) -> dict:
|
||||
"""
|
||||
Get paginate workflow app logs
|
||||
:param app: app model
|
||||
:param args: request args
|
||||
:return:
|
||||
Get paginate workflow app logs using SQLAlchemy 2.0 style
|
||||
:param session: SQLAlchemy session
|
||||
:param app_model: app model
|
||||
:param keyword: search keyword
|
||||
:param status: filter by status
|
||||
:param created_at_before: filter logs created before this timestamp
|
||||
:param created_at_after: filter logs created after this timestamp
|
||||
:param page: page number
|
||||
:param limit: items per page
|
||||
:return: Pagination object
|
||||
"""
|
||||
query = db.select(WorkflowAppLog).where(
|
||||
# Build base statement using SQLAlchemy 2.0 style
|
||||
stmt = select(WorkflowAppLog).where(
|
||||
WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
|
||||
)
|
||||
|
||||
status = WorkflowRunStatus.value_of(args.get("status", "")) if args.get("status") else None
|
||||
keyword = args["keyword"]
|
||||
if keyword or status:
|
||||
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||
stmt = stmt.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||
|
||||
if keyword:
|
||||
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||
@@ -40,20 +56,40 @@ class WorkflowAppService:
|
||||
if keyword_uuid:
|
||||
keyword_conditions.append(WorkflowRun.id == keyword_uuid)
|
||||
|
||||
query = query.outerjoin(
|
||||
stmt = stmt.outerjoin(
|
||||
EndUser,
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER),
|
||||
).filter(or_(*keyword_conditions))
|
||||
).where(or_(*keyword_conditions))
|
||||
|
||||
if status:
|
||||
# join with workflow_run and filter by status
|
||||
query = query.filter(WorkflowRun.status == status.value)
|
||||
stmt = stmt.where(WorkflowRun.status == status)
|
||||
|
||||
query = query.order_by(WorkflowAppLog.created_at.desc())
|
||||
# Add time-based filtering
|
||||
if created_at_before:
|
||||
stmt = stmt.where(WorkflowAppLog.created_at <= created_at_before)
|
||||
|
||||
pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
if created_at_after:
|
||||
stmt = stmt.where(WorkflowAppLog.created_at >= created_at_after)
|
||||
|
||||
return pagination
|
||||
stmt = stmt.order_by(WorkflowAppLog.created_at.desc())
|
||||
|
||||
# Get total count using the same filters
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = session.scalar(count_stmt) or 0
|
||||
|
||||
# Apply pagination limits
|
||||
offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
|
||||
|
||||
# Execute query and get items
|
||||
items = list(session.scalars(offset_stmt).all())
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"has_more": total > page * limit,
|
||||
"data": items,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _safe_parse_uuid(value: str):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import contexts
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@@ -119,6 +121,9 @@ class WorkflowRunService:
|
||||
"""
|
||||
workflow_run = self.get_workflow_run(app_model, run_id)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@@ -13,11 +14,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.event.types import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
@@ -35,6 +37,8 @@ from models.workflow import (
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""
|
||||
@@ -78,22 +82,38 @@ class WorkflowService:
|
||||
|
||||
return workflow
|
||||
|
||||
def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]:
|
||||
def get_all_published_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
page: int,
|
||||
limit: int,
|
||||
user_id: str | None,
|
||||
named_only: bool = False,
|
||||
) -> tuple[Sequence[Workflow], bool]:
|
||||
"""
|
||||
Get published workflow with pagination
|
||||
"""
|
||||
if not app_model.workflow_id:
|
||||
return [], False
|
||||
|
||||
workflows = (
|
||||
db.session.query(Workflow)
|
||||
.filter(Workflow.app_id == app_model.id)
|
||||
.order_by(desc(Workflow.version))
|
||||
.offset((page - 1) * limit)
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == app_model.id)
|
||||
.order_by(Workflow.version.desc())
|
||||
.limit(limit + 1)
|
||||
.all()
|
||||
.offset((page - 1) * limit)
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Workflow.created_by == user_id)
|
||||
|
||||
if named_only:
|
||||
stmt = stmt.where(Workflow.marked_name != "")
|
||||
|
||||
workflows = session.scalars(stmt).all()
|
||||
|
||||
has_more = len(workflows) > limit
|
||||
if has_more:
|
||||
workflows = workflows[:-1]
|
||||
@@ -156,23 +176,26 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow:
|
||||
"""
|
||||
Publish workflow from draft
|
||||
|
||||
:param app_model: App instance
|
||||
:param account: Account instance
|
||||
:param draft_workflow: Workflow instance
|
||||
"""
|
||||
if not draft_workflow:
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Workflow:
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow(
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=draft_workflow.type,
|
||||
@@ -182,15 +205,12 @@ class WorkflowService:
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
@@ -246,14 +266,69 @@ class WorkflowService:
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
node_instance, generator = WorkflowEntry.single_step_run(
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
)
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=app_model.tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.run_free_node(
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_inputs=user_inputs,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
|
||||
:param start_at: float
|
||||
:param tenant_id: str
|
||||
:param node_id: str
|
||||
"""
|
||||
try:
|
||||
node_instance, generator = getter()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
@@ -303,9 +378,7 @@ class WorkflowService:
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = app_model.tenant_id
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
workflow_node_execution.tenant_id = tenant_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
||||
workflow_node_execution.index = 1
|
||||
workflow_node_execution.node_id = node_id
|
||||
@@ -313,7 +386,6 @@ class WorkflowService:
|
||||
workflow_node_execution.title = node_instance.node_data.title
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
if run_succeeded and node_run_result:
|
||||
@@ -342,9 +414,6 @@ class WorkflowService:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
@@ -386,3 +455,65 @@ class WorkflowService:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
:param session: SQLAlchemy database session
|
||||
:param workflow_id: Workflow ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param account_id: Account ID (for permission check)
|
||||
:param data: Dictionary containing fields to update
|
||||
:return: Updated workflow or None if not found
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
return None
|
||||
|
||||
allowed_fields = ["marked_name", "marked_comment"]
|
||||
|
||||
for field, value in data.items():
|
||||
if field in allowed_fields:
|
||||
setattr(workflow, field, value)
|
||||
|
||||
workflow.updated_by = account_id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
return workflow
|
||||
|
||||
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
|
||||
"""
|
||||
Delete a workflow
|
||||
|
||||
:param session: SQLAlchemy database session
|
||||
:param workflow_id: Workflow ID
|
||||
:param tenant_id: Tenant ID
|
||||
:return: True if successful
|
||||
:raises: ValueError if workflow not found
|
||||
:raises: WorkflowInUseError if workflow is in use
|
||||
:raises: DraftWorkflowDeletionError if workflow is a draft version
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# Check if workflow is a draft version
|
||||
if workflow.version == "draft":
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if app:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
|
||||
|
||||
session.delete(workflow)
|
||||
return True
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask_login import current_user # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
from services.account_service_extend import TenantExtendService
|
||||
from services.feature_service import FeatureService
|
||||
@@ -21,7 +21,6 @@ class WorkspaceService:
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"in_trail": True,
|
||||
"trial_end_reason": None,
|
||||
"role": "normal",
|
||||
}
|
||||
@@ -45,9 +44,7 @@ class WorkspaceService:
|
||||
|
||||
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(
|
||||
tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]
|
||||
):
|
||||
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
|
||||
base_url = dify_config.FILES_URL
|
||||
replace_webapp_logo = (
|
||||
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
|
||||
|
||||
Reference in New Issue
Block a user