feat: 1.12.1初步合并

Merge upstream release cd03e0a (hotfix/1.12.1-fix.0) into main

# Conflicts:
#	api/.env.example
#	api/controllers/service_api/app/annotation.py
#	api/controllers/service_api/app/completion.py
#	api/controllers/service_api/app/conversation.py
#	api/controllers/service_api/app/message.py
#	api/core/file/file_manager.py
#	api/core/rag/datasource/retrieval_service.py
#	api/extensions/ext_celery.py
#	api/libs/gmpy2_pkcs10aep_cipher.py
#	api/uv.lock
#	web/pnpm-lock.yaml
#	web/service/client.ts
This commit is contained in:
npc0-hue
2026-02-09 09:51:18 +08:00
334 changed files with 31035 additions and 5335 deletions
@@ -480,4 +480,4 @@ const useButtonState = () => {
### Related Skills ### Related Skills
- `frontend-testing` - For testing refactored components - `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification - `web/docs/test.md` - Testing specification
+2 -2
View File
@@ -7,7 +7,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`). > **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill ## When to Apply This Skill
@@ -309,7 +309,7 @@ For more detailed information, refer to:
### Primary Specification (MUST follow) ### Primary Specification (MUST follow)
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document. - **`web/docs/test.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase ### Reference Examples in Codebase
@@ -4,7 +4,7 @@ This guide defines the workflow for generating tests, especially for complex com
## Scope Clarification ## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals. This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals.
| Scope | Rule | | Scope | Rule |
|-------|------| |-------|------|
+3
View File
@@ -9,6 +9,9 @@
# CODEOWNERS file # CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola /.github/CODEOWNERS @laipz8200 @crazywoola
# Agents
/.agents/skills/ @hyoban
# Docs # Docs
/docs/ @crazywoola /docs/ @crazywoola
+1
View File
@@ -72,6 +72,7 @@ jobs:
OPENDAL_FS_ROOT: /tmp/dify-storage OPENDAL_FS_ROOT: /tmp/dify-storage
run: | run: |
uv run --project api pytest \ uv run --project api pytest \
-n auto \
--timeout "${PYTEST_TIMEOUT:-180}" \ --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \ api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \ api/tests/integration_tests/tools \
+2 -6
View File
@@ -47,13 +47,9 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports run: uv run --directory api --dev lint-imports
- name: Run Basedpyright Checks - name: Run Type Checks
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check run: make type-check
- name: Run Mypy Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
+2 -31
View File
@@ -7,7 +7,7 @@ Dify is an open-source platform for developing LLM applications with an intuitiv
The codebase is split into: The codebase is split into:
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design - **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19 - **Frontend Web** (`/web`): Next.js application using TypeScript and React
- **Docker deployment** (`/docker`): Containerized deployment configurations - **Docker deployment** (`/docker`): Containerized deployment configurations
## Backend Workflow ## Backend Workflow
@@ -18,36 +18,7 @@ The codebase is split into:
## Frontend Workflow ## Frontend Workflow
```bash - Read `web/AGENTS.md` for details
cd web
pnpm lint:fix
pnpm type-check:tsgo
pnpm test
```
### Frontend Linting
ESLint is used for frontend code quality. Available commands:
```bash
# Lint all files (report only)
pnpm lint
# Lint and auto-fix issues
pnpm lint:fix
# Lint specific files or directories
pnpm lint:fix app/components/base/button/
pnpm lint:fix app/components/base/button/index.tsx
# Lint quietly (errors only, no warnings)
pnpm lint:quiet
# Check code complexity
pnpm lint:complexity
```
**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
## Testing & Quality Practices ## Testing & Quality Practices
+1 -1
View File
@@ -77,7 +77,7 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly. For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there. **Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend #### Backend
+7 -5
View File
@@ -68,9 +68,11 @@ lint:
@echo "✅ Linting complete" @echo "✅ Linting complete"
type-check: type-check:
@echo "📝 Running type check with basedpyright..." @echo "📝 Running type checks (basedpyright + mypy + ty)..."
@uv run --directory api --dev basedpyright @./dev/basedpyright-check $(PATH_TO_CHECK)
@echo "✅ Type check complete" @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@cd api && uv run ty check
@echo "✅ Type checks complete"
test: test:
@echo "🧪 Running backend unit tests..." @echo "🧪 Running backend unit tests..."
@@ -78,7 +80,7 @@ test:
echo "Target: $(TARGET_TESTS)"; \ echo "Target: $(TARGET_TESTS)"; \
uv run --project api --dev pytest $(TARGET_TESTS); \ uv run --project api --dev pytest $(TARGET_TESTS); \
else \ else \
uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \ PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
fi fi
@echo "✅ Tests complete" @echo "✅ Tests complete"
@@ -130,7 +132,7 @@ help:
@echo " make format - Format code with ruff" @echo " make format - Format code with ruff"
@echo " make check - Check code with ruff" @echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright" @echo " make type-check - Run type checks (basedpyright, mypy, ty)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo "" @echo ""
@echo "Docker Build Targets:" @echo "Docker Build Targets:"
+1
View File
@@ -617,6 +617,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003 PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640 PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration # Marketplace configuration
+55
View File
@@ -227,6 +227,9 @@ ignore_imports =
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.llm.node -> models.dataset core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
@@ -300,6 +303,58 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> services core.workflow.nodes.agent.agent_node -> services
core.workflow.nodes.tool.tool_node -> services core.workflow.nodes.tool.tool_node -> services
[importlinter:contract:model-runtime-no-internal-imports]
name = Model Runtime Internal Imports
type = forbidden
source_modules =
core.model_runtime
forbidden_modules =
configs
controllers
extensions
models
services
tasks
core.agent
core.app
core.base
core.callback_handler
core.datasource
core.db
core.entities
core.errors
core.extension
core.external_data_tool
core.file
core.helper
core.hosting_configuration
core.indexing_runner
core.llm_generator
core.logging
core.mcp
core.memory
core.model_manager
core.moderation
core.ops
core.plugin
core.prompt
core.provider_manager
core.rag
core.repositories
core.schemas
core.tools
core.trigger
core.variables
core.workflow
ignore_imports =
core.model_runtime.model_providers.__base.ai_model -> configs
core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
core.model_runtime.model_providers.__base.large_language_model -> configs
core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
core.model_runtime.model_providers.model_provider_factory -> configs
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
[importlinter:contract:rsc] [importlinter:contract:rsc]
name = RSC name = RSC
type = layers type = layers
+13 -1
View File
@@ -53,6 +53,7 @@ select = [
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module "S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage, "S311", # suspicious-non-cryptographic-random-usage,
"TID", # flake8-tidy-imports
] ]
@@ -88,6 +89,7 @@ ignore = [
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
"TID252", # allow relative imports from parent modules
] ]
[lint.per-file-ignores] [lint.per-file-ignores]
@@ -109,10 +111,20 @@ ignore = [
"S110", # allow ignoring exceptions in tests code (currently) "S110", # allow ignoring exceptions in tests code (currently)
] ]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes] [lint.pyflakes]
allowed-unused-imports = [ allowed-unused-imports = [
"_pytest.monkeypatch",
"tests.integration_tests", "tests.integration_tests",
"tests.unit_tests", "tests.unit_tests",
] ]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
msg = "Use Pydantic payload/query models instead of reqparse."
+9 -1
View File
@@ -1,4 +1,12 @@
from __future__ import annotations
import sys import sys
from typing import TYPE_CHECKING, cast
if TYPE_CHECKING:
from celery import Celery
celery: Celery
def is_db_command() -> bool: def is_db_command() -> bool:
@@ -23,7 +31,7 @@ else:
from app_factory import create_app from app_factory import create_app
app = create_app() app = create_app()
celery = app.extensions["celery"] celery = cast("Celery", app.extensions["celery"])
if __name__ == "__main__": if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001) app.run(host="0.0.0.0", port=5001)
+1 -1
View File
@@ -149,7 +149,7 @@ def initialize_extensions(app: DifyApp):
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2)) logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app(): def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs() app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate from extensions import ext_database, ext_migrate
+105 -98
View File
@@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool):
all_ids_in_tables = [] all_ids_in_tables = []
for ids_table in ids_tables: for ids_table in ids_tables:
query = "" query = ""
if ids_table["type"] == "uuid": match ids_table["type"]:
click.echo( case "uuid":
click.style( click.echo(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white" click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
) )
) c = ids_table["column"]
query = ( query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" with db.engine.begin() as conn:
) rs = conn.execute(sa.text(query))
with db.engine.begin() as conn: for i in rs:
rs = conn.execute(sa.text(query)) all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
for i in rs: case "text":
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) t = ids_table["table"]
elif ids_table["type"] == "text": click.echo(
click.echo( click.style(
click.style( f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}", fg="white",
fg="white", )
) )
) query = (
query = ( f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " f"FROM {ids_table['table']}"
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
elif ids_table["type"] == "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
) )
) with db.engine.begin() as conn:
query = ( rs = conn.execute(sa.text(query))
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " for i in rs:
f"FROM {ids_table['table']}" for j in i[0]:
) all_ids_in_tables.append({"table": ids_table["table"], "id": j})
with db.engine.begin() as conn: case "json":
rs = conn.execute(sa.text(query)) click.echo(
for i in rs: click.style(
for j in i[0]: (
all_ids_in_tables.append({"table": ids_table["table"], "id": j}) f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case _:
pass
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
except Exception as e: except Exception as e:
@@ -1737,59 +1741,18 @@ def file_usage(
if src_filter != src: if src_filter != src:
continue continue
if ids_table["type"] == "uuid": match ids_table["type"]:
# Direct UUID match case "uuid":
query = ( # Direct UUID match
f"SELECT {ids_table['pk_column']}, {ids_table['column']} " query = (
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
) f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
with db.engine.begin() as conn: )
rs = conn.execute(sa.text(query)) with db.engine.begin() as conn:
for row in rs: rs = conn.execute(sa.text(query))
record_id = str(row[0]) for row in rs:
ref_file_id = str(row[1]) record_id = str(row[0])
if ref_file_id not in file_key_map: ref_file_id = str(row[1])
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map: if ref_file_id not in file_key_map:
continue continue
storage_key = file_key_map[ref_file_id] storage_key = file_key_map[ref_file_id]
@@ -1812,6 +1775,50 @@ def file_usage(
) )
total_count += 1 total_count += 1
case "text" | "json":
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
case _:
pass
# Output results # Output results
if output_json: if output_json:
result = { result = {
+5
View File
@@ -243,6 +243,11 @@ class PluginConfig(BaseSettings):
default=15728640 * 12, default=15728640 * 12,
) )
PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching plugin model schemas in Redis",
default=60 * 60,
)
class MarketplaceConfig(BaseSettings): class MarketplaceConfig(BaseSettings):
""" """
-7
View File
@@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING: if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController from core.trigger.provider import PluginTriggerProviderController
@@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock") ContextVar("plugin_model_providers_lock")
) )
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers")) RecyclableContextVar(ContextVar("datasource_plugin_providers"))
) )
+6 -8
View File
@@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource):
def post(self): def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload) payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner( banner = ExporleBanner(
content=content, content={
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
},
link=payload.link, link=payload.link,
sort=payload.sort, sort=payload.sort,
language=payload.language, language=payload.language,
+49 -41
View File
@@ -1,10 +1,11 @@
from typing import Any, Literal from typing import Any, Literal
from flask import abort, make_response, request from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@@ -16,9 +17,11 @@ from controllers.console.wraps import (
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import ( from fields.annotation_fields import (
annotation_fields, Annotation,
annotation_hit_history_fields, AnnotationExportList,
build_annotation_model, AnnotationHitHistory,
AnnotationHitHistoryList,
AnnotationList,
) )
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
@@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload) reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery) reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload) reg(AnnotationFilePayload)
register_schema_models(
console_ns,
Annotation,
AnnotationList,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
@@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id) app_id = str(app_id)
args = AnnotationReplyPayload.model_validate(console_ns.payload) args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable": match action:
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) case "enable":
elif action == "disable": result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
result = AppAnnotationService.disable_app_annotation(app_id) case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200 return result, 200
@@ -201,33 +213,33 @@ class AnnotationApi(Resource):
app_id = str(app_id) app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = { annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
"data": marshal(annotation_list, annotation_fields), response = AnnotationList(
"has_more": len(annotation_list) == limit, data=annotation_models,
"limit": limit, has_more=len(annotation_list) == limit,
"total": total, limit=limit,
"page": page, total=total,
} page=page,
return response, 200 )
return response.model_dump(mode="json"), 200
@console_ns.doc("create_annotation") @console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) @console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
app_id = str(app_id) app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload) args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True) data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required @setup_required
@login_required @login_required
@@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Annotations exported successfully", "Annotations exported successfully",
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}), console_ns.models[AnnotationExportList.__name__],
) )
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)} annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
# Create response with secure headers for CSV export # Create response with secure headers for CSV export
response = make_response(response_data, 200) response = make_response(response_data, 200)
@@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation") @console_ns.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation") @console_ns.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) @console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
@console_ns.response(204, "Annotation deleted successfully") @console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required @edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
@@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation = AppAnnotationService.update_app_annotation_directly( annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id args.model_dump(exclude_none=True), app_id, annotation_id
) )
return annotation return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required @setup_required
@login_required @login_required
@@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Hit histories retrieved successfully", "Hit histories retrieved successfully",
console_ns.model( console_ns.models[AnnotationHitHistoryList.__name__],
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
) )
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit app_id, annotation_id, page, limit
) )
response = { history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields), annotation_hit_history_list, from_attributes=True
"has_more": len(annotation_hit_history_list) == limit, )
"limit": limit, response = AnnotationHitHistoryList(
"total": total, data=history_models,
"page": page, has_more=len(annotation_hit_history_list) == limit,
} limit=limit,
return response total=total,
page=page,
)
return response.model_dump(mode="json")
+7 -9
View File
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
@@ -33,7 +34,6 @@ from services.errors.audio import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel): class TextToSpeechPayload(BaseModel):
@@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code") language: str = Field(..., description="Language code")
console_ns.schema_model( class AudioTranscriptResponse(BaseModel):
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) text: str = Field(description="Transcribed text from audio")
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__, register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text") @console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Audio transcription successful", "Audio transcription successful",
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), console_ns.models[AudioTranscriptResponse.__name__],
) )
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large") @console_ns.response(413, "Audio file too large")
+13 -10
View File
@@ -508,16 +508,19 @@ class ChatConversationApi(Resource):
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc) query = query.where(Conversation.created_at <= end_datetime_utc)
if args.annotation_status == "annotated": match args.annotation_status:
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore case "annotated":
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
) MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
elif args.annotation_status == "not_annotated": )
query = ( case "not_annotated":
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) query = (
.group_by(Conversation.id) query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.having(func.count(MessageAnnotation.id) == 0) .group_by(Conversation.id)
) .having(func.count(MessageAnnotation.id) == 0)
)
case "all":
pass
if app_model.mode == AppMode.ADVANCED_CHAT: if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
+22 -37
View File
@@ -1,5 +1,4 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -12,10 +11,12 @@ from controllers.console.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
@@ -26,28 +27,13 @@ from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel): class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID") flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context") node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text") current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)") language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation") instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output") ideal_output: str = Field(default="", description="Expected ideal output")
@@ -64,6 +50,7 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload) reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload) reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload) reg(InstructionTemplatePayload)
reg(ModelConfig)
@console_ns.route("/rule-generate") @console_ns.route("/rule-generate")
@@ -82,12 +69,7 @@ class RuleGenerateApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
@@ -118,9 +100,7 @@ class RuleCodeGenerateApi(Resource):
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args.instruction, args=args,
model_config=args.model_config_data,
code_language=args.code_language,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@@ -152,8 +132,7 @@ class RuleStructuredOutputGenerateApi(Resource):
try: try:
structured_output = LLMGenerator.generate_structured_output( structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args.instruction, args=args,
model_config=args.model_config_data,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@@ -204,23 +183,29 @@ class InstructionGenerateApi(Resource):
case "llm": case "llm":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args.instruction, args=RuleGeneratePayload(
model_config=args.model_config_data, instruction=args.instruction,
no_variable=True, model_config=args.model_config_data,
no_variable=True,
),
) )
case "agent": case "agent":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args.instruction, args=RuleGeneratePayload(
model_config=args.model_config_data, instruction=args.instruction,
no_variable=True, model_config=args.model_config_data,
no_variable=True,
),
) )
case "code": case "code":
return LLMGenerator.generate_code( return LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args.instruction, args=RuleCodeGeneratePayload(
model_config=args.model_config_data, instruction=args.instruction,
code_language=args.language, model_config=args.model_config_data,
code_language=args.language,
),
) )
case _: case _:
return {"error": f"invalid node type: {node_type}"} return {"error": f"invalid node type: {node_type}"}
+19 -12
View File
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
@@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel): class ChatMessagesQuery(BaseModel):
@@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
raise ValueError("has_comment must be a boolean value") raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]): class AnnotationCountResponse(BaseModel):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) count: int = Field(description="Number of annotations")
reg(ChatMessagesQuery) class SuggestedQuestionsResponse(BaseModel):
reg(MessageFeedbackPayload) data: list[str] = Field(description="Suggested question")
reg(FeedbackExportQuery)
register_schema_models(
console_ns,
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models # Register in dependency order: base models first, then dependent models
@@ -231,7 +240,7 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model) @marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
@@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Annotation count retrieved successfully", "Annotation count retrieved successfully",
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), console_ns.models[AnnotationCountResponse.__name__],
) )
@get_app_model @get_app_model
@setup_required @setup_required
@@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Suggested questions retrieved successfully", "Suggested questions retrieved successfully",
console_ns.model( console_ns.models[SuggestedQuestionsResponse.__name__],
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
) )
@console_ns.response(404, "Message or conversation not found") @console_ns.response(404, "Message or conversation not found")
@setup_required @setup_required
@@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = FeedbackExportQuery.model_validate(request.args.to_dict())
# Import the service function # Import the service function
from services.feedback_service import FeedbackService from services.feedback_service import FeedbackService
@@ -2,9 +2,11 @@ import logging
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource, fields from flask_restx import Resource
from pydantic import BaseModel, Field
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from libs.login import login_required from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
@@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OAuthDataSourceResponse(BaseModel):
data: str = Field(description="Authorization URL or 'internal' for internal setup")
class OAuthDataSourceBindingResponse(BaseModel):
result: str = Field(description="Operation result")
class OAuthDataSourceSyncResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
OAuthDataSourceResponse,
OAuthDataSourceBindingResponse,
OAuthDataSourceSyncResponse,
)
def get_oauth_providers(): def get_oauth_providers():
with current_app.app_context(): with current_app.app_context():
notion_oauth = NotionOAuth( notion_oauth = NotionOAuth(
@@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Authorization URL or internal setup success", "Authorization URL or internal setup success",
console_ns.model( console_ns.models[OAuthDataSourceResponse.__name__],
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
) )
@console_ns.response(400, "Invalid provider") @console_ns.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Data source binding success", "Data source binding success",
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), console_ns.models[OAuthDataSourceBindingResponse.__name__],
) )
@console_ns.response(400, "Invalid provider or code") @console_ns.response(400, "Invalid provider or code")
def get(self, provider: str): def get(self, provider: str):
@@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Data source sync success", "Data source sync success",
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), console_ns.models[OAuthDataSourceSyncResponse.__name__],
) )
@console_ns.response(400, "Invalid provider or sync failed") @console_ns.response(400, "Invalid provider or sync failed")
@setup_required @setup_required
+30 -20
View File
@@ -2,10 +2,11 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restx import Resource, fields from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailCodeError, EmailCodeError,
@@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
return valid_password(value) return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): class ForgotPasswordEmailResponse(BaseModel):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")
code: str | None = Field(default=None, description="Error code if account not found")
class ForgotPasswordCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether code is valid")
email: EmailStr = Field(description="Email address")
token: str = Field(description="New reset token")
class ForgotPasswordResetResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
ForgotPasswordSendPayload,
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordEmailResponse,
ForgotPasswordCheckResponse,
ForgotPasswordResetResponse,
)
@console_ns.route("/forgot-password") @console_ns.route("/forgot-password")
@@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Email sent successfully", "Email sent successfully",
console_ns.model( console_ns.models[ForgotPasswordEmailResponse.__name__],
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
) )
@console_ns.response(400, "Invalid email or rate limit exceeded") @console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required @setup_required
@@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Code verified successfully", "Code verified successfully",
console_ns.model( console_ns.models[ForgotPasswordCheckResponse.__name__],
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
) )
@console_ns.response(400, "Invalid code or token") @console_ns.response(400, "Invalid code or token")
@setup_required @setup_required
@@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Password reset successfully", "Password reset successfully",
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), console_ns.models[ForgotPasswordResetResponse.__name__],
) )
@console_ns.response(400, "Invalid token or password mismatch") @console_ns.response(400, "Invalid token or password mismatch")
@setup_required @setup_required
+33 -33
View File
@@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
grant_type = OAuthGrantType(payload.grant_type) grant_type = OAuthGrantType(payload.grant_type)
except ValueError: except ValueError:
raise BadRequest("invalid grant_type") raise BadRequest("invalid grant_type")
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE: if payload.client_secret != oauth_provider_app.client_secret:
if not payload.code: raise BadRequest("client_secret is invalid")
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret: if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("client_secret is invalid") raise BadRequest("redirect_uri is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris: access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
raise BadRequest("redirect_uri is invalid") grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
case OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token( access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
) )
return jsonable_encoder( return jsonable_encoder(
{ {
"access_token": access_token, "access_token": access_token,
"token_type": "Bearer", "token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token, "refresh_token": refresh_token,
} }
) )
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
@console_ns.route("/oauth/provider/account") @console_ns.route("/oauth/provider/account")
+20 -20
View File
@@ -1,6 +1,6 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any, cast from typing import Any, Literal, cast
from flask import request from flask import request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
@@ -157,9 +157,8 @@ class DataSourceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, binding_id, action): def patch(self, binding_id, action: Literal["enable", "disable"]):
binding_id = str(binding_id) binding_id = str(binding_id)
action = str(action)
with Session(db.engine) as session: with Session(db.engine) as session:
data_source_binding = session.execute( data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id) select(DataSourceOauthBinding).filter_by(id=binding_id)
@@ -167,23 +166,24 @@ class DataSourceApi(Resource):
if data_source_binding is None: if data_source_binding is None:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")
# enable binding # enable binding
if action == "enable": match action:
if data_source_binding.disabled: case "enable":
data_source_binding.disabled = False if data_source_binding.disabled:
data_source_binding.updated_at = naive_utc_now() data_source_binding.disabled = False
db.session.add(data_source_binding) data_source_binding.updated_at = naive_utc_now()
db.session.commit() db.session.add(data_source_binding)
else: db.session.commit()
raise ValueError("Data source is not disabled.") else:
# disable binding raise ValueError("Data source is not disabled.")
if action == "disable": # disable binding
if not data_source_binding.disabled: case "disable":
data_source_binding.disabled = True if not data_source_binding.disabled:
data_source_binding.updated_at = naive_utc_now() data_source_binding.disabled = True
db.session.add(data_source_binding) data_source_binding.updated_at = naive_utc_now()
db.session.commit() db.session.add(data_source_binding)
else: db.session.commit()
raise ValueError("Data source is disabled.") else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200 return {"result": "success"}, 200
+9 -1
View File
@@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None embedding_model: str | None = None
embedding_model_provider: str | None = None embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None external_knowledge_id: str | None = None
@@ -288,7 +289,14 @@ class DatasetListApi(Resource):
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
query = ConsoleDatasetListQuery.model_validate(request.args.to_dict()) # Convert query parameters to dict, handling list parameters correctly
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
if "ids" in request.args:
query_params["ids"] = request.args.getlist("ids")
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = ConsoleDatasetListQuery.model_validate(query_params)
# provider = request.args.get("provider", default="vendor") # provider = request.args.get("provider", default="vendor")
if query.ids: if query.ids:
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
@@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService from services.file_service import FileService
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import ( from ..app.error import (
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel):
name: str name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel): class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive.""" """Request payload for bulk downloading documents as a zip archive."""
@@ -125,6 +130,7 @@ register_schema_models(
RetrievalModel, RetrievalModel,
DocumentRetryPayload, DocumentRetryPayload,
DocumentRenamePayload, DocumentRenamePayload,
GenerateSummaryPayload,
DocumentBatchDownloadZipPayload, DocumentBatchDownloadZipPayload,
) )
@@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items
DocumentService.enrich_documents_with_summary_index_status(
documents=documents,
dataset=dataset,
tenant_id=current_tenant_id,
)
if fetch: if fetch:
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
@@ -563,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if document.indexing_status in {"completed", "error"}: if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
match document.data_source_type:
case "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if document.data_source_type == "upload_file": if file_detail is None:
if not data_source_info: raise NotFound("File not found.")
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if file_detail is None: extract_setting = ExtractSetting(
raise NotFound("File not found.") datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
case "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
case "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
extract_setting = ExtractSetting( case _:
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form raise ValueError("Data source type not support")
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
@@ -797,6 +809,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status, "display_status": document.display_status,
"doc_form": document.doc_form, "doc_form": document.doc_form,
"doc_language": document.doc_language, "doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
} }
else: else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@@ -832,6 +845,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status, "display_status": document.display_status,
"doc_form": document.doc_form, "doc_form": document.doc_form,
"doc_language": document.doc_language, "doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
} }
return response, 200 return response, 200
@@ -939,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
if action == "pause": match action:
if document.indexing_status != "indexing": case "pause":
raise InvalidActionError("Document not in indexing state.") if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id document.paused_by = current_user.id
document.paused_at = naive_utc_now() document.paused_at = naive_utc_now()
document.is_paused = True document.is_paused = True
db.session.commit() db.session.commit()
elif action == "resume": case "resume":
if document.indexing_status not in {"paused", "error"}: if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.") raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None document.paused_by = None
document.paused_at = None document.paused_at = None
document.is_paused = False document.is_paused = False
db.session.commit() db.session.commit()
return {"result": "success"}, 200 return {"result": "success"}, 200
@@ -1255,3 +1270,149 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data, "input_data": log.input_data,
"datasource_node_id": log.datasource_node_id, "datasource_node_id": log.datasource_node_id,
}, 200 }, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
from werkzeug.exceptions import BadRequest
raise BadRequest("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
# Verify all documents exist and belong to the dataset
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Update need_summary to True for documents that don't have it set
# This handles the case where documents were created when summary_index_setting was disabled
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info("Skipping summary generation for qa_model document %s", document.id)
continue
# Dispatch async task
generate_summary_index_task.delay(dataset_id, document.id)
logger.info(
"Dispatched summary generation task for document %s in dataset %s",
document.id,
dataset_id,
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get summary status detail from service
from services.summary_index_service import SummaryIndexService
result = SummaryIndexService.get_document_summary_status_detail(
document_id=document_id,
dataset_id=dataset_id,
)
return result, 200
@@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields))
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel): class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100) limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list) status: list[str] = Field(default_factory=list)
@@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None keywords: list[str] | None = None
regenerate_child_chunks: bool = False regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel): class BatchImportPayload(BaseModel):
@@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields))
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = { response = {
"data": marshal(segments.items, segment_fields), "data": segments_with_summary,
"limit": limit, "limit": limit,
"total": segments.total, "total": segments.total,
"total_pages": segments.pages, "total_pages": segments.pages,
@@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True) payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document) SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset) segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {}) payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True) payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document) SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment( segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
) )
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@login_required @login_required
@@ -1,6 +1,13 @@
from flask_restx import Resource from flask_restx import Resource, fields
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from libs.login import login_required from libs.login import login_required
from .. import console_ns from .. import console_ns
@@ -14,13 +21,45 @@ from ..wraps import (
register_schema_model(console_ns, HitTestingPayload) register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase): class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval") @console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__]) @console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully") @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@setup_required @setup_required
+5 -4
View File
@@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable": match action:
MetadataService.enable_built_in_field(dataset) case "enable":
elif action == "disable": MetadataService.enable_built_in_field(dataset)
MetadataService.disable_built_in_field(dataset) case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200 return {"result": "success"}, 200
@@ -1,10 +1,9 @@
import json import json
import logging import logging
from typing import Any, Literal, cast from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db from extensions.ext_database import db
from factories import variable_factory from factories import variable_factory
from libs import helper from libs import helper
from libs.helper import TimestampField from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant, current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel): class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100) limit: int = Field(default=20, ge=1, le=100)
@@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models( register_schema_models(
console_ns, console_ns,
DraftWorkflowSyncPayload, DraftWorkflowSyncPayload,
@@ -135,6 +138,7 @@ register_schema_models(
NodeIdQuery, NodeIdQuery,
WorkflowRunQuery, WorkflowRunQuery,
DatasourceVariablesPayload, DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
) )
@@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
return recommended_plugins return recommended_plugins
+3 -3
View File
@@ -9,7 +9,7 @@ import services
from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse from controllers.common.fields import Site as SiteResponse
from controllers.common.schema import get_or_create_model from controllers.common.schema import get_or_create_model
from controllers.console import api, console_ns from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
@@ -51,7 +51,7 @@ from fields.app_fields import (
tag_fields, tag_fields,
) )
from fields.dataset_fields import dataset_fields from fields.dataset_fields import dataset_fields
from fields.member_fields import build_simple_account_model from fields.member_fields import simple_account_fields
from fields.workflow_fields import ( from fields.workflow_fields import (
conversation_variable_fields, conversation_variable_fields,
pipeline_variable_fields, pipeline_variable_fields,
@@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = build_simple_account_model(console_ns) simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
+40 -53
View File
@@ -1,87 +1,74 @@
import os import os
from typing import Literal
from flask import session from flask import session
from flask_restx import Resource, fields
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from controllers.fastopenapi import console_router
from extensions.ext_database import db from extensions.ext_database import db
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel): class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30) password: str = Field(..., max_length=30, description="Initialization password")
console_ns.schema_model( class InitStatusResponse(BaseModel):
InitValidatePayload.__name__, status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
class InitValidateResponse(BaseModel):
result: str = Field(description="Operation result", examples=["success"])
@console_router.get(
"/init",
response_model=InitStatusResponse,
tags=["console"],
) )
def get_init_status() -> InitStatusResponse:
"""Get initialization validation status."""
init_status = get_init_validate_status()
if init_status:
return InitStatusResponse(status="finished")
return InitStatusResponse(status="not_started")
@console_ns.route("/init") @console_router.post(
class InitValidateAPI(Resource): "/init",
@console_ns.doc("get_init_status") response_model=InitValidateResponse,
@console_ns.doc(description="Get initialization validation status") tags=["console"],
@console_ns.response( status_code=201,
200, )
"Success", @only_edition_self_hosted
model=console_ns.model( def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
"InitStatusResponse", """Validate initialization password."""
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, tenant_count = TenantService.get_tenant_count()
), if tenant_count > 0:
) raise AlreadySetupError()
def get(self):
"""Get initialization validation status"""
init_status = get_init_validate_status()
if init_status:
return {"status": "finished"}
return {"status": "not_started"}
@console_ns.doc("validate_init_password") if payload.password != os.environ.get("INIT_PASSWORD"):
@console_ns.doc(description="Validate initialization password for self-hosted edition") session["is_init_validated"] = False
@console_ns.expect(console_ns.models[InitValidatePayload.__name__]) raise InitValidateFailedError()
@console_ns.response(
201,
"Success",
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Validate initialization password"""
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
payload = InitValidatePayload.model_validate(console_ns.payload) session["is_init_validated"] = True
input_password = payload.password return InitValidateResponse(result="success")
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
session["is_init_validated"] = True
return {"result": "success"}, 201
def get_init_validate_status(): def get_init_validate_status() -> bool:
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"): if os.environ.get("INIT_PASSWORD"):
if session.get("is_init_validated"): if session.get("is_init_validated"):
return True return True
with Session(db.engine) as db_session: with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none() return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
return True return True
+58 -69
View File
@@ -1,7 +1,6 @@
import urllib.parse import urllib.parse
import httpx import httpx
from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import services import services
@@ -11,7 +10,7 @@ from controllers.common.errors import (
RemoteFileUploadError, RemoteFileUploadError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.common.schema import register_schema_models from controllers.fastopenapi import console_router
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
@@ -19,84 +18,74 @@ from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from services.file_service import FileService from services.file_service import FileService
from . import console_ns
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource):
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
def get(self, url):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
return info.model_dump(mode="json")
class RemoteFileUploadPayload(BaseModel): class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch") url: str = Field(..., description="URL to fetch")
console_ns.schema_model( @console_router.get(
RemoteFileUploadPayload.__name__, "/remote-files/<path:url>",
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"), response_model=RemoteFileInfo,
tags=["console"],
) )
def get_remote_file_info(url: str) -> RemoteFileInfo:
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
@console_ns.route("/remote-files/upload") @console_router.post(
class RemoteFileUploadApi(Resource): "/remote-files/upload",
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) response_model=FileWithSignedUrl,
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__]) tags=["console"],
def post(self): status_code=201,
args = RemoteFileUploadPayload.model_validate(console_ns.payload) )
url = args.url def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
url = payload.url
try: try:
resp = ssrf_proxy.head(url=url) resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK: if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK: if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e: except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp) file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try: try:
user, _ = current_account_with_tenant() user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file( upload_file = FileService(db.engine).upload_file(
filename=file_info.filename, filename=file_info.filename,
content=content, content=content,
mimetype=file_info.mimetype, mimetype=file_info.mimetype,
user=user, user=user,
source_url=url, source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
payload = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
) )
return payload.model_dump(mode="json"), 201 except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
+12 -2
View File
@@ -1,18 +1,28 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_restx import Resource, marshal_with from flask_restx import Namespace, Resource, fields, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.recommended_app_service_extend import RecommendedAppService from services.recommended_app_service_extend import RecommendedAppService
from services.tag_service import TagService from services.tag_service import TagService
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel): class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50) name: str = Field(description="Tag name", min_length=1, max_length=50)
+24 -17
View File
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailAlreadyInUseError, EmailAlreadyInUseError,
@@ -37,7 +38,7 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload) reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload) reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload) reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = { integrate_fields = {
"provider": fields.String, "provider": fields.String,
@@ -236,11 +243,11 @@ class AccountProfileApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
return current_user return _serialize_account(current_user)
@console_ns.route("/account/name") @console_ns.route("/account/name")
@@ -249,14 +256,14 @@ class AccountNameApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload) args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name) updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account return _serialize_account(updated_account)
@console_ns.route("/account/avatar") @console_ns.route("/account/avatar")
@@ -265,7 +272,7 @@ class AccountAvatarApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} payload = console_ns.payload or {}
@@ -273,7 +280,7 @@ class AccountAvatarApi(Resource):
updated_account = AccountService.update_account(current_user, avatar=args.avatar) updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account return _serialize_account(updated_account)
@console_ns.route("/account/interface-language") @console_ns.route("/account/interface-language")
@@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} payload = console_ns.payload or {}
@@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource):
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account return _serialize_account(updated_account)
@console_ns.route("/account/interface-theme") @console_ns.route("/account/interface-theme")
@@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} payload = console_ns.payload or {}
@@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource):
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account return _serialize_account(updated_account)
@console_ns.route("/account/timezone") @console_ns.route("/account/timezone")
@@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} payload = console_ns.payload or {}
@@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource):
updated_account = AccountService.update_account(current_user, timezone=args.timezone) updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account return _serialize_account(updated_account)
@console_ns.route("/account/password") @console_ns.route("/account/password")
@@ -333,7 +340,7 @@ class AccountPasswordApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} payload = console_ns.payload or {}
@@ -344,7 +351,7 @@ class AccountPasswordApi(Resource):
except ServiceCurrentPasswordIncorrectError: except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError() raise CurrentPasswordIncorrectError()
return {"result": "success"} return _serialize_account(current_user)
@console_ns.route("/account/integrates") @console_ns.route("/account/integrates")
@@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self): def post(self):
payload = console_ns.payload or {} payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload) args = ChangeEmailResetPayload.model_validate(payload)
@@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource):
email=normalized_new_email, email=normalized_new_email,
) )
return updated_account return _serialize_account(updated_account)
@console_ns.route("/account/change-email/check-email-unique") @console_ns.route("/account/change-email/check-email-unique")
+52 -17
View File
@@ -1,9 +1,10 @@
from typing import Any from typing import Any
from flask import request from flask import request
from flask_restx import Resource, fields from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str plugin_id: str
class EndpointCreateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class PluginEndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class EndpointDeleteResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointUpdateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointEnableResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointDisableResponse(BaseModel):
success: bool = Field(description="Operation success")
def reg(cls: type[BaseModel]): def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload) register_schema_models(
reg(EndpointIdPayload) console_ns,
reg(EndpointUpdatePayload) EndpointCreatePayload,
reg(EndpointListQuery) EndpointIdPayload,
reg(EndpointListForPluginQuery) EndpointUpdatePayload,
EndpointListQuery,
EndpointListForPluginQuery,
EndpointCreateResponse,
EndpointListResponse,
PluginEndpointListResponse,
EndpointDeleteResponse,
EndpointUpdateResponse,
EndpointEnableResponse,
EndpointDisableResponse,
)
@console_ns.route("/workspaces/current/endpoints/create") @console_ns.route("/workspaces/current/endpoints/create")
@@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint created successfully", "Endpoint created successfully",
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.models[EndpointCreateResponse.__name__],
) )
@console_ns.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@@ -91,9 +130,7 @@ class EndpointListApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
console_ns.model( console_ns.models[EndpointListResponse.__name__],
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
) )
@setup_required @setup_required
@login_required @login_required
@@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
console_ns.model( console_ns.models[PluginEndpointListResponse.__name__],
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
) )
@setup_required @setup_required
@login_required @login_required
@@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint deleted successfully", "Endpoint deleted successfully",
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.models[EndpointDeleteResponse.__name__],
) )
@console_ns.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint updated successfully", "Endpoint updated successfully",
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.models[EndpointUpdateResponse.__name__],
) )
@console_ns.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint enabled successfully", "Endpoint enabled successfully",
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.models[EndpointEnableResponse.__name__],
) )
@console_ns.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint disabled successfully", "Endpoint disabled successfully",
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.models[EndpointDisableResponse.__name__],
) )
@console_ns.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
+13 -14
View File
@@ -1,12 +1,12 @@
from urllib import parse from urllib import parse
from flask import abort, request from flask import abort, request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, TypeAdapter
import services import services
from configs import dify_config from configs import dify_config
from controllers.common.schema import get_or_create_model, register_enum_models from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
CannotTransferOwnerToSelfError, CannotTransferOwnerToSelfError,
@@ -25,7 +25,7 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_with_role_fields, account_with_role_list_fields from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole from models.account import Account, TenantAccountRole
@@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload) reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload) reg(OwnerTransferPayload)
register_enum_models(console_ns, TenantAccountRole) register_enum_models(console_ns, TenantAccountRole)
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
@console_ns.route("/workspaces/current/members") @console_ns.route("/workspaces/current/members")
@@ -84,13 +79,15 @@ class MemberListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_model) @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self): def get(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant) members = TenantService.get_tenant_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200 member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/invite-email") @console_ns.route("/workspaces/current/members/invite-email")
@@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_model) @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self): def get(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant) members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200 member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
File diff suppressed because it is too large Load Diff
+39 -37
View File
@@ -1,16 +1,16 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_restx import Namespace, Resource, fields from flask_restx import Resource
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, TypeAdapter
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model from fields.annotation_fields import Annotation, AnnotationList
from models.model import ApiToken, App # extend - 密钥额度限制,新增ApiToken from models.model import ApiToken, App # extend - 密钥额度限制,新增ApiToken
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
@@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel):
embedding_model_name: str = Field(description="Embedding model name") embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) register_schema_models(
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
)
@service_api_ns.route("/apps/annotation-reply/<string:action>") @service_api_ns.route("/apps/annotation-reply/<string:action>")
@@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_model: App, action: Literal["enable", "disable"]): def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature.""" """Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
if action == "enable": match action:
result = AppAnnotationService.enable_app_annotation(args, app_model.id) case "enable":
elif action == "disable": result = AppAnnotationService.enable_app_annotation(args, app_model.id)
result = AppAnnotationService.disable_app_annotation(app_model.id) case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200 return result, 200
@@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
# Define annotation list response model
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
"has_more": fields.Boolean,
"limit": fields.Integer,
"total": fields.Integer,
"page": fields.Integer,
}
def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
@service_api_ns.route("/apps/annotations") @service_api_ns.route("/apps/annotations")
class AnnotationListApi(Resource): class AnnotationListApi(Resource):
@service_api_ns.doc("list_annotations") @service_api_ns.doc("list_annotations")
@@ -109,8 +95,12 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@service_api_ns.response(
200,
"Annotations retrieved successfully",
service_api_ns.models[AnnotationList.__name__],
)
@validate_app_token @validate_app_token
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""List annotations for the application.""" """List annotations for the application."""
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
@@ -118,13 +108,15 @@ class AnnotationListApi(Resource):
keyword = request.args.get("keyword", default="", type=str) keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
return { annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
"data": annotation_list, response = AnnotationList(
"has_more": len(annotation_list) == limit, data=annotation_models,
"limit": limit, has_more=len(annotation_list) == limit,
"total": total, limit=limit,
"page": page, total=total,
} page=page,
)
return response.model_dump(mode="json")
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation") @service_api_ns.doc("create_annotation")
@@ -135,13 +127,18 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@service_api_ns.response(
HTTPStatus.CREATED,
"Annotation created successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token @validate_app_token
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App): def post(self, app_model: App):
"""Create a new annotation.""" """Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation, 201 response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>") @service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
@@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource):
404: "Annotation not found", 404: "Annotation not found",
} }
) )
@service_api_ns.response(
200,
"Annotation updated successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token @validate_app_token
@edit_permission_required @edit_permission_required
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str): def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation.""" """Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@service_api_ns.doc("delete_annotation") @service_api_ns.doc("delete_annotation")
@service_api_ns.doc(description="Delete an annotation") @service_api_ns.doc(description="Delete an annotation")
@@ -30,6 +30,7 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import UUIDStrOrEmpty
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
@@ -53,7 +54,7 @@ class ChatRequestPayload(BaseModel):
query: str query: str
files: list[dict[str, Any]] | None = None files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID") conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev") retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@@ -1,5 +1,4 @@
from typing import Any, Literal from typing import Any, Literal
from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -23,12 +22,13 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model, build_conversation_variable_model,
) )
from libs.helper import UUIDStrOrEmpty
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错 from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
class ConversationListQuery(BaseModel): class ConversationListQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations" default="-updated_at", description="Sort order for conversations"
@@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel): class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field( variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255 default=None, description="Filter variables by name", min_length=1, max_length=255
+3 -3
View File
@@ -1,6 +1,5 @@
import logging import logging
from typing import Literal from typing import Literal
from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
from services.errors.message import ( from services.errors.message import (
FirstMessageNotExistsError, FirstMessageNotExistsError,
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel): class MessageListQuery(BaseModel):
conversation_id: UUID conversation_id: UUIDStrOrEmpty
first_id: UUID | None = None first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
+12 -9
View File
@@ -17,7 +17,7 @@ from controllers.service_api.wraps import (
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields from fields.tag_fields import DataSetTag
from libs.login import current_user from libs.login import current_user
from models.account import Account from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
@@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel):
retrieval_model: RetrievalModel | None = None retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None embedding_model: str | None = None
embedding_model_provider: str | None = None embedding_model_provider: str | None = None
summary_index_setting: dict | None = None
class DatasetUpdatePayload(BaseModel): class DatasetUpdatePayload(BaseModel):
@@ -113,6 +114,7 @@ register_schema_models(
TagBindingPayload, TagBindingPayload,
TagUnbindingPayload, TagUnbindingPayload,
DatasetListQuery, DatasetListQuery,
DataSetTag,
) )
@@ -217,6 +219,7 @@ class DatasetListApi(DatasetApiResource):
embedding_model_provider=payload.embedding_model_provider, embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=payload.embedding_model, embedding_model_name=payload.embedding_model,
retrieval_model=payload.retrieval_model, retrieval_model=payload.retrieval_model,
summary_index_setting=payload.summary_index_setting,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@@ -478,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _): def get(self, _):
"""Get all knowledge type tags.""" """Get all knowledge type tags."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
cid = current_user.current_tenant_id cid = current_user.current_tenant_id
assert cid is not None assert cid is not None
tags = TagService.get_tags("knowledge", cid) tags = TagService.get_tags("knowledge", cid)
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return tags, 200 return [tag.model_dump(mode="json") for tag in tag_models], 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag") @service_api_ns.doc("create_dataset_tag")
@@ -498,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions", 403: "Forbidden - insufficient permissions",
} }
) )
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def post(self, _): def post(self, _):
"""Add a knowledge type tag.""" """Add a knowledge type tag."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
@@ -508,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200 return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@@ -521,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions", 403: "Forbidden - insufficient permissions",
} }
) )
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def patch(self, _): def patch(self, _):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -534,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200 return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
Segmentation, Segmentation,
) )
from services.file_service import FileService from services.file_service import FileService
from services.summary_index_service import SummaryIndexService
class DocumentTextCreatePayload(BaseModel): class DocumentTextCreatePayload(BaseModel):
@@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource):
) )
documents = paginated_documents.items documents = paginated_documents.items
DocumentService.enrich_documents_with_summary_index_status(
documents=documents,
dataset=dataset,
tenant_id=tenant_id,
)
response = { response = {
"data": marshal(documents, document_fields), "data": marshal(documents, document_fields),
"has_more": len(documents) == query_params.limit, "has_more": len(documents) == query_params.limit,
@@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource):
if metadata not in self.METADATA_CHOICES: if metadata not in self.METADATA_CHOICES:
raise InvalidMetadataError(f"Invalid metadata value: {metadata}") raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
# Calculate summary_index_status if needed
summary_index_status = None
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
if has_summary_index and document.need_summary is True:
summary_index_status = SummaryIndexService.get_document_summary_index_status(
document_id=document_id,
dataset_id=dataset_id,
tenant_id=tenant_id,
)
if metadata == "only": if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without": elif metadata == "without":
@@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource):
"display_status": document.display_status, "display_status": document.display_status,
"doc_form": document.doc_form, "doc_form": document.doc_form,
"doc_language": document.doc_language, "doc_language": document.doc_language,
"summary_index_status": summary_index_status,
"need_summary": document.need_summary if document.need_summary is not None else False,
} }
else: else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource):
"display_status": document.display_status, "display_status": document.display_status,
"doc_form": document.doc_form, "doc_form": document.doc_form,
"doc_language": document.doc_language, "doc_language": document.doc_language,
"summary_index_status": summary_index_status,
"need_summary": document.need_summary if document.need_summary is not None else False,
} }
return response return response
@@ -1,7 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
register_schema_model(service_api_ns, HitTestingPayload)
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve") @service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
404: "Dataset not found", 404: "Dataset not found",
} }
) )
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Perform hit testing on a dataset. """Perform hit testing on a dataset.
@@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable": match action:
MetadataService.enable_built_in_field(dataset) case "enable":
elif action == "disable": MetadataService.enable_built_in_field(dataset)
MetadataService.disable_built_in_field(dataset) case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200 return {"result": "success"}, 200
+8 -8
View File
@@ -126,14 +126,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# If caller needs end-user context, attach EndUser to current_user # If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg: if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: user_id = None
user_id = request.args.get("user") match fetch_user_arg.fetch_from:
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: case WhereisUserArg.QUERY:
user_id = request.get_json().get("user") user_id = request.args.get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: case WhereisUserArg.JSON:
user_id = request.form.get("user") user_id = request.get_json().get("user")
else: case WhereisUserArg.FORM:
user_id = None user_id = request.form.get("user")
if not user_id and fetch_user_arg.required: if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.") raise ValueError("Arg user must be provided.")
@@ -14,16 +14,17 @@ class AgentConfigManager:
agent_dict = config.get("agent_mode", {}) agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot") agent_strategy = agent_dict.get("strategy", "cot")
if agent_strategy == "function_call": match agent_strategy:
strategy = AgentEntity.Strategy.FUNCTION_CALLING case "function_call":
elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
else: case "cot" | "react":
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
case _:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = [] agent_tools = []
for tool in agent_dict.get("tools", []): for tool in agent_dict.get("tools", []):
@@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC):
"document_name": resource["document_name"], "document_name": resource["document_name"],
"score": resource["score"], "score": resource["score"],
"content": resource["content"], "content": resource["content"],
"summary": resource.get("summary"),
} }
) )
metadata["retriever_resources"] = updated_resources metadata["retriever_resources"] = updated_resources
@@ -250,7 +250,7 @@ class WorkflowResponseConverter:
data=WorkflowFinishStreamResponse.Data( data=WorkflowFinishStreamResponse.Data(
id=run_id, id=run_id,
workflow_id=workflow_id, workflow_id=workflow_id,
status=status.value, status=status,
outputs=encoded_outputs, outputs=encoded_outputs,
error=error, error=error,
elapsed_time=elapsed_time, elapsed_time=elapsed_time,
@@ -340,13 +340,13 @@ class WorkflowResponseConverter:
metadata = self._merge_metadata(event.execution_metadata, snapshot) metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent): if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED.value status = WorkflowNodeExecutionStatus.SUCCEEDED
error_message = event.error error_message = event.error
elif isinstance(event, QueueNodeFailedEvent): elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED.value status = WorkflowNodeExecutionStatus.FAILED
error_message = event.error error_message = event.error
else: else:
status = WorkflowNodeExecutionStatus.EXCEPTION.value status = WorkflowNodeExecutionStatus.EXCEPTION
error_message = event.error error_message = event.error
return NodeFinishStreamResponse( return NodeFinishStreamResponse(
@@ -413,7 +413,7 @@ class WorkflowResponseConverter:
process_data_truncated=process_data_truncated, process_data_truncated=process_data_truncated,
outputs=outputs, outputs=outputs,
outputs_truncated=outputs_truncated, outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY.value, status=WorkflowNodeExecutionStatus.RETRY,
error=event.error, error=event.error,
elapsed_time=elapsed_time, elapsed_time=elapsed_time,
execution_metadata=metadata, execution_metadata=metadata,
@@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required") raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"] inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"] start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"] datasource_type = DatasourceProviderType(args["datasource_type"])
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
) )
@@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator):
tenant_id: str, tenant_id: str,
dataset_id: str, dataset_id: str,
built_in_field_enabled: bool, built_in_field_enabled: bool,
datasource_type: str, datasource_type: DatasourceProviderType,
datasource_info: Mapping[str, Any], datasource_info: Mapping[str, Any],
created_from: str, created_from: str,
position: int, position: int,
@@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator):
batch: str, batch: str,
document_form: str, document_form: str,
): ):
if datasource_type == "local_file": match datasource_type:
name = datasource_info.get("name", "untitled") case DatasourceProviderType.LOCAL_FILE:
elif datasource_type == "online_document": name = datasource_info.get("name", "untitled")
name = datasource_info.get("page", {}).get("page_name", "untitled") case DatasourceProviderType.ONLINE_DOCUMENT:
elif datasource_type == "website_crawl": name = datasource_info.get("page", {}).get("page_name", "untitled")
name = datasource_info.get("title", "untitled") case DatasourceProviderType.WEBSITE_CRAWL:
elif datasource_type == "online_drive": name = datasource_info.get("title", "untitled")
name = datasource_info.get("name", "untitled") case DatasourceProviderType.ONLINE_DRIVE:
else: name = datasource_info.get("name", "untitled")
raise ValueError(f"Unsupported datasource type: {datasource_type}") case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document( document = Document(
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_id=dataset_id, dataset_id=dataset_id,
@@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator):
def _format_datasource_info_list( def _format_datasource_info_list(
self, self,
datasource_type: str, datasource_type: DatasourceProviderType,
datasource_info_list: list[Mapping[str, Any]], datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
@@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator):
""" """
Format datasource info list. Format datasource info list.
""" """
if datasource_type == "online_drive": if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
all_files: list[Mapping[str, Any]] = [] all_files: list[Mapping[str, Any]] = []
datasource_node_data = None datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", []) datasource_nodes = workflow.graph_dict.get("nodes", [])
+5 -5
View File
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel): class AnnotationReplyAccount(BaseModel):
@@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
id: str id: str
workflow_id: str workflow_id: str
status: str status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, Any] | None = None
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
@@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse):
process_data_truncated: bool = False process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True outputs_truncated: bool = True
status: str status: WorkflowNodeExecutionStatus
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse):
process_data_truncated: bool = False process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False outputs_truncated: bool = False
status: str status: WorkflowNodeExecutionStatus
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str id: str
workflow_id: str workflow_id: str
status: str status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, Any] | None = None
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
+1
View File
@@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel): class PreviewDetail(BaseModel):
content: str content: str
summary: str | None = None
child_chunks: list[str] | None = None child_chunks: list[str] | None = None
@@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC):
@classmethod @classmethod
def get_default_config(cls) -> DefaultConfig: def get_default_config(cls) -> DefaultConfig:
return { variables: list[VariableConfig] = [
"type": "code", {"variable": "arg1", "value_selector": []},
"config": { {"variable": "arg2", "value_selector": []},
"variables": [ ]
{"variable": "arg1", "value_selector": []}, outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}}
{"variable": "arg2", "value_selector": []},
], config: CodeConfig = {
"code_language": cls.get_language(), "variables": variables,
"code": cls.get_default_code(), "code_language": cls.get_language(),
"outputs": {"result": {"type": "string", "children": None}}, "code": cls.get_default_code(),
}, "outputs": outputs,
} }
return {"type": "code", "config": config}
+68 -57
View File
@@ -311,14 +311,18 @@ class IndexingRunner:
qa_preview_texts: list[QAPreviewDetail] = [] qa_preview_texts: list[QAPreviewDetail] = []
total_segments = 0 total_segments = 0
# doc_form represents the segmentation method (general, parent-child, QA)
index_type = doc_form index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
# one extract_setting is one source document
for extract_setting in extract_settings: for extract_setting in extract_settings:
# extract # extract
processing_rule = DatasetProcessRule( processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
) )
# Extract document content
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# Cleaning and segmentation
documents = index_processor.transform( documents = index_processor.transform(
text_docs, text_docs,
current_user=None, current_user=None,
@@ -361,75 +365,82 @@ class IndexingRunner:
if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[]) return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
# Generate summary preview
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
preview_texts = index_processor.generate_summary_preview(
tenant_id, preview_texts, summary_index_setting, doc_language
)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract( def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]: ) -> list[Document]:
# load file
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return []
data_source_info = dataset_document.data_source_info_dict data_source_info = dataset_document.data_source_info_dict
text_docs = [] text_docs = []
if dataset_document.data_source_type == "upload_file": match dataset_document.data_source_type:
if not data_source_info or "upload_file_id" not in data_source_info: case "upload_file":
raise ValueError("no upload file found") if not data_source_info or "upload_file_id" not in data_source_info:
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) raise ValueError("no upload file found")
file_detail = db.session.scalars(stmt).one_or_none() stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail: if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
or "notion_page_id" not in data_source_info
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, datasource_type=DatasourceType.NOTION,
upload_file=file_detail, notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form, document_model=dataset_document.doc_form,
) )
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import": case "website_crawl":
if ( if (
not data_source_info not data_source_info
or "notion_workspace_id" not in data_source_info or "provider" not in data_source_info
or "notion_page_id" not in data_source_info or "url" not in data_source_info
): or "job_id" not in data_source_info
raise ValueError("no notion import info found") ):
extract_setting = ExtractSetting( raise ValueError("no website import info found")
datasource_type=DatasourceType.NOTION, extract_setting = ExtractSetting(
notion_info=NotionInfo.model_validate( datasource_type=DatasourceType.WEBSITE,
{ website_info=WebsiteInfo.model_validate(
"credential_id": data_source_info.get("credential_id"), {
"notion_workspace_id": data_source_info["notion_workspace_id"], "provider": data_source_info["provider"],
"notion_obj_id": data_source_info["notion_page_id"], "job_id": data_source_info["job_id"],
"notion_page_type": data_source_info["type"], "tenant_id": dataset_document.tenant_id,
"document": dataset_document, "url": data_source_info["url"],
"tenant_id": dataset_document.tenant_id, "mode": data_source_info["mode"],
} "only_main_content": data_source_info["only_main_content"],
), }
document_model=dataset_document.doc_form, ),
) document_model=dataset_document.doc_form,
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) )
elif dataset_document.data_source_type == "website_crawl": text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
if ( case _:
not data_source_info return []
or "provider" not in data_source_info
or "url" not in data_source_info
or "job_id" not in data_source_info
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
# update document status to splitting # update document status to splitting
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
+20
View File
@@ -0,0 +1,20 @@
"""Shared payload models for LLM generator helpers and controllers."""
from pydantic import BaseModel, Field
from core.app.app_config.entities import ModelConfig
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
+46 -37
View File
@@ -6,6 +6,8 @@ from typing import Protocol, cast
import json_repair import json_repair
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import ( from core.llm_generator.prompts import (
@@ -151,19 +153,19 @@ class LLMGenerator:
return questions return questions
@classmethod @classmethod
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
output_parser = RuleConfigGeneratorOutputParser() output_parser = RuleConfigGeneratorOutputParser()
error = "" error = ""
error_step = "" error_step = ""
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
model_parameters = model_config.get("completion_params", {}) model_parameters = args.model_config_data.completion_params
if no_variable: if args.no_variable:
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
prompt_generate = prompt_template.format( prompt_generate = prompt_template.format(
inputs={ inputs={
"TASK_DESCRIPTION": instruction, "TASK_DESCRIPTION": args.instruction,
}, },
remove_template_variables=False, remove_template_variables=False,
) )
@@ -175,8 +177,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.LLM, model_type=ModelType.LLM,
provider=model_config.get("provider", ""), provider=args.model_config_data.provider,
model=model_config.get("name", ""), model=args.model_config_data.name,
) )
try: try:
@@ -190,7 +192,7 @@ class LLMGenerator:
error = str(e) error = str(e)
error_step = "generate rule config" error_step = "generate rule config"
except Exception as e: except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e) rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -209,7 +211,7 @@ class LLMGenerator:
# format the prompt_generate_prompt # format the prompt_generate_prompt
prompt_generate_prompt = prompt_template.format( prompt_generate_prompt = prompt_template.format(
inputs={ inputs={
"TASK_DESCRIPTION": instruction, "TASK_DESCRIPTION": args.instruction,
}, },
remove_template_variables=False, remove_template_variables=False,
) )
@@ -220,8 +222,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.LLM, model_type=ModelType.LLM,
provider=model_config.get("provider", ""), provider=args.model_config_data.provider,
model=model_config.get("name", ""), model=args.model_config_data.name,
) )
try: try:
@@ -250,7 +252,7 @@ class LLMGenerator:
# the second step to generate the task_parameter and task_statement # the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format( statement_generate_prompt = statement_template.format(
inputs={ inputs={
"TASK_DESCRIPTION": instruction, "TASK_DESCRIPTION": args.instruction,
"INPUT_TEXT": prompt_content.message.get_text_content(), "INPUT_TEXT": prompt_content.message.get_text_content(),
}, },
remove_template_variables=False, remove_template_variables=False,
@@ -276,7 +278,7 @@ class LLMGenerator:
error_step = "generate conversation opener" error_step = "generate conversation opener"
except Exception as e: except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e) rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -284,16 +286,20 @@ class LLMGenerator:
return rule_config return rule_config
@classmethod @classmethod
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): def generate_code(
if code_language == "python": cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
):
if args.code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else: else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format( prompt = prompt_template.format(
inputs={ inputs={
"INSTRUCTION": instruction, "INSTRUCTION": args.instruction,
"CODE_LANGUAGE": code_language, "CODE_LANGUAGE": args.code_language,
}, },
remove_template_variables=False, remove_template_variables=False,
) )
@@ -302,28 +308,28 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.LLM, model_type=ModelType.LLM,
provider=model_config.get("provider", ""), provider=args.model_config_data.provider,
model=model_config.get("name", ""), model=args.model_config_data.name,
) )
prompt_messages = [UserPromptMessage(content=prompt)] prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = model_config.get("completion_params", {}) model_parameters = args.model_config_data.completion_params
try: try:
response: LLMResult = model_instance.invoke_llm( response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
) )
generated_code = response.message.get_text_content() generated_code = response.message.get_text_content()
return {"code": generated_code, "language": code_language, "error": ""} return {"code": generated_code, "language": args.code_language, "error": ""}
except InvokeError as e: except InvokeError as e:
error = str(e) error = str(e)
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
) )
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod @classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str): def generate_qa_document(cls, tenant_id: str, query, document_language: str):
@@ -353,20 +359,20 @@ class LLMGenerator:
return answer.strip() return answer.strip()
@classmethod @classmethod
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager() model_manager = ModelManager()
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.LLM, model_type=ModelType.LLM,
provider=model_config.get("provider", ""), provider=args.model_config_data.provider,
model=model_config.get("name", ""), model=args.model_config_data.name,
) )
prompt_messages = [ prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=instruction), UserPromptMessage(content=args.instruction),
] ]
model_parameters = model_config.get("model_parameters", {}) model_parameters = args.model_config_data.completion_params
try: try:
response: LLMResult = model_instance.invoke_llm( response: LLMResult = model_instance.invoke_llm(
@@ -390,12 +396,17 @@ class LLMGenerator:
error = str(e) error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e: except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod @staticmethod
def instruction_modify_legacy( def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None tenant_id: str,
flow_id: str,
current: str,
instruction: str,
model_config: ModelConfig,
ideal_output: str | None,
): ):
last_run: Message | None = ( last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
@@ -434,7 +445,7 @@ class LLMGenerator:
node_id: str, node_id: str,
current: str, current: str,
instruction: str, instruction: str,
model_config: dict, model_config: ModelConfig,
ideal_output: str | None, ideal_output: str | None,
workflow_service: WorkflowServiceInterface, workflow_service: WorkflowServiceInterface,
): ):
@@ -505,7 +516,7 @@ class LLMGenerator:
@staticmethod @staticmethod
def __instruction_modify_common( def __instruction_modify_common(
tenant_id: str, tenant_id: str,
model_config: dict, model_config: ModelConfig,
last_run: dict | None, last_run: dict | None,
current: str | None, current: str | None,
error_message: str | None, error_message: str | None,
@@ -526,8 +537,8 @@ class LLMGenerator:
model_instance = ModelManager().get_model_instance( model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.LLM, model_type=ModelType.LLM,
provider=model_config.get("provider", ""), provider=model_config.provider,
model=model_config.get("name", ""), model=model_config.name,
) )
match node_type: match node_type:
case "llm" | "agent": case "llm" | "agent":
@@ -570,7 +581,5 @@ class LLMGenerator:
error = str(e) error = str(e)
return {"error": f"Failed to generate code. Error: {error}"} return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e: except Exception as e:
logger.exception( logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
return {"error": f"An unexpected error occurred: {str(e)}"} return {"error": f"An unexpected error occurred: {str(e)}"}
+19
View File
@@ -434,3 +434,22 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT.""" You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
DEFAULT_GENERATOR_SUMMARY_PROMPT = (
"""Summarize the following content. Extract only the key information and main points. """
"""Remove redundant details.
Requirements:
1. Write a concise summary in plain text
2. You must write in {language}. No language other than {language} should be used.
3. Focus on important facts, concepts, and details
4. If images are included, describe their key information
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
6. Write directly without extra words
7. If there is not enough content to generate a meaningful summary,
return an empty string without any explanation or prompt
Output only the summary text. Start summarizing now:
"""
)
+1 -1
View File
@@ -347,7 +347,7 @@ class BaseSession(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
responder = RequestResponder( responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id, request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None, request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request, request=validated_request,
+1 -1
View File
@@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.MAX_TOKENS: { DefaultParameterName.MAX_TOKENS: {
"label": { "label": {
"en_US": "Max Tokens", "en_US": "Max Tokens",
"zh_Hans": "最大标记", "zh_Hans": "最大 Token 数",
}, },
"type": "int", "type": "int",
"help": { "help": {
@@ -1,10 +1,11 @@
import decimal import decimal
import hashlib import hashlib
from threading import Lock import logging
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, ValidationError
from redis import RedisError
import contexts from configs import dify_config
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError, InvokeServerUnavailableError,
) )
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class AIModel(BaseModel): class AIModel(BaseModel):
@@ -144,34 +148,60 @@ class AIModel(BaseModel):
plugin_model_manager = PluginModelClient() plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else [] sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try: try:
contexts.plugin_model_schemas.get() cached_schema_json = redis_client.get(cache_key)
except LookupError: except (RedisError, RuntimeError) as exc:
contexts.plugin_model_schemas.set({}) logger.warning(
contexts.plugin_model_schema_lock.set(Lock()) "Failed to read plugin model schema cache for model %s: %s",
model,
with contexts.plugin_model_schema_lock.get(): str(exc),
if cache_key in contexts.plugin_model_schemas.get(): exc_info=True,
return contexts.plugin_model_schemas.get()[cache_key]
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
) )
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if schema: schema = plugin_model_manager.get_model_schema(
contexts.plugin_model_schemas.get()[cache_key] = schema tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
return schema if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
""" """
@@ -92,6 +92,10 @@ def _build_llm_result_from_first_chunk(
Build a single `LLMResult` from the first returned chunk. Build a single `LLMResult` from the first returned chunk.
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
Note:
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
""" """
content = "" content = ""
content_list: list[PromptMessageContentUnionTypes] = [] content_list: list[PromptMessageContentUnionTypes] = []
@@ -99,18 +103,25 @@ def _build_llm_result_from_first_chunk(
system_fingerprint: str | None = None system_fingerprint: str | None = None
tools_calls: list[AssistantPromptMessage.ToolCall] = [] tools_calls: list[AssistantPromptMessage.ToolCall] = []
first_chunk = next(chunks, None) try:
if first_chunk is not None: first_chunk = next(chunks, None)
if isinstance(first_chunk.delta.message.content, str): if first_chunk is not None:
content += first_chunk.delta.message.content if isinstance(first_chunk.delta.message.content, str):
elif isinstance(first_chunk.delta.message.content, list): content += first_chunk.delta.message.content
content_list.extend(first_chunk.delta.message.content) elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
if first_chunk.delta.message.tool_calls: if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
usage = first_chunk.delta.usage or LLMUsage.empty_usage() usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint system_fingerprint = first_chunk.system_fingerprint
finally:
try:
for _ in chunks:
pass
except Exception:
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
return LLMResult( return LLMResult(
model=model, model=model,
@@ -283,7 +294,7 @@ class LargeLanguageModel(AIModel):
# TODO # TODO
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
if stream and isinstance(result, Generator): if stream and not isinstance(result, LLMResult):
return self._invoke_result_generator( return self._invoke_result_generator(
model=model, model=model,
result=result, result=result,
@@ -5,7 +5,11 @@ import logging
from collections.abc import Sequence from collections.abc import Sequence
from threading import Lock from threading import Lock
from pydantic import ValidationError
from redis import RedisError
import contexts import contexts
from configs import dify_config
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
@@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -175,34 +180,60 @@ class ModelProviderFactory:
""" """
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else [] sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try: try:
contexts.plugin_model_schemas.get() cached_schema_json = redis_client.get(cache_key)
except LookupError: except (RedisError, RuntimeError) as exc:
contexts.plugin_model_schemas.set({}) logger.warning(
contexts.plugin_model_schema_lock.set(Lock()) "Failed to read plugin model schema cache for model %s: %s",
model,
with contexts.plugin_model_schema_lock.get(): str(exc),
if cache_key in contexts.plugin_model_schemas.get(): exc_info=True,
return contexts.plugin_model_schemas.get()[cache_key]
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
) )
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if schema: schema = self.plugin_model_manager.get_model_schema(
contexts.plugin_model_schemas.get()[cache_key] = schema tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)
return schema if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_models( def get_models(
self, self,
@@ -283,6 +314,8 @@ class ModelProviderFactory:
elif model_type == ModelType.TTS: elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params) return TTSModel.model_validate(init_params)
raise ValueError(f"Unsupported model type: {model_type}")
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
""" """
Get provider icon Get provider icon
+125 -29
View File
@@ -23,7 +23,13 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file from core.tools.signature import sign_upload_file
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import (
ChildChunk,
Dataset,
DocumentSegment,
DocumentSegmentSummary,
SegmentAttachmentBinding,
)
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import UploadFile from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
@@ -378,15 +384,15 @@ class RetrievalService:
.all() .all()
} }
records = []
include_segment_ids = set()
segment_child_map = {}
valid_dataset_documents = {} valid_dataset_documents = {}
image_doc_ids: list[Any] = [] image_doc_ids: list[Any] = []
child_index_node_ids = [] child_index_node_ids = []
index_node_ids = [] index_node_ids = []
doc_to_document_map = {} doc_to_document_map = {}
summary_segment_ids = set() # Track segments retrieved via summary
summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score
# First pass: collect all document IDs and identify summary documents
for document in documents: for document in documents:
document_id = document.metadata.get("document_id") document_id = document.metadata.get("document_id")
if document_id not in dataset_documents: if document_id not in dataset_documents:
@@ -397,16 +403,39 @@ class RetrievalService:
continue continue
valid_dataset_documents[document_id] = dataset_document valid_dataset_documents[document_id] = dataset_document
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
# Check if this is a summary document
is_summary = document.metadata.get("is_summary", False)
if is_summary:
# For summary documents, find the original chunk via original_chunk_id
original_chunk_id = document.metadata.get("original_chunk_id")
if original_chunk_id:
summary_segment_ids.add(original_chunk_id)
# Save summary's score for later use
summary_score = document.metadata.get("score")
if summary_score is not None:
try:
summary_score_float = float(summary_score)
# If the same segment has multiple summary hits, take the highest score
if original_chunk_id not in summary_score_map:
summary_score_map[original_chunk_id] = summary_score_float
else:
summary_score_map[original_chunk_id] = max(
summary_score_map[original_chunk_id], summary_score_float
)
except (ValueError, TypeError):
# Skip invalid score values
pass
continue # Skip adding to other lists for summary documents
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE: if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id) image_doc_ids.append(doc_id)
else: else:
child_index_node_ids.append(doc_id) child_index_node_ids.append(doc_id)
else: else:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE: if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id) image_doc_ids.append(doc_id)
else: else:
@@ -419,9 +448,10 @@ class RetrievalService:
segment_ids = [] segment_ids = []
index_node_segments: list[DocumentSegment] = [] index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = []
attachment_map = {} attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[Any, Any] = {} child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map = {} doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
with session_factory.create_session() as session: with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session) attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
@@ -436,6 +466,7 @@ class RetrievalService:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else: else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all() child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
@@ -459,6 +490,7 @@ class RetrievalService:
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments: for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids: if segment_ids:
document_segment_stmt = select(DocumentSegment).where( document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
@@ -470,6 +502,40 @@ class RetrievalService:
if index_node_segments: if index_node_segments:
segments.extend(index_node_segments) segments.extend(index_node_segments)
# Handle summary documents: query segments by original_chunk_id
if summary_segment_ids:
summary_segment_ids_list = list(summary_segment_ids)
summary_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(summary_segment_ids_list),
)
summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
segments.extend(summary_segments)
# Add summary segment IDs to segment_ids for summary query
for seg in summary_segments:
if seg.id not in segment_ids:
segment_ids.append(seg.id)
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
)
.all()
)
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
for segment in segments: for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
@@ -478,45 +544,68 @@ class RetrievalService:
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids: if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id) include_segment_ids.add(segment.id)
# Check if this segment was retrieved via summary
# Use summary score as base score if available, otherwise 0.0
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos: if child_chunks or attachment_infos:
child_chunk_details = [] child_chunk_details = []
max_score = 0.0
for child_chunk in child_chunks: for child_chunk in child_chunks:
document = doc_to_document_map[child_chunk.index_node_id] child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
if child_document:
child_score = child_document.metadata.get("score", 0.0)
else:
child_score = 0.0
child_chunk_detail = { child_chunk_detail = {
"id": child_chunk.id, "id": child_chunk.id,
"content": child_chunk.content, "content": child_chunk.content,
"position": child_chunk.position, "position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0, "score": child_score,
} }
child_chunk_details.append(child_chunk_detail) child_chunk_details.append(child_chunk_detail)
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) max_score = max(max_score, child_score)
for attachment_info in attachment_infos: for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]] file_document = doc_to_document_map.get(attachment_info["id"])
max_score = max( if file_document:
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 max_score = max(max_score, file_document.metadata.get("score", 0.0))
)
map_detail = { map_detail = {
"max_score": max_score, "max_score": max_score,
"child_chunks": child_chunk_details, "child_chunks": child_chunk_details,
} }
segment_child_map[segment.id] = map_detail segment_child_map[segment.id] = map_detail
record = { else:
# No child chunks or attachments, use summary score if available
summary_score = summary_score_map.get(segment.id)
if summary_score is not None:
segment_child_map[segment.id] = {
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
"segment": segment, "segment": segment,
} }
records.append(record) records.append(record)
else: else:
if segment.id not in include_segment_ids: if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id) include_segment_ids.add(segment.id)
max_score = 0.0
document = doc_to_document_map.get(segment.index_node_id) # Check if this segment was retrieved via summary
if document: # Use summary score if available (summary retrieval takes priority)
max_score = max(max_score, document.metadata.get("score", 0.0)) max_score = summary_score_map.get(segment.id, 0.0)
# If not retrieved via summary, use original segment's score
if segment.id not in summary_score_map:
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
# Also consider attachment scores
for attachment_info in attachment_infos: for attachment_info in attachment_infos:
file_document = doc_to_document_map.get(attachment_info["id"]) file_doc = doc_to_document_map.get(attachment_info["id"])
if file_document: if file_doc:
max_score = max(max_score, file_document.metadata.get("score", 0.0)) max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = { record = {
"segment": segment, "segment": segment,
"score": max_score, "score": max_score,
@@ -557,9 +646,16 @@ class RetrievalService:
else None else None
) )
# Extract summary if this segment was retrieved via summary
summary_content = segment_summary_map.get(segment.id)
# Create RetrievalSegments object # Create RetrievalSegments object
retrieval_segment = RetrievalSegments( retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files segment=segment,
child_chunks=child_chunks,
score=score,
files=files,
summary=summary_content,
) )
result.append(retrieval_segment) result.append(retrieval_segment)
@@ -391,46 +391,78 @@ class QdrantVector(BaseVector):
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25. """Return docs most similar by full-text search.
Searches each keyword separately and merges results to ensure documents
matching ANY keyword are returned (OR logic). Results are capped at top_k.
Args:
query: Search query text. Multi-word queries are split into keywords,
with each keyword searched separately. Limited to 10 keywords.
**kwargs: Additional search parameters (top_k, document_ids_filter)
Returns: Returns:
List of documents most similar to the query text and distance for each. List of up to top_k unique documents matching any query keyword.
""" """
from qdrant_client.http import models from qdrant_client.http import models
scroll_filter = models.Filter( # Build base must conditions (AND logic) for metadata filters
must=[ base_must_conditions: list = [
models.FieldCondition( models.FieldCondition(
key="group_id", key="group_id",
match=models.MatchValue(value=self._group_id), match=models.MatchValue(value=self._group_id),
), ),
models.FieldCondition( ]
key="page_content",
match=models.MatchText(text=query),
),
]
)
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter: if document_ids_filter:
if scroll_filter.must: base_must_conditions.append(
scroll_filter.must.append( models.FieldCondition(
models.FieldCondition( key="metadata.document_id",
key="metadata.document_id", match=models.MatchAny(any=document_ids_filter),
match=models.MatchAny(any=document_ids_filter),
)
) )
response = self._client.scroll( )
collection_name=self._collection_name,
scroll_filter=scroll_filter, # Split query into keywords, deduplicate and limit to prevent DoS
limit=kwargs.get("top_k", 2), keywords = list(dict.fromkeys(kw.strip() for kw in query.strip().split() if kw.strip()))[:10]
with_payload=True,
with_vectors=True, if not keywords:
) return []
results = response[0]
documents = [] top_k = kwargs.get("top_k", 2)
for result in results: seen_ids: set[str | int] = set()
if result: documents: list[Document] = []
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document) # Search each keyword separately and merge results.
# This ensures each keyword gets its own search, preventing one keyword's
# results from completely overshadowing another's due to scroll ordering.
for keyword in keywords:
scroll_filter = models.Filter(
must=[
*base_must_conditions,
models.FieldCondition(
key="page_content",
match=models.MatchText(text=keyword),
),
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=top_k,
with_payload=True,
with_vectors=True,
)
results = response[0]
for result in results:
if result and result.id not in seen_ids:
seen_ids.add(result.id)
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
if len(documents) >= top_k:
return documents
return documents return documents
+1
View File
@@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
child_chunks: list[RetrievalChildChunk] | None = None child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None score: float | None = None
files: list[dict[str, str | int]] | None = None files: list[dict[str, str | int]] | None = None
summary: str | None = None # Summary content if retrieved via summary index
@@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel):
doc_metadata: dict[str, Any] | None = None doc_metadata: dict[str, Any] | None = None
title: str | None = None title: str | None = None
files: list[dict[str, Any]] | None = None files: list[dict[str, Any]] | None = None
summary: str | None = None
+6 -3
View File
@@ -1,4 +1,7 @@
"""Abstract interface for document loader implementations.""" """Word (.docx) document extractor used for RAG ingestion.
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
import logging import logging
import mimetypes import mimetypes
@@ -8,7 +11,6 @@ import tempfile
import uuid import uuid
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx
from docx import Document as DocxDocument from docx import Document as DocxDocument
from docx.oxml.ns import qn from docx.oxml.ns import qn
from docx.text.run import Run from docx.text.run import Run
@@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor):
# If the file is a web path, download it to a temporary file, and use that # If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
response = httpx.get(self.file_path, timeout=None) response = ssrf_proxy.get(self.file_path)
if response.status_code != 200: if response.status_code != 200:
response.close() response.close()
@@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor):
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115 self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
try: try:
self.temp_file.write(response.content) self.temp_file.write(response.content)
self.temp_file.flush()
finally: finally:
response.close() response.close()
self.file_path = self.temp_file.name self.file_path = self.temp_file.name
@@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
import httpx import httpx
from configs import dify_config from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
@@ -45,6 +46,27 @@ class BaseIndexProcessor(ABC):
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
Args:
tenant_id: Tenant ID
preview_texts: List of preview details to generate summaries for
summary_index_setting: Summary index configuration
doc_language: Optional document language to ensure summary is generated in the correct language
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def load( def load(
self, self,
@@ -1,9 +1,27 @@
"""Paragraph index processor.""" """Paragraph index processor."""
import logging
import re
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any, cast
logger = logging.getLogger(__name__)
from core.entities.knowledge_entities import PreviewDetail
from core.file import File, FileTransferMethod, FileType, file_manager
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
@@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper from libs import helper
from models import UploadFile
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
class ParagraphIndexProcessor(BaseIndexProcessor): class ParagraphIndexProcessor(BaseIndexProcessor):
@@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword.add_texts(documents) keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
if node_ids: if node_ids:
@@ -227,3 +273,347 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
} }
else: else:
raise ValueError("Chunks is not a list") raise ValueError("Chunks is not a list")
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item."""
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_texts))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
futures = [executor.submit(process, preview) for preview in preview_texts]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode (indexing-estimate), if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise ValueError(error_summary)
return preview_texts
@staticmethod
def generate_summary(
tenant_id: str,
text: str,
summary_index_setting: dict | None = None,
segment_id: str | None = None,
document_language: str | None = None,
) -> tuple[str, LLMUsage]:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
and supports vision models by including images from the segment attachments or text content.
Args:
tenant_id: Tenant ID
text: Text content to summarize
summary_index_setting: Summary index configuration
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
document_language: Optional document language (e.g., "Chinese", "English")
to ensure summary is generated in the correct language
Returns:
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
"""
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
model_name = summary_index_setting.get("model_name")
model_provider_name = summary_index_setting.get("model_provider_name")
summary_prompt = summary_index_setting.get("summary_prompt")
if not model_name or not model_provider_name:
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
# Import default summary prompt
is_default_prompt = False
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
is_default_prompt = True
# Format prompt with document language only for default prompt
# Custom prompts are used as-is to avoid interfering with user-defined templates
# If document_language is provided, use it; otherwise, use "the same language as the input content"
# This is especially important for image-only chunks where text is empty or minimal
if is_default_prompt:
language_for_prompt = document_language or "the same language as the input content"
try:
summary_prompt = summary_prompt.format(language=language_for_prompt)
except KeyError:
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id, model_provider_name, ModelType.LLM
)
model_instance = ModelInstance(provider_model_bundle, model_name)
# Get model schema to check if vision is supported
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
# Extract images if model supports vision
image_files = []
if supports_vision:
# First, try to get images from SegmentAttachmentBinding (preferred method)
if segment_id:
image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
# If no images from attachments, fall back to extracting from text
if not image_files:
image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
# Build prompt messages
prompt_messages = []
if image_files:
# If we have images, create a UserPromptMessage with both text and images
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
# Add images first
for file in image_files:
try:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
)
prompt_message_contents.append(file_content)
except Exception as e:
logger.warning("Failed to convert image file to prompt message content: %s", str(e))
continue
# Add text content
if prompt_message_contents: # Only add text if we successfully added images
prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
# If image conversion failed, fall back to text-only
prompt = f"{summary_prompt}\n{text}"
prompt_messages.append(UserPromptMessage(content=prompt))
else:
# No images, use simple text prompt
prompt = f"{summary_prompt}\n{text}"
prompt_messages.append(UserPromptMessage(content=prompt))
result = model_instance.invoke_llm(
prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False
)
# Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator
if not isinstance(result, LLMResult):
raise ValueError("Expected LLMResult when stream=False")
summary_content = getattr(result.message, "content", "")
usage = result.usage
# Deduct quota for summary generation (same as workflow nodes)
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
return summary_content, usage
@staticmethod
def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
"""
Extract images from markdown text and convert them to File objects.
Args:
tenant_id: Tenant ID
text: Text content that may contain markdown image links
Returns:
List of File objects representing images found in the text
"""
# Extract markdown images using regex pattern
pattern = r"!\[.*?\]\((.*?)\)"
images = re.findall(pattern, text)
if not images:
return []
upload_file_id_list = []
for image in images:
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
# Tool files are handled differently, skip for now
continue
if not upload_file_id_list:
return []
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = (
db.session.query(UploadFile)
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
.all()
)
# Create File objects from UploadFile records
file_objects = []
for upload_file in upload_files:
# Only process image files
if not upload_file.mime_type or "image" not in upload_file.mime_type:
continue
mapping = {
"upload_file_id": upload_file.id,
"transfer_method": FileTransferMethod.LOCAL_FILE.value,
"type": FileType.IMAGE.value,
}
try:
file_obj = build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
file_objects.append(file_obj)
except Exception as e:
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
continue
return file_objects
@staticmethod
def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
"""
Extract images from SegmentAttachmentBinding table (preferred method).
This matches how DatasetRetrieval gets segment attachments.
Args:
tenant_id: Tenant ID
segment_id: Segment ID to fetch attachments for
Returns:
List of File objects representing images found in segment attachments
"""
from sqlalchemy import select
# Query attachments from SegmentAttachmentBinding table
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment_id,
SegmentAttachmentBinding.tenant_id == tenant_id,
)
).all()
if not attachments_with_bindings:
return []
file_objects = []
for _, upload_file in attachments_with_bindings:
# Only process image files
if not upload_file.mime_type or "image" not in upload_file.mime_type:
continue
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
)
file_objects.append(file_obj)
except Exception as e:
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
continue
return file_objects
@@ -1,11 +1,14 @@
"""Paragraph index processor.""" """Paragraph index processor."""
import json import json
import logging
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from configs import dify_config from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
@@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class ParentChildIndexProcessor(BaseIndexProcessor): class ParentChildIndexProcessor(BaseIndexProcessor):
@@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids # node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
@@ -326,3 +356,97 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
"preview": preview, "preview": preview,
"total_segments": len(parent_childs.parent_child_chunks), "total_segments": len(parent_childs.parent_child_chunks),
} }
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
Note: For parent-child structure, we only generate summaries for parent chunks.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item (parent chunk)."""
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_texts))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
futures = [executor.submit(process, preview) for preview in preview_texts]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode (indexing-estimate), if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise ValueError(error_summary)
return preview_texts
@@ -11,6 +11,8 @@ import pandas as pd
from flask import Flask, current_app from flask import Flask, current_app
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
@@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.account import Account from models.account import Account
from models.dataset import Dataset from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor):
vector.create_multimodal(multimodal_documents) vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Note: qa_model doesn't generate summaries, but we clean them for completeness
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
vector = Vector(dataset) vector = Vector(dataset)
if node_ids: if node_ids:
vector.delete_by_ids(node_ids) vector.delete_by_ids(node_ids)
@@ -212,6 +240,21 @@ class QAIndexProcessor(BaseIndexProcessor):
"total_segments": len(qa_chunks.qa_chunks), "total_segments": len(qa_chunks.qa_chunks),
} }
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
Note: QA model uses question-answer pairs, which don't require summary generation.
"""
# QA model doesn't generate summaries, return as-is
return preview_texts
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = [] format_documents = []
if document_node.page_content is None or not document_node.page_content.strip(): if document_node.page_content is None or not document_node.page_content.strip():
+18 -11
View File
@@ -236,20 +236,24 @@ class DatasetRetrieval:
if records: if records:
for record in records: for record in records:
segment = record.segment segment = record.segment
# Build content: if summary exists, add it before the segment content
if segment.answer: if segment.answer:
document_context_list.append( segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
else: else:
document_context_list.append( segment_content = segment.get_sign_content()
DocumentContext(
content=segment.get_sign_content(), # If summary exists, prepend it to the content
score=record.score, if record.summary:
) final_content = f"{record.summary}\n{segment_content}"
else:
final_content = segment_content
document_context_list.append(
DocumentContext(
content=final_content,
score=record.score,
) )
)
if vision_enabled: if vision_enabled:
attachments_with_bindings = db.session.execute( attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile) select(SegmentAttachmentBinding, UploadFile)
@@ -316,6 +320,9 @@ class DatasetRetrieval:
source.content = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source.content = segment.content source.content = segment.content
# Add summary if this segment was retrieved via summary
if hasattr(record, "summary") and record.summary:
source.summary = record.summary
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list: if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
+1
View File
@@ -35,6 +35,7 @@ class SchemaRegistry:
registry.load_all_versions() registry.load_all_versions()
cls._default_instance = registry cls._default_instance = registry
return cls._default_instance
return cls._default_instance return cls._default_instance
+16 -22
View File
@@ -226,16 +226,13 @@ class ToolManager:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
return cast( return builtin_tool.fork_tool_runtime(
BuiltinTool, runtime=ToolRuntime(
builtin_tool.fork_tool_runtime( tenant_id=tenant_id,
runtime=ToolRuntime( credentials={},
tenant_id=tenant_id, invoke_from=invoke_from,
credentials={}, tool_invoke_from=tool_invoke_from,
invoke_from=invoke_from, )
tool_invoke_from=tool_invoke_from,
)
),
) )
builtin_provider = None builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController): if isinstance(provider_controller, PluginToolProviderController):
@@ -337,18 +334,15 @@ class ToolManager:
decrypted_credentials = refreshed_credentials.credentials decrypted_credentials = refreshed_credentials.credentials
cache.delete() cache.delete()
return cast( return builtin_tool.fork_tool_runtime(
BuiltinTool, runtime=ToolRuntime(
builtin_tool.fork_tool_runtime( tenant_id=tenant_id,
runtime=ToolRuntime( credentials=dict(decrypted_credentials),
tenant_id=tenant_id, credential_type=CredentialType.of(builtin_provider.credential_type),
credentials=dict(decrypted_credentials), runtime_parameters={},
credential_type=CredentialType.of(builtin_provider.credential_type), invoke_from=invoke_from,
runtime_parameters={}, tool_invoke_from=tool_invoke_from,
invoke_from=invoke_from, )
tool_invoke_from=tool_invoke_from,
)
),
) )
elif provider_type == ToolProviderType.API: elif provider_type == ToolProviderType.API:
@@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if records: if records:
for record in records: for record in records:
segment = record.segment segment = record.segment
# Build content: if summary exists, add it before the segment content
if segment.answer: if segment.answer:
document_context_list.append( segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
else: else:
document_context_list.append( segment_content = segment.get_sign_content()
DocumentContext(
content=segment.get_sign_content(), # If summary exists, prepend it to the content
score=record.score, if record.summary:
) final_content = f"{record.summary}\n{segment_content}"
else:
final_content = segment_content
document_context_list.append(
DocumentContext(
content=final_content,
score=record.score,
) )
)
if self.return_resource: if self.return_resource:
for record in records: for record in records:
@@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
source.content = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source.content = segment.content source.content = segment.content
# Add summary if this segment was retrieved via summary
if hasattr(record, "summary") and record.summary:
source.summary = record.summary
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list: if self.return_resource and retrieval_resource_list:
@@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils: class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod @classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
""" """
+5 -5
View File
@@ -23,8 +23,8 @@ class TriggerDebugEventBus:
""" """
# LUA_SELECT: Atomic poll or register for event # LUA_SELECT: Atomic poll or register for event
# KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id} # KEYS[1] = trigger_debug_inbox:{<tenant_id>}:<address_id>
# KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:... # KEYS[2] = trigger_debug_waiting_pool:{<tenant_id>}:...
# ARGV[1] = address_id # ARGV[1] = address_id
LUA_SELECT = ( LUA_SELECT = (
"local v=redis.call('GET',KEYS[1]);" "local v=redis.call('GET',KEYS[1]);"
@@ -35,7 +35,7 @@ class TriggerDebugEventBus:
) )
# LUA_DISPATCH: Dispatch event to all waiting addresses # LUA_DISPATCH: Dispatch event to all waiting addresses
# KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:... # KEYS[1] = trigger_debug_waiting_pool:{<tenant_id>}:...
# ARGV[1] = tenant_id # ARGV[1] = tenant_id
# ARGV[2] = event_json # ARGV[2] = event_json
LUA_DISPATCH = ( LUA_DISPATCH = (
@@ -43,7 +43,7 @@ class TriggerDebugEventBus:
"if #a==0 then return 0 end;" "if #a==0 then return 0 end;"
"redis.call('DEL',KEYS[1]);" "redis.call('DEL',KEYS[1]);"
"for i=1,#a do " "for i=1,#a do "
f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});" f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
"end;" "end;"
"return #a" "return #a"
) )
@@ -108,7 +108,7 @@ class TriggerDebugEventBus:
Event object if available, None otherwise Event object if available, None otherwise
""" """
address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest() address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}" address: str = f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}"
try: try:
event_data = redis_client.eval( event_data = redis_client.eval(
+2 -2
View File
@@ -42,7 +42,7 @@ def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str:
app_id: App ID app_id: App ID
node_id: Node ID node_id: Node ID
""" """
return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}" return f"{TriggerDebugPoolKey.WEBHOOK}:{{{tenant_id}}}:{app_id}:{node_id}"
class PluginTriggerDebugEvent(BaseDebugEvent): class PluginTriggerDebugEvent(BaseDebugEvent):
@@ -64,4 +64,4 @@ def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str
provider_id: Provider ID provider_id: Provider ID
subscription_id: Subscription ID subscription_id: Subscription ID
""" """
return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}" return f"{TriggerDebugPoolKey.PLUGIN}:{{{tenant_id}}}:{str(provider_id)}:{subscription_id}:{name}"
+15 -16
View File
@@ -5,15 +5,20 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final from typing import Protocol, cast, final
from pydantic import TypeAdapter
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict from libs.typing import is_str
from .edge import Edge from .edge import Edge
from .validation import get_graph_validator from .validation import get_graph_validator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
class NodeFactory(Protocol): class NodeFactory(Protocol):
""" """
@@ -23,7 +28,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety. allowing for different node creation strategies while maintaining type safety.
""" """
def create_node(self, node_config: dict[str, object]) -> Node: def create_node(self, node_config: NodeConfigDict) -> Node:
""" """
Create a Node instance from node configuration data. Create a Node instance from node configuration data.
@@ -63,28 +68,24 @@ class Graph:
self.root_node = root_node self.root_node = root_node
@classmethod @classmethod
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]: def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
""" """
Parse node configurations and build a mapping of node IDs to configs. Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries :param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config :return: mapping of node ID to node config
""" """
node_configs_map: dict[str, dict[str, object]] = {} node_configs_map: dict[str, NodeConfigDict] = {}
for node_config in node_configs: for node_config in node_configs:
node_id = node_config.get("id") node_configs_map[node_config["id"]] = node_config
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
return node_configs_map return node_configs_map
@classmethod @classmethod
def _find_root_node_id( def _find_root_node_id(
cls, cls,
node_configs_map: Mapping[str, Mapping[str, object]], node_configs_map: Mapping[str, NodeConfigDict],
edge_configs: Sequence[Mapping[str, object]], edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None, root_node_id: str | None = None,
) -> str: ) -> str:
@@ -113,10 +114,8 @@ class Graph:
# Prefer START node if available # Prefer START node if available
start_node_id = None start_node_id = None
for nid in root_candidates: for nid in root_candidates:
node_data = node_configs_map[nid].get("data") node_data = node_configs_map[nid]["data"]
if not is_str_dict(node_data): node_type = node_data["type"]
continue
node_type = node_data.get("type")
if not isinstance(node_type, str): if not isinstance(node_type, str):
continue continue
if NodeType(node_type).is_start_node: if NodeType(node_type).is_start_node:
@@ -176,7 +175,7 @@ class Graph:
@classmethod @classmethod
def _create_node_instances( def _create_node_instances(
cls, cls,
node_configs_map: dict[str, dict[str, object]], node_configs_map: dict[str, NodeConfigDict],
node_factory: NodeFactory, node_factory: NodeFactory,
) -> dict[str, Node]: ) -> dict[str, Node]:
""" """
@@ -303,7 +302,7 @@ class Graph:
node_configs = graph_config.get("nodes", []) node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs) edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs) node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs: if not node_configs:
raise ValueError("Graph must have at least one node") raise ValueError("Graph must have at least one node")
@@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel from .protocols.command_channel import CommandChannel
from .ready_queue import ReadyQueue
from .worker_management import WorkerPool from .worker_management import WorkerPool
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -90,7 +89,7 @@ class GraphEngine:
self._graph_execution.workflow_id = workflow_id self._graph_execution.workflow_id = workflow_id
# === Execution Queues === # === Execution Queues ===
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue) self._ready_queue = self._graph_runtime_state.ready_queue
# Queue for events generated during execution # Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
@@ -15,10 +15,10 @@ from uuid import uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .path import Path from .path import Path
from .session import ResponseSession from .session import ResponseSession
@@ -75,7 +75,7 @@ class ResponseStreamCoordinator:
Ensures ordered streaming of responses based on upstream node outputs and constants. Ensures ordered streaming of responses based on upstream node outputs and constants.
""" """
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None: def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
""" """
Initialize coordinator with variable pool. Initialize coordinator with variable pool.
@@ -10,10 +10,10 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.runtime.graph_runtime_state import NodeProtocol
@dataclass @dataclass
@@ -29,21 +29,26 @@ class ResponseSession:
index: int = 0 # Current position in the template segments index: int = 0 # Current position in the template segments
@classmethod @classmethod
def from_node(cls, node: Node) -> ResponseSession: def from_node(cls, node: NodeProtocol) -> ResponseSession:
""" """
Create a ResponseSession from an AnswerNode or EndNode. Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
- `id: str`
- `get_streaming_template() -> Template`
Args: Args:
node: Must be either an AnswerNode or EndNode instance node: Node from the materialized workflow graph.
Returns: Returns:
ResponseSession configured with the node's streaming template ResponseSession configured with the node's streaming template
Raises: Raises:
TypeError: If node is not an AnswerNode or EndNode TypeError: If node is not a supported response node type.
""" """
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode): if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
return cls( return cls(
node_id=node.id, node_id=node.id,
template=node.get_streaming_template(), template=node.get_streaming_template(),
+33 -31
View File
@@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]):
result[parameter_name] = None result[parameter_name] = None
continue continue
agent_input = node_data.agent_parameters[parameter_name] agent_input = node_data.agent_parameters[parameter_name]
if agent_input.type == "variable": match agent_input.type:
variable = variable_pool.get(agent_input.value) # type: ignore case "variable":
if variable is None: variable = variable_pool.get(agent_input.value) # type: ignore
raise AgentVariableNotFoundError(str(agent_input.value)) if variable is None:
parameter_value = variable.value raise AgentVariableNotFoundError(str(agent_input.value))
elif agent_input.type in {"mixed", "constant"}: parameter_value = variable.value
# variable_pool.convert_template expects a string template, case "mixed" | "constant":
# but if passing a dict, convert to JSON string first before rendering # variable_pool.convert_template expects a string template,
try: # but if passing a dict, convert to JSON string first before rendering
if not isinstance(agent_input.value, str): try:
parameter_value = json.dumps(agent_input.value, ensure_ascii=False) if not isinstance(agent_input.value, str):
else: parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value) parameter_value = str(agent_input.value)
except TypeError: segment_group = variable_pool.convert_template(parameter_value)
parameter_value = str(agent_input.value) parameter_value = segment_group.log if for_log else segment_group.text
segment_group = variable_pool.convert_template(parameter_value) # variable_pool.convert_template returns a string,
parameter_value = segment_group.log if for_log else segment_group.text # so we need to convert it back to a dictionary
# variable_pool.convert_template returns a string, try:
# so we need to convert it back to a dictionary if not isinstance(agent_input.value, str):
try: parameter_value = json.loads(parameter_value)
if not isinstance(agent_input.value, str): except json.JSONDecodeError:
parameter_value = json.loads(parameter_value) parameter_value = parameter_value
except json.JSONDecodeError: case _:
parameter_value = parameter_value raise AgentInputTypeError(agent_input.type)
else:
raise AgentInputTypeError(agent_input.type)
value = parameter_value value = parameter_value
if parameter.type == "array[tools]": if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value) value = cast(list[dict[str, Any]], value)
@@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]):
result: dict[str, Any] = {} result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters: for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name] input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]: match input.type:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() case "mixed" | "constant":
for selector in selectors: selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
result[selector.variable] = selector.value_selector for selector in selectors:
elif input.type == "variable": result[selector.variable] = selector.value_selector
result[parameter_name] = input.value case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()} result = {node_id + "." + key: value for key, value in result.items()}
+2 -2
View File
@@ -1,4 +1,4 @@
from typing import Annotated, Literal, Self from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel from pydantic import AfterValidator, BaseModel
@@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel): class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)] type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, Self] | None = None children: dict[str, "CodeNodeData.Output"] | None = None
class Dependency(BaseModel): class Dependency(BaseModel):
name: str name: str
@@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
if datasource_type is None: if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set") raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
datasource_runtime = DatasourceManager.get_datasource_runtime( datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "", datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type), datasource_type=datasource_type,
) )
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
@@ -268,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
if typed_node_data.datasource_parameters: if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters: for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name] input = typed_node_data.datasource_parameters[parameter_name]
if input.type == "mixed": match input.type:
assert isinstance(input.value, str) case "mixed":
selectors = VariableTemplateParser(input.value).extract_variable_selectors() assert isinstance(input.value, str)
for selector in selectors: selectors = VariableTemplateParser(input.value).extract_variable_selectors()
result[selector.variable] = selector.value_selector for selector in selectors:
elif input.type == "variable": result[selector.variable] = selector.value_selector
result[parameter_name] = input.value case "variable":
elif input.type == "constant": result[parameter_name] = input.value
pass case "constant":
pass
case None:
pass
result = {node_id + "." + key: value for key, value in result.items()} result = {node_id + "." + key: value for key, value in result.items()}
@@ -306,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
for message in message_stream: for message in message_stream:
if message.type in { match message.type:
DatasourceMessage.MessageType.IMAGE_LINK, case (
DatasourceMessage.MessageType.BINARY_LINK, DatasourceMessage.MessageType.IMAGE_LINK
DatasourceMessage.MessageType.IMAGE, | DatasourceMessage.MessageType.BINARY_LINK
}: | DatasourceMessage.MessageType.IMAGE
assert isinstance(message.message, DatasourceMessage.TextMessage) ):
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0] datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session: with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt) datasource_file = session.scalar(stmt)
if datasource_file is None: if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist") raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = { mapping = {
"tool_file_id": datasource_file_id, "tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method, "transfer_method": transfer_method,
"url": url, "url": url,
} }
file = file_factory.build_from_mapping( file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
) files.append(file)
elif message.type == DatasourceMessage.MessageType.TEXT: case DatasourceMessage.MessageType.BLOB:
assert isinstance(message.message, DatasourceMessage.TextMessage) # get tool file id
text += message.message.text assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent( assert message.meta
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
case DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent( yield StreamChunkEvent(
selector=[self._node_id, variable_name], selector=[self._node_id, "text"],
chunk=variable_value, chunk=message.message.text,
is_final=False, is_final=False,
) )
else: case DatasourceMessage.MessageType.JSON:
variables[variable_name] = variable_value assert isinstance(message.message, DatasourceMessage.JsonMessage)
elif message.type == DatasourceMessage.MessageType.FILE: json.append(message.message.json_object)
assert message.meta is not None case DatasourceMessage.MessageType.LINK:
files.append(message.meta["file"]) assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
case DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
case DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case (
DatasourceMessage.MessageType.BLOB_CHUNK
| DatasourceMessage.MessageType.LOG
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
):
pass
# mark the end of the stream # mark the end of the stream
yield StreamChunkEvent( yield StreamChunkEvent(
selector=[self._node_id, "text"], selector=[self._node_id, "text"],
@@ -2,7 +2,7 @@ import base64
import json import json
import secrets import secrets
import string import string
from collections.abc import Mapping from collections.abc import Callable, Mapping
from copy import deepcopy from copy import deepcopy
from typing import Any, Literal from typing import Any, Literal
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode, urlparse
@@ -11,9 +11,9 @@ import httpx
from json_repair import repair_json from json_repair import repair_json
from configs import dify_config from configs import dify_config
from core.file import file_manager
from core.file.enums import FileTransferMethod from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
@@ -79,8 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout, timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool, variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol = ssrf_proxy, http_client: HttpClientProtocol | None = None,
file_manager: FileManagerProtocol = file_manager, file_manager: FileManagerProtocol | None = None,
): ):
# If authorization API key is present, convert the API key using the variable pool # If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key": if node_data.authorization.type == "api-key":
@@ -107,8 +107,8 @@ class Executor:
self.data = None self.data = None
self.json = None self.json = None
self.max_retries = max_retries self.max_retries = max_retries
self._http_client = http_client self._http_client = http_client or ssrf_proxy
self._file_manager = file_manager self._file_manager = file_manager or default_file_manager
# init template # init template
self.variable_pool = variable_pool self.variable_pool = variable_pool
@@ -336,7 +336,7 @@ class Executor:
""" """
do http request depending on api bundle do http request depending on api bundle
""" """
_METHOD_MAP = { _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
"get": self._http_client.get, "get": self._http_client.get,
"head": self._http_client.head, "head": self._http_client.head,
"post": self._http_client.post, "post": self._http_client.post,
@@ -348,7 +348,7 @@ class Executor:
if method_lc not in _METHOD_MAP: if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}") raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = { request_args: dict[str, Any] = {
"data": self.data, "data": self.data,
"files": self.files, "files": self.files,
"json": self.json, "json": self.json,
@@ -361,14 +361,13 @@ class Executor:
} }
# request_args = {k: v for k, v in request_args.items() if v is not None} # request_args = {k: v for k, v in request_args.items() if v is not None}
try: try:
response: httpx.Response = _METHOD_MAP[method_lc]( response = _METHOD_MAP[method_lc](
url=self.url, url=self.url,
**request_args, **request_args,
max_retries=self.max_retries, max_retries=self.max_retries,
) )
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e: except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response return response
def invoke(self) -> Response: def invoke(self) -> Response:
+7 -6
View File
@@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod, file_manager from core.file import File, FileTransferMethod
from core.helper import ssrf_proxy from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params: "GraphInitParams", graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState", graph_runtime_state: "GraphRuntimeState",
*, *,
http_client: HttpClientProtocol = ssrf_proxy, http_client: HttpClientProtocol | None = None,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol = file_manager, file_manager: FileManagerProtocol | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
id=id, id=id,
@@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
self._http_client = http_client self._http_client = http_client or ssrf_proxy
self._tool_file_manager_factory = tool_file_manager_factory self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager self._file_manager = file_manager or default_file_manager
@classmethod @classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return outputs return outputs
# Check if all non-None outputs are lists # Check if all non-None outputs are lists
non_none_outputs = [output for output in outputs if output is not None] non_none_outputs: list[object] = [output for output in outputs if output is not None]
if not non_none_outputs: if not non_none_outputs:
return outputs return outputs
@@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
type: str = "knowledge-index" type: str = "knowledge-index"
chunk_structure: str chunk_structure: str
index_chunk_variable_selector: list[str] index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None
@@ -1,9 +1,11 @@
import concurrent.futures
import datetime import datetime
import logging import logging
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from flask import current_app
from sqlalchemy import func, select from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
from .entities import KnowledgeIndexNodeData from .entities import KnowledgeIndexNodeData
from .exc import ( from .exc import (
@@ -67,7 +71,29 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
# index knowledge # index knowledge
try: try:
if is_preview: if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks) # Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
# Try to get document language if document_id is available
doc_language = None
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document and document.doc_language:
doc_language = document.doc_language
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
doc_language=doc_language,
)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables, inputs=variables,
@@ -148,6 +174,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
) )
.scalar() .scalar()
) )
# Update need_summary based on dataset's summary_index_setting
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
document.need_summary = True
else:
document.need_summary = False
db.session.add(document) db.session.add(document)
# update document segment status # update document segment status
db.session.query(DocumentSegment).where( db.session.query(DocumentSegment).where(
@@ -163,6 +194,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
db.session.commit() db.session.commit()
# Generate summary index if enabled
self._handle_summary_index_generation(dataset, document, variable_pool)
return { return {
"dataset_id": ds_id_value, "dataset_id": ds_id_value,
"dataset_name": dataset_name_value, "dataset_name": dataset_name_value,
@@ -173,9 +207,308 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
"display_status": "completed", "display_status": "completed",
} }
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]: def _handle_summary_index_generation(
self,
dataset: Dataset,
document: Document,
variable_pool: VariablePool,
) -> None:
"""
Handle summary index generation based on mode (debug/preview or production).
Args:
dataset: Dataset containing the document
document: Document to generate summaries for
variable_pool: Variable pool to check invoke_from
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return
# Check if summary index is enabled
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
# Skip qa_model documents
if document.doc_form == "qa_model":
return
# Determine if in preview/debug mode
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
if is_preview:
try:
# Query segments that need summary generation
query = db.session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
segments = query.all()
if not segments:
logger.info("No segments found for document %s", document.id)
return
# Filter segments based on mode
segments_to_process = []
for segment in segments:
# Skip if summary already exists
existing_summary = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
.first()
)
if existing_summary:
continue
# For parent-child mode, all segments are parent chunks, so process all
segments_to_process.append(segment)
if not segments_to_process:
logger.info("No segments need summary generation for document %s", document.id)
return
# Use ThreadPoolExecutor for concurrent generation
flask_app = current_app._get_current_object() # type: ignore
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
def process_segment(segment: DocumentSegment) -> None:
"""Process a single segment in a thread with Flask app context."""
with flask_app.app_context():
try:
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
except Exception:
logger.exception(
"Failed to generate summary for segment %s",
segment.id,
)
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
# Wait for all tasks to complete
concurrent.futures.wait(futures)
logger.info(
"Successfully generated summary index for %s segments in document %s",
len(segments_to_process),
document.id,
)
except Exception:
logger.exception("Failed to generate summary index for document %s", document.id)
# Don't fail the entire indexing process if summary generation fails
else:
# Production mode: asynchronous generation
logger.info(
"Queuing summary index generation task for document %s (production mode)",
document.id,
)
try:
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info("Summary index generation task queued for document %s", document.id)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document.id,
)
# Don't fail the entire indexing process if task queuing fails
def _get_preview_output_with_summaries(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
doc_language: str | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
This method generates summaries on-the-fly without saving to database.
Args:
chunk_structure: Chunk structure type
chunks: Chunks to generate preview for
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
doc_language: Optional document language to ensure summary is generated in the correct language
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks) preview_output = index_processor.format_preview(chunks)
# Check if summary index is enabled
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
# Generate summaries for chunks
if "preview" in preview_output and isinstance(preview_output["preview"], list):
chunk_count = len(preview_output["preview"])
logger.info(
"Generating summaries for %s chunks in preview mode (dataset: %s)",
chunk_count,
dataset.id,
)
# Use ParagraphIndexProcessor's generate_summary method
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get Flask app for application context in worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: dict) -> None:
"""Generate summary for a single chunk."""
if "content" in preview_item:
# Set Flask application context in worker thread
if flask_app:
with flask_app.app_context():
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
else:
# Fallback: try without app context (may fail)
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item)
for preview_item in preview_output["preview"]
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode, if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise KnowledgeIndexNodeError(error_summary)
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
logger.info(
"Completed summary generation for preview chunks: %s/%s succeeded",
completed_count,
len(preview_output["preview"]),
)
return preview_output
def _get_preview_output(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset | None = None,
variable_pool: VariablePool | None = None,
) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# If dataset is provided, try to enrich preview with summaries
if dataset and variable_pool:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document:
# Query summaries for this document
summaries = (
db.session.query(DocumentSegmentSummary)
.filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
.all()
)
if summaries:
# Create a map of segment content to summary for matching
# Use content matching as chunks in preview might not be indexed yet
summary_by_content = {}
for summary in summaries:
segment = (
db.session.query(DocumentSegment)
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
.first()
)
if segment:
# Normalize content for matching (strip whitespace)
normalized_content = segment.content.strip()
summary_by_content[normalized_content] = summary.summary_content
# Enrich preview with summaries by content matching
if "preview" in preview_output and isinstance(preview_output["preview"], list):
matched_count = 0
for preview_item in preview_output["preview"]:
if "content" in preview_item:
# Normalize content for matching
normalized_chunk_content = preview_item["content"].strip()
if normalized_chunk_content in summary_by_content:
preview_item["summary"] = summary_by_content[normalized_chunk_content]
matched_count += 1
if matched_count > 0:
logger.info(
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
matched_count,
dataset.id,
document.id,
)
return preview_output
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None: if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required") raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": match node_data.multiple_retrieval_config.reranking_mode:
if node_data.multiple_retrieval_config.reranking_model: case "reranking_model":
reranking_model = { if node_data.multiple_retrieval_config.reranking_model:
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, reranking_model = {
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
} "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
else: }
else:
reranking_model = None
weights = None
case "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None reranking_model = None
weights = None vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": weights = {
if node_data.multiple_retrieval_config.weights is None: "vector_setting": {
raise ValueError("weights is required") "vector_weight": vector_setting.vector_weight,
reranking_model = None "embedding_provider_name": vector_setting.embedding_provider_name,
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting "embedding_model_name": vector_setting.embedding_model_name,
weights = { },
"vector_setting": { "keyword_setting": {
"vector_weight": vector_setting.vector_weight, "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
"embedding_provider_name": vector_setting.embedding_provider_name, },
"embedding_model_name": vector_setting.embedding_model_name, }
}, case _:
"keyword_setting": { reranking_model = None
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight weights = None
},
}
else:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve( all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id, app_id=self.app_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@@ -419,6 +420,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.get_sign_content() source["content"] = segment.get_sign_content()
# Add summary if available
if record.summary:
source["summary"] = record.summary
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if retrieval_resource_list: if retrieval_resource_list:
retrieval_resource_list = sorted( retrieval_resource_list = sorted(
@@ -450,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
) )
filters: list[Any] = [] filters: list[Any] = []
metadata_condition = None metadata_condition = None
if node_data.metadata_filtering_mode == "disabled": match node_data.metadata_filtering_mode:
return None, None, usage case "disabled":
elif node_data.metadata_filtering_mode == "automatic": return None, None, usage
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( case "automatic":
dataset_ids, query, node_data automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
) dataset_ids, query, node_data
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
) )
elif node_data.metadata_filtering_mode == "manual": usage = self._merge_usage(usage, automatic_usage)
if node_data.metadata_filtering_conditions: if automatic_metadata_filters:
conditions = [] conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore for sequence, filter in enumerate(automatic_metadata_filters):
metadata_name = condition.name DatasetRetrieval.process_metadata_filter_func(
expected_value = condition.value sequence,
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): filter.get("condition", ""),
if isinstance(expected_value, str): filter.get("metadata_name", ""),
expected_value = self.graph_runtime_state.variable_pool.convert_template( filter.get("value"),
expected_value filters,
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
) )
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
) )
filters = DatasetRetrieval.process_metadata_filter_func( case "manual":
sequence, if node_data.metadata_filtering_conditions:
condition.comparison_operator, conditions = []
metadata_name, for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
expected_value, metadata_name = condition.name
filters, expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
) )
metadata_condition = MetadataCondition( case _:
logical_operator=node_data.metadata_filtering_conditions.logical_operator, raise ValueError("Invalid metadata filtering mode")
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters: if filters:
if ( if (
node_data.metadata_filtering_conditions node_data.metadata_filtering_conditions
@@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "name": case "name":
return lambda x: x.filename or "" return lambda x: x.filename or ""
case "type": case "type":
return lambda x: x.type return lambda x: str(x.type)
case "extension": case "extension":
return lambda x: x.extension or "" return lambda x: x.extension or ""
case "mime_type": case "mime_type":
return lambda x: x.mime_type or "" return lambda x: x.mime_type or ""
case "transfer_method": case "transfer_method":
return lambda x: x.transfer_method return lambda x: str(x.transfer_method)
case "url": case "url":
return lambda x: x.remote_url or "" return lambda x: x.remote_url or ""
case "related_id": case "related_id":
@@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key) extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_string_func(key=key) extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str): elif key == "size" and isinstance(value, str):
extract_func = _get_file_extract_number_func(key=key) extract_number = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
else: else:
raise InvalidKeyError(f"Invalid key: {key}") raise InvalidKeyError(f"Invalid key: {key}")
+13 -13
View File
@@ -685,6 +685,8 @@ class LLMNode(Node[LLMNodeData]):
if "content" not in item: if "content" not in item:
raise InvalidContextStructureError(f"Invalid context structure: {item}") raise InvalidContextStructureError(f"Invalid context structure: {item}")
if item.get("summary"):
context_str += item["summary"] + "\n"
context_str += item["content"] + "\n" context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item) retriever_resource = self._convert_to_original_retriever_resource(item)
@@ -746,6 +748,7 @@ class LLMNode(Node[LLMNodeData]):
page=metadata.get("page"), page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"), doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"), files=context_dict.get("files"),
summary=context_dict.get("summary"),
) )
return source return source
@@ -849,18 +852,16 @@ class LLMNode(Node[LLMNodeData]):
# Insert histories into the prompt # Insert histories into the prompt
prompt_content = prompt_messages[0].content prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list # For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content) if isinstance(prompt_content, str):
if prompt_content_type == str:
prompt_content = str(prompt_content) prompt_content = str(prompt_content)
if "#histories#" in prompt_content: if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text) prompt_content = prompt_content.replace("#histories#", memory_text)
else: else:
prompt_content = memory_text + "\n" + prompt_content prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content prompt_messages[0].content = prompt_content
elif prompt_content_type == list: elif isinstance(prompt_content, list):
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content: for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT: if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data: if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text) content_item.data = content_item.data.replace("#histories#", memory_text)
else: else:
@@ -870,13 +871,12 @@ class LLMNode(Node[LLMNodeData]):
# Add current query to the prompt message # Add current query to the prompt message
if sys_query: if sys_query:
if prompt_content_type == str: if isinstance(prompt_content, str):
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content prompt_messages[0].content = prompt_content
elif prompt_content_type == list: elif isinstance(prompt_content, list):
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content: for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT: if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data content_item.data = sys_query + "\n" + content_item.data
else: else:
raise ValueError("Invalid prompt content type") raise ValueError("Invalid prompt content type")
@@ -1030,14 +1030,14 @@ class LLMNode(Node[LLMNodeData]):
if typed_node_data.prompt_config: if typed_node_data.prompt_config:
enable_jinja = False enable_jinja = False
if isinstance(prompt_template, list): if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type == "jinja2":
enable_jinja = True
else:
for prompt in prompt_template: for prompt in prompt_template:
if prompt.edition_type == "jinja2": if prompt.edition_type == "jinja2":
enable_jinja = True enable_jinja = True
break break
else:
if prompt_template.edition_type == "jinja2":
enable_jinja = True
if enable_jinja: if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
+7 -7
View File
@@ -1,4 +1,4 @@
from typing import Protocol from typing import Any, Protocol
import httpx import httpx
@@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol):
@property @property
def request_error(self) -> type[Exception]: ... def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
class FileManagerProtocol(Protocol): class FileManagerProtocol(Protocol):
+2 -2
View File
@@ -54,8 +54,8 @@ class ToolNodeData(BaseNodeData, ToolEntity):
for val in value: for val in value:
if not isinstance(val, str): if not isinstance(val, str):
raise ValueError("value must be a list of strings") raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict): elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
raise ValueError("value must be a string, int, float, bool or dict") raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ return typ
tool_parameters: dict[str, ToolInput] tool_parameters: dict[str, ToolInput]
+11 -10
View File
@@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]):
result = {} result = {}
for parameter_name in typed_node_data.tool_parameters: for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name] input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed": match input.type:
assert isinstance(input.value, str) case "mixed":
selectors = VariableTemplateParser(input.value).extract_variable_selectors() assert isinstance(input.value, str)
for selector in selectors: selectors = VariableTemplateParser(input.value).extract_variable_selectors()
result[selector.variable] = selector.value_selector for selector in selectors:
elif input.type == "variable": result[selector.variable] = selector.value_selector
selector_key = ".".join(input.value) case "variable":
result[f"#{selector_key}#"] = input.value selector_key = ".".join(input.value)
elif input.type == "constant": result[f"#{selector_key}#"] = input.value
pass case "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()} result = {node_id + "." + key: value for key, value in result.items()}

Some files were not shown because too many files have changed in this diff Show More