diff --git a/.agents/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md index 7006c382c..140e0ef43 100644 --- a/.agents/skills/component-refactoring/SKILL.md +++ b/.agents/skills/component-refactoring/SKILL.md @@ -480,4 +480,4 @@ const useButtonState = () => { ### Related Skills - `frontend-testing` - For testing refactored components -- `web/testing/testing.md` - Testing specification +- `web/docs/test.md` - Testing specification diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 0716c81ef..280fcb634 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -7,7 +7,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. -> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`). +> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`). ## When to Apply This Skill @@ -309,7 +309,7 @@ For more detailed information, refer to: ### Primary Specification (MUST follow) -- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document. +- **`web/docs/test.md`** - The canonical testing specification. This skill is derived from this document. ### Reference Examples in Codebase diff --git a/.agents/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md index 009c3e013..bc4ed8285 100644 --- a/.agents/skills/frontend-testing/references/workflow.md +++ b/.agents/skills/frontend-testing/references/workflow.md @@ -4,7 +4,7 @@ This guide defines the workflow for generating tests, especially for complex com ## Scope Clarification -This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals. +This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals. | Scope | Rule | |-------|------| diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 106c26bbe..36fa39b5d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,9 @@ # CODEOWNERS file /.github/CODEOWNERS @laipz8200 @crazywoola +# Agents +/.agents/skills/ @hyoban + # Docs /docs/ @crazywoola diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 7c21bec7f..7616db32b 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -72,6 +72,7 @@ jobs: OPENDAL_FS_ROOT: /tmp/dify-storage run: | uv run --project api pytest \ + -n auto \ --timeout "${PYTEST_TIMEOUT:-180}" \ api/tests/integration_tests/workflow \ api/tests/integration_tests/tools \ diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index fdc05d1d6..cbd6edf94 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -47,13 +47,9 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: uv run --directory api --dev lint-imports - - name: Run Basedpyright Checks + - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' - run: dev/basedpyright-check - - - name: Run Mypy Type Checks - if: steps.changed-files.outputs.any_changed == 'true' - run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + run: make type-check - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/AGENTS.md~upstream_main b/AGENTS.md~upstream_main index 7d96ac3a6..51fa6e452 100644 --- a/AGENTS.md~upstream_main +++ b/AGENTS.md~upstream_main @@ -7,7 +7,7 @@ Dify is an open-source platform for developing LLM applications with an intuitiv The codebase is split into: - **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design -- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19 +- **Frontend Web** (`/web`): Next.js application using TypeScript and React - **Docker deployment** (`/docker`): Containerized deployment configurations ## Backend Workflow @@ -18,36 +18,7 @@ The codebase is split into: ## Frontend Workflow -```bash -cd web -pnpm lint:fix -pnpm type-check:tsgo -pnpm test -``` - -### Frontend Linting - -ESLint is used for frontend code quality. Available commands: - -```bash -# Lint all files (report only) -pnpm lint - -# Lint and auto-fix issues -pnpm lint:fix - -# Lint specific files or directories -pnpm lint:fix app/components/base/button/ -pnpm lint:fix app/components/base/button/index.tsx - -# Lint quietly (errors only, no warnings) -pnpm lint:quiet - -# Check code complexity -pnpm lint:complexity -``` - -**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files. +- Read `web/AGENTS.md` for details ## Testing & Quality Practices diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 20a7d6c6f..d7f007af6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,7 +77,7 @@ How we prioritize: For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly. -**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there. +**Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there. #### Backend diff --git a/Makefile b/Makefile index e92a7b131..984e8676e 100644 --- a/Makefile +++ b/Makefile @@ -68,9 +68,11 @@ lint: @echo "✅ Linting complete" type-check: - @echo "📝 Running type check with basedpyright..." - @uv run --directory api --dev basedpyright - @echo "✅ Type check complete" + @echo "📝 Running type checks (basedpyright + mypy + ty)..." + @./dev/basedpyright-check $(PATH_TO_CHECK) + @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + @cd api && uv run ty check + @echo "✅ Type checks complete" test: @echo "🧪 Running backend unit tests..." @@ -78,7 +80,7 @@ test: echo "Target: $(TARGET_TESTS)"; \ uv run --project api --dev pytest $(TARGET_TESTS); \ else \ - uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \ + PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \ fi @echo "✅ Tests complete" @@ -130,7 +132,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" - @echo " make type-check - Run type checking with basedpyright" + @echo " make type-check - Run type checks (basedpyright, mypy, ty)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/api/.env.example b/api/.env.example index d9f03ec09..495e0547c 100644 --- a/api/.env.example +++ b/api/.env.example @@ -617,6 +617,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002 PLUGIN_REMOTE_INSTALL_PORT=5003 PLUGIN_REMOTE_INSTALL_HOST=localhost PLUGIN_MAX_PACKAGE_SIZE=15728640 +PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration diff --git a/api/.importlinter b/api/.importlinter index 2b4a3a5bd..9dad25456 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -227,6 +227,9 @@ ignore_imports = core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset + core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service + core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.llm.node -> models.dataset core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer @@ -300,6 +303,58 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> services core.workflow.nodes.tool.tool_node -> services +[importlinter:contract:model-runtime-no-internal-imports] +name = Model Runtime Internal Imports +type = forbidden +source_modules = + core.model_runtime +forbidden_modules = + configs + controllers + extensions + models + services + tasks + core.agent + core.app + core.base + core.callback_handler + core.datasource + core.db + core.entities + core.errors + core.extension + core.external_data_tool + core.file + core.helper + core.hosting_configuration + core.indexing_runner + core.llm_generator + core.logging + core.mcp + core.memory + core.model_manager + core.moderation + core.ops + core.plugin + core.prompt + core.provider_manager + core.rag + core.repositories + core.schemas + core.tools + core.trigger + core.variables + core.workflow +ignore_imports = + core.model_runtime.model_providers.__base.ai_model -> configs + core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis + core.model_runtime.model_providers.__base.large_language_model -> configs + core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type + core.model_runtime.model_providers.model_provider_factory -> configs + core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis + core.model_runtime.model_providers.model_provider_factory -> models.provider_ids + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/.ruff.toml b/api/.ruff.toml index 8db0cbcb2..3301452ad 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -53,6 +53,7 @@ select = [ "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module "S311", # suspicious-non-cryptographic-random-usage, + "TID", # flake8-tidy-imports ] @@ -88,6 +89,7 @@ ignore = [ "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false + "TID252", # allow relative imports from parent modules ] [lint.per-file-ignores] @@ -109,10 +111,20 @@ ignore = [ "S110", # allow ignoring exceptions in tests code (currently) ] +"controllers/console/explore/trial.py" = ["TID251"] +"controllers/console/human_input_form.py" = ["TID251"] +"controllers/web/human_input_form.py" = ["TID251"] [lint.pyflakes] allowed-unused-imports = [ - "_pytest.monkeypatch", "tests.integration_tests", "tests.unit_tests", ] + +[lint.flake8-tidy-imports] + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] +msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] +msg = "Use Pydantic payload/query models instead of reqparse." diff --git a/api/app.py b/api/app.py index 99f70f32d..c018c8a04 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,12 @@ +from __future__ import annotations + import sys +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from celery import Celery + + celery: Celery def is_db_command() -> bool: @@ -23,7 +31,7 @@ else: from app_factory import create_app app = create_app() - celery = app.extensions["celery"] + celery = cast("Celery", app.extensions["celery"]) if __name__ == "__main__": app.run(host="0.0.0.0", port=5001) diff --git a/api/app_factory.py b/api/app_factory.py index 07859a375..dcbc82168 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -149,7 +149,7 @@ def initialize_extensions(app: DifyApp): logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2)) -def create_migrations_app(): +def create_migrations_app() -> DifyApp: app = create_flask_app_with_configs() from extensions import ext_database, ext_migrate diff --git a/api/commands.py b/api/commands.py index 842c64b56..6d5beb0e6 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool): all_ids_in_tables = [] for ids_table in ids_tables: query = "" - if ids_table["type"] == "uuid": - click.echo( - click.style( - f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white" + match ids_table["type"]: + case "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) ) - ) - query = ( - f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) - elif ids_table["type"] == "text": - click.echo( - click.style( - f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}", - fg="white", + c = ids_table["column"] + query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + case "text": + t = ids_table["table"] + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", + fg="white", + ) ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - elif ids_table["type"] == "json": - click.echo( - click.style( - ( - f"- Listing file-id-like JSON string in column {ids_table['column']} " - f"in table {ids_table['table']}" - ), - fg="white", + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case _: + pass click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) except Exception as e: @@ -1737,59 +1741,18 @@ def file_usage( if src_filter != src: continue - if ids_table["type"] == "uuid": - # Direct UUID match - query = ( - f"SELECT {ids_table['pk_column']}, {ids_table['column']} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - ref_file_id = str(row[1]) - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - - elif ids_table["type"] in ("text", "json"): - # Extract UUIDs from text/json content - column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] - query = ( - f"SELECT {ids_table['pk_column']}, {column_cast} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - content = str(row[1]) - - # Find all UUIDs in the content - import re - - uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) - matches = uuid_pattern.findall(content) - - for ref_file_id in matches: + match ids_table["type"]: + case "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) if ref_file_id not in file_key_map: continue storage_key = file_key_map[ref_file_id] @@ -1812,6 +1775,50 @@ def file_usage( ) total_count += 1 + case "text" | "json": + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + case _: + pass + # Output results if output_json: result = { diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 786094f29..d97e9a044 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -243,6 +243,11 @@ class PluginConfig(BaseSettings): default=15728640 * 12, ) + PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field( + description="TTL in seconds for caching plugin model schemas in Redis", + default=60 * 60, + ) + class MarketplaceConfig(BaseSettings): """ diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 7c16bc231..c52dcf8a5 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( ContextVar("plugin_model_providers_lock") ) -plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) - -plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( - ContextVar("plugin_model_schemas") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index e1ee2c24b..03b602f6e 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource): def post(self): payload = InsertExploreBannerPayload.model_validate(console_ns.payload) - content = { - "category": payload.category, - "title": payload.title, - "description": payload.description, - "img-src": payload.img_src, - } - banner = ExporleBanner( - content=content, + content={ + "category": payload.category, + "title": payload.title, + "description": payload.description, + "img-src": payload.img_src, + }, link=payload.link, sort=payload.sort, language=payload.language, diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 6a4c1528b..9931bb5dd 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,10 +1,11 @@ from typing import Any, Literal from flask import abort, make_response, request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from controllers.common.errors import NoFileUploadedError, TooManyFilesError +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -16,9 +17,11 @@ from controllers.console.wraps import ( ) from extensions.ext_redis import redis_client from fields.annotation_fields import ( - annotation_fields, - annotation_hit_history_fields, - build_annotation_model, + Annotation, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, + AnnotationList, ) from libs.helper import uuid_value from libs.login import login_required @@ -89,6 +92,14 @@ reg(CreateAnnotationPayload) reg(UpdateAnnotationPayload) reg(AnnotationReplyStatusQuery) reg(AnnotationFilePayload) +register_schema_models( + console_ns, + Annotation, + AnnotationList, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, +) @console_ns.route("/apps//annotation-reply/") @@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) args = AnnotationReplyPayload.model_validate(console_ns.payload) - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -201,33 +213,33 @@ class AnnotationApi(Resource): app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response, 200 + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json"), 200 @console_ns.doc("create_annotation") @console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) - @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) + @console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) @edit_permission_required def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) data = args.model_dump(exclude_none=True) annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -264,7 +276,7 @@ class AnnotationExportApi(Resource): @console_ns.response( 200, "Annotations exported successfully", - console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}), + console_ns.models[AnnotationExportList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -274,7 +286,8 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response_data = {"data": marshal(annotation_list, annotation_fields)} + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") # Create response with secure headers for CSV export response = make_response(response_data, 200) @@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.doc("update_delete_annotation") @console_ns.doc(description="Update or delete an annotation") @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) - @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) + @console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__]) @console_ns.response(204, "Annotation deleted successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - @marshal_with(annotation_fields) def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) @@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource): annotation = AppAnnotationService.update_app_annotation_directly( args.model_dump(exclude_none=True), app_id, annotation_id ) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource): @console_ns.response( 200, "Hit histories retrieved successfully", - console_ns.model( - "AnnotationHitHistoryList", - { - "data": fields.List( - fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields)) - ) - }, - ), + console_ns.models[AnnotationHitHistoryList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource): annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( app_id, annotation_id, page, limit ) - response = { - "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), - "has_more": len(annotation_hit_history_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response + history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( + annotation_hit_history_list, from_attributes=True + ) + response = AnnotationHitHistoryList( + data=history_models, + has_more=len(annotation_hit_history_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d344ede46..941db325b 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -33,7 +34,6 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class TextToSpeechPayload(BaseModel): @@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel): language: str = Field(..., description="Language code") -console_ns.schema_model( - TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechVoiceQuery.__name__, - TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AudioTranscriptResponse(BaseModel): + text: str = Field(description="Transcribed text from audio") + + +register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery) @console_ns.route("/apps//audio-to-text") @@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource): @console_ns.response( 200, "Audio transcription successful", - console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + console_ns.models[AudioTranscriptResponse.__name__], ) @console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(413, "Audio file too large") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 55fdcb51e..82cc957d0 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -508,16 +508,19 @@ class ChatConversationApi(Resource): case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at <= end_datetime_utc) - if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ) - elif args.annotation_status == "not_annotated": - query = ( - query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) - .group_by(Conversation.id) - .having(func.count(MessageAnnotation.id) == 0) - ) + match args.annotation_status: + case "annotated": + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + case "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) + case "all": + pass if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index b4fc44767..1ac55b5e8 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field @@ -12,10 +11,12 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.wraps import account_initialization_required, setup_required +from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -26,28 +27,13 @@ from services.workflow_service import WorkflowService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -class RuleGeneratePayload(BaseModel): - instruction: str = Field(..., description="Rule generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - no_variable: bool = Field(default=False, description="Whether to exclude variables") - - -class RuleCodeGeneratePayload(RuleGeneratePayload): - code_language: str = Field(default="javascript", description="Programming language for code generation") - - -class RuleStructuredOutputPayload(BaseModel): - instruction: str = Field(..., description="Structured output generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - - class InstructionGeneratePayload(BaseModel): flow_id: str = Field(..., description="Workflow/Flow ID") node_id: str = Field(default="", description="Node ID for workflow context") current: str = Field(default="", description="Current instruction text") language: str = Field(default="javascript", description="Programming language (javascript/python)") instruction: str = Field(..., description="Instruction for generation") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") ideal_output: str = Field(default="", description="Expected ideal output") @@ -64,6 +50,7 @@ reg(RuleCodeGeneratePayload) reg(RuleStructuredOutputPayload) reg(InstructionGeneratePayload) reg(InstructionTemplatePayload) +reg(ModelConfig) @console_ns.route("/rule-generate") @@ -82,12 +69,7 @@ class RuleGenerateApi(Resource): _, current_tenant_id = current_account_with_tenant() try: - rules = LLMGenerator.generate_rule_config( - tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=args.no_variable, - ) + rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -118,9 +100,7 @@ class RuleCodeGenerateApi(Resource): try: code_result = LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.code_language, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -152,8 +132,7 @@ class RuleStructuredOutputGenerateApi(Resource): try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -204,23 +183,29 @@ class InstructionGenerateApi(Resource): case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.language, + args=RuleCodeGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.language, + ), ) case _: return {"error": f"invalid node type: {node_type}"} diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 12ada8b79..0be3e0ec4 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, select from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft from services.message_service import MessageService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ChatMessagesQuery(BaseModel): @@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel): raise ValueError("has_comment must be a boolean value") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class AnnotationCountResponse(BaseModel): + count: int = Field(description="Number of annotations") -reg(ChatMessagesQuery) -reg(MessageFeedbackPayload) -reg(FeedbackExportQuery) +class SuggestedQuestionsResponse(BaseModel): + data: list[str] = Field(description="Suggested question") + + +register_schema_models( + console_ns, + ChatMessagesQuery, + MessageFeedbackPayload, + FeedbackExportQuery, + AnnotationCountResponse, + SuggestedQuestionsResponse, +) # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -231,7 +240,7 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): - args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatMessagesQuery.model_validate(request.args.to_dict()) conversation = ( db.session.query(Conversation) @@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource): @console_ns.response( 200, "Annotation count retrieved successfully", - console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + console_ns.models[AnnotationCountResponse.__name__], ) @get_app_model @setup_required @@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource): @console_ns.response( 200, "Suggested questions retrieved successfully", - console_ns.model( - "SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))} - ), + console_ns.models[SuggestedQuestionsResponse.__name__], ) @console_ns.response(404, "Message or conversation not found") @setup_required @@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource): @login_required @account_initialization_required def get(self, app_model): - args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FeedbackExportQuery.model_validate(request.args.to_dict()) # Import the service function from services.feedback_service import FeedbackService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0dd7d33ae..3a3278ec9 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,9 +2,11 @@ import logging import httpx from flask import current_app, redirect, request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field from configs import dify_config +from controllers.common.schema import register_schema_models from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required, logger = logging.getLogger(__name__) +class OAuthDataSourceResponse(BaseModel): + data: str = Field(description="Authorization URL or 'internal' for internal setup") + + +class OAuthDataSourceBindingResponse(BaseModel): + result: str = Field(description="Operation result") + + +class OAuthDataSourceSyncResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + OAuthDataSourceResponse, + OAuthDataSourceBindingResponse, + OAuthDataSourceSyncResponse, +) + + def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( @@ -34,10 +56,7 @@ class OAuthDataSource(Resource): @console_ns.response( 200, "Authorization URL or internal setup success", - console_ns.model( - "OAuthDataSourceResponse", - {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, - ), + console_ns.models[OAuthDataSourceResponse.__name__], ) @console_ns.response(400, "Invalid provider") @console_ns.response(403, "Admin privileges required") @@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource): @console_ns.response( 200, "Data source binding success", - console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceBindingResponse.__name__], ) @console_ns.response(400, "Invalid provider or code") def get(self, provider: str): @@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource): @console_ns.response( 200, "Data source sync success", - console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceSyncResponse.__name__], ) @console_ns.response(400, "Invalid provider or sync failed") @setup_required diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 394f205d9..1ed931b0d 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,10 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailCodeError, @@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel): return valid_password(value) -for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ForgotPasswordEmailResponse(BaseModel): + result: str = Field(description="Operation result") + data: str | None = Field(default=None, description="Reset token") + code: str | None = Field(default=None, description="Error code if account not found") + + +class ForgotPasswordCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether code is valid") + email: EmailStr = Field(description="Email address") + token: str = Field(description="New reset token") + + +class ForgotPasswordResetResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + ForgotPasswordSendPayload, + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordEmailResponse, + ForgotPasswordCheckResponse, + ForgotPasswordResetResponse, +) @console_ns.route("/forgot-password") @@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): @console_ns.response( 200, "Email sent successfully", - console_ns.model( - "ForgotPasswordEmailResponse", - { - "result": fields.String(description="Operation result"), - "data": fields.String(description="Reset token"), - "code": fields.String(description="Error code if account not found"), - }, - ), + console_ns.models[ForgotPasswordEmailResponse.__name__], ) @console_ns.response(400, "Invalid email or rate limit exceeded") @setup_required @@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource): @console_ns.response( 200, "Code verified successfully", - console_ns.model( - "ForgotPasswordCheckResponse", - { - "is_valid": fields.Boolean(description="Whether code is valid"), - "email": fields.String(description="Email address"), - "token": fields.String(description="New reset token"), - }, - ), + console_ns.models[ForgotPasswordCheckResponse.__name__], ) @console_ns.response(400, "Invalid code or token") @setup_required @@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource): @console_ns.response( 200, "Password reset successfully", - console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[ForgotPasswordResetResponse.__name__], ) @console_ns.response(400, "Invalid token or password mismatch") @setup_required diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6162d88a0..38ea5d2da 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource): grant_type = OAuthGrantType(payload.grant_type) except ValueError: raise BadRequest("invalid grant_type") + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + if not payload.code: + raise BadRequest("code is required") - if grant_type == OAuthGrantType.AUTHORIZATION_CODE: - if not payload.code: - raise BadRequest("code is required") + if payload.client_secret != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") - if payload.client_secret != oauth_provider_app.client_secret: - raise BadRequest("client_secret is invalid") + if payload.redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") - if payload.redirect_uri not in oauth_provider_app.redirect_uris: - raise BadRequest("redirect_uri is invalid") + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=payload.code, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + case OAuthGrantType.REFRESH_TOKEN: + if not payload.refresh_token: + raise BadRequest("refresh_token is required") - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, code=payload.code, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) - elif grant_type == OAuthGrantType.REFRESH_TOKEN: - if not payload.refresh_token: - raise BadRequest("refresh_token is required") - - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) @console_ns.route("/oauth/provider/account") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01e9bf77c..daef4e005 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator -from typing import Any, cast +from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -157,9 +157,8 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, binding_id, action): + def patch(self, binding_id, action: Literal["enable", "disable"]): binding_id = str(binding_id) - action = str(action) with Session(db.engine) as session: data_source_binding = session.execute( select(DataSourceOauthBinding).filter_by(id=binding_id) @@ -167,23 +166,24 @@ class DataSourceApi(Resource): if data_source_binding is None: raise NotFound("Data source binding not found.") # enable binding - if action == "enable": - if data_source_binding.disabled: - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is not disabled.") - # disable binding - if action == "disable": - if not data_source_binding.disabled: - data_source_binding.disabled = True - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is disabled.") + match action: + case "enable": + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is not disabled.") + # disable binding + case "disable": + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is disabled.") return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8fbbc51e2..30e4ed111 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None + summary_index_setting: dict[str, Any] | None = None partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None @@ -288,7 +289,14 @@ class DatasetListApi(Resource): @enterprise_license_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - query = ConsoleDatasetListQuery.model_validate(request.args.to_dict()) + # Convert query parameters to dict, handling list parameters correctly + query_params: dict[str, str | list[str]] = dict(request.args.to_dict()) + # Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value) + if "ids" in request.args: + query_params["ids"] = request.args.getlist("ids") + if "tag_ids" in request.args: + query_params["tag_ids"] = request.args.getlist("tag_ids") + query = ConsoleDatasetListQuery.model_validate(query_params) # provider = request.args.get("provider", default="vendor") if query.ids: datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 57fb9abf2..bf097d374 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService +from tasks.generate_summary_index_task import generate_summary_index_task from ..app.error import ( ProviderModelCurrentlyNotSupportError, @@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel): name: str +class GenerateSummaryPayload(BaseModel): + document_list: list[str] + + class DocumentBatchDownloadZipPayload(BaseModel): """Request payload for bulk downloading documents as a zip archive.""" @@ -125,6 +130,7 @@ register_schema_models( RetrievalModel, DocumentRetryPayload, DocumentRenamePayload, + GenerateSummaryPayload, DocumentBatchDownloadZipPayload, ) @@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource): paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items + + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=current_tenant_id, + ) + if fetch: for document in documents: completed_segments = ( @@ -563,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict + match document.data_source_type: + case "upload_file": + if not data_source_info: + continue + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) + .first() + ) - if document.data_source_type == "upload_file": - if not data_source_info: - continue - file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() - ) + if file_detail is None: + raise NotFound("File not found.") - if file_detail is None: - raise NotFound("File not found.") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form + ) + extract_settings.append(extract_setting) + case "notion_import": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_tenant_id, + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) + case "website_crawl": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) - extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form - ) - extract_settings.append(extract_setting) - - elif document.data_source_type == "notion_import": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_tenant_id, - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - elif document.data_source_type == "website_crawl": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - - else: - raise ValueError("Data source type not support") + case _: + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( @@ -797,6 +809,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -832,6 +845,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response, 200 @@ -939,23 +953,24 @@ class DocumentProcessingApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() - if action == "pause": - if document.indexing_status != "indexing": - raise InvalidActionError("Document not in indexing state.") + match action: + case "pause": + if document.indexing_status != "indexing": + raise InvalidActionError("Document not in indexing state.") - document.paused_by = current_user.id - document.paused_at = naive_utc_now() - document.is_paused = True - db.session.commit() + document.paused_by = current_user.id + document.paused_at = naive_utc_now() + document.is_paused = True + db.session.commit() - elif action == "resume": - if document.indexing_status not in {"paused", "error"}: - raise InvalidActionError("Document not in paused or error state.") + case "resume": + if document.indexing_status not in {"paused", "error"}: + raise InvalidActionError("Document not in paused or error state.") - document.paused_by = None - document.paused_at = None - document.is_paused = False - db.session.commit() + document.paused_by = None + document.paused_at = None + document.is_paused = False + db.session.commit() return {"result": "success"}, 200 @@ -1255,3 +1270,149 @@ class DocumentPipelineExecutionLogApi(DocumentResource): "input_data": log.input_data, "datasource_node_id": log.datasource_node_id, }, 200 + + +@console_ns.route("/datasets//documents/generate-summary") +class DocumentGenerateSummaryApi(Resource): + @console_ns.doc("generate_summary_for_documents") + @console_ns.doc(description="Generate summary index for documents") + @console_ns.doc(params={"dataset_id": "Dataset ID"}) + @console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__]) + @console_ns.response(200, "Summary generation started successfully") + @console_ns.response(400, "Invalid request or dataset configuration") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self, dataset_id): + """ + Generate summary index for specified documents. + + This endpoint checks if the dataset configuration supports summary generation + (indexing_technique must be 'high_quality' and summary_index_setting.enable must be true), + then asynchronously generates summary indexes for the provided documents. + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + if not current_user.is_dataset_editor: + raise Forbidden() + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Validate request payload + payload = GenerateSummaryPayload.model_validate(console_ns.payload or {}) + document_list = payload.document_list + + if not document_list: + from werkzeug.exceptions import BadRequest + + raise BadRequest("document_list cannot be empty.") + + # Check if dataset configuration supports summary generation + if dataset.indexing_technique != "high_quality": + raise ValueError( + f"Summary generation is only available for 'high_quality' indexing technique. " + f"Current indexing technique: {dataset.indexing_technique}" + ) + + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.") + + # Verify all documents exist and belong to the dataset + documents = DocumentService.get_documents_by_ids(dataset_id, document_list) + + if len(documents) != len(document_list): + found_ids = {doc.id for doc in documents} + missing_ids = set(document_list) - found_ids + raise NotFound(f"Some documents not found: {list(missing_ids)}") + + # Update need_summary to True for documents that don't have it set + # This handles the case where documents were created when summary_index_setting was disabled + documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"] + + if documents_to_update: + document_ids_to_update = [str(doc.id) for doc in documents_to_update] + DocumentService.update_documents_need_summary( + dataset_id=dataset_id, + document_ids=document_ids_to_update, + need_summary=True, + ) + + # Dispatch async tasks for each document + for document in documents: + # Skip qa_model documents as they don't generate summaries + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + continue + + # Dispatch async task + generate_summary_index_task.delay(dataset_id, document.id) + logger.info( + "Dispatched summary generation task for document %s in dataset %s", + document.id, + dataset_id, + ) + + return {"result": "success"}, 200 + + +@console_ns.route("/datasets//documents//summary-status") +class DocumentSummaryStatusApi(DocumentResource): + @console_ns.doc("get_document_summary_status") + @console_ns.doc(description="Get summary index generation status for a document") + @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @console_ns.response(200, "Summary status retrieved successfully") + @console_ns.response(404, "Document not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + """ + Get summary index generation status for a document. + + Returns: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + document_id = str(document_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Get summary status detail from service + from services.summary_index_service import SummaryIndexService + + result = SummaryIndexService.get_document_summary_status_detail( + document_id=document_id, + dataset_id=dataset_id, + ) + + return result, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 08e1ddd3e..23a668112 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +def _get_segment_with_summary(segment, dataset_id): + """Helper function to marshal segment and add summary information.""" + from services.summary_index_service import SummaryIndexService + + segment_dict = dict(marshal(segment, segment_fields)) + # Query summary for this segment (only enabled summaries) + summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) + segment_dict["summary"] = summary.summary_content if summary else None + return segment_dict + + class SegmentListQuery(BaseModel): limit: int = Field(default=20, ge=1, le=100) status: list[str] = Field(default_factory=list) @@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class BatchImportPayload(BaseModel): @@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource): segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) + # Query summaries for all segments in this page (batch query for efficiency) + segment_ids = [segment.id for segment in segments.items] + summaries = {} + if segment_ids: + from services.summary_index_service import SummaryIndexService + + summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + # Only include enabled summaries (already filtered by service) + summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()} + + # Add summary to each segment + segments_with_summary = [] + for segment in segments.items: + segment_dict = dict(marshal(segment, segment_fields)) + segment_dict["summary"] = summaries.get(segment.id) + segments_with_summary.append(segment_dict) + response = { - "data": marshal(segments.items, segment_fields), + "data": segments_with_summary, "limit": limit, "total": segments.total, "total_pages": segments.pages, @@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource): payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) segment = SegmentService.create_segment(payload_dict, document, dataset) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @console_ns.route("/datasets//documents//segments/") @@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): payload = SegmentUpdatePayload.model_validate(console_ns.payload or {}) payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) + + # Update segment (summary update with change detection is handled in SegmentService.update_segment) segment = SegmentService.update_segment( SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset ) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @setup_required @login_required diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 932cb4fcc..e62be13c2 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,6 +1,13 @@ -from flask_restx import Resource +from flask_restx import Resource, fields from controllers.common.schema import register_schema_model +from fields.hit_testing_fields import ( + child_chunk_fields, + document_fields, + files_fields, + hit_testing_record_fields, + segment_fields, +) from libs.login import login_required from .. import console_ns @@ -14,13 +21,45 @@ from ..wraps import ( register_schema_model(console_ns, HitTestingPayload) +def _get_or_create_model(model_name: str, field_def): + """Get or create a flask_restx model to avoid dict type issues in Swagger.""" + existing = console_ns.models.get(model_name) + if existing is None: + existing = console_ns.model(model_name, field_def) + return existing + + +# Register models for flask_restx to avoid dict type issues in Swagger +document_model = _get_or_create_model("HitTestingDocument", document_fields) + +segment_fields_copy = segment_fields.copy() +segment_fields_copy["document"] = fields.Nested(document_model) +segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + +child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) +files_model = _get_or_create_model("HitTestingFile", files_fields) + +hit_testing_record_fields_copy = hit_testing_record_fields.copy() +hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) +hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) +hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) +hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) + +# Response model for hit testing API +hit_testing_response_fields = { + "query": fields.String, + "records": fields.List(fields.Nested(hit_testing_record_model)), +} +hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + + @console_ns.route("/datasets//hit-testing") class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc("test_dataset_retrieval") @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully") + @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 05fc4cd71..2e69ddc5a 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d34fd5088..29b6b64b9 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,10 +1,9 @@ import json import logging from typing import Any, Literal, cast -from uuid import UUID from flask import abort, request -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper -from libs.helper import TimestampField +from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline @@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel): class WorkflowRunQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) @@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel): start_node_title: str +class RagPipelineRecommendedPluginQuery(BaseModel): + type: str = "all" + + register_schema_models( console_ns, DraftWorkflowSyncPayload, @@ -135,6 +138,7 @@ register_schema_models( NodeIdQuery, WorkflowRunQuery, DatasourceVariablesPayload, + RagPipelineRecommendedPluginQuery, ) @@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, location="args", required=False, default="all") - args = parser.parse_args() - type = args["type"] + query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict()) rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) + recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type) return recommended_plugins diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1eb0cdb01..cd523b481 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -9,7 +9,7 @@ import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse from controllers.common.schema import get_or_create_model -from controllers.console import api, console_ns +from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -51,7 +51,7 @@ from fields.app_fields import ( tag_fields, ) from fields.dataset_fields import dataset_fields -from fields.member_fields import build_simple_account_model +from fields.member_fields import simple_account_fields from fields.workflow_fields import ( conversation_variable_fields, pipeline_variable_fields, @@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) -simple_account_model = build_simple_account_model(console_ns) +simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields) conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 2bebe79ea..f086bf186 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,87 +1,74 @@ import os +from typing import Literal from flask import session -from flask_restx import Resource, fields from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from controllers.fastopenapi import console_router from extensions.ext_database import db from models.model import DifySetup from services.account_service import TenantService -from . import console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class InitValidatePayload(BaseModel): - password: str = Field(..., max_length=30) + password: str = Field(..., max_length=30, description="Initialization password") -console_ns.schema_model( - InitValidatePayload.__name__, - InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +class InitStatusResponse(BaseModel): + status: Literal["finished", "not_started"] = Field(..., description="Initialization status") + + +class InitValidateResponse(BaseModel): + result: str = Field(description="Operation result", examples=["success"]) + + +@console_router.get( + "/init", + response_model=InitStatusResponse, + tags=["console"], ) +def get_init_status() -> InitStatusResponse: + """Get initialization validation status.""" + init_status = get_init_validate_status() + if init_status: + return InitStatusResponse(status="finished") + return InitStatusResponse(status="not_started") -@console_ns.route("/init") -class InitValidateAPI(Resource): - @console_ns.doc("get_init_status") - @console_ns.doc(description="Get initialization validation status") - @console_ns.response( - 200, - "Success", - model=console_ns.model( - "InitStatusResponse", - {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, - ), - ) - def get(self): - """Get initialization validation status""" - init_status = get_init_validate_status() - if init_status: - return {"status": "finished"} - return {"status": "not_started"} +@console_router.post( + "/init", + response_model=InitValidateResponse, + tags=["console"], + status_code=201, +) +@only_edition_self_hosted +def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse: + """Validate initialization password.""" + tenant_count = TenantService.get_tenant_count() + if tenant_count > 0: + raise AlreadySetupError() - @console_ns.doc("validate_init_password") - @console_ns.doc(description="Validate initialization password for self-hosted edition") - @console_ns.expect(console_ns.models[InitValidatePayload.__name__]) - @console_ns.response( - 201, - "Success", - model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), - ) - @console_ns.response(400, "Already setup or validation failed") - @only_edition_self_hosted - def post(self): - """Validate initialization password""" - # is tenant created - tenant_count = TenantService.get_tenant_count() - if tenant_count > 0: - raise AlreadySetupError() + if payload.password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False + raise InitValidateFailedError() - payload = InitValidatePayload.model_validate(console_ns.payload) - input_password = payload.password - - if input_password != os.environ.get("INIT_PASSWORD"): - session["is_init_validated"] = False - raise InitValidateFailedError() - - session["is_init_validated"] = True - return {"result": "success"}, 201 + session["is_init_validated"] = True + return InitValidateResponse(result="success") -def get_init_validate_status(): +def get_init_validate_status() -> bool: if dify_config.EDITION == "SELF_HOSTED": if os.environ.get("INIT_PASSWORD"): if session.get("is_init_validated"): return True with Session(db.engine) as db_session: - return db_session.execute(select(DifySetup)).scalar_one_or_none() + return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None return True diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 70c7b80ff..88a9ce3a7 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from flask_restx import Resource from pydantic import BaseModel, Field import services @@ -11,7 +10,7 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) -from controllers.common.schema import register_schema_models +from controllers.fastopenapi import console_router from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db @@ -19,84 +18,74 @@ from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant from services.file_service import FileService -from . import console_ns - -register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl) - - -@console_ns.route("/remote-files/") -class RemoteFileInfoApi(Resource): - @console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__]) - def get(self, url): - decoded_url = urllib.parse.unquote(url) - resp = ssrf_proxy.head(decoded_url) - if resp.status_code != httpx.codes.OK: - # failed back to get method - resp = ssrf_proxy.get(decoded_url, timeout=3) - resp.raise_for_status() - info = RemoteFileInfo( - file_type=resp.headers.get("Content-Type", "application/octet-stream"), - file_length=int(resp.headers.get("Content-Length", 0)), - ) - return info.model_dump(mode="json") - class RemoteFileUploadPayload(BaseModel): url: str = Field(..., description="URL to fetch") -console_ns.schema_model( - RemoteFileUploadPayload.__name__, - RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"), +@console_router.get( + "/remote-files/", + response_model=RemoteFileInfo, + tags=["console"], ) +def get_remote_file_info(url: str) -> RemoteFileInfo: + decoded_url = urllib.parse.unquote(url) + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", 0)), + ) -@console_ns.route("/remote-files/upload") -class RemoteFileUploadApi(Resource): - @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) - @console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__]) - def post(self): - args = RemoteFileUploadPayload.model_validate(console_ns.payload) - url = args.url +@console_router.post( + "/remote-files/upload", + response_model=FileWithSignedUrl, + tags=["console"], + status_code=201, +) +def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl: + url = payload.url - try: - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - if resp.status_code != httpx.codes.OK: - raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") - except httpx.RequestError as e: - raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") - file_info = helpers.guess_file_info_from_response(resp) + file_info = helpers.guess_file_info_from_response(resp) - if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): - raise FileTooLargeError + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + raise FileTooLargeError - content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content - try: - user, _ = current_account_with_tenant() - upload_file = FileService(db.engine).upload_file( - filename=file_info.filename, - content=content, - mimetype=file_info.mimetype, - user=user, - source_url=url, - ) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - - payload = FileWithSignedUrl( - id=upload_file.id, - name=upload_file.name, - size=upload_file.size, - extension=upload_file.extension, - url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - mime_type=upload_file.mime_type, - created_by=upload_file.created_by, - created_at=int(upload_file.created_at.timestamp()), + try: + user, _ = current_account_with_tenant() + upload_file = FileService(db.engine).upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=user, + source_url=url, ) - return payload.model_dump(mode="json"), 201 + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index b598d6382..4d656b55b 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,18 +1,28 @@ from typing import Literal from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Namespace, Resource, fields, marshal_with from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required -from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required from services.recommended_app_service_extend import RecommendedAppService from services.tag_service import TagService +dataset_tag_fields = { + "id": fields.String, + "name": fields.String, + "type": fields.String, + "binding_count": fields.String, +} + + +def build_dataset_tag_fields(api_or_ns: Namespace): + return api_or_ns.model("DataSetTag", dataset_tag_fields) + class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 38c66525b..708df6264 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -37,7 +38,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_fields +from fields.member_fields import Account as AccountResponse from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required @@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload) reg(ChangeEmailValidityPayload) reg(ChangeEmailResetPayload) reg(CheckEmailUniquePayload) +register_schema_models(console_ns, AccountResponse) + + +def _serialize_account(account) -> dict: + return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") + integrate_fields = { "provider": fields.String, @@ -236,11 +243,11 @@ class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) @enterprise_license_required def get(self): current_user, _ = current_account_with_tenant() - return current_user + return _serialize_account(current_user) @console_ns.route("/account/name") @@ -249,14 +256,14 @@ class AccountNameApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) updated_account = AccountService.update_account(current_user, name=args.name) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/avatar") @@ -265,7 +272,7 @@ class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -273,7 +280,7 @@ class AccountAvatarApi(Resource): updated_account = AccountService.update_account(current_user, avatar=args.avatar) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-language") @@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource): updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-theme") @@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource): updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/timezone") @@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource): updated_account = AccountService.update_account(current_user, timezone=args.timezone) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/password") @@ -333,7 +340,7 @@ class AccountPasswordApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -344,7 +351,7 @@ class AccountPasswordApi(Resource): except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() - return {"result": "success"} + return _serialize_account(current_user) @console_ns.route("/account/integrates") @@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) @@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource): email=normalized_new_email, ) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/change-email/check-email-unique") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index bfd9fc6c2..1897cbdca 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,9 +1,10 @@ from typing import Any from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery): plugin_id: str +class EndpointCreateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class PluginEndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class EndpointDeleteResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointUpdateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointEnableResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointDisableResponse(BaseModel): + success: bool = Field(description="Operation success") + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -reg(EndpointCreatePayload) -reg(EndpointIdPayload) -reg(EndpointUpdatePayload) -reg(EndpointListQuery) -reg(EndpointListForPluginQuery) +register_schema_models( + console_ns, + EndpointCreatePayload, + EndpointIdPayload, + EndpointUpdatePayload, + EndpointListQuery, + EndpointListForPluginQuery, + EndpointCreateResponse, + EndpointListResponse, + PluginEndpointListResponse, + EndpointDeleteResponse, + EndpointUpdateResponse, + EndpointEnableResponse, + EndpointDisableResponse, +) @console_ns.route("/workspaces/current/endpoints/create") @@ -57,7 +96,7 @@ class EndpointCreateApi(Resource): @console_ns.response( 200, "Endpoint created successfully", - console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointCreateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -91,9 +130,7 @@ class EndpointListApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[EndpointListResponse.__name__], ) @setup_required @login_required @@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[PluginEndpointListResponse.__name__], ) @setup_required @login_required @@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource): @console_ns.response( 200, "Endpoint deleted successfully", - console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDeleteResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource): @console_ns.response( 200, "Endpoint updated successfully", - console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointUpdateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -221,7 +256,7 @@ class EndpointEnableApi(Resource): @console_ns.response( 200, "Endpoint enabled successfully", - console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointEnableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -248,7 +283,7 @@ class EndpointDisableApi(Resource): @console_ns.response( 200, "Endpoint disabled successfully", - console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDisableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 271cdce3c..dd302b90d 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,12 +1,12 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter import services from configs import dify_config -from controllers.common.schema import get_or_create_model, register_enum_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, @@ -25,7 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_with_role_fields, account_with_role_list_fields +from fields.member_fields import AccountWithRole, AccountWithRoleList from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole @@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload) reg(OwnerTransferCheckPayload) reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) - -account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields) - -account_with_role_list_fields_copy = account_with_role_list_fields.copy() -account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model)) -account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy) +register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) @console_ns.route("/workspaces/current/members") @@ -84,13 +79,15 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/invite-email") @@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e9e7b7271..5bfa89584 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,16 +1,16 @@ import io import logging +from typing import Any, Literal from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_restx import ( - Resource, - reqparse, -) +from flask_restx import Resource +from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -26,8 +26,9 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from libs.helper import StrLen, alphanumeric, uuid_value +from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -52,24 +53,209 @@ def is_valid_url(url: str) -> bool: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] except (ValueError, TypeError): - # ValueError: Invalid URL format - # TypeError: url is not a string return False -parser_tool = reqparse.RequestParser().add_argument( - "type", - type=str, - choices=["builtin", "model", "api", "workflow", "mcp"], - required=False, - nullable=True, - location="args", +class ToolProviderListQuery(BaseModel): + type: Literal["builtin", "model", "api", "workflow", "mcp"] | None = None + + +class BuiltinToolCredentialDeletePayload(BaseModel): + credential_id: str + + +class BuiltinToolAddPayload(BaseModel): + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + type: CredentialType + + +class BuiltinToolUpdatePayload(BaseModel): + credential_id: str + credentials: dict[str, Any] | None = None + name: str | None = Field(default=None, max_length=30) + + +class ApiToolProviderBasePayload(BaseModel): + credentials: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + provider: str + icon: dict[str, Any] + privacy_policy: str | None = None + labels: list[str] | None = None + custom_disclaimer: str = "" + + +class ApiToolProviderAddPayload(ApiToolProviderBasePayload): + pass + + +class ApiToolProviderUpdatePayload(ApiToolProviderBasePayload): + original_provider: str + + +class UrlQuery(BaseModel): + url: HttpUrl + + +class ProviderQuery(BaseModel): + provider: str + + +class ApiToolProviderDeletePayload(BaseModel): + provider: str + + +class ApiToolSchemaPayload(BaseModel): + schema_: str = Field(alias="schema") + + +class ApiToolTestPayload(BaseModel): + tool_name: str + provider_name: str | None = None + credentials: dict[str, Any] + parameters: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + + +class WorkflowToolBasePayload(BaseModel): + name: str + label: str + description: str + icon: dict[str, Any] + parameters: list[WorkflowToolParameterConfiguration] = Field(default_factory=list) + privacy_policy: str | None = "" + labels: list[str] | None = None + + @field_validator("name") + @classmethod + def validate_name(cls, value: str) -> str: + return alphanumeric(value) + + +class WorkflowToolCreatePayload(WorkflowToolBasePayload): + workflow_app_id: str + + @field_validator("workflow_app_id") + @classmethod + def validate_workflow_app_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolUpdatePayload(WorkflowToolBasePayload): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolDeletePayload(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolGetQuery(BaseModel): + workflow_tool_id: str | None = None + workflow_app_id: str | None = None + + @field_validator("workflow_tool_id", "workflow_app_id") + @classmethod + def validate_ids(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + @model_validator(mode="after") + def ensure_one(self) -> "WorkflowToolGetQuery": + if not self.workflow_tool_id and not self.workflow_app_id: + raise ValueError("workflow_tool_id or workflow_app_id is required") + return self + + +class WorkflowToolListQuery(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class BuiltinProviderDefaultCredentialPayload(BaseModel): + id: str + + +class ToolOAuthCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = True + + +class MCPProviderBasePayload(BaseModel): + server_url: str + name: str + icon: str + icon_type: str + icon_background: str = "" + server_identifier: str + configuration: dict[str, Any] | None = Field(default_factory=dict) + headers: dict[str, Any] | None = Field(default_factory=dict) + authentication: dict[str, Any] | None = Field(default_factory=dict) + + +class MCPProviderCreatePayload(MCPProviderBasePayload): + pass + + +class MCPProviderUpdatePayload(MCPProviderBasePayload): + provider_id: str + + +class MCPProviderDeletePayload(BaseModel): + provider_id: str + + +class MCPAuthPayload(BaseModel): + provider_id: str + authorization_code: str | None = None + + +class MCPCallbackQuery(BaseModel): + code: str + state: str + + +register_schema_models( + console_ns, + BuiltinToolCredentialDeletePayload, + BuiltinToolAddPayload, + BuiltinToolUpdatePayload, + ApiToolProviderAddPayload, + ApiToolProviderUpdatePayload, + ApiToolProviderDeletePayload, + ApiToolSchemaPayload, + ApiToolTestPayload, + WorkflowToolCreatePayload, + WorkflowToolUpdatePayload, + WorkflowToolDeletePayload, + BuiltinProviderDefaultCredentialPayload, + ToolOAuthCustomClientPayload, + MCPProviderCreatePayload, + MCPProviderUpdatePayload, + MCPProviderDeletePayload, + MCPAuthPayload, ) @console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): - @console_ns.expect(parser_tool) @setup_required @login_required @account_initialization_required @@ -78,9 +264,10 @@ class ToolProviderListApi(Resource): user_id = user.id - args = parser_tool.parse_args() + raw_args = request.args.to_dict() + query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -110,14 +297,9 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) -parser_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): - @console_ns.expect(parser_delete) + @console_ns.expect(console_ns.models[BuiltinToolCredentialDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -125,26 +307,18 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): _, tenant_id = current_account_with_tenant() - args = parser_delete.parse_args() + payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.delete_builtin_tool_provider( tenant_id, provider, - args["credential_id"], + payload.credential_id, ) -parser_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") - .add_argument("type", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): - @console_ns.expect(parser_add) + @console_ns.expect(console_ns.models[BuiltinToolAddPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -153,32 +327,21 @@ class ToolBuiltinProviderAddApi(Resource): user_id = user.id - args = parser_add.parse_args() - - if args["type"] not in CredentialType.values(): - raise ValueError(f"Invalid credential type: {args['type']}") + payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credentials=args["credentials"], - name=args["name"], - api_type=CredentialType.of(args["type"]), + credentials=payload.credentials, + name=payload.name, + api_type=CredentialType.of(payload.type), ) -parser_update = ( - reqparse.RequestParser() - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): - @console_ns.expect(parser_update) + @console_ns.expect(console_ns.models[BuiltinToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -187,15 +350,15 @@ class ToolBuiltinProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_update.parse_args() + payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {}) result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credential_id=args["credential_id"], - credentials=args.get("credentials", None), - name=args.get("name", ""), + credential_id=payload.credential_id, + credentials=payload.credentials, + name=payload.name or "", ) return result @@ -225,22 +388,9 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -parser_api_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) - .add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): - @console_ns.expect(parser_api_add) + @console_ns.expect(console_ns.models[ApiToolProviderAddPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -250,28 +400,24 @@ class ToolApiProviderAddApi(Resource): user_id = user.id - args = parser_api_add.parse_args() + payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args["provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args.get("privacy_policy", ""), - args.get("custom_disclaimer", ""), - args.get("labels", []), + payload.provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy or "", + payload.custom_disclaimer or "", + payload.labels or [], ) -parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): - @console_ns.expect(parser_remote) @setup_required @login_required @account_initialization_required @@ -280,23 +426,18 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - args = parser_remote.parse_args() + raw_args = request.args.to_dict() + query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( user_id, tenant_id, - args["url"], + str(query.url), ) -parser_tools = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): - @console_ns.expect(parser_tools) @setup_required @login_required @account_initialization_required @@ -305,34 +446,21 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - args = parser_tools.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, - args["provider"], + query.provider, ) ) -parser_api_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("original_provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") - .add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): - @console_ns.expect(parser_api_update) + @console_ns.expect(console_ns.models[ApiToolProviderUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -342,31 +470,26 @@ class ToolApiProviderUpdateApi(Resource): user_id = user.id - args = parser_api_update.parse_args() + payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args["provider"], - args["original_provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args["privacy_policy"], - args["custom_disclaimer"], - args.get("labels", []), + payload.provider, + payload.original_provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy, + payload.custom_disclaimer, + payload.labels or [], ) -parser_api_delete = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): - @console_ns.expect(parser_api_delete) + @console_ns.expect(console_ns.models[ApiToolProviderDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -376,21 +499,17 @@ class ToolApiProviderDeleteApi(Resource): user_id = user.id - args = parser_api_delete.parse_args() + payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args["provider"], + payload.provider, ) -parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): - @console_ns.expect(parser_get) @setup_required @login_required @account_initialization_required @@ -399,12 +518,13 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - args = parser_get.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args["provider"], + query.provider, ) @@ -423,72 +543,43 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) -parser_schema = reqparse.RequestParser().add_argument( - "schema", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): - @console_ns.expect(parser_schema) + @console_ns.expect(console_ns.models[ApiToolSchemaPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_schema.parse_args() + payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=args["schema"], + schema=payload.schema_, ) -parser_pre = ( - reqparse.RequestParser() - .add_argument("tool_name", type=str, required=True, nullable=False, location="json") - .add_argument("provider_name", type=str, required=False, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): - @console_ns.expect(parser_pre) + @console_ns.expect(console_ns.models[ApiToolTestPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_pre.parse_args() + payload = ApiToolTestPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( current_tenant_id, - args["provider_name"] or "", - args["tool_name"], - args["credentials"], - args["parameters"], - args["schema_type"], - args["schema"], + payload.provider_name or "", + payload.tool_name, + payload.credentials, + payload.parameters, + payload.schema_type, + payload.schema_, ) -parser_create = ( - reqparse.RequestParser() - .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[WorkflowToolCreatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -498,38 +589,25 @@ class ToolWorkflowProviderCreateApi(Resource): user_id = user.id - args = parser_create.parse_args() + payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.create_workflow_tool( user_id=user_id, tenant_id=tenant_id, - workflow_app_id=args["workflow_app_id"], - name=args["name"], - label=args["label"], - icon=args["icon"], - description=args["description"], - parameters=args["parameters"], - privacy_policy=args["privacy_policy"], - labels=args["labels"], + workflow_app_id=payload.workflow_app_id, + name=payload.name, + label=payload.label, + icon=payload.icon, + description=payload.description, + parameters=payload.parameters, + privacy_policy=payload.privacy_policy or "", + labels=payload.labels or [], ) -parser_workflow_update = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): - @console_ns.expect(parser_workflow_update) + @console_ns.expect(console_ns.models[WorkflowToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -538,33 +616,25 @@ class ToolWorkflowProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_workflow_update.parse_args() - - if not args["workflow_tool_id"]: - raise ValueError("incorrect workflow_tool_id") + payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + payload.workflow_tool_id, + payload.name, + payload.label, + payload.icon, + payload.description, + payload.parameters, + payload.privacy_policy or "", + payload.labels or [], ) -parser_workflow_delete = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): - @console_ns.expect(parser_workflow_delete) + @console_ns.expect(console_ns.models[WorkflowToolDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -574,25 +644,17 @@ class ToolWorkflowProviderDeleteApi(Resource): user_id = user.id - args = parser_workflow_delete.parse_args() + payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], + payload.workflow_tool_id, ) -parser_wf_get = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") - .add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): - @console_ns.expect(parser_wf_get) @setup_required @login_required @account_initialization_required @@ -601,19 +663,20 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - args = parser_wf_get.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolGetQuery.model_validate(raw_args) - if args.get("workflow_tool_id"): + if query.workflow_tool_id: tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) - elif args.get("workflow_app_id"): + elif query.workflow_app_id: tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args["workflow_app_id"], + query.workflow_app_id, ) else: raise ValueError("incorrect workflow_tool_id or workflow_app_id") @@ -621,14 +684,8 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) -parser_wf_tools = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): - @console_ns.expect(parser_wf_tools) @setup_required @login_required @account_initialization_required @@ -637,13 +694,14 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - args = parser_wf_tools.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) ) @@ -810,49 +868,39 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_default_cred = reqparse.RequestParser().add_argument( - "id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): - @console_ns.expect(parser_default_cred) + @console_ns.expect(console_ns.models[BuiltinProviderDefaultCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider): current_user, current_tenant_id = current_account_with_tenant() - args = parser_default_cred.parse_args() + payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( - tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id ) -parser_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): - @console_ns.expect(parser_custom) + @console_ns.expect(console_ns.models[ToolOAuthCustomClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - args = parser_custom.parse_args() + payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {}) _, tenant_id = current_account_with_tenant() return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, provider=provider, - client_params=args.get("client_params", {}), - enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + client_params=payload.client_params or {}, + enable_oauth_custom_client=payload.enable_oauth_custom_client + if payload.enable_oauth_custom_client is not None + else True, ) @setup_required @@ -904,49 +952,19 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) -parser_mcp = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_put = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json") - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_delete = reqparse.RequestParser().add_argument( - "provider_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): - @console_ns.expect(parser_mcp) + @console_ns.expect(console_ns.models[MCPProviderCreatePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_mcp.parse_args() + payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {}) user, tenant_id = current_account_with_tenant() # Parse and validate models - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None # 1) Create provider in a short transaction (no network I/O inside) with session_factory.create_session() as session, session.begin(): @@ -954,13 +972,13 @@ class ToolProviderMCPApi(Resource): result = service.create_provider( tenant_id=tenant_id, user_id=user.id, - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, ) @@ -969,8 +987,8 @@ class ToolProviderMCPApi(Resource): # Perform network I/O outside any DB session to avoid holding locks. try: reconnect = MCPToolManageService.reconnect_with_url( - server_url=args["server_url"], - headers=args.get("headers") or {}, + server_url=payload.server_url, + headers=payload.headers or {}, timeout=configuration.timeout, sse_read_timeout=configuration.sse_read_timeout, ) @@ -988,14 +1006,14 @@ class ToolProviderMCPApi(Resource): return jsonable_encoder(result) - @console_ns.expect(parser_mcp_put) + @console_ns.expect(console_ns.models[MCPProviderUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required def put(self): - args = parser_mcp_put.parse_args() - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {}) + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None _, current_tenant_id = current_account_with_tenant() # Step 1: Get provider data for URL validation (short-lived session, no network I/O) @@ -1003,14 +1021,14 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) validation_data = service.get_provider_for_url_validation( - tenant_id=current_tenant_id, provider_id=args["provider_id"] + tenant_id=current_tenant_id, provider_id=payload.provider_id ) # Step 2: Perform URL validation with network I/O OUTSIDE of any database session # This prevents holding database locks during potentially slow network operations validation_result = MCPToolManageService.validate_server_url_standalone( tenant_id=current_tenant_id, - new_server_url=args["server_url"], + new_server_url=payload.server_url, validation_data=validation_data, ) @@ -1019,14 +1037,14 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + provider_id=payload.provider_id, + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, validation_result=validation_result, @@ -1034,37 +1052,30 @@ class ToolProviderMCPApi(Resource): return {"result": "success"} - @console_ns.expect(parser_mcp_delete) + @console_ns.expect(console_ns.models[MCPProviderDeletePayload.__name__]) @setup_required @login_required @account_initialization_required def delete(self): - args = parser_mcp_delete.parse_args() + payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) + service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) return {"result": "success"} -parser_auth = ( - reqparse.RequestParser() - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("authorization_code", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): - @console_ns.expect(parser_auth) + @console_ns.expect(console_ns.models[MCPAuthPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_auth.parse_args() - provider_id = args["provider_id"] + payload = MCPAuthPayload.model_validate(console_ns.payload or {}) + provider_id = payload.provider_id _, tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): @@ -1102,7 +1113,7 @@ class ToolMCPAuthApi(Resource): # Pass the extracted OAuth metadata hints to auth() auth_result = auth( provider_entity, - args.get("authorization_code"), + payload.authorization_code, resource_metadata_url=e.resource_metadata_url, scope_hint=e.scope_hint, ) @@ -1167,20 +1178,13 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) -parser_cb = ( - reqparse.RequestParser() - .add_argument("code", type=str, required=True, nullable=False, location="args") - .add_argument("state", type=str, required=True, nullable=False, location="args") -) - - @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): - @console_ns.expect(parser_cb) def get(self): - args = parser_cb.parse_args() - state_key = args["state"] - authorization_code = args["code"] + raw_args = request.args.to_dict() + query = MCPCallbackQuery.model_validate(raw_args) + state_key = query.state + authorization_code = query.code # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 6b3c1313e..54a1aea70 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,16 +1,16 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields +from flask_restx import Resource from flask_restx.api import HTTPStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import annotation_fields, build_annotation_model +from fields.annotation_fields import Annotation, AnnotationList from models.model import ApiToken, App # extend - 密钥额度限制,新增ApiToken from services.annotation_service import AppAnnotationService @@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel): embedding_model_name: str = Field(description="Embedding model name") -register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) +register_schema_models( + service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList +) @service_api_ns.route("/apps/annotation-reply/") @@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_model.id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args, app_model.id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 @@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -# Define annotation list response model -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), - "has_more": fields.Boolean, - "limit": fields.Integer, - "total": fields.Integer, - "page": fields.Integer, -} - - -def build_annotation_list_model(api_or_ns: Namespace): - """Build the annotation list model for the API or Namespace.""" - copied_annotation_list_fields = annotation_list_fields.copy() - copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) - return api_or_ns.model("AnnotationList", copied_annotation_list_fields) - - @service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): @service_api_ns.doc("list_annotations") @@ -109,8 +95,12 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + 200, + "Annotations retrieved successfully", + service_api_ns.models[AnnotationList.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token """List annotations for the application.""" page = request.args.get("page", default=1, type=int) @@ -118,13 +108,15 @@ class AnnotationListApi(Resource): keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - return { - "data": annotation_list, - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("create_annotation") @@ -135,13 +127,18 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + HTTPStatus.CREATED, + "Annotation created successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): """Create a new annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation, 201 + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json"), HTTPStatus.CREATED @service_api_ns.route("/apps/annotations/") @@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource): 404: "Annotation not found", } ) + @service_api_ns.response( + 200, + "Annotation updated successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token @edit_permission_required - @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) - return annotation + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json") @service_api_ns.doc("delete_annotation") @service_api_ns.doc(description="Delete an annotation") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index ad34e54c5..ba8294a1b 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -30,6 +30,7 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.helper import UUIDStrOrEmpty from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken from services.app_generate_service import AppGenerateService from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken @@ -53,7 +54,7 @@ class ChatRequestPayload(BaseModel): query: str files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None - conversation_id: str | None = Field(default=None, description="Conversation UUID") + conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID") retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 62c8c7ec7..4c8787e06 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,4 @@ from typing import Any, Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -23,12 +22,13 @@ from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_model, ) +from libs.helper import UUIDStrOrEmpty from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错 from services.conversation_service import ConversationService class ConversationListQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( default="-updated_at", description="Sort order for conversations" @@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel): class ConversationVariablesQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") variable_name: str | None = Field( default=None, description="Filter variables by name", min_length=1, max_length=255 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index c11f0d316..fc16d9f6b 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,6 +1,5 @@ import logging from typing import Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem +from libs.helper import UUIDStrOrEmpty from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken from services.errors.message import ( FirstMessageNotExistsError, @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) class MessageListQuery(BaseModel): - conversation_id: UUID - first_id: UUID | None = None + conversation_id: UUIDStrOrEmpty + first_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 28864a140..db5cabe8a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -17,7 +17,7 @@ from controllers.service_api.wraps import ( from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import build_dataset_tag_fields +from fields.tag_fields import DataSetTag from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel): retrieval_model: RetrievalModel | None = None embedding_model: str | None = None embedding_model_provider: str | None = None + summary_index_setting: dict | None = None class DatasetUpdatePayload(BaseModel): @@ -113,6 +114,7 @@ register_schema_models( TagBindingPayload, TagUnbindingPayload, DatasetListQuery, + DataSetTag, ) @@ -217,6 +219,7 @@ class DatasetListApi(DatasetApiResource): embedding_model_provider=payload.embedding_model_provider, embedding_model_name=payload.embedding_model, retrieval_model=payload.retrieval_model, + summary_index_setting=payload.summary_index_setting, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -478,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _): """Get all knowledge type tags.""" assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None tags = TagService.get_tags("knowledge", cid) - - return tags, 200 + tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True) + return [tag.model_dump(mode="json") for tag in tag_models], 200 @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @@ -498,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def post(self, _): """Add a knowledge type tag.""" assert isinstance(current_user, Account) @@ -508,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource): payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @@ -521,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def patch(self, _): assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -534,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c85c1cf81..a01524f1b 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( Segmentation, ) from services.file_service import FileService +from services.summary_index_service import SummaryIndexService class DocumentTextCreatePayload(BaseModel): @@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource): ) documents = paginated_documents.items + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=tenant_id, + ) + response = { "data": marshal(documents, document_fields), "has_more": len(documents) == query_params.limit, @@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource): if metadata not in self.METADATA_CHOICES: raise InvalidMetadataError(f"Invalid metadata value: {metadata}") + # Calculate summary_index_status if needed + summary_index_status = None + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + if has_summary_index and document.need_summary is True: + summary_index_status = SummaryIndexService.get_document_summary_index_status( + document_id=document_id, + dataset_id=dataset_id, + tenant_id=tenant_id, + ) + if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": @@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 8dbb69090..97a70f5d0 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,7 +1,10 @@ -from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.common.schema import register_schema_model +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +register_schema_model(service_api_ns, HitTestingPayload) + @service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): 404: "Dataset not found", } ) + @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Perform hit testing on a dataset. diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index b8d950800..692342a38 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 8186ad95f..6e01cfb71 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -126,14 +126,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe # If caller needs end-user context, attach EndUser to current_user if fetch_user_arg: - if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get("user") - else: - user_id = None + user_id = None + match fetch_user_arg.fetch_from: + case WhereisUserArg.QUERY: + user_id = request.args.get("user") + case WhereisUserArg.JSON: + user_id = request.get_json().get("user") + case WhereisUserArg.FORM: + user_id = request.form.get("user") if not user_id and fetch_user_arg.required: raise ValueError("Arg user must be provided.") diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index c1f336fdd..9b981dfc0 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -14,16 +14,17 @@ class AgentConfigManager: agent_dict = config.get("agent_mode", {}) agent_strategy = agent_dict.get("strategy", "cot") - if agent_strategy == "function_call": - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy in {"cot", "react"}: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if config["model"]["provider"] == "openai": + match agent_strategy: + case "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: + case "cot" | "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + case _: + # old configs, try to detect default strategy + if config["model"]["provider"] == "openai": + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] for tool in agent_dict.get("tools", []): diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 74c6d2eca..d1e2f16b6 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC): "document_name": resource["document_name"], "score": resource["score"], "content": resource["content"], + "summary": resource.get("summary"), } ) metadata["retriever_resources"] = updated_resources diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d3..cefff7be9 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -250,7 +250,7 @@ class WorkflowResponseConverter: data=WorkflowFinishStreamResponse.Data( id=run_id, workflow_id=workflow_id, - status=status.value, + status=status, outputs=encoded_outputs, error=error, elapsed_time=elapsed_time, @@ -340,13 +340,13 @@ class WorkflowResponseConverter: metadata = self._merge_metadata(event.execution_metadata, snapshot) if isinstance(event, QueueNodeSucceededEvent): - status = WorkflowNodeExecutionStatus.SUCCEEDED.value + status = WorkflowNodeExecutionStatus.SUCCEEDED error_message = event.error elif isinstance(event, QueueNodeFailedEvent): - status = WorkflowNodeExecutionStatus.FAILED.value + status = WorkflowNodeExecutionStatus.FAILED error_message = event.error else: - status = WorkflowNodeExecutionStatus.EXCEPTION.value + status = WorkflowNodeExecutionStatus.EXCEPTION error_message = event.error return NodeFinishStreamResponse( @@ -413,7 +413,7 @@ class WorkflowResponseConverter: process_data_truncated=process_data_truncated, outputs=outputs, outputs_truncated=outputs_truncated, - status=WorkflowNodeExecutionStatus.RETRY.value, + status=WorkflowNodeExecutionStatus.RETRY, error=event.error, elapsed_time=elapsed_time, execution_metadata=metadata, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index ea4441b5d..eca96cb07 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] - datasource_type: str = args["datasource_type"] + datasource_type = DatasourceProviderType(args["datasource_type"]) datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user ) @@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator): tenant_id: str, dataset_id: str, built_in_field_enabled: bool, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info: Mapping[str, Any], created_from: str, position: int, @@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator): batch: str, document_form: str, ): - if datasource_type == "local_file": - name = datasource_info.get("name", "untitled") - elif datasource_type == "online_document": - name = datasource_info.get("page", {}).get("page_name", "untitled") - elif datasource_type == "website_crawl": - name = datasource_info.get("title", "untitled") - elif datasource_type == "online_drive": - name = datasource_info.get("name", "untitled") - else: - raise ValueError(f"Unsupported datasource type: {datasource_type}") - + match datasource_type: + case DatasourceProviderType.LOCAL_FILE: + name = datasource_info.get("name", "untitled") + case DatasourceProviderType.ONLINE_DOCUMENT: + name = datasource_info.get("page", {}).get("page_name", "untitled") + case DatasourceProviderType.WEBSITE_CRAWL: + name = datasource_info.get("title", "untitled") + case DatasourceProviderType.ONLINE_DRIVE: + name = datasource_info.get("name", "untitled") + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") document = Document( tenant_id=tenant_id, dataset_id=dataset_id, @@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator): def _format_datasource_info_list( self, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info_list: list[Mapping[str, Any]], pipeline: Pipeline, workflow: Workflow, @@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator): """ Format datasource info list. """ - if datasource_type == "online_drive": + if datasource_type == DatasourceProviderType.ONLINE_DRIVE: all_files: list[Mapping[str, Any]] = [] datasource_node_data = None datasource_nodes = workflow.graph_dict.get("nodes", []) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 79a5e657b..26fb17cce 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class AnnotationReplyAccount(BaseModel): @@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float @@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = True - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = False - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index d4093b524..b1ba3c3e2 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): content: str + summary: str | None = None child_chunks: list[str] | None = None diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index e93e1e441..f4cce0b33 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC): @classmethod def get_default_config(cls) -> DefaultConfig: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": cls.get_language(), - "code": cls.get_default_code(), - "outputs": {"result": {"type": "string", "children": None}}, - }, + variables: list[VariableConfig] = [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ] + outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}} + + config: CodeConfig = { + "variables": variables, + "code_language": cls.get_language(), + "code": cls.get_default_code(), + "outputs": outputs, } + return {"type": "code", "config": config} diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f1b50f360..4e3ad7bb7 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -311,14 +311,18 @@ class IndexingRunner: qa_preview_texts: list[QAPreviewDetail] = [] total_segments = 0 + # doc_form represents the segmentation method (general, parent-child, QA) index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() + # one extract_setting is one source document for extract_setting in extract_settings: # extract processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) + # Extract document content text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + # Cleaning and segmentation documents = index_processor.transform( text_docs, current_user=None, @@ -361,75 +365,82 @@ class IndexingRunner: if doc_form and doc_form == "qa_model": return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[]) + + # Generate summary preview + summary_index_setting = tmp_processing_rule.get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable") and preview_texts: + preview_texts = index_processor.generate_summary_preview( + tenant_id, preview_texts, summary_index_setting, doc_language + ) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: - # load file - if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: - return [] - data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == "upload_file": - if not data_source_info or "upload_file_id" not in data_source_info: - raise ValueError("no upload file found") - stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) - file_detail = db.session.scalars(stmt).one_or_none() + match dataset_document.data_source_type: + case "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: + raise ValueError("no upload file found") + stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) + file_detail = db.session.scalars(stmt).one_or_none() - if file_detail: + if file_detail: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, + upload_file=file_detail, + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): + raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, - upload_file=file_detail, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_workspace_id" not in data_source_info - or "notion_page_id" not in data_source_info - ): - raise ValueError("no notion import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "website_crawl": - if ( - not data_source_info - or "provider" not in data_source_info - or "url" not in data_source_info - or "job_id" not in data_source_info - ): - raise ValueError("no website import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): + raise ValueError("no website import info found") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case _: + return [] # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, diff --git a/api/core/llm_generator/entities.py b/api/core/llm_generator/entities.py new file mode 100644 index 000000000..3bb8d2c89 --- /dev/null +++ b/api/core/llm_generator/entities.py @@ -0,0 +1,20 @@ +"""Shared payload models for LLM generator helpers and controllers.""" + +from pydantic import BaseModel, Field + +from core.app.app_config.entities import ModelConfig + + +class RuleGeneratePayload(BaseModel): + instruction: str = Field(..., description="Rule generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") + no_variable: bool = Field(default=False, description="Whether to exclude variables") + + +class RuleCodeGeneratePayload(RuleGeneratePayload): + code_language: str = Field(default="javascript", description="Programming language for code generation") + + +class RuleStructuredOutputPayload(BaseModel): + instruction: str = Field(..., description="Structured output generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index be1e306d4..5b2c64026 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -6,6 +6,8 @@ from typing import Protocol, cast import json_repair +from core.app.app_config.entities import ModelConfig +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( @@ -151,19 +153,19 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): + def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload): output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = model_config.get("completion_params", {}) - if no_variable: + model_parameters = args.model_config_data.completion_params + if args.no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -175,8 +177,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -190,7 +192,7 @@ class LLMGenerator: error = str(e) error_step = "generate rule config" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -209,7 +211,7 @@ class LLMGenerator: # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -220,8 +222,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -250,7 +252,7 @@ class LLMGenerator: # the second step to generate the task_parameter and task_statement statement_generate_prompt = statement_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, "INPUT_TEXT": prompt_content.message.get_text_content(), }, remove_template_variables=False, @@ -276,7 +278,7 @@ class LLMGenerator: error_step = "generate conversation opener" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -284,16 +286,20 @@ class LLMGenerator: return rule_config @classmethod - def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): - if code_language == "python": + def generate_code( + cls, + tenant_id: str, + args: RuleCodeGeneratePayload, + ): + if args.code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) prompt = prompt_template.format( inputs={ - "INSTRUCTION": instruction, - "CODE_LANGUAGE": code_language, + "INSTRUCTION": args.instruction, + "CODE_LANGUAGE": args.code_language, }, remove_template_variables=False, ) @@ -302,28 +308,28 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = model_config.get("completion_params", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) generated_code = response.message.get_text_content() - return {"code": generated_code, "language": code_language, "error": ""} + return {"code": generated_code, "language": args.code_language, "error": ""} except InvokeError as e: error = str(e) - return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"} except Exception as e: logger.exception( - "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language + "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language ) - return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"} @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): @@ -353,20 +359,20 @@ class LLMGenerator: return answer.strip() @classmethod - def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), - UserPromptMessage(content=instruction), + UserPromptMessage(content=args.instruction), ] - model_parameters = model_config.get("model_parameters", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( @@ -390,12 +396,17 @@ class LLMGenerator: error = str(e) return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) + logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name) return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} @staticmethod def instruction_modify_legacy( - tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None + tenant_id: str, + flow_id: str, + current: str, + instruction: str, + model_config: ModelConfig, + ideal_output: str | None, ): last_run: Message | None = ( db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() @@ -434,7 +445,7 @@ class LLMGenerator: node_id: str, current: str, instruction: str, - model_config: dict, + model_config: ModelConfig, ideal_output: str | None, workflow_service: WorkflowServiceInterface, ): @@ -505,7 +516,7 @@ class LLMGenerator: @staticmethod def __instruction_modify_common( tenant_id: str, - model_config: dict, + model_config: ModelConfig, last_run: dict | None, current: str | None, error_message: str | None, @@ -526,8 +537,8 @@ class LLMGenerator: model_instance = ModelManager().get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=model_config.provider, + model=model_config.name, ) match node_type: case "llm" | "agent": @@ -570,7 +581,5 @@ class LLMGenerator: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} except Exception as e: - logger.exception( - "Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True - ) + logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True) return {"error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ec2b7f2d4..ee9a016c9 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -434,3 +434,22 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex You should edit the prompt according to the IDEAL OUTPUT.""" INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" + +DEFAULT_GENERATOR_SUMMARY_PROMPT = ( + """Summarize the following content. Extract only the key information and main points. """ + """Remove redundant details. + +Requirements: +1. Write a concise summary in plain text +2. You must write in {language}. No language other than {language} should be used. +3. Focus on important facts, concepts, and details +4. If images are included, describe their key information +5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions" +6. Write directly without extra words +7. If there is not enough content to generate a meaningful summary, + return an empty string without any explanation or prompt + +Output only the summary text. Start summarizing now: + +""" +) diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 84a6fd0d1..e1a40593e 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -347,7 +347,7 @@ class BaseSession( message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) ) - responder = RequestResponder( + responder = RequestResponder[ReceiveRequestT, SendResultT]( request_id=message.message.root.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, request=validated_request, diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 76969fea7..51c9c5125 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.MAX_TOKENS: { "label": { "en_US": "Max Tokens", - "zh_Hans": "最大标记", + "zh_Hans": "最大 Token 数", }, "type": "int", "help": { diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 45f0335c2..c3e50eadd 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,10 +1,11 @@ import decimal import hashlib -from threading import Lock +import logging -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from redis import RedisError -import contexts +from configs import dify_config from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) class AIModel(BaseModel): @@ -144,34 +148,60 @@ class AIModel(BaseModel): plugin_model_manager = PluginModelClient() cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model_type=self.model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 7a0757f21..bbbdec61d 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -92,6 +92,10 @@ def _build_llm_result_from_first_chunk( Build a single `LLMResult` from the first returned chunk. This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. + + Note: + This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying + streaming resources are released (e.g., HTTP connections owned by the plugin runtime). """ content = "" content_list: list[PromptMessageContentUnionTypes] = [] @@ -99,18 +103,25 @@ def _build_llm_result_from_first_chunk( system_fingerprint: str | None = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] - first_chunk = next(chunks, None) - if first_chunk is not None: - if isinstance(first_chunk.delta.message.content, str): - content += first_chunk.delta.message.content - elif isinstance(first_chunk.delta.message.content, list): - content_list.extend(first_chunk.delta.message.content) + try: + first_chunk = next(chunks, None) + if first_chunk is not None: + if isinstance(first_chunk.delta.message.content, str): + content += first_chunk.delta.message.content + elif isinstance(first_chunk.delta.message.content, list): + content_list.extend(first_chunk.delta.message.content) - if first_chunk.delta.message.tool_calls: - _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) + if first_chunk.delta.message.tool_calls: + _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) - usage = first_chunk.delta.usage or LLMUsage.empty_usage() - system_fingerprint = first_chunk.system_fingerprint + usage = first_chunk.delta.usage or LLMUsage.empty_usage() + system_fingerprint = first_chunk.system_fingerprint + finally: + try: + for _ in chunks: + pass + except Exception: + logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True) return LLMResult( model=model, @@ -283,7 +294,7 @@ class LargeLanguageModel(AIModel): # TODO raise self._transform_invoke_error(e) - if stream and isinstance(result, Generator): + if stream and not isinstance(result, LLMResult): return self._invoke_result_generator( model=model, result=result, diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 28f162a92..9cfc6889a 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,7 +5,11 @@ import logging from collections.abc import Sequence from threading import Lock +from pydantic import ValidationError +from redis import RedisError + import contexts +from configs import dify_config from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -175,34 +180,60 @@ class ModelProviderFactory: """ plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = self.plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_models( self, @@ -283,6 +314,8 @@ class ModelProviderFactory: elif model_type == ModelType.TTS: return TTSModel.model_validate(init_params) + raise ValueError(f"Unsupported model type: {model_type}") + def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ Get provider icon diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index e32b1ddb3..abd672fe6 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -23,7 +23,13 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding +from models.dataset import ( + ChildChunk, + Dataset, + DocumentSegment, + DocumentSegmentSummary, + SegmentAttachmentBinding, +) from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService @@ -378,15 +384,15 @@ class RetrievalService: .all() } - records = [] - include_segment_ids = set() - segment_child_map = {} - valid_dataset_documents = {} image_doc_ids: list[Any] = [] child_index_node_ids = [] index_node_ids = [] doc_to_document_map = {} + summary_segment_ids = set() # Track segments retrieved via summary + summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score + + # First pass: collect all document IDs and identify summary documents for document in documents: document_id = document.metadata.get("document_id") if document_id not in dataset_documents: @@ -397,16 +403,39 @@ class RetrievalService: continue valid_dataset_documents[document_id] = dataset_document + doc_id = document.metadata.get("doc_id") or "" + doc_to_document_map[doc_id] = document + + # Check if this is a summary document + is_summary = document.metadata.get("is_summary", False) + if is_summary: + # For summary documents, find the original chunk via original_chunk_id + original_chunk_id = document.metadata.get("original_chunk_id") + if original_chunk_id: + summary_segment_ids.add(original_chunk_id) + # Save summary's score for later use + summary_score = document.metadata.get("score") + if summary_score is not None: + try: + summary_score_float = float(summary_score) + # If the same segment has multiple summary hits, take the highest score + if original_chunk_id not in summary_score_map: + summary_score_map[original_chunk_id] = summary_score_float + else: + summary_score_map[original_chunk_id] = max( + summary_score_map[original_chunk_id], summary_score_float + ) + except (ValueError, TypeError): + # Skip invalid score values + pass + continue # Skip adding to other lists for summary documents + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: child_index_node_ids.append(doc_id) else: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: @@ -419,9 +448,10 @@ class RetrievalService: segment_ids = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map = {} - child_chunk_map: dict[Any, Any] = {} - doc_segment_map = {} + attachment_map: dict[str, list[dict[str, Any]]] = {} + child_chunk_map: dict[str, list[ChildChunk]] = {} + doc_segment_map: dict[str, list[str]] = {} + segment_summary_map: dict[str, str] = {} # Map segment_id to summary content with session_factory.create_session() as session: attachments = cls.get_segment_attachment_infos(image_doc_ids, session) @@ -436,6 +466,7 @@ class RetrievalService: doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) else: doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_index_nodes = session.execute(child_chunk_stmt).scalars().all() @@ -459,6 +490,7 @@ class RetrievalService: index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore for index_node_segment in index_node_segments: doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] + if segment_ids: document_segment_stmt = select(DocumentSegment).where( DocumentSegment.enabled == True, @@ -470,6 +502,40 @@ class RetrievalService: if index_node_segments: segments.extend(index_node_segments) + # Handle summary documents: query segments by original_chunk_id + if summary_segment_ids: + summary_segment_ids_list = list(summary_segment_ids) + summary_segment_stmt = select(DocumentSegment).where( + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id.in_(summary_segment_ids_list), + ) + summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore + segments.extend(summary_segments) + # Add summary segment IDs to segment_ids for summary query + for seg in summary_segments: + if seg.id not in segment_ids: + segment_ids.append(seg.id) + + # Batch query summaries for segments retrieved via summary (only enabled summaries) + if summary_segment_ids: + summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)), + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries + ) + .all() + ) + for summary in summaries: + if summary.summary_content: + segment_summary_map[summary.chunk_id] = summary.summary_content + + include_segment_ids = set() + segment_child_map: dict[str, dict[str, Any]] = {} + records: list[dict[str, Any]] = [] + for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) @@ -478,45 +544,68 @@ class RetrievalService: if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) + # Check if this segment was retrieved via summary + # Use summary score as base score if available, otherwise 0.0 + max_score = summary_score_map.get(segment.id, 0.0) + if child_chunks or attachment_infos: child_chunk_details = [] - max_score = 0.0 for child_chunk in child_chunks: - document = doc_to_document_map[child_chunk.index_node_id] + child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) + if child_document: + child_score = child_document.metadata.get("score", 0.0) + else: + child_score = 0.0 child_chunk_detail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, - "score": document.metadata.get("score", 0.0) if document else 0.0, + "score": child_score, } child_chunk_details.append(child_chunk_detail) - max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) + max_score = max(max_score, child_score) for attachment_info in attachment_infos: - file_document = doc_to_document_map[attachment_info["id"]] - max_score = max( - max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 - ) + file_document = doc_to_document_map.get(attachment_info["id"]) + if file_document: + max_score = max(max_score, file_document.metadata.get("score", 0.0)) map_detail = { "max_score": max_score, "child_chunks": child_chunk_details, } segment_child_map[segment.id] = map_detail - record = { + else: + # No child chunks or attachments, use summary score if available + summary_score = summary_score_map.get(segment.id) + if summary_score is not None: + segment_child_map[segment.id] = { + "max_score": summary_score, + "child_chunks": [], + } + record: dict[str, Any] = { "segment": segment, } records.append(record) else: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) - max_score = 0.0 - document = doc_to_document_map.get(segment.index_node_id) - if document: - max_score = max(max_score, document.metadata.get("score", 0.0)) + + # Check if this segment was retrieved via summary + # Use summary score if available (summary retrieval takes priority) + max_score = summary_score_map.get(segment.id, 0.0) + + # If not retrieved via summary, use original segment's score + if segment.id not in summary_score_map: + segment_document = doc_to_document_map.get(segment.index_node_id) + if segment_document: + max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + + # Also consider attachment scores for attachment_info in attachment_infos: - file_document = doc_to_document_map.get(attachment_info["id"]) - if file_document: - max_score = max(max_score, file_document.metadata.get("score", 0.0)) + file_doc = doc_to_document_map.get(attachment_info["id"]) + if file_doc: + max_score = max(max_score, file_doc.metadata.get("score", 0.0)) + record = { "segment": segment, "score": max_score, @@ -557,9 +646,16 @@ class RetrievalService: else None ) + # Extract summary if this segment was retrieved via summary + summary_content = segment_summary_map.get(segment.id) + # Create RetrievalSegments object retrieval_segment = RetrievalSegments( - segment=segment, child_chunks=child_chunks, score=score, files=files + segment=segment, + child_chunks=child_chunks, + score=score, + files=files, + summary=summary_content, ) result.append(retrieval_segment) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f8c62b908..4a4a458f2 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -391,46 +391,78 @@ class QdrantVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - """Return docs most similar by bm25. + """Return docs most similar by full-text search. + + Searches each keyword separately and merges results to ensure documents + matching ANY keyword are returned (OR logic). Results are capped at top_k. + + Args: + query: Search query text. Multi-word queries are split into keywords, + with each keyword searched separately. Limited to 10 keywords. + **kwargs: Additional search parameters (top_k, document_ids_filter) + Returns: - List of documents most similar to the query text and distance for each. + List of up to top_k unique documents matching any query keyword. """ from qdrant_client.http import models - scroll_filter = models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self._group_id), - ), - models.FieldCondition( - key="page_content", - match=models.MatchText(text=query), - ), - ] - ) + # Build base must conditions (AND logic) for metadata filters + base_must_conditions: list = [ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ] + document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: - if scroll_filter.must: - scroll_filter.must.append( - models.FieldCondition( - key="metadata.document_id", - match=models.MatchAny(any=document_ids_filter), - ) + base_must_conditions.append( + models.FieldCondition( + key="metadata.document_id", + match=models.MatchAny(any=document_ids_filter), ) - response = self._client.scroll( - collection_name=self._collection_name, - scroll_filter=scroll_filter, - limit=kwargs.get("top_k", 2), - with_payload=True, - with_vectors=True, - ) - results = response[0] - documents = [] - for result in results: - if result: - document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) - documents.append(document) + ) + + # Split query into keywords, deduplicate and limit to prevent DoS + keywords = list(dict.fromkeys(kw.strip() for kw in query.strip().split() if kw.strip()))[:10] + + if not keywords: + return [] + + top_k = kwargs.get("top_k", 2) + seen_ids: set[str | int] = set() + documents: list[Document] = [] + + # Search each keyword separately and merge results. + # This ensures each keyword gets its own search, preventing one keyword's + # results from completely overshadowing another's due to scroll ordering. + for keyword in keywords: + scroll_filter = models.Filter( + must=[ + *base_must_conditions, + models.FieldCondition( + key="page_content", + match=models.MatchText(text=keyword), + ), + ] + ) + + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=top_k, + with_payload=True, + with_vectors=True, + ) + results = response[0] + + for result in results: + if result and result.id not in seen_ids: + seen_ids.add(result.id) + document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) + documents.append(document) + if len(documents) >= top_k: + return documents return documents diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index b54a37b49..f6834ab87 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel): child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None files: list[dict[str, str | int]] | None = None + summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 9f66cd9a0..aec5c353f 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel): doc_metadata: dict[str, Any] | None = None title: str | None = None files: list[dict[str, Any]] | None = None + summary: str | None = None diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 511f5a698..1ddbfc586 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,7 @@ -"""Abstract interface for document loader implementations.""" +"""Word (.docx) document extractor used for RAG ingestion. + +Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). +""" import logging import mimetypes @@ -8,7 +11,6 @@ import tempfile import uuid from urllib.parse import urlparse -import httpx from docx import Document as DocxDocument from docx.oxml.ns import qn from docx.text.run import Run @@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor): # If the file is a web path, download it to a temporary file, and use that if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): - response = httpx.get(self.file_path, timeout=None) + response = ssrf_proxy.get(self.file_path) if response.status_code != 200: response.close() @@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor): self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115 try: self.temp_file.write(response.content) + self.temp_file.flush() finally: response.close() self.file_path = self.temp_file.name diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index e36b54eed..6e76321ea 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse import httpx from configs import dify_config +from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType @@ -45,6 +46,27 @@ class BaseIndexProcessor(ABC): def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError + @abstractmethod + def generate_summary_preview( + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, + ) -> list[PreviewDetail]: + """ + For each segment in preview_texts, generate a summary using LLM and attach it to the segment. + The summary can be stored in a new attribute, e.g., summary. + This method should be implemented by subclasses. + + Args: + tenant_id: Tenant ID + preview_texts: List of preview details to generate summaries for + summary_index_setting: Summary index configuration + doc_language: Optional document language to ensure summary is generated in the correct language + """ + raise NotImplementedError + @abstractmethod def load( self, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index cf68cff7d..41d7656f8 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,9 +1,27 @@ """Paragraph index processor.""" +import logging +import re import uuid from collections.abc import Mapping -from typing import Any +from typing import Any, cast +logger = logging.getLogger(__name__) + +from core.entities.knowledge_entities import PreviewDetail +from core.file import File, FileTransferMethod, FileType, file_manager +from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService @@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.workflow.nodes.llm import llm_utils +from extensions.ext_database import db +from factories.file_factory import build_from_mapping from libs import helper +from models import UploadFile from models.account import Account -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService class ParagraphIndexProcessor(BaseIndexProcessor): @@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword.add_texts(documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: @@ -227,3 +273,347 @@ class ParagraphIndexProcessor(BaseIndexProcessor): } else: raise ValueError("Chunks is not a list") + + def generate_summary_preview( + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, + ) -> list[PreviewDetail]: + """ + For each segment, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item.""" + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts + + @staticmethod + def generate_summary( + tenant_id: str, + text: str, + summary_index_setting: dict | None = None, + segment_id: str | None = None, + document_language: str | None = None, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, + and supports vision models by including images from the segment attachments or text content. + + Args: + tenant_id: Tenant ID + text: Text content to summarize + summary_index_setting: Summary index configuration + segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + document_language: Optional document language (e.g., "Chinese", "English") + to ensure summary is generated in the correct language + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + """ + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("summary_index_setting is required and must be enabled to generate summary.") + + model_name = summary_index_setting.get("model_name") + model_provider_name = summary_index_setting.get("model_provider_name") + summary_prompt = summary_index_setting.get("summary_prompt") + + if not model_name or not model_provider_name: + raise ValueError("model_name and model_provider_name are required in summary_index_setting") + + # Import default summary prompt + is_default_prompt = False + if not summary_prompt: + summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + is_default_prompt = True + + # Format prompt with document language only for default prompt + # Custom prompts are used as-is to avoid interfering with user-defined templates + # If document_language is provided, use it; otherwise, use "the same language as the input content" + # This is especially important for image-only chunks where text is empty or minimal + if is_default_prompt: + language_for_prompt = document_language or "the same language as the input content" + try: + summary_prompt = summary_prompt.format(language=language_for_prompt) + except KeyError: + # If default prompt doesn't have {language} placeholder, use it as-is + pass + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id, model_provider_name, ModelType.LLM + ) + model_instance = ModelInstance(provider_model_bundle, model_name) + + # Get model schema to check if vision is supported + model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials) + supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features + + # Extract images if model supports vision + image_files = [] + if supports_vision: + # First, try to get images from SegmentAttachmentBinding (preferred method) + if segment_id: + image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id) + + # If no images from attachments, fall back to extracting from text + if not image_files: + image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text) + + # Build prompt messages + prompt_messages = [] + + if image_files: + # If we have images, create a UserPromptMessage with both text and images + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + + # Add images first + for file in image_files: + try: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW + ) + prompt_message_contents.append(file_content) + except Exception as e: + logger.warning("Failed to convert image file to prompt message content: %s", str(e)) + continue + + # Add text content + if prompt_message_contents: # Only add text if we successfully added images + prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}")) + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + # If image conversion failed, fall back to text-only + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + else: + # No images, use simple text prompt + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + + result = model_instance.invoke_llm( + prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False + ) + + # Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator + if not isinstance(result, LLMResult): + raise ValueError("Expected LLMResult when stream=False") + + summary_content = getattr(result.message, "content", "") + usage = result.usage + + # Deduct quota for summary generation (same as workflow nodes) + try: + llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + except Exception as e: + # Log but don't fail summary generation if quota deduction fails + logger.warning("Failed to deduct quota for summary generation: %s", str(e)) + + return summary_content, usage + + @staticmethod + def _extract_images_from_text(tenant_id: str, text: str) -> list[File]: + """ + Extract images from markdown text and convert them to File objects. + + Args: + tenant_id: Tenant ID + text: Text content that may contain markdown image links + + Returns: + List of File objects representing images found in the text + """ + # Extract markdown images using regex pattern + pattern = r"!\[.*?\]\((.*?)\)" + images = re.findall(pattern, text) + + if not images: + return [] + + upload_file_id_list = [] + + for image in images: + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + # Tool files are handled differently, skip for now + continue + + if not upload_file_id_list: + return [] + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = ( + db.session.query(UploadFile) + .where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) + .all() + ) + + # Create File objects from UploadFile records + file_objects = [] + for upload_file in upload_files: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + mapping = { + "upload_file_id": upload_file.id, + "transfer_method": FileTransferMethod.LOCAL_FILE.value, + "type": FileType.IMAGE.value, + } + + try: + file_obj = build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects + + @staticmethod + def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]: + """ + Extract images from SegmentAttachmentBinding table (preferred method). + This matches how DatasetRetrieval gets segment attachments. + + Args: + tenant_id: Tenant ID + segment_id: Segment ID to fetch attachments for + + Returns: + List of File objects representing images found in segment attachments + """ + from sqlalchemy import select + + # Query attachments from SegmentAttachmentBinding table + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment_id, + SegmentAttachmentBinding.tenant_id == tenant_id, + ) + ).all() + + if not attachments_with_bindings: + return [] + + file_objects = [] + for _, upload_file in attachments_with_bindings: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + try: + # Create File object directly (similar to DatasetRetrieval) + file_obj = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 0366f3259..0ea77405e 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,11 +1,14 @@ """Paragraph index processor.""" import json +import logging import uuid from collections.abc import Mapping from typing import Any from configs import dify_config +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) class ParentChildIndexProcessor(BaseIndexProcessor): @@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor): def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") @@ -326,3 +356,97 @@ class ParentChildIndexProcessor(BaseIndexProcessor): "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + + def generate_summary_preview( + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, + ) -> list[PreviewDetail]: + """ + For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + + Note: For parent-child structure, we only generate summaries for parent chunks. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item (parent chunk).""" + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + document_language=doc_language, + ) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + document_language=doc_language, + ) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 1183d5fbd..40d9caaa6 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -11,6 +11,8 @@ import pandas as pd from flask import Flask, current_app from werkzeug.datastructures import FileStorage +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.account import Account -from models.dataset import Dataset +from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) @@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor): vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Note: qa_model doesn't generate summaries, but we clean them for completeness + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -212,6 +240,21 @@ class QAIndexProcessor(BaseIndexProcessor): "total_segments": len(qa_chunks.qa_chunks), } + def generate_summary_preview( + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, + ) -> list[PreviewDetail]: + """ + QA model doesn't generate summaries, so this method returns preview_texts unchanged. + + Note: QA model uses question-answer pairs, which don't require summary generation. + """ + # QA model doesn't generate summaries, return as-is + return preview_texts + def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f8f85d141..541c241ae 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -236,20 +236,24 @@ class DatasetRetrieval: if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if vision_enabled: attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -316,6 +320,9 @@ class DatasetRetrieval: source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index b4ecfe47f..b87fba4ea 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -35,6 +35,7 @@ class SchemaRegistry: registry.load_all_versions() cls._default_instance = registry + return cls._default_instance return cls._default_instance diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 383be199d..2a4094554 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -226,16 +226,13 @@ class ToolManager: raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") if not provider_controller.need_credentials: - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) builtin_provider = None if isinstance(provider_controller, PluginToolProviderController): @@ -337,18 +334,15 @@ class ToolManager: decrypted_credentials = refreshed_credentials.credentials cache.delete() - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials=dict(decrypted_credentials), - credential_type=CredentialType.of(builtin_provider.credential_type), - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=dict(decrypted_credentials), + credential_type=CredentialType.of(builtin_provider.credential_type), + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) elif provider_type == ToolProviderType.API: diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f96510fb4..057ec41f6 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if self.return_resource: for record in records: @@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if self.return_resource and retrieval_resource_list: diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 188da0c32..6d75df360 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity class WorkflowToolConfigurationUtils: - @classmethod - def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): - for configuration in configurations: - WorkflowToolParameterConfiguration.model_validate(configuration) - @classmethod def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ diff --git a/api/core/trigger/debug/event_bus.py b/api/core/trigger/debug/event_bus.py index 9d10e1a0e..e3fb6a13d 100644 --- a/api/core/trigger/debug/event_bus.py +++ b/api/core/trigger/debug/event_bus.py @@ -23,8 +23,8 @@ class TriggerDebugEventBus: """ # LUA_SELECT: Atomic poll or register for event - # KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id} - # KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:... + # KEYS[1] = trigger_debug_inbox:{}: + # KEYS[2] = trigger_debug_waiting_pool:{}:... # ARGV[1] = address_id LUA_SELECT = ( "local v=redis.call('GET',KEYS[1]);" @@ -35,7 +35,7 @@ class TriggerDebugEventBus: ) # LUA_DISPATCH: Dispatch event to all waiting addresses - # KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:... + # KEYS[1] = trigger_debug_waiting_pool:{}:... # ARGV[1] = tenant_id # ARGV[2] = event_json LUA_DISPATCH = ( @@ -43,7 +43,7 @@ class TriggerDebugEventBus: "if #a==0 then return 0 end;" "redis.call('DEL',KEYS[1]);" "for i=1,#a do " - f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});" + f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});" "end;" "return #a" ) @@ -108,7 +108,7 @@ class TriggerDebugEventBus: Event object if available, None otherwise """ address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest() - address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}" + address: str = f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}" try: event_data = redis_client.eval( diff --git a/api/core/trigger/debug/events.py b/api/core/trigger/debug/events.py index 9f7bab5e4..9aec342ed 100644 --- a/api/core/trigger/debug/events.py +++ b/api/core/trigger/debug/events.py @@ -42,7 +42,7 @@ def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str: app_id: App ID node_id: Node ID """ - return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}" + return f"{TriggerDebugPoolKey.WEBHOOK}:{{{tenant_id}}}:{app_id}:{node_id}" class PluginTriggerDebugEvent(BaseDebugEvent): @@ -64,4 +64,4 @@ def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str provider_id: Provider ID subscription_id: Subscription ID """ - return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}" + return f"{TriggerDebugPoolKey.PLUGIN}:{{{tenant_id}}}:{str(provider_id)}:{subscription_id}:{name}" diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 31bf6f3b2..52bbbb20c 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -5,15 +5,20 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Protocol, cast, final +from pydantic import TypeAdapter + +from core.workflow.entities.graph_config import NodeConfigDict from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node -from libs.typing import is_str, is_str_dict +from libs.typing import is_str from .edge import Edge from .validation import get_graph_validator logger = logging.getLogger(__name__) +_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) + class NodeFactory(Protocol): """ @@ -23,7 +28,7 @@ class NodeFactory(Protocol): allowing for different node creation strategies while maintaining type safety. """ - def create_node(self, node_config: dict[str, object]) -> Node: + def create_node(self, node_config: NodeConfigDict) -> Node: """ Create a Node instance from node configuration data. @@ -63,28 +68,24 @@ class Graph: self.root_node = root_node @classmethod - def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]: + def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: """ Parse node configurations and build a mapping of node IDs to configs. :param node_configs: list of node configuration dictionaries :return: mapping of node ID to node config """ - node_configs_map: dict[str, dict[str, object]] = {} + node_configs_map: dict[str, NodeConfigDict] = {} for node_config in node_configs: - node_id = node_config.get("id") - if not node_id or not isinstance(node_id, str): - continue - - node_configs_map[node_id] = node_config + node_configs_map[node_config["id"]] = node_config return node_configs_map @classmethod def _find_root_node_id( cls, - node_configs_map: Mapping[str, Mapping[str, object]], + node_configs_map: Mapping[str, NodeConfigDict], edge_configs: Sequence[Mapping[str, object]], root_node_id: str | None = None, ) -> str: @@ -113,10 +114,8 @@ class Graph: # Prefer START node if available start_node_id = None for nid in root_candidates: - node_data = node_configs_map[nid].get("data") - if not is_str_dict(node_data): - continue - node_type = node_data.get("type") + node_data = node_configs_map[nid]["data"] + node_type = node_data["type"] if not isinstance(node_type, str): continue if NodeType(node_type).is_start_node: @@ -176,7 +175,7 @@ class Graph: @classmethod def _create_node_instances( cls, - node_configs_map: dict[str, dict[str, object]], + node_configs_map: dict[str, NodeConfigDict], node_factory: NodeFactory, ) -> dict[str, Node]: """ @@ -303,7 +302,7 @@ class Graph: node_configs = graph_config.get("nodes", []) edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) + node_configs = _ListNodeConfigDict.validate_python(node_configs) if not node_configs: raise ValueError("Graph must have at least one node") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 0b359a239..2b76b563f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel -from .ready_queue import ReadyQueue from .worker_management import WorkerPool if TYPE_CHECKING: @@ -90,7 +89,7 @@ class GraphEngine: self._graph_execution.workflow_id = workflow_id # === Execution Queues === - self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue) + self._ready_queue = self._graph_runtime_state.ready_queue # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 98e0ea91e..e82ba2943 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -15,10 +15,10 @@ from uuid import uuid4 from pydantic import BaseModel, Field from core.workflow.enums import NodeExecutionType, NodeState -from core.workflow.graph import Graph from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.runtime import VariablePool +from core.workflow.runtime.graph_runtime_state import GraphProtocol from .path import Path from .session import ResponseSession @@ -75,7 +75,7 @@ class ResponseStreamCoordinator: Ensures ordered streaming of responses based on upstream node outputs and constants. """ - def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None: + def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: """ Initialize coordinator with variable pool. diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8ceaa428c..5e4fada7d 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -10,10 +10,10 @@ from __future__ import annotations from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.knowledge_index import KnowledgeIndexNode +from core.workflow.runtime.graph_runtime_state import NodeProtocol @dataclass @@ -29,21 +29,26 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> ResponseSession: + def from_node(cls, node: NodeProtocol) -> ResponseSession: """ - Create a ResponseSession from an AnswerNode or EndNode. + Create a ResponseSession from a response-capable node. + + The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer, + but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides: + - `id: str` + - `get_streaming_template() -> Template` Args: - node: Must be either an AnswerNode or EndNode instance + node: Node from the materialized workflow graph. Returns: ResponseSession configured with the node's streaming template Raises: - TypeError: If node is not an AnswerNode or EndNode + TypeError: If node is not a supported response node type. """ if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode): - raise TypeError + raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode") return cls( node_id=node.id, template=node.get_streaming_template(), diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5a365f769..e195aebe6 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]): result[parameter_name] = None continue agent_input = node_data.agent_parameters[parameter_name] - if agent_input.type == "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - elif agent_input.type in {"mixed", "constant"}: - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + # variable_pool.convert_template expects a string template, + # but if passing a dict, convert to JSON string first before rendering + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - else: - raise AgentInputTypeError(agent_input.type) + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + # variable_pool.convert_template returns a string, + # so we need to convert it back to a dictionary + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]): result: dict[str, Any] = {} for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] - if input.type in ["mixed", "constant"]: - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 10a1c897e..802601119 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Self +from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel @@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, Self] | None = None + children: dict[str, "CodeNodeData.Output"] | None = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 925561cf7..a732a7041 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]): if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") + datasource_type = DatasourceProviderType.value_of(datasource_type) + datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType.value_of(datasource_type), + datasource_type=datasource_type, ) datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) @@ -268,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]): if typed_node_data.datasource_parameters: for parameter_name in typed_node_data.datasource_parameters: input = typed_node_data.datasource_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + case "constant": + pass + case None: + pass result = {node_id + "." + key: value for key, value in result.items()} @@ -306,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]): variables: dict[str, Any] = {} for message in message_stream: - if message.type in { - DatasourceMessage.MessageType.IMAGE_LINK, - DatasourceMessage.MessageType.BINARY_LINK, - DatasourceMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceMessage.TextMessage) + match message.type: + case ( + DatasourceMessage.MessageType.IMAGE_LINK + | DatasourceMessage.MessageType.BINARY_LINK + | DatasourceMessage.MessageType.IMAGE + ): + assert isinstance(message.message, DatasourceMessage.TextMessage) - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE + url = message.message.text + transfer_method = FileTransferMethod.TOOL_FILE - datasource_file_id = str(url).split("/")[-1].split(".")[0] + datasource_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( mapping=mapping, tenant_id=self.tenant_id, ) - ) - elif message.type == DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - elif message.type == DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value + files.append(file) + case DatasourceMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceMessage.TextMessage) + assert message.meta + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "tool_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + case DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + text += message.message.text yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, + selector=[self._node_id, "text"], + chunk=message.message.text, is_final=False, ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) + case DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) + json.append(message.message.json_object) + case DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=stream_text, + is_final=False, + ) + case DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[self._node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + case DatasourceMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + case ( + DatasourceMessage.MessageType.BLOB_CHUNK + | DatasourceMessage.MessageType.LOG + | DatasourceMessage.MessageType.RETRIEVER_RESOURCES + ): + pass + # mark the end of the stream yield StreamChunkEvent( selector=[self._node_id, "text"], diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 429f8411a..7de821656 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -2,7 +2,7 @@ import base64 import json import secrets import string -from collections.abc import Mapping +from collections.abc import Callable, Mapping from copy import deepcopy from typing import Any, Literal from urllib.parse import urlencode, urlparse @@ -11,9 +11,9 @@ import httpx from json_repair import repair_json from configs import dify_config -from core.file import file_manager from core.file.enums import FileTransferMethod -from core.helper import ssrf_proxy +from core.file.file_manager import file_manager as default_file_manager +from core.helper.ssrf_proxy import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.runtime import VariablePool @@ -79,8 +79,8 @@ class Executor: timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, - http_client: HttpClientProtocol = ssrf_proxy, - file_manager: FileManagerProtocol = file_manager, + http_client: HttpClientProtocol | None = None, + file_manager: FileManagerProtocol | None = None, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -107,8 +107,8 @@ class Executor: self.data = None self.json = None self.max_retries = max_retries - self._http_client = http_client - self._file_manager = file_manager + self._http_client = http_client or ssrf_proxy + self._file_manager = file_manager or default_file_manager # init template self.variable_pool = variable_pool @@ -336,7 +336,7 @@ class Executor: """ do http request depending on api bundle """ - _METHOD_MAP = { + _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { "get": self._http_client.get, "head": self._http_client.head, "post": self._http_client.post, @@ -348,7 +348,7 @@ class Executor: if method_lc not in _METHOD_MAP: raise InvalidHttpMethodError(f"Invalid http method {self.method}") - request_args = { + request_args: dict[str, Any] = { "data": self.data, "files": self.files, "json": self.json, @@ -361,14 +361,13 @@ class Executor: } # request_args = {k: v for k, v in request_args.items() if v is not None} try: - response: httpx.Response = _METHOD_MAP[method_lc]( + response = _METHOD_MAP[method_lc]( url=self.url, **request_args, max_retries=self.max_retries, ) except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e: raise HttpRequestNodeError(str(e)) from e - # FIXME: fix type ignore, this maybe httpx type issue return response def invoke(self) -> Response: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 964e53e03..480482375 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any from configs import dify_config -from core.file import File, FileTransferMethod, file_manager -from core.helper import ssrf_proxy +from core.file import File, FileTransferMethod +from core.file.file_manager import file_manager as default_file_manager +from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.variables.segments import ArrayFileSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - http_client: HttpClientProtocol = ssrf_proxy, + http_client: HttpClientProtocol | None = None, tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, - file_manager: FileManagerProtocol = file_manager, + file_manager: FileManagerProtocol | None = None, ) -> None: super().__init__( id=id, @@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._http_client = http_client + self._http_client = http_client or ssrf_proxy self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager + self._file_manager = file_manager or default_file_manager @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c19182549..25a881ea7 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return outputs # Check if all non-None outputs are lists - non_none_outputs = [output for output in outputs if output is not None] + non_none_outputs: list[object] = [output for output in outputs if output is not None] if not non_none_outputs: return outputs diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 3daca90b9..bfeb9b5b7 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData): type: str = "knowledge-index" chunk_structure: str index_chunk_variable_selector: list[str] + indexing_technique: str | None = None + summary_index_setting: dict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 17ca4bef7..2aff953bc 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,9 +1,11 @@ +import concurrent.futures import datetime import logging import time from collections.abc import Mapping from typing import Any +from flask import current_app from sqlalchemy import func, select from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.runtime import VariablePool from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary +from services.summary_index_service import SummaryIndexService +from tasks.generate_summary_index_task import generate_summary_index_task from .entities import KnowledgeIndexNodeData from .exc import ( @@ -67,7 +71,29 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): # index knowledge try: if is_preview: - outputs = self._get_preview_output(node_data.chunk_structure, chunks) + # Preview mode: generate summaries for chunks directly without saving to database + # Format preview and generate summaries on-the-fly + # Get indexing_technique and summary_index_setting from node_data (workflow graph config) + # or fallback to dataset if not available in node_data + indexing_technique = node_data.indexing_technique or dataset.indexing_technique + summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + + # Try to get document language if document_id is available + doc_language = None + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document and document.doc_language: + doc_language = document.doc_language + + outputs = self._get_preview_output_with_summaries( + node_data.chunk_structure, + chunks, + dataset=dataset, + indexing_technique=indexing_technique, + summary_index_setting=summary_index_setting, + doc_language=doc_language, + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -148,6 +174,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): ) .scalar() ) + # Update need_summary based on dataset's summary_index_setting + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + document.need_summary = True + else: + document.need_summary = False db.session.add(document) # update document segment status db.session.query(DocumentSegment).where( @@ -163,6 +194,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): db.session.commit() + # Generate summary index if enabled + self._handle_summary_index_generation(dataset, document, variable_pool) + return { "dataset_id": ds_id_value, "dataset_name": dataset_name_value, @@ -173,9 +207,308 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): "display_status": "completed", } - def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]: + def _handle_summary_index_generation( + self, + dataset: Dataset, + document: Document, + variable_pool: VariablePool, + ) -> None: + """ + Handle summary index generation based on mode (debug/preview or production). + + Args: + dataset: Dataset containing the document + document: Document to generate summaries for + variable_pool: Variable pool to check invoke_from + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + return + + # Skip qa_model documents + if document.doc_form == "qa_model": + return + + # Determine if in preview/debug mode + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER + + if is_preview: + try: + # Query segments that need summary generation + query = db.session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return + + # Filter segments based on mode + segments_to_process = [] + for segment in segments: + # Skip if summary already exists + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed") + .first() + ) + if existing_summary: + continue + + # For parent-child mode, all segments are parent chunks, so process all + segments_to_process.append(segment) + + if not segments_to_process: + logger.info("No segments need summary generation for document %s", document.id) + return + + # Use ThreadPoolExecutor for concurrent generation + flask_app = current_app._get_current_object() # type: ignore + max_workers = min(10, len(segments_to_process)) # Limit to 10 workers + + def process_segment(segment: DocumentSegment) -> None: + """Process a single segment in a thread with Flask app context.""" + with flask_app.app_context(): + try: + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting) + except Exception: + logger.exception( + "Failed to generate summary for segment %s", + segment.id, + ) + # Continue processing other segments + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(process_segment, segment) for segment in segments_to_process] + # Wait for all tasks to complete + concurrent.futures.wait(futures) + + logger.info( + "Successfully generated summary index for %s segments in document %s", + len(segments_to_process), + document.id, + ) + except Exception: + logger.exception("Failed to generate summary index for document %s", document.id) + # Don't fail the entire indexing process if summary generation fails + else: + # Production mode: asynchronous generation + logger.info( + "Queuing summary index generation task for document %s (production mode)", + document.id, + ) + try: + generate_summary_index_task.delay(dataset.id, document.id, None) + logger.info("Summary index generation task queued for document %s", document.id) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document.id, + ) + # Don't fail the entire indexing process if task queuing fails + + def _get_preview_output_with_summaries( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset, + indexing_technique: str | None = None, + summary_index_setting: dict | None = None, + doc_language: str | None = None, + ) -> Mapping[str, Any]: + """ + Generate preview output with summaries for chunks in preview mode. + This method generates summaries on-the-fly without saving to database. + + Args: + chunk_structure: Chunk structure type + chunks: Chunks to generate preview for + dataset: Dataset object (for tenant_id) + indexing_technique: Indexing technique from node config or dataset + summary_index_setting: Summary index setting from node config or dataset + doc_language: Optional document language to ensure summary is generated in the correct language + """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - return index_processor.format_preview(chunks) + preview_output = index_processor.format_preview(chunks) + + # Check if summary index is enabled + if indexing_technique != "high_quality": + return preview_output + + if not summary_index_setting or not summary_index_setting.get("enable"): + return preview_output + + # Generate summaries for chunks + if "preview" in preview_output and isinstance(preview_output["preview"], list): + chunk_count = len(preview_output["preview"]) + logger.info( + "Generating summaries for %s chunks in preview mode (dataset: %s)", + chunk_count, + dataset.id, + ) + # Use ParagraphIndexProcessor's generate_summary method + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + # Get Flask app for application context in worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def generate_summary_for_chunk(preview_item: dict) -> None: + """Generate summary for a single chunk.""" + if "content" in preview_item: + # Set Flask application context in worker thread + if flask_app: + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + document_language=doc_language, + ) + if summary: + preview_item["summary"] = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + document_language=doc_language, + ) + if summary: + preview_item["summary"] = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_output["preview"])) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor: + futures = [ + executor.submit(generate_summary_for_chunk, preview_item) + for preview_item in preview_output["preview"] + ] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode, if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise KnowledgeIndexNodeError(error_summary) + + completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None) + logger.info( + "Completed summary generation for preview chunks: %s/%s succeeded", + completed_count, + len(preview_output["preview"]), + ) + + return preview_output + + def _get_preview_output( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset | None = None, + variable_pool: VariablePool | None = None, + ) -> Mapping[str, Any]: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + preview_output = index_processor.format_preview(chunks) + + # If dataset is provided, try to enrich preview with summaries + if dataset and variable_pool: + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document: + # Query summaries for this document + summaries = ( + db.session.query(DocumentSegmentSummary) + .filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + .all() + ) + + if summaries: + # Create a map of segment content to summary for matching + # Use content matching as chunks in preview might not be indexed yet + summary_by_content = {} + for summary in summaries: + segment = ( + db.session.query(DocumentSegment) + .filter_by(id=summary.chunk_id, dataset_id=dataset.id) + .first() + ) + if segment: + # Normalize content for matching (strip whitespace) + normalized_content = segment.content.strip() + summary_by_content[normalized_content] = summary.summary_content + + # Enrich preview with summaries by content matching + if "preview" in preview_output and isinstance(preview_output["preview"], list): + matched_count = 0 + for preview_item in preview_output["preview"]: + if "content" in preview_item: + # Normalize content for matching + normalized_chunk_content = preview_item["content"].strip() + if normalized_chunk_content in summary_by_content: + preview_item["summary"] = summary_by_content[normalized_chunk_content] + matched_count += 1 + + if matched_count > 0: + logger.info( + "Enriched preview with %s existing summaries (dataset: %s, document: %s)", + matched_count, + dataset.id, + document.id, + ) + + return preview_output @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8670a71aa..0827494a4 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": - if node_data.multiple_retrieval_config.reranking_model: - reranking_model = { - "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, - "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, - } - else: + match node_data.multiple_retrieval_config.reranking_mode: + case "reranking_model": + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None + weights = None + case "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None - weights = None - elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": - if node_data.multiple_retrieval_config.weights is None: - raise ValueError("weights is required") - reranking_model = None - vector_setting = node_data.multiple_retrieval_config.weights.vector_setting - weights = { - "vector_setting": { - "vector_weight": vector_setting.vector_weight, - "embedding_provider_name": vector_setting.embedding_provider_name, - "embedding_model_name": vector_setting.embedding_model_name, - }, - "keyword_setting": { - "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - }, - } - else: - reranking_model = None - weights = None + vector_setting = node_data.multiple_retrieval_config.weights.vector_setting + weights = { + "vector_setting": { + "vector_weight": vector_setting.vector_weight, + "embedding_provider_name": vector_setting.embedding_provider_name, + "embedding_model_name": vector_setting.embedding_model_name, + }, + "keyword_setting": { + "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight + }, + } + case _: + reranking_model = None + weights = None all_documents = dataset_retrieval.multiple_retrieve( app_id=self.app_id, tenant_id=self.tenant_id, @@ -419,6 +420,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: source["content"] = segment.get_sign_content() + # Add summary if available + if record.summary: + source["summary"] = record.summary retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( @@ -450,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) filters: list[Any] = [] metadata_condition = None - if node_data.metadata_filtering_mode == "disabled": - return None, None, usage - elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( - dataset_ids, query, node_data - ) - usage = self._merge_usage(usage, automatic_usage) - if automatic_metadata_filters: - conditions = [] - for sequence, filter in enumerate(automatic_metadata_filters): - DatasetRetrieval.process_metadata_filter_func( - sequence, - filter.get("condition", ""), - filter.get("metadata_name", ""), - filter.get("value"), - filters, - ) - conditions.append( - Condition( - name=filter.get("metadata_name"), # type: ignore - comparison_operator=filter.get("condition"), # type: ignore - value=filter.get("value"), - ) - ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator - if node_data.metadata_filtering_conditions - else "or", - conditions=conditions, + match node_data.metadata_filtering_mode: + case "disabled": + return None, None, usage + case "automatic": + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data ) - elif node_data.metadata_filtering_mode == "manual": - if node_data.metadata_filtering_conditions: - conditions = [] - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore - metadata_name = condition.name - expected_value = condition.value - if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: - expected_value = expected_value.value - elif expected_value.value_type == "string": - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() - else: - raise ValueError("Invalid expected metadata value type") - conditions.append( - Condition( - name=metadata_name, - comparison_operator=condition.comparison_operator, - value=expected_value, + usage = self._merge_usage(usage, automatic_usage) + if automatic_metadata_filters: + conditions = [] + for sequence, filter in enumerate(automatic_metadata_filters): + DatasetRetrieval.process_metadata_filter_func( + sequence, + filter.get("condition", ""), + filter.get("metadata_name", ""), + filter.get("value"), + filters, ) + conditions.append( + Condition( + name=filter.get("metadata_name"), # type: ignore + comparison_operator=filter.get("condition"), # type: ignore + value=filter.get("value"), + ) + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator + if node_data.metadata_filtering_conditions + else "or", + conditions=conditions, ) - filters = DatasetRetrieval.process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + case "manual": + if node_data.metadata_filtering_conditions: + conditions = [] + for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore + metadata_name = condition.name + expected_value = condition.value + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template( + expected_value + ).value[0] + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() + else: + raise ValueError("Invalid expected metadata value type") + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, + ) + ) + filters = DatasetRetrieval.process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator, + conditions=conditions, ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator, - conditions=conditions, - ) - else: - raise ValueError("Invalid metadata filtering mode") + case _: + raise ValueError("Invalid metadata filtering mode") if filters: if ( node_data.metadata_filtering_conditions diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 813d898b9..235f5b9c5 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: case "name": return lambda x: x.filename or "" case "type": - return lambda x: x.type + return lambda x: str(x.type) case "extension": return lambda x: x.extension or "" case "mime_type": return lambda x: x.mime_type or "" case "transfer_method": - return lambda x: x.transfer_method + return lambda x: str(x.transfer_method) case "url": return lambda x: x.remote_url or "" case "related_id": @@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) elif key == "size" and isinstance(value, str): - extract_func = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) + extract_number = _get_file_extract_number_func(key=key) + return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) else: raise InvalidKeyError(f"Invalid key: {key}") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index dfb55dcd8..beccf7934 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -685,6 +685,8 @@ class LLMNode(Node[LLMNodeData]): if "content" not in item: raise InvalidContextStructureError(f"Invalid context structure: {item}") + if item.get("summary"): + context_str += item["summary"] + "\n" context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) @@ -746,6 +748,7 @@ class LLMNode(Node[LLMNodeData]): page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), files=context_dict.get("files"), + summary=context_dict.get("summary"), ) return source @@ -849,18 +852,16 @@ class LLMNode(Node[LLMNodeData]): # Insert histories into the prompt prompt_content = prompt_messages[0].content # For issue #11247 - Check if prompt content is a string or a list - prompt_content_type = type(prompt_content) - if prompt_content_type == str: + if isinstance(prompt_content, str): prompt_content = str(prompt_content) if "#histories#" in prompt_content: prompt_content = prompt_content.replace("#histories#", memory_text) else: prompt_content = memory_text + "\n" + prompt_content prompt_messages[0].content = prompt_content - elif prompt_content_type == list: - prompt_content = prompt_content if isinstance(prompt_content, list) else [] + elif isinstance(prompt_content, list): for content_item in prompt_content: - if content_item.type == PromptMessageContentType.TEXT: + if isinstance(content_item, TextPromptMessageContent): if "#histories#" in content_item.data: content_item.data = content_item.data.replace("#histories#", memory_text) else: @@ -870,13 +871,12 @@ class LLMNode(Node[LLMNodeData]): # Add current query to the prompt message if sys_query: - if prompt_content_type == str: + if isinstance(prompt_content, str): prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content - elif prompt_content_type == list: - prompt_content = prompt_content if isinstance(prompt_content, list) else [] + elif isinstance(prompt_content, list): for content_item in prompt_content: - if content_item.type == PromptMessageContentType.TEXT: + if isinstance(content_item, TextPromptMessageContent): content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") @@ -1030,14 +1030,14 @@ class LLMNode(Node[LLMNodeData]): if typed_node_data.prompt_config: enable_jinja = False - if isinstance(prompt_template, list): + if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + if prompt_template.edition_type == "jinja2": + enable_jinja = True + else: for prompt in prompt_template: if prompt.edition_type == "jinja2": enable_jinja = True break - else: - if prompt_template.edition_type == "jinja2": - enable_jinja = True if enable_jinja: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py index e7dcf62fc..2ad39e0ab 100644 --- a/api/core/workflow/nodes/protocols.py +++ b/api/core/workflow/nodes/protocols.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Any, Protocol import httpx @@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol): @property def request_error(self) -> type[Exception]: ... - def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... class FileManagerProtocol(Protocol): diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index c1cfbb1ed..8fe33c240 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -54,8 +54,8 @@ class ToolNodeData(BaseNodeData, ToolEntity): for val in value: if not isinstance(val, str): raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, str | int | float | bool | dict): - raise ValueError("value must be a string, int, float, bool or dict") + elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))): + raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}") return typ tool_parameters: dict[str, ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 68ac60e4f..60d76db9b 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]): result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + selector_key = ".".join(input.value) + result[f"#{selector_key}#"] = input.value + case "constant": + pass result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 401cecc16..acf0ee683 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -6,12 +6,13 @@ import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import Any, Protocol +from typing import Any, ClassVar, Protocol from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.runtime.variable_pool import VariablePool @@ -103,14 +104,33 @@ class ResponseStreamCoordinatorProtocol(Protocol): ... +class NodeProtocol(Protocol): + """Structural interface for graph nodes.""" + + id: str + state: NodeState + execution_type: NodeExecutionType + node_type: ClassVar[NodeType] + + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... + + +class EdgeProtocol(Protocol): + id: str + state: NodeState + tail: str + head: str + source_handle: str + + class GraphProtocol(Protocol): """Structural interface required from graph instances attached to the runtime state.""" - nodes: Mapping[str, object] - edges: Mapping[str, object] - root_node: object + nodes: Mapping[str, NodeProtocol] + edges: Mapping[str, EdgeProtocol] + root_node: NodeProtocol - def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... + def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... @dataclass(slots=True) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 43f15f6fd..4b1845cda 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -144,11 +144,11 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - node_config = dict(workflow.get_node_config_by_id(node_id)) - node_config_data = node_config.get("data", {}) + node_config = workflow.get_node_config_by_id(node_id) + node_config_data = node_config["data"] # Get node type - node_type = NodeType(node_config_data.get("type")) + node_type = NodeType(node_config_data["type"]) # init graph init params and runtime state graph_init_params = GraphInitParams( diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 216712359..fd4104a0f 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -103,6 +103,8 @@ def init_app(app: DifyApp) -> Celery: "tasks.async_workflow_tasks", # trigger workers "tasks.trigger_processing_tasks", # async trigger processing "tasks.extend.update_account_money_when_workflow_node_execution_created_extend", # 二开部分 - workflow计费任务 + "tasks.generate_summary_index_task", # summary index generation + "tasks.regenerate_summary_index_task", # summary index regeneration ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py index e6c1bc6be..ab4d23a07 100644 --- a/api/extensions/ext_fastopenapi.py +++ b/api/extensions/ext_fastopenapi.py @@ -27,10 +27,13 @@ def init_app(app: DifyApp) -> None: ) # Ensure route decorators are evaluated. + import controllers.console.init_validate as init_validate_module import controllers.console.ping as ping_module - from controllers.console import setup + from controllers.console import remote_files, setup + _ = init_validate_module _ = ping_module + _ = remote_files _ = setup router.include_router(console_router, prefix="/console/api") diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index c1608f58a..18eed4e48 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage): """ content = self.load_once(filename) - with Path(target_filepath).open("wb") as f: - f.write(content) + Path(target_filepath).write_bytes(content) logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index e69306dcb..a64695072 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,36 +1,69 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -annotation_fields = { - "id": fields.String, - "question": fields.String, - "answer": fields.Raw(attribute="content"), - "hit_count": fields.Integer, - "created_at": TimestampField, - # 'account': fields.Nested(simple_account_fields, allow_null=True) -} +from pydantic import BaseModel, ConfigDict, Field, field_validator -def build_annotation_model(api_or_ns: Namespace): - """Build the annotation model for the API or Namespace.""" - return api_or_ns.model("Annotation", annotation_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), -} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -annotation_hit_history_fields = { - "id": fields.String, - "source": fields.String, - "score": fields.Float, - "question": fields.String, - "created_at": TimestampField, - "match": fields.String(attribute="annotation_question"), - "response": fields.String(attribute="annotation_content"), -} -annotation_hit_history_list_fields = { - "data": fields.List(fields.Nested(annotation_hit_history_fields)), -} +class Annotation(ResponseModel): + id: str + question: str | None = None + answer: str | None = Field(default=None, validation_alias="content") + hit_count: int | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationList(ResponseModel): + data: list[Annotation] + has_more: bool + limit: int + total: int + page: int + + +class AnnotationExportList(ResponseModel): + data: list[Annotation] + + +class AnnotationHitHistory(ResponseModel): + id: str + source: str | None = None + score: float | None = None + question: str | None = None + created_at: int | None = None + match: str | None = Field(default=None, validation_alias="annotation_question") + response: str | None = Field(default=None, validation_alias="annotation_content") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationHitHistoryList(ResponseModel): + data: list[AnnotationHitHistory] + has_more: bool + limit: int + total: int + page: int diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 1e5ec7d20..ff6578098 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -39,6 +39,14 @@ dataset_retrieval_model_fields = { "score_threshold_enabled": fields.Boolean, "score_threshold": fields.Float, } + +dataset_summary_index_fields = { + "enable": fields.Boolean, + "model_name": fields.String, + "model_provider_name": fields.String, + "summary_prompt": fields.String, +} + external_retrieval_model_fields = { "top_k": fields.Integer, "score_threshold": fields.Float, @@ -83,6 +91,7 @@ dataset_detail_fields = { "embedding_model_provider": fields.String, "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "summary_index_setting": fields.Nested(dataset_summary_index_fields), "tags": fields.List(fields.Nested(tag_fields)), "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 9be59f745..35a2a04f3 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -33,6 +33,11 @@ document_fields = { "hit_count": fields.Integer, "doc_form": fields.String, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + # Whether this document needs summary index generation + "need_summary": fields.Boolean, } document_with_segments_fields = { @@ -60,6 +65,10 @@ document_with_segments_fields = { "completed_segments": fields.Integer, "total_segments": fields.Integer, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + "need_summary": fields.Boolean, # Whether this document needs summary index generation } dataset_and_document_fields = { diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 5389b0213..effe7bfb2 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,7 @@ -from flask_restx import Namespace, fields +from __future__ import annotations + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict simple_end_user_fields = { "id": fields.String, @@ -8,5 +11,18 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleEndUser", simple_end_user_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleEndUser(ResponseModel): + id: str + type: str + is_anonymous: bool + session_id: str | None = None diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index e70f9fa72..0b5499283 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -58,4 +58,5 @@ hit_testing_record_fields = { "score": fields.Float, "tsne_position": fields.Raw, "files": fields.List(fields.Nested(files_fields)), + "summary": fields.String, # Summary content if retrieved via summary index } diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 25160927e..11d9a1a2f 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,6 +1,11 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import AvatarUrlField, TimestampField +from datetime import datetime + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict, computed_field, field_validator + +from core.file import helpers as file_helpers simple_account_fields = { "id": fields.String, @@ -9,36 +14,78 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleAccount", simple_account_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -account_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "is_password_set": fields.Boolean, - "interface_language": fields.String, - "interface_theme": fields.String, - "timezone": fields.String, - "last_login_at": TimestampField, - "last_login_ip": fields.String, - "created_at": TimestampField, -} +def _build_avatar_url(avatar: str | None) -> str | None: + if avatar is None: + return None + if avatar.startswith(("http://", "https://")): + return avatar + return file_helpers.get_signed_file_url(avatar) -account_with_role_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "last_login_at": TimestampField, - "last_active_at": TimestampField, - "created_at": TimestampField, - "role": fields.String, - "status": fields.String, -} -account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleAccount(ResponseModel): + id: str + name: str + email: str + + +class _AccountAvatar(ResponseModel): + avatar: str | None = None + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def avatar_url(self) -> str | None: + return _build_avatar_url(self.avatar) + + +class Account(_AccountAvatar): + id: str + name: str + email: str + is_password_set: bool + interface_language: str | None = None + interface_theme: str | None = None + timezone: str | None = None + last_login_at: int | None = None + last_login_ip: str | None = None + created_at: int | None = None + + @field_validator("last_login_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRole(_AccountAvatar): + id: str + name: str + email: str + last_login_at: int | None = None + last_active_at: int | None = None + created_at: int | None = None + role: str + status: str + + @field_validator("last_login_at", "last_active_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRoleList(ResponseModel): + accounts: list[AccountWithRole] diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index c81e482f7..e6c3b42f9 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -36,6 +36,7 @@ class RetrieverResource(ResponseModel): segment_position: int | None = None index_node_hash: str | None = None content: str | None = None + summary: str | None = None created_at: int | None = None @field_validator("created_at", mode="before") diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 56d6b6837..2ce9fb154 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -49,4 +49,5 @@ segment_fields = { "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "attachments": fields.List(fields.Nested(attachment_fields)), + "summary": fields.String, # Summary content for the segment } diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index e359a4408..7cb64e5ca 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,12 +1,20 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} +from pydantic import BaseModel, ConfigDict -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class DataSetTag(ResponseModel): + id: str + name: str + type: str + binding_count: str | None = None diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index ae7035632..d0e762f62 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,7 +1,7 @@ from flask_restx import Namespace, fields -from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields -from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, @@ -25,17 +25,9 @@ workflow_app_log_partial_fields = { def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_app_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowAppLogPartial", copied_fields) @@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = { def build_workflow_archived_log_partial_model(api_or_ns: Namespace): """Build the workflow archived log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_archived_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index bb7fa25c0..ef26699fb 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -136,7 +136,7 @@ class PKCS1OAepCipher: # Step 3a (OS2IP) em_int = bytes_to_long(em) # Step 3b (RSAEP) - m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute] + m_int: int = gmpy2.powmod(em_int, self._key.e, self._key.n) # type: ignore[attr-defined] # Step 3c (I2OSP) c = long_to_bytes(m_int, k) return c @@ -169,7 +169,7 @@ class PKCS1OAepCipher: ct_int = bytes_to_long(ciphertext) # Step 2b (RSADP) # m_int = self._key._decrypt(ct_int) - m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute] + m_int: int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # type: ignore[attr-defined] # Complete step 2c (I2OSP) em = long_to_bytes(m_int, k) # Step 3a diff --git a/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py new file mode 100644 index 000000000..c6c72859d --- /dev/null +++ b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py @@ -0,0 +1,107 @@ +"""add summary index feature + +Revision ID: 788d3099ae3a +Revises: 9d77545f524e +Create Date: 2026-01-27 18:15:45.277928 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '788d3099ae3a' +down_revision = '9d77545f524e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey') + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + else: + # MySQL: Use compatible syntax + op.create_table( + 'document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'), + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('need_summary') + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('summary_index_setting') + + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.drop_index('document_segment_summaries_status_idx') + batch_op.drop_index('document_segment_summaries_document_id_idx') + batch_op.drop_index('document_segment_summaries_dataset_id_idx') + batch_op.drop_index('document_segment_summaries_chunk_id_idx') + + op.drop_table('document_segment_summaries') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 62f11b8c7..e7da2961b 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -72,6 +72,7 @@ class Dataset(Base): keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10")) collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(AdjustedJSON, nullable=True) + summary_index_setting = mapped_column(AdjustedJSON, nullable=True) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) icon_info = mapped_column(AdjustedJSON, nullable=True) runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'")) @@ -419,6 +420,7 @@ class Document(Base): doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) + need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base): segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DocumentSegmentSummary(Base): + __tablename__ = "document_segment_summaries" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), + sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"), + sa.Index("document_segment_summaries_document_id_idx", "document_id"), + sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"), + sa.Index("document_segment_summaries_status_idx", "status"), + ) + + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # corresponds to DocumentSegment.id or parent chunk id + chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + summary_content: Mapped[str] = mapped_column(LongText, nullable=True) + summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) + summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'")) + error: Mapped[str] = mapped_column(LongText, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + def __repr__(self): + return f"" diff --git a/api/models/model.py b/api/models/model.py index 0d796bfbf..cbc19bf6f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -695,16 +695,22 @@ class AccountTrialAppRecord(Base): return user -class ExporleBanner(Base): +class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) - content = mapped_column(sa.JSON, nullable=False) - link = mapped_column(String(255), nullable=False) - sort = mapped_column(sa.Integer, nullable=False) - status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying")) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) + content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) + link: Mapped[str] = mapped_column(String(255), nullable=False) + sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) + status: Mapped[str] = mapped_column( + sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" + ) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + language: Mapped[str] = mapped_column( + String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US" + ) class OAuthProviderApp(TypeBase): diff --git a/api/models/workflow.py b/api/models/workflow.py index 330fc9c0f..a14e9d0b4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -30,6 +30,7 @@ from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) +from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause from core.workflow.enums import NodeType from extensions.ext_storage import Storage @@ -230,7 +231,7 @@ class Workflow(Base): # bug # - `_get_graph_and_variable_pool_for_single_node_run`. return json.loads(self.graph) if self.graph else {} - def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: + def get_node_config_by_id(self, node_id: str) -> NodeConfigDict: """Extract a node configuration from the workflow graph by node ID. A node configuration is a dictionary containing the node's properties, including the node's id, title, and its data as a dict. @@ -248,8 +249,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - assert isinstance(node_config, dict) - return node_config + return NodeConfigDictAdapter.validate_python(node_config) @staticmethod def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: diff --git a/api/pyproject.toml b/api/pyproject.toml index c88616e0f..a13b42903 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.11.4" +version = "1.12.1" requires-python = ">=3.11,<3.13" dependencies = [ @@ -86,7 +86,7 @@ dependencies = [ "sseclient-py~=1.8.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", - "flask-restx~=1.3.0", + "flask-restx~=1.3.2", "packaging~=23.2", "croniter>=6.0.0", "weaviate-client==4.17.0", @@ -125,7 +125,7 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=38.2.0", "lxml-stubs~=0.5.1", - "ty~=0.0.1a19", + "ty>=0.0.14", "basedpyright~=1.31.0", "ruff~=0.14.0", "pytest~=8.3.2", @@ -154,7 +154,7 @@ dev = [ "types-openpyxl~=3.1.5", "types-pexpect~=4.9.0", "types-protobuf~=5.29.1", - "types-psutil~=7.0.0", + "types-psutil~=7.2.2", "types-psycopg2~=2.9.21", "types-pygments~=2.19.0", "types-pymysql~=1.1.0", @@ -184,6 +184,7 @@ dev = [ # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", + "pytest-xdist>=3.8.0", ] ############################################################ diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 56e9cc6a0..8ebc87a67 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -158,7 +158,7 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) ) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) - return annotations.items, annotations.total + return annotations.items, annotations.total or 0 @classmethod def export_annotation_list_by_app_id(cls, app_id: str): @@ -524,7 +524,7 @@ class AppAnnotationService: annotation_hit_histories = db.paginate( select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False ) - return annotation_hit_histories.items, annotation_hit_histories.total + return annotation_hit_histories.items, annotation_hit_histories.total or 0 @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index be9a0e927..1ea6c4e1c 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config +from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.file import helpers as file_helpers from core.helper.name_generator import generate_incremental_name @@ -89,6 +90,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t from tasks.document_indexing_update_task import document_indexing_update_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.regenerate_summary_index_task import regenerate_summary_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -211,6 +213,7 @@ class DatasetService: embedding_model_provider: str | None = None, embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, + summary_index_setting: dict | None = None, ): # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): @@ -253,6 +256,8 @@ class DatasetService: dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting db.session.add(dataset) db.session.flush() @@ -476,6 +481,11 @@ class DatasetService: if external_retrieval_model: dataset.retrieval_model = external_retrieval_model + # Update summary index setting if provided + summary_index_setting = data.get("summary_index_setting", None) + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting + # Update basic dataset properties dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", dataset.description) @@ -564,6 +574,9 @@ class DatasetService: # update Retrieval model if data.get("retrieval_model"): filtered_data["retrieval_model"] = data["retrieval_model"] + # update summary index setting + if data.get("summary_index_setting"): + filtered_data["summary_index_setting"] = data.get("summary_index_setting") # update icon info if data.get("icon_info"): filtered_data["icon_info"] = data.get("icon_info") @@ -572,12 +585,27 @@ class DatasetService: db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.commit() + # Reload dataset to get updated values + db.session.refresh(dataset) + # update pipeline knowledge base node data DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id) # Trigger vector index task if indexing technique changed if action: deal_dataset_vector_index_task.delay(dataset.id, action) + # If embedding_model changed, also regenerate summary vectors + if action == "update": + regenerate_summary_index_task.delay( + dataset.id, + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + # Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries. + # The new setting will only apply to: + # 1. New documents added after the setting change + # 2. Manual summary generation requests return dataset @@ -616,6 +644,7 @@ class DatasetService: knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue] knowledge_index_node_data["keyword_number"] = dataset.keyword_number + knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting node["data"] = knowledge_index_node_data updated = True except Exception: @@ -854,6 +883,54 @@ class DatasetService: ) filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod + def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool: + """ + Check if summary_index_setting model (model_name or model_provider_name) has changed. + + Args: + dataset: Current dataset object + data: Update data dictionary + + Returns: + bool: True if summary model changed, False otherwise + """ + # Check if summary_index_setting is being updated + if "summary_index_setting" not in data or data.get("summary_index_setting") is None: + return False + + new_summary_setting = data.get("summary_index_setting") + old_summary_setting = dataset.summary_index_setting + + # If new setting is disabled, no need to regenerate + if not new_summary_setting or not new_summary_setting.get("enable"): + return False + + # If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate) + # Note: This task only regenerates existing summaries, not generates new ones + if not old_summary_setting: + return False + + # Compare model_name and model_provider_name + old_model_name = old_summary_setting.get("model_name") + old_model_provider = old_summary_setting.get("model_provider_name") + new_model_name = new_summary_setting.get("model_name") + new_model_provider = new_summary_setting.get("model_provider_name") + + # Check if model changed + if old_model_name != new_model_name or old_model_provider != new_model_provider: + logger.info( + "Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s", + dataset.id, + old_model_provider, + old_model_name, + new_model_provider, + new_model_name, + ) + return True + + return False + @staticmethod def update_rag_pipeline_dataset_settings( session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False @@ -889,6 +966,9 @@ class DatasetService: else: raise ValueError("Invalid index method") dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) else: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: @@ -994,6 +1074,9 @@ class DatasetService: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) session.commit() if action: @@ -1306,6 +1389,46 @@ class DocumentService: ).all() return documents + @staticmethod + def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int: + """ + Update need_summary field for multiple documents. + + This method handles the case where documents were created when summary_index_setting was disabled, + and need to be updated when summary_index_setting is later enabled. + + Args: + dataset_id: Dataset ID + document_ids: List of document IDs to update + need_summary: Value to set for need_summary field (default: True) + + Returns: + Number of documents updated + """ + if not document_ids: + return 0 + + document_id_list: list[str] = [str(document_id) for document_id in document_ids] + + with session_factory.create_session() as session: + updated_count = ( + session.query(Document) + .filter( + Document.id.in_(document_id_list), + Document.dataset_id == dataset_id, + Document.doc_form != "qa_model", # Skip qa_model documents + ) + .update({Document.need_summary: need_summary}, synchronize_session=False) + ) + session.commit() + logger.info( + "Updated need_summary to %s for %d documents in dataset %s", + need_summary, + updated_count, + dataset_id, + ) + return updated_count + @staticmethod def get_document_download_url(document: Document) -> str: """ @@ -1314,6 +1437,50 @@ class DocumentService: upload_file = DocumentService._get_upload_file_for_upload_file_document(document) return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) + @staticmethod + def enrich_documents_with_summary_index_status( + documents: Sequence[Document], + dataset: Dataset, + tenant_id: str, + ) -> None: + """ + Enrich documents with summary_index_status based on dataset summary index settings. + + This method calculates and sets the summary_index_status for each document that needs summary. + Documents that don't need summary or when summary index is disabled will have status set to None. + + Args: + documents: List of Document instances to enrich + dataset: Dataset instance containing summary_index_setting + tenant_id: Tenant ID for summary status lookup + """ + # Check if dataset has summary index enabled + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + + # Filter documents that need summary calculation + documents_need_summary = [doc for doc in documents if doc.need_summary is True] + document_ids_need_summary = [str(doc.id) for doc in documents_need_summary] + + # Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled) + summary_status_map: dict[str, str | None] = {} + if has_summary_index and document_ids_need_summary: + from services.summary_index_service import SummaryIndexService + + summary_status_map = SummaryIndexService.get_documents_summary_index_status( + document_ids=document_ids_need_summary, + dataset_id=dataset.id, + tenant_id=tenant_id, + ) + + # Add summary_index_status to each document + for document in documents: + if has_summary_index and document.need_summary is True: + # Get status from map, default to None (not queued yet) + document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined] + else: + # Return null if summary index is not enabled or document doesn't need summary + document.summary_index_status = None # type: ignore[attr-defined] + @staticmethod def prepare_document_batch_download_zip( *, @@ -1964,6 +2131,8 @@ class DocumentService: DuplicateDocumentIndexingTaskProxy( dataset.tenant_id, dataset.id, duplicate_document_ids ).delay() + # Note: Summary index generation is triggered in document_indexing_task after indexing completes + # to ensure segments are available. See tasks/document_indexing_task.py except LockNotOwnedError: pass @@ -2268,6 +2437,11 @@ class DocumentService: name: str, batch: str, ): + # Set need_summary based on dataset's summary_index_setting + need_summary = False + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + need_summary = True + document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -2281,6 +2455,7 @@ class DocumentService: created_by=account.id, doc_form=document_form, doc_language=document_language, + need_summary=need_summary, ) doc_metadata = {} if dataset.built_in_field_enabled: @@ -2505,6 +2680,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + summary_index_setting=knowledge_config.summary_index_setting, is_multimodal=knowledge_config.is_multimodal, ) @@ -2686,6 +2862,14 @@ class DocumentService: if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") + # valid summary index setting + summary_index_setting = args["process_rule"].get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable"): + if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]: + raise ValueError("Summary index model name is required") + if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]: + raise ValueError("Summary index model provider name is required") + @staticmethod def batch_update_document_status( dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user @@ -2794,14 +2978,15 @@ class DocumentService: """ now = naive_utc_now() - if action == "enable": - return DocumentService._prepare_enable_update(document, now) - elif action == "disable": - return DocumentService._prepare_disable_update(document, user, now) - elif action == "archive": - return DocumentService._prepare_archive_update(document, user, now) - elif action == "un_archive": - return DocumentService._prepare_unarchive_update(document, now) + match action: + case "enable": + return DocumentService._prepare_enable_update(document, now) + case "disable": + return DocumentService._prepare_disable_update(document, user, now) + case "archive": + return DocumentService._prepare_archive_update(document, user, now) + case "un_archive": + return DocumentService._prepare_unarchive_update(document, now) return None @@ -3154,6 +3339,35 @@ class SegmentService: if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # update summary index if summary is provided and has changed + if args.summary is not None: + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + if dataset.indexing_technique == "high_quality": + # Query existing summary from database + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + # Check if summary has changed + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, update it + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -3228,6 +3442,73 @@ class SegmentService: elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # Handle summary index when content changed + if dataset.indexing_technique == "high_quality": + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + if args.summary is None: + # User didn't provide summary, auto-regenerate if segment previously had summary + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + # Segment previously had summary, regenerate it with new content + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info("Auto-regenerated summary for segment %s after content change", segment.id) + except Exception: + logger.exception("Failed to auto-regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails + else: + # User provided summary, check if it has changed + # Manual summary updates are allowed even if summary_index_setting doesn't exist + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, use user-provided summary + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + logger.info("Updated summary for segment %s with user-provided content", segment.id) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails + else: + # Summary hasn't changed, regenerate based on new content + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info( + "Regenerated summary for segment %s after content change (summary unchanged)", + segment.id, + ) + except Exception: + logger.exception("Failed to regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails # update multimodel vector index VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: @@ -3342,56 +3623,57 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - if action == "enable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == False, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + match action: + case "enable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - elif action == "disable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == True, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.disabled_by = current_user.id - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + case "disable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk( @@ -3616,6 +3898,39 @@ class SegmentService: ) return result if isinstance(result, DocumentSegment) else None + @classmethod + def get_segments_by_document_and_dataset( + cls, + document_id: str, + dataset_id: str, + status: str | None = None, + enabled: bool | None = None, + ) -> Sequence[DocumentSegment]: + """ + Get segments for a document in a dataset with optional filtering. + + Args: + document_id: Document ID + dataset_id: Dataset ID + status: Optional status filter (e.g., "completed") + enabled: Optional enabled filter (True/False) + + Returns: + Sequence of DocumentSegment instances + """ + query = select(DocumentSegment).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + + if status is not None: + query = query.where(DocumentSegment.status == status) + + if enabled is not None: + query = query.where(DocumentSegment.enabled == enabled) + + return db.session.scalars(query).all() + class DatasetCollectionBindingService: @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 7959734e8..8dc5b9350 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel): data_source: DataSource | None = None process_rule: ProcessRule | None = None retrieval_model: RetrievalModel | None = None + summary_index_setting: dict | None = None doc_form: str = "text_model" doc_language: str = "English" embedding_model: str | None = None @@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel): regenerate_child_chunks: bool = False enabled: bool | None = None attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index cbb0efcc2..041ae4edb 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel): embedding_model: str = "" keyword_number: int | None = 10 retrieval_model: RetrievalSetting + # add summary index setting + summary_index_setting: dict | None = None @field_validator("embedding_model_provider", mode="before") @classmethod diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c1c6e204f..be1ce834f 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -343,6 +343,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() @@ -477,6 +480,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 8ea365e90..d0dfbc107 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -174,6 +174,10 @@ class RagPipelineTransformService: else: dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Copy summary_index_setting from dataset to knowledge_index node configuration + if dataset.summary_index_setting: + knowledge_configuration.summary_index_setting = dataset.summary_index_setting + knowledge_configuration_dict.update(knowledge_configuration.model_dump()) node["data"] = knowledge_configuration_dict return node diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py new file mode 100644 index 000000000..7c03ceed5 --- /dev/null +++ b/api/services/summary_index_service.py @@ -0,0 +1,1441 @@ +"""Summary index service for generating and managing document segment summaries.""" + +import logging +import time +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.orm import Session + +from core.db.session_factory import session_factory +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument + +logger = logging.getLogger(__name__) + + +class SummaryIndexService: + """Service for generating and managing summary indexes.""" + + @staticmethod + def generate_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for a single segment. + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + + Raises: + ValueError: If summary_index_setting is invalid or generation fails + """ + # Reuse the existing generate_summary method from ParagraphIndexProcessor + # Use lazy import to avoid circular import + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + # Get document language to ensure summary is generated in the correct language + # This is especially important for image-only chunks where text is empty or minimal + document_language = None + if segment.document and segment.document.doc_language: + document_language = segment.document.doc_language + + summary_content, usage = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=segment.content, + summary_index_setting=summary_index_setting, + segment_id=segment.id, + document_language=document_language, + ) + + if not summary_content: + raise ValueError("Generated summary is empty") + + return summary_content, usage + + @staticmethod + def create_summary_record( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + status: str = "generating", + ) -> DocumentSegmentSummary: + """ + Create or update a DocumentSegmentSummary record. + If a summary record already exists for this segment, it will be updated instead of creating a new one. + + Args: + segment: DocumentSegment to create summary for + dataset: Dataset containing the segment + summary_content: Generated summary content + status: Summary status (default: "generating") + + Returns: + Created or updated DocumentSegmentSummary instance + """ + with session_factory.create_session() as session: + # Check if summary record already exists + existing_summary = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if existing_summary: + # Update existing record + existing_summary.summary_content = summary_content + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + # Re-enable if it was disabled + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + session.flush() + return existing_summary + else: + # Create new record (enabled by default) + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + status=status, + enabled=True, # Explicitly set enabled to True + ) + session.add(summary_record) + session.flush() + return summary_record + + @staticmethod + def vectorize_summary( + summary_record: DocumentSegmentSummary, + segment: DocumentSegment, + dataset: Dataset, + session: Session | None = None, + ) -> None: + """ + Vectorize summary and store in vector database. + + Args: + summary_record: DocumentSegmentSummary record + segment: Original DocumentSegment + dataset: Dataset containing the segment + session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. + If not provided, creates a new session and commits automatically. + """ + if dataset.indexing_technique != "high_quality": + logger.warning( + "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", + dataset.id, + ) + return + + # Get summary_record_id for later session queries + summary_record_id = summary_record.id + # Save the original session parameter for use in error handling + original_session = session + logger.debug( + "Starting vectorization for segment %s, summary_record_id=%s, using_provided_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + + # Reuse existing index_node_id if available (like segment does), otherwise generate new one + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + # Reuse existing index_node_id (like segment behavior) + summary_index_node_id = old_summary_node_id + logger.debug("Reusing existing index_node_id %s for segment %s", summary_index_node_id, segment.id) + else: + # Generate new index node ID only for new summaries + summary_index_node_id = str(uuid.uuid4()) + logger.debug("Generated new index_node_id %s for segment %s", summary_index_node_id, segment.id) + + # Always regenerate hash (in case summary content changed) + summary_content = summary_record.summary_content + if not summary_content or not summary_content.strip(): + raise ValueError(f"Summary content is empty for segment {segment.id}, cannot vectorize") + summary_hash = helper.generate_text_hash(summary_content) + + # Delete old vector only if we're reusing the same index_node_id (to overwrite) + # If index_node_id changed, the old vector should have been deleted elsewhere + if old_summary_node_id and old_summary_node_id == summary_index_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.", + segment.id, + str(e), + ) + + # Calculate embedding tokens for summary (for logging and statistics) + embedding_tokens = 0 + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) + embedding_tokens = tokens_list[0] if tokens_list else 0 + except Exception as e: + logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) + + # Create document with summary content and metadata + summary_document = Document( + page_content=summary_content, + metadata={ + "doc_id": summary_index_node_id, + "doc_hash": summary_hash, + "dataset_id": dataset.id, + "document_id": segment.document_id, + "original_chunk_id": segment.id, # Key: link to original chunk + "doc_type": DocType.TEXT, + "is_summary": True, # Identifier for summary documents + }, + ) + + # Vectorize and store with retry mechanism for connection errors + max_retries = 3 + retry_delay = 2.0 + + for attempt in range(max_retries): + try: + logger.debug( + "Attempting to vectorize summary for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + vector = Vector(dataset) + # Use duplicate_check=False to ensure re-vectorization even if old vector still exists + # The old vector should have been deleted above, but if deletion failed, + # we still want to re-vectorize (upsert will overwrite) + vector.add_texts([summary_document], duplicate_check=False) + logger.debug( + "Successfully added summary vector to database for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + + # Log embedding token usage + if embedding_tokens > 0: + logger.info( + "Summary embedding for segment %s used %s tokens", + segment.id, + embedding_tokens, + ) + + # Success - update summary record with index node info + # Use provided session if available, otherwise create a new one + use_provided_session = session is not None + if not use_provided_session: + logger.debug("Creating new session for vectorization of segment %s", segment.id) + session_context = session_factory.create_session() + session = session_context.__enter__() + else: + logger.debug("Using provided session for vectorization of segment %s", segment.id) + session_context = None # Don't use context manager for provided session + + # At this point, session is guaranteed to be not None + # Type narrowing: session is definitely not None after the if/else above + if session is None: + raise RuntimeError("Session should not be None at this point") + + try: + # Declare summary_record_in_session variable + summary_record_in_session: DocumentSegmentSummary | None + + # If using provided session, merge the summary_record into it + if use_provided_session: + # Merge the summary_record into the provided session + logger.debug( + "Merging summary_record (id=%s) into provided session for segment %s", + summary_record_id, + segment.id, + ) + summary_record_in_session = session.merge(summary_record) + logger.debug( + "Successfully merged summary_record for segment %s, merged_id=%s", + segment.id, + summary_record_in_session.id, + ) + else: + # Query the summary record in the new session + logger.debug( + "Querying summary_record by id=%s for segment %s in new session", + summary_record_id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + + if not summary_record_in_session: + # Record not found - try to find by chunk_id and dataset_id instead + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if not summary_record_in_session: + # Still not found - create a new one using the parameter data + logger.warning( + "Summary record not found in database for segment %s (id=%s), creating new one. " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + summary_record_in_session = DocumentSegmentSummary( + id=summary_record_id, # Use the same ID if available + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + summary_index_node_id=summary_index_node_id, + summary_index_node_hash=summary_hash, + tokens=embedding_tokens, + status="completed", + enabled=True, + ) + session.add(summary_record_in_session) + logger.info( + "Created new summary record (id=%s) for segment %s after vectorization", + summary_record_id, + segment.id, + ) + else: + # Found by chunk_id - update it + logger.info( + "Found summary record for segment %s by chunk_id " + "(id mismatch: expected %s, found %s). " + "This may indicate the record was created in a different session.", + segment.id, + summary_record_id, + summary_record_in_session.id, + ) + else: + logger.debug( + "Found summary_record (id=%s) for segment %s in new session", + summary_record_id, + segment.id, + ) + + # At this point, summary_record_in_session is guaranteed to be not None + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + + # Update all fields including summary_content + # Always use the summary_content from the parameter (which is the latest from outer session) + # rather than relying on what's in the database, in case outer session hasn't committed yet + summary_record_in_session.summary_index_node_id = summary_index_node_id + summary_record_in_session.summary_index_node_hash = summary_hash + summary_record_in_session.tokens = embedding_tokens # Save embedding tokens + summary_record_in_session.status = "completed" + # Ensure summary_content is preserved (use the latest from summary_record parameter) + # This is critical: use the parameter value, not the database value + summary_record_in_session.summary_content = summary_content + # Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + session.add(summary_record_in_session) + + # Only commit if we created the session ourselves + if not use_provided_session: + logger.debug("Committing session for segment %s (self-created session)", segment.id) + session.commit() + logger.debug("Successfully committed session for segment %s", segment.id) + else: + # When using provided session, flush to ensure changes are written to database + # This prevents refresh() from overwriting our changes + logger.debug( + "Flushing session for segment %s (using provided session, caller will commit)", + segment.id, + ) + session.flush() + logger.debug("Successfully flushed session for segment %s", segment.id) + # If using provided session, let the caller handle commit + + logger.info( + "Successfully vectorized summary for segment %s, index_node_id=%s, index_node_hash=%s, " + "tokens=%s, summary_record_id=%s, use_provided_session=%s", + segment.id, + summary_index_node_id, + summary_hash, + embedding_tokens, + summary_record_in_session.id, + use_provided_session, + ) + # Update the original object for consistency + summary_record.summary_index_node_id = summary_index_node_id + summary_record.summary_index_node_hash = summary_hash + summary_record.tokens = embedding_tokens + summary_record.status = "completed" + summary_record.summary_content = summary_content + if summary_record_in_session.updated_at: + summary_record.updated_at = summary_record_in_session.updated_at + finally: + # Only close session if we created it ourselves + if not use_provided_session and session_context: + session_context.__exit__(None, None, None) + # Success, exit function + return + + except (ConnectionError, Exception) as e: + error_str = str(e).lower() + # Check if it's a connection-related error that might be transient + is_connection_error = any( + keyword in error_str + for keyword in [ + "connection", + "disconnected", + "timeout", + "network", + "could not connect", + "server disconnected", + "weaviate", + ] + ) + + if is_connection_error and attempt < max_retries - 1: + # Retry for connection errors + wait_time = retry_delay * (2**attempt) # Exponential backoff + logger.warning( + "Vectorization attempt %s/%s failed for segment %s (connection error): %s. " + "Retrying in %.1f seconds...", + attempt + 1, + max_retries, + segment.id, + str(e), + wait_time, + ) + time.sleep(wait_time) + continue + else: + # Final attempt failed or non-connection error - log and update status + logger.error( + "Failed to vectorize summary for segment %s after %s attempts: %s. " + "summary_record_id=%s, index_node_id=%s, use_provided_session=%s", + segment.id, + attempt + 1, + str(e), + summary_record_id, + summary_index_node_id, + session is not None, + exc_info=True, + ) + # Update error status in session + # Use the original_session saved at function start (the function parameter) + logger.debug( + "Updating error status for segment %s, summary_record_id=%s, has_original_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + # Always create a new session for error handling to avoid issues with closed sessions + # Even if original_session was provided, we create a new one for safety + with session_factory.create_session() as error_session: + # Try to find the record by id first + # Note: Using assignment only (no type annotation) to avoid redeclaration error + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + if not summary_record_in_session: + # Try to find by chunk_id and dataset_id + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(e)}" + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + error_session.add(summary_record_in_session) + error_session.commit() + logger.info( + "Updated error status in new session for segment %s, record_id=%s", + segment.id, + summary_record_in_session.id, + ) + # Update the original object for consistency + summary_record.status = "error" + summary_record.error = summary_record_in_session.error + summary_record.updated_at = summary_record_in_session.updated_at + else: + logger.warning( + "Could not update error status: summary record not found for segment %s (id=%s). " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + raise + + @staticmethod + def batch_create_summary_records( + segments: list[DocumentSegment], + dataset: Dataset, + status: str = "not_started", + ) -> None: + """ + Batch create summary records for segments with specified status. + If a record already exists, update its status. + + Args: + segments: List of DocumentSegment instances + dataset: Dataset containing the segments + status: Initial status for the records (default: "not_started") + """ + segment_ids = [segment.id for segment in segments] + if not segment_ids: + return + + with session_factory.create_session() as session: + # Query existing summary records + existing_summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .all() + ) + existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} + + # Create or update records + for segment in segments: + existing_summary = existing_summary_map.get(segment.id) + if existing_summary: + # Update existing record + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + else: + # Create new record + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=None, # Will be filled later + status=status, + enabled=True, + ) + session.add(summary_record) + + # Commit the batch created records + session.commit() + + @staticmethod + def update_summary_record_error( + segment: DocumentSegment, + dataset: Dataset, + error: str, + ) -> None: + """ + Update summary record with error status. + + Args: + segment: DocumentSegment + dataset: Dataset containing the segment + error: Error message + """ + with session_factory.create_session() as session: + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + summary_record.status = "error" + summary_record.error = error + session.add(summary_record) + session.commit() + else: + logger.warning("Summary record not found for segment %s when updating error", segment.id) + + @staticmethod + def generate_and_vectorize_summary( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> DocumentSegmentSummary: + """ + Generate summary for a segment and vectorize it. + Assumes summary record already exists (created by batch_create_summary_records). + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Created DocumentSegmentSummary instance + + Raises: + ValueError: If summary generation fails + """ + with session_factory.create_session() as session: + try: + # Get or refresh summary record in this session + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if not summary_record_in_session: + # If not found, create one + logger.warning("Summary record not found for segment %s, creating one", segment.id) + summary_record_in_session = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content="", + status="generating", + enabled=True, + ) + session.add(summary_record_in_session) + session.flush() + + # Update status to "generating" + summary_record_in_session.status = "generating" + summary_record_in_session.error = None # type: ignore[assignment] + session.add(summary_record_in_session) + # Don't flush here - wait until after vectorization succeeds + + # Generate summary (returns summary_content and llm_usage) + summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment( + segment, dataset, summary_index_setting + ) + + # Update summary content + summary_record_in_session.summary_content = summary_content + session.add(summary_record_in_session) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Log LLM usage for summary generation + if llm_usage and llm_usage.total_tokens > 0: + logger.info( + "Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)", + segment.id, + llm_usage.total_tokens, + llm_usage.prompt_tokens, + llm_usage.completion_tokens, + ) + + # Vectorize summary (will delete old vector if exists before creating new one) + # Pass the session-managed record to vectorize_summary + # vectorize_summary will update status to "completed" and tokens in its own session + # vectorize_summary will also ensure summary_content is preserved + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record_in_session, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record_in_session) + # Commit the session + # (summary_record_in_session should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully generated and vectorized summary for segment %s", segment.id) + return summary_record_in_session + except Exception as vectorize_error: + # If vectorization fails, update status to error in current session + logger.exception("Failed to vectorize summary for segment %s", segment.id) + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}" + session.add(summary_record_in_session) + session.commit() + raise + + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = str(e) + session.add(summary_record_in_session) + session.commit() + raise + + @staticmethod + def generate_summaries_for_document( + dataset: Dataset, + document: DatasetDocument, + summary_index_setting: dict, + segment_ids: list[str] | None = None, + only_parent_chunks: bool = False, + ) -> list[DocumentSegmentSummary]: + """ + Generate summaries for all segments in a document including vectorization. + + Args: + dataset: Dataset containing the document + document: DatasetDocument to generate summaries for + summary_index_setting: Summary index configuration + segment_ids: Optional list of specific segment IDs to process + only_parent_chunks: If True, only process parent chunks (for parent-child mode) + + Returns: + List of created DocumentSegmentSummary instances + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", + dataset.id, + dataset.indexing_technique, + ) + return [] + + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info("Summary index is disabled for dataset %s", dataset.id) + return [] + + # Skip qa_model documents + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + return [] + + logger.info( + "Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s", + document.id, + dataset.id, + len(segment_ids) if segment_ids else "all", + only_parent_chunks, + ) + + with session_factory.create_session() as session: + # Query segments (only enabled segments) + query = session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, # Only generate summaries for enabled segments + ) + + if segment_ids: + query = query.filter(DocumentSegment.id.in_(segment_ids)) + + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return [] + + # Batch create summary records with "not_started" status before processing + # This ensures all records exist upfront, allowing status tracking + SummaryIndexService.batch_create_summary_records( + segments=segments, + dataset=dataset, + status="not_started", + ) + + summary_records = [] + + for segment in segments: + # For parent-child mode, only process parent chunks + # In parent-child mode, all DocumentSegments are parent chunks, + # so we process all of them. Child chunks are stored in ChildChunk table + # and are not DocumentSegments, so they won't be in the segments list. + # This check is mainly for clarity and future-proofing. + if only_parent_chunks: + # In parent-child mode, all segments in the query are parent chunks + # Child chunks are not DocumentSegments, so they won't appear here + # We can process all segments + pass + + try: + summary_record = SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + summary_records.append(summary_record) + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + SummaryIndexService.update_summary_record_error( + segment=segment, + dataset=dataset, + error=str(e), + ) + # Continue with other segments + continue + + logger.info( + "Completed summary generation for document %s: %s summaries generated and vectorized", + document.id, + len(summary_records), + ) + return summary_records + + @staticmethod + def disable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + disabled_by: str | None = None, + ) -> None: + """ + Disable summary records and remove vectors from vector database for segments. + Unlike delete, this preserves the summary records but marks them as disabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to disable summaries for. If None, disable all. + disabled_by: User ID who disabled the summaries + """ + from libs.datetime_utils import naive_utc_now + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=True, # Only disable enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Disabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Remove from vector database (but keep records) + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + try: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + except Exception as e: + logger.warning("Failed to remove summary vectors: %s", str(e)) + + # Disable summary records (don't delete) + now = naive_utc_now() + for summary in summaries: + summary.enabled = False + summary.disabled_at = now + summary.disabled_by = disabled_by + session.add(summary) + + session.commit() + logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def enable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Enable summary records and re-add vectors to vector database for segments. + + Note: This method enables summaries based on chunk status, not summary_index_setting.enable. + The summary_index_setting.enable flag only controls automatic generation, + not whether existing summaries can be used. + Summary.enabled should always be kept in sync with chunk.enabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to enable summaries for. If None, enable all. + """ + # Only enable summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=False, # Only enable disabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Enabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Re-vectorize and re-add to vector database + enabled_count = 0 + for summary in summaries: + # Get the original segment + segment = ( + session.query(DocumentSegment) + .filter_by( + id=summary.chunk_id, + dataset_id=dataset.id, + ) + .first() + ) + + # Summary.enabled stays in sync with chunk.enabled, + # only enable summary if the associated chunk is enabled. + if not segment or not segment.enabled or segment.status != "completed": + continue + + if not summary.summary_content: + continue + + try: + # Re-vectorize summary (this will update status and tokens in its own session) + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary) + + # Enable summary record + summary.enabled = True + summary.disabled_at = None + summary.disabled_by = None + session.add(summary) + enabled_count += 1 + except Exception: + logger.exception("Failed to re-vectorize summary %s", summary.id) + # Keep it disabled if vectorization fails + continue + + session.commit() + logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id) + + @staticmethod + def delete_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Delete summary records and vectors for segments (used only for actual deletion scenarios). + For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to delete summaries for. If None, delete all. + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + # Delete from vector database + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + + # Delete summary records + for summary in summaries: + session.delete(summary) + + session.commit() + logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def update_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + ) -> DocumentSegmentSummary | None: + """ + Update summary for a segment and re-vectorize it. + + Args: + segment: DocumentSegment to update summary for + dataset: Dataset containing the segment + summary_content: New summary content + + Returns: + Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality + """ + # Only update summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return None + + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + + # Skip qa_model documents + if segment.document and segment.document.doc_form == "qa_model": + return None + + with session_factory.create_session() as session: + try: + # Check if summary_content is empty (whitespace-only strings are considered empty) + if not summary_content or not summary_content.strip(): + # If summary is empty, only delete existing summary vector and record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record: + # Delete old vector if exists + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Delete summary record since summary is empty + session.delete(summary_record) + session.commit() + logger.info("Deleted summary for segment %s (empty content provided)", segment.id) + return None + else: + # No existing summary record, nothing to do + logger.info("No summary record found for segment %s, nothing to delete", segment.id) + return None + + # Find existing summary record + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + # Update existing summary + old_summary_node_id = summary_record.summary_index_node_id + + # Update summary content + summary_record.summary_content = summary_content + summary_record.status = "generating" + summary_record.error = None # type: ignore[assignment] # Clear any previous errors + session.add(summary_record) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Delete old vector if exists (before vectorization) + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # vectorize_summary will also ensure summary_content is preserved + # Note: vectorize_summary may take time due to embedding API calls, but we need to complete it + # to ensure the summary is properly indexed + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record) + # Now commit the session (summary_record should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Don't raise the exception - just log it and return the record with error status + # This allows the segment update to complete even if vectorization fails + summary_record.status = "error" + summary_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + # The caller can check the status if needed + return summary_record + else: + # Create new summary record if doesn't exist + summary_record = SummaryIndexService.create_summary_record( + segment, dataset, summary_content, status="generating" + ) + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # Note: summary_record was created in a different session, + # so we need to merge it into current session + try: + # Merge the record into current session first (since it was created in a different session) + summary_record = session.merge(summary_record) + # Pass the session to vectorize_summary - it will update the merged record + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh to get updated status and tokens from database + session.refresh(summary_record) + # Commit the session to persist the changes + session.commit() + logger.info("Successfully created and vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Merge the record into current session first + error_record = session.merge(summary_record) + error_record.status = "error" + error_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + return error_record + + except Exception as e: + logger.exception("Failed to update summary for segment %s", segment.id) + # Update summary record with error status if it exists + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record: + summary_record.status = "error" + summary_record.error = str(e) + session.add(summary_record) + session.commit() + raise + + @staticmethod + def get_segment_summary(segment_id: str, dataset_id: str) -> DocumentSegmentSummary | None: + """ + Get summary for a single segment. + + Args: + segment_id: Segment ID (chunk_id) + dataset_id: Dataset ID + + Returns: + DocumentSegmentSummary instance if found, None otherwise + """ + with session_factory.create_session() as session: + return ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .first() + ) + + @staticmethod + def get_segments_summaries(segment_ids: list[str], dataset_id: str) -> dict[str, DocumentSegmentSummary]: + """ + Get summaries for multiple segments. + + Args: + segment_ids: List of segment IDs (chunk_ids) + dataset_id: Dataset ID + + Returns: + Dictionary mapping segment_id to DocumentSegmentSummary (only enabled summaries) + """ + if not segment_ids: + return {} + + with session_factory.create_session() as session: + summary_records = ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .all() + ) + + return {summary.chunk_id: summary for summary in summary_records} + + @staticmethod + def get_document_summaries( + document_id: str, dataset_id: str, segment_ids: list[str] | None = None + ) -> list[DocumentSegmentSummary]: + """ + Get all summary records for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + segment_ids: Optional list of segment IDs to filter by + + Returns: + List of DocumentSegmentSummary instances (only enabled summaries) + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter( + DocumentSegmentSummary.document_id == document_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + return query.all() + + @staticmethod + def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: + """ + Get summary_index_status for a single document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + "SUMMARIZING" if there are pending summaries, None otherwise + """ + # Get all segments for this document (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + segment_ids = [seg.id for seg in segments] + + if not segment_ids: + return None + + # Get all summary records for these segments + summaries = SummaryIndexService.get_segments_summaries(segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Check if there are any "not_started" or "generating" status summaries + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + return "SUMMARIZING" if has_pending_summaries else None + + @staticmethod + def get_documents_summary_index_status( + document_ids: list[str], dataset_id: str, tenant_id: str + ) -> dict[str, str | None]: + """ + Get summary_index_status for multiple documents. + + Args: + document_ids: List of document IDs + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + Dictionary mapping document_id to summary_index_status ("SUMMARIZING" or None) + """ + if not document_ids: + return {} + + # Get all segments for these documents (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id, DocumentSegment.document_id) + .where( + DocumentSegment.document_id.in_(document_ids), + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + + # Group segments by document_id + document_segments_map: dict[str, list[str]] = {} + for segment in segments: + doc_id = str(segment.document_id) + if doc_id not in document_segments_map: + document_segments_map[doc_id] = [] + document_segments_map[doc_id].append(segment.id) + + # Get all summary records for these segments + all_segment_ids = [seg.id for seg in segments] + summaries = SummaryIndexService.get_segments_summaries(all_segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Calculate summary_index_status for each document + result: dict[str, str | None] = {} + for doc_id in document_ids: + segment_ids = document_segments_map.get(doc_id, []) + if not segment_ids: + # No segments, status is None (not started) + result[doc_id] = None + continue + + # Check if there are any "not_started" or "generating" status summaries + # Only check enabled=True summaries (already filtered in query) + # If segment has no summary record (summary_status_map.get returns None), + # it means the summary is disabled (enabled=False) or not created yet, ignore it + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + if has_pending_summaries: + # Task is still running (not started or generating) + result[doc_id] = "SUMMARIZING" + else: + # All enabled=True summaries are "completed" or "error", task finished + # Or no enabled=True summaries exist (all disabled) + result[doc_id] = None + + return result + + @staticmethod + def get_document_summary_status_detail( + document_id: str, + dataset_id: str, + ) -> dict[str, Any]: + """ + Get detailed summary status for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + + Returns: + Dictionary containing: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + from services.dataset_service import SegmentService + + # Get all segments for this document + segments = SegmentService.get_segments_by_document_and_dataset( + document_id=document_id, + dataset_id=dataset_id, + status="completed", + enabled=True, + ) + + total_segments = len(segments) + + # Get all summary records for these segments + segment_ids = [segment.id for segment in segments] + summaries = [] + if segment_ids: + summaries = SummaryIndexService.get_document_summaries( + document_id=document_id, + dataset_id=dataset_id, + segment_ids=segment_ids, + ) + + # Create a mapping of chunk_id to summary + summary_map = {summary.chunk_id: summary for summary in summaries} + + # Count statuses + status_counts = { + "completed": 0, + "generating": 0, + "error": 0, + "not_started": 0, + } + + summary_list = [] + for segment in segments: + summary = summary_map.get(segment.id) + if summary: + status = summary.status + status_counts[status] = status_counts.get(status, 0) + 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": summary.status, + "summary_preview": ( + summary.summary_content[:100] + "..." + if summary.summary_content and len(summary.summary_content) > 100 + else summary.summary_content + ), + "error": summary.error, + "created_at": int(summary.created_at.timestamp()) if summary.created_at else None, + "updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None, + } + ) + else: + status_counts["not_started"] += 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": "not_started", + "summary_preview": None, + "error": None, + "created_at": None, + "updated_at": None, + } + ) + + return { + "total_segments": total_segments, + "summary_status": status_counts, + "summaries": summary_list, + } diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index ab5d5480d..6d84d4e25 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,8 +1,6 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -10,8 +8,8 @@ from sqlalchemy.orm import Session from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -38,12 +36,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -75,7 +71,7 @@ class WorkflowToolManageService: label=label, icon=json.dumps(icon), description=description, - parameter_configuration=json.dumps(parameters), + parameter_configuration=json.dumps([p.model_dump() for p in parameters]), privacy_policy=privacy_policy, version=workflow.version, ) @@ -104,7 +100,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): @@ -122,8 +118,6 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -162,7 +156,7 @@ class WorkflowToolManageService: workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now() diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 62e6497e9..2d3d00cd5 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str): ) session.commit() + # Enable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e)) + end_at = time.perf_counter() logger.info( click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 74b939e84..d38828498 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 86e7cc716..91ace6be0 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index bcca1bf49..4214f043e 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): ).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) segment_ids = [segment.id for segment in segments] segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) session.execute(segment_delete_stmt) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 9f2ee8abd..c13f1525b 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -49,6 +49,7 @@ def delete_segment_from_index_task( doc_form = dataset_document.doc_form # Proceed with index cleanup using the index_node_ids directly + # For actual deletion, we should delete summaries (not just disable them) index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.clean( dataset, @@ -56,6 +57,7 @@ def delete_segment_from_index_task( with_keywords=True, delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, + delete_summaries=True, # Actually delete summaries when segment is deleted ) if dataset.is_multimodal: # delete segment attachment binding diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 0ce6429a9..bc4517162 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.clean(dataset, [segment.index_node_id]) + # Disable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + disabled_by=segment.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info( click.style( diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 03635902d..3cc267e82 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + # Disable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + # Get disabled_by from first segment (they should all have the same disabled_by) + disabled_by = segments[0].disabled_by if segments else None + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 3bdff6019..34496e9c6 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService +from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) @@ -99,6 +100,78 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): indexing_runner.run(documents) end_at = time.perf_counter() logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + + # Trigger summary index generation for completed documents if enabled + # Only generate for high_quality indexing technique and when summary_index_setting is enabled + # Re-query dataset to get latest summary_index_setting (in case it was updated) + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.warning("Dataset %s not found after indexing", dataset_id) + return + + if dataset.indexing_technique == "high_quality": + summary_index_setting = dataset.summary_index_setting + if summary_index_setting and summary_index_setting.get("enable"): + # expire all session to get latest document's indexing status + session.expire_all() + # Check each document's indexing status and trigger summary generation if completed + for document_id in document_ids: + # Re-query document to get latest status (IndexingRunner may have updated it) + document = ( + session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() + ) + if document: + logger.info( + "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + if ( + document.indexing_status == "completed" + and document.doc_form != "qa_model" + and document.need_summary is True + ): + try: + generate_summary_index_task.delay(dataset.id, document_id, None) + logger.info( + "Queued summary index generation task for document %s in dataset %s " + "after indexing completed", + document_id, + dataset.id, + ) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document_id, + ) + # Don't fail the entire indexing process if summary task queuing fails + else: + logger.info( + "Skipping summary generation for document %s: " + "status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + else: + logger.warning("Document %s not found after indexing", document_id) + else: + logger.info( + "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", + dataset.id, + summary_index_setting.get("enable") if summary_index_setting else None, + ) + else: + logger.info( + "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", + dataset.id, + dataset.indexing_technique, + ) except DocumentIsPausedError as ex: logger.info(click.style(str(ex), fg="yellow")) except Exception: diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 1f9f21aa7..41ebb0b07 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str): # save vector index index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) + # Enable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + ) + except Exception as e: + logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 48d3c8e17..d90eb4c39 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i # save vector index index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) + # Enable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py new file mode 100644 index 000000000..e4273e16b --- /dev/null +++ b/api/tasks/generate_summary_index_task.py @@ -0,0 +1,119 @@ +"""Async task for generating summary indexes.""" + +import logging +import time + +import click +from celery import shared_task + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None): + """ + Async generate summary index for document segments. + + Args: + dataset_id: Dataset ID + document_id: Document ID + segment_ids: Optional list of specific segment IDs to process. If None, process all segments. + + Usage: + generate_summary_index_task.delay(dataset_id, document_id) + generate_summary_index_task.delay(dataset_id, document_id, segment_ids) + """ + logger.info( + click.style( + f"Start generating summary index for document {document_id} in dataset {dataset_id}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + if not document: + logger.error(click.style(f"Document not found: {document_id}", fg="red")) + return + + # Check if document needs summary + if not document.need_summary: + logger.info( + click.style( + f"Skipping summary generation for document {document_id}: need_summary is False", + fg="cyan", + ) + ) + return + + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary generation for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + # Determine if only parent chunks should be processed + only_parent_chunks = dataset.chunk_structure == "parent_child_index" + + # Generate summaries + summary_records = SummaryIndexService.generate_summaries_for_document( + dataset=dataset, + document=document, + summary_index_setting=summary_index_setting, + segment_ids=segment_ids, + only_parent_chunks=only_parent_chunks, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Summary index generation completed for document {document_id}: " + f"{len(summary_records)} summaries generated, latency: {end_at - start_at}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to generate summary index for document %s", document_id) + # Update document segments with error status if needed + if segment_ids: + error_message = f"Summary generation failed: {str(e)}" + with session_factory.create_session() as session: + session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + ).update( + { + DocumentSegment.error: error_message, + }, + synchronize_session=False, + ) + session.commit() diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py new file mode 100644 index 000000000..cf8988d13 --- /dev/null +++ b/api/tasks/regenerate_summary_index_task.py @@ -0,0 +1,315 @@ +"""Task for regenerating summary indexes when dataset settings change.""" + +import logging +import time +from collections import defaultdict + +import click +from celery import shared_task +from sqlalchemy import or_, select + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def regenerate_summary_index_task( + dataset_id: str, + regenerate_reason: str = "summary_model_changed", + regenerate_vectors_only: bool = False, +): + """ + Regenerate summary indexes for all documents in a dataset. + + This task is triggered when: + 1. summary_index_setting model changes (regenerate_reason="summary_model_changed") + - Regenerates summary content and vectors for all existing summaries + 2. embedding_model changes (regenerate_reason="embedding_model_changed") + - Only regenerates vectors for existing summaries (keeps summary content) + + Args: + dataset_id: Dataset ID + regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed") + regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content + """ + logger.info( + click.style( + f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + # Only regenerate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary regeneration for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled (only for summary_model change) + # For embedding_model change, we still re-vectorize existing summaries even if setting is disabled + summary_index_setting = dataset.summary_index_setting + if not regenerate_vectors_only: + # For summary_model change, require summary_index_setting to be enabled + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + total_segments_processed = 0 + total_segments_failed = 0 + + if regenerate_vectors_only: + # For embedding_model change: directly query all segments with existing summaries + # Don't require document indexing_status == "completed" + # Include summaries with status "completed" or "error" (if they have content) + segments_with_summaries = ( + session.query(DocumentSegment, DocumentSegmentSummary) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .join( + DatasetDocument, + DocumentSegment.document_id == DatasetDocument.id, + ) + .where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", # Segment must be completed + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content + # Include completed summaries or error summaries (with content) + or_( + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.status == "error", + ), + DatasetDocument.enabled == True, # Document must be enabled + DatasetDocument.archived == False, # Document must not be archived + DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + ) + .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) + .all() + ) + + if not segments_with_summaries: + logger.info( + click.style( + f"No segments with summaries found for re-vectorization in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s segments with summaries for re-vectorization in dataset %s", + len(segments_with_summaries), + dataset_id, + ) + + # Group by document for logging + segments_by_document = defaultdict(list) + for segment, summary_record in segments_with_summaries: + segments_by_document[segment.document_id].append((segment, summary_record)) + + logger.info( + "Segments grouped into %s documents for re-vectorization", + len(segments_by_document), + ) + + for document_id, segment_summary_pairs in segments_by_document.items(): + logger.info( + "Re-vectorizing summaries for %s segments in document %s", + len(segment_summary_pairs), + document_id, + ) + + for segment, summary_record in segment_summary_pairs: + try: + # Delete old vector + if summary_record.summary_index_node_id: + try: + from core.rag.datasource.vdb.vector_factory import Vector + + vector = Vector(dataset) + vector.delete_by_ids([summary_record.summary_index_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize with new embedding model + SummaryIndexService.vectorize_summary(summary_record, segment, dataset) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to re-vectorize summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + summary_record.status = "error" + summary_record.error = f"Re-vectorization failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + else: + # For summary_model change: require document indexing_status == "completed" + # Get all documents with completed indexing status + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + + if not dataset_documents: + logger.info( + click.style( + f"No documents found for summary regeneration in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s documents for summary regeneration in dataset %s", + len(dataset_documents), + dataset_id, + ) + + for dataset_document in dataset_documents: + # Skip qa_model documents + if dataset_document.doc_form == "qa_model": + continue + + try: + # Get all segments with existing summaries + segments = ( + session.query(DocumentSegment) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + + if not segments: + continue + + logger.info( + "Regenerating summaries for %s segments in document %s", + len(segments), + dataset_document.id, + ) + + for segment in segments: + summary_record = None + try: + # Get existing summary record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by( + chunk_id=segment.id, + dataset_id=dataset_id, + ) + .first() + ) + + if not summary_record: + logger.warning("Summary record not found for segment %s, skipping", segment.id) + continue + + # Regenerate both summary content and vectors (for summary_model change) + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to regenerate summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + if summary_record: + summary_record.status = "error" + summary_record.error = f"Regeneration failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + except Exception as e: + logger.error( + "Failed to process document %s for summary regeneration: %s", + dataset_document.id, + str(e), + exc_info=True, + ) + continue + + end_at = time.perf_counter() + if regenerate_vectors_only: + logger.info( + click.style( + f"Summary re-vectorization completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + else: + logger.info( + click.style( + f"Summary index regeneration completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + + except Exception: + logger.exception("Regenerate summary index failed for dataset %s", dataset_id) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 817249845..6240f2200 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): - def del_workflow_archive_log(workflow_archive_log_id: str): - db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( + def del_workflow_archive_log(session, workflow_archive_log_id: str): + session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) @@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: total_files_deleted = 0 while True: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): # Get a batch of draft variable IDs along with their file_ids query_sql = """ SELECT id, file_id FROM workflow_draft_variables diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index c3c255fb1..55259ab52 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -46,6 +46,21 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() + + # Disable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=document.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e)) + index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f46d1bf5d..d02023362 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -10,7 +10,10 @@ from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile -from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch +from tasks.remove_app_and_related_data_task import ( + _delete_draft_variables, + delete_draft_variables_batch, +) @pytest.fixture @@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.return_value = None with session_factory.create_session() as session: draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = session.query(WorkflowDraftVariableFile).count() - upload_files_before = session.query(UploadFile).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.side_effect = [Exception("Storage error"), None] deleted_count = delete_draft_variables_batch(app_id, batch_size=10) @@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration: if app2_obj: session.delete(app2_obj) session.commit() + + +class TestDeleteDraftVariablesSessionCommit: + """Test suite to verify session commit behavior in delete_draft_variables_batch.""" + + @pytest.fixture + def setup_offload_test_data(self, app_and_tenant): + """Create test data with offload files for session commit tests.""" + from core.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now + + tenant, app = app_and_tenant + + with session_factory.create_session() as session: + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file1) + session.add(upload_file2) + session.flush() + + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + session.add(var_file1) + session.add(var_file2) + session.flush() + + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(draft_var1) + session.add(draft_var2) + session.add(draft_var3) + session.commit() + + data = { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } + + yield data + + with session_factory.create_session() as session: + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in data["draft_variables"]]), + (WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]), + (UploadFile, [uf.id for uf in data["upload_files"]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + session.execute(cleanup_query) + session.commit() + + @pytest.fixture + def setup_commit_test_data(self, app_and_tenant): + """Create test data for session commit tests.""" + tenant, app = app_and_tenant + variable_ids: list[str] = [] + + with session_factory.create_session() as session: + variables = [] + for i in range(10): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + variables.append(var) + session.commit() + variable_ids = [v.id for v in variables] + + yield { + "app": app, + "tenant": tenant, + "variable_ids": variable_ids, + } + + with session_factory.create_session() as session: + cleanup_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) + .execution_options(synchronize_session=False) + ) + session.execute(cleanup_query) + session.commit() + + def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data): + """Test that session.begin() is used for automatic transaction management.""" + data = setup_commit_test_data + app_id = data["app"].id + + # Since session.begin() is used, the transaction is automatically committed + # when the with block exits successfully. We verify this by checking that + # data is actually persisted. + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + # Verify all data was deleted (proves transaction was committed) + with session_factory.create_session() as session: + remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + + assert deleted_count == 10 + assert remaining_count == 0 + + def test_data_persisted_after_batch_deletion(self, setup_commit_test_data): + """Test that data is actually persisted to database after batch deletion with commits.""" + data = setup_commit_test_data + app_id = data["app"].id + variable_ids = data["variable_ids"] + + # Verify initial state + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Perform deletion with small batch size to force multiple commits + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + assert deleted_count == 10 + + # Verify all data is deleted in a new session (proves commits worked) + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + # Verify specific IDs are deleted + with session_factory.create_session() as session: + remaining_vars = ( + session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + ) + assert remaining_vars == 0 + + def test_session_commit_with_empty_dataset(self, setup_commit_test_data): + """Test session behavior when deleting from an empty dataset.""" + nonexistent_app_id = str(uuid.uuid4()) + + # Should not raise any errors and should return 0 + deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10) + assert deleted_count == 0 + + def test_session_commit_with_single_batch(self, setup_commit_test_data): + """Test that commit happens correctly when all data fits in a single batch.""" + data = setup_commit_test_data + app_id = data["app"].id + + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Delete all in a single batch + deleted_count = delete_draft_variables_batch(app_id, batch_size=100) + assert deleted_count == 10 + + # Verify data is persisted + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + def test_invalid_batch_size_raises_error(self, setup_commit_test_data): + """Test that invalid batch size raises ValueError.""" + data = setup_commit_test_data + app_id = data["app"].id + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=0) + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=-1) + + @patch("extensions.ext_storage.storage") + def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data): + """Test that session commits correctly when cleaning up offload data.""" + data = setup_offload_test_data + app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + mock_storage.delete.return_value = None + + # Verify initial state + with session_factory.create_session() as session: + draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_before == 3 + assert var_files_before == 2 + assert upload_files_before == 2 + + # Delete variables with offload data + deleted_count = delete_draft_variables_batch(app_id, batch_size=10) + assert deleted_count == 3 + + # Verify all data is persisted (deleted) in new session + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_after == 0 + assert var_files_after == 0 + assert upload_files_after == 0 + + # Verify storage cleanup was called + assert mock_storage.delete.call_count == 2 diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index fe0e03f7b..a2bf10001 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -1,3 +1,5 @@ +import uuid + from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector from core.rag.models.document import Document from tests.integration_tests.vdb.test_vector_store import ( @@ -18,6 +20,10 @@ class QdrantVectorTest(AbstractVectorTest): api_key="difyai123456", ), ) + # Additional doc IDs for multi-keyword search tests + self.doc_apple_id = "" + self.doc_banana_id = "" + self.doc_both_id = "" def search_by_vector(self): super().search_by_vector() @@ -27,6 +33,77 @@ class QdrantVectorTest(AbstractVectorTest): ) assert len(hits_by_vector) == 0 + def _create_document(self, content: str, doc_id: str) -> Document: + """Create a document with the given content and doc_id.""" + return Document( + page_content=content, + metadata={ + "doc_id": doc_id, + "doc_hash": doc_id, + "document_id": doc_id, + "dataset_id": self.dataset_id, + }, + ) + + def setup_multi_keyword_documents(self): + """Create test documents with different keyword combinations for multi-keyword search tests.""" + self.doc_apple_id = str(uuid.uuid4()) + self.doc_banana_id = str(uuid.uuid4()) + self.doc_both_id = str(uuid.uuid4()) + + documents = [ + self._create_document("This document contains apple only", self.doc_apple_id), + self._create_document("This document contains banana only", self.doc_banana_id), + self._create_document("This document contains both apple and banana", self.doc_both_id), + ] + embeddings = [self.example_embedding] * len(documents) + + self.vector.add_texts(documents=documents, embeddings=embeddings) + + def search_by_full_text_multi_keyword(self): + """Test multi-keyword search returns docs matching ANY keyword (OR logic).""" + # First verify single keyword searches work correctly + hits_apple = self.vector.search_by_full_text(query="apple", top_k=10) + apple_ids = {doc.metadata["doc_id"] for doc in hits_apple} + assert self.doc_apple_id in apple_ids, "Document with 'apple' should be found" + assert self.doc_both_id in apple_ids, "Document with 'apple and banana' should be found" + + hits_banana = self.vector.search_by_full_text(query="banana", top_k=10) + banana_ids = {doc.metadata["doc_id"] for doc in hits_banana} + assert self.doc_banana_id in banana_ids, "Document with 'banana' should be found" + assert self.doc_both_id in banana_ids, "Document with 'apple and banana' should be found" + + # Test multi-keyword search returns all matching documents + hits = self.vector.search_by_full_text(query="apple banana", top_k=10) + doc_ids = {doc.metadata["doc_id"] for doc in hits} + + assert self.doc_apple_id in doc_ids, "Document with 'apple' should be found in multi-keyword search" + assert self.doc_banana_id in doc_ids, "Document with 'banana' should be found in multi-keyword search" + assert self.doc_both_id in doc_ids, "Document with both keywords should be found" + # Expect 3 results: doc_apple (apple only), doc_banana (banana only), doc_both (contains both) + assert len(hits) == 3, f"Expected 3 documents, got {len(hits)}" + + # Test keyword order independence + hits_ba = self.vector.search_by_full_text(query="banana apple", top_k=10) + ids_ba = {doc.metadata["doc_id"] for doc in hits_ba} + assert doc_ids == ids_ba, "Keyword order should not affect search results" + + # Test no duplicates in results + doc_id_list = [doc.metadata["doc_id"] for doc in hits] + assert len(doc_id_list) == len(set(doc_id_list)), "Search results should not contain duplicates" + + def run_all_tests(self): + self.create_vector() + self.search_by_vector() + self.search_by_full_text() + self.text_exists() + self.get_ids_by_metadata_field() + # Multi-keyword search tests + self.setup_multi_keyword_documents() + self.search_by_full_text_multi_keyword() + # Cleanup - delete_vector() removes the entire collection + self.delete_vector() + def test_qdrant_vector(setup_mock_redis): QdrantVectorTest().run_all_tests() diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index e3431fd38..934d1bdd3 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -90,6 +90,7 @@ class TestWebhookService: "id": "webhook_node", "type": "webhook", "data": { + "type": "trigger-webhook", "title": "Test Webhook", "method": "post", "content_type": "application/json", diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 3d46735a1..3c0a660e7 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -3,7 +3,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import ValidationError +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -130,20 +132,24 @@ class TestWorkflowToolManageService: def _create_test_workflow_tool_parameters(self): """Helper method to create valid workflow tool parameters.""" return [ - { - "name": "input_text", - "description": "Input text for processing", - "form": "form", - "type": "string", - "required": True, - }, - { - "name": "output_format", - "description": "Output format specification", - "form": "form", - "type": "select", - "required": False, - }, + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + "description": "Input text for processing", + "form": "form", + "type": "string", + "required": True, + } + ), + WorkflowToolParameterConfiguration.model_validate( + { + "name": "output_format", + "description": "Output format specification", + "form": "form", + "type": "select", + "required": False, + } + ), ] def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -208,7 +214,7 @@ class TestWorkflowToolManageService: assert created_tool_provider.label == tool_label assert created_tool_provider.icon == json.dumps(tool_icon) assert created_tool_provider.description == tool_description - assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) + assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters]) assert created_tool_provider.privacy_policy == tool_privacy_policy assert created_tool_provider.version == workflow.version assert created_tool_provider.user_id == account.id @@ -353,18 +359,9 @@ class TestWorkflowToolManageService: app, account, workflow = self._create_test_app_and_account( db_session_with_containers, mock_external_service_dependencies ) - - # Setup invalid workflow tool parameters (missing required fields) - invalid_parameters = [ - { - "name": "input_text", - # Missing description and form fields - "type": "string", - "required": True, - } - ] # Attempt to create workflow tool with invalid parameters - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValidationError) as exc_info: + # Setup invalid workflow tool parameters (missing required fields) WorkflowToolManageService.create_workflow_tool( user_id=account.id, tenant_id=account.current_tenant.id, @@ -373,7 +370,16 @@ class TestWorkflowToolManageService: label=fake.word(), icon={"type": "emoji", "emoji": "🔧"}, description=fake.text(max_nb_chars=200), - parameters=invalid_parameters, + parameters=[ + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + # Missing description and form fields + "type": "string", + "required": True, + } + ) + ], ) # Verify error message contains validation error @@ -579,11 +585,12 @@ class TestWorkflowToolManageService: # Verify database state was updated db.session.refresh(created_tool) + assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label assert created_tool.icon == json.dumps(updated_tool_icon) assert created_tool.description == updated_tool_description - assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) + assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters]) assert created_tool.privacy_policy == updated_tool_privacy_policy assert created_tool.version == workflow.version assert created_tool.updated_at is not None @@ -750,13 +757,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILE type file_parameters = [ - { - "name": "document", - "description": "Upload a document", - "form": "form", - "type": "file", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "document", + "description": "Upload a document", + "form": "form", + "type": "file", + "required": False, + } + ) ] # Execute the method under test @@ -823,13 +832,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILES type files_parameters = [ - { - "name": "documents", - "description": "Upload multiple documents", - "form": "form", - "type": "files", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "documents", + "description": "Upload multiple documents", + "form": "form", + "type": "files", + "required": False, + } + ) ] # Execute the method under test diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index c5e157618..e3c1a617f 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from sqlalchemy import create_engine # Getting the absolute path of the current file's directory ABS_PATH = os.path.dirname(os.path.abspath(__file__)) @@ -36,6 +37,7 @@ import sys sys.path.insert(0, PROJECT_DIR) +from core.db.session_factory import configure_session_factory, session_factory from extensions import ext_redis @@ -102,3 +104,18 @@ def reset_secret_key(): yield finally: dify_config.SECRET_KEY = original + + +@pytest.fixture(scope="session") +def _unit_test_engine(): + engine = create_engine("sqlite:///:memory:") + yield engine + engine.dispose() + + +@pytest.fixture(autouse=True) +def _configure_session_factory(_unit_test_engine): + try: + session_factory.get_session_maker() + except RuntimeError: + configure_session_factory(_unit_test_engine, expire_on_commit=False) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 40eb59a8f..c55760591 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -31,6 +31,13 @@ def _load_app_module(): def schema_model(self, name, schema): self.models[name] = schema + return schema + + def model(self, name, model_dict=None, **kwargs): + """Register a model with the namespace (flask-restx compatibility).""" + if model_dict is not None: + self.models[name] = model_dict + return model_dict def _decorator(self, obj): return obj diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py new file mode 100644 index 000000000..b9bc42fb2 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py @@ -0,0 +1,46 @@ +import builtins +from unittest.mock import patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.secret_key = "test-secret-key" + return app + + +def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch): + ext_fastopenapi.init_app(app) + monkeypatch.delenv("INIT_PASSWORD", raising=False) + + with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"): + client = app.test_client() + response = client.get("/console/api/init") + + assert response.status_code == 200 + assert response.get_json() == {"status": "finished"} + + +def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch): + ext_fastopenapi.init_app(app) + monkeypatch.setenv("INIT_PASSWORD", "test-init-password") + + with ( + patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0), + ): + client = app.test_client() + response = client.post("/console/api/init", json={"password": "test-init-password"}) + + assert response.status_code == 201 + assert response.get_json() == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py new file mode 100644 index 000000000..cb2604cf1 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py @@ -0,0 +1,92 @@ +import builtins +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch + +import httpx +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_console_remote_files_fastopenapi_get_info(app: Flask): + ext_fastopenapi.init_app(app) + + response = httpx.Response( + 200, + request=httpx.Request("HEAD", "http://example.com/file.txt"), + headers={"Content-Type": "text/plain", "Content-Length": "10"}, + ) + + with patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response): + client = app.test_client() + encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt" + resp = client.get(f"/console/api/remote-files/{encoded_url}") + + assert resp.status_code == 200 + assert resp.get_json() == {"file_type": "text/plain", "file_length": 10} + + +def test_console_remote_files_fastopenapi_upload(app: Flask): + ext_fastopenapi.init_app(app) + + head_response = httpx.Response( + 200, + request=httpx.Request("GET", "http://example.com/file.txt"), + content=b"hello", + ) + file_info = SimpleNamespace( + extension="txt", + size=5, + filename="file.txt", + mimetype="text/plain", + ) + uploaded = SimpleNamespace( + id="file-id", + name="file.txt", + size=5, + extension="txt", + mime_type="text/plain", + created_by="user-id", + created_at=datetime(2024, 1, 1), + ) + + with ( + patch("controllers.console.remote_files.db", new=SimpleNamespace(engine=object())), + patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response), + patch("controllers.console.remote_files.helpers.guess_file_info_from_response", return_value=file_info), + patch("controllers.console.remote_files.FileService.is_file_size_within_limit", return_value=True), + patch("controllers.console.remote_files.FileService.__init__", return_value=None), + patch("controllers.console.remote_files.current_account_with_tenant", return_value=(object(), "tenant-id")), + patch("controllers.console.remote_files.FileService.upload_file", return_value=uploaded), + patch("controllers.console.remote_files.file_helpers.get_signed_file_url", return_value="signed-url"), + ): + client = app.test_client() + resp = client.post( + "/console/api/remote-files/upload", + json={"url": "http://example.com/file.txt"}, + ) + + assert resp.status_code == 201 + assert resp.get_json() == { + "id": "file-id", + "name": "file.txt", + "size": 5, + "extension": "txt", + "url": "signed-url", + "mime_type": "text/plain", + "created_by": "user-id", + "created_at": int(uploaded.created_at.timestamp()), + } diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py new file mode 100644 index 000000000..94c3019d5 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -0,0 +1,364 @@ +"""Endpoint tests for controllers.console.workspace.tool_providers.""" + +from __future__ import annotations + +import builtins +import importlib +from contextlib import contextmanager +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask.views import MethodView + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +_CONTROLLER_MODULE: ModuleType | None = None +_WRAPS_MODULE: ModuleType | None = None +_CONTROLLER_PATCHERS: list[patch] = [] + + +@contextmanager +def _mock_db(): + mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + with patch("extensions.ext_database.db.session", mock_session): + yield + + +@pytest.fixture +def app() -> Flask: + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture +def controller_module(monkeypatch: pytest.MonkeyPatch): + module_name = "controllers.console.workspace.tool_providers" + global _CONTROLLER_MODULE + if _CONTROLLER_MODULE is None: + + def _noop(func): + return func + + patch_targets = [ + ("libs.login.login_required", _noop), + ("controllers.console.wraps.setup_required", _noop), + ("controllers.console.wraps.account_initialization_required", _noop), + ("controllers.console.wraps.is_admin_or_owner_required", _noop), + ("controllers.console.wraps.enterprise_license_required", _noop), + ] + for target, value in patch_targets: + patcher = patch(target, value) + patcher.start() + _CONTROLLER_PATCHERS.append(patcher) + monkeypatch.setenv("DIFY_SETUP_READY", "true") + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) + + module = _CONTROLLER_MODULE + monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) + + # Ensure decorators that consult deployment edition do not reach the database. + global _WRAPS_MODULE + wraps_module = importlib.import_module("controllers.console.wraps") + _WRAPS_MODULE = wraps_module + monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None) + return module + + +def _mock_account(user_id: str = "user-123") -> SimpleNamespace: + return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None) + + +def _set_current_account( + monkeypatch: pytest.MonkeyPatch, + controller_module: ModuleType, + user: SimpleNamespace, + tenant_id: str, +) -> None: + def _getter(): + return user, tenant_id + + user.current_tenant_id = tenant_id + + monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter) + if _WRAPS_MODULE is not None: + monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter) + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "_get_user", lambda: user) + + +def test_tool_provider_list_calls_service_with_query( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value=[{"provider": "builtin"}]) + monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock) + + with app.test_request_context("/workspaces/current/tool-providers?type=builtin"): + response = controller_module.ToolProviderListApi().get() + + assert response == [{"provider": "builtin"}] + service_mock.assert_called_once_with(user.id, "tenant-456", "builtin") + + +def test_builtin_provider_add_passes_payload( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value={"status": "ok"}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock) + + payload = { + "credentials": {"api_key": "sk-test"}, + "name": "MyTool", + "type": controller_module.CredentialType.API_KEY, + } + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/openai/add", + method="POST", + json=payload, + ): + response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai") + + assert response == {"status": "ok"} + service_mock.assert_called_once_with( + user_id="user-123", + tenant_id="tenant-456", + provider="openai", + credentials={"api_key": "sk-test"}, + name="MyTool", + api_type=controller_module.CredentialType.API_KEY, + ) + + +def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-789") + _set_current_account(monkeypatch, controller_module, user, "tenant-789") + + service_mock = MagicMock(return_value=[{"name": "tool-a"}]) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock) + monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload) + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/my-provider/tools", + method="GET", + ): + response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider") + + assert response == [{"name": "tool-a"}] + service_mock.assert_called_once_with("tenant-789", "my-provider") + + +def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-9") + _set_current_account(monkeypatch, controller_module, user, "tenant-9") + service_mock = MagicMock(return_value={"info": True}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock) + + with app.test_request_context("/info", method="GET"): + resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo") + + assert resp == {"info": True} + service_mock.assert_called_once_with("tenant-9", "demo") + + +def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-cred") + _set_current_account(monkeypatch, controller_module, user, "tenant-cred") + service_mock = MagicMock(return_value=[{"cred": 1}]) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "get_builtin_tool_provider_credentials", + service_mock, + ) + + with app.test_request_context("/creds", method="GET"): + resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo") + + assert resp == [{"cred": 1}] + service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo") + + +def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-10") + service_mock = MagicMock(return_value={"schema": "ok"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock) + + with app.test_request_context("/remote?url=https://example.com/"): + resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get() + + assert resp == {"schema": "ok"} + service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/") + + +def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-11") + service_mock = MagicMock(return_value=[{"tool": "t"}]) + monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock) + + with app.test_request_context("/tools?provider=foo"): + resp = controller_module.ToolApiProviderListToolsApi().get() + + assert resp == [{"tool": "t"}] + service_mock.assert_called_once_with(user.id, "tenant-11", "foo") + + +def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-12") + service_mock = MagicMock(return_value={"provider": "foo"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock) + + with app.test_request_context("/get?provider=foo"): + resp = controller_module.ToolApiProviderGetApi().get() + + assert resp == {"provider": "foo"} + service_mock.assert_called_once_with(user.id, "tenant-12", "foo") + + +def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-13") + _set_current_account(monkeypatch, controller_module, user, "tenant-13") + service_mock = MagicMock(return_value={"schema": True}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_provider_credentials_schema", + service_mock, + ) + + with app.test_request_context("/schema", method="GET"): + resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get( + provider="demo", credential_type="api-key" + ) + + assert resp == {"schema": True} + service_mock.assert_called_once() + + +def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf") + tool_service = MagicMock(return_value={"wf": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_tool_id", + tool_service, + ) + + tool_id = "00000000-0000-0000-0000-000000000001" + with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"wf": 1} + tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id) + + +def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf2") + service_mock = MagicMock(return_value={"app": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_app_id", + service_mock, + ) + + app_id = "00000000-0000-0000-0000-000000000002" + with app.test_request_context(f"/workflow?workflow_app_id={app_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"app": 1} + service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id) + + +def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf3") + service_mock = MagicMock(return_value=[{"id": 1}]) + monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock) + + tool_id = "00000000-0000-0000-0000-000000000003" + with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderListToolApi().get() + + assert resp == [{"id": 1}] + service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id) + + +def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-bt") + + provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/builtin"): + resp = controller_module.ToolBuiltinListApi().get() + + assert resp == [{"name": "builtin"}] + + +def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-api") + _set_current_account(monkeypatch, controller_module, user, "tenant-api") + + provider = SimpleNamespace(to_dict=lambda: {"name": "api"}) + monkeypatch.setattr( + controller_module.ApiToolManageService, + "list_api_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/api"): + resp = controller_module.ToolApiListApi().get() + + assert resp == [{"name": "api"}] + + +def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf4") + + provider = SimpleNamespace(to_dict=lambda: {"name": "wf"}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "list_tenant_workflow_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/workflow"): + resp = controller_module.ToolWorkflowListApi().get() + + assert resp == [{"name": "wf"}] + + +def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-label") + _set_current_account(monkeypatch, controller_module, user, "tenant-labels") + monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"]) + + with app.test_request_context("/tool-labels"): + resp = controller_module.ToolLabelsApi().get() + + assert resp == ["a", "b"] diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py index 91352b2a5..cfdeef6a8 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -101,3 +101,26 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): assert result.message.tool_calls == [] assert result.usage == LLMUsage.empty_usage() assert result.system_fingerprint is None + + +def test__normalize_non_stream_plugin_result__closes_chunk_iterator(): + prompt_messages = [UserPromptMessage(content="hi")] + + chunk = _make_chunk(content="hello", usage=LLMUsage.empty_usage()) + closed: list[bool] = [] + + def _chunk_iter(): + try: + yield chunk + yield _make_chunk(content="ignored", usage=LLMUsage.empty_usage()) + finally: + closed.append(True) + + result = _normalize_non_stream_plugin_result( + model="test-model", + prompt_messages=prompt_messages, + result=_chunk_iter(), + ) + + assert result.message.content == "hello" + assert closed == [True] diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index f9e59a5f0..0792ada19 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,7 +1,9 @@ """Primarily used for testing merged cell scenarios""" +import io import os import tempfile +from pathlib import Path from types import SimpleNamespace from docx import Document @@ -56,6 +58,42 @@ def test_parse_row(): assert extractor._parse_row(row, {}, 3) == gt[idx] +def test_init_downloads_via_ssrf_proxy(monkeypatch): + doc = Document() + doc.add_paragraph("hello") + buf = io.BytesIO() + doc.save(buf) + docx_bytes = buf.getvalue() + + calls: list[tuple[str, object]] = [] + + class FakeResponse: + status_code = 200 + content = docx_bytes + + def close(self) -> None: + calls.append(("close", None)) + + def fake_get(url: str, **kwargs): + calls.append(("get", (url, kwargs))) + return FakeResponse() + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + + extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id") + try: + assert calls + assert calls[0][0] == "get" + url, kwargs = calls[0][1] + assert url == "https://example.com/test.docx" + assert kwargs.get("timeout") is None + assert extractor.web_path == "https://example.com/test.docx" + assert extractor.file_path != extractor.web_path + assert Path(extractor.file_path).read_bytes() == docx_bytes + finally: + extractor.temp_file.close() + + def test_extract_images_from_docx(monkeypatch): external_bytes = b"ext-bytes" internal_bytes = b"int-bytes" diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index 0aabe2fc3..08818945e 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -138,6 +138,7 @@ class TestDatasetServiceUpdateDataset: "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, + patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task, patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, @@ -147,6 +148,7 @@ class TestDatasetServiceUpdateDataset: "model_manager": mock_model_manager, "get_binding": mock_get_binding, "task": mock_task, + "regenerate_task": mock_regenerate_task, "current_user": mock_current_user, } @@ -549,6 +551,13 @@ class TestDatasetServiceUpdateDataset: # Verify vector index task was triggered mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update") + # Verify regenerate summary index task was triggered (when embedding_model changes) + mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with( + "dataset-123", + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + # Verify return value assert result == dataset diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index a14bbb01d..2b11e42cd 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs: mock_query.where.return_value = mock_delete_query mock_db.session.query.return_value = mock_query - delete_func("log-1") + delete_func(mock_db.session, "log-1") mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) mock_query.where.assert_called_once() diff --git a/api/ty.toml b/api/ty.toml index bb4ff5bbc..380e14dbe 100644 --- a/api/ty.toml +++ b/api/ty.toml @@ -1,16 +1,34 @@ [src] exclude = [ - # TODO: enable when violations fixed + # deps groups (A1/A2/B/C/D/E) + # B: app runner + prompt + "core/prompt", + "core/app/apps/base_app_runner.py", "core/app/apps/workflow_app_runner.py", + "core/agent", + "core/plugin", + # C: services/controllers/fields/libs + "services", + "controllers/inner_api", "controllers/console/app", "controllers/console/explore", "controllers/console/datasets", "controllers/console/workspace", + "controllers/service_api/wraps.py", + "fields/conversation_fields.py", + "libs/external_api.py", + # D: observability + integrations + "core/ops", + "extensions", + # E: vector DB integrations + "core/rag/datasource/vdb", # non-producition or generated code "migrations", "tests", ] + [rules] -missing-argument = "ignore" # TODO: restore when **args for constructor is supported properly -possibly-unbound-attribute = "ignore" +deprecated = "ignore" +unused-ignore-comment = "ignore" +# possibly-missing-attribute = "ignore" \ No newline at end of file diff --git a/api/uv.lock b/api/uv.lock index c95426d19..3ed8cc977 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1395,7 +1395,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.11.4" +version = "1.12.1" source = { virtual = "." } dependencies = [ { name = "alibabacloud-dingtalk" }, @@ -1512,6 +1512,7 @@ dev = [ { name = "pytest-env" }, { name = "pytest-mock" }, { name = "pytest-timeout" }, + { name = "pytest-xdist" }, { name = "ruff" }, { name = "scipy-stubs" }, { name = "sseclient-py" }, @@ -1623,7 +1624,7 @@ requires-dist = [ { name = "flask-migrate", specifier = "~=4.0.7" }, { name = "flask-orjson", specifier = "~=2.0.0" }, { name = "flask-restful", specifier = "~=0.3.10" }, - { name = "flask-restx", specifier = "~=1.3.0" }, + { name = "flask-restx", specifier = "~=1.3.2" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=25.9.1" }, { name = "gmpy2", specifier = "~=2.2.1" }, @@ -1717,11 +1718,12 @@ dev = [ { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, + { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "ruff", specifier = "~=0.14.0" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.13.2" }, - { name = "ty", specifier = "~=0.0.1a19" }, + { name = "ty", specifier = ">=0.0.14" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, { name = "types-cachetools", specifier = "~=5.5.0" }, @@ -1744,7 +1746,7 @@ dev = [ { name = "types-openpyxl", specifier = "~=3.1.5" }, { name = "types-pexpect", specifier = "~=4.9.0" }, { name = "types-protobuf", specifier = "~=5.29.1" }, - { name = "types-psutil", specifier = "~=7.0.0" }, + { name = "types-psutil", specifier = "~=7.2.2" }, { name = "types-psycopg2", specifier = "~=2.9.21" }, { name = "types-pygments", specifier = "~=2.19.0" }, { name = "types-pymysql", specifier = "~=1.1.0" }, @@ -1935,6 +1937,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/d8/2a1c638d9e0aa7e269269a1a1bf423ddd94267f1a01bbe3ad03432b67dd4/eval_type_backport-0.3.0-py3-none-any.whl", hash = "sha256:975a10a0fe333c8b6260d7fdb637698c9a16c3a9e3b6eb943fee6a6f67a37fe8", size = 6061, upload-time = "2025-11-13T20:56:49.499Z" }, ] +[[package]] +name = "execnet" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, +] + [[package]] name = "faker" version = "38.2.0" @@ -5216,6 +5227,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, ] +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, +] + [[package]] name = "python-calamine" version = "0.5.4" @@ -6290,27 +6314,26 @@ wheels = [ [[package]] name = "ty" -version = "0.0.1a27" +version = "0.0.14" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8f/65/3592d7c73d80664378fc90d0a00c33449a99cbf13b984433c883815245f3/ty-0.0.1a27.tar.gz", hash = "sha256:d34fe04979f2c912700cbf0919e8f9b4eeaa10c4a2aff7450e5e4c90f998bc28", size = 4516059, upload-time = "2025-11-18T21:55:18.381Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/57/22c3d6bf95c2229120c49ffc2f0da8d9e8823755a1c3194da56e51f1cc31/ty-0.0.14.tar.gz", hash = "sha256:a691010565f59dd7f15cf324cdcd1d9065e010c77a04f887e1ea070ba34a7de2", size = 5036573, upload-time = "2026-01-27T00:57:31.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/05/7945aa97356446fd53ed3ddc7ee02a88d8ad394217acd9428f472d6b109d/ty-0.0.1a27-py3-none-linux_armv6l.whl", hash = "sha256:3cbb735f5ecb3a7a5f5b82fb24da17912788c109086df4e97d454c8fb236fbc5", size = 9375047, upload-time = "2025-11-18T21:54:31.577Z" }, - { url = "https://files.pythonhosted.org/packages/69/4e/89b167a03de0e9ec329dc89bc02e8694768e4576337ef6c0699987681342/ty-0.0.1a27-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4a6367236dc456ba2416563301d498aef8c6f8959be88777ef7ba5ac1bf15f0b", size = 9169540, upload-time = "2025-11-18T21:54:34.036Z" }, - { url = "https://files.pythonhosted.org/packages/38/07/e62009ab9cc242e1becb2bd992097c80a133fce0d4f055fba6576150d08a/ty-0.0.1a27-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8e93e231a1bcde964cdb062d2d5e549c24493fb1638eecae8fcc42b81e9463a4", size = 8711942, upload-time = "2025-11-18T21:54:36.3Z" }, - { url = "https://files.pythonhosted.org/packages/b5/43/f35716ec15406f13085db52e762a3cc663c651531a8124481d0ba602eca0/ty-0.0.1a27-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5b6a8166b60117da1179851a3d719cc798bf7e61f91b35d76242f0059e9ae1d", size = 8984208, upload-time = "2025-11-18T21:54:39.453Z" }, - { url = "https://files.pythonhosted.org/packages/2d/79/486a3374809523172379768de882c7a369861165802990177fe81489b85f/ty-0.0.1a27-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bfbe8b0e831c072b79a078d6c126d7f4d48ca17f64a103de1b93aeda32265dc5", size = 9157209, upload-time = "2025-11-18T21:54:42.664Z" }, - { url = "https://files.pythonhosted.org/packages/ff/08/9a7c8efcb327197d7d347c548850ef4b54de1c254981b65e8cd0672dc327/ty-0.0.1a27-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90e09678331552e7c25d7eb47868b0910dc5b9b212ae22c8ce71a52d6576ddbb", size = 9519207, upload-time = "2025-11-18T21:54:45.311Z" }, - { url = "https://files.pythonhosted.org/packages/e0/9d/7b4680683e83204b9edec551bb91c21c789ebc586b949c5218157ee474b7/ty-0.0.1a27-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:88c03e4beeca79d85a5618921e44b3a6ea957e0453e08b1cdd418b51da645939", size = 10148794, upload-time = "2025-11-18T21:54:48.329Z" }, - { url = "https://files.pythonhosted.org/packages/89/21/8b961b0ab00c28223f06b33222427a8e31aa04f39d1b236acc93021c626c/ty-0.0.1a27-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ece5811322789fefe22fc088ed36c5879489cd39e913f9c1ff2a7678f089c61", size = 9900563, upload-time = "2025-11-18T21:54:51.214Z" }, - { url = "https://files.pythonhosted.org/packages/85/eb/95e1f0b426c2ea8d443aa923fcab509059c467bbe64a15baaf573fea1203/ty-0.0.1a27-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f2ccb4f0fddcd6e2017c268dfce2489e9a36cb82a5900afe6425835248b1086", size = 9926355, upload-time = "2025-11-18T21:54:53.927Z" }, - { url = "https://files.pythonhosted.org/packages/f5/78/40e7f072049e63c414f2845df780be3a494d92198c87c2ffa65e63aecf3f/ty-0.0.1a27-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33450528312e41d003e96a1647780b2783ab7569bbc29c04fc76f2d1908061e3", size = 9480580, upload-time = "2025-11-18T21:54:56.617Z" }, - { url = "https://files.pythonhosted.org/packages/18/da/f4a2dfedab39096808ddf7475f35ceb750d9a9da840bee4afd47b871742f/ty-0.0.1a27-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a0a9ac635deaa2b15947701197ede40cdecd13f89f19351872d16f9ccd773fa1", size = 8957524, upload-time = "2025-11-18T21:54:59.085Z" }, - { url = "https://files.pythonhosted.org/packages/21/ea/26fee9a20cf77a157316fd3ab9c6db8ad5a0b20b2d38a43f3452622587ac/ty-0.0.1a27-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:797fb2cd49b6b9b3ac9f2f0e401fb02d3aa155badc05a8591d048d38d28f1e0c", size = 9201098, upload-time = "2025-11-18T21:55:01.845Z" }, - { url = "https://files.pythonhosted.org/packages/b0/53/e14591d1275108c9ae28f97ac5d4b93adcc2c8a4b1b9a880dfa9d07c15f8/ty-0.0.1a27-py3-none-musllinux_1_2_i686.whl", hash = "sha256:7fe81679a0941f85e98187d444604e24b15bde0a85874957c945751756314d03", size = 9275470, upload-time = "2025-11-18T21:55:04.23Z" }, - { url = "https://files.pythonhosted.org/packages/37/44/e2c9acecac70bf06fb41de285e7be2433c2c9828f71e3bf0e886fc85c4fd/ty-0.0.1a27-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:355f651d0cdb85535a82bd9f0583f77b28e3fd7bba7b7da33dcee5a576eff28b", size = 9592394, upload-time = "2025-11-18T21:55:06.542Z" }, - { url = "https://files.pythonhosted.org/packages/ee/a7/4636369731b24ed07c2b4c7805b8d990283d677180662c532d82e4ef1a36/ty-0.0.1a27-py3-none-win32.whl", hash = "sha256:61782e5f40e6df622093847b34c366634b75d53f839986f1bf4481672ad6cb55", size = 8783816, upload-time = "2025-11-18T21:55:09.648Z" }, - { url = "https://files.pythonhosted.org/packages/a7/1d/b76487725628d9e81d9047dc0033a5e167e0d10f27893d04de67fe1a9763/ty-0.0.1a27-py3-none-win_amd64.whl", hash = "sha256:c682b238085d3191acddcf66ef22641562946b1bba2a7f316012d5b2a2f4de11", size = 9616833, upload-time = "2025-11-18T21:55:12.457Z" }, - { url = "https://files.pythonhosted.org/packages/3a/db/c7cd5276c8f336a3cf87992b75ba9d486a7cf54e753fcd42495b3bc56fb7/ty-0.0.1a27-py3-none-win_arm64.whl", hash = "sha256:e146dfa32cbb0ac6afb0cb65659e87e4e313715e68d76fe5ae0a4b3d5b912ce8", size = 9137796, upload-time = "2025-11-18T21:55:15.897Z" }, + { url = "https://files.pythonhosted.org/packages/99/cb/cc6d1d8de59beb17a41f9a614585f884ec2d95450306c173b3b7cc090d2e/ty-0.0.14-py3-none-linux_armv6l.whl", hash = "sha256:32cf2a7596e693094621d3ae568d7ee16707dce28c34d1762947874060fdddaa", size = 10034228, upload-time = "2026-01-27T00:57:53.133Z" }, + { url = "https://files.pythonhosted.org/packages/f3/96/dd42816a2075a8f31542296ae687483a8d047f86a6538dfba573223eaf9a/ty-0.0.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f971bf9805f49ce8c0968ad53e29624d80b970b9eb597b7cbaba25d8a18ce9a2", size = 9939162, upload-time = "2026-01-27T00:57:43.857Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b4/73c4859004e0f0a9eead9ecb67021438b2e8e5fdd8d03e7f5aca77623992/ty-0.0.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:45448b9e4806423523268bc15e9208c4f3f2ead7c344f615549d2e2354d6e924", size = 9418661, upload-time = "2026-01-27T00:58:03.411Z" }, + { url = "https://files.pythonhosted.org/packages/58/35/839c4551b94613db4afa20ee555dd4f33bfa7352d5da74c5fa416ffa0fd2/ty-0.0.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94a9b747ff40114085206bdb3205a631ef19a4d3fb89e302a88754cbbae54c", size = 9837872, upload-time = "2026-01-27T00:57:23.718Z" }, + { url = "https://files.pythonhosted.org/packages/41/2b/bbecf7e2faa20c04bebd35fc478668953ca50ee5847ce23e08acf20ea119/ty-0.0.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6756715a3c33182e9ab8ffca2bb314d3c99b9c410b171736e145773ee0ae41c3", size = 9848819, upload-time = "2026-01-27T00:57:58.501Z" }, + { url = "https://files.pythonhosted.org/packages/be/60/3c0ba0f19c0f647ad9d2b5b5ac68c0f0b4dc899001bd53b3a7537fb247a2/ty-0.0.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89d0038a2f698ba8b6fec5cf216a4e44e2f95e4a5095a8c0f57fe549f87087c2", size = 10324371, upload-time = "2026-01-27T00:57:29.291Z" }, + { url = "https://files.pythonhosted.org/packages/24/32/99d0a0b37d0397b0a989ffc2682493286aa3bc252b24004a6714368c2c3d/ty-0.0.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c64a83a2d669b77f50a4957039ca1450626fb474619f18f6f8a3eb885bf7544", size = 10865898, upload-time = "2026-01-27T00:57:33.542Z" }, + { url = "https://files.pythonhosted.org/packages/1a/88/30b583a9e0311bb474269cfa91db53350557ebec09002bfc3fb3fc364e8c/ty-0.0.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:242488bfb547ef080199f6fd81369ab9cb638a778bb161511d091ffd49c12129", size = 10555777, upload-time = "2026-01-27T00:58:05.853Z" }, + { url = "https://files.pythonhosted.org/packages/cd/a2/cb53fb6325dcf3d40f2b1d0457a25d55bfbae633c8e337bde8ec01a190eb/ty-0.0.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4790c3866f6c83a4f424fc7d09ebdb225c1f1131647ba8bdc6fcdc28f09ed0ff", size = 10412913, upload-time = "2026-01-27T00:57:38.834Z" }, + { url = "https://files.pythonhosted.org/packages/42/8f/f2f5202d725ed1e6a4e5ffaa32b190a1fe70c0b1a2503d38515da4130b4c/ty-0.0.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:950f320437f96d4ea9a2332bbfb5b68f1c1acd269ebfa4c09b6970cc1565bd9d", size = 9837608, upload-time = "2026-01-27T00:57:55.898Z" }, + { url = "https://files.pythonhosted.org/packages/f7/ba/59a2a0521640c489dafa2c546ae1f8465f92956fede18660653cce73b4c5/ty-0.0.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4a0ec3ee70d83887f86925bbc1c56f4628bd58a0f47f6f32ddfe04e1f05466df", size = 9884324, upload-time = "2026-01-27T00:57:46.786Z" }, + { url = "https://files.pythonhosted.org/packages/03/95/8d2a49880f47b638743212f011088552ecc454dd7a665ddcbdabea25772a/ty-0.0.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a1a4e6b6da0c58b34415955279eff754d6206b35af56a18bb70eb519d8d139ef", size = 10033537, upload-time = "2026-01-27T00:58:01.149Z" }, + { url = "https://files.pythonhosted.org/packages/e9/40/4523b36f2ce69f92ccf783855a9e0ebbbd0f0bb5cdce6211ee1737159ed3/ty-0.0.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dc04384e874c5de4c5d743369c277c8aa73d1edea3c7fc646b2064b637db4db3", size = 10495910, upload-time = "2026-01-27T00:57:26.691Z" }, + { url = "https://files.pythonhosted.org/packages/08/d5/655beb51224d1bfd4f9ddc0bb209659bfe71ff141bcf05c418ab670698f0/ty-0.0.14-py3-none-win32.whl", hash = "sha256:b20e22cf54c66b3e37e87377635da412d9a552c9bf4ad9fc449fed8b2e19dad2", size = 9507626, upload-time = "2026-01-27T00:57:41.43Z" }, + { url = "https://files.pythonhosted.org/packages/b6/d9/c569c9961760e20e0a4bc008eeb1415754564304fd53997a371b7cf3f864/ty-0.0.14-py3-none-win_amd64.whl", hash = "sha256:e312ff9475522d1a33186657fe74d1ec98e4a13e016d66f5758a452c90ff6409", size = 10437980, upload-time = "2026-01-27T00:57:36.422Z" }, + { url = "https://files.pythonhosted.org/packages/ad/0c/186829654f5bfd9a028f6648e9caeb11271960a61de97484627d24443f91/ty-0.0.14-py3-none-win_arm64.whl", hash = "sha256:b6facdbe9b740cb2c15293a1d178e22ffc600653646452632541d01c36d5e378", size = 9885831, upload-time = "2026-01-27T00:57:49.747Z" }, ] [[package]] @@ -6560,11 +6583,11 @@ wheels = [ [[package]] name = "types-psutil" -version = "7.0.0.20251116" +version = "7.2.2.20260130" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" } +sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" }, + { url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" }, ] [[package]] diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index 496cb4095..a03408330 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -1,10 +1,16 @@ #!/bin/bash -set -x +set -euxo pipefail SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-20}" +PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}" -# libs -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests +# Run most tests in parallel (excluding controllers which have import conflicts with xdist) +# Controller tests have module-level side effects (Flask route registration) that cause +# race conditions when imported concurrently by multiple pytest-xdist workers. +pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} api/tests/unit_tests --ignore=api/tests/unit_tests/controllers + +# Run controller tests sequentially to avoid import race conditions +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests/controllers diff --git a/docker/.env.example b/docker/.env.example index 916b4627d..cbfe6aa5c 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -397,7 +397,7 @@ WEB_API_CORS_ALLOW_ORIGINS=* # Specifies the allowed origins for cross-origin requests to the console API, # e.g. https://cloud.dify.ai or * for all origins. CONSOLE_CORS_ALLOW_ORIGINS=* -# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). Leading dots are optional. +# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site's top-level domain (e.g., `example.com`). Leading dots are optional. COOKIE_DOMAIN= # When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. NEXT_PUBLIC_COOKIE_DOMAIN= @@ -1080,7 +1080,7 @@ ALIYUN_SLS_ENDPOINT= ALIYUN_SLS_REGION= # Aliyun SLS Project Name ALIYUN_SLS_PROJECT_NAME= -# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage) +# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage) ALIYUN_SLS_LOGSTORE_TTL=365 # Enable dual-write to both SLS LogStore and SQL database (default: false) LOGSTORE_DUAL_WRITE_ENABLED=false @@ -1375,6 +1375,7 @@ PLUGIN_DAEMON_PORT=5002 PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi PLUGIN_DAEMON_URL=http://plugin_daemon:5002 PLUGIN_MAX_PACKAGE_SIZE=52428800 +PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 PLUGIN_PPROF_ENABLED=false PLUGIN_DEBUGGING_HOST=0.0.0.0 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 965999038..cb5e2c47f 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.4 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -270,7 +270,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. @@ -662,13 +662,14 @@ services: - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" - "${IRIS_WEB_SERVER_PORT:-52773}:52773" volumes: - - ./volumes/iris:/opt/iris + - ./volumes/iris:/durable - ./iris/iris-init.script:/iris-init.script - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh entrypoint: ["/custom-entrypoint.sh"] tty: true environment: TZ: ${IRIS_TIMEZONE:-UTC} + ISC_DATA_DIRECTORY: /durable/iris # Oracle vector database oracle: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 81c34fc6a..4a739bbbe 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -123,7 +123,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 902ca3103..1886f848e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -589,6 +589,7 @@ x-shared-env: &shared-api-worker-env PLUGIN_DAEMON_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} PLUGIN_DAEMON_URL: ${PLUGIN_DAEMON_URL:-http://plugin_daemon:5002} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_MODEL_SCHEMA_CACHE_TTL: ${PLUGIN_MODEL_SCHEMA_CACHE_TTL:-3600} PLUGIN_PPROF_ENABLED: ${PLUGIN_PPROF_ENABLED:-false} PLUGIN_DEBUGGING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_DEBUGGING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} @@ -706,7 +707,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -748,7 +749,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -787,7 +788,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -817,7 +818,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.4 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -955,7 +956,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. @@ -1347,13 +1348,14 @@ services: - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" - "${IRIS_WEB_SERVER_PORT:-52773}:52773" volumes: - - ./volumes/iris:/opt/iris + - ./volumes/iris:/durable - ./iris/iris-init.script:/iris-init.script - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh entrypoint: ["/custom-entrypoint.sh"] tty: true environment: TZ: ${IRIS_TIMEZONE:-UTC} + ISC_DATA_DIRECTORY: /durable/iris # Oracle vector database oracle: diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose index b5c0acefb..bf6c1423c 100755 --- a/docker/generate_docker_compose +++ b/docker/generate_docker_compose @@ -9,7 +9,7 @@ def parse_env_example(file_path): Parses the .env.example file and returns a dictionary with variable names as keys and default values as values. """ env_vars = {} - with open(file_path, "r") as f: + with open(file_path, "r", encoding="utf-8") as f: for line_number, line in enumerate(f, 1): line = line.strip() # Ignore empty lines and comments @@ -55,7 +55,7 @@ def insert_shared_env(template_path, output_path, shared_env_block, header_comme Inserts the shared environment variables block and header comments into the template file, removing any existing x-shared-env anchors, and generates the final docker-compose.yaml file. """ - with open(template_path, "r") as f: + with open(template_path, "r", encoding="utf-8") as f: template_content = f.read() # Remove existing x-shared-env: &shared-api-worker-env lines @@ -69,7 +69,7 @@ def insert_shared_env(template_path, output_path, shared_env_block, header_comme # Prepare the final content with header comments and shared env block final_content = f"{header_comments}\n{shared_env_block}\n\n{template_content}" - with open(output_path, "w") as f: + with open(output_path, "w", encoding="utf-8") as f: f.write(final_content) print(f"Generated {output_path}") diff --git a/docker/iris/docker-entrypoint.sh b/docker/iris/docker-entrypoint.sh index 067bfa03e..1a3b10423 100755 --- a/docker/iris/docker-entrypoint.sh +++ b/docker/iris/docker-entrypoint.sh @@ -1,15 +1,33 @@ #!/bin/bash set -e -# IRIS configuration flag file -IRIS_CONFIG_DONE="/opt/iris/.iris-configured" +# IRIS configuration flag file (stored in durable directory to persist with data) +IRIS_CONFIG_DONE="/durable/.iris-configured" + +# Function to wait for IRIS to be ready +wait_for_iris() { + echo "Waiting for IRIS to be ready..." + local max_attempts=30 + local attempt=1 + while [ "$attempt" -le "$max_attempts" ]; do + if iris qlist IRIS 2>/dev/null | grep -q "running"; then + echo "IRIS is ready." + return 0 + fi + echo "Attempt $attempt/$max_attempts: IRIS not ready yet, waiting..." + sleep 2 + attempt=$((attempt + 1)) + done + echo "ERROR: IRIS failed to start within expected time." >&2 + return 1 +} # Function to configure IRIS configure_iris() { echo "Configuring IRIS for first-time setup..." # Wait for IRIS to be fully started - sleep 5 + wait_for_iris # Execute the initialization script iris session IRIS < /iris-init.script diff --git a/web/AGENTS.md b/web/AGENTS.md index 7362cd51d..5dd41b8a3 100644 --- a/web/AGENTS.md +++ b/web/AGENTS.md @@ -1,5 +1,9 @@ +## Frontend Workflow + +- Refer to the `./docs/test.md` and `./docs/lint.md` for detailed frontend workflow instructions. + ## Automated Test Generation -- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests. +- Use `./docs/test.md` as the canonical instruction set for generating frontend automated tests. - When proposing or saving tests, re-read that document and follow every requirement. - All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance. diff --git a/web/README.md b/web/README.md index 3210cfd21..26ce999b4 100644 --- a/web/README.md +++ b/web/README.md @@ -97,6 +97,8 @@ Open [http://localhost:6006](http://localhost:6006) with your browser to see the If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting. +Then follow the [Lint Documentation](./docs/lint.md) to lint the code. + ## Test We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index 3410ecbe9..dfbac5d74 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -3,7 +3,7 @@ import type { ReactNode } from 'react' import Cookies from 'js-cookie' import { usePathname, useRouter, useSearchParams } from 'next/navigation' -import { parseAsString, useQueryState } from 'nuqs' +import { parseAsBoolean, useQueryState } from 'nuqs' import { useCallback, useEffect, useState } from 'react' import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, @@ -28,7 +28,7 @@ export const AppInitializer = ({ const [init, setInit] = useState(false) const [oauthNewUser, setOauthNewUser] = useQueryState( 'oauth_new_user', - parseAsString.withOptions({ history: 'replace' }), + parseAsBoolean.withOptions({ history: 'replace' }), ) const isSetupFinished = useCallback(async () => { @@ -46,7 +46,7 @@ export const AppInitializer = ({ (async () => { const action = searchParams.get('action') - if (oauthNewUser === 'true') { + if (oauthNewUser) { let utmInfo = null const utmInfoStr = Cookies.get('utm_info') if (utmInfoStr) { diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 255feaccd..aa31f0201 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -31,6 +31,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import AppIcon from '../base/app-icon' import AppOperations from './app-operations' @@ -145,13 +146,8 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx appID: appDetail.id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${appDetail.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 4d7c832e0..96127c421 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -11,6 +11,7 @@ import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/kn import { useInvalid } from '@/service/use-base' import { useExportPipelineDSL } from '@/service/use-pipeline' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import ActionButton from '../../base/action-button' import Confirm from '../../base/confirm' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' @@ -64,13 +65,8 @@ const DropDown = ({ pipelineId: pipeline_id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${name}.pipeline` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${name}.pipeline` }) } catch { Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/app-sidebar/toggle-button.tsx b/web/app/components/app-sidebar/toggle-button.tsx index a6bdee4f7..cbfbeee45 100644 --- a/web/app/components/app-sidebar/toggle-button.tsx +++ b/web/app/components/app-sidebar/toggle-button.tsx @@ -4,7 +4,7 @@ import { useTranslation } from 'react-i18next' import { cn } from '@/utils/classnames' import Button from '../base/button' import Tooltip from '../base/tooltip' -import { getKeyboardKeyNameBySystem } from '../workflow/utils' +import ShortcutsName from '../workflow/shortcuts-name' type TooltipContentProps = { expand: boolean @@ -20,18 +20,7 @@ const TooltipContent = ({ return (
{expand ? t('sidebar.collapseSidebar', { ns: 'layout' }) : t('sidebar.expandSidebar', { ns: 'layout' })} -
- { - TOGGLE_SHORTCUT.map(key => ( - - {getKeyboardKeyNameBySystem(key)} - - )) - } -
+
) } diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 5add1aed3..4fc1e2600 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -21,6 +21,7 @@ import { LanguagesSupported } from '@/i18n-config/language' import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import Button from '../../../base/button' import AddAnnotationModal from '../add-annotation-modal' import BatchAddModal from '../batch-add-annotation-modal' @@ -56,28 +57,23 @@ const HeaderOptions: FC = ({ ) const JSONLOutput = () => { - const a = document.createElement('a') const content = listTransformer(list).join('\n') const file = new Blob([content], { type: 'application/jsonl' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `annotations-${locale}.jsonl` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `annotations-${locale}.jsonl` }) } - const fetchList = async () => { + const fetchList = React.useCallback(async () => { const { data }: any = await fetchExportAnnotationList(appId) setList(data as AnnotationItemBasic[]) - } + }, [appId]) useEffect(() => { fetchList() - }, []) + }, [fetchList]) useEffect(() => { if (controlUpdateList) fetchList() - }, [controlUpdateList]) + }, [controlUpdateList, fetchList]) const [showBulkImportModal, setShowBulkImportModal] = useState(false) const [showClearConfirm, setShowClearConfirm] = useState(false) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 0a026a680..0fc364cb7 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -49,7 +49,8 @@ import Divider from '../../base/divider' import Loading from '../../base/loading' import Toast from '../../base/toast' import Tooltip from '../../base/tooltip' -import { getKeyboardKeyCodeBySystem, getKeyboardKeyNameBySystem } from '../../workflow/utils' +import ShortcutsName from '../../workflow/shortcuts-name' +import { getKeyboardKeyCodeBySystem } from '../../workflow/utils' import AccessControl from '../app-access-control' import PublishWithMultipleModel from './publish-with-multiple-model' import SuggestedAction from './suggested-action' @@ -345,13 +346,7 @@ const AppPublisher = ({ : (
{t('common.publishUpdate', { ns: 'workflow' })} -
- {PUBLISH_SHORTCUT.map(key => ( - - {getKeyboardKeyNameBySystem(key)} - - ))} -
+
) } diff --git a/web/app/components/app/configuration/config-var/index.spec.tsx b/web/app/components/app/configuration/config-var/index.spec.tsx index b5015ed07..490d7b441 100644 --- a/web/app/components/app/configuration/config-var/index.spec.tsx +++ b/web/app/components/app/configuration/config-var/index.spec.tsx @@ -2,7 +2,7 @@ import type { ReactNode } from 'react' import type { IConfigVarProps } from './index' import type { ExternalDataTool } from '@/models/common' import type { PromptVariable } from '@/models/debug' -import { act, fireEvent, render, screen } from '@testing-library/react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { vi } from 'vitest' import Toast from '@/app/components/base/toast' @@ -240,7 +240,9 @@ describe('ConfigVar', () => { const saveButton = await screen.findByRole('button', { name: 'common.operation.save' }) fireEvent.click(saveButton) - expect(onPromptVariablesChange).toHaveBeenCalledTimes(1) + await waitFor(() => { + expect(onPromptVariablesChange).toHaveBeenCalledTimes(1) + }) }) it('should show error when variable key is duplicated', async () => { diff --git a/web/app/components/app/configuration/config/automatic/automatic-btn.spec.tsx b/web/app/components/app/configuration/config/automatic/automatic-btn.spec.tsx new file mode 100644 index 000000000..f027f643a --- /dev/null +++ b/web/app/components/app/configuration/config/automatic/automatic-btn.spec.tsx @@ -0,0 +1,77 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import AutomaticBtn from './automatic-btn' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('AutomaticBtn', () => { + const mockOnClick = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render the button with correct text', () => { + render() + + expect(screen.getByText('operation.automatic')).toBeInTheDocument() + }) + + it('should render the sparkling icon', () => { + const { container } = render() + + // The icon should be an SVG element inside the button + const svg = container.querySelector('svg') + expect(svg).toBeTruthy() + }) + + it('should render as a button element', () => { + render() + + expect(screen.getByRole('button')).toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should call onClick when button is clicked', () => { + render() + + const button = screen.getByRole('button') + fireEvent.click(button) + + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('should call onClick multiple times on multiple clicks', () => { + render() + + const button = screen.getByRole('button') + + fireEvent.click(button) + fireEvent.click(button) + fireEvent.click(button) + + expect(mockOnClick).toHaveBeenCalledTimes(3) + }) + }) + + describe('Styling', () => { + it('should have secondary-accent variant', () => { + render() + + const button = screen.getByRole('button') + expect(button.className).toContain('secondary-accent') + }) + + it('should have small size', () => { + render() + + const button = screen.getByRole('button') + expect(button.className).toContain('small') + }) + }) +}) diff --git a/web/app/components/app/create-app-dialog/app-card/index.tsx b/web/app/components/app/create-app-dialog/app-card/index.tsx index 15cfbd541..e203edfc8 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.tsx @@ -62,19 +62,19 @@ const AppCard = ({ {app.description} - {canCreate && ( + {(canCreate || isTrialApp) && ( )} diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index cb8f4db67..d26a581fd 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -124,7 +124,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({ name: 'My App', @@ -152,7 +152,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalled()) expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' }) diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index e2b50cf03..66c7bce80 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { AppIconSelection } from '../../base/app-icon-picker' -import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLine, RiExchange2Fill } from '@remixicon/react' +import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon/react' import { useDebounceFn, useKeyPress } from 'ahooks' import Image from 'next/image' @@ -29,6 +29,7 @@ import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' import { basePath } from '@/utils/var' import AppIconPicker from '../../base/app-icon-picker' +import ShortcutsName from '../../workflow/shortcuts-name' type CreateAppProps = { onSuccess: () => void @@ -269,10 +270,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 838e9cc03..04d8b1e75 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { MouseEventHandler } from 'react' -import { RiCloseLine, RiCommandLine, RiCornerDownLeftLine } from '@remixicon/react' +import { RiCloseLine } from '@remixicon/react' import { useDebounceFn, useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import { useRouter } from 'next/navigation' @@ -28,6 +28,7 @@ import { } from '@/service/apps' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import ShortcutsName from '../../workflow/shortcuts-name' import Uploader from './uploader' type CreateFromDSLModalProps = { @@ -298,10 +299,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS className="gap-1" > {t('newApp.Create', { ns: 'app' })} -
- - -
+ diff --git a/web/app/components/app/log/empty-element.spec.tsx b/web/app/components/app/log/empty-element.spec.tsx new file mode 100644 index 000000000..71d2bd0dd --- /dev/null +++ b/web/app/components/app/log/empty-element.spec.tsx @@ -0,0 +1,134 @@ +import type { App } from '@/types/app' +import { render, screen } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import EmptyElement from './empty-element' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), + Trans: ({ i18nKey, components }: { i18nKey: string, components: Record }) => ( + + {i18nKey} + {components.shareLink} + {components.testLink} + + ), +})) + +vi.mock('@/utils/app-redirection', () => ({ + getRedirectionPath: (isTest: boolean, _app: App) => isTest ? '/test-path' : '/prod-path', +})) + +vi.mock('@/utils/var', () => ({ + basePath: '/base', +})) + +describe('EmptyElement', () => { + const createMockAppDetail = (mode: AppModeEnum) => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test description', + mode, + icon_type: 'emoji', + icon: 'test-icon', + icon_background: '#ffffff', + enable_site: true, + enable_api: true, + created_at: Date.now(), + site: { + access_token: 'test-token', + app_base_url: 'https://app.example.com', + }, + }) as unknown as App + + describe('Rendering', () => { + it('should render empty element with title', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + expect(screen.getByText('table.empty.element.title')).toBeInTheDocument() + }) + + it('should render Trans component with i18n key', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const transComponent = screen.getByTestId('trans-component') + expect(transComponent).toHaveAttribute('data-i18n-key', 'table.empty.element.content') + }) + + it('should render ThreeDotsIcon SVG', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + const { container } = render() + + const svg = container.querySelector('svg') + expect(svg).toBeInTheDocument() + }) + }) + + describe('App Mode Handling', () => { + it('should use CHAT mode for chat apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token') + }) + + it('should use COMPLETION mode for completion apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.COMPLETION) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/completion/test-token') + }) + + it('should use WORKFLOW mode for workflow apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.WORKFLOW) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/workflow/test-token') + }) + + it('should use CHAT mode for advanced-chat apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.ADVANCED_CHAT) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token') + }) + + it('should use CHAT mode for agent-chat apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.AGENT_CHAT) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token') + }) + }) + + describe('Links', () => { + it('should render share link with correct attributes', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const links = screen.getAllByRole('link') + const shareLink = links[0] + + expect(shareLink).toHaveAttribute('target', '_blank') + expect(shareLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should render test link with redirection path', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const links = screen.getAllByRole('link') + const testLink = links[1] + + expect(testLink).toHaveAttribute('href', '/test-path') + }) + }) +}) diff --git a/web/app/components/app/log/filter.spec.tsx b/web/app/components/app/log/filter.spec.tsx new file mode 100644 index 000000000..8e978cdf9 --- /dev/null +++ b/web/app/components/app/log/filter.spec.tsx @@ -0,0 +1,210 @@ +import type { QueryParam } from './index' +import { fireEvent, render, screen } from '@testing-library/react' +import Filter, { TIME_PERIOD_MAPPING } from './filter' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { count?: number }) => { + if (options?.count !== undefined) + return `${key} (${options.count})` + return key + }, + }), +})) + +vi.mock('@/service/use-log', () => ({ + useAnnotationsCount: () => ({ + data: { count: 10 }, + isLoading: false, + }), +})) + +describe('Filter', () => { + const defaultQueryParams: QueryParam = { + period: '9', + annotation_status: 'all', + keyword: '', + } + + const mockSetQueryParams = vi.fn() + const defaultProps = { + appId: 'test-app-id', + queryParams: defaultQueryParams, + setQueryParams: mockSetQueryParams, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render filter components', () => { + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + + it('should return null when loading', () => { + // This test verifies the component renders correctly with the mocked data + const { container } = render() + expect(container.firstChild).not.toBeNull() + }) + + it('should render sort component in chat mode', () => { + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + + it('should not render sort component when not in chat mode', () => { + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + }) + + describe('TIME_PERIOD_MAPPING', () => { + it('should have correct period keys', () => { + expect(Object.keys(TIME_PERIOD_MAPPING)).toEqual(['1', '2', '3', '4', '5', '6', '7', '8', '9']) + }) + + it('should have today period with value 0', () => { + expect(TIME_PERIOD_MAPPING['1'].value).toBe(0) + expect(TIME_PERIOD_MAPPING['1'].name).toBe('today') + }) + + it('should have last7days period with value 7', () => { + expect(TIME_PERIOD_MAPPING['2'].value).toBe(7) + expect(TIME_PERIOD_MAPPING['2'].name).toBe('last7days') + }) + + it('should have last4weeks period with value 28', () => { + expect(TIME_PERIOD_MAPPING['3'].value).toBe(28) + expect(TIME_PERIOD_MAPPING['3'].name).toBe('last4weeks') + }) + + it('should have allTime period with value -1', () => { + expect(TIME_PERIOD_MAPPING['9'].value).toBe(-1) + expect(TIME_PERIOD_MAPPING['9'].name).toBe('allTime') + }) + }) + + describe('User Interactions', () => { + it('should update keyword when typing in search input', () => { + render() + + const searchInput = screen.getByPlaceholderText('operation.search') + fireEvent.change(searchInput, { target: { value: 'test search' } }) + + expect(mockSetQueryParams).toHaveBeenCalledWith({ + ...defaultQueryParams, + keyword: 'test search', + }) + }) + + it('should clear keyword when clear button is clicked', () => { + const propsWithKeyword = { + ...defaultProps, + queryParams: { ...defaultQueryParams, keyword: 'existing search' }, + } + + render() + + const clearButton = screen.getByTestId('input-clear') + fireEvent.click(clearButton) + + expect(mockSetQueryParams).toHaveBeenCalledWith({ + ...defaultQueryParams, + keyword: '', + }) + }) + }) + + describe('Query Params', () => { + it('should display "today" when period is set to 1', () => { + const propsWithPeriod = { + ...defaultProps, + queryParams: { ...defaultQueryParams, period: '1' }, + } + + render() + + // Period '1' maps to 'today' in TIME_PERIOD_MAPPING + expect(screen.getByText('filter.period.today')).toBeInTheDocument() + }) + + it('should display "last7days" when period is set to 2', () => { + const propsWithPeriod = { + ...defaultProps, + queryParams: { ...defaultQueryParams, period: '2' }, + } + + render() + + expect(screen.getByText('filter.period.last7days')).toBeInTheDocument() + }) + + it('should display "allTime" when period is set to 9', () => { + render() + + // Default period is '9' which maps to 'allTime' + expect(screen.getByText('filter.period.allTime')).toBeInTheDocument() + }) + + it('should display annotated status with count when annotation_status is annotated', () => { + const propsWithAnnotation = { + ...defaultProps, + queryParams: { ...defaultQueryParams, annotation_status: 'annotated' }, + } + + render() + + // The mock returns count: 10, so the text should include the count + expect(screen.getByText('filter.annotation.annotated (10)')).toBeInTheDocument() + }) + + it('should display not_annotated status when annotation_status is not_annotated', () => { + const propsWithNotAnnotated = { + ...defaultProps, + queryParams: { ...defaultQueryParams, annotation_status: 'not_annotated' }, + } + + render() + + expect(screen.getByText('filter.annotation.not_annotated')).toBeInTheDocument() + }) + + it('should display all annotation status when annotation_status is all', () => { + render() + + // Default annotation_status is 'all' + expect(screen.getByText('filter.annotation.all')).toBeInTheDocument() + }) + }) + + describe('Chat Mode', () => { + it('should display sort component with sort_by parameter', () => { + const propsWithSort = { + ...defaultProps, + isChatMode: true, + queryParams: { ...defaultQueryParams, sort_by: 'created_at' }, + } + + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + + it('should handle descending sort order', () => { + const propsWithDescSort = { + ...defaultProps, + isChatMode: true, + queryParams: { ...defaultQueryParams, sort_by: '-created_at' }, + } + + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/log/model-info.spec.tsx b/web/app/components/app/log/model-info.spec.tsx new file mode 100644 index 000000000..c8263c236 --- /dev/null +++ b/web/app/components/app/log/model-info.spec.tsx @@ -0,0 +1,221 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import ModelInfo from './model-info' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useTextGenerationCurrentProviderAndModelAndModelList: () => ({ + currentModel: { + model: 'gpt-4', + model_display_name: 'GPT-4', + }, + currentProvider: { + provider: 'openai', + label: 'OpenAI', + }, + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-icon', () => ({ + default: ({ modelName }: { provider: unknown, modelName: string }) => ( +
ModelIcon
+ ), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-name', () => ({ + default: ({ modelItem, showMode }: { modelItem: { model: string }, showMode: boolean }) => ( +
+ {modelItem?.model} +
+ ), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +describe('ModelInfo', () => { + const defaultModel = { + name: 'gpt-4', + provider: 'openai', + completion_params: { + temperature: 0.7, + top_p: 0.9, + presence_penalty: 0.1, + max_tokens: 2048, + stop: ['END'], + }, + } + + describe('Rendering', () => { + it('should render model icon', () => { + render() + + expect(screen.getByTestId('model-icon')).toBeInTheDocument() + }) + + it('should render model name', () => { + render() + + expect(screen.getByTestId('model-name')).toBeInTheDocument() + expect(screen.getByTestId('model-name')).toHaveTextContent('gpt-4') + }) + + it('should render info icon button', () => { + const { container } = render() + + // The info button should contain an SVG icon + const svgs = container.querySelectorAll('svg') + expect(svgs.length).toBeGreaterThan(0) + }) + + it('should show model name with showMode prop', () => { + render() + + expect(screen.getByTestId('model-name')).toHaveAttribute('data-show-mode', 'true') + }) + }) + + describe('Info Panel Toggle', () => { + it('should be closed by default', () => { + render() + + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + }) + + it('should open when info button is clicked', () => { + render() + + const trigger = screen.getByTestId('portal-trigger') + fireEvent.click(trigger) + + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true') + }) + + it('should close when info button is clicked again', () => { + render() + + const trigger = screen.getByTestId('portal-trigger') + + // Open + fireEvent.click(trigger) + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true') + + // Close + fireEvent.click(trigger) + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + }) + }) + + describe('Model Parameters Display', () => { + it('should render model params header', () => { + render() + + expect(screen.getByText('detail.modelParams')).toBeInTheDocument() + }) + + it('should render temperature parameter', () => { + render() + + expect(screen.getByText('Temperature')).toBeInTheDocument() + expect(screen.getByText('0.7')).toBeInTheDocument() + }) + + it('should render top_p parameter', () => { + render() + + expect(screen.getByText('Top P')).toBeInTheDocument() + expect(screen.getByText('0.9')).toBeInTheDocument() + }) + + it('should render presence_penalty parameter', () => { + render() + + expect(screen.getByText('Presence Penalty')).toBeInTheDocument() + expect(screen.getByText('0.1')).toBeInTheDocument() + }) + + it('should render max_tokens parameter', () => { + render() + + expect(screen.getByText('Max Token')).toBeInTheDocument() + expect(screen.getByText('2048')).toBeInTheDocument() + }) + + it('should render stop parameter as comma-separated values', () => { + render() + + expect(screen.getByText('Stop')).toBeInTheDocument() + expect(screen.getByText('END')).toBeInTheDocument() + }) + }) + + describe('Missing Parameters', () => { + it('should show dash for missing parameters', () => { + const modelWithNoParams = { + name: 'gpt-4', + provider: 'openai', + completion_params: {}, + } + + render() + + const dashes = screen.getAllByText('-') + expect(dashes.length).toBeGreaterThan(0) + }) + + it('should show dash for non-array stop values', () => { + const modelWithInvalidStop = { + name: 'gpt-4', + provider: 'openai', + completion_params: { + stop: 'not-an-array', + }, + } + + render() + + const stopValues = screen.getAllByText('-') + expect(stopValues.length).toBeGreaterThan(0) + }) + + it('should join array stop values with comma', () => { + const modelWithMultipleStops = { + name: 'gpt-4', + provider: 'openai', + completion_params: { + stop: ['END', 'STOP', 'DONE'], + }, + } + + render() + + expect(screen.getByText('END,STOP,DONE')).toBeInTheDocument() + }) + }) + + describe('Model without completion_params', () => { + it('should handle undefined completion_params', () => { + const modelWithNoCompletionParams = { + name: 'gpt-4', + provider: 'openai', + } + + render() + + expect(screen.getByTestId('model-icon')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/log/var-panel.spec.tsx b/web/app/components/app/log/var-panel.spec.tsx new file mode 100644 index 000000000..eff186e5b --- /dev/null +++ b/web/app/components/app/log/var-panel.spec.tsx @@ -0,0 +1,217 @@ +import { act, fireEvent, render, screen } from '@testing-library/react' +import VarPanel from './var-panel' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/app/components/base/image-uploader/image-preview', () => ({ + default: ({ url, title, onCancel }: { url: string, title: string, onCancel: () => void }) => ( +
+ +
+ ), +})) + +describe('VarPanel', () => { + const defaultProps = { + varList: [ + { label: 'name', value: 'John Doe' }, + { label: 'age', value: '25' }, + ], + message_files: [], + } + + describe('Rendering', () => { + it('should render variables section header', () => { + render() + + expect(screen.getByText('detail.variables')).toBeInTheDocument() + }) + + it('should render variable labels with braces', () => { + render() + + expect(screen.getByText('name')).toBeInTheDocument() + expect(screen.getByText('age')).toBeInTheDocument() + }) + + it('should render variable values', () => { + render() + + expect(screen.getByText('John Doe')).toBeInTheDocument() + expect(screen.getByText('25')).toBeInTheDocument() + }) + + it('should render opening and closing braces', () => { + render() + + const openingBraces = screen.getAllByText('{{') + const closingBraces = screen.getAllByText('}}') + + expect(openingBraces.length).toBe(2) + expect(closingBraces.length).toBe(2) + }) + + it('should render Variable02 icon', () => { + const { container } = render() + + const svg = container.querySelector('svg') + expect(svg).toBeInTheDocument() + }) + }) + + describe('Collapse/Expand', () => { + it('should show expanded state by default', () => { + render() + + expect(screen.getByText('John Doe')).toBeInTheDocument() + expect(screen.getByText('25')).toBeInTheDocument() + }) + + it('should collapse when header is clicked', () => { + render() + + const header = screen.getByText('detail.variables').closest('div') + fireEvent.click(header!) + + expect(screen.queryByText('John Doe')).not.toBeInTheDocument() + expect(screen.queryByText('25')).not.toBeInTheDocument() + }) + + it('should expand when clicked again', () => { + render() + + const header = screen.getByText('detail.variables').closest('div') + + // Collapse + fireEvent.click(header!) + expect(screen.queryByText('John Doe')).not.toBeInTheDocument() + + // Expand + fireEvent.click(header!) + expect(screen.getByText('John Doe')).toBeInTheDocument() + }) + + it('should show arrow icon when collapsed', () => { + const { container } = render() + + const header = screen.getByText('detail.variables').closest('div') + fireEvent.click(header!) + + // When collapsed, there should be SVG icons in the component + const svgs = container.querySelectorAll('svg') + expect(svgs.length).toBeGreaterThan(0) + }) + + it('should show arrow icon when expanded', () => { + const { container } = render() + + // When expanded, there should be SVG icons in the component + const svgs = container.querySelectorAll('svg') + expect(svgs.length).toBeGreaterThan(0) + }) + }) + + describe('Message Files', () => { + it('should not render images section when message_files is empty', () => { + render() + + expect(screen.queryByText('detail.uploadImages')).not.toBeInTheDocument() + }) + + it('should render images section when message_files has items', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg', 'https://example.com/image2.jpg'], + } + + render() + + expect(screen.getByText('detail.uploadImages')).toBeInTheDocument() + }) + + it('should render image thumbnails with correct background', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg'], + } + + const { container } = render() + + const thumbnail = container.querySelector('[style*="background-image"]') + expect(thumbnail).toBeInTheDocument() + expect(thumbnail).toHaveStyle({ backgroundImage: 'url(https://example.com/image1.jpg)' }) + }) + + it('should open image preview when thumbnail is clicked', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg'], + } + + const { container } = render() + + const thumbnail = container.querySelector('[style*="background-image"]') + fireEvent.click(thumbnail!) + + expect(screen.getByTestId('image-preview')).toBeInTheDocument() + expect(screen.getByTestId('image-preview')).toHaveAttribute('data-url', 'https://example.com/image1.jpg') + }) + + it('should close image preview when close button is clicked', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg'], + } + + const { container } = render() + + // Open preview + const thumbnail = container.querySelector('[style*="background-image"]') + fireEvent.click(thumbnail!) + + expect(screen.getByTestId('image-preview')).toBeInTheDocument() + + // Close preview + act(() => { + fireEvent.click(screen.getByTestId('close-preview')) + }) + + expect(screen.queryByTestId('image-preview')).not.toBeInTheDocument() + }) + }) + + describe('Empty State', () => { + it('should render with empty varList', () => { + const emptyProps = { + varList: [], + message_files: [], + } + + render() + + expect(screen.getByText('detail.variables')).toBeInTheDocument() + }) + }) + + describe('Multiple Images', () => { + it('should render multiple image thumbnails', () => { + const propsWithMultipleFiles = { + ...defaultProps, + message_files: [ + 'https://example.com/image1.jpg', + 'https://example.com/image2.jpg', + 'https://example.com/image3.jpg', + ], + } + + const { container } = render() + + const thumbnails = container.querySelectorAll('[style*="background-image"]') + expect(thumbnails.length).toBe(3) + }) + }) +}) diff --git a/web/app/components/app/overview/trigger-card.spec.tsx b/web/app/components/app/overview/trigger-card.spec.tsx new file mode 100644 index 000000000..0ee9da582 --- /dev/null +++ b/web/app/components/app/overview/trigger-card.spec.tsx @@ -0,0 +1,390 @@ +import type { AppDetailResponse } from '@/models/app' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import TriggerCard from './trigger-card' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { count?: number }) => { + if (options?.count !== undefined) + return `${key} (${options.count})` + return key + }, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: true, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +const mockSetTriggerStatus = vi.fn() +const mockSetTriggerStatuses = vi.fn() +vi.mock('@/app/components/workflow/store/trigger-status', () => ({ + useTriggerStatusStore: () => ({ + setTriggerStatus: mockSetTriggerStatus, + setTriggerStatuses: mockSetTriggerStatuses, + }), +})) + +const mockUpdateTriggerStatus = vi.fn() +const mockInvalidateAppTriggers = vi.fn() +let mockTriggers: Array<{ + id: string + node_id: string + title: string + trigger_type: string + status: string + provider_name?: string +}> = [] +let mockIsLoading = false + +vi.mock('@/service/use-tools', () => ({ + useAppTriggers: () => ({ + data: { data: mockTriggers }, + isLoading: mockIsLoading, + }), + useUpdateTriggerStatus: () => ({ + mutateAsync: mockUpdateTriggerStatus, + }), + useInvalidateAppTriggers: () => mockInvalidateAppTriggers, +})) + +vi.mock('@/service/use-triggers', () => ({ + useAllTriggerPlugins: () => ({ + data: [ + { id: 'plugin-1', name: 'Test Plugin', icon: 'test-icon' }, + ], + }), +})) + +vi.mock('@/utils', () => ({ + canFindTool: () => false, +})) + +vi.mock('@/app/components/workflow/block-icon', () => ({ + default: ({ type }: { type: string }) => ( +
BlockIcon
+ ), +})) + +vi.mock('@/app/components/base/switch', () => ({ + default: ({ defaultValue, onChange, disabled }: { defaultValue: boolean, onChange: (v: boolean) => void, disabled: boolean }) => ( + + ), +})) + +describe('TriggerCard', () => { + const mockAppInfo = { + id: 'test-app-id', + name: 'Test App', + description: 'Test description', + mode: AppModeEnum.WORKFLOW, + icon_type: 'emoji', + icon: 'test-icon', + icon_background: '#ffffff', + created_at: Date.now(), + updated_at: Date.now(), + enable_site: true, + enable_api: true, + } as AppDetailResponse + + const mockOnToggleResult = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockTriggers = [] + mockIsLoading = false + mockUpdateTriggerStatus.mockResolvedValue({}) + }) + + describe('Loading State', () => { + it('should render loading skeleton when isLoading is true', () => { + mockIsLoading = true + + const { container } = render( + , + ) + + expect(container.querySelector('.animate-pulse')).toBeInTheDocument() + }) + }) + + describe('Empty State', () => { + it('should show no triggers added message when triggers is empty', () => { + mockTriggers = [] + + render() + + expect(screen.getByText('overview.triggerInfo.noTriggerAdded')).toBeInTheDocument() + }) + + it('should show trigger status description when no triggers', () => { + mockTriggers = [] + + render() + + expect(screen.getByText('overview.triggerInfo.triggerStatusDescription')).toBeInTheDocument() + }) + + it('should show learn more link when no triggers', () => { + mockTriggers = [] + + render() + + const learnMoreLink = screen.getByText('overview.triggerInfo.learnAboutTriggers') + expect(learnMoreLink).toBeInTheDocument() + expect(learnMoreLink).toHaveAttribute('href', 'https://docs.example.com/use-dify/nodes/trigger/overview') + }) + }) + + describe('With Triggers', () => { + beforeEach(() => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Webhook Trigger', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + { + id: 'trigger-2', + node_id: 'node-2', + title: 'Schedule Trigger', + trigger_type: 'trigger-schedule', + status: 'disabled', + }, + ] + }) + + it('should show triggers count message', () => { + render() + + expect(screen.getByText('overview.triggerInfo.triggersAdded (2)')).toBeInTheDocument() + }) + + it('should render trigger titles', () => { + render() + + expect(screen.getByText('Webhook Trigger')).toBeInTheDocument() + expect(screen.getByText('Schedule Trigger')).toBeInTheDocument() + }) + + it('should show running status for enabled triggers', () => { + render() + + expect(screen.getByText('overview.status.running')).toBeInTheDocument() + }) + + it('should show disable status for disabled triggers', () => { + render() + + expect(screen.getByText('overview.status.disable')).toBeInTheDocument() + }) + + it('should render block icons for each trigger', () => { + render() + + const blockIcons = screen.getAllByTestId('block-icon') + expect(blockIcons.length).toBe(2) + }) + + it('should render switches for each trigger', () => { + render() + + const switches = screen.getAllByTestId('switch') + expect(switches.length).toBe(2) + }) + }) + + describe('Toggle Trigger', () => { + beforeEach(() => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Test Trigger', + trigger_type: 'trigger-webhook', + status: 'disabled', + }, + ] + }) + + it('should call updateTriggerStatus when toggle is clicked', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockUpdateTriggerStatus).toHaveBeenCalledWith({ + appId: 'test-app-id', + triggerId: 'trigger-1', + enableTrigger: true, + }) + }) + }) + + it('should update trigger status in store optimistically', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockSetTriggerStatus).toHaveBeenCalledWith('node-1', 'enabled') + }) + }) + + it('should invalidate app triggers after successful update', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockInvalidateAppTriggers).toHaveBeenCalledWith('test-app-id') + }) + }) + + it('should call onToggleResult with null on success', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockOnToggleResult).toHaveBeenCalledWith(null) + }) + }) + + it('should rollback status and call onToggleResult with error on failure', async () => { + const error = new Error('Update failed') + mockUpdateTriggerStatus.mockRejectedValueOnce(error) + + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockSetTriggerStatus).toHaveBeenCalledWith('node-1', 'disabled') + expect(mockOnToggleResult).toHaveBeenCalledWith(error) + }) + }) + }) + + describe('Trigger Types', () => { + it('should render webhook trigger type correctly', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Webhook', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + ] + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', 'trigger-webhook') + }) + + it('should render schedule trigger type correctly', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Schedule', + trigger_type: 'trigger-schedule', + status: 'enabled', + }, + ] + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', 'trigger-schedule') + }) + + it('should render plugin trigger type correctly', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Plugin', + trigger_type: 'trigger-plugin', + status: 'enabled', + provider_name: 'plugin-1', + }, + ] + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', 'trigger-plugin') + }) + }) + + describe('Editor Permissions', () => { + it('should render switches for triggers', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Test Trigger', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + ] + + render() + + const switchBtn = screen.getByTestId('switch') + expect(switchBtn).toBeInTheDocument() + }) + }) + + describe('Status Sync', () => { + it('should sync trigger statuses to store when data loads', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Test', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + { + id: 'trigger-2', + node_id: 'node-2', + title: 'Test 2', + trigger_type: 'trigger-schedule', + status: 'disabled', + }, + ] + + render() + + expect(mockSetTriggerStatuses).toHaveBeenCalledWith({ + 'node-1': 'enabled', + 'node-2': 'disabled', + }) + }) + }) +}) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index beb3c0699..57330dda6 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -35,6 +35,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import { formatTime } from '@/utils/time' import { basePath } from '@/utils/var' @@ -172,13 +173,8 @@ const AppCard = ({ app, onRefresh, onApp = false }: AppCardProps) => { appID: app.id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${app.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${app.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) @@ -436,7 +432,7 @@ const AppCard = ({ app, onRefresh, onApp = false }: AppCardProps) => { dateFormat: `${t('segment.dateTimeFormat', { ns: 'datasetDocuments' })}`, }) return `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${timeText}` - }, [app.updated_at, app.created_at]) + }, [app.updated_at, app.created_at, t]) return ( <> diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index 255bfbf9c..3be849248 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -105,6 +105,7 @@ const Apps = () => { {isShowTryAppPanel && ( { e.stopPropagation() - downloadFile(url || base64Url || '', name) + downloadUrl({ url: url || base64Url || '', fileName: name, target: '_blank' }) }} > diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx index 77dc3e35b..d9118aac4 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx @@ -8,9 +8,9 @@ import Button from '@/app/components/base/button' import { ReplayLine } from '@/app/components/base/icons/src/vender/other' import ImagePreview from '@/app/components/base/image-uploader/image-preview' import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { downloadUrl } from '@/utils/download' import FileImageRender from '../file-image-render' import { - downloadFile, fileIsUploaded, } from '../utils' @@ -85,7 +85,7 @@ const FileImageItem = ({ className="absolute bottom-0.5 right-0.5 flex h-6 w-6 items-center justify-center rounded-lg bg-components-actionbar-bg shadow-md" onClick={(e) => { e.stopPropagation() - downloadFile(download_url || '', name) + downloadUrl({ url: download_url || '', fileName: name, target: '_blank' }) }} > diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx index 828864239..af32f917b 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx @@ -12,10 +12,10 @@ import VideoPreview from '@/app/components/base/file-uploader/video-preview' import { ReplayLine } from '@/app/components/base/icons/src/vender/other' import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' import { cn } from '@/utils/classnames' +import { downloadUrl } from '@/utils/download' import { formatFileSize } from '@/utils/format' import FileTypeIcon from '../file-type-icon' import { - downloadFile, fileIsUploaded, getFileAppearanceType, getFileExtension, @@ -100,7 +100,7 @@ const FileItem = ({ className="absolute -right-1 -top-1 hidden group-hover/file-item:flex" onClick={(e) => { e.stopPropagation() - downloadFile(download_url || '', name) + downloadUrl({ url: download_url || '', fileName: name, target: '_blank' }) }} > diff --git a/web/app/components/base/file-uploader/utils.spec.ts b/web/app/components/base/file-uploader/utils.spec.ts index de167a8c2..f69b3c27f 100644 --- a/web/app/components/base/file-uploader/utils.spec.ts +++ b/web/app/components/base/file-uploader/utils.spec.ts @@ -1,4 +1,3 @@ -import type { MockInstance } from 'vitest' import mime from 'mime' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { upload } from '@/service/base' @@ -6,7 +5,6 @@ import { TransferMethod } from '@/types/app' import { FILE_EXTS } from '../prompt-editor/constants' import { FileAppearanceTypeEnum } from './types' import { - downloadFile, fileIsUploaded, fileUpload, getFileAppearanceType, @@ -782,74 +780,4 @@ describe('file-uploader utils', () => { } as any)).toBe(true) }) }) - - describe('downloadFile', () => { - let mockAnchor: HTMLAnchorElement - let createElementMock: MockInstance - let appendChildMock: MockInstance - let removeChildMock: MockInstance - - beforeEach(() => { - // Mock createElement and appendChild - mockAnchor = { - href: '', - download: '', - style: { display: '' }, - target: '', - title: '', - click: vi.fn(), - } as unknown as HTMLAnchorElement - - createElementMock = vi.spyOn(document, 'createElement').mockReturnValue(mockAnchor as any) - appendChildMock = vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => { - return node - }) - removeChildMock = vi.spyOn(document.body, 'removeChild').mockImplementation((node: Node) => { - return node - }) - }) - - afterEach(() => { - vi.resetAllMocks() - }) - - it('should create and trigger download with correct attributes', () => { - const url = 'https://example.com/test.pdf' - const filename = 'test.pdf' - - downloadFile(url, filename) - - // Verify anchor element was created with correct properties - expect(createElementMock).toHaveBeenCalledWith('a') - expect(mockAnchor.href).toBe(url) - expect(mockAnchor.download).toBe(filename) - expect(mockAnchor.style.display).toBe('none') - expect(mockAnchor.target).toBe('_blank') - expect(mockAnchor.title).toBe(filename) - - // Verify DOM operations - expect(appendChildMock).toHaveBeenCalledWith(mockAnchor) - expect(mockAnchor.click).toHaveBeenCalled() - expect(removeChildMock).toHaveBeenCalledWith(mockAnchor) - }) - - it('should handle empty filename', () => { - const url = 'https://example.com/test.pdf' - const filename = '' - - downloadFile(url, filename) - - expect(mockAnchor.download).toBe('') - expect(mockAnchor.title).toBe('') - }) - - it('should handle empty url', () => { - const url = '' - const filename = 'test.pdf' - - downloadFile(url, filename) - - expect(mockAnchor.href).toBe('') - }) - }) }) diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index 5d5754b8f..23e460db5 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -249,15 +249,3 @@ export const fileIsUploaded = (file: FileEntity) => { if (file.transferMethod === TransferMethod.remote_url && file.progress === 100) return true } - -export const downloadFile = (url: string, filename: string) => { - const anchor = document.createElement('a') - anchor.href = url - anchor.download = filename - anchor.style.display = 'none' - anchor.target = '_blank' - anchor.title = filename - document.body.appendChild(anchor) - anchor.click() - document.body.removeChild(anchor) -} diff --git a/web/app/components/base/image-uploader/image-preview.tsx b/web/app/components/base/image-uploader/image-preview.tsx index b6a07c60a..0641af3d7 100644 --- a/web/app/components/base/image-uploader/image-preview.tsx +++ b/web/app/components/base/image-uploader/image-preview.tsx @@ -8,6 +8,7 @@ import { createPortal } from 'react-dom' import { useHotkeys } from 'react-hotkeys-hook' import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { downloadUrl } from '@/utils/download' type ImagePreviewProps = { url: string @@ -60,27 +61,14 @@ const ImagePreview: FC = ({ const downloadImage = () => { // Open in a new window, considering the case when the page is inside an iframe - if (url.startsWith('http') || url.startsWith('https')) { - const a = document.createElement('a') - a.href = url - a.target = '_blank' - a.download = title - a.click() - } - else if (url.startsWith('data:image')) { - // Base64 image - const a = document.createElement('a') - a.href = url - a.target = '_blank' - a.download = title - a.click() - } - else { - Toast.notify({ - type: 'error', - message: `Unable to open image: ${url}`, - }) + if (url.startsWith('http') || url.startsWith('https') || url.startsWith('data:image')) { + downloadUrl({ url, fileName: title, target: '_blank' }) + return } + Toast.notify({ + type: 'error', + message: `Unable to open image: ${url}`, + }) } const zoomIn = () => { @@ -135,12 +123,7 @@ const ImagePreview: FC = ({ catch (err) { console.error('Failed to copy image:', err) - const link = document.createElement('a') - link.href = url - link.download = `${title}.png` - document.body.appendChild(link) - link.click() - document.body.removeChild(link) + downloadUrl({ url, fileName: `${title}.png` }) Toast.notify({ type: 'info', @@ -215,6 +198,7 @@ const ImagePreview: FC = ({ tabIndex={-1} > { } + {/* eslint-disable-next-line next/no-img-element */} {title} { }, [isShow]) const downloadQR = () => { - const canvas = document.getElementsByTagName('canvas')[0] - const link = document.createElement('a') - link.download = 'qrcode.png' - link.href = canvas.toDataURL() - link.click() + const canvas = qrCodeRef.current?.querySelector('canvas') + if (!(canvas instanceof HTMLCanvasElement)) + return + downloadUrl({ url: canvas.toDataURL(), fileName: 'qrcode.png' }) } const handlePanelClick = (event: React.MouseEvent) => { diff --git a/web/app/components/billing/annotation-full/usage.spec.tsx b/web/app/components/billing/annotation-full/usage.spec.tsx new file mode 100644 index 000000000..c5fd1a2b1 --- /dev/null +++ b/web/app/components/billing/annotation-full/usage.spec.tsx @@ -0,0 +1,57 @@ +import { render, screen } from '@testing-library/react' +import Usage from './usage' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +const mockPlan = { + usage: { + annotatedResponse: 50, + }, + total: { + annotatedResponse: 100, + }, +} + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: mockPlan, + }), +})) + +describe('Usage', () => { + // Rendering: renders UsageInfo with correct props from context + describe('Rendering', () => { + it('should render usage info with data from provider context', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('annotatedResponse.quotaTitle')).toBeInTheDocument() + }) + + it('should pass className to UsageInfo component', () => { + // Arrange + const testClassName = 'mt-4' + + // Act + const { container } = render() + + // Assert + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass(testClassName) + }) + + it('should display usage and total values from context', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('50')).toBeInTheDocument() + expect(screen.getByText('100')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/billing/billing-page/index.spec.tsx b/web/app/components/billing/billing-page/index.spec.tsx index 8b68f7401..f80c688d4 100644 --- a/web/app/components/billing/billing-page/index.spec.tsx +++ b/web/app/components/billing/billing-page/index.spec.tsx @@ -73,6 +73,56 @@ describe('Billing', () => { }) }) + it('returns the refetched url from the async callback', async () => { + const newUrl = 'https://new-billing-url' + refetchMock.mockResolvedValue({ data: newUrl }) + render() + + const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ }) + fireEvent.click(actionButton) + + await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled()) + const [asyncCallback] = openAsyncWindowMock.mock.calls[0] + + // Execute the async callback passed to openAsyncWindow + const result = await asyncCallback() + expect(result).toBe(newUrl) + expect(refetchMock).toHaveBeenCalled() + }) + + it('returns null when refetch returns no url', async () => { + refetchMock.mockResolvedValue({ data: null }) + render() + + const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ }) + fireEvent.click(actionButton) + + await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled()) + const [asyncCallback] = openAsyncWindowMock.mock.calls[0] + + // Execute the async callback when url is null + const result = await asyncCallback() + expect(result).toBeNull() + }) + + it('handles errors in onError callback', async () => { + const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {}) + render() + + const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ }) + fireEvent.click(actionButton) + + await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled()) + const [, options] = openAsyncWindowMock.mock.calls[0] + + // Execute the onError callback + const testError = new Error('Test error') + options.onError(testError) + expect(consoleError).toHaveBeenCalledWith('Failed to fetch billing url', testError) + + consoleError.mockRestore() + }) + it('disables the button while billing url is fetching', () => { fetching = true render() diff --git a/web/app/components/billing/plan/index.spec.tsx b/web/app/components/billing/plan/index.spec.tsx index 473f81f9f..fb1800653 100644 --- a/web/app/components/billing/plan/index.spec.tsx +++ b/web/app/components/billing/plan/index.spec.tsx @@ -125,4 +125,70 @@ describe('PlanComp', () => { expect(setShowAccountSettingModalMock).toHaveBeenCalledWith(null) }) + + it('does not trigger verify when isPending is true', async () => { + isPending = true + render() + + const verifyBtn = screen.getByText('education.toVerified') + fireEvent.click(verifyBtn) + + await waitFor(() => expect(mutateAsyncMock).not.toHaveBeenCalled()) + }) + + it('renders sandbox plan', () => { + providerContextMock.mockReturnValue({ + plan: { ...planMock, type: Plan.sandbox }, + enableEducationPlan: false, + allowRefreshEducationVerify: false, + isEducationAccount: false, + }) + render() + + expect(screen.getByText('billing.plans.sandbox.name')).toBeInTheDocument() + }) + + it('renders team plan', () => { + providerContextMock.mockReturnValue({ + plan: { ...planMock, type: Plan.team }, + enableEducationPlan: false, + allowRefreshEducationVerify: false, + isEducationAccount: false, + }) + render() + + expect(screen.getByText('billing.plans.team.name')).toBeInTheDocument() + }) + + it('shows verify button when education account is about to expire', () => { + providerContextMock.mockReturnValue({ + plan: planMock, + enableEducationPlan: true, + allowRefreshEducationVerify: true, + isEducationAccount: true, + }) + render() + + expect(screen.getByText('education.toVerified')).toBeInTheDocument() + }) + + it('handles modal onConfirm and onCancel callbacks', async () => { + mutateAsyncMock.mockRejectedValueOnce(new Error('boom')) + render() + + // Trigger verify to show modal + const verifyBtn = screen.getByText('education.toVerified') + fireEvent.click(verifyBtn) + + await waitFor(() => expect(screen.getByTestId('verify-modal').getAttribute('data-is-show')).toBe('true')) + + // Get the props passed to the modal and call onConfirm/onCancel + const lastCall = verifyStateModalMock.mock.calls[verifyStateModalMock.mock.calls.length - 1][0] + expect(lastCall.onConfirm).toBeDefined() + expect(lastCall.onCancel).toBeDefined() + + // Call onConfirm to close modal + lastCall.onConfirm() + lastCall.onCancel() + }) }) diff --git a/web/app/components/billing/pricing/assets/index.spec.tsx b/web/app/components/billing/pricing/assets/index.spec.tsx index 7980f9a18..cc56c5759 100644 --- a/web/app/components/billing/pricing/assets/index.spec.tsx +++ b/web/app/components/billing/pricing/assets/index.spec.tsx @@ -52,6 +52,24 @@ describe('Pricing Assets', () => { expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-saas-dify-blue-accessible)')).toBe(true) }) + it('should render inactive state for Cloud', () => { + // Arrange + const { container } = render() + + // Assert + const rects = Array.from(container.querySelectorAll('rect')) + expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-text-primary)')).toBe(true) + }) + + it('should render active state for SelfHosted', () => { + // Arrange + const { container } = render() + + // Assert + const rects = Array.from(container.querySelectorAll('rect')) + expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-saas-dify-blue-accessible)')).toBe(true) + }) + it('should render inactive state for SelfHosted', () => { // Arrange const { container } = render() diff --git a/web/app/components/billing/utils/index.spec.ts b/web/app/components/billing/utils/index.spec.ts new file mode 100644 index 000000000..03a159c18 --- /dev/null +++ b/web/app/components/billing/utils/index.spec.ts @@ -0,0 +1,301 @@ +import type { CurrentPlanInfoBackend } from '../type' +import { DocumentProcessingPriority, Plan } from '../type' +import { getPlanVectorSpaceLimitMB, parseCurrentPlan, parseVectorSpaceToMB } from './index' + +describe('billing utils', () => { + // parseVectorSpaceToMB tests + describe('parseVectorSpaceToMB', () => { + it('should parse MB values correctly', () => { + expect(parseVectorSpaceToMB('50MB')).toBe(50) + expect(parseVectorSpaceToMB('100MB')).toBe(100) + }) + + it('should parse GB values and convert to MB', () => { + expect(parseVectorSpaceToMB('5GB')).toBe(5 * 1024) + expect(parseVectorSpaceToMB('20GB')).toBe(20 * 1024) + }) + + it('should be case insensitive', () => { + expect(parseVectorSpaceToMB('50mb')).toBe(50) + expect(parseVectorSpaceToMB('5gb')).toBe(5 * 1024) + }) + + it('should return 0 for invalid format', () => { + expect(parseVectorSpaceToMB('50')).toBe(0) + expect(parseVectorSpaceToMB('invalid')).toBe(0) + expect(parseVectorSpaceToMB('')).toBe(0) + expect(parseVectorSpaceToMB('50TB')).toBe(0) + }) + }) + + // getPlanVectorSpaceLimitMB tests + describe('getPlanVectorSpaceLimitMB', () => { + it('should return correct vector space for sandbox plan', () => { + expect(getPlanVectorSpaceLimitMB(Plan.sandbox)).toBe(50) + }) + + it('should return correct vector space for professional plan', () => { + expect(getPlanVectorSpaceLimitMB(Plan.professional)).toBe(5 * 1024) + }) + + it('should return correct vector space for team plan', () => { + expect(getPlanVectorSpaceLimitMB(Plan.team)).toBe(20 * 1024) + }) + + it('should return 0 for invalid plan', () => { + // @ts-expect-error - Testing invalid plan input + expect(getPlanVectorSpaceLimitMB('invalid')).toBe(0) + }) + }) + + // parseCurrentPlan tests + describe('parseCurrentPlan', () => { + const createMockPlanData = (overrides: Partial = {}): CurrentPlanInfoBackend => ({ + billing: { + enabled: true, + subscription: { + plan: Plan.sandbox, + }, + }, + members: { + size: 1, + limit: 1, + }, + apps: { + size: 2, + limit: 5, + }, + vector_space: { + size: 10, + limit: 50, + }, + annotation_quota_limit: { + size: 5, + limit: 10, + }, + documents_upload_quota: { + size: 20, + limit: 0, + }, + docs_processing: DocumentProcessingPriority.standard, + can_replace_logo: false, + model_load_balancing_enabled: false, + dataset_operator_enabled: false, + education: { + enabled: false, + activated: false, + }, + webapp_copyright_enabled: false, + workspace_members: { + size: 1, + limit: 1, + }, + is_allow_transfer_workspace: false, + knowledge_pipeline: { + publish_enabled: false, + }, + ...overrides, + }) + + it('should parse plan type correctly', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + expect(result.type).toBe(Plan.sandbox) + }) + + it('should parse usage values correctly', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + + expect(result.usage.vectorSpace).toBe(10) + expect(result.usage.buildApps).toBe(2) + expect(result.usage.teamMembers).toBe(1) + expect(result.usage.annotatedResponse).toBe(5) + expect(result.usage.documentsUploadQuota).toBe(20) + }) + + it('should parse total limits correctly', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + + expect(result.total.vectorSpace).toBe(50) + expect(result.total.buildApps).toBe(5) + expect(result.total.teamMembers).toBe(1) + expect(result.total.annotatedResponse).toBe(10) + }) + + it('should convert 0 limits to NUM_INFINITE (-1)', () => { + const data = createMockPlanData({ + documents_upload_quota: { + size: 20, + limit: 0, + }, + }) + const result = parseCurrentPlan(data) + expect(result.total.documentsUploadQuota).toBe(-1) + }) + + it('should handle api_rate_limit quota', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: null, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.usage.apiRateLimit).toBe(100) + expect(result.total.apiRateLimit).toBe(5000) + }) + + it('should handle trigger_event quota', () => { + const data = createMockPlanData({ + trigger_event: { + usage: 50, + limit: 3000, + reset_date: null, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.usage.triggerEvents).toBe(50) + expect(result.total.triggerEvents).toBe(3000) + }) + + it('should use fallback for api_rate_limit when not provided', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + + // Fallback to plan preset value for sandbox: 5000 + expect(result.total.apiRateLimit).toBe(5000) + }) + + it('should convert 0 or -1 rate limits to NUM_INFINITE', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 0, + limit: 0, + reset_date: null, + }, + }) + const result = parseCurrentPlan(data) + expect(result.total.apiRateLimit).toBe(-1) + + const data2 = createMockPlanData({ + api_rate_limit: { + usage: 0, + limit: -1, + reset_date: null, + }, + }) + const result2 = parseCurrentPlan(data2) + expect(result2.total.apiRateLimit).toBe(-1) + }) + + it('should handle reset dates with milliseconds timestamp', () => { + const futureDate = Date.now() + 86400000 // Tomorrow in ms + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: futureDate, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.reset.apiRateLimit).toBe(1) + }) + + it('should handle reset dates with seconds timestamp', () => { + const futureDate = Math.floor(Date.now() / 1000) + 86400 // Tomorrow in seconds + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: futureDate, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.reset.apiRateLimit).toBe(1) + }) + + it('should handle reset dates in YYYYMMDD format', () => { + const tomorrow = new Date() + tomorrow.setDate(tomorrow.getDate() + 1) + const year = tomorrow.getFullYear() + const month = String(tomorrow.getMonth() + 1).padStart(2, '0') + const day = String(tomorrow.getDate()).padStart(2, '0') + const dateNumber = Number.parseInt(`${year}${month}${day}`, 10) + + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: dateNumber, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.reset.apiRateLimit).toBe(1) + }) + + it('should return null for invalid reset dates', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: 0, + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + + it('should return null for negative reset dates', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: -1, + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + + it('should return null when reset date is in the past', () => { + const pastDate = Date.now() - 86400000 // Yesterday + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: pastDate, + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + + it('should handle missing apps field', () => { + const data = createMockPlanData() + // @ts-expect-error - Testing edge case + delete data.apps + const result = parseCurrentPlan(data) + expect(result.usage.buildApps).toBe(0) + }) + + it('should return null for unrecognized date format', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: 12345, // Unrecognized format + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + }) +}) diff --git a/web/app/components/datasets/api/index.spec.tsx b/web/app/components/datasets/api/index.spec.tsx new file mode 100644 index 000000000..33ee656a2 --- /dev/null +++ b/web/app/components/datasets/api/index.spec.tsx @@ -0,0 +1,24 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it } from 'vitest' +import ApiIndex from './index' + +afterEach(() => { + cleanup() +}) + +describe('ApiIndex', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('index')).toBeInTheDocument() + }) + + it('should render a div with text "index"', () => { + const { container } = render() + expect(container.firstChild).toBeInstanceOf(HTMLDivElement) + expect(container.textContent).toBe('index') + }) + + it('should be a valid function component', () => { + expect(typeof ApiIndex).toBe('function') + }) +}) diff --git a/web/app/components/datasets/chunk.spec.tsx b/web/app/components/datasets/chunk.spec.tsx new file mode 100644 index 000000000..d3dc011ae --- /dev/null +++ b/web/app/components/datasets/chunk.spec.tsx @@ -0,0 +1,111 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it } from 'vitest' +import { ChunkContainer, ChunkLabel, QAPreview } from './chunk' + +afterEach(() => { + cleanup() +}) + +describe('ChunkLabel', () => { + it('should render label text', () => { + render() + expect(screen.getByText('Chunk 1')).toBeInTheDocument() + }) + + it('should render character count', () => { + render() + expect(screen.getByText('150 characters')).toBeInTheDocument() + }) + + it('should render separator dot', () => { + render() + expect(screen.getByText('·')).toBeInTheDocument() + }) + + it('should render with zero character count', () => { + render() + expect(screen.getByText('0 characters')).toBeInTheDocument() + }) + + it('should render with large character count', () => { + render() + expect(screen.getByText('999999 characters')).toBeInTheDocument() + }) +}) + +describe('ChunkContainer', () => { + it('should render label and character count', () => { + render(Content) + expect(screen.getByText('Container 1')).toBeInTheDocument() + expect(screen.getByText('200 characters')).toBeInTheDocument() + }) + + it('should render children content', () => { + render(Test Content) + expect(screen.getByText('Test Content')).toBeInTheDocument() + }) + + it('should render with complex children', () => { + render( + +
+ Nested content +
+
, + ) + expect(screen.getByTestId('child-div')).toBeInTheDocument() + expect(screen.getByText('Nested content')).toBeInTheDocument() + }) + + it('should render empty children', () => { + render({null}) + expect(screen.getByText('Empty')).toBeInTheDocument() + }) +}) + +describe('QAPreview', () => { + const mockQA = { + question: 'What is the meaning of life?', + answer: 'The meaning of life is 42.', + } + + it('should render question text', () => { + render() + expect(screen.getByText('What is the meaning of life?')).toBeInTheDocument() + }) + + it('should render answer text', () => { + render() + expect(screen.getByText('The meaning of life is 42.')).toBeInTheDocument() + }) + + it('should render Q label', () => { + render() + expect(screen.getByText('Q')).toBeInTheDocument() + }) + + it('should render A label', () => { + render() + expect(screen.getByText('A')).toBeInTheDocument() + }) + + it('should render with empty strings', () => { + render() + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('A')).toBeInTheDocument() + }) + + it('should render with long text', () => { + const longQuestion = 'Q'.repeat(500) + const longAnswer = 'A'.repeat(500) + render() + expect(screen.getByText(longQuestion)).toBeInTheDocument() + expect(screen.getByText(longAnswer)).toBeInTheDocument() + }) + + it('should render with special characters', () => { + render(?', answer: '& special chars!' }} />) + expect(screen.getByText('What about & < > " \'' - renderWithProviders( - , + it('should show error notification when operation fails', async () => { + vi.useFakeTimers() + mockEnable.mockRejectedValue(new Error('API Error')) + render( + , ) - - // Act - hover to show tooltip - const tooltipTrigger = screen.getByTestId('error-tooltip-trigger') - fireEvent.mouseEnter(tooltipTrigger) - - // Assert - await waitFor(() => { - expect(screen.getByText(specialChars)).toBeInTheDocument() + const switchElement = document.querySelector('[role="switch"]') + await act(async () => { + fireEvent.click(switchElement!) }) - }) - - it('should handle all status types in sequence', () => { - // Arrange - const statuses: DocumentDisplayStatus[] = [ - 'queuing', - 'indexing', - 'paused', - 'error', - 'available', - 'enabled', - 'disabled', - 'archived', - ] - - // Act & Assert - statuses.forEach((status) => { - const { unmount } = renderWithProviders() - const indicator = screen.getByTestId('status-indicator') - expect(indicator).toBeInTheDocument() - unmount() + await act(async () => { + vi.advanceTimersByTime(600) + // Flush promises + await Promise.resolve() }) + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'actionMsg.modifiedUnsuccessfully', + }) + vi.useRealTimers() }) }) - // ==================== Component Memoization ==================== - // Test React.memo behavior - describe('Component Memoization', () => { + describe('status color mapping', () => { + it('should have correct color class for green status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-green-green-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for orange status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-warning-warning-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for red status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-red-red-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for blue status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-blue-light-blue-light-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for gray status', () => { + const { container } = render() + const text = container.querySelector('.text-text-tertiary') + expect(text).toBeInTheDocument() + }) + }) + + describe('memoization', () => { it('should be wrapped with React.memo', () => { - // Assert - expect(StatusItem).toHaveProperty('$$typeof', Symbol.for('react.memo')) - }) - - it('should render correctly with same props', () => { - // Arrange - const props = { - status: 'available' as const, - scene: 'detail' as const, - detail: createDetailProps(), - } - - // Act - const { rerender } = renderWithProviders() - rerender( - - - , - ) - - // Assert - const indicator = screen.getByTestId('status-indicator') - expect(indicator).toBeInTheDocument() - }) - - it('should update when status prop changes', () => { - // Arrange - const { rerender } = renderWithProviders() - - // Assert initial - green/success background - let indicator = screen.getByTestId('status-indicator') - expect(indicator).toHaveClass('bg-components-badge-status-light-success-bg') - - // Act - rerender( - - - , - ) - - // Assert updated - red/error background - indicator = screen.getByTestId('status-indicator') - expect(indicator).toHaveClass('bg-components-badge-status-light-error-bg') + expect((StatusItem as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) }) }) - // ==================== Styling Tests ==================== - // Test CSS classes and styling - describe('Styling', () => { - it('should apply correct status text color for green status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.available') - expect(statusText).toHaveClass('text-util-colors-green-green-600') - }) - - it('should apply correct status text color for red status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.error') - expect(statusText).toHaveClass('text-util-colors-red-red-600') - }) - - it('should apply correct status text color for orange status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.queuing') - expect(statusText).toHaveClass('text-util-colors-warning-warning-600') - }) - - it('should apply correct status text color for blue status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.indexing') - expect(statusText).toHaveClass('text-util-colors-blue-light-blue-light-600') - }) - - it('should apply correct status text color for gray status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.disabled') - expect(statusText).toHaveClass('text-text-tertiary') - }) - - it('should render switch with md size in detail scene', () => { - // Arrange & Act - renderWithProviders( + describe('default props', () => { + it('should work with default datasetId', () => { + render( , ) + const switchElement = document.querySelector('[role="switch"]') + expect(switchElement).toBeInTheDocument() + }) - // Assert - check switch has the md size class (h-4 w-7) - const switchEl = screen.getByRole('switch') - expect(switchEl).toHaveClass('h-4', 'w-7') + it('should work without detail prop', () => { + render() + expect(screen.getByText('Available')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/extra-info/api-access/index.spec.tsx b/web/app/components/datasets/extra-info/api-access/index.spec.tsx index fb4930cbd..19e6b1ebc 100644 --- a/web/app/components/datasets/extra-info/api-access/index.spec.tsx +++ b/web/app/components/datasets/extra-info/api-access/index.spec.tsx @@ -1,792 +1,137 @@ -import type { DataSet } from '@/models/datasets' -import { render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -// ============================================================================ -// Component Imports (after mocks) -// ============================================================================ - -import Card from './card' +import { act, cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' import ApiAccess from './index' -// ============================================================================ -// Mock Setup -// ============================================================================ - -// Mock next/navigation -vi.mock('next/navigation', () => ({ - useRouter: () => ({ - push: vi.fn(), - replace: vi.fn(), +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, }), - usePathname: () => '/test', - useSearchParams: () => new URLSearchParams(), })) -// Mock next/link -vi.mock('next/link', () => ({ - default: ({ children, href, ...props }: { children: React.ReactNode, href: string, [key: string]: unknown }) => ( - {children} - ), -})) - -// Dataset context mock data -const mockDataset: Partial = { - id: 'dataset-123', - name: 'Test Dataset', - enable_api: true, -} - -// Mock use-context-selector -vi.mock('use-context-selector', () => ({ - useContext: vi.fn(() => ({ dataset: mockDataset })), - useContextSelector: vi.fn((_, selector) => selector({ dataset: mockDataset })), - createContext: vi.fn(() => ({})), -})) - -// Mock dataset detail context -const mockMutateDatasetRes = vi.fn() +// Mock context and hooks for Card component vi.mock('@/context/dataset-detail', () => ({ - default: {}, - useDatasetDetailContext: vi.fn(() => ({ - dataset: mockDataset, - mutateDatasetRes: mockMutateDatasetRes, - })), - useDatasetDetailContextWithSelector: vi.fn((selector: (v: { dataset?: typeof mockDataset, mutateDatasetRes?: () => void }) => unknown) => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ), + useDatasetDetailContextWithSelector: vi.fn(() => 'test-dataset-id'), })) -// Mock app context for workspace permissions -let mockIsCurrentWorkspaceManager = true vi.mock('@/context/app-context', () => ({ - useSelector: vi.fn((selector: (state: { isCurrentWorkspaceManager: boolean }) => unknown) => - selector({ isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager }), - ), + useSelector: vi.fn(() => true), })) -// Mock service hooks -const mockEnableDatasetServiceApi = vi.fn(() => Promise.resolve({ result: 'success' })) -const mockDisableDatasetServiceApi = vi.fn(() => Promise.resolve({ result: 'success' })) +vi.mock('@/hooks/use-api-access-url', () => ({ + useDatasetApiAccessUrl: vi.fn(() => 'https://api.example.com/docs'), +})) vi.mock('@/service/knowledge/use-dataset', () => ({ - useDatasetApiBaseUrl: vi.fn(() => ({ - data: { api_base_url: 'https://api.example.com' }, - isLoading: false, - })), - useEnableDatasetServiceApi: vi.fn(() => ({ - mutateAsync: mockEnableDatasetServiceApi, - isPending: false, - })), - useDisableDatasetServiceApi: vi.fn(() => ({ - mutateAsync: mockDisableDatasetServiceApi, - isPending: false, - })), + useEnableDatasetServiceApi: vi.fn(() => ({ mutateAsync: vi.fn() })), + useDisableDatasetServiceApi: vi.fn(() => ({ mutateAsync: vi.fn() })), })) -// Mock API access URL hook -vi.mock('@/hooks/use-api-access-url', () => ({ - useDatasetApiAccessUrl: vi.fn(() => 'https://docs.dify.ai/api-reference/datasets'), -})) - -// ============================================================================ -// ApiAccess Component Tests -// ============================================================================ +afterEach(() => { + cleanup() +}) describe('ApiAccess', () => { - beforeEach(() => { - vi.clearAllMocks() + it('should render without crashing', () => { + render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // Rendering Tests - // -------------------------------------------------------------------------- - describe('Rendering', () => { - it('should render without crashing', () => { - render() - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should render API title when expanded', () => { - render() - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should not render API title when collapsed', () => { - render() - expect(screen.queryByText(/appMenus\.apiAccess/i)).not.toBeInTheDocument() - }) - - it('should render ApiAggregate icon', () => { - const { container } = render() - const icon = container.querySelector('svg') - expect(icon).toBeInTheDocument() - }) - - it('should render Indicator component', () => { - const { container } = render() - const indicatorElement = container.querySelector('.relative.flex.h-8') - expect(indicatorElement).toBeInTheDocument() - }) - - it('should render with proper container padding', () => { - const { container } = render() - const wrapper = container.firstChild as HTMLElement - expect(wrapper).toHaveClass('p-3', 'pt-2') - }) + it('should render API access text when expanded', () => { + render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // Props Variations Tests - // -------------------------------------------------------------------------- - describe('Props Variations', () => { - it('should apply compressed layout when expand is false', () => { - const { container } = render() - const triggerContainer = container.querySelector('[class*="w-8"]') - expect(triggerContainer).toBeInTheDocument() - }) - - it('should apply full width when expand is true', () => { - const { container } = render() - const trigger = container.querySelector('.w-full') - expect(trigger).toBeInTheDocument() - }) - - it('should pass apiEnabled=true to Indicator with green color', () => { - const { container } = render() - // Indicator uses color prop - test the visual presence - const indicatorContainer = container.querySelector('.relative.flex.h-8') - expect(indicatorContainer).toBeInTheDocument() - }) - - it('should pass apiEnabled=false to Indicator with yellow color', () => { - const { container } = render() - const indicatorContainer = container.querySelector('.relative.flex.h-8') - expect(indicatorContainer).toBeInTheDocument() - }) - - it('should position Indicator absolutely when collapsed', () => { - const { container } = render() - // When collapsed, Indicator has 'absolute -right-px -top-px' classes - const triggerDiv = container.querySelector('[class*="w-8"][class*="justify-center"]') - expect(triggerDiv).toBeInTheDocument() - }) + it('should not render API access text when collapsed', () => { + render() + expect(screen.queryByText('appMenus.apiAccess')).not.toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // User Interactions Tests - // -------------------------------------------------------------------------- - describe('User Interactions', () => { - it('should toggle popup open state on click', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - expect(trigger).toBeInTheDocument() - - if (trigger) - await user.click(trigger) - - // After click, the popup should toggle (Card should be rendered via portal) - }) - - it('should apply hover styles on trigger', () => { - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('div[class*="cursor-pointer"]') - expect(trigger).toHaveClass('cursor-pointer') - }) - - it('should toggle open state from false to true on first click', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // The handleToggle function should flip open from false to true - }) - - it('should toggle open state back to false on second click', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) { - await user.click(trigger) // open - await user.click(trigger) // close - } - - // The handleToggle function should flip open from true to false - }) - - it('should apply open state styling when popup is open', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // When open, the trigger should have bg-state-base-hover class - }) + it('should render with apiEnabled=true', () => { + render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // Portal and Card Integration Tests - // -------------------------------------------------------------------------- - describe('Portal and Card Integration', () => { - it('should render Card component inside portal when open', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // Wait for portal content to appear - await waitFor(() => { - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - }) - - it('should pass apiEnabled prop to Card component', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - await waitFor(() => { - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() - }) - }) - - it('should use correct portal placement configuration', () => { - render() - // PortalToFollowElem is configured with placement="top-start" - // The component should render without errors - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should use correct portal offset configuration', () => { - render() - // PortalToFollowElem is configured with offset={{ mainAxis: 4, crossAxis: -4 }} - // The component should render without errors - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - }) - - // -------------------------------------------------------------------------- - // Edge Cases Tests - // -------------------------------------------------------------------------- - describe('Edge Cases', () => { - it('should handle rapid toggle clicks gracefully', async () => { - const user = userEvent.setup() - - const { container } = render() - - // Use a more specific selector to find the trigger in the main component - const trigger = container.querySelector('.p-3 [class*="cursor-pointer"]') - if (trigger) { - // Rapid clicks - await user.click(trigger) - await user.click(trigger) - await user.click(trigger) - } - - // Component should handle state changes without errors - use getAllByText since Card may be open - const elements = screen.getAllByText(/appMenus\.apiAccess/i) - expect(elements.length).toBeGreaterThan(0) - }) - - it('should render correctly when both expand and apiEnabled are false', () => { - render() - // Should render without title but with indicator - expect(screen.queryByText(/appMenus\.apiAccess/i)).not.toBeInTheDocument() - }) - - it('should maintain state across prop changes', () => { - const { rerender } = render() - - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - - rerender() - - // Component should still render after prop change - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - }) - - // -------------------------------------------------------------------------- - // Memoization Tests - // -------------------------------------------------------------------------- - describe('Memoization', () => { - it('should be memoized with React.memo', () => { - const { rerender } = render() - - rerender() - - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should not re-render unnecessarily with same props', () => { - const { rerender } = render() - - rerender() - rerender() - - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - }) -}) - -// ============================================================================ -// Card Component Tests -// ============================================================================ - -describe('Card (api-access)', () => { - beforeEach(() => { - vi.clearAllMocks() - mockIsCurrentWorkspaceManager = true - mockEnableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - mockDisableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - }) - - // -------------------------------------------------------------------------- - // Rendering Tests - // -------------------------------------------------------------------------- - describe('Rendering', () => { - it('should render without crashing', () => { - render() - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - - it('should display enabled status when API is enabled', () => { - render() - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - - it('should display disabled status when API is disabled', () => { - render() - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() - }) - - it('should render switch component', () => { - render() - expect(screen.getByRole('switch')).toBeInTheDocument() - }) - - it('should render API Reference link', () => { - render() - expect(screen.getByText(/overview\.apiInfo\.doc/i)).toBeInTheDocument() - }) - - it('should render Indicator component', () => { - const { container } = render() - // Indicator is rendered - verify card structure - const cardContainer = container.querySelector('.w-\\[208px\\]') - expect(cardContainer).toBeInTheDocument() - }) - - it('should render description tip text', () => { - render() - expect(screen.getByText(/appMenus\.apiAccessTip/i)).toBeInTheDocument() - }) - - it('should apply success text color when enabled', () => { - render() - const statusText = screen.getByText(/serviceApi\.enabled/i) - expect(statusText).toHaveClass('text-text-success') - }) - - it('should apply warning text color when disabled', () => { - render() - const statusText = screen.getByText(/serviceApi\.disabled/i) - expect(statusText).toHaveClass('text-text-warning') - }) - }) - - // -------------------------------------------------------------------------- - // User Interactions Tests - // -------------------------------------------------------------------------- - describe('User Interactions', () => { - it('should call enableDatasetServiceApi when switch is toggled on', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') - }) - }) - - it('should call disableDatasetServiceApi when switch is toggled off', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockDisableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') - }) - }) - - it('should call mutateDatasetRes after successful API enable', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockMutateDatasetRes).toHaveBeenCalled() - }) - }) - - it('should call mutateDatasetRes after successful API disable', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockMutateDatasetRes).toHaveBeenCalled() - }) - }) - - it('should not call mutateDatasetRes on API enable failure', async () => { - mockEnableDatasetServiceApi.mockResolvedValueOnce({ result: 'fail' }) - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalled() - }) - - expect(mockMutateDatasetRes).not.toHaveBeenCalled() - }) - - it('should not call mutateDatasetRes on API disable failure', async () => { - mockDisableDatasetServiceApi.mockResolvedValueOnce({ result: 'fail' }) - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockDisableDatasetServiceApi).toHaveBeenCalled() - }) - - expect(mockMutateDatasetRes).not.toHaveBeenCalled() - }) - - it('should have correct href for API Reference link', () => { - render() - - const apiRefLink = screen.getByText(/overview\.apiInfo\.doc/i).closest('a') - expect(apiRefLink).toHaveAttribute('href', 'https://docs.dify.ai/api-reference/datasets') - }) - - it('should open API Reference in new tab', () => { - render() - - const apiRefLink = screen.getByText(/overview\.apiInfo\.doc/i).closest('a') - expect(apiRefLink).toHaveAttribute('target', '_blank') - expect(apiRefLink).toHaveAttribute('rel', 'noopener noreferrer') - }) - }) - - // -------------------------------------------------------------------------- - // Permission Handling Tests - // -------------------------------------------------------------------------- - describe('Permission Handling', () => { - it('should disable switch when user is not workspace manager', () => { - mockIsCurrentWorkspaceManager = false - - render() - - const switchButton = screen.getByRole('switch') - expect(switchButton).toHaveClass('!cursor-not-allowed') - expect(switchButton).toHaveClass('!opacity-50') - }) - - it('should enable switch when user is workspace manager', () => { - mockIsCurrentWorkspaceManager = true - - render() - - const switchButton = screen.getByRole('switch') - expect(switchButton).not.toHaveClass('!cursor-not-allowed') - expect(switchButton).not.toHaveClass('!opacity-50') - }) - - it('should not trigger API call when switch is disabled and clicked', async () => { - mockIsCurrentWorkspaceManager = false - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - // API should not be called when disabled - expect(mockEnableDatasetServiceApi).not.toHaveBeenCalled() - }) - }) - - // -------------------------------------------------------------------------- - // Edge Cases Tests - // -------------------------------------------------------------------------- - describe('Edge Cases', () => { - it('should handle empty datasetId gracefully', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - return selector({ - dataset: { ...mockDataset, id: '' } as DataSet, - mutateDatasetRes: mockMutateDatasetRes, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('') - }) - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - - it('should handle undefined datasetId gracefully when enabling API', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - const partialDataset = { ...mockDataset } as Partial - delete partialDataset.id - return selector({ - dataset: partialDataset as DataSet, - mutateDatasetRes: mockMutateDatasetRes, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - // Should use fallback empty string - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('') - }) - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - - it('should handle undefined datasetId gracefully when disabling API', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - const partialDataset = { ...mockDataset } as Partial - delete partialDataset.id - return selector({ - dataset: partialDataset as DataSet, - mutateDatasetRes: mockMutateDatasetRes, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - // Should use fallback empty string for disableDatasetServiceApi - expect(mockDisableDatasetServiceApi).toHaveBeenCalledWith('') - }) - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - - it('should handle undefined mutateDatasetRes gracefully', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - return selector({ - dataset: mockDataset as DataSet, - mutateDatasetRes: undefined, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalled() - }) - - // Should not throw error when mutateDatasetRes is undefined - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - }) - - // -------------------------------------------------------------------------- - // Memoization Tests - // -------------------------------------------------------------------------- - describe('Memoization', () => { - it('should be memoized with React.memo', () => { - const { rerender } = render() - - rerender() - - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - - it('should use useCallback for onToggle handler', () => { - const { rerender } = render() - - rerender() - - // Component should render without issues with memoized callbacks - expect(screen.getByRole('switch')).toBeInTheDocument() - }) - - it('should update when apiEnabled prop changes', () => { - const { rerender } = render() - - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - - rerender() - - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() - }) - }) -}) - -// ============================================================================ -// Integration Tests -// ============================================================================ - -describe('ApiAccess Integration', () => { - beforeEach(() => { - vi.clearAllMocks() - mockIsCurrentWorkspaceManager = true - mockEnableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - mockDisableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - }) - - it('should open Card popup and toggle API status', async () => { - const user = userEvent.setup() - + it('should render with apiEnabled=false', () => { render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() + }) - // Open popup - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) + it('should be wrapped with React.memo', () => { + expect((ApiAccess as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) - // Wait for Card to appear - await waitFor(() => { - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() + describe('toggle functionality', () => { + it('should toggle open state when trigger is clicked', async () => { + const { container } = render() + const trigger = container.querySelector('.cursor-pointer') + expect(trigger).toBeInTheDocument() + + // Click to open + await act(async () => { + fireEvent.click(trigger!) + }) + + // The component should update its state - check for state change via class + expect(trigger).toBeInTheDocument() }) - // Toggle API on - const switchButton = screen.getByRole('switch') - await user.click(switchButton) + it('should toggle open state multiple times', async () => { + const { container } = render() + const trigger = container.querySelector('.cursor-pointer') - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') + // First click - open + await act(async () => { + fireEvent.click(trigger!) + }) + + // Second click - close + await act(async () => { + fireEvent.click(trigger!) + }) + + expect(trigger).toBeInTheDocument() + }) + + it('should work when collapsed', async () => { + const { container } = render() + const trigger = container.querySelector('.cursor-pointer') + + await act(async () => { + fireEvent.click(trigger!) + }) + + expect(trigger).toBeInTheDocument() }) }) - it('should complete full workflow: open -> view status -> toggle -> verify callback', async () => { - const user = userEvent.setup() - - render() - - // Open popup - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // Verify enabled status is shown - await waitFor(() => { - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() + describe('indicator color', () => { + it('should render with green indicator when apiEnabled is true', () => { + const { container } = render() + // Indicator component should be present + const indicator = container.querySelector('.shrink-0') + expect(indicator).toBeInTheDocument() }) - // Toggle API off - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - // Verify API call and callback - await waitFor(() => { - expect(mockDisableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') - expect(mockMutateDatasetRes).toHaveBeenCalled() + it('should render with yellow indicator when apiEnabled is false', () => { + const { container } = render() + const indicator = container.querySelector('.shrink-0') + expect(indicator).toBeInTheDocument() }) }) - it('should navigate to API Reference from Card', async () => { - const user = userEvent.setup() - - render() - - // Open popup - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // Wait for Card to appear - await waitFor(() => { - expect(screen.getByText(/overview\.apiInfo\.doc/i)).toBeInTheDocument() + describe('layout', () => { + it('should have justify-center when collapsed', () => { + const { container } = render() + const trigger = container.querySelector('.justify-center') + expect(trigger).toBeInTheDocument() }) - // Verify link - const apiRefLink = screen.getByText(/overview\.apiInfo\.doc/i).closest('a') - expect(apiRefLink).toHaveAttribute('href', 'https://docs.dify.ai/api-reference/datasets') + it('should not have justify-center when expanded', () => { + const { container } = render() + const innerDiv = container.querySelector('.cursor-pointer') + // When expanded, should have gap-2 and text, not justify-center + expect(innerDiv).not.toHaveClass('justify-center') + }) }) }) diff --git a/web/app/components/datasets/extra-info/statistics.spec.tsx b/web/app/components/datasets/extra-info/statistics.spec.tsx new file mode 100644 index 000000000..d7f79a1ab --- /dev/null +++ b/web/app/components/datasets/extra-info/statistics.spec.tsx @@ -0,0 +1,87 @@ +import type { RelatedApp, RelatedAppResponse } from '@/models/datasets' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import { AppModeEnum } from '@/types/app' +import Statistics from './statistics' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock useDocLink +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +afterEach(() => { + cleanup() +}) + +describe('Statistics', () => { + const mockRelatedApp: RelatedApp = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon_type: 'emoji', + icon: '🤖', + icon_background: '#ffffff', + icon_url: '', + } + + const mockRelatedApps: RelatedAppResponse = { + data: [mockRelatedApp], + total: 1, + } + + it('should render document count', () => { + render() + expect(screen.getByText('5')).toBeInTheDocument() + }) + + it('should render document label', () => { + render() + expect(screen.getByText('datasetMenus.documents')).toBeInTheDocument() + }) + + it('should render related apps total', () => { + render() + expect(screen.getByText('1')).toBeInTheDocument() + }) + + it('should render related app label', () => { + render() + expect(screen.getByText('datasetMenus.relatedApp')).toBeInTheDocument() + }) + + it('should render -- for undefined document count', () => { + render() + expect(screen.getByText('--')).toBeInTheDocument() + }) + + it('should render -- for undefined related apps total', () => { + render() + const dashes = screen.getAllByText('--') + expect(dashes.length).toBeGreaterThan(0) + }) + + it('should render with zero document count', () => { + render() + expect(screen.getByText('0')).toBeInTheDocument() + }) + + it('should render with empty related apps', () => { + const emptyRelatedApps: RelatedAppResponse = { + data: [], + total: 0, + } + render() + expect(screen.getByText('0')).toBeInTheDocument() + }) + + it('should be wrapped with React.memo', () => { + expect((Statistics as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) +}) diff --git a/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx b/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx index 3ae95ec53..04ebd16f6 100644 --- a/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx +++ b/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx @@ -3,8 +3,6 @@ import type { FC, ReactNode } from 'react' import type { SliceProps } from './type' import { autoUpdate, flip, FloatingFocusManager, offset, shift, useDismiss, useFloating, useHover, useInteractions, useRole } from '@floating-ui/react' import { RiDeleteBinLine } from '@remixicon/react' -// @ts-expect-error no types available -import lineClamp from 'line-clamp' import { useState } from 'react' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' import { cn } from '@/utils/classnames' @@ -58,12 +56,8 @@ export const EditSlice: FC = (props) => { <> { - refs.setReference(ref) - if (ref) - lineClamp(ref, 4) - }} + className={cn('mr-0 line-clamp-4 block', className)} + ref={refs.setReference} {...getReferenceProps()} > { + cleanup() +}) + +describe('DatasetsLoading', () => { + it('should render null', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('should not throw on multiple renders', () => { + expect(() => { + render() + render() + }).not.toThrow() + }) +}) diff --git a/web/app/components/datasets/no-linked-apps-panel.spec.tsx b/web/app/components/datasets/no-linked-apps-panel.spec.tsx new file mode 100644 index 000000000..aa66e43fb --- /dev/null +++ b/web/app/components/datasets/no-linked-apps-panel.spec.tsx @@ -0,0 +1,58 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import NoLinkedAppsPanel from './no-linked-apps-panel' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock useDocLink +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +afterEach(() => { + cleanup() +}) + +describe('NoLinkedAppsPanel', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('datasetMenus.emptyTip')).toBeInTheDocument() + }) + + it('should render the empty tip text', () => { + render() + expect(screen.getByText('datasetMenus.emptyTip')).toBeInTheDocument() + }) + + it('should render the view doc link', () => { + render() + expect(screen.getByText('datasetMenus.viewDoc')).toBeInTheDocument() + }) + + it('should render link with correct href', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveAttribute('href', 'https://docs.example.com/use-dify/knowledge/integrate-knowledge-within-application') + }) + + it('should render link with target="_blank"', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveAttribute('target', '_blank') + }) + + it('should render link with rel="noopener noreferrer"', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should be wrapped with React.memo', () => { + expect((NoLinkedAppsPanel as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) +}) diff --git a/web/app/components/datasets/preview/index.spec.tsx b/web/app/components/datasets/preview/index.spec.tsx new file mode 100644 index 000000000..56638fb61 --- /dev/null +++ b/web/app/components/datasets/preview/index.spec.tsx @@ -0,0 +1,25 @@ +import { cleanup, render } from '@testing-library/react' +import { afterEach, describe, expect, it } from 'vitest' +import DatasetPreview from './index' + +afterEach(() => { + cleanup() +}) + +describe('DatasetPreview', () => { + it('should render null', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('should be a valid function component', () => { + expect(typeof DatasetPreview).toBe('function') + }) + + it('should not throw on multiple renders', () => { + expect(() => { + render() + render() + }).not.toThrow() + }) +}) diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index 1993c9fd8..ca072cfca 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import { IS_CE_EDITION } from '@/config' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDocLink } from '@/context/i18n' @@ -359,7 +360,7 @@ const Form = () => { { indexMethod === IndexingType.QUALIFIED && [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode) - && ( + && IS_CE_EDITION && ( <> ({ + default: ({ isShow, onClose }: { isShow: boolean, onClose: () => void }) => ( + isShow ?
: null + ), +})) + +describe('ApiServer', () => { + const defaultProps = { + apiBaseUrl: 'https://api.example.com', + } + + describe('rendering', () => { + it('should render the API server label', () => { + render() + expect(screen.getByText('appApi.apiServer')).toBeInTheDocument() + }) + + it('should render the API base URL', () => { + render() + expect(screen.getByText('https://api.example.com')).toBeInTheDocument() + }) + + it('should render the OK status badge', () => { + render() + expect(screen.getByText('appApi.ok')).toBeInTheDocument() + }) + + it('should render the API key button', () => { + render() + expect(screen.getByText('appApi.apiKey')).toBeInTheDocument() + }) + + it('should render CopyFeedback component', () => { + render() + // CopyFeedback renders a button for copying + const copyButtons = screen.getAllByRole('button') + expect(copyButtons.length).toBeGreaterThan(0) + }) + }) + + describe('with different API URLs', () => { + it('should render localhost URL', () => { + render() + expect(screen.getByText('http://localhost:3000/api')).toBeInTheDocument() + }) + + it('should render production URL', () => { + render() + expect(screen.getByText('https://api.dify.ai/v1')).toBeInTheDocument() + }) + + it('should render URL with path', () => { + render() + expect(screen.getByText('https://api.example.com/v1/chat')).toBeInTheDocument() + }) + }) + + describe('with appId prop', () => { + it('should render without appId', () => { + render() + expect(screen.getByText('https://api.example.com')).toBeInTheDocument() + }) + + it('should render with appId', () => { + render() + expect(screen.getByText('https://api.example.com')).toBeInTheDocument() + }) + }) + + describe('SecretKeyButton interaction', () => { + it('should open modal when API key button is clicked', async () => { + const user = userEvent.setup() + render() + + const apiKeyButton = screen.getByText('appApi.apiKey') + await act(async () => { + await user.click(apiKeyButton) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + + it('should close modal when close button is clicked', async () => { + const user = userEvent.setup() + render() + + // Open modal + const apiKeyButton = screen.getByText('appApi.apiKey') + await act(async () => { + await user.click(apiKeyButton) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + + // Close modal + const closeButton = screen.getByText('Close Modal') + await act(async () => { + await user.click(closeButton) + }) + + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have flex layout with wrap', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('flex') + expect(wrapper.className).toContain('flex-wrap') + }) + + it('should have items-center alignment', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('items-center') + }) + + it('should have gap-y-2 for vertical spacing', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('gap-y-2') + }) + + it('should apply green styling to OK badge', () => { + render() + const okBadge = screen.getByText('appApi.ok') + expect(okBadge.className).toContain('bg-[#ECFDF3]') + expect(okBadge.className).toContain('text-[#039855]') + }) + + it('should have border styling on URL container', () => { + render() + const urlText = screen.getByText('https://api.example.com') + const urlContainer = urlText.closest('div[class*="rounded-lg"]') + expect(urlContainer).toBeInTheDocument() + }) + }) + + describe('API server label', () => { + it('should have correct styling for label', () => { + render() + const label = screen.getByText('appApi.apiServer') + expect(label.className).toContain('rounded-md') + expect(label.className).toContain('border') + }) + + it('should have tertiary text color on label', () => { + render() + const label = screen.getByText('appApi.apiServer') + expect(label.className).toContain('text-text-tertiary') + }) + }) + + describe('URL display', () => { + it('should have truncate class for long URLs', () => { + render() + const urlText = screen.getByText('https://api.example.com') + expect(urlText.className).toContain('truncate') + }) + + it('should have font-medium class on URL', () => { + render() + const urlText = screen.getByText('https://api.example.com') + expect(urlText.className).toContain('font-medium') + }) + + it('should have secondary text color on URL', () => { + render() + const urlText = screen.getByText('https://api.example.com') + expect(urlText.className).toContain('text-text-secondary') + }) + }) + + describe('divider', () => { + it('should render vertical divider between URL and copy button', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider).toBeInTheDocument() + }) + + it('should have correct divider dimensions', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider?.className).toContain('h-[14px]') + expect(divider?.className).toContain('w-[1px]') + }) + }) + + describe('SecretKeyButton styling', () => { + it('should have shrink-0 class to prevent shrinking', () => { + render() + // The SecretKeyButton wraps a Button component + const button = screen.getByRole('button', { name: /apiKey/i }) + // Check parent container has shrink-0 + const buttonContainer = button.closest('.shrink-0') + expect(buttonContainer).toBeInTheDocument() + }) + }) + + describe('accessibility', () => { + it('should have accessible button for API key', () => { + render() + const button = screen.getByRole('button', { name: /apiKey/i }) + expect(button).toBeInTheDocument() + }) + + it('should have multiple buttons (copy + API key)', () => { + render() + const buttons = screen.getAllByRole('button') + expect(buttons.length).toBeGreaterThanOrEqual(2) + }) + }) +}) diff --git a/web/app/components/develop/code.spec.tsx b/web/app/components/develop/code.spec.tsx new file mode 100644 index 000000000..b279c41a6 --- /dev/null +++ b/web/app/components/develop/code.spec.tsx @@ -0,0 +1,590 @@ +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { Code, CodeGroup, Embed, Pre } from './code' + +// Mock the clipboard utility +vi.mock('@/utils/clipboard', () => ({ + writeTextToClipboard: vi.fn().mockResolvedValue(undefined), +})) + +describe('code.tsx components', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers({ shouldAdvanceTime: true }) + }) + + afterEach(() => { + vi.runOnlyPendingTimers() + vi.useRealTimers() + }) + + describe('Code', () => { + it('should render children', () => { + render(const x = 1) + expect(screen.getByText('const x = 1')).toBeInTheDocument() + }) + + it('should render as code element', () => { + render(code snippet) + const codeElement = screen.getByText('code snippet') + expect(codeElement.tagName).toBe('CODE') + }) + + it('should pass through additional props', () => { + render(snippet) + const codeElement = screen.getByTestId('custom-code') + expect(codeElement).toHaveClass('custom-class') + }) + + it('should render with complex children', () => { + render( + + part1 + part2 + , + ) + expect(screen.getByText('part1')).toBeInTheDocument() + expect(screen.getByText('part2')).toBeInTheDocument() + }) + }) + + describe('Embed', () => { + it('should render value prop', () => { + render(ignored children) + expect(screen.getByText('embedded content')).toBeInTheDocument() + }) + + it('should render as span element', () => { + render(children) + const span = screen.getByText('test value') + expect(span.tagName).toBe('SPAN') + }) + + it('should pass through additional props', () => { + render(children) + const embed = screen.getByTestId('embed-test') + expect(embed).toHaveClass('embed-class') + }) + + it('should not render children, only value', () => { + render(hidden children) + expect(screen.getByText('shown')).toBeInTheDocument() + expect(screen.queryByText('hidden children')).not.toBeInTheDocument() + }) + }) + + describe('CodeGroup', () => { + describe('with string targetCode', () => { + it('should render code from targetCode string', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('const hello = \'world\'')).toBeInTheDocument() + }) + + it('should have shadow and rounded styles', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.shadow-md') + expect(codeGroup).toBeInTheDocument() + expect(codeGroup).toHaveClass('rounded-2xl') + }) + + it('should have bg-zinc-900 background', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.bg-zinc-900') + expect(codeGroup).toBeInTheDocument() + }) + }) + + describe('with array targetCode', () => { + it('should render single code example without tabs', () => { + const examples = [{ code: 'single example' }] + render( + +
fallback
+
, + ) + expect(screen.getByText('single example')).toBeInTheDocument() + }) + + it('should render multiple code examples with tabs', () => { + const examples = [ + { title: 'JavaScript', code: 'console.log("js")' }, + { title: 'Python', code: 'print("py")' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByRole('tab', { name: 'JavaScript' })).toBeInTheDocument() + expect(screen.getByRole('tab', { name: 'Python' })).toBeInTheDocument() + }) + + it('should show first tab content by default', () => { + const examples = [ + { title: 'Tab1', code: 'first content' }, + { title: 'Tab2', code: 'second content' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByText('first content')).toBeInTheDocument() + }) + + it('should switch tabs on click', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + const examples = [ + { title: 'Tab1', code: 'first content' }, + { title: 'Tab2', code: 'second content' }, + ] + render( + +
fallback
+
, + ) + + const tab2 = screen.getByRole('tab', { name: 'Tab2' }) + await act(async () => { + await user.click(tab2) + }) + + await waitFor(() => { + expect(screen.getByText('second content')).toBeInTheDocument() + }) + }) + + it('should use "Code" as default title when title not provided', () => { + const examples = [ + { code: 'example 1' }, + { code: 'example 2' }, + ] + render( + +
fallback
+
, + ) + const codeTabs = screen.getAllByRole('tab', { name: 'Code' }) + expect(codeTabs).toHaveLength(2) + }) + }) + + describe('with title prop', () => { + it('should render title in header', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('API Example')).toBeInTheDocument() + }) + + it('should render title in h3 element', () => { + render( + +
fallback
+
, + ) + const h3 = screen.getByRole('heading', { level: 3 }) + expect(h3).toHaveTextContent('Example Title') + }) + }) + + describe('with tag and label props', () => { + it('should render tag in code panel header', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render label in code panel header', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('/api/users')).toBeInTheDocument() + }) + + it('should render both tag and label with separator', () => { + const { container } = render( + +
fallback
+
, + ) + expect(screen.getByText('POST')).toBeInTheDocument() + expect(screen.getByText('/api/create')).toBeInTheDocument() + // Separator should be present + const separator = container.querySelector('.rounded-full.bg-zinc-500') + expect(separator).toBeInTheDocument() + }) + }) + + describe('CopyButton functionality', () => { + it('should render copy button', () => { + render( + +
fallback
+
, + ) + const copyButton = screen.getByRole('button') + expect(copyButton).toBeInTheDocument() + }) + + it('should show "Copy" text initially', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('Copy')).toBeInTheDocument() + }) + + it('should show "Copied!" after clicking copy button', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + const { writeTextToClipboard } = await import('@/utils/clipboard') + + render( + +
fallback
+
, + ) + + const copyButton = screen.getByRole('button') + await act(async () => { + await user.click(copyButton) + }) + + await waitFor(() => { + expect(writeTextToClipboard).toHaveBeenCalledWith('code to copy') + }) + + expect(screen.getByText('Copied!')).toBeInTheDocument() + }) + + it('should reset copy state after timeout', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + + render( + +
fallback
+
, + ) + + const copyButton = screen.getByRole('button') + await act(async () => { + await user.click(copyButton) + }) + + await waitFor(() => { + expect(screen.getByText('Copied!')).toBeInTheDocument() + }) + + // Advance time past the timeout + await act(async () => { + vi.advanceTimersByTime(1500) + }) + + await waitFor(() => { + expect(screen.getByText('Copy')).toBeInTheDocument() + }) + }) + }) + + describe('without targetCode (using children)', () => { + it('should render children when no targetCode provided', () => { + render( + +
child code content
+
, + ) + expect(screen.getByText('child code content')).toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have not-prose class to prevent prose styling', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.not-prose') + expect(codeGroup).toBeInTheDocument() + }) + + it('should have my-6 margin', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.my-6') + expect(codeGroup).toBeInTheDocument() + }) + + it('should have overflow-hidden', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.overflow-hidden') + expect(codeGroup).toBeInTheDocument() + }) + }) + }) + + describe('Pre', () => { + describe('when outside CodeGroup context', () => { + it('should wrap children in CodeGroup', () => { + const { container } = render( +
+            
code content
+
, + ) + // Should render within a CodeGroup structure + const codeGroup = container.querySelector('.bg-zinc-900') + expect(codeGroup).toBeInTheDocument() + }) + + it('should pass props to CodeGroup', () => { + render( +
+            
code
+
, + ) + expect(screen.getByText('Pre Title')).toBeInTheDocument() + }) + }) + + describe('when inside CodeGroup context (isGrouped)', () => { + it('should return children directly without wrapping', () => { + render( + +
+              inner code
+            
+
, + ) + // The outer code should be rendered (from targetCode) + expect(screen.getByText('outer code')).toBeInTheDocument() + }) + }) + }) + + describe('CodePanelHeader (via CodeGroup)', () => { + it('should not render when neither tag nor label provided', () => { + const { container } = render( + +
fallback
+
, + ) + const headerDivider = container.querySelector('.border-b-white\\/7\\.5') + expect(headerDivider).not.toBeInTheDocument() + }) + + it('should render when only tag is provided', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render when only label is provided', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('/api/endpoint')).toBeInTheDocument() + }) + + it('should render label with font-mono styling', () => { + render( + +
fallback
+
, + ) + const label = screen.getByText('/api/test') + expect(label.className).toContain('font-mono') + expect(label.className).toContain('text-xs') + }) + }) + + describe('CodeGroupHeader (via CodeGroup with multiple tabs)', () => { + it('should render tab list for multiple examples', () => { + const examples = [ + { title: 'cURL', code: 'curl example' }, + { title: 'Node.js', code: 'node example' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByRole('tablist')).toBeInTheDocument() + }) + + it('should style active tab differently', () => { + const examples = [ + { title: 'Active', code: 'active code' }, + { title: 'Inactive', code: 'inactive code' }, + ] + render( + +
fallback
+
, + ) + const activeTab = screen.getByRole('tab', { name: 'Active' }) + expect(activeTab.className).toContain('border-emerald-500') + expect(activeTab.className).toContain('text-emerald-400') + }) + + it('should have header background styling', () => { + const examples = [ + { title: 'Tab1', code: 'code1' }, + { title: 'Tab2', code: 'code2' }, + ] + const { container } = render( + +
fallback
+
, + ) + const header = container.querySelector('.bg-zinc-800') + expect(header).toBeInTheDocument() + }) + }) + + describe('CodePanel (via CodeGroup)', () => { + it('should render code in pre element', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('pre content').closest('pre') + expect(preElement).toBeInTheDocument() + }) + + it('should have text-white class on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('white text').closest('pre') + expect(preElement?.className).toContain('text-white') + }) + + it('should have text-xs class on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('small text').closest('pre') + expect(preElement?.className).toContain('text-xs') + }) + + it('should have overflow-x-auto on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('scrollable').closest('pre') + expect(preElement?.className).toContain('overflow-x-auto') + }) + + it('should have p-4 padding on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('padded').closest('pre') + expect(preElement?.className).toContain('p-4') + }) + }) + + describe('ClipboardIcon (via CopyButton in CodeGroup)', () => { + it('should render clipboard icon in copy button', () => { + render( + +
fallback
+
, + ) + const copyButton = screen.getByRole('button') + const svg = copyButton.querySelector('svg') + expect(svg).toBeInTheDocument() + expect(svg).toHaveAttribute('viewBox', '0 0 20 20') + }) + }) + + describe('edge cases', () => { + it('should handle empty string targetCode', () => { + render( + +
fallback
+
, + ) + // Should render copy button even with empty code + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should handle targetCode with special characters', () => { + const specialCode = '
&
' + render( + +
fallback
+
, + ) + expect(screen.getByText(specialCode)).toBeInTheDocument() + }) + + it('should handle multiline targetCode', () => { + const multilineCode = `line1 +line2 +line3` + render( + +
fallback
+
, + ) + // Multiline code should be rendered - use a partial match + expect(screen.getByText(/line1/)).toBeInTheDocument() + expect(screen.getByText(/line2/)).toBeInTheDocument() + expect(screen.getByText(/line3/)).toBeInTheDocument() + }) + + it('should handle examples with tag property', () => { + const examples = [ + { title: 'Example', tag: 'v1', code: 'versioned code' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByText('versioned code')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/index.spec.tsx b/web/app/components/develop/index.spec.tsx new file mode 100644 index 000000000..f90e33e69 --- /dev/null +++ b/web/app/components/develop/index.spec.tsx @@ -0,0 +1,339 @@ +import { render, screen } from '@testing-library/react' +import DevelopMain from './index' + +// Mock the app store with a factory function to control state +const mockAppDetailValue: { current: unknown } = { current: undefined } +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: unknown) => unknown) => { + const state = { appDetail: mockAppDetailValue.current } + return selector(state) + }, +})) + +// Mock the Doc component since it has complex dependencies +vi.mock('@/app/components/develop/doc', () => ({ + default: ({ appDetail }: { appDetail: { name?: string } | null }) => ( +
+ Doc Component - + {appDetail?.name} +
+ ), +})) + +// Mock the ApiServer component +vi.mock('@/app/components/develop/ApiServer', () => ({ + default: ({ apiBaseUrl, appId }: { apiBaseUrl: string, appId: string }) => ( +
+ API Server - + {apiBaseUrl} + {' '} + - + {appId} +
+ ), +})) + +describe('DevelopMain', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppDetailValue.current = undefined + }) + + describe('loading state', () => { + it('should show loading when appDetail is undefined', () => { + mockAppDetailValue.current = undefined + render() + + // Loading component renders with role="status" + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should show loading when appDetail is null', () => { + mockAppDetailValue.current = null + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should have centered loading container', () => { + mockAppDetailValue.current = undefined + const { container } = render() + + const loadingContainer = container.querySelector('.flex.h-full.items-center.justify-center') + expect(loadingContainer).toBeInTheDocument() + }) + + it('should have correct background on loading state', () => { + mockAppDetailValue.current = undefined + const { container } = render() + + const loadingContainer = container.querySelector('.bg-background-default') + expect(loadingContainer).toBeInTheDocument() + }) + }) + + describe('with appDetail loaded', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com/v1', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should render ApiServer component', () => { + render() + expect(screen.getByTestId('api-server')).toBeInTheDocument() + }) + + it('should pass api_base_url to ApiServer', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('https://api.example.com/v1') + }) + + it('should pass appId to ApiServer', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('app-123') + }) + + it('should render Doc component', () => { + render() + expect(screen.getByTestId('doc-component')).toBeInTheDocument() + }) + + it('should pass appDetail to Doc component', () => { + render() + expect(screen.getByTestId('doc-component')).toHaveTextContent('Test Application') + }) + + it('should not show loading when appDetail exists', () => { + render() + expect(screen.queryByRole('status')).not.toBeInTheDocument() + }) + }) + + describe('layout structure', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have flex column layout', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('flex') + expect(mainContainer.className).toContain('flex-col') + }) + + it('should have relative positioning', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('relative') + }) + + it('should have full height', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('h-full') + }) + + it('should have overflow-hidden', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('overflow-hidden') + }) + }) + + describe('header section', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have header with border', () => { + const { container } = render() + const header = container.querySelector('.border-b') + expect(header).toBeInTheDocument() + }) + + it('should have shrink-0 on header to prevent shrinking', () => { + const { container } = render() + const header = container.querySelector('.shrink-0') + expect(header).toBeInTheDocument() + }) + + it('should have horizontal padding on header', () => { + const { container } = render() + const header = container.querySelector('.px-6') + expect(header).toBeInTheDocument() + }) + + it('should have vertical padding on header', () => { + const { container } = render() + const header = container.querySelector('.py-2') + expect(header).toBeInTheDocument() + }) + + it('should have items centered in header', () => { + const { container } = render() + const header = container.querySelector('.items-center') + expect(header).toBeInTheDocument() + }) + + it('should have justify-between in header', () => { + const { container } = render() + const header = container.querySelector('.justify-between') + expect(header).toBeInTheDocument() + }) + }) + + describe('content section', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have grow class for content area', () => { + const { container } = render() + const content = container.querySelector('.grow') + expect(content).toBeInTheDocument() + }) + + it('should have overflow-auto for content scrolling', () => { + const { container } = render() + const content = container.querySelector('.overflow-auto') + expect(content).toBeInTheDocument() + }) + + it('should have horizontal padding on content', () => { + const { container } = render() + const content = container.querySelector('.px-4') + expect(content).toBeInTheDocument() + }) + + it('should have vertical padding on content', () => { + const { container } = render() + const content = container.querySelector('.py-4') + expect(content).toBeInTheDocument() + }) + + it('should have responsive padding', () => { + const { container } = render() + const content = container.querySelector('[class*="sm:px-10"]') + expect(content).toBeInTheDocument() + }) + }) + + describe('with different appIds', () => { + const mockAppDetail = { + id: 'app-456', + name: 'Another App', + api_base_url: 'https://another-api.com', + mode: 'completion', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should pass different appId to ApiServer', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('app-456') + }) + + it('should handle app with different api_base_url', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('https://another-api.com') + }) + }) + + describe('empty state handling', () => { + it('should handle appDetail with minimal properties', () => { + mockAppDetailValue.current = { + api_base_url: 'https://api.test.com', + } + render() + expect(screen.getByTestId('api-server')).toBeInTheDocument() + }) + + it('should handle appDetail with empty api_base_url', () => { + mockAppDetailValue.current = { + api_base_url: '', + name: 'Empty URL App', + } + render() + expect(screen.getByTestId('api-server')).toBeInTheDocument() + }) + }) + + describe('title element', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have title div with correct styling', () => { + const { container } = render() + const title = container.querySelector('.text-lg.font-medium.text-text-primary') + expect(title).toBeInTheDocument() + }) + + it('should render empty title div', () => { + const { container } = render() + const title = container.querySelector('.text-lg.font-medium.text-text-primary') + expect(title?.textContent).toBe('') + }) + }) + + describe('border styling', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have solid border style', () => { + const { container } = render() + const header = container.querySelector('.border-solid') + expect(header).toBeInTheDocument() + }) + + it('should have divider regular color on border', () => { + const { container } = render() + const header = container.querySelector('.border-b-divider-regular') + expect(header).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/md.spec.tsx b/web/app/components/develop/md.spec.tsx new file mode 100644 index 000000000..8eab1c0ac --- /dev/null +++ b/web/app/components/develop/md.spec.tsx @@ -0,0 +1,655 @@ +import { render, screen } from '@testing-library/react' +import { Col, Heading, Properties, Property, PropertyInstruction, Row, SubProperty } from './md' + +describe('md.tsx components', () => { + describe('Heading', () => { + const defaultProps = { + url: '/api/messages', + method: 'GET' as const, + title: 'Get Messages', + name: '#get-messages', + } + + describe('rendering', () => { + it('should render the method badge', () => { + render() + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render the url', () => { + render() + expect(screen.getByText('/api/messages')).toBeInTheDocument() + }) + + it('should render the title as a link', () => { + render() + const link = screen.getByRole('link', { name: 'Get Messages' }) + expect(link).toBeInTheDocument() + expect(link).toHaveAttribute('href', '#get-messages') + }) + + it('should render an anchor span with correct id', () => { + const { container } = render() + const anchor = container.querySelector('#get-messages') + expect(anchor).toBeInTheDocument() + }) + + it('should strip # prefix from name for id', () => { + const { container } = render() + const anchor = container.querySelector('#with-hash') + expect(anchor).toBeInTheDocument() + }) + }) + + describe('method styling', () => { + it('should apply emerald styles for GET method', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('text-emerald') + expect(badge.className).toContain('bg-emerald-400/10') + expect(badge.className).toContain('ring-emerald-300') + }) + + it('should apply sky styles for POST method', () => { + render() + const badge = screen.getByText('POST') + expect(badge.className).toContain('text-sky') + expect(badge.className).toContain('bg-sky-400/10') + expect(badge.className).toContain('ring-sky-300') + }) + + it('should apply amber styles for PUT method', () => { + render() + const badge = screen.getByText('PUT') + expect(badge.className).toContain('text-amber') + expect(badge.className).toContain('bg-amber-400/10') + expect(badge.className).toContain('ring-amber-300') + }) + + it('should apply rose styles for DELETE method', () => { + render() + const badge = screen.getByText('DELETE') + expect(badge.className).toContain('text-red') + expect(badge.className).toContain('bg-rose') + expect(badge.className).toContain('ring-rose') + }) + + it('should apply violet styles for PATCH method', () => { + render() + const badge = screen.getByText('PATCH') + expect(badge.className).toContain('text-violet') + expect(badge.className).toContain('bg-violet-400/10') + expect(badge.className).toContain('ring-violet-300') + }) + }) + + describe('badge base styles', () => { + it('should have rounded-lg class', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('rounded-lg') + }) + + it('should have font-mono class', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('font-mono') + }) + + it('should have font-semibold class', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('font-semibold') + }) + + it('should have ring-1 and ring-inset classes', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('ring-1') + expect(badge.className).toContain('ring-inset') + }) + }) + + describe('url styles', () => { + it('should have font-mono class on url', () => { + render() + const url = screen.getByText('/api/messages') + expect(url.className).toContain('font-mono') + }) + + it('should have text-xs class on url', () => { + render() + const url = screen.getByText('/api/messages') + expect(url.className).toContain('text-xs') + }) + + it('should have zinc text color on url', () => { + render() + const url = screen.getByText('/api/messages') + expect(url.className).toContain('text-zinc-400') + }) + }) + + describe('h2 element', () => { + it('should render title inside h2', () => { + render() + const h2 = screen.getByRole('heading', { level: 2 }) + expect(h2).toBeInTheDocument() + expect(h2).toHaveTextContent('Get Messages') + }) + + it('should have scroll-mt-32 class on h2', () => { + render() + const h2 = screen.getByRole('heading', { level: 2 }) + expect(h2.className).toContain('scroll-mt-32') + }) + }) + }) + + describe('Row', () => { + it('should render children', () => { + render( + +
Child 1
+
Child 2
+
, + ) + expect(screen.getByText('Child 1')).toBeInTheDocument() + expect(screen.getByText('Child 2')).toBeInTheDocument() + }) + + it('should have grid layout', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('grid') + expect(row.className).toContain('grid-cols-1') + }) + + it('should have gap classes', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('gap-x-16') + expect(row.className).toContain('gap-y-10') + }) + + it('should have xl responsive classes', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('xl:grid-cols-2') + expect(row.className).toContain('xl:!max-w-none') + }) + + it('should have items-start class', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('items-start') + }) + }) + + describe('Col', () => { + it('should render children', () => { + render( + +
Column Content
+ , + ) + expect(screen.getByText('Column Content')).toBeInTheDocument() + }) + + it('should have first/last child margin classes', () => { + const { container } = render( + +
Content
+ , + ) + const col = container.firstChild as HTMLElement + expect(col.className).toContain('[&>:first-child]:mt-0') + expect(col.className).toContain('[&>:last-child]:mb-0') + }) + + it('should apply sticky classes when sticky is true', () => { + const { container } = render( + +
Sticky Content
+ , + ) + const col = container.firstChild as HTMLElement + expect(col.className).toContain('xl:sticky') + expect(col.className).toContain('xl:top-24') + }) + + it('should not apply sticky classes when sticky is false', () => { + const { container } = render( + +
Non-sticky Content
+ , + ) + const col = container.firstChild as HTMLElement + expect(col.className).not.toContain('xl:sticky') + expect(col.className).not.toContain('xl:top-24') + }) + }) + + describe('Properties', () => { + it('should render children', () => { + render( + +
  • Property 1
  • +
  • Property 2
  • +
    , + ) + expect(screen.getByText('Property 1')).toBeInTheDocument() + expect(screen.getByText('Property 2')).toBeInTheDocument() + }) + + it('should render as ul with role list', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list).toBeInTheDocument() + expect(list.tagName).toBe('UL') + }) + + it('should have my-6 margin class', () => { + const { container } = render( + +
  • Property
  • +
    , + ) + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('my-6') + }) + + it('should have list-none class on ul', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('list-none') + }) + + it('should have m-0 and p-0 classes on ul', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('m-0') + expect(list.className).toContain('p-0') + }) + + it('should have divide-y class on ul', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('divide-y') + }) + + it('should have max-w constraint class', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('max-w-[calc(theme(maxWidth.lg)-theme(spacing.8))]') + }) + }) + + describe('Property', () => { + const defaultProps = { + name: 'user_id', + type: 'string', + anchor: false, + } + + it('should render name in code element', () => { + render( + + User identifier + , + ) + const code = screen.getByText('user_id') + expect(code.tagName).toBe('CODE') + }) + + it('should render type', () => { + render( + + User identifier + , + ) + expect(screen.getByText('string')).toBeInTheDocument() + }) + + it('should render children as description', () => { + render( + + User identifier + , + ) + expect(screen.getByText('User identifier')).toBeInTheDocument() + }) + + it('should render as li element', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('li')).toBeInTheDocument() + }) + + it('should have m-0 class on li', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('m-0') + }) + + it('should have padding classes on li', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('px-0') + expect(li.className).toContain('py-4') + }) + + it('should have first:pt-0 and last:pb-0 classes', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('first:pt-0') + expect(li.className).toContain('last:pb-0') + }) + + it('should render dl element with proper structure', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('dl')).toBeInTheDocument() + }) + + it('should have sr-only dt elements for accessibility', () => { + const { container } = render( + + User identifier + , + ) + const dtElements = container.querySelectorAll('dt') + expect(dtElements.length).toBe(3) + dtElements.forEach((dt) => { + expect(dt.className).toContain('sr-only') + }) + }) + + it('should have font-mono class on type', () => { + render( + + Description + , + ) + const typeElement = screen.getByText('string') + expect(typeElement.className).toContain('font-mono') + expect(typeElement.className).toContain('text-xs') + }) + }) + + describe('SubProperty', () => { + const defaultProps = { + name: 'sub_field', + type: 'number', + anchor: false, + } + + it('should render name in code element', () => { + render( + + Sub field description + , + ) + const code = screen.getByText('sub_field') + expect(code.tagName).toBe('CODE') + }) + + it('should render type', () => { + render( + + Sub field description + , + ) + expect(screen.getByText('number')).toBeInTheDocument() + }) + + it('should render children as description', () => { + render( + + Sub field description + , + ) + expect(screen.getByText('Sub field description')).toBeInTheDocument() + }) + + it('should render as li element', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('li')).toBeInTheDocument() + }) + + it('should have m-0 class on li', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('m-0') + }) + + it('should have different padding than Property (py-1 vs py-4)', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('px-0') + expect(li.className).toContain('py-1') + }) + + it('should have last:pb-0 class', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('last:pb-0') + }) + + it('should render dl element with proper structure', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('dl')).toBeInTheDocument() + }) + + it('should have sr-only dt elements for accessibility', () => { + const { container } = render( + + Sub field description + , + ) + const dtElements = container.querySelectorAll('dt') + expect(dtElements.length).toBe(3) + dtElements.forEach((dt) => { + expect(dt.className).toContain('sr-only') + }) + }) + + it('should have font-mono and text-xs on type', () => { + render( + + Description + , + ) + const typeElement = screen.getByText('number') + expect(typeElement.className).toContain('font-mono') + expect(typeElement.className).toContain('text-xs') + }) + }) + + describe('PropertyInstruction', () => { + it('should render children', () => { + render( + + This is an instruction + , + ) + expect(screen.getByText('This is an instruction')).toBeInTheDocument() + }) + + it('should render as li element', () => { + const { container } = render( + + Instruction text + , + ) + expect(container.querySelector('li')).toBeInTheDocument() + }) + + it('should have m-0 class', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('m-0') + }) + + it('should have padding classes', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('px-0') + expect(li.className).toContain('py-4') + }) + + it('should have italic class', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('italic') + }) + + it('should have first:pt-0 class', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('first:pt-0') + }) + }) + + describe('integration tests', () => { + it('should render Property inside Properties', () => { + render( + + + Unique identifier + + + Display name + + , + ) + + expect(screen.getByText('id')).toBeInTheDocument() + expect(screen.getByText('name')).toBeInTheDocument() + expect(screen.getByText('Unique identifier')).toBeInTheDocument() + expect(screen.getByText('Display name')).toBeInTheDocument() + }) + + it('should render Col inside Row', () => { + render( + + +
    Left column
    + + +
    Right column
    + +
    , + ) + + expect(screen.getByText('Left column')).toBeInTheDocument() + expect(screen.getByText('Right column')).toBeInTheDocument() + }) + + it('should render PropertyInstruction inside Properties', () => { + render( + + + Note: All fields are required + + + A required field + + , + ) + + expect(screen.getByText('Note: All fields are required')).toBeInTheDocument() + expect(screen.getByText('required_field')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/secret-key/input-copy.spec.tsx b/web/app/components/develop/secret-key/input-copy.spec.tsx new file mode 100644 index 000000000..0216f2bfa --- /dev/null +++ b/web/app/components/develop/secret-key/input-copy.spec.tsx @@ -0,0 +1,314 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import copy from 'copy-to-clipboard' +import InputCopy from './input-copy' + +// Mock copy-to-clipboard +vi.mock('copy-to-clipboard', () => ({ + default: vi.fn().mockReturnValue(true), +})) + +describe('InputCopy', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers({ shouldAdvanceTime: true }) + }) + + afterEach(() => { + vi.runOnlyPendingTimers() + vi.useRealTimers() + }) + + describe('rendering', () => { + it('should render the value', () => { + render() + expect(screen.getByText('test-api-key-12345')).toBeInTheDocument() + }) + + it('should render with empty value by default', () => { + render() + // Empty string should be rendered + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should render children when provided', () => { + render( + + Custom Content + , + ) + expect(screen.getByTestId('custom-child')).toBeInTheDocument() + }) + + it('should render CopyFeedback component', () => { + render() + // CopyFeedback should render a button + const buttons = screen.getAllByRole('button') + expect(buttons.length).toBeGreaterThan(0) + }) + }) + + describe('styling', () => { + it('should apply custom className', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('custom-class') + }) + + it('should have flex layout', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('flex') + }) + + it('should have items-center alignment', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('items-center') + }) + + it('should have rounded-lg class', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('rounded-lg') + }) + + it('should have background class', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('bg-components-input-bg-normal') + }) + + it('should have hover state', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('hover:bg-state-base-hover') + }) + + it('should have py-2 padding', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('py-2') + }) + }) + + describe('copy functionality', () => { + it('should copy value when clicked', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('copy-this-value') + await act(async () => { + await user.click(copyableArea) + }) + + expect(copy).toHaveBeenCalledWith('copy-this-value') + }) + + it('should update copied state after clicking', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('test-value') + await act(async () => { + await user.click(copyableArea) + }) + + // Copy function should have been called + expect(copy).toHaveBeenCalledWith('test-value') + }) + + it('should reset copied state after timeout', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('test-value') + await act(async () => { + await user.click(copyableArea) + }) + + expect(copy).toHaveBeenCalledWith('test-value') + + // Advance time to reset the copied state + await act(async () => { + vi.advanceTimersByTime(1500) + }) + + // Component should still be functional + expect(screen.getByText('test-value')).toBeInTheDocument() + }) + + it('should render tooltip on value', () => { + render() + // Value should be wrapped in tooltip (tooltip shows on hover, not as visible text) + const valueText = screen.getByText('test-value') + expect(valueText).toBeInTheDocument() + }) + }) + + describe('tooltip', () => { + it('should render tooltip wrapper', () => { + render() + const valueText = screen.getByText('test') + expect(valueText).toBeInTheDocument() + }) + + it('should have cursor-pointer on clickable area', () => { + render() + const valueText = screen.getByText('test') + const clickableArea = valueText.closest('div[class*="cursor-pointer"]') + expect(clickableArea).toBeInTheDocument() + }) + }) + + describe('divider', () => { + it('should render vertical divider', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider).toBeInTheDocument() + }) + + it('should have correct divider dimensions', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider?.className).toContain('h-4') + expect(divider?.className).toContain('w-px') + }) + + it('should have shrink-0 on divider', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider?.className).toContain('shrink-0') + }) + }) + + describe('value display', () => { + it('should have truncate class for long values', () => { + render() + const valueText = screen.getByText('very-long-api-key-value-that-might-overflow') + const container = valueText.closest('div[class*="truncate"]') + expect(container).toBeInTheDocument() + }) + + it('should have text-secondary color on value', () => { + render() + const valueText = screen.getByText('test-value') + expect(valueText.className).toContain('text-text-secondary') + }) + + it('should have absolute positioning for overlay', () => { + render() + const valueText = screen.getByText('test') + const container = valueText.closest('div[class*="absolute"]') + expect(container).toBeInTheDocument() + }) + }) + + describe('inner container', () => { + it('should have grow class on inner container', () => { + const { container } = render() + const innerContainer = container.querySelector('.grow') + expect(innerContainer).toBeInTheDocument() + }) + + it('should have h-5 height on inner container', () => { + const { container } = render() + const innerContainer = container.querySelector('.h-5') + expect(innerContainer).toBeInTheDocument() + }) + }) + + describe('with children', () => { + it('should render children before value', () => { + const { container } = render( + + Prefix: + , + ) + const children = container.querySelector('[data-testid="prefix"]') + expect(children).toBeInTheDocument() + }) + + it('should render both children and value', () => { + render( + + Label: + , + ) + expect(screen.getByText('Label:')).toBeInTheDocument() + expect(screen.getByText('api-key')).toBeInTheDocument() + }) + }) + + describe('CopyFeedback section', () => { + it('should have margin on CopyFeedback container', () => { + const { container } = render() + const copyFeedbackContainer = container.querySelector('.mx-1') + expect(copyFeedbackContainer).toBeInTheDocument() + }) + }) + + describe('relative container', () => { + it('should have relative positioning on value container', () => { + const { container } = render() + const relativeContainer = container.querySelector('.relative') + expect(relativeContainer).toBeInTheDocument() + }) + + it('should have grow on value container', () => { + const { container } = render() + // Find the relative container that also has grow + const valueContainer = container.querySelector('.relative.grow') + expect(valueContainer).toBeInTheDocument() + }) + + it('should have full height on value container', () => { + const { container } = render() + const valueContainer = container.querySelector('.relative.h-full') + expect(valueContainer).toBeInTheDocument() + }) + }) + + describe('edge cases', () => { + it('should handle undefined value', () => { + render() + // Should not crash + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should handle empty string value', () => { + render() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should handle very long values', () => { + const longValue = 'a'.repeat(500) + render() + expect(screen.getByText(longValue)).toBeInTheDocument() + }) + + it('should handle special characters in value', () => { + const specialValue = 'key-with-special-chars!@#$%^&*()' + render() + expect(screen.getByText(specialValue)).toBeInTheDocument() + }) + }) + + describe('multiple clicks', () => { + it('should handle multiple rapid clicks', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('test') + + // Click multiple times rapidly + await act(async () => { + await user.click(copyableArea) + await user.click(copyableArea) + await user.click(copyableArea) + }) + + expect(copy).toHaveBeenCalledTimes(3) + }) + }) +}) diff --git a/web/app/components/develop/secret-key/secret-key-button.spec.tsx b/web/app/components/develop/secret-key/secret-key-button.spec.tsx new file mode 100644 index 000000000..4b4fbaab2 --- /dev/null +++ b/web/app/components/develop/secret-key/secret-key-button.spec.tsx @@ -0,0 +1,297 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SecretKeyButton from './secret-key-button' + +// Mock the SecretKeyModal since it has complex dependencies +vi.mock('@/app/components/develop/secret-key/secret-key-modal', () => ({ + default: ({ isShow, onClose, appId }: { isShow: boolean, onClose: () => void, appId?: string }) => ( + isShow + ? ( +
    + {`Modal for ${appId || 'no-app'}`} + +
    + ) + : null + ), +})) + +describe('SecretKeyButton', () => { + describe('rendering', () => { + it('should render the button', () => { + render() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should render the API key text', () => { + render() + expect(screen.getByText('appApi.apiKey')).toBeInTheDocument() + }) + + it('should render the key icon', () => { + const { container } = render() + // RiKey2Line icon should be rendered as an svg + const svg = container.querySelector('svg') + expect(svg).toBeInTheDocument() + }) + + it('should not show modal initially', () => { + render() + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + }) + + describe('button interaction', () => { + it('should open modal when button is clicked', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + + it('should close modal when onClose is called', async () => { + const user = userEvent.setup() + render() + + // Open modal + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + + // Close modal + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + + it('should toggle modal visibility', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + + // Open + await act(async () => { + await user.click(button) + }) + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + + // Close + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + + // Open again + await act(async () => { + await user.click(button) + }) + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + }) + + describe('props', () => { + it('should apply custom className', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('custom-class') + }) + + it('should pass appId to modal', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByText('Modal for app-123')).toBeInTheDocument() + }) + + it('should handle undefined appId', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByText('Modal for no-app')).toBeInTheDocument() + }) + + it('should apply custom textCls', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('custom-text-class') + }) + }) + + describe('button styling', () => { + it('should have px-3 padding', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('px-3') + }) + + it('should have small size', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('btn-small') + }) + + it('should have ghost variant', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('btn-ghost') + }) + }) + + describe('icon styling', () => { + it('should have icon container with flex layout', () => { + const { container } = render() + const iconContainer = container.querySelector('.flex.items-center.justify-center') + expect(iconContainer).toBeInTheDocument() + }) + + it('should have correct icon dimensions', () => { + const { container } = render() + const iconContainer = container.querySelector('.h-3\\.5.w-3\\.5') + expect(iconContainer).toBeInTheDocument() + }) + + it('should have tertiary text color on icon', () => { + const { container } = render() + const icon = container.querySelector('.text-text-tertiary') + expect(icon).toBeInTheDocument() + }) + }) + + describe('text styling', () => { + it('should have system-xs-medium class', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('system-xs-medium') + }) + + it('should have horizontal padding', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('px-[3px]') + }) + + it('should have tertiary text color', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('text-text-tertiary') + }) + }) + + describe('modal props', () => { + it('should pass isShow prop to modal', async () => { + const user = userEvent.setup() + render() + + // Initially modal should not be visible + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + // Now modal should be visible + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + + it('should pass onClose callback to modal', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + + // Modal should be closed after clicking close + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + }) + + describe('accessibility', () => { + it('should have accessible button', () => { + render() + const button = screen.getByRole('button') + expect(button).toBeInTheDocument() + }) + + it('should be keyboard accessible', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + button.focus() + expect(document.activeElement).toBe(button) + + // Press Enter to activate + await act(async () => { + await user.keyboard('{Enter}') + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + }) + + describe('multiple instances', () => { + it('should work independently when multiple instances exist', async () => { + const user = userEvent.setup() + render( + <> + + + , + ) + + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(2) + + // Click first button + await act(async () => { + await user.click(buttons[0]) + }) + + expect(screen.getByText('Modal for app-1')).toBeInTheDocument() + + // Close first modal + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + + // Click second button + await act(async () => { + await user.click(buttons[1]) + }) + + expect(screen.getByText('Modal for app-2')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/secret-key/secret-key-generate.spec.tsx b/web/app/components/develop/secret-key/secret-key-generate.spec.tsx new file mode 100644 index 000000000..5988d6b7f --- /dev/null +++ b/web/app/components/develop/secret-key/secret-key-generate.spec.tsx @@ -0,0 +1,302 @@ +import type { CreateApiKeyResponse } from '@/models/app' +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SecretKeyGenerateModal from './secret-key-generate' + +// Helper to create a valid CreateApiKeyResponse +const createMockApiKey = (token: string): CreateApiKeyResponse => ({ + id: 'mock-id', + token, + created_at: '2024-01-01T00:00:00Z', +}) + +describe('SecretKeyGenerateModal', () => { + const defaultProps = { + isShow: true, + onClose: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering when shown', () => { + it('should render the modal when isShow is true', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should render the generate tips text', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.generateTips')).toBeInTheDocument() + }) + + it('should render the OK button', () => { + render() + expect(screen.getByText('appApi.actionMsg.ok')).toBeInTheDocument() + }) + + it('should render the close icon', () => { + render() + // Modal renders via portal, so query from document.body + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + }) + + it('should render InputCopy component', () => { + render() + expect(screen.getByText('test-token-123')).toBeInTheDocument() + }) + }) + + describe('rendering when hidden', () => { + it('should not render content when isShow is false', () => { + render() + expect(screen.queryByText('appApi.apiKeyModal.apiSecretKey')).not.toBeInTheDocument() + }) + }) + + describe('newKey prop', () => { + it('should display the token when newKey is provided', () => { + render() + expect(screen.getByText('sk-abc123xyz')).toBeInTheDocument() + }) + + it('should handle undefined newKey', () => { + render() + // Should not crash and modal should still render + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should handle newKey with empty token', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should display long tokens correctly', () => { + const longToken = `sk-${'a'.repeat(100)}` + render() + expect(screen.getByText(longToken)).toBeInTheDocument() + }) + }) + + describe('close functionality', () => { + it('should call onClose when X icon is clicked', async () => { + const user = userEvent.setup() + const onClose = vi.fn() + render() + + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + + await act(async () => { + await user.click(closeIcon!) + }) + + // HeadlessUI Dialog may trigger onClose multiple times (icon click handler + dialog close) + expect(onClose).toHaveBeenCalled() + }) + + it('should call onClose when OK button is clicked', async () => { + const user = userEvent.setup() + const onClose = vi.fn() + render() + + const okButton = screen.getByRole('button', { name: /ok/i }) + await act(async () => { + await user.click(okButton) + }) + + // HeadlessUI Dialog calls onClose both from button click and modal close + expect(onClose).toHaveBeenCalled() + }) + }) + + describe('className prop', () => { + it('should apply custom className', () => { + render( + , + ) + // Modal renders via portal + const modal = document.body.querySelector('.custom-modal-class') + expect(modal).toBeInTheDocument() + }) + + it('should apply shrink-0 class', () => { + render( + , + ) + // Modal renders via portal + const modal = document.body.querySelector('.shrink-0') + expect(modal).toBeInTheDocument() + }) + }) + + describe('modal styling', () => { + it('should have px-8 padding', () => { + render() + // Modal renders via portal + const modal = document.body.querySelector('.px-8') + expect(modal).toBeInTheDocument() + }) + }) + + describe('close icon styling', () => { + it('should have cursor-pointer class on close icon', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + }) + + it('should have correct dimensions on close icon', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg[class*="h-6"][class*="w-6"]') + expect(closeIcon).toBeInTheDocument() + }) + + it('should have tertiary text color on close icon', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg[class*="text-text-tertiary"]') + expect(closeIcon).toBeInTheDocument() + }) + }) + + describe('header section', () => { + it('should have flex justify-end on close container', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + const closeContainer = closeIcon?.parentElement + expect(closeContainer).toBeInTheDocument() + expect(closeContainer?.className).toContain('flex') + expect(closeContainer?.className).toContain('justify-end') + }) + + it('should have negative margin on close container', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + const closeContainer = closeIcon?.parentElement + expect(closeContainer).toBeInTheDocument() + expect(closeContainer?.className).toContain('-mr-2') + expect(closeContainer?.className).toContain('-mt-6') + }) + + it('should have bottom margin on close container', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + const closeContainer = closeIcon?.parentElement + expect(closeContainer).toBeInTheDocument() + expect(closeContainer?.className).toContain('mb-4') + }) + }) + + describe('tips text styling', () => { + it('should have mt-1 margin on tips', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('mt-1') + }) + + it('should have correct font size', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('text-[13px]') + }) + + it('should have normal font weight', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('font-normal') + }) + + it('should have leading-5 line height', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('leading-5') + }) + + it('should have tertiary text color', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('text-text-tertiary') + }) + }) + + describe('InputCopy section', () => { + it('should render InputCopy with token value', () => { + render() + expect(screen.getByText('test-token')).toBeInTheDocument() + }) + + it('should have w-full class on InputCopy', () => { + render() + // The InputCopy component should have w-full + const inputText = screen.getByText('test') + const inputContainer = inputText.closest('.w-full') + expect(inputContainer).toBeInTheDocument() + }) + }) + + describe('OK button section', () => { + it('should render OK button', () => { + render() + const button = screen.getByRole('button', { name: /ok/i }) + expect(button).toBeInTheDocument() + }) + + it('should have button container with flex layout', () => { + render() + const button = screen.getByRole('button', { name: /ok/i }) + const container = button.parentElement + expect(container).toBeInTheDocument() + expect(container?.className).toContain('flex') + }) + + it('should have shrink-0 on button', () => { + render() + const button = screen.getByRole('button', { name: /ok/i }) + expect(button.className).toContain('shrink-0') + }) + }) + + describe('button text styling', () => { + it('should have text-xs font size on button text', () => { + render() + const buttonText = screen.getByText('appApi.actionMsg.ok') + expect(buttonText.className).toContain('text-xs') + }) + + it('should have font-medium on button text', () => { + render() + const buttonText = screen.getByText('appApi.actionMsg.ok') + expect(buttonText.className).toContain('font-medium') + }) + + it('should have secondary text color on button text', () => { + render() + const buttonText = screen.getByText('appApi.actionMsg.ok') + expect(buttonText.className).toContain('text-text-secondary') + }) + }) + + describe('default prop values', () => { + it('should default isShow to false', () => { + // When isShow is explicitly set to false + render() + expect(screen.queryByText('appApi.apiKeyModal.apiSecretKey')).not.toBeInTheDocument() + }) + }) + + describe('modal title', () => { + it('should display the correct title', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/secret-key/secret-key-modal.spec.tsx b/web/app/components/develop/secret-key/secret-key-modal.spec.tsx new file mode 100644 index 000000000..79c51759e --- /dev/null +++ b/web/app/components/develop/secret-key/secret-key-modal.spec.tsx @@ -0,0 +1,614 @@ +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SecretKeyModal from './secret-key-modal' + +// Mock the app context +const mockCurrentWorkspace = vi.fn().mockReturnValue({ + id: 'workspace-1', + name: 'Test Workspace', +}) +const mockIsCurrentWorkspaceManager = vi.fn().mockReturnValue(true) +const mockIsCurrentWorkspaceEditor = vi.fn().mockReturnValue(true) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + currentWorkspace: mockCurrentWorkspace(), + isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager(), + isCurrentWorkspaceEditor: mockIsCurrentWorkspaceEditor(), + }), +})) + +// Mock the timestamp hook +vi.mock('@/hooks/use-timestamp', () => ({ + default: () => ({ + formatTime: vi.fn((value: number, _format: string) => `Formatted: ${value}`), + formatDate: vi.fn((value: string, _format: string) => `Formatted: ${value}`), + }), +})) + +// Mock API services +const mockCreateAppApikey = vi.fn().mockResolvedValue({ token: 'new-app-token-123' }) +const mockDelAppApikey = vi.fn().mockResolvedValue({}) +vi.mock('@/service/apps', () => ({ + createApikey: (...args: unknown[]) => mockCreateAppApikey(...args), + delApikey: (...args: unknown[]) => mockDelAppApikey(...args), +})) + +const mockCreateDatasetApikey = vi.fn().mockResolvedValue({ token: 'new-dataset-token-123' }) +const mockDelDatasetApikey = vi.fn().mockResolvedValue({}) +vi.mock('@/service/datasets', () => ({ + createApikey: (...args: unknown[]) => mockCreateDatasetApikey(...args), + delApikey: (...args: unknown[]) => mockDelDatasetApikey(...args), +})) + +// Mock React Query hooks for apps +const mockAppApiKeysData = vi.fn().mockReturnValue({ data: [] }) +const mockIsAppApiKeysLoading = vi.fn().mockReturnValue(false) +const mockInvalidateAppApiKeys = vi.fn() + +vi.mock('@/service/use-apps', () => ({ + useAppApiKeys: (_appId: string, _options: unknown) => ({ + data: mockAppApiKeysData(), + isLoading: mockIsAppApiKeysLoading(), + }), + useInvalidateAppApiKeys: () => mockInvalidateAppApiKeys, +})) + +// Mock React Query hooks for datasets +const mockDatasetApiKeysData = vi.fn().mockReturnValue({ data: [] }) +const mockIsDatasetApiKeysLoading = vi.fn().mockReturnValue(false) +const mockInvalidateDatasetApiKeys = vi.fn() + +vi.mock('@/service/knowledge/use-dataset', () => ({ + useDatasetApiKeys: (_options: unknown) => ({ + data: mockDatasetApiKeysData(), + isLoading: mockIsDatasetApiKeysLoading(), + }), + useInvalidateDatasetApiKeys: () => mockInvalidateDatasetApiKeys, +})) + +describe('SecretKeyModal', () => { + const defaultProps = { + isShow: true, + onClose: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + mockCurrentWorkspace.mockReturnValue({ id: 'workspace-1', name: 'Test Workspace' }) + mockIsCurrentWorkspaceManager.mockReturnValue(true) + mockIsCurrentWorkspaceEditor.mockReturnValue(true) + mockAppApiKeysData.mockReturnValue({ data: [] }) + mockIsAppApiKeysLoading.mockReturnValue(false) + mockDatasetApiKeysData.mockReturnValue({ data: [] }) + mockIsDatasetApiKeysLoading.mockReturnValue(false) + }) + + describe('rendering when shown', () => { + it('should render the modal when isShow is true', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should render the tips text', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKeyTips')).toBeInTheDocument() + }) + + it('should render the create new key button', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.createNewSecretKey')).toBeInTheDocument() + }) + + it('should render the close icon', () => { + render() + // Modal renders via portal, so we need to query from document.body + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + }) + }) + + describe('rendering when hidden', () => { + it('should not render content when isShow is false', () => { + render() + expect(screen.queryByText('appApi.apiKeyModal.apiSecretKey')).not.toBeInTheDocument() + }) + }) + + describe('loading state', () => { + it('should show loading when app API keys are loading', () => { + mockIsAppApiKeysLoading.mockReturnValue(true) + render() + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should show loading when dataset API keys are loading', () => { + mockIsDatasetApiKeysLoading.mockReturnValue(true) + render() + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should not show loading when data is loaded', () => { + mockIsAppApiKeysLoading.mockReturnValue(false) + render() + expect(screen.queryByRole('status')).not.toBeInTheDocument() + }) + }) + + describe('API keys list for app', () => { + const apiKeys = [ + { id: 'key-1', token: 'sk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + { id: 'key-2', token: 'sk-xyz987wvu654tsr321', created_at: 1700050000, last_used_at: null }, + ] + + beforeEach(() => { + mockAppApiKeysData.mockReturnValue({ data: apiKeys }) + }) + + it('should render API keys when available', () => { + render() + // Token 'sk-abc123def456ghi789' (21 chars) -> first 3 'sk-' + '...' + last 20 'k-abc123def456ghi789' + expect(screen.getByText('sk-...k-abc123def456ghi789')).toBeInTheDocument() + }) + + it('should render created time for keys', () => { + render() + expect(screen.getByText('Formatted: 1700000000')).toBeInTheDocument() + }) + + it('should render last used time for keys', () => { + render() + expect(screen.getByText('Formatted: 1700100000')).toBeInTheDocument() + }) + + it('should render "never" for keys without last_used_at', () => { + render() + expect(screen.getByText('appApi.never')).toBeInTheDocument() + }) + + it('should render delete button for managers', () => { + render() + // Delete button contains RiDeleteBinLine SVG - look for SVGs with h-4 w-4 class within buttons + const buttons = screen.getAllByRole('button') + // There should be at least 3 buttons: copy feedback, delete, and create + expect(buttons.length).toBeGreaterThanOrEqual(2) + // Check for delete icon SVG - Modal renders via portal + const deleteIcon = document.body.querySelector('svg[class*="h-4"][class*="w-4"]') + expect(deleteIcon).toBeInTheDocument() + }) + + it('should not render delete button for non-managers', () => { + mockIsCurrentWorkspaceManager.mockReturnValue(false) + render() + // The specific delete action button should not be present + const actionButtons = screen.getAllByRole('button') + // Should only have copy and create buttons, not delete + expect(actionButtons.length).toBeGreaterThan(0) + }) + + it('should render table headers', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.secretKey')).toBeInTheDocument() + expect(screen.getByText('appApi.apiKeyModal.created')).toBeInTheDocument() + expect(screen.getByText('appApi.apiKeyModal.lastUsed')).toBeInTheDocument() + }) + }) + + describe('API keys list for dataset', () => { + const datasetKeys = [ + { id: 'dk-1', token: 'dk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + ] + + beforeEach(() => { + mockDatasetApiKeysData.mockReturnValue({ data: datasetKeys }) + }) + + it('should render dataset API keys when no appId', () => { + render() + // Token 'dk-abc123def456ghi789' (21 chars) -> first 3 'dk-' + '...' + last 20 'k-abc123def456ghi789' + expect(screen.getByText('dk-...k-abc123def456ghi789')).toBeInTheDocument() + }) + }) + + describe('close functionality', () => { + it('should call onClose when X icon is clicked', async () => { + const user = userEvent.setup() + const onClose = vi.fn() + render() + + // Modal renders via portal, so we need to query from document.body + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + + await act(async () => { + await user.click(closeIcon!) + }) + + expect(onClose).toHaveBeenCalledTimes(1) + }) + }) + + describe('create new key', () => { + it('should call create API for app when button is clicked', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockCreateAppApikey).toHaveBeenCalledWith({ + url: '/apps/app-123/api-keys', + body: {}, + }) + }) + }) + + it('should call create API for dataset when no appId', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockCreateDatasetApikey).toHaveBeenCalledWith({ + url: '/datasets/api-keys', + body: {}, + }) + }) + }) + + it('should show generate modal after creating key', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + // The SecretKeyGenerateModal should be shown with the new token + expect(screen.getByText('appApi.apiKeyModal.generateTips')).toBeInTheDocument() + }) + }) + + it('should invalidate app API keys after creating', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockInvalidateAppApiKeys).toHaveBeenCalledWith('app-123') + }) + }) + + it('should invalidate dataset API keys after creating (no appId)', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockInvalidateDatasetApiKeys).toHaveBeenCalled() + }) + }) + + it('should disable create button when no workspace', () => { + mockCurrentWorkspace.mockReturnValue(null) + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey').closest('button') + expect(createButton).toBeDisabled() + }) + + it('should disable create button when not editor', () => { + mockIsCurrentWorkspaceEditor.mockReturnValue(false) + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey').closest('button') + expect(createButton).toBeDisabled() + }) + }) + + describe('delete key', () => { + const apiKeys = [ + { id: 'key-1', token: 'sk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + ] + + beforeEach(() => { + mockAppApiKeysData.mockReturnValue({ data: apiKeys }) + }) + + it('should render delete button for managers', () => { + render() + + // Find buttons that contain SVG (delete/copy buttons) + const actionButtons = screen.getAllByRole('button') + // There should be at least copy, delete, and create buttons + expect(actionButtons.length).toBeGreaterThanOrEqual(3) + }) + + it('should render API key row with actions', () => { + render() + + // Verify the truncated token is rendered + expect(screen.getByText('sk-...k-abc123def456ghi789')).toBeInTheDocument() + }) + + it('should have action buttons in the key row', () => { + render() + + // Check for action button containers - Modal renders via portal + const actionContainers = document.body.querySelectorAll('[class*="space-x-2"]') + expect(actionContainers.length).toBeGreaterThan(0) + }) + + it('should have delete button visible for managers', async () => { + render() + + // Find the delete button by looking for the button with the delete icon + const deleteIcon = document.body.querySelector('svg[class*="h-4"][class*="w-4"]') + const deleteButton = deleteIcon?.closest('button') + expect(deleteButton).toBeInTheDocument() + }) + + it('should show confirm dialog when delete button is clicked', async () => { + const user = userEvent.setup() + render() + + // Find delete button by action-btn class (second action button after copy) + const actionButtons = document.body.querySelectorAll('button.action-btn') + // The delete button is the second action button (first is copy) + const deleteButton = actionButtons[1] + expect(deleteButton).toBeInTheDocument() + + await act(async () => { + await user.click(deleteButton!) + }) + + // Confirm dialog should appear + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('appApi.actionMsg.deleteConfirmTips')).toBeInTheDocument() + }) + }) + + it('should call delete API for app when confirmed', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + // Find and click the confirm button + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockDelAppApikey).toHaveBeenCalledWith({ + url: '/apps/app-123/api-keys/key-1', + params: {}, + }) + }) + }) + + it('should invalidate app API keys after deleting', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockInvalidateAppApiKeys).toHaveBeenCalledWith('app-123') + }) + }) + + it('should close confirm dialog and clear delKeyId when cancel is clicked', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + // Click cancel button + const cancelButton = screen.getByText('common.operation.cancel') + await act(async () => { + await user.click(cancelButton) + }) + + // Confirm dialog should close + await waitFor(() => { + expect(screen.queryByText('appApi.actionMsg.deleteConfirmTitle')).not.toBeInTheDocument() + }) + + // Delete API should not be called + expect(mockDelAppApikey).not.toHaveBeenCalled() + }) + }) + + describe('delete key for dataset', () => { + const datasetKeys = [ + { id: 'dk-1', token: 'dk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + ] + + beforeEach(() => { + mockDatasetApiKeysData.mockReturnValue({ data: datasetKeys }) + }) + + it('should call delete API for dataset when no appId', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockDelDatasetApikey).toHaveBeenCalledWith({ + url: '/datasets/api-keys/dk-1', + params: {}, + }) + }) + }) + + it('should invalidate dataset API keys after deleting', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockInvalidateDatasetApiKeys).toHaveBeenCalled() + }) + }) + }) + + describe('token truncation', () => { + it('should truncate token correctly', () => { + const apiKeys = [ + { id: 'key-1', token: 'sk-abcdefghijklmnopqrstuvwxyz1234567890', created_at: 1700000000, last_used_at: null }, + ] + mockAppApiKeysData.mockReturnValue({ data: apiKeys }) + + render() + + // Token format: first 3 chars + ... + last 20 chars + // 'sk-abcdefghijklmnopqrstuvwxyz1234567890' -> 'sk-...qrstuvwxyz1234567890' + expect(screen.getByText('sk-...qrstuvwxyz1234567890')).toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should render modal with expected structure', () => { + render() + // Modal should render and contain the title + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should render create button with flex styling', () => { + render() + // Modal renders via portal, so query from document.body + const flexContainers = document.body.querySelectorAll('[class*="flex"]') + expect(flexContainers.length).toBeGreaterThan(0) + }) + }) + + describe('empty state', () => { + it('should not render table when no keys', () => { + mockAppApiKeysData.mockReturnValue({ data: [] }) + render() + + expect(screen.queryByText('appApi.apiKeyModal.secretKey')).not.toBeInTheDocument() + }) + + it('should not render table when data is null', () => { + mockAppApiKeysData.mockReturnValue(null) + render() + + expect(screen.queryByText('appApi.apiKeyModal.secretKey')).not.toBeInTheDocument() + }) + }) + + describe('SecretKeyGenerateModal', () => { + it('should close generate modal on close', async () => { + const user = userEvent.setup() + render() + + // Create a new key to open generate modal + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(screen.getByText('appApi.apiKeyModal.generateTips')).toBeInTheDocument() + }) + + // Find and click the close/OK button in generate modal + const okButton = screen.getByText('appApi.actionMsg.ok') + await act(async () => { + await user.click(okButton) + }) + + await waitFor(() => { + expect(screen.queryByText('appApi.apiKeyModal.generateTips')).not.toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/develop/tag.spec.tsx b/web/app/components/develop/tag.spec.tsx new file mode 100644 index 000000000..60a12040f --- /dev/null +++ b/web/app/components/develop/tag.spec.tsx @@ -0,0 +1,242 @@ +import { render, screen } from '@testing-library/react' +import { Tag } from './tag' + +describe('Tag', () => { + describe('rendering', () => { + it('should render children text', () => { + render(GET) + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render as a span element', () => { + render(POST) + const tag = screen.getByText('POST') + expect(tag.tagName).toBe('SPAN') + }) + }) + + describe('default color mapping based on HTTP methods', () => { + it('should apply emerald color for GET method', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('text-emerald') + }) + + it('should apply sky color for POST method', () => { + render(POST) + const tag = screen.getByText('POST') + expect(tag.className).toContain('text-sky') + }) + + it('should apply amber color for PUT method', () => { + render(PUT) + const tag = screen.getByText('PUT') + expect(tag.className).toContain('text-amber') + }) + + it('should apply rose color for DELETE method', () => { + render(DELETE) + const tag = screen.getByText('DELETE') + expect(tag.className).toContain('text-red') + }) + + it('should apply emerald color for unknown methods', () => { + render(UNKNOWN) + const tag = screen.getByText('UNKNOWN') + expect(tag.className).toContain('text-emerald') + }) + + it('should handle lowercase method names', () => { + render(get) + const tag = screen.getByText('get') + expect(tag.className).toContain('text-emerald') + }) + + it('should handle mixed case method names', () => { + render(Post) + const tag = screen.getByText('Post') + expect(tag.className).toContain('text-sky') + }) + }) + + describe('custom color prop', () => { + it('should override default color with custom emerald color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-emerald') + }) + + it('should override default color with custom sky color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-sky') + }) + + it('should override default color with custom amber color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-amber') + }) + + it('should override default color with custom rose color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-red') + }) + + it('should override default color with custom zinc color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-zinc') + }) + + it('should override automatic color mapping with explicit color', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('text-sky') + }) + }) + + describe('variant styles', () => { + it('should apply medium variant styles by default', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('rounded-lg') + expect(tag.className).toContain('px-1.5') + expect(tag.className).toContain('ring-1') + expect(tag.className).toContain('ring-inset') + }) + + it('should apply small variant styles', () => { + render(GET) + const tag = screen.getByText('GET') + // Small variant should not have ring styles + expect(tag.className).not.toContain('rounded-lg') + expect(tag.className).not.toContain('ring-1') + }) + }) + + describe('base styles', () => { + it('should always have font-mono class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('font-mono') + }) + + it('should always have correct font-size class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('text-[0.625rem]') + }) + + it('should always have font-semibold class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('font-semibold') + }) + + it('should always have leading-6 class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('leading-6') + }) + }) + + describe('color styles for medium variant', () => { + it('should apply full emerald medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-emerald-300') + expect(tag.className).toContain('bg-emerald-400/10') + expect(tag.className).toContain('text-emerald-500') + }) + + it('should apply full sky medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-sky-300') + expect(tag.className).toContain('bg-sky-400/10') + expect(tag.className).toContain('text-sky-500') + }) + + it('should apply full amber medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-amber-300') + expect(tag.className).toContain('bg-amber-400/10') + expect(tag.className).toContain('text-amber-500') + }) + + it('should apply full rose medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-rose-200') + expect(tag.className).toContain('bg-rose-50') + expect(tag.className).toContain('text-red-500') + }) + + it('should apply full zinc medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-zinc-200') + expect(tag.className).toContain('bg-zinc-50') + expect(tag.className).toContain('text-zinc-500') + }) + }) + + describe('color styles for small variant', () => { + it('should apply emerald small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-emerald-500') + // Small variant should not have background/ring styles + expect(tag.className).not.toContain('bg-emerald-400/10') + expect(tag.className).not.toContain('ring-emerald-300') + }) + + it('should apply sky small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-sky-500') + }) + + it('should apply amber small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-amber-500') + }) + + it('should apply rose small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-red-500') + }) + + it('should apply zinc small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-zinc-400') + }) + }) + + describe('HTTP method color combinations', () => { + it('should correctly map PATCH to emerald (default)', () => { + render(PATCH) + const tag = screen.getByText('PATCH') + // PATCH is not in the valueColorMap, so it defaults to emerald + expect(tag.className).toContain('text-emerald') + }) + + it('should correctly render all standard HTTP methods', () => { + const methods = ['GET', 'POST', 'PUT', 'DELETE'] + const expectedColors = ['emerald', 'sky', 'amber', 'red'] + + methods.forEach((method, index) => { + const { unmount } = render({method}) + const tag = screen.getByText(method) + expect(tag.className).toContain(`text-${expectedColors[index]}`) + unmount() + }) + }) + }) +}) diff --git a/web/app/components/explore/app-card/index.tsx b/web/app/components/explore/app-card/index.tsx index 8693d1ed7..0c6c6f1f0 100644 --- a/web/app/components/explore/app-card/index.tsx +++ b/web/app/components/explore/app-card/index.tsx @@ -112,17 +112,19 @@ const AppCard = ({ {isExplore && (canCreate || isTrialApp) && ( )} diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index 3738c7489..151190b11 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -279,6 +279,7 @@ const Apps = ({ {isShowTryAppPanel && ( ({ + useCarousel: () => ({ + api: { + scrollTo: mockScrollTo, + slideNodes: mockSlideNodes, + }, + selectedIndex: 0, + }), +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'banner.viewMore': 'View More', + } + return translations[key] || key + }, + }), +})) + +const createMockBanner = (overrides: Partial = {}): Banner => ({ + id: 'banner-1', + status: 'enabled', + link: 'https://example.com', + content: { + 'category': 'Featured', + 'title': 'Test Banner Title', + 'description': 'Test banner description text', + 'img-src': 'https://example.com/image.png', + }, + ...overrides, +} as Banner) + +// Mock ResizeObserver methods declared at module level and initialized +const mockResizeObserverObserve = vi.fn() +const mockResizeObserverDisconnect = vi.fn() + +// Create mock class outside of describe block for proper hoisting +class MockResizeObserver { + constructor(_callback: ResizeObserverCallback) { + // Store callback if needed + } + + observe(...args: Parameters) { + mockResizeObserverObserve(...args) + } + + disconnect() { + mockResizeObserverDisconnect() + } + + unobserve() { + // No-op + } +} + +describe('BannerItem', () => { + let mockWindowOpen: ReturnType + + beforeEach(() => { + mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null) + mockSlideNodes.mockReturnValue([{}, {}, {}]) // 3 slides + + vi.stubGlobal('ResizeObserver', MockResizeObserver) + + // Mock window.innerWidth for responsive tests + Object.defineProperty(window, 'innerWidth', { + writable: true, + configurable: true, + value: 1400, // Above RESPONSIVE_BREAKPOINT (1200) + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + vi.unstubAllGlobals() + mockWindowOpen.mockRestore() + }) + + describe('basic rendering', () => { + it('renders banner category', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('Featured')).toBeInTheDocument() + }) + + it('renders banner title', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + + it('renders banner description', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('Test banner description text')).toBeInTheDocument() + }) + + it('renders banner image with correct src and alt', () => { + const banner = createMockBanner() + render( + , + ) + + const image = screen.getByRole('img') + expect(image).toHaveAttribute('src', 'https://example.com/image.png') + expect(image).toHaveAttribute('alt', 'Test Banner Title') + }) + + it('renders view more text', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('View More')).toBeInTheDocument() + }) + }) + + describe('click handling', () => { + it('opens banner link in new tab when clicked', () => { + const banner = createMockBanner({ link: 'https://test-link.com' }) + render( + , + ) + + const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') + fireEvent.click(bannerElement!) + + expect(mockWindowOpen).toHaveBeenCalledWith( + 'https://test-link.com', + '_blank', + 'noopener,noreferrer', + ) + }) + + it('does not open window when banner has no link', () => { + const banner = createMockBanner({ link: '' }) + render( + , + ) + + const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') + fireEvent.click(bannerElement!) + + expect(mockWindowOpen).not.toHaveBeenCalled() + }) + }) + + describe('slide indicators', () => { + it('renders correct number of indicator buttons', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(3) + }) + + it('renders indicator buttons with correct numbers', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('01')).toBeInTheDocument() + expect(screen.getByText('02')).toBeInTheDocument() + expect(screen.getByText('03')).toBeInTheDocument() + }) + + it('calls scrollTo when indicator is clicked', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + const secondIndicator = screen.getByText('02').closest('button') + fireEvent.click(secondIndicator!) + + expect(mockScrollTo).toHaveBeenCalledWith(1) + }) + + it('renders no indicators when no slides', () => { + mockSlideNodes.mockReturnValue([]) + const banner = createMockBanner() + render( + , + ) + + expect(screen.queryByRole('button')).not.toBeInTheDocument() + }) + }) + + describe('isPaused prop', () => { + it('defaults isPaused to false', () => { + const banner = createMockBanner() + render( + , + ) + + // Component should render without issues + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + + it('accepts isPaused prop', () => { + const banner = createMockBanner() + render( + , + ) + + // Component should render with isPaused + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + }) + + describe('responsive behavior', () => { + it('sets up ResizeObserver on mount', () => { + const banner = createMockBanner() + render( + , + ) + + expect(mockResizeObserverObserve).toHaveBeenCalled() + }) + + it('adds resize event listener on mount', () => { + const addEventListenerSpy = vi.spyOn(window, 'addEventListener') + const banner = createMockBanner() + render( + , + ) + + expect(addEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) + addEventListenerSpy.mockRestore() + }) + + it('removes resize event listener on unmount', () => { + const removeEventListenerSpy = vi.spyOn(window, 'removeEventListener') + const banner = createMockBanner() + const { unmount } = render( + , + ) + + unmount() + + expect(removeEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) + removeEventListenerSpy.mockRestore() + }) + + it('sets maxWidth when window width is below breakpoint', () => { + // Set window width below RESPONSIVE_BREAKPOINT (1200) + Object.defineProperty(window, 'innerWidth', { + writable: true, + configurable: true, + value: 1000, + }) + + const banner = createMockBanner() + render( + , + ) + + // Component should render and apply responsive styles + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + + it('applies responsive styles when below breakpoint', () => { + // Set window width below RESPONSIVE_BREAKPOINT (1200) + Object.defineProperty(window, 'innerWidth', { + writable: true, + configurable: true, + value: 800, + }) + + const banner = createMockBanner() + render( + , + ) + + // The component should render even with responsive mode + expect(screen.getByText('View More')).toBeInTheDocument() + }) + }) + + describe('content variations', () => { + it('renders long category text', () => { + const banner = createMockBanner({ + content: { + 'category': 'Very Long Category Name', + 'title': 'Title', + 'description': 'Description', + 'img-src': 'https://example.com/img.png', + }, + } as Partial) + render( + , + ) + + expect(screen.getByText('Very Long Category Name')).toBeInTheDocument() + }) + + it('renders long title with truncation class', () => { + const banner = createMockBanner({ + content: { + 'category': 'Category', + 'title': 'A Very Long Title That Should Be Truncated Eventually', + 'description': 'Description', + 'img-src': 'https://example.com/img.png', + }, + } as Partial) + render( + , + ) + + const titleElement = screen.getByText('A Very Long Title That Should Be Truncated Eventually') + expect(titleElement).toHaveClass('line-clamp-2') + }) + + it('renders long description with truncation class', () => { + const banner = createMockBanner({ + content: { + 'category': 'Category', + 'title': 'Title', + 'description': 'A very long description that should be limited to a certain number of lines for proper display in the banner component.', + 'img-src': 'https://example.com/img.png', + }, + } as Partial) + render( + , + ) + + const descriptionElement = screen.getByText(/A very long description/) + expect(descriptionElement).toHaveClass('line-clamp-4') + }) + }) + + describe('slide calculation', () => { + it('calculates next index correctly for first slide', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + // With selectedIndex=0 and 3 slides, nextIndex should be 1 + // The second indicator button should show the "next slide" state + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(3) + }) + + it('handles single slide case', () => { + mockSlideNodes.mockReturnValue([{}]) + const banner = createMockBanner() + render( + , + ) + + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(1) + }) + }) + + describe('wrapper styling', () => { + it('has cursor-pointer class', () => { + const banner = createMockBanner() + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('cursor-pointer') + }) + + it('has rounded-2xl class', () => { + const banner = createMockBanner() + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('rounded-2xl') + }) + }) +}) diff --git a/web/app/components/explore/banner/banner.spec.tsx b/web/app/components/explore/banner/banner.spec.tsx new file mode 100644 index 000000000..de719c393 --- /dev/null +++ b/web/app/components/explore/banner/banner.spec.tsx @@ -0,0 +1,472 @@ +import type * as React from 'react' +import type { Banner as BannerType } from '@/models/app' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import Banner from './banner' + +const mockUseGetBanners = vi.fn() + +vi.mock('@/service/use-explore', () => ({ + useGetBanners: (...args: unknown[]) => mockUseGetBanners(...args), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/app/components/base/carousel', () => ({ + Carousel: Object.assign( + ({ children, onMouseEnter, onMouseLeave, className }: { + children: React.ReactNode + onMouseEnter?: () => void + onMouseLeave?: () => void + className?: string + }) => ( +
    + {children} +
    + ), + { + Content: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), + Item: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), + Plugin: { + Autoplay: (config: Record) => ({ type: 'autoplay', ...config }), + }, + }, + ), + useCarousel: () => ({ + api: { + scrollTo: vi.fn(), + slideNodes: () => [], + }, + selectedIndex: 0, + }), +})) + +vi.mock('./banner-item', () => ({ + BannerItem: ({ banner, autoplayDelay, isPaused }: { + banner: BannerType + autoplayDelay: number + isPaused?: boolean + }) => ( +
    + BannerItem: + {' '} + {banner.content.title} +
    + ), +})) + +const createMockBanner = (id: string, status: string = 'enabled', title: string = 'Test Banner'): BannerType => ({ + id, + status, + link: 'https://example.com', + content: { + 'category': 'Featured', + title, + 'description': 'Test description', + 'img-src': 'https://example.com/image.png', + }, +} as BannerType) + +describe('Banner', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + vi.useRealTimers() + }) + + describe('loading state', () => { + it('renders loading state when isLoading is true', () => { + mockUseGetBanners.mockReturnValue({ + data: null, + isLoading: true, + isError: false, + }) + + render() + + // Loading component renders a spinner + const loadingWrapper = document.querySelector('[style*="min-height"]') + expect(loadingWrapper).toBeInTheDocument() + }) + + it('shows loading indicator with correct minimum height', () => { + mockUseGetBanners.mockReturnValue({ + data: null, + isLoading: true, + isError: false, + }) + + render() + + const loadingWrapper = document.querySelector('[style*="min-height: 168px"]') + expect(loadingWrapper).toBeInTheDocument() + }) + }) + + describe('error state', () => { + it('returns null when isError is true', () => { + mockUseGetBanners.mockReturnValue({ + data: null, + isLoading: false, + isError: true, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + }) + + describe('empty state', () => { + it('returns null when banners array is empty', () => { + mockUseGetBanners.mockReturnValue({ + data: [], + isLoading: false, + isError: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('returns null when all banners are disabled', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'disabled'), + createMockBanner('2', 'disabled'), + ], + isLoading: false, + isError: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('returns null when data is undefined', () => { + mockUseGetBanners.mockReturnValue({ + data: undefined, + isLoading: false, + isError: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + }) + + describe('successful render', () => { + it('renders carousel when enabled banners exist', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + expect(screen.getByTestId('carousel')).toBeInTheDocument() + }) + + it('renders only enabled banners', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'Enabled Banner 1'), + createMockBanner('2', 'disabled', 'Disabled Banner'), + createMockBanner('3', 'enabled', 'Enabled Banner 2'), + ], + isLoading: false, + isError: false, + }) + + render() + + const bannerItems = screen.getAllByTestId('banner-item') + expect(bannerItems).toHaveLength(2) + expect(screen.getByText('BannerItem: Enabled Banner 1')).toBeInTheDocument() + expect(screen.getByText('BannerItem: Enabled Banner 2')).toBeInTheDocument() + expect(screen.queryByText('BannerItem: Disabled Banner')).not.toBeInTheDocument() + }) + + it('passes correct autoplayDelay to BannerItem', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-autoplay-delay', '5000') + }) + + it('renders carousel with correct class', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + expect(screen.getByTestId('carousel')).toHaveClass('rounded-2xl') + }) + }) + + describe('hover behavior', () => { + it('sets isPaused to true on mouse enter', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + const carousel = screen.getByTestId('carousel') + fireEvent.mouseEnter(carousel) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'true') + }) + + it('sets isPaused to false on mouse leave', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + const carousel = screen.getByTestId('carousel') + + // Enter and then leave + fireEvent.mouseEnter(carousel) + fireEvent.mouseLeave(carousel) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'false') + }) + }) + + describe('resize behavior', () => { + it('pauses animation during resize', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + // Trigger resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'true') + }) + + it('resumes animation after resize debounce delay', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + // Trigger resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + // Wait for debounce delay (50ms) + act(() => { + vi.advanceTimersByTime(50) + }) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'false') + }) + + it('resets debounce timer on multiple resize events', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + // Trigger first resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + // Wait partial time + act(() => { + vi.advanceTimersByTime(30) + }) + + // Trigger second resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + // Wait another 30ms (total 60ms from second resize but only 30ms after) + act(() => { + vi.advanceTimersByTime(30) + }) + + // Should still be paused (debounce resets) + let bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'true') + + // Wait remaining time + act(() => { + vi.advanceTimersByTime(20) + }) + + bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'false') + }) + }) + + describe('cleanup', () => { + it('removes resize event listener on unmount', () => { + const removeEventListenerSpy = vi.spyOn(window, 'removeEventListener') + + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + const { unmount } = render() + unmount() + + expect(removeEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) + removeEventListenerSpy.mockRestore() + }) + + it('clears resize timer on unmount', () => { + const clearTimeoutSpy = vi.spyOn(globalThis, 'clearTimeout') + + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + const { unmount } = render() + + // Trigger resize to create timer + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + unmount() + + expect(clearTimeoutSpy).toHaveBeenCalled() + clearTimeoutSpy.mockRestore() + }) + }) + + describe('hook calls', () => { + it('calls useGetBanners with correct locale', () => { + mockUseGetBanners.mockReturnValue({ + data: [], + isLoading: false, + isError: false, + }) + + render() + + expect(mockUseGetBanners).toHaveBeenCalledWith('en-US') + }) + }) + + describe('multiple banners', () => { + it('renders all enabled banners in carousel items', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'Banner 1'), + createMockBanner('2', 'enabled', 'Banner 2'), + createMockBanner('3', 'enabled', 'Banner 3'), + ], + isLoading: false, + isError: false, + }) + + render() + + const carouselItems = screen.getAllByTestId('carousel-item') + expect(carouselItems).toHaveLength(3) + }) + + it('preserves banner order', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'First Banner'), + createMockBanner('2', 'enabled', 'Second Banner'), + createMockBanner('3', 'enabled', 'Third Banner'), + ], + isLoading: false, + isError: false, + }) + + render() + + const bannerItems = screen.getAllByTestId('banner-item') + expect(bannerItems[0]).toHaveAttribute('data-banner-id', '1') + expect(bannerItems[1]).toHaveAttribute('data-banner-id', '2') + expect(bannerItems[2]).toHaveAttribute('data-banner-id', '3') + }) + }) + + describe('React.memo behavior', () => { + it('renders as memoized component', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + const { rerender } = render() + + // Re-render with same props + rerender() + + // Component should still be present (memo doesn't break rendering) + expect(screen.getByTestId('carousel')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/explore/banner/indicator-button.spec.tsx b/web/app/components/explore/banner/indicator-button.spec.tsx new file mode 100644 index 000000000..545f4e2f9 --- /dev/null +++ b/web/app/components/explore/banner/indicator-button.spec.tsx @@ -0,0 +1,448 @@ +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { IndicatorButton } from './indicator-button' + +describe('IndicatorButton', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + vi.useRealTimers() + }) + + describe('basic rendering', () => { + it('renders button with correct index number', () => { + const mockOnClick = vi.fn() + render( + , + ) + + expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText('01')).toBeInTheDocument() + }) + + it('renders two-digit index numbers', () => { + const mockOnClick = vi.fn() + render( + , + ) + + expect(screen.getByText('10')).toBeInTheDocument() + }) + + it('pads single digit index numbers with leading zero', () => { + const mockOnClick = vi.fn() + render( + , + ) + + expect(screen.getByText('05')).toBeInTheDocument() + }) + }) + + describe('active state', () => { + it('applies active styles when index equals selectedIndex', () => { + const mockOnClick = vi.fn() + render( + , + ) + + const button = screen.getByRole('button') + expect(button).toHaveClass('bg-text-primary') + }) + + it('applies inactive styles when index does not equal selectedIndex', () => { + const mockOnClick = vi.fn() + render( + , + ) + + const button = screen.getByRole('button') + expect(button).toHaveClass('bg-components-panel-on-panel-item-bg') + }) + }) + + describe('click handling', () => { + it('calls onClick when button is clicked', () => { + const mockOnClick = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByRole('button')) + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('stops event propagation when clicked', () => { + const mockOnClick = vi.fn() + const mockParentClick = vi.fn() + + render( +
    + +
    , + ) + + fireEvent.click(screen.getByRole('button')) + expect(mockOnClick).toHaveBeenCalledTimes(1) + expect(mockParentClick).not.toHaveBeenCalled() + }) + }) + + describe('progress indicator', () => { + it('does not show progress indicator when not next slide', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + // Check for conic-gradient style which indicates progress indicator + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).not.toBeInTheDocument() + }) + + it('shows progress indicator when isNextSlide is true and not active', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).toBeInTheDocument() + }) + + it('does not show progress indicator when isNextSlide but also active', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).not.toBeInTheDocument() + }) + }) + + describe('animation behavior', () => { + it('starts progress from 0 when isNextSlide becomes true', () => { + const mockOnClick = vi.fn() + const { container, rerender } = render( + , + ) + + // Initially no progress indicator + expect(container.querySelector('[style*="conic-gradient"]')).not.toBeInTheDocument() + + // Rerender with isNextSlide=true + rerender( + , + ) + + // Now progress indicator should be visible + expect(container.querySelector('[style*="conic-gradient"]')).toBeInTheDocument() + }) + + it('resets progress when resetKey changes', () => { + const mockOnClick = vi.fn() + const { rerender, container } = render( + , + ) + + // Progress indicator should be present + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).toBeInTheDocument() + + // Rerender with new resetKey - this should reset the progress animation + rerender( + , + ) + + const newProgressIndicator = container.querySelector('[style*="conic-gradient"]') + // The progress indicator should still be present after reset + expect(newProgressIndicator).toBeInTheDocument() + }) + + it('stops animation when isPaused is true', () => { + const mockOnClick = vi.fn() + const mockRequestAnimationFrame = vi.spyOn(window, 'requestAnimationFrame') + + render( + , + ) + + // The component should still render but animation should be paused + // requestAnimationFrame might still be called for polling but progress won't update + expect(screen.getByRole('button')).toBeInTheDocument() + mockRequestAnimationFrame.mockRestore() + }) + + it('cancels animation frame on unmount', () => { + const mockOnClick = vi.fn() + const mockCancelAnimationFrame = vi.spyOn(window, 'cancelAnimationFrame') + + const { unmount } = render( + , + ) + + // Trigger animation frame + act(() => { + vi.advanceTimersToNextTimer() + }) + + unmount() + + expect(mockCancelAnimationFrame).toHaveBeenCalled() + mockCancelAnimationFrame.mockRestore() + }) + + it('cancels animation frame when isNextSlide becomes false', () => { + const mockOnClick = vi.fn() + const mockCancelAnimationFrame = vi.spyOn(window, 'cancelAnimationFrame') + + const { rerender } = render( + , + ) + + // Trigger animation frame + act(() => { + vi.advanceTimersToNextTimer() + }) + + // Change isNextSlide to false - this should cancel the animation frame + rerender( + , + ) + + expect(mockCancelAnimationFrame).toHaveBeenCalled() + mockCancelAnimationFrame.mockRestore() + }) + + it('continues polling when document is hidden', () => { + const mockOnClick = vi.fn() + const mockRequestAnimationFrame = vi.spyOn(window, 'requestAnimationFrame') + + // Mock document.hidden to be true + Object.defineProperty(document, 'hidden', { + writable: true, + configurable: true, + value: true, + }) + + render( + , + ) + + // Component should still render + expect(screen.getByRole('button')).toBeInTheDocument() + + // Reset document.hidden + Object.defineProperty(document, 'hidden', { + writable: true, + configurable: true, + value: false, + }) + + mockRequestAnimationFrame.mockRestore() + }) + }) + + describe('isPaused prop default', () => { + it('defaults isPaused to false when not provided', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + // Progress indicator should be visible (animation running) + expect(container.querySelector('[style*="conic-gradient"]')).toBeInTheDocument() + }) + }) + + describe('button styling', () => { + it('has correct base classes', () => { + const mockOnClick = vi.fn() + render( + , + ) + + const button = screen.getByRole('button') + expect(button).toHaveClass('relative') + expect(button).toHaveClass('flex') + expect(button).toHaveClass('items-center') + expect(button).toHaveClass('justify-center') + expect(button).toHaveClass('rounded-[7px]') + expect(button).toHaveClass('border') + expect(button).toHaveClass('transition-colors') + }) + }) +}) diff --git a/web/app/components/explore/create-app-modal/index.spec.tsx b/web/app/components/explore/create-app-modal/index.spec.tsx index 7ddb5a908..65ec0e609 100644 --- a/web/app/components/explore/create-app-modal/index.spec.tsx +++ b/web/app/components/explore/create-app-modal/index.spec.tsx @@ -138,7 +138,7 @@ describe('CreateAppModal', () => { setup({ appName: 'My App', isEditModal: false }) expect(screen.getByText('explore.appCustomize.title:{"name":"My App"}')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeInTheDocument() expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument() }) @@ -146,7 +146,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: true, appMode: AppModeEnum.CHAT, max_active_requests: 5 }) expect(screen.getByText('app.editAppTitle')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeInTheDocument() expect(screen.getByRole('switch')).toBeInTheDocument() expect((screen.getByRole('spinbutton') as HTMLInputElement).value).toBe('5') }) @@ -166,7 +166,7 @@ describe('CreateAppModal', () => { it('should not render modal content when hidden', () => { setup({ show: false }) - expect(screen.queryByRole('button', { name: 'common.operation.create' })).not.toBeInTheDocument() + expect(screen.queryByRole('button', { name: /common\.operation\.create/ })).not.toBeInTheDocument() }) }) @@ -175,13 +175,13 @@ describe('CreateAppModal', () => { it('should disable confirm action when confirmDisabled is true', () => { setup({ confirmDisabled: true }) - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) it('should disable confirm action when appName is empty', () => { setup({ appName: ' ' }) - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) }) @@ -245,7 +245,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: false }) expect(screen.getByText('billing.apps.fullTip2')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) it('should allow saving when apps quota is reached in edit mode', () => { @@ -257,7 +257,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: true }) expect(screen.queryByText('billing.apps.fullTip2')).not.toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeEnabled() + expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeEnabled() }) }) @@ -384,7 +384,7 @@ describe('CreateAppModal', () => { fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' })) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -433,7 +433,7 @@ describe('CreateAppModal', () => { expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument() // Submit and verify the payload uses the original icon (cancel reverts to props) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -471,7 +471,7 @@ describe('CreateAppModal', () => { appIconBackground: '#000000', }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -495,7 +495,7 @@ describe('CreateAppModal', () => { const { onConfirm } = setup({ appDescription: 'Old description' }) fireEvent.change(screen.getByPlaceholderText('app.newApp.appDescriptionPlaceholder'), { target: { value: 'Updated description' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -512,7 +512,7 @@ describe('CreateAppModal', () => { appIconBackground: null, }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -536,7 +536,7 @@ describe('CreateAppModal', () => { fireEvent.click(screen.getByRole('switch')) fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -551,7 +551,7 @@ describe('CreateAppModal', () => { it('should omit max_active_requests when input is empty', () => { const { onConfirm } = setup({ isEditModal: true, max_active_requests: null }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -564,7 +564,7 @@ describe('CreateAppModal', () => { const { onConfirm } = setup({ isEditModal: true, max_active_requests: null }) fireEvent.change(screen.getByRole('spinbutton'), { target: { value: 'abc' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -576,7 +576,7 @@ describe('CreateAppModal', () => { it('should show toast error and not submit when name becomes empty before debounced submit runs', () => { const { onConfirm, onHide } = setup({ appName: 'My App' }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) fireEvent.change(screen.getByPlaceholderText('app.newApp.appNamePlaceholder'), { target: { value: ' ' } }) act(() => { diff --git a/web/app/components/explore/create-app-modal/index.tsx b/web/app/components/explore/create-app-modal/index.tsx index 9bffcc6c6..cfe59fb7f 100644 --- a/web/app/components/explore/create-app-modal/index.tsx +++ b/web/app/components/explore/create-app-modal/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { AppIconType } from '@/types/app' -import { RiCloseLine, RiCommandLine, RiCornerDownLeftLine } from '@remixicon/react' +import { RiCloseLine } from '@remixicon/react' import { useDebounceFn, useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import * as React from 'react' @@ -17,6 +17,7 @@ import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { useProviderContext } from '@/context/provider-context' import { AppModeEnum } from '@/types/app' import AppIconPicker from '../../base/app-icon-picker' +import ShortcutsName from '../../workflow/shortcuts-name' export type CreateAppModalProps = { show: boolean @@ -198,10 +199,7 @@ const CreateAppModal = ({ onClick={handleSubmit} > {!isEditModal ? t('operation.create', { ns: 'common' }) : t('operation.save', { ns: 'common' })} -
    - - -
    + diff --git a/web/app/components/explore/sidebar/no-apps/index.tsx b/web/app/components/explore/sidebar/no-apps/index.tsx index 39b425ce3..f2f406008 100644 --- a/web/app/components/explore/sidebar/no-apps/index.tsx +++ b/web/app/components/explore/sidebar/no-apps/index.tsx @@ -2,6 +2,7 @@ import type { FC } from 'react' import * as React from 'react' import { useTranslation } from 'react-i18next' +import { useDocLink } from '@/context/i18n' import useTheme from '@/hooks/use-theme' import { Theme } from '@/types/app' import { cn } from '@/utils/classnames' @@ -12,12 +13,13 @@ const i18nPrefix = 'sidebar.noApps' const NoApps: FC = () => { const { t } = useTranslation() const { theme } = useTheme() + const docLink = useDocLink() return (
    {t(`${i18nPrefix}.title`, { ns: 'explore' })}
    {t(`${i18nPrefix}.description`, { ns: 'explore' })}
    - {t(`${i18nPrefix}.learnMore`, { ns: 'explore' })} + {t(`${i18nPrefix}.learnMore`, { ns: 'explore' })}
    ) } diff --git a/web/app/components/explore/try-app/app-info/index.spec.tsx b/web/app/components/explore/try-app/app-info/index.spec.tsx new file mode 100644 index 000000000..cfae862a7 --- /dev/null +++ b/web/app/components/explore/try-app/app-info/index.spec.tsx @@ -0,0 +1,395 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import AppInfo from './index' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'types.advanced': 'Advanced', + 'types.chatbot': 'Chatbot', + 'types.agent': 'Agent', + 'types.workflow': 'Workflow', + 'types.completion': 'Completion', + 'tryApp.createFromSampleApp': 'Create from Sample', + 'tryApp.category': 'Category', + 'tryApp.requirements': 'Requirements', + } + return translations[key] || key + }, + }), +})) + +const mockUseGetRequirements = vi.fn() + +vi.mock('./use-get-requirements', () => ({ + default: (...args: unknown[]) => mockUseGetRequirements(...args), +})) + +const createMockAppDetail = (mode: string, overrides: Partial = {}): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App Name', + description: 'Test App Description', + mode, + site: { + title: 'Test Site Title', + icon: '🚀', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, + ...overrides, +} as unknown as TryAppInfo) + +describe('AppInfo', () => { + beforeEach(() => { + mockUseGetRequirements.mockReturnValue({ + requirements: [], + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('app name and icon', () => { + it('renders app name', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Test App Name')).toBeInTheDocument() + }) + + it('renders app name with title attribute', () => { + const appDetail = createMockAppDetail('chat', { + name: 'Very Long App Name That Should Be Truncated', + } as Partial) + const mockOnCreate = vi.fn() + + render( + , + ) + + const nameElement = screen.getByText('Very Long App Name That Should Be Truncated') + expect(nameElement).toHaveAttribute('title', 'Very Long App Name That Should Be Truncated') + }) + }) + + describe('app type', () => { + it('displays ADVANCED for advanced-chat mode', () => { + const appDetail = createMockAppDetail('advanced-chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('ADVANCED')).toBeInTheDocument() + }) + + it('displays CHATBOT for chat mode', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('CHATBOT')).toBeInTheDocument() + }) + + it('displays AGENT for agent-chat mode', () => { + const appDetail = createMockAppDetail('agent-chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('AGENT')).toBeInTheDocument() + }) + + it('displays WORKFLOW for workflow mode', () => { + const appDetail = createMockAppDetail('workflow') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('WORKFLOW')).toBeInTheDocument() + }) + + it('displays COMPLETION for completion mode', () => { + const appDetail = createMockAppDetail('completion') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('COMPLETION')).toBeInTheDocument() + }) + }) + + describe('description', () => { + it('renders description when provided', () => { + const appDetail = createMockAppDetail('chat', { + description: 'This is a test description', + } as Partial) + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('This is a test description')).toBeInTheDocument() + }) + + it('does not render description when empty', () => { + const appDetail = createMockAppDetail('chat', { + description: '', + } as Partial) + const mockOnCreate = vi.fn() + + const { container } = render( + , + ) + + // Check that there's no element with the description class that has empty content + const descriptionElements = container.querySelectorAll('.system-sm-regular.mt-\\[14px\\]') + expect(descriptionElements.length).toBe(0) + }) + }) + + describe('create button', () => { + it('renders create button with correct text', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Create from Sample')).toBeInTheDocument() + }) + + it('calls onCreate when button is clicked', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByText('Create from Sample')) + expect(mockOnCreate).toHaveBeenCalledTimes(1) + }) + }) + + describe('category', () => { + it('renders category when provided', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Category')).toBeInTheDocument() + expect(screen.getByText('AI Assistant')).toBeInTheDocument() + }) + + it('does not render category section when not provided', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.queryByText('Category')).not.toBeInTheDocument() + }) + }) + + describe('requirements', () => { + it('renders requirements when available', () => { + mockUseGetRequirements.mockReturnValue({ + requirements: [ + { name: 'OpenAI GPT-4', iconUrl: 'https://example.com/icon1.png' }, + { name: 'Google Search', iconUrl: 'https://example.com/icon2.png' }, + ], + }) + + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Requirements')).toBeInTheDocument() + expect(screen.getByText('OpenAI GPT-4')).toBeInTheDocument() + expect(screen.getByText('Google Search')).toBeInTheDocument() + }) + + it('does not render requirements section when empty', () => { + mockUseGetRequirements.mockReturnValue({ + requirements: [], + }) + + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.queryByText('Requirements')).not.toBeInTheDocument() + }) + + it('renders requirement icons with correct background image', () => { + mockUseGetRequirements.mockReturnValue({ + requirements: [ + { name: 'Test Tool', iconUrl: 'https://example.com/test-icon.png' }, + ], + }) + + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + const { container } = render( + , + ) + + const iconElement = container.querySelector('[style*="background-image"]') + expect(iconElement).toBeInTheDocument() + expect(iconElement).toHaveStyle({ backgroundImage: 'url(https://example.com/test-icon.png)' }) + }) + }) + + describe('className prop', () => { + it('applies custom className', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + const { container } = render( + , + ) + + expect(container.firstChild).toHaveClass('custom-class') + }) + }) + + describe('hook calls', () => { + it('calls useGetRequirements with correct parameters', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(mockUseGetRequirements).toHaveBeenCalledWith({ + appDetail, + appId: 'my-app-id', + }) + }) + }) +}) diff --git a/web/app/components/explore/try-app/app-info/use-get-requirements.spec.ts b/web/app/components/explore/try-app/app-info/use-get-requirements.spec.ts new file mode 100644 index 000000000..c8af6121d --- /dev/null +++ b/web/app/components/explore/try-app/app-info/use-get-requirements.spec.ts @@ -0,0 +1,425 @@ +import type { TryAppInfo } from '@/service/try-app' +import { renderHook } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import useGetRequirements from './use-get-requirements' + +const mockUseGetTryAppFlowPreview = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppFlowPreview: (...args: unknown[]) => mockUseGetTryAppFlowPreview(...args), +})) + +vi.mock('@/config', () => ({ + MARKETPLACE_API_PREFIX: 'https://marketplace.api', +})) + +const createMockAppDetail = (mode: string, overrides: Partial = {}): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: 'icon', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, + ...overrides, +} as unknown as TryAppInfo) + +describe('useGetRequirements', () => { + afterEach(() => { + vi.clearAllMocks() + }) + + describe('basic app modes (chat, completion, agent-chat)', () => { + it('returns model provider for chat mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('chat') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('openai') + expect(result.current.requirements[0].iconUrl).toBe('https://marketplace.api/plugins/langgenius/openai/icon') + }) + + it('returns model provider for completion mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('completion', { + model_config: { + model: { + provider: 'anthropic/claude/claude', + name: 'claude-3', + mode: 'completion', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { tools: [] }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('claude') + }) + + it('returns model provider and tools for agent-chat mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('agent-chat', { + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { + tools: [ + { + enabled: true, + provider_id: 'langgenius/google_search/google_search', + tool_label: 'Google Search', + }, + { + enabled: true, + provider_id: 'langgenius/web_scraper/web_scraper', + tool_label: 'Web Scraper', + }, + { + enabled: false, + provider_id: 'langgenius/disabled_tool/disabled_tool', + tool_label: 'Disabled Tool', + }, + ], + }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(3) + expect(result.current.requirements.map(r => r.name)).toContain('openai') + expect(result.current.requirements.map(r => r.name)).toContain('Google Search') + expect(result.current.requirements.map(r => r.name)).toContain('Web Scraper') + expect(result.current.requirements.map(r => r.name)).not.toContain('Disabled Tool') + }) + + it('filters out disabled tools in agent mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('agent-chat', { + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { + tools: [ + { + enabled: false, + provider_id: 'langgenius/tool1/tool1', + tool_label: 'Tool 1', + }, + { + enabled: false, + provider_id: 'langgenius/tool2/tool2', + tool_label: 'Tool 2', + }, + ], + }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + // Only model provider should be included, no disabled tools + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('openai') + }) + }) + + describe('advanced app modes (workflow, advanced-chat)', () => { + it('returns requirements from flow data for workflow mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + { + data: { + type: 'tool', + provider_id: 'langgenius/google/google', + tool_label: 'Google Tool', + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(2) + expect(result.current.requirements.map(r => r.name)).toContain('gpt-4') + expect(result.current.requirements.map(r => r.name)).toContain('Google Tool') + }) + + it('returns requirements from flow data for advanced-chat mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'anthropic/claude/claude', + name: 'claude-3-opus', + }, + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('advanced-chat') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('claude-3-opus') + }) + + it('returns empty requirements when flow data has no nodes', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(0) + }) + + it('returns empty requirements when flow data is null', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: null, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(0) + }) + + it('extracts multiple LLM nodes from flow data', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + { + data: { + type: 'llm', + model: { + provider: 'anthropic/claude/claude', + name: 'claude-3', + }, + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(2) + expect(result.current.requirements.map(r => r.name)).toContain('gpt-4') + expect(result.current.requirements.map(r => r.name)).toContain('claude-3') + }) + + it('extracts multiple tool nodes from flow data', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'tool', + provider_id: 'langgenius/tool1/tool1', + tool_label: 'Tool 1', + }, + }, + { + data: { + type: 'tool', + provider_id: 'langgenius/tool2/tool2', + tool_label: 'Tool 2', + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(2) + expect(result.current.requirements.map(r => r.name)).toContain('Tool 1') + expect(result.current.requirements.map(r => r.name)).toContain('Tool 2') + }) + }) + + describe('deduplication', () => { + it('removes duplicate requirements by name', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('gpt-4') + }) + }) + + describe('icon URL generation', () => { + it('generates correct icon URL for model providers', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('chat', { + model_config: { + model: { + provider: 'org/plugin/model', + name: 'model-name', + mode: 'chat', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { tools: [] }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements[0].iconUrl).toBe('https://marketplace.api/plugins/org/plugin/icon') + }) + }) + + describe('hook calls', () => { + it('calls useGetTryAppFlowPreview with correct parameters for basic apps', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('chat') + renderHook(() => useGetRequirements({ appDetail, appId: 'test-app-id' })) + + expect(mockUseGetTryAppFlowPreview).toHaveBeenCalledWith('test-app-id', true) + }) + + it('calls useGetTryAppFlowPreview with correct parameters for advanced apps', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('workflow') + renderHook(() => useGetRequirements({ appDetail, appId: 'test-app-id' })) + + expect(mockUseGetTryAppFlowPreview).toHaveBeenCalledWith('test-app-id', false) + }) + }) +}) diff --git a/web/app/components/explore/try-app/app/chat.spec.tsx b/web/app/components/explore/try-app/app/chat.spec.tsx new file mode 100644 index 000000000..ebd430c4e --- /dev/null +++ b/web/app/components/explore/try-app/app/chat.spec.tsx @@ -0,0 +1,357 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import TryApp from './chat' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'chat.resetChat': 'Reset Chat', + 'tryApp.tryInfo': 'This is try mode info', + } + return translations[key] || key + }, + }), +})) + +const mockRemoveConversationIdInfo = vi.fn() +const mockHandleNewConversation = vi.fn() +const mockUseEmbeddedChatbot = vi.fn() + +vi.mock('@/app/components/base/chat/embedded-chatbot/hooks', () => ({ + useEmbeddedChatbot: (...args: unknown[]) => mockUseEmbeddedChatbot(...args), +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => 'pc', + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +vi.mock('../../../base/chat/embedded-chatbot/theme/theme-context', () => ({ + useThemeContext: () => ({ + primaryColor: '#1890ff', + }), +})) + +vi.mock('@/app/components/base/chat/embedded-chatbot/chat-wrapper', () => ({ + default: () =>
    ChatWrapper
    , +})) + +vi.mock('@/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown', () => ({ + default: () =>
    ViewFormDropdown
    , +})) + +const createMockAppDetail = (overrides: Partial = {}): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test Chat App', + description: 'Test Description', + mode: 'chat', + site: { + title: 'Test Site Title', + icon: '💬', + icon_type: 'emoji', + icon_background: '#4F46E5', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, + ...overrides, +} as unknown as TryAppInfo) + +describe('TryApp (chat.tsx)', () => { + beforeEach(() => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: null, + inputsForms: [], + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('basic rendering', () => { + it('renders app name', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByText('Test Chat App')).toBeInTheDocument() + }) + + it('renders app name with title attribute', () => { + const appDetail = createMockAppDetail({ name: 'Long App Name' } as Partial) + + render( + , + ) + + const nameElement = screen.getByText('Long App Name') + expect(nameElement).toHaveAttribute('title', 'Long App Name') + }) + + it('renders ChatWrapper', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByTestId('chat-wrapper')).toBeInTheDocument() + }) + + it('renders alert with try info', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByText('This is try mode info')).toBeInTheDocument() + }) + + it('applies className prop', () => { + const appDetail = createMockAppDetail() + + const { container } = render( + , + ) + + // The component wraps with EmbeddedChatbotContext.Provider, first child is the div with className + const innerDiv = container.querySelector('.custom-class') + expect(innerDiv).toBeInTheDocument() + }) + }) + + describe('reset button', () => { + it('does not render reset button when no conversation', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: null, + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + // Reset button should not be present + expect(screen.queryByRole('button')).not.toBeInTheDocument() + }) + + it('renders reset button when conversation exists', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + // Should have a button (the reset button) + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('calls handleNewConversation when reset button is clicked', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + fireEvent.click(screen.getByRole('button')) + + expect(mockRemoveConversationIdInfo).toHaveBeenCalledWith('test-app-id') + expect(mockHandleNewConversation).toHaveBeenCalled() + }) + }) + + describe('view form dropdown', () => { + it('does not render view form dropdown when no conversation', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: null, + inputsForms: [{ id: 'form1' }], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.queryByTestId('view-form-dropdown')).not.toBeInTheDocument() + }) + + it('does not render view form dropdown when no input forms', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.queryByTestId('view-form-dropdown')).not.toBeInTheDocument() + }) + + it('renders view form dropdown when conversation and input forms exist', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [{ id: 'form1' }], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByTestId('view-form-dropdown')).toBeInTheDocument() + }) + }) + + describe('alert hiding', () => { + it('hides alert when onHide is called', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + // Find and click the hide button on the alert + const alertElement = screen.getByText('This is try mode info').closest('[class*="alert"]')?.parentElement + const hideButton = alertElement?.querySelector('button, [role="button"], svg') + + if (hideButton) { + fireEvent.click(hideButton) + // After hiding, the alert should not be visible + expect(screen.queryByText('This is try mode info')).not.toBeInTheDocument() + } + }) + }) + + describe('hook calls', () => { + it('calls useEmbeddedChatbot with correct parameters', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(mockUseEmbeddedChatbot).toHaveBeenCalledWith('tryApp', 'my-app-id') + }) + + it('calls removeConversationIdInfo on mount', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(mockRemoveConversationIdInfo).toHaveBeenCalledWith('my-app-id') + }) + }) +}) diff --git a/web/app/components/explore/try-app/app/index.spec.tsx b/web/app/components/explore/try-app/app/index.spec.tsx new file mode 100644 index 000000000..927365a64 --- /dev/null +++ b/web/app/components/explore/try-app/app/index.spec.tsx @@ -0,0 +1,188 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import TryApp from './index' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: vi.fn(), +})) + +vi.mock('./chat', () => ({ + default: ({ appId, appDetail, className }: { appId: string, appDetail: TryAppInfo, className: string }) => ( +
    + Chat Component +
    + ), +})) + +vi.mock('./text-generation', () => ({ + default: ({ + appId, + className, + isWorkflow, + appData, + }: { appId: string, className: string, isWorkflow: boolean, appData: { mode: string } }) => ( +
    + TextGeneration Component +
    + ), +})) + +const createMockAppDetail = (mode: string): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: 'icon', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'test/provider', + name: 'test-model', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, +} as unknown as TryAppInfo) + +describe('TryApp (app/index.tsx)', () => { + afterEach(() => { + cleanup() + }) + + describe('chat mode rendering', () => { + it('renders Chat component for chat mode', () => { + const appDetail = createMockAppDetail('chat') + render() + + expect(screen.getByTestId('chat-component')).toBeInTheDocument() + expect(screen.queryByTestId('text-generation-component')).not.toBeInTheDocument() + }) + + it('renders Chat component for advanced-chat mode', () => { + const appDetail = createMockAppDetail('advanced-chat') + render() + + expect(screen.getByTestId('chat-component')).toBeInTheDocument() + expect(screen.queryByTestId('text-generation-component')).not.toBeInTheDocument() + }) + + it('renders Chat component for agent-chat mode', () => { + const appDetail = createMockAppDetail('agent-chat') + render() + + expect(screen.getByTestId('chat-component')).toBeInTheDocument() + expect(screen.queryByTestId('text-generation-component')).not.toBeInTheDocument() + }) + + it('passes correct props to Chat component', () => { + const appDetail = createMockAppDetail('chat') + render() + + const chatComponent = screen.getByTestId('chat-component') + expect(chatComponent).toHaveAttribute('data-app-id', 'test-app-id') + expect(chatComponent).toHaveAttribute('data-mode', 'chat') + expect(chatComponent).toHaveClass('h-full', 'grow') + }) + }) + + describe('completion mode rendering', () => { + it('renders TextGeneration component for completion mode', () => { + const appDetail = createMockAppDetail('completion') + render() + + expect(screen.getByTestId('text-generation-component')).toBeInTheDocument() + expect(screen.queryByTestId('chat-component')).not.toBeInTheDocument() + }) + + it('renders TextGeneration component for workflow mode', () => { + const appDetail = createMockAppDetail('workflow') + render() + + expect(screen.getByTestId('text-generation-component')).toBeInTheDocument() + expect(screen.queryByTestId('chat-component')).not.toBeInTheDocument() + }) + + it('passes isWorkflow=true for workflow mode', () => { + const appDetail = createMockAppDetail('workflow') + render() + + const textGenComponent = screen.getByTestId('text-generation-component') + expect(textGenComponent).toHaveAttribute('data-is-workflow', 'true') + }) + + it('passes isWorkflow=false for completion mode', () => { + const appDetail = createMockAppDetail('completion') + render() + + const textGenComponent = screen.getByTestId('text-generation-component') + expect(textGenComponent).toHaveAttribute('data-is-workflow', 'false') + }) + + it('passes correct props to TextGeneration component', () => { + const appDetail = createMockAppDetail('completion') + render() + + const textGenComponent = screen.getByTestId('text-generation-component') + expect(textGenComponent).toHaveAttribute('data-app-id', 'test-app-id') + expect(textGenComponent).toHaveClass('h-full', 'grow') + }) + }) + + describe('document title', () => { + it('calls useDocumentTitle with site title', async () => { + const useDocumentTitle = (await import('@/hooks/use-document-title')).default + const appDetail = createMockAppDetail('chat') + appDetail.site.title = 'My App Title' + + render() + + expect(useDocumentTitle).toHaveBeenCalledWith('My App Title') + }) + + it('calls useDocumentTitle with empty string when site.title is undefined', async () => { + const useDocumentTitle = (await import('@/hooks/use-document-title')).default + const appDetail = createMockAppDetail('chat') + appDetail.site = undefined as unknown as TryAppInfo['site'] + + render() + + expect(useDocumentTitle).toHaveBeenCalledWith('') + }) + }) + + describe('wrapper styling', () => { + it('renders with correct wrapper classes', () => { + const appDetail = createMockAppDetail('chat') + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'h-full', 'w-full') + }) + }) +}) diff --git a/web/app/components/explore/try-app/app/text-generation.spec.tsx b/web/app/components/explore/try-app/app/text-generation.spec.tsx new file mode 100644 index 000000000..cbeafc513 --- /dev/null +++ b/web/app/components/explore/try-app/app/text-generation.spec.tsx @@ -0,0 +1,468 @@ +import type { AppData } from '@/models/share' +import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import TextGeneration from './text-generation' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'tryApp.tryInfo': 'This is a try app notice', + } + return translations[key] || key + }, + }), +})) + +const mockUpdateAppInfo = vi.fn() +const mockUpdateAppParams = vi.fn() +const mockAppParams = { + user_input_form: [], + more_like_this: { enabled: false }, + file_upload: null, + text_to_speech: { enabled: false }, + system_parameters: {}, +} +let mockStoreAppParams: typeof mockAppParams | null = mockAppParams + +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: unknown) => unknown) => { + const state = { + updateAppInfo: mockUpdateAppInfo, + updateAppParams: mockUpdateAppParams, + appParams: mockStoreAppParams, + } + return selector(state) + }, +})) + +const mockUseGetTryAppParams = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppParams: (...args: unknown[]) => mockUseGetTryAppParams(...args), +})) + +let mockMediaType = 'pc' + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => mockMediaType, + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +vi.mock('@/app/components/share/text-generation/run-once', () => ({ + default: ({ + siteInfo, + onSend, + onInputsChange, + }: { siteInfo: { title: string }, onSend: () => void, onInputsChange: (inputs: Record) => void }) => ( +
    + {siteInfo?.title} + + +
    + ), +})) + +vi.mock('@/app/components/share/text-generation/result', () => ({ + default: ({ + isWorkflow, + appId, + onCompleted, + onRunStart, + }: { isWorkflow: boolean, appId: string, onCompleted: () => void, onRunStart: () => void }) => ( +
    + + +
    + ), +})) + +const createMockAppData = (overrides: Partial = {}): AppData => ({ + app_id: 'test-app-id', + site: { + title: 'Test App Title', + description: 'Test App Description', + icon: '🚀', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + default_language: 'en', + prompt_public: true, + copyright: '', + privacy_policy: '', + custom_disclaimer: '', + }, + custom_config: { + remove_webapp_brand: false, + }, + ...overrides, +} as AppData) + +describe('TextGeneration', () => { + beforeEach(() => { + mockStoreAppParams = mockAppParams + mockMediaType = 'pc' + mockUseGetTryAppParams.mockReturnValue({ + data: mockAppParams, + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('loading state', () => { + it('renders loading when appData is null', () => { + render( + , + ) + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('renders loading when appParams is not available', () => { + mockStoreAppParams = null + mockUseGetTryAppParams.mockReturnValue({ + data: null, + }) + + render( + , + ) + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + }) + + describe('content rendering', () => { + it('renders app title', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + // Multiple elements may have the title (header and RunOnce mock) + const titles = screen.getAllByText('Test App Title') + expect(titles.length).toBeGreaterThan(0) + }) + }) + + it('renders app description when available', async () => { + const appData = createMockAppData({ + site: { + title: 'Test App', + description: 'This is a description', + icon: '🚀', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + default_language: 'en', + prompt_public: true, + copyright: '', + privacy_policy: '', + custom_disclaimer: '', + }, + } as unknown as Partial) + + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('This is a description')).toBeInTheDocument() + }) + }) + + it('renders RunOnce component', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('run-once')).toBeInTheDocument() + }) + }) + + it('renders Result component', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) + }) + + describe('workflow mode', () => { + it('passes isWorkflow=true to Result when isWorkflow prop is true', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + const resultComponent = screen.getByTestId('result-component') + expect(resultComponent).toHaveAttribute('data-is-workflow', 'true') + }) + }) + + it('passes isWorkflow=false to Result when isWorkflow prop is false', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + const resultComponent = screen.getByTestId('result-component') + expect(resultComponent).toHaveAttribute('data-is-workflow', 'false') + }) + }) + }) + + describe('send functionality', () => { + it('triggers send when RunOnce sends', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('send-button')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('send-button')) + + // The send should work without errors + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) + + describe('completion handling', () => { + it('shows alert after completion', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('complete-button')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('complete-button')) + + await waitFor(() => { + expect(screen.getByText('This is a try app notice')).toBeInTheDocument() + }) + }) + }) + + describe('className prop', () => { + it('applies custom className', async () => { + const appData = createMockAppData() + + const { container } = render( + , + ) + + await waitFor(() => { + const element = container.querySelector('.custom-class') + expect(element).toBeInTheDocument() + }) + }) + }) + + describe('hook effects', () => { + it('calls updateAppInfo when appData changes', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(mockUpdateAppInfo).toHaveBeenCalledWith(appData) + }) + }) + + it('calls updateAppParams when tryAppParams changes', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(mockUpdateAppParams).toHaveBeenCalledWith(mockAppParams) + }) + }) + + it('calls useGetTryAppParams with correct appId', () => { + const appData = createMockAppData() + + render( + , + ) + + expect(mockUseGetTryAppParams).toHaveBeenCalledWith('my-app-id') + }) + }) + + describe('result panel visibility', () => { + it('shows result panel after run starts', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('run-start-button')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('run-start-button')) + + // Result panel should remain visible + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) + + describe('input handling', () => { + it('handles input changes from RunOnce', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('inputs-change-button')).toBeInTheDocument() + }) + + // Trigger input change which should call setInputs callback + fireEvent.click(screen.getByTestId('inputs-change-button')) + + // The component should handle the input change without errors + expect(screen.getByTestId('run-once')).toBeInTheDocument() + }) + }) + + describe('mobile behavior', () => { + it('renders mobile toggle panel on mobile', async () => { + mockMediaType = 'mobile' + const appData = createMockAppData() + + const { container } = render( + , + ) + + await waitFor(() => { + // Mobile toggle panel should be rendered + const togglePanel = container.querySelector('.cursor-grab') + expect(togglePanel).toBeInTheDocument() + }) + }) + + it('toggles result panel visibility on mobile', async () => { + mockMediaType = 'mobile' + const appData = createMockAppData() + + const { container } = render( + , + ) + + await waitFor(() => { + const togglePanel = container.querySelector('.cursor-grab') + expect(togglePanel).toBeInTheDocument() + }) + + // Click to show result panel + const toggleParent = container.querySelector('.cursor-grab')?.parentElement + if (toggleParent) { + fireEvent.click(toggleParent) + } + + // Click again to hide result panel + await waitFor(() => { + const newToggleParent = container.querySelector('.cursor-grab')?.parentElement + if (newToggleParent) { + fireEvent.click(newToggleParent) + } + }) + + // Component should handle both show and hide without errors + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/explore/try-app/index.spec.tsx b/web/app/components/explore/try-app/index.spec.tsx new file mode 100644 index 000000000..dc057b4d9 --- /dev/null +++ b/web/app/components/explore/try-app/index.spec.tsx @@ -0,0 +1,419 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import TryApp from './index' +import { TypeEnum } from './tab' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'tryApp.tabHeader.try': 'Try', + 'tryApp.tabHeader.detail': 'Detail', + } + return translations[key] || key + }, + }), +})) + +vi.mock('@/config', async (importOriginal) => { + const actual = await importOriginal() as object + return { + ...actual, + IS_CLOUD_EDITION: true, + } +}) + +const mockUseGetTryAppInfo = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppInfo: (...args: unknown[]) => mockUseGetTryAppInfo(...args), +})) + +vi.mock('./app', () => ({ + default: ({ appId, appDetail }: { appId: string, appDetail: TryAppInfo }) => ( +
    + App Component +
    + ), +})) + +vi.mock('./preview', () => ({ + default: ({ appId, appDetail }: { appId: string, appDetail: TryAppInfo }) => ( +
    + Preview Component +
    + ), +})) + +vi.mock('./app-info', () => ({ + default: ({ + appId, + appDetail, + category, + className, + onCreate, + }: { appId: string, appDetail: TryAppInfo, category?: string, className?: string, onCreate: () => void }) => ( +
    + + App Info: + {' '} + {appDetail?.name} +
    + ), +})) + +const createMockAppDetail = (mode: string = 'chat'): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App Name', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: '🚀', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, +} as unknown as TryAppInfo) + +describe('TryApp (main index.tsx)', () => { + beforeEach(() => { + mockUseGetTryAppInfo.mockReturnValue({ + data: createMockAppDetail(), + isLoading: false, + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('loading state', () => { + it('renders loading when isLoading is true', () => { + mockUseGetTryAppInfo.mockReturnValue({ + data: null, + isLoading: true, + }) + + render( + , + ) + + expect(document.body.querySelector('[role="status"]')).toBeInTheDocument() + }) + }) + + describe('content rendering', () => { + it('renders Tab component', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Try')).toBeInTheDocument() + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + }) + + it('renders App component by default (TRY mode)', async () => { + render( + , + ) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="app-component"]')).toBeInTheDocument() + expect(document.body.querySelector('[data-testid="preview-component"]')).not.toBeInTheDocument() + }) + }) + + it('renders AppInfo component', async () => { + render( + , + ) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="app-info-component"]')).toBeInTheDocument() + }) + }) + + it('renders close button', async () => { + render( + , + ) + + await waitFor(() => { + // Find the close button (the one with RiCloseLine icon) + const buttons = document.body.querySelectorAll('button') + expect(buttons.length).toBeGreaterThan(0) + }) + }) + }) + + describe('tab switching', () => { + it('switches to Preview when Detail tab is clicked', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByText('Detail')) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="preview-component"]')).toBeInTheDocument() + expect(document.body.querySelector('[data-testid="app-component"]')).not.toBeInTheDocument() + }) + }) + + it('switches back to App when Try tab is clicked', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + // First switch to Detail + fireEvent.click(screen.getByText('Detail')) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="preview-component"]')).toBeInTheDocument() + }) + + // Then switch back to Try + fireEvent.click(screen.getByText('Try')) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="app-component"]')).toBeInTheDocument() + }) + }) + }) + + describe('close functionality', () => { + it('calls onClose when close button is clicked', async () => { + const mockOnClose = vi.fn() + + render( + , + ) + + await waitFor(() => { + // Find the button with close icon + const buttons = document.body.querySelectorAll('button') + const closeButton = Array.from(buttons).find(btn => + btn.querySelector('svg') || btn.className.includes('rounded-[10px]'), + ) + expect(closeButton).toBeInTheDocument() + + if (closeButton) + fireEvent.click(closeButton) + }) + + expect(mockOnClose).toHaveBeenCalled() + }) + }) + + describe('create functionality', () => { + it('calls onCreate when create button in AppInfo is clicked', async () => { + const mockOnCreate = vi.fn() + + render( + , + ) + + await waitFor(() => { + const createButton = document.body.querySelector('[data-testid="create-button"]') + expect(createButton).toBeInTheDocument() + + if (createButton) + fireEvent.click(createButton) + }) + + expect(mockOnCreate).toHaveBeenCalledTimes(1) + }) + }) + + describe('category prop', () => { + it('passes category to AppInfo when provided', async () => { + render( + , + ) + + await waitFor(() => { + const appInfo = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfo).toHaveAttribute('data-category', 'AI Assistant') + }) + }) + + it('does not pass category to AppInfo when not provided', async () => { + render( + , + ) + + await waitFor(() => { + const appInfo = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfo).not.toHaveAttribute('data-category', expect.any(String)) + }) + }) + }) + + describe('hook calls', () => { + it('calls useGetTryAppInfo with correct appId', () => { + render( + , + ) + + expect(mockUseGetTryAppInfo).toHaveBeenCalledWith('my-specific-app-id') + }) + }) + + describe('props passing', () => { + it('passes appId to App component', async () => { + render( + , + ) + + await waitFor(() => { + const appComponent = document.body.querySelector('[data-testid="app-component"]') + expect(appComponent).toHaveAttribute('data-app-id', 'my-app-id') + }) + }) + + it('passes appId to Preview component when in Detail mode', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByText('Detail')) + + await waitFor(() => { + const previewComponent = document.body.querySelector('[data-testid="preview-component"]') + expect(previewComponent).toHaveAttribute('data-app-id', 'my-app-id') + }) + }) + + it('passes appId to AppInfo component', async () => { + render( + , + ) + + await waitFor(() => { + const appInfoComponent = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfoComponent).toHaveAttribute('data-app-id', 'my-app-id') + }) + }) + + it('passes appDetail to AppInfo component', async () => { + render( + , + ) + + await waitFor(() => { + const appInfoComponent = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfoComponent?.textContent).toContain('Test App Name') + }) + }) + }) + + describe('TypeEnum export', () => { + it('exports TypeEnum correctly', () => { + expect(TypeEnum.TRY).toBe('try') + expect(TypeEnum.DETAIL).toBe('detail') + }) + }) +}) diff --git a/web/app/components/explore/try-app/index.tsx b/web/app/components/explore/try-app/index.tsx index b2e2b7214..c6f00ed08 100644 --- a/web/app/components/explore/try-app/index.tsx +++ b/web/app/components/explore/try-app/index.tsx @@ -1,11 +1,13 @@ /* eslint-disable style/multiline-ternary */ 'use client' import type { FC } from 'react' +import type { App as AppType } from '@/models/explore' import { RiCloseLine } from '@remixicon/react' import * as React from 'react' import { useState } from 'react' import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal/index' +import { useGlobalPublicStore } from '@/context/global-public-context' import { useGetTryAppInfo } from '@/service/use-try-app' import Button from '../../base/button' import App from './app' @@ -15,6 +17,7 @@ import Tab, { TypeEnum } from './tab' type Props = { appId: string + app?: AppType category?: string onClose: () => void onCreate: () => void @@ -22,13 +25,23 @@ type Props = { const TryApp: FC = ({ appId, + app, category, onClose, onCreate, }) => { - const [type, setType] = useState(TypeEnum.TRY) + const { systemFeatures } = useGlobalPublicStore() + const isTrialApp = !!(app && app.can_trial && systemFeatures.enable_trial_app) + const [type, setType] = useState(() => (app && !isTrialApp ? TypeEnum.DETAIL : TypeEnum.TRY)) const { data: appDetail, isLoading } = useGetTryAppInfo(appId) + React.useEffect(() => { + if (app && !isTrialApp && type !== TypeEnum.DETAIL) + // eslint-disable-next-line react-hooks-extra/no-direct-set-state-in-use-effect + setType(TypeEnum.DETAIL) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [app, isTrialApp]) + return ( = ({ + + ), })) @@ -110,65 +142,504 @@ describe('GotoAnything', () => { mockQueryResult = { data: [], isLoading: false, isError: false, error: null } matchActionMock.mockReset() searchAnythingMock.mockClear() + mockFindCommand = null }) - it('should open modal via shortcut and navigate to selected result', async () => { - mockQueryResult = { - data: [{ - id: 'app-1', - type: 'app', - title: 'Sample App', - description: 'desc', - path: '/apps/1', - icon:
    🧩
    , - data: {}, - } as any], - isLoading: false, - isError: false, - error: null, - } + describe('modal behavior', () => { + it('should open modal via Ctrl+K shortcut', async () => { + render() - render() + triggerKeyPress('ctrl.k') - triggerKeyPress('ctrl.k') + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + }) - const input = await screen.findByPlaceholderText('app.gotoAnything.searchPlaceholder') - await userEvent.type(input, 'app') + it('should close modal via ESC key', async () => { + render() - const result = await screen.findByText('Sample App') - await userEvent.click(result) + triggerKeyPress('ctrl.k') + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) - expect(routerPush).toHaveBeenCalledWith('/apps/1') + triggerKeyPress('esc') + await waitFor(() => { + expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument() + }) + }) + + it('should toggle modal when pressing Ctrl+K twice', async () => { + render() + + triggerKeyPress('ctrl.k') + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + triggerKeyPress('ctrl.k') + await waitFor(() => { + expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument() + }) + }) + + it('should call onHide when modal closes', async () => { + const onHide = vi.fn() + render() + + triggerKeyPress('ctrl.k') + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + triggerKeyPress('esc') + await waitFor(() => { + expect(onHide).toHaveBeenCalled() + }) + }) + + it('should reset search query when modal opens', async () => { + const user = userEvent.setup() + render() + + // Open modal first time + triggerKeyPress('ctrl.k') + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + // Type something + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'test') + + // Close modal + triggerKeyPress('esc') + await waitFor(() => { + expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument() + }) + + // Open modal again - should be empty + triggerKeyPress('ctrl.k') + await waitFor(() => { + const newInput = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + expect(newInput).toHaveValue('') + }) + }) }) - it('should open plugin installer when selecting plugin result', async () => { - mockQueryResult = { - data: [{ - id: 'plugin-1', - type: 'plugin', - title: 'Plugin Item', - description: 'desc', - path: '', - icon:
    , - data: { - name: 'Plugin Item', - latest_package_identifier: 'pkg', - }, - } as any], - isLoading: false, - isError: false, - error: null, - } + describe('search functionality', () => { + it('should navigate to selected result', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [{ + id: 'app-1', + type: 'app', + title: 'Sample App', + description: 'desc', + path: '/apps/1', + icon:
    🧩
    , + data: {}, + }], + isLoading: false, + isError: false, + error: null, + } - render() + render() + triggerKeyPress('ctrl.k') - triggerKeyPress('ctrl.k') - const input = await screen.findByPlaceholderText('app.gotoAnything.searchPlaceholder') - await userEvent.type(input, 'plugin') + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) - const pluginItem = await screen.findByText('Plugin Item') - await userEvent.click(pluginItem) + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'app') - expect(await screen.findByTestId('install-modal')).toHaveTextContent('Plugin Item') + const result = await screen.findByText('Sample App') + await user.click(result) + + expect(routerPush).toHaveBeenCalledWith('/apps/1') + }) + + it('should clear selection when typing without prefix', async () => { + const user = userEvent.setup() + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'test query') + + // Should not throw and input should have value + expect(input).toHaveValue('test query') + }) + }) + + describe('empty states', () => { + it('should show loading state', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [], + isLoading: true, + isError: false, + error: null, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'search') + + // Loading state shows in both EmptyState (spinner) and Footer + const searchingTexts = screen.getAllByText('app.gotoAnything.searching') + expect(searchingTexts.length).toBeGreaterThanOrEqual(1) + }) + + it('should show error state', async () => { + const user = userEvent.setup() + const testError = new Error('Search failed') + mockQueryResult = { + data: [], + isLoading: false, + isError: true, + error: testError, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'search') + + expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument() + }) + + it('should show default state when no query', async () => { + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument() + }) + + it('should show no results state when search returns empty', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [], + isLoading: false, + isError: false, + error: null, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'nonexistent') + + expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument() + }) + }) + + describe('plugin installation', () => { + it('should open plugin installer when selecting plugin result', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [{ + id: 'plugin-1', + type: 'plugin', + title: 'Plugin Item', + description: 'desc', + path: '', + icon:
    , + data: { + name: 'Plugin Item', + latest_package_identifier: 'pkg', + }, + }], + isLoading: false, + isError: false, + error: null, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'plugin') + + const pluginItem = await screen.findByText('Plugin Item') + await user.click(pluginItem) + + expect(await screen.findByTestId('install-modal')).toHaveTextContent('Plugin Item') + }) + + it('should close plugin installer via close button', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [{ + id: 'plugin-1', + type: 'plugin', + title: 'Plugin Item', + description: 'desc', + path: '', + icon:
    , + data: { + name: 'Plugin Item', + latest_package_identifier: 'pkg', + }, + }], + isLoading: false, + isError: false, + error: null, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'plugin') + + const pluginItem = await screen.findByText('Plugin Item') + await user.click(pluginItem) + + const closeBtn = await screen.findByTestId('close-install') + await user.click(closeBtn) + + await waitFor(() => { + expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument() + }) + }) + + it('should close plugin installer on success', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [{ + id: 'plugin-1', + type: 'plugin', + title: 'Plugin Item', + description: 'desc', + path: '', + icon:
    , + data: { + name: 'Plugin Item', + latest_package_identifier: 'pkg', + }, + }], + isLoading: false, + isError: false, + error: null, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'plugin') + + const pluginItem = await screen.findByText('Plugin Item') + await user.click(pluginItem) + + const successBtn = await screen.findByTestId('success-install') + await user.click(successBtn) + + await waitFor(() => { + expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument() + }) + }) + }) + + describe('slash command handling', () => { + it('should execute direct slash command on Enter', async () => { + const user = userEvent.setup() + const executeMock = vi.fn() + mockFindCommand = { + mode: 'direct', + execute: executeMock, + isAvailable: () => true, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, '/theme') + await user.keyboard('{Enter}') + + expect(executeMock).toHaveBeenCalled() + }) + + it('should NOT execute unavailable slash command', async () => { + const user = userEvent.setup() + const executeMock = vi.fn() + mockFindCommand = { + mode: 'direct', + execute: executeMock, + isAvailable: () => false, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, '/theme') + await user.keyboard('{Enter}') + + expect(executeMock).not.toHaveBeenCalled() + }) + + it('should NOT execute non-direct mode slash command on Enter', async () => { + const user = userEvent.setup() + const executeMock = vi.fn() + mockFindCommand = { + mode: 'submenu', + execute: executeMock, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, '/language') + await user.keyboard('{Enter}') + + expect(executeMock).not.toHaveBeenCalled() + }) + + it('should close modal after executing direct slash command', async () => { + const user = userEvent.setup() + mockFindCommand = { + mode: 'direct', + execute: vi.fn(), + isAvailable: () => true, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, '/theme') + await user.keyboard('{Enter}') + + await waitFor(() => { + expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument() + }) + }) + }) + + describe('result navigation', () => { + it('should handle knowledge result navigation', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [{ + id: 'kb-1', + type: 'knowledge', + title: 'Knowledge Base', + description: 'desc', + path: '/datasets/kb-1', + icon:
    , + data: {}, + }], + isLoading: false, + isError: false, + error: null, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'knowledge') + + const result = await screen.findByText('Knowledge Base') + await user.click(result) + + expect(routerPush).toHaveBeenCalledWith('/datasets/kb-1') + }) + + it('should NOT navigate when result has no path', async () => { + const user = userEvent.setup() + mockQueryResult = { + data: [{ + id: 'item-1', + type: 'app', + title: 'No Path Item', + description: 'desc', + path: '', + icon:
    , + data: {}, + }], + isLoading: false, + isError: false, + error: null, + } + + render() + triggerKeyPress('ctrl.k') + + await waitFor(() => { + expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument() + }) + + const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder') + await user.type(input, 'no path') + + const result = await screen.findByText('No Path Item') + await user.click(result) + + expect(routerPush).not.toHaveBeenCalled() + }) }) }) diff --git a/web/app/components/goto-anything/index.tsx b/web/app/components/goto-anything/index.tsx index d34176e4c..8ee2395cc 100644 --- a/web/app/components/goto-anything/index.tsx +++ b/web/app/components/goto-anything/index.tsx @@ -1,299 +1,149 @@ 'use client' -import type { FC } from 'react' -import type { Plugin } from '../plugins/types' -import type { SearchResult } from './actions' -import { RiSearchLine } from '@remixicon/react' -import { useQuery } from '@tanstack/react-query' -import { useDebounce, useKeyPress } from 'ahooks' +import type { FC, KeyboardEvent } from 'react' import { Command } from 'cmdk' -import { useRouter } from 'next/navigation' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' -import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { getKeyboardKeyCodeBySystem, isEventTargetInputArea, isMac } from '@/app/components/workflow/utils/common' -import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation' -import { useGetLanguage } from '@/context/i18n' import InstallFromMarketplace from '../plugins/install-plugin/install-from-marketplace' -import { createActions, matchAction, searchAnything } from './actions' import { SlashCommandProvider } from './actions/commands' import { slashCommandRegistry } from './actions/commands/registry' import CommandSelector from './command-selector' +import { EmptyState, Footer, ResultList, SearchInput } from './components' import { GotoAnythingProvider, useGotoAnythingContext } from './context' +import { + useGotoAnythingModal, + useGotoAnythingNavigation, + useGotoAnythingResults, + useGotoAnythingSearch, +} from './hooks' type Props = { onHide?: () => void } + const GotoAnything: FC = ({ onHide, }) => { - const router = useRouter() - const defaultLocale = useGetLanguage() - const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext() const { t } = useTranslation() - const [show, setShow] = useState(false) - const [searchQuery, setSearchQuery] = useState('') - const [cmdVal, setCmdVal] = useState('_') - const inputRef = useRef(null) + const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext() + const prevShowRef = useRef(false) - // Filter actions based on context - const Actions = useMemo(() => { - // Create actions based on current page context - return createActions(isWorkflowPage, isRagPipelinePage) - }, [isWorkflowPage, isRagPipelinePage]) + // Search state management (called first so setSearchQuery is available) + const { + searchQuery, + setSearchQuery, + searchQueryDebouncedValue, + searchMode, + isCommandsMode, + cmdVal, + setCmdVal, + clearSelection, + Actions, + } = useGotoAnythingSearch() - const [activePlugin, setActivePlugin] = useState() + // Modal state management + const { + show, + setShow, + inputRef, + handleClose: modalClose, + } = useGotoAnythingModal() - // Handle keyboard shortcuts - const handleToggleModal = useCallback((e: KeyboardEvent) => { - // Allow closing when modal is open, even if focus is in the search input - if (!show && isEventTargetInputArea(e.target as HTMLElement)) - return - e.preventDefault() - setShow((prev) => { - if (!prev) { - // Opening modal - reset search state - setSearchQuery('') - } - return !prev - }) - }, [show]) - - useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.k`, handleToggleModal, { - exactMatch: true, - useCapture: true, - }) - - useKeyPress(['esc'], (e) => { - if (show) { - e.preventDefault() - setShow(false) + // Reset state when modal opens/closes + useEffect(() => { + if (show && !prevShowRef.current) { + // Modal just opened - reset search setSearchQuery('') } + else if (!show && prevShowRef.current) { + // Modal just closed + setSearchQuery('') + clearSelection() + onHide?.() + } + prevShowRef.current = show + }, [show, setSearchQuery, clearSelection, onHide]) + + // Results fetching and processing + const { + dedupedResults, + groupedResults, + isLoading, + isError, + error, + } = useGotoAnythingResults({ + searchQueryDebouncedValue, + searchMode, + isCommandsMode, + Actions, + isWorkflowPage, + isRagPipelinePage, + cmdVal, + setCmdVal, }) - const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), { - wait: 300, + // Navigation handlers + const { + handleCommandSelect, + handleNavigate, + activePlugin, + setActivePlugin, + } = useGotoAnythingNavigation({ + Actions, + setSearchQuery, + clearSelection, + inputRef, + onClose: () => setShow(false), }) - const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/' - || (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), Actions)) - || (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), Actions)) + // Handle search input change + const handleSearchChange = useCallback((value: string) => { + setSearchQuery(value) + if (!value.startsWith('@') && !value.startsWith('/')) + clearSelection() + }, [setSearchQuery, clearSelection]) - const searchMode = useMemo(() => { - if (isCommandsMode) { - // Distinguish between @ (scopes) and / (commands) mode - if (searchQuery.trim().startsWith('@')) - return 'scopes' - else if (searchQuery.trim().startsWith('/')) - return 'commands' - return 'commands' // default fallback - } + // Handle search input keydown for slash commands + const handleSearchKeyDown = useCallback((e: KeyboardEvent) => { + if (e.key === 'Enter') { + const query = searchQuery.trim() + // Check if it's a complete slash command + if (query.startsWith('/')) { + const commandName = query.substring(1).split(' ')[0] + const handler = slashCommandRegistry.findCommand(commandName) - const query = searchQueryDebouncedValue.toLowerCase() - const action = matchAction(query, Actions) - - if (!action) - return 'general' - - return action.key === '/' ? '@command' : action.key - }, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery]) - - const { data: searchResults = [], isLoading, isError, error } = useQuery( - { - queryKey: [ - 'goto-anything', - 'search-result', - searchQueryDebouncedValue, - searchMode, - isWorkflowPage, - isRagPipelinePage, - defaultLocale, - Actions, - ], - queryFn: async () => { - const query = searchQueryDebouncedValue.toLowerCase() - const action = matchAction(query, Actions) - return await searchAnything(defaultLocale, query, action, Actions) - }, - enabled: !!searchQueryDebouncedValue && !isCommandsMode, - staleTime: 30000, - gcTime: 300000, - }, - ) - - // Prevent automatic selection of the first option when cmdVal is not set - const clearSelection = () => { - setCmdVal('_') - } - - const handleCommandSelect = useCallback((commandKey: string) => { - // Check if it's a slash command - if (commandKey.startsWith('/')) { - const commandName = commandKey.substring(1) - const handler = slashCommandRegistry.findCommand(commandName) - - // If it's a direct mode command, execute immediately - if (handler?.mode === 'direct' && handler.execute) { - handler.execute() - setShow(false) - setSearchQuery('') - return + // If it's a direct mode command, execute immediately + const isAvailable = handler?.isAvailable?.() ?? true + if (handler?.mode === 'direct' && handler.execute && isAvailable) { + e.preventDefault() + handler.execute() + setShow(false) + setSearchQuery('') + } } } + }, [searchQuery, setShow, setSearchQuery]) - // Otherwise, proceed with the normal flow (submenu mode) - setSearchQuery(`${commandKey} `) - clearSelection() - setTimeout(() => { - inputRef.current?.focus() - }, 0) - }, []) - - // Handle navigation to selected result - const handleNavigate = useCallback((result: SearchResult) => { - setShow(false) - setSearchQuery('') - - switch (result.type) { - case 'command': { - // Execute slash commands - const action = Actions.slash - action?.action?.(result) - break - } - case 'plugin': - setActivePlugin(result.data) - break - case 'workflow-node': - // Handle workflow node selection and navigation - if (result.metadata?.nodeId) - selectWorkflowNode(result.metadata.nodeId, true) - - break - default: - if (result.path) - router.push(result.path) - } - }, [router]) - - const dedupedResults = useMemo(() => { - const seen = new Set() - return searchResults.filter((result) => { - const key = `${result.type}-${result.id}` - if (seen.has(key)) - return false - seen.add(key) - return true - }) - }, [searchResults]) - - // Group results by type - const groupedResults = useMemo(() => dedupedResults.reduce((acc, result) => { - if (!acc[result.type]) - acc[result.type] = [] - - acc[result.type].push(result) - return acc - }, {} as { [key: string]: SearchResult[] }), [dedupedResults]) - - useEffect(() => { - if (isCommandsMode) - return - - if (!dedupedResults.length) - return - - const currentValueExists = dedupedResults.some(result => `${result.type}-${result.id}` === cmdVal) - - if (!currentValueExists) - setCmdVal(`${dedupedResults[0].type}-${dedupedResults[0].id}`) - }, [isCommandsMode, dedupedResults, cmdVal]) - - const emptyResult = useMemo(() => { - if (dedupedResults.length || !searchQuery.trim() || isLoading || isCommandsMode) - return null - - const isCommandSearch = searchMode !== 'general' - const commandType = isCommandSearch ? searchMode.replace('@', '') : '' - - if (isError) { - return ( -
    -
    -
    {t('gotoAnything.searchTemporarilyUnavailable', { ns: 'app' })}
    -
    - {t('gotoAnything.servicesUnavailableMessage', { ns: 'app' })} -
    -
    -
    - ) - } - - return ( -
    -
    -
    - {isCommandSearch - ? (() => { - const keyMap = { - app: 'gotoAnything.emptyState.noAppsFound', - plugin: 'gotoAnything.emptyState.noPluginsFound', - knowledge: 'gotoAnything.emptyState.noKnowledgeBasesFound', - node: 'gotoAnything.emptyState.noWorkflowNodesFound', - } as const - return t(keyMap[commandType as keyof typeof keyMap] || 'gotoAnything.noResults', { ns: 'app' }) - })() - : t('gotoAnything.noResults', { ns: 'app' })} -
    -
    - {isCommandSearch - ? t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' }) - : t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts: Object.values(Actions).map(action => action.shortcut).join(', ') })} -
    -
    -
    - ) - }, [dedupedResults, searchQuery, Actions, searchMode, isLoading, isError, isCommandsMode]) - - const defaultUI = useMemo(() => { - if (searchQuery.trim()) - return null - - return ( -
    -
    -
    {t('gotoAnything.searchTitle', { ns: 'app' })}
    -
    -
    {t('gotoAnything.searchHint', { ns: 'app' })}
    -
    {t('gotoAnything.commandHint', { ns: 'app' })}
    -
    {t('gotoAnything.slashHint', { ns: 'app' })}
    -
    -
    -
    - ) - }, [searchQuery, Actions]) - - useEffect(() => { - if (show) { - requestAnimationFrame(() => { - inputRef.current?.focus() - }) - } - }, [show]) + // Determine which empty state to show + const emptyStateVariant = useMemo(() => { + if (isLoading) + return 'loading' + if (isError) + return 'error' + if (!searchQuery.trim()) + return 'default' + if (dedupedResults.length === 0 && !isCommandsMode) + return 'no-results' + return null + }, [isLoading, isError, searchQuery, dedupedResults.length, isCommandsMode]) return ( <> { - setShow(false) - setSearchQuery('') - clearSelection() - onHide?.() - }} + onClose={modalClose} closable={false} className="!w-[480px] !p-0" highPriority={true} @@ -306,85 +156,24 @@ const GotoAnything: FC = ({ disablePointerSelection loop > -
    - -
    - { - setSearchQuery(e.target.value) - if (!e.target.value.startsWith('@') && !e.target.value.startsWith('/')) - clearSelection() - }} - onKeyDown={(e) => { - if (e.key === 'Enter') { - const query = searchQuery.trim() - // Check if it's a complete slash command - if (query.startsWith('/')) { - const commandName = query.substring(1).split(' ')[0] - const handler = slashCommandRegistry.findCommand(commandName) - - // If it's a direct mode command, execute immediately - const isAvailable = handler?.isAvailable?.() ?? true - if (handler?.mode === 'direct' && handler.execute && isAvailable) { - e.preventDefault() - handler.execute() - setShow(false) - setSearchQuery('') - } - } - } - }} - className="flex-1 !border-0 !bg-transparent !shadow-none" - wrapperClassName="flex-1 !border-0 !bg-transparent" - autoFocus - /> - {searchMode !== 'general' && ( -
    - - {(() => { - if (searchMode === 'scopes') - return 'SCOPES' - else if (searchMode === 'commands') - return 'COMMANDS' - else - return searchMode.replace('@', '').toUpperCase() - })()} - -
    - )} -
    -
    - - {isMac() ? '⌘' : 'Ctrl'} - - - K - -
    -
    + - {isLoading && ( -
    -
    -
    - {t('gotoAnything.searching', { ns: 'app' })} -
    -
    + {emptyStateVariant === 'loading' && ( + )} - {isError && ( -
    -
    -
    {t('gotoAnything.searchFailed', { ns: 'app' })}
    -
    - {error.message} -
    -
    -
    + + {emptyStateVariant === 'error' && ( + )} + {!isLoading && !isError && ( <> {isCommandsMode @@ -399,118 +188,46 @@ const GotoAnything: FC = ({ /> ) : ( - Object.entries(groupedResults).map(([type, results], groupIndex) => ( - { - const typeMap = { - 'app': 'gotoAnything.groups.apps', - 'plugin': 'gotoAnything.groups.plugins', - 'knowledge': 'gotoAnything.groups.knowledgeBases', - 'workflow-node': 'gotoAnything.groups.workflowNodes', - 'command': 'gotoAnything.groups.commands', - } as const - return t(typeMap[type as keyof typeof typeMap] || `${type}s`, { ns: 'app' }) - })()} - className="p-2 capitalize text-text-secondary" - > - {results.map(result => ( - handleNavigate(result)} - > - {result.icon} -
    -
    - {result.title} -
    - {result.description && ( -
    - {result.description} -
    - )} -
    -
    - {result.type} -
    -
    - ))} -
    - )) + )} - {!isCommandsMode && emptyResult} - {!isCommandsMode && defaultUI} + + {!isCommandsMode && emptyStateVariant === 'no-results' && ( + + )} + + {!isCommandsMode && emptyStateVariant === 'default' && ( + + )} )}
    - {/* Always show footer to prevent height jumping */} -
    -
    - {(!!dedupedResults.length || isError) - ? ( - <> - - {isError - ? ( - {t('gotoAnything.someServicesUnavailable', { ns: 'app' })} - ) - : ( - <> - {t('gotoAnything.resultCount', { ns: 'app', count: dedupedResults.length })} - {searchMode !== 'general' && ( - - {t('gotoAnything.inScope', { ns: 'app', scope: searchMode.replace('@', '') })} - - )} - - )} - - - {searchMode !== 'general' - ? t('gotoAnything.clearToSearchAll', { ns: 'app' }) - : t('gotoAnything.useAtForSpecific', { ns: 'app' })} - - - ) - : ( - <> - - {(() => { - if (isCommandsMode) - return t('gotoAnything.selectToNavigate', { ns: 'app' }) - - if (searchQuery.trim()) - return t('gotoAnything.searching', { ns: 'app' }) - - return t('gotoAnything.startTyping', { ns: 'app' }) - })()} - - - {searchQuery.trim() || isCommandsMode - ? t('gotoAnything.tips', { ns: 'app' }) - : t('gotoAnything.pressEscToClose', { ns: 'app' })} - - - )} -
    -
    +
    - - { - activePlugin && ( - setActivePlugin(undefined)} - onSuccess={() => setActivePlugin(undefined)} - /> - ) - } + + {activePlugin && ( + setActivePlugin(undefined)} + onSuccess={() => setActivePlugin(undefined)} + /> + )} ) } diff --git a/web/app/components/header/account-dropdown/compliance.tsx b/web/app/components/header/account-dropdown/compliance.tsx index 562914dd0..6bc5b5c3f 100644 --- a/web/app/components/header/account-dropdown/compliance.tsx +++ b/web/app/components/header/account-dropdown/compliance.tsx @@ -10,6 +10,7 @@ import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import { getDocDownloadUrl } from '@/service/common' import { cn } from '@/utils/classnames' +import { downloadUrl } from '@/utils/download' import Button from '../../base/button' import Gdpr from '../../base/icons/src/public/common/Gdpr' import Iso from '../../base/icons/src/public/common/Iso' @@ -47,9 +48,7 @@ const UpgradeOrDownload: FC = ({ doc_name }) => { mutationFn: async () => { try { const ret = await getDocDownloadUrl(doc_name) - const a = document.createElement('a') - a.href = ret.url - a.click() + downloadUrl({ url: ret.url }) Toast.notify({ type: 'success', message: t('operation.downloadSuccess', { ns: 'common' }), diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx index 543d3deeb..9155fa15b 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx @@ -1,5 +1,5 @@ import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' // Import after mocks @@ -821,6 +821,9 @@ describe('CommonCreateModal', () => { expect(mockCreateBuilder).toHaveBeenCalled() }) + // Flush pending state updates from createBuilder promise resolution + await act(async () => {}) + const input = screen.getByTestId('form-field-webhook_url') fireEvent.change(input, { target: { value: 'https://example.com/webhook' } }) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/footer-tip.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/footer-tip.spec.tsx new file mode 100644 index 000000000..5d5cde973 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/footer-tip.spec.tsx @@ -0,0 +1,59 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import FooterTip from './footer-tip' + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('FooterTip', () => { + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('Drag to adjust grouping')).toBeInTheDocument() + }) + + it('should render the drag tip text', () => { + render() + + expect(screen.getByText('Drag to adjust grouping')).toBeInTheDocument() + }) + + it('should have correct container classes', () => { + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'shrink-0', 'items-center', 'justify-center', 'gap-x-2', 'py-4') + }) + + it('should have correct text styling', () => { + render() + + const text = screen.getByText('Drag to adjust grouping') + expect(text).toHaveClass('system-xs-regular') + }) + + it('should have correct text color', () => { + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('text-text-quaternary') + }) + + it('should render the drag icon', () => { + const { container } = render() + + // The RiDragDropLine icon should be rendered + const icon = container.querySelector('.size-4') + expect(icon).toBeInTheDocument() + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((FooterTip as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/hooks.spec.ts b/web/app/components/rag-pipeline/components/panel/input-field/hooks.spec.ts new file mode 100644 index 000000000..452963ba7 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/hooks.spec.ts @@ -0,0 +1,166 @@ +import { renderHook } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { useFloatingRight } from './hooks' + +// Mock reactflow +const mockGetNodes = vi.fn() +vi.mock('reactflow', () => ({ + useStore: (selector: (s: { getNodes: () => { id: string, data: { selected: boolean } }[] }) => unknown) => { + return selector({ getNodes: mockGetNodes }) + }, +})) + +// Mock zustand/react/shallow +vi.mock('zustand/react/shallow', () => ({ + useShallow: (fn: (...args: unknown[]) => unknown) => fn, +})) + +// Mock workflow store +let mockNodePanelWidth = 400 +let mockWorkflowCanvasWidth: number | undefined = 1200 +let mockOtherPanelWidth = 0 + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => { + return selector({ + nodePanelWidth: mockNodePanelWidth, + workflowCanvasWidth: mockWorkflowCanvasWidth, + otherPanelWidth: mockOtherPanelWidth, + }) + }, +})) + +beforeEach(() => { + mockNodePanelWidth = 400 + mockWorkflowCanvasWidth = 1200 + mockOtherPanelWidth = 0 + mockGetNodes.mockReturnValue([]) +}) + +afterEach(() => { + vi.clearAllMocks() +}) + +describe('useFloatingRight', () => { + describe('initial state', () => { + it('should return floatingRight as false initially', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useFloatingRight(600)) + + expect(result.current.floatingRight).toBe(false) + }) + + it('should return floatingRightWidth as target width when not floating', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useFloatingRight(600)) + + expect(result.current.floatingRightWidth).toBe(600) + }) + }) + + describe('with no selected node', () => { + it('should calculate space without node panel width', () => { + mockGetNodes.mockReturnValue([{ id: 'node-1', data: { selected: false } }]) + mockWorkflowCanvasWidth = 1000 + + const { result } = renderHook(() => useFloatingRight(400)) + + // leftWidth = 1000 - 0 (no selected node) - 0 - 400 - 4 = 596 + // 596 >= 404 so floatingRight should be false + expect(result.current.floatingRight).toBe(false) + }) + }) + + describe('with selected node', () => { + it('should subtract node panel width from available space', () => { + mockGetNodes.mockReturnValue([{ id: 'node-1', data: { selected: true } }]) + mockWorkflowCanvasWidth = 1200 + + const { result } = renderHook(() => useFloatingRight(400)) + + // leftWidth = 1200 - 400 (node panel) - 0 - 400 - 4 = 396 + // 396 < 404 so floatingRight should be true + expect(result.current.floatingRight).toBe(true) + }) + }) + + describe('floatingRightWidth calculation', () => { + it('should return target width when not floating', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = 2000 + + const { result } = renderHook(() => useFloatingRight(600)) + + expect(result.current.floatingRightWidth).toBe(600) + }) + + it('should return minimum of target width and available panel widths when floating with no selected node', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = 500 + mockOtherPanelWidth = 200 + + const { result } = renderHook(() => useFloatingRight(600)) + + // When floating and no selected node, width = min(600, 0 + 200) = 200 + expect(result.current.floatingRightWidth).toBeLessThanOrEqual(600) + }) + + it('should include node panel width when node is selected', () => { + mockGetNodes.mockReturnValue([{ id: 'node-1', data: { selected: true } }]) + mockWorkflowCanvasWidth = 500 + mockNodePanelWidth = 300 + mockOtherPanelWidth = 100 + + const { result } = renderHook(() => useFloatingRight(600)) + + // When floating with selected node, width = min(600, 300 + 100) = 400 + expect(result.current.floatingRightWidth).toBeLessThanOrEqual(600) + }) + }) + + describe('edge cases', () => { + it('should handle undefined workflowCanvasWidth', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = undefined + + const { result } = renderHook(() => useFloatingRight(400)) + + // Should not throw and should maintain initial state + expect(result.current.floatingRight).toBe(false) + }) + + it('should handle zero target element width', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useFloatingRight(0)) + + expect(result.current.floatingRightWidth).toBe(0) + }) + + it('should handle very large target element width', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = 500 + + const { result } = renderHook(() => useFloatingRight(10000)) + + // Should be floating due to limited space + expect(result.current.floatingRight).toBe(true) + }) + + it('should return first selected node id when multiple nodes exist', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { selected: false } }, + { id: 'node-2', data: { selected: true } }, + { id: 'node-3', data: { selected: false } }, + ]) + mockWorkflowCanvasWidth = 1200 + + const { result } = renderHook(() => useFloatingRight(400)) + + // Should have selected node so node panel is considered + expect(result.current).toBeDefined() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/index.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/index.spec.tsx new file mode 100644 index 000000000..71be12bb8 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/index.spec.tsx @@ -0,0 +1,212 @@ +import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import { BlockEnum } from '@/app/components/workflow/types' +import Datasource from './datasource' +import GlobalInputs from './global-inputs' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock BlockIcon +vi.mock('@/app/components/workflow/block-icon', () => ({ + default: ({ type, toolIcon, className }: { type: BlockEnum, toolIcon?: string, className?: string }) => ( +
    + ), +})) + +// Mock useToolIcon +vi.mock('@/app/components/workflow/hooks', () => ({ + useToolIcon: (nodeData: DataSourceNodeType) => nodeData.provider_name || 'default-icon', +})) + +// Mock Tooltip +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ popupContent, popupClassName }: { popupContent: string, popupClassName?: string }) => ( +
    + ), +})) + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('Datasource', () => { + const createMockNodeData = (overrides?: Partial): DataSourceNodeType => ({ + title: 'Test Data Source', + desc: 'Test description', + type: BlockEnum.DataSource, + provider_name: 'test-provider', + provider_type: 'api', + datasource_name: 'test-datasource', + datasource_label: 'Test Datasource', + plugin_id: 'test-plugin', + datasource_parameters: {}, + datasource_configurations: {}, + ...overrides, + } as DataSourceNodeType) + + describe('rendering', () => { + it('should render without crashing', () => { + const nodeData = createMockNodeData() + + render() + + expect(screen.getByTestId('block-icon')).toBeInTheDocument() + }) + + it('should render the node title', () => { + const nodeData = createMockNodeData({ title: 'My Custom Data Source' }) + + render() + + expect(screen.getByText('My Custom Data Source')).toBeInTheDocument() + }) + + it('should render BlockIcon with correct type', () => { + const nodeData = createMockNodeData() + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', BlockEnum.DataSource) + }) + + it('should pass toolIcon from useToolIcon hook', () => { + const nodeData = createMockNodeData({ provider_name: 'custom-provider' }) + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'custom-provider') + }) + + it('should have correct icon container styling', () => { + const nodeData = createMockNodeData() + + const { container } = render() + + const iconContainer = container.querySelector('.size-5') + expect(iconContainer).toBeInTheDocument() + expect(iconContainer).toHaveClass('flex', 'items-center', 'justify-center', 'rounded-md') + }) + + it('should have correct text styling', () => { + const nodeData = createMockNodeData() + + render() + + const titleElement = screen.getByText('Test Data Source') + expect(titleElement).toHaveClass('system-sm-medium', 'text-text-secondary') + }) + + it('should have correct container layout', () => { + const nodeData = createMockNodeData() + + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'items-center', 'gap-x-1.5') + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((Datasource as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) + + describe('edge cases', () => { + it('should handle empty title', () => { + const nodeData = createMockNodeData({ title: '' }) + + render() + + // Should still render without the title text + expect(screen.getByTestId('block-icon')).toBeInTheDocument() + }) + + it('should handle long title', () => { + const longTitle = 'A'.repeat(100) + const nodeData = createMockNodeData({ title: longTitle }) + + render() + + expect(screen.getByText(longTitle)).toBeInTheDocument() + }) + + it('should handle special characters in title', () => { + const nodeData = createMockNodeData({ title: 'Test ' }) + + render() + + expect(screen.getByText('Test ')).toBeInTheDocument() + }) + }) +}) + +describe('GlobalInputs', () => { + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('inputFieldPanel.globalInputs.title')).toBeInTheDocument() + }) + + it('should render title with correct translation key', () => { + render() + + expect(screen.getByText('inputFieldPanel.globalInputs.title')).toBeInTheDocument() + }) + + it('should render tooltip component', () => { + render() + + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + }) + + it('should pass correct tooltip content', () => { + render() + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toHaveAttribute('data-content', 'inputFieldPanel.globalInputs.tooltip') + }) + + it('should have correct tooltip className', () => { + render() + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toHaveClass('w-[240px]') + }) + + it('should have correct container layout', () => { + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'items-center', 'gap-x-1') + }) + + it('should have correct title styling', () => { + render() + + const titleElement = screen.getByText('inputFieldPanel.globalInputs.title') + expect(titleElement).toHaveClass('system-sm-semibold-uppercase', 'text-text-secondary') + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((GlobalInputs as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/publish-toast.spec.tsx b/web/app/components/rag-pipeline/components/publish-toast.spec.tsx new file mode 100644 index 000000000..d61f091ed --- /dev/null +++ b/web/app/components/rag-pipeline/components/publish-toast.spec.tsx @@ -0,0 +1,129 @@ +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import PublishToast from './publish-toast' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock workflow store with controllable state +let mockPublishedAt = 0 +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => { + return selector({ publishedAt: mockPublishedAt }) + }, +})) + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('PublishToast', () => { + beforeEach(() => { + mockPublishedAt = 0 + }) + + describe('rendering', () => { + it('should render when publishedAt is 0', () => { + mockPublishedAt = 0 + render() + + expect(screen.getByText('publishToast.title')).toBeInTheDocument() + }) + + it('should render toast title', () => { + render() + + expect(screen.getByText('publishToast.title')).toBeInTheDocument() + }) + + it('should render toast description', () => { + render() + + expect(screen.getByText('publishToast.desc')).toBeInTheDocument() + }) + + it('should not render when publishedAt is set', () => { + mockPublishedAt = Date.now() + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('should have correct positioning classes', () => { + render() + + const container = screen.getByText('publishToast.title').closest('.absolute') + expect(container).toHaveClass('bottom-[45px]', 'left-0', 'right-0', 'z-10') + }) + + it('should render info icon', () => { + const { container } = render() + + // The RiInformation2Fill icon should be rendered + const iconContainer = container.querySelector('.text-text-accent') + expect(iconContainer).toBeInTheDocument() + }) + + it('should render close button', () => { + const { container } = render() + + // The close button is a div with cursor-pointer, not a semantic button + const closeButton = container.querySelector('.cursor-pointer') + expect(closeButton).toBeInTheDocument() + }) + }) + + describe('user interactions', () => { + it('should hide toast when close button is clicked', () => { + const { container } = render() + + // The close button is a div with cursor-pointer, not a semantic button + const closeButton = container.querySelector('.cursor-pointer') + expect(screen.getByText('publishToast.title')).toBeInTheDocument() + + fireEvent.click(closeButton!) + + expect(screen.queryByText('publishToast.title')).not.toBeInTheDocument() + }) + + it('should remain hidden after close button is clicked', () => { + const { container, rerender } = render() + + // The close button is a div with cursor-pointer, not a semantic button + const closeButton = container.querySelector('.cursor-pointer') + fireEvent.click(closeButton!) + + rerender() + + expect(screen.queryByText('publishToast.title')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have gradient overlay', () => { + const { container } = render() + + const gradientOverlay = container.querySelector('.bg-gradient-to-r') + expect(gradientOverlay).toBeInTheDocument() + }) + + it('should have correct toast width', () => { + render() + + const toastContainer = screen.getByText('publishToast.title').closest('.w-\\[420px\\]') + expect(toastContainer).toBeInTheDocument() + }) + + it('should have rounded border', () => { + render() + + const toastContainer = screen.getByText('publishToast.title').closest('.rounded-xl') + expect(toastContainer).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx index 0cdc9a032..c66b293d8 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx @@ -28,11 +28,12 @@ import { useToastContext } from '@/app/components/base/toast' import { useChecklistBeforePublish, } from '@/app/components/workflow/hooks' +import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { useStore, useWorkflowStore, } from '@/app/components/workflow/store' -import { getKeyboardKeyCodeBySystem, getKeyboardKeyNameBySystem } from '@/app/components/workflow/utils' +import { getKeyboardKeyCodeBySystem } from '@/app/components/workflow/utils' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDocLink } from '@/context/i18n' import { useModalContextSelector } from '@/context/modal-context' @@ -261,13 +262,7 @@ const Popup = () => { : (
    {t('common.publishUpdate', { ns: 'workflow' })} -
    - {PUBLISH_SHORTCUT.map(key => ( - - {getKeyboardKeyNameBySystem(key)} - - ))} -
    +
    ) } diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/run-mode.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/run-mode.tsx index 00c531004..81389e51b 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/run-mode.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/run-mode.tsx @@ -4,9 +4,9 @@ import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' import { useWorkflowRun, useWorkflowStartRun } from '@/app/components/workflow/hooks' +import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { useStore, useWorkflowStore } from '@/app/components/workflow/store' import { WorkflowRunningStatus } from '@/app/components/workflow/types' -import { getKeyboardKeyNameBySystem } from '@/app/components/workflow/utils' import { EVENT_WORKFLOW_STOP } from '@/app/components/workflow/variable-inspect/types' import { useEventEmitterContextContext } from '@/context/event-emitter' import { cn } from '@/utils/classnames' @@ -78,14 +78,7 @@ const RunMode = ({ )} { !isDisabled && ( -
    -
    - {getKeyboardKeyNameBySystem('alt')} -
    -
    - R -
    -
    + ) } diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-main.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-main.spec.tsx new file mode 100644 index 000000000..3de3c3dee --- /dev/null +++ b/web/app/components/rag-pipeline/components/rag-pipeline-main.spec.tsx @@ -0,0 +1,276 @@ +import type { PropsWithChildren } from 'react' +import type { Edge, Node, Viewport } from 'reactflow' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import RagPipelineMain from './rag-pipeline-main' + +// Mock hooks from ../hooks +vi.mock('../hooks', () => ({ + useAvailableNodesMetaData: () => ({ nodes: [], nodesMap: {} }), + useDSL: () => ({ + exportCheck: vi.fn(), + handleExportDSL: vi.fn(), + }), + useGetRunAndTraceUrl: () => ({ + getWorkflowRunAndTraceUrl: vi.fn(), + }), + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: vi.fn(), + syncWorkflowDraftWhenPageClose: vi.fn(), + }), + usePipelineRefreshDraft: () => ({ + handleRefreshWorkflowDraft: vi.fn(), + }), + usePipelineRun: () => ({ + handleBackupDraft: vi.fn(), + handleLoadBackupDraft: vi.fn(), + handleRestoreFromPublishedWorkflow: vi.fn(), + handleRun: vi.fn(), + handleStopRun: vi.fn(), + }), + usePipelineStartRun: () => ({ + handleStartWorkflowRun: vi.fn(), + handleWorkflowStartRunInWorkflow: vi.fn(), + }), +})) + +// Mock useConfigsMap +vi.mock('../hooks/use-configs-map', () => ({ + useConfigsMap: () => ({ + flowId: 'test-flow-id', + flowType: 'ragPipeline', + fileSettings: {}, + }), +})) + +// Mock useInspectVarsCrud +vi.mock('../hooks/use-inspect-vars-crud', () => ({ + useInspectVarsCrud: () => ({ + hasNodeInspectVars: vi.fn(), + hasSetInspectVar: vi.fn(), + fetchInspectVarValue: vi.fn(), + editInspectVarValue: vi.fn(), + renameInspectVarName: vi.fn(), + appendNodeInspectVars: vi.fn(), + deleteInspectVar: vi.fn(), + deleteNodeInspectorVars: vi.fn(), + deleteAllInspectorVars: vi.fn(), + isInspectVarEdited: vi.fn(), + resetToLastRunVar: vi.fn(), + invalidateSysVarValues: vi.fn(), + resetConversationVar: vi.fn(), + invalidateConversationVarValues: vi.fn(), + }), +})) + +// Mock workflow store +const mockSetRagPipelineVariables = vi.fn() +const mockSetEnvironmentVariables = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setRagPipelineVariables: mockSetRagPipelineVariables, + setEnvironmentVariables: mockSetEnvironmentVariables, + }), + }), +})) + +// Mock workflow hooks +vi.mock('@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars', () => ({ + useSetWorkflowVarsWithValue: () => ({ + fetchInspectVars: vi.fn(), + }), +})) + +// Mock WorkflowWithInnerContext +vi.mock('@/app/components/workflow', () => ({ + WorkflowWithInnerContext: ({ children, onWorkflowDataUpdate }: PropsWithChildren<{ onWorkflowDataUpdate?: (payload: unknown) => void }>) => ( +
    + {children} + + +
    + ), +})) + +// Mock RagPipelineChildren +vi.mock('./rag-pipeline-children', () => ({ + default: () =>
    Children
    , +})) + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('RagPipelineMain', () => { + const defaultProps = { + nodes: [] as Node[], + edges: [] as Edge[], + viewport: { x: 0, y: 0, zoom: 1 } as Viewport, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should render RagPipelineChildren', () => { + render() + + expect(screen.getByTestId('rag-pipeline-children')).toBeInTheDocument() + }) + + it('should pass nodes to WorkflowWithInnerContext', () => { + const nodes = [{ id: 'node-1', type: 'custom', position: { x: 0, y: 0 }, data: {} }] as Node[] + + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should pass edges to WorkflowWithInnerContext', () => { + const edges = [{ id: 'edge-1', source: 'node-1', target: 'node-2' }] as Edge[] + + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should pass viewport to WorkflowWithInnerContext', () => { + const viewport = { x: 100, y: 200, zoom: 1.5 } + + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + }) + + describe('handleWorkflowDataUpdate callback', () => { + it('should update rag_pipeline_variables when provided', () => { + render() + + const button = screen.getByTestId('trigger-update') + button.click() + + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([{ id: '1', name: 'var1' }]) + }) + + it('should update environment_variables when provided', () => { + render() + + const button = screen.getByTestId('trigger-update') + button.click() + + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([{ id: '2', name: 'env1' }]) + }) + + it('should only update rag_pipeline_variables when environment_variables is not provided', () => { + render() + + const button = screen.getByTestId('trigger-update-partial') + button.click() + + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([{ id: '3', name: 'var2' }]) + expect(mockSetEnvironmentVariables).not.toHaveBeenCalled() + }) + }) + + describe('hooks integration', () => { + it('should use useNodesSyncDraft hook', () => { + render() + + // If the component renders, the hook was called successfully + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use usePipelineRefreshDraft hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use usePipelineRun hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use usePipelineStartRun hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useAvailableNodesMetaData hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useGetRunAndTraceUrl hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useDSL hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useConfigsMap hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useInspectVarsCrud hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + }) + + describe('edge cases', () => { + it('should handle empty nodes array', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should handle empty edges array', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should handle default viewport', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx new file mode 100644 index 000000000..f57bd80d7 --- /dev/null +++ b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx @@ -0,0 +1,1088 @@ +import type { PropsWithChildren } from 'react' +import { act, cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { DSLImportStatus } from '@/models/app' +import UpdateDSLModal from './update-dsl-modal' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock use-context-selector +const mockNotify = vi.fn() +vi.mock('use-context-selector', () => ({ + useContext: () => ({ notify: mockNotify }), +})) + +// Mock toast context +vi.mock('@/app/components/base/toast', () => ({ + ToastContext: { Provider: ({ children }: PropsWithChildren) => children }, +})) + +// Mock event emitter +const mockEmit = vi.fn() +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { emit: mockEmit }, + }), +})) + +// Mock workflow store +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + pipelineId: 'test-pipeline-id', + }), + }), +})) + +// Mock workflow utils +vi.mock('@/app/components/workflow/utils', () => ({ + initialNodes: (nodes: unknown[]) => nodes, + initialEdges: (edges: unknown[]) => edges, +})) + +// Mock plugin dependencies +const mockHandleCheckPluginDependencies = vi.fn() +vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({ + usePluginDependencies: () => ({ + handleCheckPluginDependencies: mockHandleCheckPluginDependencies, + }), +})) + +// Mock pipeline service +const mockImportDSL = vi.fn() +const mockImportDSLConfirm = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useImportPipelineDSL: () => ({ mutateAsync: mockImportDSL }), + useImportPipelineDSLConfirm: () => ({ mutateAsync: mockImportDSLConfirm }), +})) + +// Mock workflow service +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: vi.fn().mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: { x: 0, y: 0, zoom: 1 } }, + hash: 'test-hash', + rag_pipeline_variables: [], + }), +})) + +// Mock Uploader +vi.mock('@/app/components/app/create-from-dsl-modal/uploader', () => ({ + default: ({ updateFile }: { updateFile: (file?: File) => void }) => ( +
    + { + const file = e.target.files?.[0] + updateFile(file) + }} + /> + +
    + ), +})) + +// Mock Button +vi.mock('@/app/components/base/button', () => ({ + default: ({ children, onClick, disabled, className, variant, loading }: { + children: React.ReactNode + onClick?: () => void + disabled?: boolean + className?: string + variant?: string + loading?: boolean + }) => ( + + ), +})) + +// Mock Modal +vi.mock('@/app/components/base/modal', () => ({ + default: ({ children, isShow, _onClose, className }: PropsWithChildren<{ + isShow: boolean + _onClose: () => void + className?: string + }>) => isShow + ? ( +
    + {children} +
    + ) + : null, +})) + +// Mock workflow constants +vi.mock('@/app/components/workflow/constants', () => ({ + WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE', +})) + +// Mock FileReader +class MockFileReader { + result: string | null = null + onload: ((e: { target: { result: string | null } }) => void) | null = null + + readAsText(_file: File) { + // Simulate async file reading using queueMicrotask for more reliable async behavior + queueMicrotask(() => { + this.result = 'test file content' + if (this.onload) { + this.onload({ target: { result: this.result } }) + } + }) + } +} + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('UpdateDSLModal', () => { + const mockOnCancel = vi.fn() + const mockOnBackup = vi.fn() + const mockOnImport = vi.fn() + let originalFileReader: typeof FileReader + + const defaultProps = { + onCancel: mockOnCancel, + onBackup: mockOnBackup, + onImport: mockOnImport, + } + + beforeEach(() => { + vi.clearAllMocks() + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + // Mock FileReader + originalFileReader = globalThis.FileReader + globalThis.FileReader = MockFileReader as unknown as typeof FileReader + }) + + afterEach(() => { + globalThis.FileReader = originalFileReader + }) + + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should render title', () => { + render() + + // The component uses t('common.importDSL', { ns: 'workflow' }) which returns 'common.importDSL' + expect(screen.getByText('common.importDSL')).toBeInTheDocument() + }) + + it('should render warning tip', () => { + render() + + // The component uses t('common.importDSLTip', { ns: 'workflow' }) + expect(screen.getByText('common.importDSLTip')).toBeInTheDocument() + }) + + it('should render uploader', () => { + render() + + expect(screen.getByTestId('uploader')).toBeInTheDocument() + }) + + it('should render backup button', () => { + render() + + // The component uses t('common.backupCurrentDraft', { ns: 'workflow' }) + expect(screen.getByText('common.backupCurrentDraft')).toBeInTheDocument() + }) + + it('should render cancel button', () => { + render() + + // The component uses t('newApp.Cancel', { ns: 'app' }) + expect(screen.getByText('newApp.Cancel')).toBeInTheDocument() + }) + + it('should render import button', () => { + render() + + // The component uses t('common.overwriteAndImport', { ns: 'workflow' }) + expect(screen.getByText('common.overwriteAndImport')).toBeInTheDocument() + }) + + it('should render choose DSL section', () => { + render() + + // The component uses t('common.chooseDSL', { ns: 'workflow' }) + expect(screen.getByText('common.chooseDSL')).toBeInTheDocument() + }) + }) + + describe('user interactions', () => { + it('should call onCancel when cancel button is clicked', () => { + render() + + const cancelButton = screen.getByText('newApp.Cancel') + fireEvent.click(cancelButton) + + expect(mockOnCancel).toHaveBeenCalled() + }) + + it('should call onBackup when backup button is clicked', () => { + render() + + const backupButton = screen.getByText('common.backupCurrentDraft') + fireEvent.click(backupButton) + + expect(mockOnBackup).toHaveBeenCalled() + }) + + it('should handle file upload', async () => { + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + + fireEvent.change(fileInput, { target: { files: [file] } }) + + // File should be processed + await waitFor(() => { + expect(screen.getByTestId('uploader')).toBeInTheDocument() + }) + }) + + it('should clear file when clear button is clicked', () => { + render() + + const clearButton = screen.getByTestId('clear-file') + fireEvent.click(clearButton) + + // File should be cleared + expect(screen.getByTestId('uploader')).toBeInTheDocument() + }) + + it('should call onCancel when close icon is clicked', () => { + render() + + // The close icon is in a div with onClick={onCancel} + const closeIconContainer = document.querySelector('.cursor-pointer') + if (closeIconContainer) { + fireEvent.click(closeIconContainer) + expect(mockOnCancel).toHaveBeenCalled() + } + }) + }) + + describe('import functionality', () => { + it('should show import button disabled when no file is selected', () => { + render() + + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).toBeDisabled() + }) + + it('should enable import button when file is selected', async () => { + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + }) + + it('should disable import button after file is cleared', async () => { + render() + + // First select a file + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + // Clear the file + const clearButton = screen.getByTestId('clear-file') + fireEvent.click(clearButton) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).toBeDisabled() + }) + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((UpdateDSLModal as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) + + describe('edge cases', () => { + it('should handle missing onImport callback', () => { + const props = { + onCancel: mockOnCancel, + onBackup: mockOnBackup, + } + + render() + + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should render import button with warning variant', () => { + render() + + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).toHaveAttribute('data-variant', 'warning') + }) + + it('should render backup button with secondary variant', () => { + render() + + // The backup button text is inside a nested div, so we need to find the closest button + const backupButtonText = screen.getByText('common.backupCurrentDraft') + const backupButton = backupButtonText.closest('button') + expect(backupButton).toHaveAttribute('data-variant', 'secondary') + }) + }) + + describe('import flow', () => { + it('should call importDSL when import button is clicked with file content', async () => { + render() + + // Select a file + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + // Wait for FileReader to process + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + // Click import button + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for import to be called + await waitFor(() => { + expect(mockImportDSL).toHaveBeenCalled() + }) + }) + + it('should show success notification on completed import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + // Select a file and click import + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + }) + }) + + it('should call onCancel after successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockOnCancel).toHaveBeenCalled() + }) + }) + + it('should call onImport after successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }, { timeout: 1000 }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockOnImport).toHaveBeenCalled() + }, { timeout: 1000 }) + }) + + it('should show warning notification on import with warnings', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED_WITH_WARNINGS, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'warning', + })) + }) + }) + + it('should show error notification when import fails', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.FAILED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error notification when pipeline_id is missing on success', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: undefined, + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error notification when import throws exception', async () => { + mockImportDSL.mockRejectedValue(new Error('Import failed')) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + // Wait for FileReader to complete (setTimeout 0) and button to be enabled + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + // Give extra time for the FileReader's setTimeout to complete + await new Promise(resolve => setTimeout(resolve, 10)) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should call handleCheckPluginDependencies on successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true) + }) + }) + + it('should emit WORKFLOW_DATA_UPDATE event after successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockEmit).toHaveBeenCalled() + }) + }) + + it('should show error modal when import status is PENDING', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + + await act(async () => { + fireEvent.change(fileInput, { target: { files: [file] } }) + // Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask) + await new Promise(resolve => queueMicrotask(resolve)) + }) + + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + + await act(async () => { + fireEvent.click(importButton) + // Flush the promise resolution from mockImportDSL + await Promise.resolve() + // Advance past the 300ms setTimeout in the component + await vi.advanceTimersByTimeAsync(350) + }) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }) + + vi.useRealTimers() + }) + + it('should show version info in error modal', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for error modal with version info + await waitFor(() => { + expect(screen.getByText('1.0.0')).toBeInTheDocument() + expect(screen.getByText('2.0.0')).toBeInTheDocument() + }, { timeout: 500 }) + }) + + it('should close error modal when cancel button is clicked', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for error modal + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + // Find and click cancel button in error modal - it should be the one with secondary variant + const cancelButtons = screen.getAllByText('newApp.Cancel') + const errorModalCancelButton = cancelButtons.find(btn => + btn.getAttribute('data-variant') === 'secondary', + ) + if (errorModalCancelButton) { + fireEvent.click(errorModalCancelButton) + } + + // Modal should be closed + await waitFor(() => { + expect(screen.queryByText('newApp.appCreateDSLErrorTitle')).not.toBeInTheDocument() + }) + }) + + it('should call importDSLConfirm when confirm button is clicked in error modal', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for error modal + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + // Click confirm button + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-id') + }) + }) + + it('should show success notification after confirm completes', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + }) + }) + + it('should show error notification when confirm fails with FAILED status', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.FAILED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error notification when confirm throws exception', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockRejectedValue(new Error('Confirm failed')) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error when confirm completes but pipeline_id is missing', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: undefined, + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should call onImport after confirm completes successfully', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockOnImport).toHaveBeenCalled() + }) + }) + + it('should call handleCheckPluginDependencies after confirm', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true) + }) + }) + + it('should handle undefined imported_dsl_version and current_dsl_version', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: undefined, + current_dsl_version: undefined, + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Should show error modal even with undefined versions + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + }) + + it('should not call importDSLConfirm when importId is not set', async () => { + // Render without triggering PENDING status first + render() + + // importId is not set, so confirm should not be called + // This is hard to test directly, but we can verify by checking the confirm flow + expect(mockImportDSLConfirm).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/index.spec.ts b/web/app/components/rag-pipeline/hooks/index.spec.ts new file mode 100644 index 000000000..7917275c1 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/index.spec.ts @@ -0,0 +1,536 @@ +import type { RAGPipelineVariables, VAR_TYPE_MAP } from '@/models/pipeline' +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { BlockEnum } from '@/app/components/workflow/types' +import { Resolution, TransferMethod } from '@/types/app' +import { FlowType } from '@/types/common' + +// ============================================================================ +// Import hooks after mocks +// ============================================================================ + +import { + useAvailableNodesMetaData, + useDSL, + useGetRunAndTraceUrl, + useInputFieldPanel, + useNodesSyncDraft, + usePipelineInit, + usePipelineRefreshDraft, + usePipelineRun, + usePipelineStartRun, +} from './index' +import { useConfigsMap } from './use-configs-map' +import { useConfigurations, useInitialData } from './use-input-fields' +import { usePipelineTemplate } from './use-pipeline-template' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock the workflow store +const _mockGetState = vi.fn() +const mockUseStore = vi.fn() +const mockUseWorkflowStore = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => mockUseStore(selector), + useWorkflowStore: () => mockUseWorkflowStore(), +})) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock toast context +const mockNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: () => ({ + notify: mockNotify, + }), +})) + +// Mock event emitter context +const mockEventEmit = vi.fn() +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: mockEventEmit, + }, + }), +})) + +// Mock i18n docLink +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.dify.ai${path}`, +})) + +// Mock workflow constants +vi.mock('@/app/components/workflow/constants', () => ({ + DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', + WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE', + START_INITIAL_POSITION: { x: 100, y: 100 }, +})) + +// Mock workflow constants/node +vi.mock('@/app/components/workflow/constants/node', () => ({ + WORKFLOW_COMMON_NODES: [ + { + metaData: { type: BlockEnum.Start }, + defaultValue: { type: BlockEnum.Start }, + }, + { + metaData: { type: BlockEnum.End }, + defaultValue: { type: BlockEnum.End }, + }, + ], +})) + +// Mock data source defaults +vi.mock('@/app/components/workflow/nodes/data-source-empty/default', () => ({ + default: { + metaData: { type: BlockEnum.DataSourceEmpty }, + defaultValue: { type: BlockEnum.DataSourceEmpty }, + }, +})) + +vi.mock('@/app/components/workflow/nodes/data-source/default', () => ({ + default: { + metaData: { type: BlockEnum.DataSource }, + defaultValue: { type: BlockEnum.DataSource }, + }, +})) + +vi.mock('@/app/components/workflow/nodes/knowledge-base/default', () => ({ + default: { + metaData: { type: BlockEnum.KnowledgeBase }, + defaultValue: { type: BlockEnum.KnowledgeBase }, + }, +})) + +// Mock workflow utils with all needed exports +vi.mock('@/app/components/workflow/utils', async (importOriginal) => { + const actual = await importOriginal() as Record + return { + ...actual, + generateNewNode: ({ id, data, position }: { id: string, data: object, position: { x: number, y: number } }) => ({ + newNode: { id, data, position, type: 'custom' }, + }), + } +}) + +// Mock pipeline service +const mockExportPipelineConfig = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipelineConfig, + }), +})) + +// Mock workflow service +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: vi.fn().mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + environment_variables: [], + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useConfigsMap', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + pipelineId: 'test-pipeline-id', + fileUploadConfig: { max_file_size: 10 }, + } + return selector(state) + }) + }) + + it('should return config map with correct flowId', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.flowId).toBe('test-pipeline-id') + }) + + it('should return config map with correct flowType', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.flowType).toBe(FlowType.ragPipeline) + }) + + it('should return file settings with image config', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.fileSettings.image).toEqual({ + enabled: false, + detail: Resolution.high, + number_limits: 3, + transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url], + }) + }) + + it('should include fileUploadConfig from store', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.fileSettings.fileUploadConfig).toEqual({ max_file_size: 10 }) + }) +}) + +describe('useGetRunAndTraceUrl', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseWorkflowStore.mockReturnValue({ + getState: () => ({ + pipelineId: 'pipeline-123', + }), + }) + }) + + it('should return getWorkflowRunAndTraceUrl function', () => { + const { result } = renderHook(() => useGetRunAndTraceUrl()) + + expect(result.current.getWorkflowRunAndTraceUrl).toBeDefined() + expect(typeof result.current.getWorkflowRunAndTraceUrl).toBe('function') + }) + + it('should generate correct run URL', () => { + const { result } = renderHook(() => useGetRunAndTraceUrl()) + + const { runUrl } = result.current.getWorkflowRunAndTraceUrl('run-456') + + expect(runUrl).toBe('/rag/pipelines/pipeline-123/workflow-runs/run-456') + }) + + it('should generate correct trace URL', () => { + const { result } = renderHook(() => useGetRunAndTraceUrl()) + + const { traceUrl } = result.current.getWorkflowRunAndTraceUrl('run-456') + + expect(traceUrl).toBe('/rag/pipelines/pipeline-123/workflow-runs/run-456/node-executions') + }) +}) + +describe('useInputFieldPanel', () => { + const mockSetShowInputFieldPanel = vi.fn() + const mockSetShowInputFieldPreviewPanel = vi.fn() + const mockSetInputFieldEditPanelProps = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + showInputFieldPreviewPanel: false, + inputFieldEditPanelProps: null, + } + return selector(state) + }) + mockUseWorkflowStore.mockReturnValue({ + getState: () => ({ + showInputFieldPreviewPanel: false, + setShowInputFieldPanel: mockSetShowInputFieldPanel, + setShowInputFieldPreviewPanel: mockSetShowInputFieldPreviewPanel, + setInputFieldEditPanelProps: mockSetInputFieldEditPanelProps, + }), + }) + }) + + it('should return isPreviewing as false when showInputFieldPreviewPanel is false', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isPreviewing).toBe(false) + }) + + it('should return isPreviewing as true when showInputFieldPreviewPanel is true', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + showInputFieldPreviewPanel: true, + inputFieldEditPanelProps: null, + } + return selector(state) + }) + + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isPreviewing).toBe(true) + }) + + it('should return isEditing as false when inputFieldEditPanelProps is null', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isEditing).toBe(false) + }) + + it('should return isEditing as true when inputFieldEditPanelProps exists', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + showInputFieldPreviewPanel: false, + inputFieldEditPanelProps: { some: 'props' }, + } + return selector(state) + }) + + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isEditing).toBe(true) + }) + + it('should call all setters when closeAllInputFieldPanels is called', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + act(() => { + result.current.closeAllInputFieldPanels() + }) + + expect(mockSetShowInputFieldPanel).toHaveBeenCalledWith(false) + expect(mockSetShowInputFieldPreviewPanel).toHaveBeenCalledWith(false) + expect(mockSetInputFieldEditPanelProps).toHaveBeenCalledWith(null) + }) + + it('should toggle preview panel when toggleInputFieldPreviewPanel is called', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + act(() => { + result.current.toggleInputFieldPreviewPanel() + }) + + expect(mockSetShowInputFieldPreviewPanel).toHaveBeenCalledWith(true) + }) + + it('should set edit panel props when toggleInputFieldEditPanel is called', () => { + const { result } = renderHook(() => useInputFieldPanel()) + const editContent = { type: 'edit', data: {} } + + act(() => { + // eslint-disable-next-line ts/no-explicit-any + result.current.toggleInputFieldEditPanel(editContent as any) + }) + + expect(mockSetInputFieldEditPanelProps).toHaveBeenCalledWith(editContent) + }) +}) + +describe('useInitialData', () => { + it('should return empty object for empty variables', () => { + const { result } = renderHook(() => useInitialData([], undefined)) + + expect(result.current).toEqual({}) + }) + + it('should handle text input type with default value', () => { + const variables: RAGPipelineVariables = [ + { + type: 'text-input' as keyof typeof VAR_TYPE_MAP, + variable: 'textVar', + label: 'Text', + required: false, + default_value: 'default text', + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, undefined)) + + expect(result.current.textVar).toBe('default text') + }) + + it('should use lastRunInputData over default value', () => { + const variables: RAGPipelineVariables = [ + { + type: 'text-input' as keyof typeof VAR_TYPE_MAP, + variable: 'textVar', + label: 'Text', + required: false, + default_value: 'default text', + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, { textVar: 'last run value' })) + + expect(result.current.textVar).toBe('last run value') + }) + + it('should handle number input type with default 0', () => { + const variables: RAGPipelineVariables = [ + { + type: 'number' as keyof typeof VAR_TYPE_MAP, + variable: 'numVar', + label: 'Number', + required: false, + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, undefined)) + + expect(result.current.numVar).toBe(0) + }) + + it('should handle file type with default empty array', () => { + const variables: RAGPipelineVariables = [ + { + type: 'file' as keyof typeof VAR_TYPE_MAP, + variable: 'fileVar', + label: 'File', + required: false, + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, undefined)) + + expect(result.current.fileVar).toEqual([]) + }) +}) + +describe('useConfigurations', () => { + it('should return empty array for empty variables', () => { + const { result } = renderHook(() => useConfigurations([])) + + expect(result.current).toEqual([]) + }) + + it('should transform variables to configurations', () => { + const variables: RAGPipelineVariables = [ + { + type: 'text-input' as keyof typeof VAR_TYPE_MAP, + variable: 'textVar', + label: 'Text Label', + required: true, + max_length: 100, + placeholder: 'Enter text', + tooltips: 'Help text', + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useConfigurations(variables)) + + expect(result.current.length).toBe(1) + expect(result.current[0].variable).toBe('textVar') + expect(result.current[0].label).toBe('Text Label') + expect(result.current[0].required).toBe(true) + expect(result.current[0].maxLength).toBe(100) + expect(result.current[0].placeholder).toBe('Enter text') + expect(result.current[0].tooltip).toBe('Help text') + }) + + it('should transform options correctly', () => { + const variables: RAGPipelineVariables = [ + { + type: 'select' as keyof typeof VAR_TYPE_MAP, + variable: 'selectVar', + label: 'Select', + required: false, + options: ['option1', 'option2', 'option3'], + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useConfigurations(variables)) + + expect(result.current[0].options).toEqual([ + { label: 'option1', value: 'option1' }, + { label: 'option2', value: 'option2' }, + { label: 'option3', value: 'option3' }, + ]) + }) +}) + +describe('useAvailableNodesMetaData', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should return nodes array', () => { + const { result } = renderHook(() => useAvailableNodesMetaData()) + + expect(result.current.nodes).toBeDefined() + expect(Array.isArray(result.current.nodes)).toBe(true) + }) + + it('should return nodesMap object', () => { + const { result } = renderHook(() => useAvailableNodesMetaData()) + + expect(result.current.nodesMap).toBeDefined() + expect(typeof result.current.nodesMap).toBe('object') + }) +}) + +describe('usePipelineTemplate', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should return nodes array with knowledge base node', () => { + const { result } = renderHook(() => usePipelineTemplate()) + + expect(result.current.nodes).toBeDefined() + expect(Array.isArray(result.current.nodes)).toBe(true) + expect(result.current.nodes.length).toBe(1) + }) + + it('should return empty edges array', () => { + const { result } = renderHook(() => usePipelineTemplate()) + + expect(result.current.edges).toEqual([]) + }) +}) + +describe('useDSL', () => { + it('should be defined and exported', () => { + expect(useDSL).toBeDefined() + expect(typeof useDSL).toBe('function') + }) +}) + +describe('exports', () => { + it('should export useAvailableNodesMetaData', () => { + expect(useAvailableNodesMetaData).toBeDefined() + }) + + it('should export useDSL', () => { + expect(useDSL).toBeDefined() + }) + + it('should export useGetRunAndTraceUrl', () => { + expect(useGetRunAndTraceUrl).toBeDefined() + }) + + it('should export useInputFieldPanel', () => { + expect(useInputFieldPanel).toBeDefined() + }) + + it('should export useNodesSyncDraft', () => { + expect(useNodesSyncDraft).toBeDefined() + }) + + it('should export usePipelineInit', () => { + expect(usePipelineInit).toBeDefined() + }) + + it('should export usePipelineRefreshDraft', () => { + expect(usePipelineRefreshDraft).toBeDefined() + }) + + it('should export usePipelineRun', () => { + expect(usePipelineRun).toBeDefined() + }) + + it('should export usePipelineStartRun', () => { + expect(usePipelineStartRun).toBeDefined() + }) +}) + +afterEach(() => { + vi.clearAllMocks() +}) diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts new file mode 100644 index 000000000..0d217f360 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts @@ -0,0 +1,353 @@ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { useDSL } from './use-DSL' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock toast context +const mockNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: () => ({ + notify: mockNotify, + }), +})) + +// Mock event emitter context +const mockEmit = vi.fn() +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: mockEmit, + }, + }), +})) + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useNodesSyncDraft +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('./use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: mockDoSyncWorkflowDraft, + }), +})) + +// Mock pipeline service +const mockExportPipelineConfig = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipelineConfig, + }), +})) + +// Mock download utility +const mockDownloadBlob = vi.fn() +vi.mock('@/utils/download', () => ({ + downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args), +})) + +// Mock workflow service +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), +})) + +// Mock workflow constants +vi.mock('@/app/components/workflow/constants', () => ({ + DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useDSL', () => { + beforeEach(() => { + vi.clearAllMocks() + + // Default store state + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + knowledgeName: 'Test Knowledge Base', + }) + + mockDoSyncWorkflowDraft.mockResolvedValue(undefined) + mockExportPipelineConfig.mockResolvedValue({ data: 'yaml-content' }) + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [], + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return exportCheck function', () => { + const { result } = renderHook(() => useDSL()) + + expect(result.current.exportCheck).toBeDefined() + expect(typeof result.current.exportCheck).toBe('function') + }) + + it('should return handleExportDSL function', () => { + const { result } = renderHook(() => useDSL()) + + expect(result.current.handleExportDSL).toBeDefined() + expect(typeof result.current.handleExportDSL).toBe('function') + }) + }) + + describe('handleExportDSL', () => { + it('should not export when pipelineId is missing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + knowledgeName: 'Test', + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDoSyncWorkflowDraft).not.toHaveBeenCalled() + expect(mockExportPipelineConfig).not.toHaveBeenCalled() + }) + + it('should sync workflow draft before export', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should call exportPipelineConfig with correct params', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL(true) + }) + + expect(mockExportPipelineConfig).toHaveBeenCalledWith({ + pipelineId: 'test-pipeline-id', + include: true, + }) + }) + + it('should create and download file', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDownloadBlob).toHaveBeenCalled() + }) + + it('should use correct file extension for download', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDownloadBlob).toHaveBeenCalledWith( + expect.objectContaining({ + fileName: 'Test Knowledge Base.pipeline', + }), + ) + }) + + it('should pass blob data to downloadBlob', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDownloadBlob).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.any(Blob), + }), + ) + }) + + it('should show error notification on export failure', async () => { + mockExportPipelineConfig.mockRejectedValue(new Error('Export failed')) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', + }) + }) + }) + + describe('exportCheck', () => { + it('should not check when pipelineId is missing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + knowledgeName: 'Test', + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockFetchWorkflowDraft).not.toHaveBeenCalled() + }) + + it('should fetch workflow draft', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + }) + + it('should directly export when no secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'string', value: 'test' }, + ], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should call doSyncWorkflowDraft (which means handleExportDSL was called) + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should emit DSL_EXPORT_CHECK event when secret variables exist', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'secret', value: 'secret-value' }, + ], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockEmit).toHaveBeenCalledWith({ + type: 'DSL_EXPORT_CHECK', + payload: { + data: [{ id: '1', value_type: 'secret', value: 'secret-value' }], + }, + }) + }) + + it('should show error notification on check failure', async () => { + mockFetchWorkflowDraft.mockRejectedValue(new Error('Fetch failed')) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', + }) + }) + + it('should filter only secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'string', value: 'plain' }, + { id: '2', value_type: 'secret', value: 'secret1' }, + { id: '3', value_type: 'number', value: '123' }, + { id: '4', value_type: 'secret', value: 'secret2' }, + ], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockEmit).toHaveBeenCalledWith({ + type: 'DSL_EXPORT_CHECK', + payload: { + data: [ + { id: '2', value_type: 'secret', value: 'secret1' }, + { id: '4', value_type: 'secret', value: 'secret2' }, + ], + }, + }) + }) + + it('should handle empty environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should directly call handleExportDSL since no secrets + expect(mockEmit).not.toHaveBeenCalled() + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should handle undefined environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: undefined, + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should directly call handleExportDSL since no secrets + expect(mockEmit).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.ts b/web/app/components/rag-pipeline/hooks/use-DSL.ts index 1660d555e..5c0f9def1 100644 --- a/web/app/components/rag-pipeline/hooks/use-DSL.ts +++ b/web/app/components/rag-pipeline/hooks/use-DSL.ts @@ -11,6 +11,7 @@ import { useWorkflowStore } from '@/app/components/workflow/store' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useExportPipelineDSL } from '@/service/use-pipeline' import { fetchWorkflowDraft } from '@/service/workflow' +import { downloadBlob } from '@/utils/download' import { useNodesSyncDraft } from './use-nodes-sync-draft' export const useDSL = () => { @@ -37,13 +38,8 @@ export const useDSL = () => { pipelineId, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${knowledgeName}.pipeline` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${knowledgeName}.pipeline` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.spec.ts b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.spec.ts new file mode 100644 index 000000000..5817d187a --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.spec.ts @@ -0,0 +1,469 @@ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { useNodesSyncDraft } from './use-nodes-sync-draft' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock reactflow +const mockGetNodes = vi.fn() +const mockStoreGetState = vi.fn() + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: mockStoreGetState, + }), +})) + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useNodesReadOnly +const mockGetNodesReadOnly = vi.fn() +vi.mock('@/app/components/workflow/hooks/use-workflow', () => ({ + useNodesReadOnly: () => ({ + getNodesReadOnly: mockGetNodesReadOnly, + }), +})) + +// Mock useSerialAsyncCallback - must pass through arguments +vi.mock('@/app/components/workflow/hooks/use-serial-async-callback', () => ({ + useSerialAsyncCallback: (fn: (...args: unknown[]) => Promise, checkFn: () => boolean) => { + return (...args: unknown[]) => { + if (!checkFn()) { + return fn(...args) + } + } + }, +})) + +// Mock service +const mockSyncWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + syncWorkflowDraft: (params: unknown) => mockSyncWorkflowDraft(params), +})) + +// Mock usePipelineRefreshDraft +const mockHandleRefreshWorkflowDraft = vi.fn() +vi.mock('@/app/components/rag-pipeline/hooks', () => ({ + usePipelineRefreshDraft: () => ({ + handleRefreshWorkflowDraft: mockHandleRefreshWorkflowDraft, + }), +})) + +// Mock API_PREFIX +vi.mock('@/config', () => ({ + API_PREFIX: '/api', +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useNodesSyncDraft', () => { + const mockSendBeacon = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + // Setup navigator.sendBeacon mock + Object.defineProperty(navigator, 'sendBeacon', { + value: mockSendBeacon, + writable: true, + configurable: true, + }) + + // Default store state + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + edges: [], + transform: [0, 0, 1], + }) + + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', _temp: true }, position: { x: 0, y: 0 } }, + { id: 'node-2', data: { type: 'end' }, position: { x: 100, y: 0 } }, + ]) + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }) + + mockGetNodesReadOnly.mockReturnValue(false) + mockSyncWorkflowDraft.mockResolvedValue({ + hash: 'new-hash', + updated_at: '2024-01-01T00:00:00Z', + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return doSyncWorkflowDraft function', () => { + const { result } = renderHook(() => useNodesSyncDraft()) + + expect(result.current.doSyncWorkflowDraft).toBeDefined() + expect(typeof result.current.doSyncWorkflowDraft).toBe('function') + }) + + it('should return syncWorkflowDraftWhenPageClose function', () => { + const { result } = renderHook(() => useNodesSyncDraft()) + + expect(result.current.syncWorkflowDraftWhenPageClose).toBeDefined() + expect(typeof result.current.syncWorkflowDraftWhenPageClose).toBe('function') + }) + }) + + describe('syncWorkflowDraftWhenPageClose', () => { + it('should not call sendBeacon when nodes are read only', () => { + mockGetNodesReadOnly.mockReturnValue(true) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should call sendBeacon with correct URL and params', () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).toHaveBeenCalledWith( + '/api/rag/pipelines/test-pipeline-id/workflows/draft', + expect.any(String), + ) + }) + + it('should not call sendBeacon when pipelineId is missing', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + }) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should not call sendBeacon when nodes array is empty', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should filter out temp nodes', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', _isTempNode: true }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + // Should not call sendBeacon because after filtering temp nodes, array is empty + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should remove underscore-prefixed data keys from nodes', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', _privateData: 'secret' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).toHaveBeenCalled() + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.graph.nodes[0].data._privateData).toBeUndefined() + }) + }) + + describe('doSyncWorkflowDraft', () => { + it('should not sync when nodes are read only', async () => { + mockGetNodesReadOnly.mockReturnValue(true) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft() + }) + + expect(mockSyncWorkflowDraft).not.toHaveBeenCalled() + }) + + it('should call syncWorkflowDraft service', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft() + }) + + expect(mockSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should call onSuccess callback when sync succeeds', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + const onSuccess = vi.fn() + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false, { onSuccess }) + }) + + expect(onSuccess).toHaveBeenCalled() + }) + + it('should call onSettled callback after sync completes', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + const onSettled = vi.fn() + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false, { onSettled }) + }) + + expect(onSettled).toHaveBeenCalled() + }) + + it('should call onError callback when sync fails', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + mockSyncWorkflowDraft.mockRejectedValue(new Error('Sync failed')) + const onError = vi.fn() + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false, { onError }) + }) + + expect(onError).toHaveBeenCalled() + }) + + it('should update hash and draft updated at on success', async () => { + const mockSetSyncWorkflowDraftHash = vi.fn() + const mockSetDraftUpdatedAt = vi.fn() + + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setDraftUpdatedAt: mockSetDraftUpdatedAt, + }) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft() + }) + + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash') + expect(mockSetDraftUpdatedAt).toHaveBeenCalledWith('2024-01-01T00:00:00Z') + }) + + it('should handle draft not sync error', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), + bodyUsed: false, + } + mockSyncWorkflowDraft.mockRejectedValue(mockJsonError) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false) + }) + + // Wait for json to be called + await new Promise(resolve => setTimeout(resolve, 0)) + + expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalled() + }) + + it('should not refresh when notRefreshWhenSyncError is true', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), + bodyUsed: false, + } + mockSyncWorkflowDraft.mockRejectedValue(mockJsonError) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(true) + }) + + // Wait for json to be called + await new Promise(resolve => setTimeout(resolve, 0)) + + expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() + }) + }) + + describe('getPostParams', () => { + it('should include viewport coordinates in params', () => { + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + edges: [], + transform: [100, 200, 1.5], + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.graph.viewport).toEqual({ x: 100, y: 200, zoom: 1.5 }) + }) + + it('should include environment variables in params', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [{ key: 'API_KEY', value: 'secret' }], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.environment_variables).toEqual([{ key: 'API_KEY', value: 'secret' }]) + }) + + it('should include rag pipeline variables in params', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [{ variable: 'input', type: 'text-input' }], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.rag_pipeline_variables).toEqual([{ variable: 'input', type: 'text-input' }]) + }) + + it('should remove underscore-prefixed keys from edges', () => { + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + edges: [{ id: 'edge-1', source: 'node-1', target: 'node-2', data: { _hidden: true, visible: false } }], + transform: [0, 0, 1], + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.graph.edges[0].data._hidden).toBeUndefined() + expect(sentData.graph.edges[0].data.visible).toBe(false) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-config.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-config.spec.ts new file mode 100644 index 000000000..491d2828d --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-config.spec.ts @@ -0,0 +1,299 @@ +import { renderHook } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineConfig } from './use-pipeline-config' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockUseStore = vi.fn() +const mockWorkflowStoreGetState = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => mockUseStore(selector), + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useWorkflowConfig +const mockUseWorkflowConfig = vi.fn() +vi.mock('@/service/use-workflow', () => ({ + useWorkflowConfig: (url: string, callback: (data: unknown) => void) => mockUseWorkflowConfig(url, callback), +})) + +// Mock useDataSourceList +const mockUseDataSourceList = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useDataSourceList: (enabled: boolean, callback: (data: unknown) => void) => mockUseDataSourceList(enabled, callback), +})) + +// Mock basePath +vi.mock('@/utils/var', () => ({ + basePath: '/base', +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineConfig', () => { + const mockSetNodesDefaultConfigs = vi.fn() + const mockSetPublishedAt = vi.fn() + const mockSetDataSourceList = vi.fn() + const mockSetFileUploadConfig = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { pipelineId: 'test-pipeline-id' } + return selector(state) + }) + + mockWorkflowStoreGetState.mockReturnValue({ + setNodesDefaultConfigs: mockSetNodesDefaultConfigs, + setPublishedAt: mockSetPublishedAt, + setDataSourceList: mockSetDataSourceList, + setFileUploadConfig: mockSetFileUploadConfig, + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should render without crashing', () => { + expect(() => renderHook(() => usePipelineConfig())).not.toThrow() + }) + + it('should call useWorkflowConfig with correct URL for nodes default configs', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflows/default-workflow-block-configs', + expect.any(Function), + ) + }) + + it('should call useWorkflowConfig with correct URL for published workflow', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflows/publish', + expect.any(Function), + ) + }) + + it('should call useWorkflowConfig with correct URL for file upload config', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith( + '/files/upload', + expect.any(Function), + ) + }) + + it('should call useDataSourceList when pipelineId exists', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseDataSourceList).toHaveBeenCalledWith(true, expect.any(Function)) + }) + + it('should call useDataSourceList with false when pipelineId is missing', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { pipelineId: undefined } + return selector(state) + }) + + renderHook(() => usePipelineConfig()) + + expect(mockUseDataSourceList).toHaveBeenCalledWith(false, expect.any(Function)) + }) + + it('should use empty URL when pipelineId is missing for nodes configs', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { pipelineId: undefined } + return selector(state) + }) + + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith('', expect.any(Function)) + }) + }) + + describe('handleUpdateNodesDefaultConfigs', () => { + it('should handle array format configs', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('default-workflow-block-configs')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + const arrayConfigs = [ + { type: 'llm', config: { model: 'gpt-4' } }, + { type: 'code', config: { language: 'python' } }, + ] + + capturedCallback?.(arrayConfigs) + + expect(mockSetNodesDefaultConfigs).toHaveBeenCalledWith({ + llm: { model: 'gpt-4' }, + code: { language: 'python' }, + }) + }) + + it('should handle object format configs', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('default-workflow-block-configs')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + const objectConfigs = { + llm: { model: 'gpt-4' }, + code: { language: 'python' }, + } + + capturedCallback?.(objectConfigs) + + expect(mockSetNodesDefaultConfigs).toHaveBeenCalledWith(objectConfigs) + }) + }) + + describe('handleUpdatePublishedAt', () => { + it('should set published at from workflow response', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('/publish')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + capturedCallback?.({ created_at: '2024-01-01T00:00:00Z' }) + + expect(mockSetPublishedAt).toHaveBeenCalledWith('2024-01-01T00:00:00Z') + }) + + it('should handle undefined workflow response', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('/publish')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + capturedCallback?.(undefined) + + expect(mockSetPublishedAt).toHaveBeenCalledWith(undefined) + }) + }) + + describe('handleUpdateDataSourceList', () => { + it('should set data source list', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: '/icon.png' } } }, + ] + + capturedCallback?.(dataSourceList) + + expect(mockSetDataSourceList).toHaveBeenCalled() + }) + + it('should prepend basePath to icon if not included', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: '/icon.png' } } }, + ] + + capturedCallback?.(dataSourceList) + + // The callback modifies the array in place + expect(dataSourceList[0].declaration.identity.icon).toBe('/base/icon.png') + }) + + it('should not modify icon if it already includes basePath', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: '/base/icon.png' } } }, + ] + + capturedCallback?.(dataSourceList) + + expect(dataSourceList[0].declaration.identity.icon).toBe('/base/icon.png') + }) + + it('should handle non-string icon', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: { url: '/icon.png' } } } }, + ] + + capturedCallback?.(dataSourceList) + + // Should not modify object icon + expect(dataSourceList[0].declaration.identity.icon).toEqual({ url: '/icon.png' }) + }) + }) + + describe('handleUpdateWorkflowFileUploadConfig', () => { + it('should set file upload config', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url === '/files/upload') { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + const config = { max_file_size: 10 * 1024 * 1024 } + capturedCallback?.(config) + + expect(mockSetFileUploadConfig).toHaveBeenCalledWith(config) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-init.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-init.spec.ts new file mode 100644 index 000000000..393852531 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-init.spec.ts @@ -0,0 +1,345 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineInit } from './use-pipeline-init' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + setState: mockWorkflowStoreSetState, + }), +})) + +// Mock dataset detail context +const mockUseDatasetDetailContextWithSelector = vi.fn() +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: Record) => unknown) => + mockUseDatasetDetailContextWithSelector(selector), +})) + +// Mock workflow service +const mockFetchWorkflowDraft = vi.fn() +const mockSyncWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), + syncWorkflowDraft: (params: unknown) => mockSyncWorkflowDraft(params), +})) + +// Mock usePipelineConfig +vi.mock('./use-pipeline-config', () => ({ + usePipelineConfig: vi.fn(), +})) + +// Mock usePipelineTemplate +vi.mock('./use-pipeline-template', () => ({ + usePipelineTemplate: () => ({ + nodes: [{ id: 'template-node' }], + edges: [], + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineInit', () => { + const mockSetEnvSecrets = vi.fn() + const mockSetEnvironmentVariables = vi.fn() + const mockSetSyncWorkflowDraftHash = vi.fn() + const mockSetDraftUpdatedAt = vi.fn() + const mockSetToolPublished = vi.fn() + const mockSetRagPipelineVariables = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockWorkflowStoreGetState.mockReturnValue({ + setEnvSecrets: mockSetEnvSecrets, + setEnvironmentVariables: mockSetEnvironmentVariables, + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setDraftUpdatedAt: mockSetDraftUpdatedAt, + setToolPublished: mockSetToolPublished, + setRagPipelineVariables: mockSetRagPipelineVariables, + }) + + mockUseDatasetDetailContextWithSelector.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + dataset: { + pipeline_id: 'test-pipeline-id', + name: 'Test Knowledge', + icon_info: { icon: 'test-icon' }, + }, + } + return selector(state) + }) + + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [{ id: 'node-1' }], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: true, + environment_variables: [], + rag_pipeline_variables: [], + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return data and isLoading', async () => { + const { result } = renderHook(() => usePipelineInit()) + + expect(result.current.isLoading).toBe(true) + expect(result.current.data).toBeUndefined() + }) + + it('should set pipelineId in workflow store on mount', () => { + renderHook(() => usePipelineInit()) + + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ + pipelineId: 'test-pipeline-id', + knowledgeName: 'Test Knowledge', + knowledgeIcon: { icon: 'test-icon' }, + }) + }) + }) + + describe('data fetching', () => { + it('should fetch workflow draft on mount', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + }) + }) + + it('should set data after successful fetch', async () => { + const { result } = renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(result.current.data).toBeDefined() + }) + }) + + it('should set isLoading to false after fetch', async () => { + const { result } = renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(result.current.isLoading).toBe(false) + }) + }) + + it('should set draft updated at', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetDraftUpdatedAt).toHaveBeenCalledWith('2024-01-01T00:00:00Z') + }) + }) + + it('should set tool published status', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetToolPublished).toHaveBeenCalledWith(true) + }) + }) + + it('should set sync hash', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('test-hash') + }) + }) + }) + + describe('environment variables handling', () => { + it('should extract secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({ 'env-1': 'secret-value' }) + }) + }) + + it('should mask secret values in environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([ + { id: 'env-1', value_type: 'secret', value: '[__HIDDEN__]' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ]) + }) + }) + + it('should handle empty environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({}) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + }) + }) + }) + + describe('rag pipeline variables handling', () => { + it('should set rag pipeline variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: [ + { variable: 'query', type: 'text-input' }, + ], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([ + { variable: 'query', type: 'text-input' }, + ]) + }) + }) + + it('should handle undefined rag pipeline variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: undefined, + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([]) + }) + }) + }) + + describe('draft not exist error handling', () => { + it('should create initial workflow when draft does not exist', async () => { + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), + bodyUsed: false, + } + mockFetchWorkflowDraft.mockRejectedValueOnce(mockJsonError) + mockSyncWorkflowDraft.mockResolvedValue({ updated_at: '2024-01-02T00:00:00Z' }) + + // Second fetch succeeds + mockFetchWorkflowDraft.mockResolvedValueOnce({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'new-hash', + updated_at: '2024-01-02T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ + notInitialWorkflow: true, + shouldAutoOpenStartNodeSelector: true, + }) + }) + }) + + it('should sync initial workflow with template nodes', async () => { + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), + bodyUsed: false, + } + mockFetchWorkflowDraft.mockRejectedValueOnce(mockJsonError) + mockSyncWorkflowDraft.mockResolvedValue({ updated_at: '2024-01-02T00:00:00Z' }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSyncWorkflowDraft).toHaveBeenCalledWith({ + url: '/rag/pipelines/test-pipeline-id/workflows/draft', + params: { + graph: { + nodes: [{ id: 'template-node' }], + edges: [], + }, + environment_variables: [], + }, + }) + }) + }) + }) + + describe('missing datasetId', () => { + it('should not fetch when datasetId is missing', async () => { + mockUseDatasetDetailContextWithSelector.mockImplementation((selector: (state: Record) => unknown) => { + const state = { dataset: undefined } + return selector(state) + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockFetchWorkflowDraft).toHaveBeenCalled() + }) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.spec.ts new file mode 100644 index 000000000..efdb18b7d --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.spec.ts @@ -0,0 +1,246 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineRefreshDraft } from './use-pipeline-refresh-draft' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useWorkflowUpdate +const mockHandleUpdateWorkflowCanvas = vi.fn() +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowUpdate: () => ({ + handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas, + }), +})) + +// Mock workflow service +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), +})) + +// Mock utils +vi.mock('../utils', () => ({ + processNodesWithoutDataSource: (nodes: unknown[], viewport: unknown) => ({ + nodes, + viewport, + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineRefreshDraft', () => { + const mockSetSyncWorkflowDraftHash = vi.fn() + const mockSetIsSyncingWorkflowDraft = vi.fn() + const mockSetEnvironmentVariables = vi.fn() + const mockSetEnvSecrets = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setIsSyncingWorkflowDraft: mockSetIsSyncingWorkflowDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setEnvSecrets: mockSetEnvSecrets, + }) + + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [{ id: 'node-1' }], + edges: [{ id: 'edge-1' }], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [], + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return handleRefreshWorkflowDraft function', () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + expect(result.current.handleRefreshWorkflowDraft).toBeDefined() + expect(typeof result.current.handleRefreshWorkflowDraft).toBe('function') + }) + }) + + describe('handleRefreshWorkflowDraft', () => { + it('should set syncing state to true at start', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + expect(mockSetIsSyncingWorkflowDraft).toHaveBeenCalledWith(true) + }) + + it('should fetch workflow draft with correct URL', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + }) + + it('should update workflow canvas with response data', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalled() + }) + }) + + it('should update sync hash after fetch', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash') + }) + }) + + it('should set syncing state to false after completion', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetIsSyncingWorkflowDraft).toHaveBeenLastCalledWith(false) + }) + }) + + it('should handle secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({ 'env-1': 'secret-value' }) + }) + }) + + it('should mask secret values in environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([ + { id: 'env-1', value_type: 'secret', value: '[__HIDDEN__]' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ]) + }) + }) + + it('should handle empty environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [], + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({}) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + }) + }) + + it('should handle undefined environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: undefined, + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({}) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + }) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-run.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-run.spec.ts new file mode 100644 index 000000000..2b2100183 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-run.spec.ts @@ -0,0 +1,825 @@ +/* eslint-disable ts/no-explicit-any */ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { WorkflowRunningStatus } from '@/app/components/workflow/types' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineRun } from './use-pipeline-run' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock reactflow +const mockStoreGetState = vi.fn() +const mockGetViewport = vi.fn() +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: mockStoreGetState, + }), + useReactFlow: () => ({ + getViewport: mockGetViewport, + }), +})) + +// Mock workflow store +const mockUseStore = vi.fn() +const mockWorkflowStoreGetState = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => mockUseStore(selector), + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + setState: mockWorkflowStoreSetState, + }), +})) + +// Mock useNodesSyncDraft +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('./use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: mockDoSyncWorkflowDraft, + }), +})) + +// Mock workflow hooks +vi.mock('@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars', () => ({ + useSetWorkflowVarsWithValue: () => ({ + fetchInspectVars: vi.fn(), + }), +})) + +const mockHandleUpdateWorkflowCanvas = vi.fn() +vi.mock('@/app/components/workflow/hooks/use-workflow-interactions', () => ({ + useWorkflowUpdate: () => ({ + handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas, + }), +})) + +vi.mock('@/app/components/workflow/hooks/use-workflow-run-event/use-workflow-run-event', () => ({ + useWorkflowRunEvent: () => ({ + handleWorkflowStarted: vi.fn(), + handleWorkflowFinished: vi.fn(), + handleWorkflowFailed: vi.fn(), + handleWorkflowNodeStarted: vi.fn(), + handleWorkflowNodeFinished: vi.fn(), + handleWorkflowNodeIterationStarted: vi.fn(), + handleWorkflowNodeIterationNext: vi.fn(), + handleWorkflowNodeIterationFinished: vi.fn(), + handleWorkflowNodeLoopStarted: vi.fn(), + handleWorkflowNodeLoopNext: vi.fn(), + handleWorkflowNodeLoopFinished: vi.fn(), + handleWorkflowNodeRetry: vi.fn(), + handleWorkflowAgentLog: vi.fn(), + handleWorkflowTextChunk: vi.fn(), + handleWorkflowTextReplace: vi.fn(), + }), +})) + +// Mock service +const mockSsePost = vi.fn() +vi.mock('@/service/base', () => ({ + ssePost: (url: string, ...args: unknown[]) => mockSsePost(url, ...args), +})) + +const mockStopWorkflowRun = vi.fn() +vi.mock('@/service/workflow', () => ({ + stopWorkflowRun: (url: string) => mockStopWorkflowRun(url), +})) + +const mockInvalidAllLastRun = vi.fn() +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +// Mock FlowType +vi.mock('@/types/common', () => ({ + FlowType: { + ragPipeline: 'rag-pipeline', + }, +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineRun', () => { + const mockSetNodes = vi.fn() + const mockGetNodes = vi.fn() + const mockSetBackupDraft = vi.fn() + const mockSetEnvironmentVariables = vi.fn() + const mockSetRagPipelineVariables = vi.fn() + const mockSetWorkflowRunningData = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + // Mock DOM element + const mockWorkflowContainer = document.createElement('div') + mockWorkflowContainer.id = 'workflow-container' + Object.defineProperty(mockWorkflowContainer, 'clientWidth', { value: 1000 }) + Object.defineProperty(mockWorkflowContainer, 'clientHeight', { value: 800 }) + document.body.appendChild(mockWorkflowContainer) + + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + setNodes: mockSetNodes, + edges: [], + }) + + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', selected: true, _runningStatus: WorkflowRunningStatus.Running } }, + ]) + + mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 }) + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft: undefined, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + return selector({ pipelineId: 'test-pipeline-id' }) + }) + + mockDoSyncWorkflowDraft.mockResolvedValue(undefined) + }) + + afterEach(() => { + const container = document.getElementById('workflow-container') + if (container) { + document.body.removeChild(container) + } + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return handleBackupDraft function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleBackupDraft).toBeDefined() + expect(typeof result.current.handleBackupDraft).toBe('function') + }) + + it('should return handleLoadBackupDraft function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleLoadBackupDraft).toBeDefined() + expect(typeof result.current.handleLoadBackupDraft).toBe('function') + }) + + it('should return handleRun function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleRun).toBeDefined() + expect(typeof result.current.handleRun).toBe('function') + }) + + it('should return handleStopRun function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleStopRun).toBeDefined() + expect(typeof result.current.handleStopRun).toBe('function') + }) + + it('should return handleRestoreFromPublishedWorkflow function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleRestoreFromPublishedWorkflow).toBeDefined() + expect(typeof result.current.handleRestoreFromPublishedWorkflow).toBe('function') + }) + }) + + describe('handleBackupDraft', () => { + it('should backup draft when no backup exists', () => { + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleBackupDraft() + }) + + expect(mockSetBackupDraft).toHaveBeenCalled() + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should not backup draft when backup already exists', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft: { nodes: [], edges: [], viewport: {}, environmentVariables: [] }, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleBackupDraft() + }) + + expect(mockSetBackupDraft).not.toHaveBeenCalled() + }) + }) + + describe('handleLoadBackupDraft', () => { + it('should load backup draft when exists', () => { + const backupDraft = { + nodes: [{ id: 'backup-node' }], + edges: [{ id: 'backup-edge' }], + viewport: { x: 100, y: 100, zoom: 1.5 }, + environmentVariables: [{ key: 'ENV', value: 'test' }], + } + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleLoadBackupDraft() + }) + + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({ + nodes: backupDraft.nodes, + edges: backupDraft.edges, + viewport: backupDraft.viewport, + }) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith(backupDraft.environmentVariables) + expect(mockSetBackupDraft).toHaveBeenCalledWith(undefined) + }) + + it('should not load when no backup exists', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft: undefined, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleLoadBackupDraft() + }) + + expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled() + }) + }) + + describe('handleStopRun', () => { + it('should call stop workflow run service', () => { + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleStopRun('task-123') + }) + + expect(mockStopWorkflowRun).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflow-runs/tasks/task-123/stop', + ) + }) + }) + + describe('handleRestoreFromPublishedWorkflow', () => { + it('should restore workflow from published version', () => { + const publishedWorkflow = { + graph: { + nodes: [{ id: 'pub-node', data: { type: 'start' } }], + edges: [{ id: 'pub-edge' }], + viewport: { x: 50, y: 50, zoom: 1 }, + }, + environment_variables: [{ key: 'PUB_ENV', value: 'pub' }], + rag_pipeline_variables: [{ variable: 'input', type: 'text-input' }], + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({ + nodes: [{ id: 'pub-node', data: { type: 'start', selected: false }, selected: false }], + edges: publishedWorkflow.graph.edges, + viewport: publishedWorkflow.graph.viewport, + }) + }) + + it('should set environment variables from published workflow', () => { + const publishedWorkflow = { + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + environment_variables: [{ key: 'ENV', value: 'value' }], + rag_pipeline_variables: [], + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([{ key: 'ENV', value: 'value' }]) + }) + + it('should set rag pipeline variables from published workflow', () => { + const publishedWorkflow = { + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + environment_variables: [], + rag_pipeline_variables: [{ variable: 'query', type: 'text-input' }], + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([{ variable: 'query', type: 'text-input' }]) + }) + + it('should handle empty environment and rag pipeline variables', () => { + const publishedWorkflow = { + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + environment_variables: undefined, + rag_pipeline_variables: undefined, + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([]) + }) + }) + + describe('handleRun', () => { + it('should sync workflow draft before running', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should reset node selection and running status', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockSetNodes).toHaveBeenCalled() + }) + + it('should clear history workflow data', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ historyWorkflowData: undefined }) + }) + + it('should set initial running data', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockSetWorkflowRunningData).toHaveBeenCalledWith({ + result: { + inputs_truncated: false, + process_data_truncated: false, + outputs_truncated: false, + status: WorkflowRunningStatus.Running, + }, + tracing: [], + resultText: '', + }) + }) + + it('should call ssePost with correct URL', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: { query: 'test' } }) + }) + + expect(mockSsePost).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflows/draft/run', + expect.any(Object), + expect.any(Object), + ) + }) + + it('should call onWorkflowStarted callback when provided', async () => { + const onWorkflowStarted = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onWorkflowStarted }) + }) + + // Trigger the callback + await act(async () => { + capturedCallbacks.onWorkflowStarted?.({ task_id: 'task-1' }) + }) + + expect(onWorkflowStarted).toHaveBeenCalledWith({ task_id: 'task-1' }) + }) + + it('should call onWorkflowFinished callback when provided', async () => { + const onWorkflowFinished = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onWorkflowFinished }) + }) + + await act(async () => { + capturedCallbacks.onWorkflowFinished?.({ status: 'succeeded' }) + }) + + expect(onWorkflowFinished).toHaveBeenCalledWith({ status: 'succeeded' }) + }) + + it('should call onError callback when provided', async () => { + const onError = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onError }) + }) + + await act(async () => { + capturedCallbacks.onError?.({ message: 'error' }) + }) + + expect(onError).toHaveBeenCalledWith({ message: 'error' }) + }) + + it('should call onNodeStarted callback when provided', async () => { + const onNodeStarted = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onNodeStarted }) + }) + + await act(async () => { + capturedCallbacks.onNodeStarted?.({ node_id: 'node-1' }) + }) + + expect(onNodeStarted).toHaveBeenCalledWith({ node_id: 'node-1' }) + }) + + it('should call onNodeFinished callback when provided', async () => { + const onNodeFinished = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onNodeFinished }) + }) + + await act(async () => { + capturedCallbacks.onNodeFinished?.({ node_id: 'node-1' }) + }) + + expect(onNodeFinished).toHaveBeenCalledWith({ node_id: 'node-1' }) + }) + + it('should call onIterationStart callback when provided', async () => { + const onIterationStart = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onIterationStart }) + }) + + await act(async () => { + capturedCallbacks.onIterationStart?.({ iteration_id: 'iter-1' }) + }) + + expect(onIterationStart).toHaveBeenCalledWith({ iteration_id: 'iter-1' }) + }) + + it('should call onIterationNext callback when provided', async () => { + const onIterationNext = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onIterationNext }) + }) + + await act(async () => { + capturedCallbacks.onIterationNext?.({ index: 1 }) + }) + + expect(onIterationNext).toHaveBeenCalledWith({ index: 1 }) + }) + + it('should call onIterationFinish callback when provided', async () => { + const onIterationFinish = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onIterationFinish }) + }) + + await act(async () => { + capturedCallbacks.onIterationFinish?.({ iteration_id: 'iter-1' }) + }) + + expect(onIterationFinish).toHaveBeenCalledWith({ iteration_id: 'iter-1' }) + }) + + it('should call onLoopStart callback when provided', async () => { + const onLoopStart = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onLoopStart }) + }) + + await act(async () => { + capturedCallbacks.onLoopStart?.({ loop_id: 'loop-1' }) + }) + + expect(onLoopStart).toHaveBeenCalledWith({ loop_id: 'loop-1' }) + }) + + it('should call onLoopNext callback when provided', async () => { + const onLoopNext = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onLoopNext }) + }) + + await act(async () => { + capturedCallbacks.onLoopNext?.({ index: 2 }) + }) + + expect(onLoopNext).toHaveBeenCalledWith({ index: 2 }) + }) + + it('should call onLoopFinish callback when provided', async () => { + const onLoopFinish = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onLoopFinish }) + }) + + await act(async () => { + capturedCallbacks.onLoopFinish?.({ loop_id: 'loop-1' }) + }) + + expect(onLoopFinish).toHaveBeenCalledWith({ loop_id: 'loop-1' }) + }) + + it('should call onNodeRetry callback when provided', async () => { + const onNodeRetry = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onNodeRetry }) + }) + + await act(async () => { + capturedCallbacks.onNodeRetry?.({ node_id: 'node-1', retry: 1 }) + }) + + expect(onNodeRetry).toHaveBeenCalledWith({ node_id: 'node-1', retry: 1 }) + }) + + it('should call onAgentLog callback when provided', async () => { + const onAgentLog = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onAgentLog }) + }) + + await act(async () => { + capturedCallbacks.onAgentLog?.({ message: 'agent log' }) + }) + + expect(onAgentLog).toHaveBeenCalledWith({ message: 'agent log' }) + }) + + it('should handle onTextChunk callback', async () => { + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + await act(async () => { + capturedCallbacks.onTextChunk?.({ text: 'chunk' }) + }) + + // Just verify it doesn't throw + expect(capturedCallbacks.onTextChunk).toBeDefined() + }) + + it('should handle onTextReplace callback', async () => { + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + await act(async () => { + capturedCallbacks.onTextReplace?.({ text: 'replaced' }) + }) + + // Just verify it doesn't throw + expect(capturedCallbacks.onTextReplace).toBeDefined() + }) + + it('should pass rest callback to ssePost', async () => { + const customCallback = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onData: customCallback } as any) + }) + + expect(capturedCallbacks.onData).toBeDefined() + }) + + it('should handle callbacks without optional handlers', async () => { + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + // Run without any optional callbacks + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + // Trigger all callbacks - they should not throw even without optional handlers + await act(async () => { + capturedCallbacks.onWorkflowStarted?.({ task_id: 'task-1' }) + capturedCallbacks.onWorkflowFinished?.({ status: 'succeeded' }) + capturedCallbacks.onError?.({ message: 'error' }) + capturedCallbacks.onNodeStarted?.({ node_id: 'node-1' }) + capturedCallbacks.onNodeFinished?.({ node_id: 'node-1' }) + capturedCallbacks.onIterationStart?.({ iteration_id: 'iter-1' }) + capturedCallbacks.onIterationNext?.({ index: 1 }) + capturedCallbacks.onIterationFinish?.({ iteration_id: 'iter-1' }) + capturedCallbacks.onLoopStart?.({ loop_id: 'loop-1' }) + capturedCallbacks.onLoopNext?.({ index: 2 }) + capturedCallbacks.onLoopFinish?.({ loop_id: 'loop-1' }) + capturedCallbacks.onNodeRetry?.({ node_id: 'node-1', retry: 1 }) + capturedCallbacks.onAgentLog?.({ message: 'agent log' }) + capturedCallbacks.onTextChunk?.({ text: 'chunk' }) + capturedCallbacks.onTextReplace?.({ text: 'replaced' }) + }) + + // Verify ssePost was called + expect(mockSsePost).toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-start-run.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-start-run.spec.ts new file mode 100644 index 000000000..4266fb993 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-start-run.spec.ts @@ -0,0 +1,217 @@ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { WorkflowRunningStatus } from '@/app/components/workflow/types' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineStartRun } from './use-pipeline-start-run' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + setState: mockWorkflowStoreSetState, + }), +})) + +// Mock workflow interactions +const mockHandleCancelDebugAndPreviewPanel = vi.fn() +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowInteractions: () => ({ + handleCancelDebugAndPreviewPanel: mockHandleCancelDebugAndPreviewPanel, + }), +})) + +// Mock useNodesSyncDraft +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('@/app/components/rag-pipeline/hooks', () => ({ + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: mockDoSyncWorkflowDraft, + }), + useInputFieldPanel: () => ({ + closeAllInputFieldPanels: vi.fn(), + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineStartRun', () => { + const mockSetIsPreparingDataSource = vi.fn() + const mockSetShowEnvPanel = vi.fn() + const mockSetShowDebugAndPreviewPanel = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + mockDoSyncWorkflowDraft.mockResolvedValue(undefined) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return handleStartWorkflowRun function', () => { + const { result } = renderHook(() => usePipelineStartRun()) + + expect(result.current.handleStartWorkflowRun).toBeDefined() + expect(typeof result.current.handleStartWorkflowRun).toBe('function') + }) + + it('should return handleWorkflowStartRunInWorkflow function', () => { + const { result } = renderHook(() => usePipelineStartRun()) + + expect(result.current.handleWorkflowStartRunInWorkflow).toBeDefined() + expect(typeof result.current.handleWorkflowStartRunInWorkflow).toBe('function') + }) + }) + + describe('handleWorkflowStartRunInWorkflow', () => { + it('should not proceed when workflow is already running', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: { + result: { status: WorkflowRunningStatus.Running }, + }, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockSetShowEnvPanel).not.toHaveBeenCalled() + }) + + it('should set preparing data source when not preparing and has running data', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: { + result: { status: WorkflowRunningStatus.Succeeded }, + }, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ + isPreparingDataSource: true, + workflowRunningData: undefined, + }) + }) + + it('should cancel debug panel when already showing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: true, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockSetIsPreparingDataSource).toHaveBeenCalledWith(false) + expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalled() + }) + + it('should sync draft and show debug panel when conditions are met', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + expect(mockSetIsPreparingDataSource).toHaveBeenCalledWith(true) + expect(mockSetShowDebugAndPreviewPanel).toHaveBeenCalledWith(true) + }) + + it('should hide env panel at start', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockSetShowEnvPanel).toHaveBeenCalledWith(false) + }) + }) + + describe('handleStartWorkflowRun', () => { + it('should call handleWorkflowStartRunInWorkflow', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + result.current.handleStartWorkflowRun() + }) + + // Should trigger the same workflow as handleWorkflowStartRunInWorkflow + expect(mockSetShowEnvPanel).toHaveBeenCalledWith(false) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/store/index.spec.ts b/web/app/components/rag-pipeline/store/index.spec.ts new file mode 100644 index 000000000..c8c0a3533 --- /dev/null +++ b/web/app/components/rag-pipeline/store/index.spec.ts @@ -0,0 +1,289 @@ +/* eslint-disable ts/no-explicit-any */ +import type { DataSourceItem } from '@/app/components/workflow/block-selector/types' +import { describe, expect, it, vi } from 'vitest' +import { createRagPipelineSliceSlice } from './index' + +// Mock the transformDataSourceToTool function +vi.mock('@/app/components/workflow/block-selector/utils', () => ({ + transformDataSourceToTool: (item: DataSourceItem) => ({ + ...item, + transformed: true, + }), +})) + +describe('createRagPipelineSliceSlice', () => { + const mockSet = vi.fn() + + describe('initial state', () => { + it('should have empty pipelineId', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.pipelineId).toBe('') + }) + + it('should have empty knowledgeName', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.knowledgeName).toBe('') + }) + + it('should have showInputFieldPanel as false', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.showInputFieldPanel).toBe(false) + }) + + it('should have showInputFieldPreviewPanel as false', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.showInputFieldPreviewPanel).toBe(false) + }) + + it('should have inputFieldEditPanelProps as null', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.inputFieldEditPanelProps).toBeNull() + }) + + it('should have empty nodesDefaultConfigs', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.nodesDefaultConfigs).toEqual({}) + }) + + it('should have empty ragPipelineVariables', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.ragPipelineVariables).toEqual([]) + }) + + it('should have empty dataSourceList', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.dataSourceList).toEqual([]) + }) + + it('should have isPreparingDataSource as false', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.isPreparingDataSource).toBe(false) + }) + }) + + describe('setShowInputFieldPanel', () => { + it('should call set with showInputFieldPanel true', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPanel(true) + + expect(mockSet).toHaveBeenCalledWith(expect.any(Function)) + + // Get the setter function and execute it + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPanel: true }) + }) + + it('should call set with showInputFieldPanel false', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPanel(false) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPanel: false }) + }) + }) + + describe('setShowInputFieldPreviewPanel', () => { + it('should call set with showInputFieldPreviewPanel true', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPreviewPanel(true) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPreviewPanel: true }) + }) + + it('should call set with showInputFieldPreviewPanel false', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPreviewPanel(false) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPreviewPanel: false }) + }) + }) + + describe('setInputFieldEditPanelProps', () => { + it('should call set with inputFieldEditPanelProps object', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const props = { type: 'create' as const } + + slice.setInputFieldEditPanelProps(props as any) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ inputFieldEditPanelProps: props }) + }) + + it('should call set with inputFieldEditPanelProps null', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setInputFieldEditPanelProps(null) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ inputFieldEditPanelProps: null }) + }) + }) + + describe('setNodesDefaultConfigs', () => { + it('should call set with nodesDefaultConfigs', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const configs = { node1: { key: 'value' } } + + slice.setNodesDefaultConfigs(configs) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ nodesDefaultConfigs: configs }) + }) + + it('should call set with empty nodesDefaultConfigs', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setNodesDefaultConfigs({}) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ nodesDefaultConfigs: {} }) + }) + }) + + describe('setRagPipelineVariables', () => { + it('should call set with ragPipelineVariables', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const variables = [ + { type: 'text-input', variable: 'var1', label: 'Var 1', required: true }, + ] + + slice.setRagPipelineVariables(variables as any) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ ragPipelineVariables: variables }) + }) + + it('should call set with empty ragPipelineVariables', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setRagPipelineVariables([]) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ ragPipelineVariables: [] }) + }) + }) + + describe('setDataSourceList', () => { + it('should transform and set dataSourceList', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const dataSourceList: DataSourceItem[] = [ + { name: 'source1', key: 'key1' } as unknown as DataSourceItem, + { name: 'source2', key: 'key2' } as unknown as DataSourceItem, + ] + + slice.setDataSourceList(dataSourceList) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result.dataSourceList).toHaveLength(2) + expect(result.dataSourceList[0]).toEqual({ name: 'source1', key: 'key1', transformed: true }) + expect(result.dataSourceList[1]).toEqual({ name: 'source2', key: 'key2', transformed: true }) + }) + + it('should set empty dataSourceList', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setDataSourceList([]) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result.dataSourceList).toEqual([]) + }) + }) + + describe('setIsPreparingDataSource', () => { + it('should call set with isPreparingDataSource true', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setIsPreparingDataSource(true) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ isPreparingDataSource: true }) + }) + + it('should call set with isPreparingDataSource false', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setIsPreparingDataSource(false) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ isPreparingDataSource: false }) + }) + }) +}) + +describe('RagPipelineSliceShape type', () => { + it('should define all required properties', () => { + const slice = createRagPipelineSliceSlice(vi.fn(), vi.fn() as any, vi.fn() as any) + + // Check all properties exist + expect(slice).toHaveProperty('pipelineId') + expect(slice).toHaveProperty('knowledgeName') + expect(slice).toHaveProperty('showInputFieldPanel') + expect(slice).toHaveProperty('setShowInputFieldPanel') + expect(slice).toHaveProperty('showInputFieldPreviewPanel') + expect(slice).toHaveProperty('setShowInputFieldPreviewPanel') + expect(slice).toHaveProperty('inputFieldEditPanelProps') + expect(slice).toHaveProperty('setInputFieldEditPanelProps') + expect(slice).toHaveProperty('nodesDefaultConfigs') + expect(slice).toHaveProperty('setNodesDefaultConfigs') + expect(slice).toHaveProperty('ragPipelineVariables') + expect(slice).toHaveProperty('setRagPipelineVariables') + expect(slice).toHaveProperty('dataSourceList') + expect(slice).toHaveProperty('setDataSourceList') + expect(slice).toHaveProperty('isPreparingDataSource') + expect(slice).toHaveProperty('setIsPreparingDataSource') + }) + + it('should have all setters as functions', () => { + const slice = createRagPipelineSliceSlice(vi.fn(), vi.fn() as any, vi.fn() as any) + + expect(typeof slice.setShowInputFieldPanel).toBe('function') + expect(typeof slice.setShowInputFieldPreviewPanel).toBe('function') + expect(typeof slice.setInputFieldEditPanelProps).toBe('function') + expect(typeof slice.setNodesDefaultConfigs).toBe('function') + expect(typeof slice.setRagPipelineVariables).toBe('function') + expect(typeof slice.setDataSourceList).toBe('function') + expect(typeof slice.setIsPreparingDataSource).toBe('function') + }) +}) diff --git a/web/app/components/rag-pipeline/utils/index.spec.ts b/web/app/components/rag-pipeline/utils/index.spec.ts new file mode 100644 index 000000000..9d816af68 --- /dev/null +++ b/web/app/components/rag-pipeline/utils/index.spec.ts @@ -0,0 +1,348 @@ +import type { Viewport } from 'reactflow' +import type { Node } from '@/app/components/workflow/types' +import { describe, expect, it, vi } from 'vitest' +import { BlockEnum } from '@/app/components/workflow/types' +import { processNodesWithoutDataSource } from './nodes' + +// Mock constants +vi.mock('@/app/components/workflow/constants', () => ({ + CUSTOM_NODE: 'custom', + NODE_WIDTH_X_OFFSET: 400, + START_INITIAL_POSITION: { x: 100, y: 100 }, +})) + +vi.mock('@/app/components/workflow/nodes/data-source-empty/constants', () => ({ + CUSTOM_DATA_SOURCE_EMPTY_NODE: 'data-source-empty', +})) + +vi.mock('@/app/components/workflow/note-node/constants', () => ({ + CUSTOM_NOTE_NODE: 'note', +})) + +vi.mock('@/app/components/workflow/note-node/types', () => ({ + NoteTheme: { blue: 'blue' }, +})) + +vi.mock('@/app/components/workflow/utils', () => ({ + generateNewNode: ({ id, type, data, position }: { id: string, type?: string, data: object, position: { x: number, y: number } }) => ({ + newNode: { id, type: type || 'custom', data, position }, + }), +})) + +describe('processNodesWithoutDataSource', () => { + describe('when nodes contain DataSource', () => { + it('should return original nodes and viewport unchanged', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.DataSource, title: 'Data Source' }, + position: { x: 100, y: 100 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.End, title: 'End' }, + position: { x: 500, y: 100 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + expect(result.nodes).toBe(nodes) + expect(result.viewport).toBe(viewport) + }) + + it('should check all nodes before returning early', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.Start, title: 'Start' }, + position: { x: 0, y: 0 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.DataSource, title: 'Data Source' }, + position: { x: 100, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + expect(result.nodes).toBe(nodes) + }) + }) + + describe('when nodes do not contain DataSource', () => { + it('should add data source empty node and note node for single custom node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'Knowledge Base' }, + position: { x: 500, y: 200 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + expect(result.nodes.length).toBe(3) + expect(result.nodes[0].id).toBe('data-source-empty') + expect(result.nodes[1].id).toBe('note') + expect(result.nodes[2]).toBe(nodes[0]) + }) + + it('should use the leftmost custom node position for new nodes', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB 1' }, + position: { x: 700, y: 100 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.End, title: 'End' }, + position: { x: 200, y: 100 }, // This is the leftmost + } as Node, + { + id: 'node-3', + type: 'custom', + data: { type: BlockEnum.Start, title: 'Start' }, + position: { x: 500, y: 100 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // New nodes should be positioned based on the leftmost node (x: 200) + // startX = 200 - 400 = -200 + expect(result.nodes[0].position.x).toBe(-200) + expect(result.nodes[0].position.y).toBe(100) + }) + + it('should adjust viewport based on new node position', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 300, y: 200 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // startX = 300 - 400 = -100 + // startY = 200 + // viewport.x = (100 - (-100)) * 1 = 200 + // viewport.y = (100 - 200) * 1 = -100 + expect(result.viewport).toEqual({ + x: 200, + y: -100, + zoom: 1, + }) + }) + + it('should apply zoom factor to viewport calculation', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 300, y: 200 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 2 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // startX = 300 - 400 = -100 + // startY = 200 + // viewport.x = (100 - (-100)) * 2 = 400 + // viewport.y = (100 - 200) * 2 = -200 + expect(result.viewport).toEqual({ + x: 400, + y: -200, + zoom: 2, + }) + }) + + it('should use default zoom 1 when viewport zoom is undefined', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes, undefined) + + expect(result.viewport?.zoom).toBe(1) + }) + + it('should add note node below data source empty node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + // Data source empty node position + const dataSourceEmptyNode = result.nodes[0] + const noteNode = result.nodes[1] + + // Note node should be 100px below data source empty node + expect(noteNode.position.x).toBe(dataSourceEmptyNode.position.x) + expect(noteNode.position.y).toBe(dataSourceEmptyNode.position.y + 100) + }) + + it('should set correct data for data source empty node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + expect(result.nodes[0].data.type).toBe(BlockEnum.DataSourceEmpty) + expect(result.nodes[0].data._isTempNode).toBe(true) + expect(result.nodes[0].data.width).toBe(240) + }) + + it('should set correct data for note node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + const noteNode = result.nodes[1] + const noteData = noteNode.data as Record + expect(noteData._isTempNode).toBe(true) + expect(noteData.theme).toBe('blue') + expect(noteData.width).toBe(240) + expect(noteData.height).toBe(300) + expect(noteData.showAuthor).toBe(true) + }) + }) + + describe('when nodes array is empty', () => { + it('should return empty nodes array unchanged', () => { + const nodes: Node[] = [] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + expect(result.nodes).toEqual([]) + expect(result.viewport).toBe(viewport) + }) + }) + + describe('when no custom nodes exist', () => { + it('should return original nodes when only non-custom nodes', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'special', // Not 'custom' + data: { type: BlockEnum.Start, title: 'Start' }, + position: { x: 100, y: 100 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // No custom nodes to find leftmost, so no new nodes are added + expect(result.nodes).toBe(nodes) + expect(result.viewport).toBe(viewport) + }) + }) + + describe('edge cases', () => { + it('should handle nodes with same x position', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB 1' }, + position: { x: 300, y: 100 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.End, title: 'End' }, + position: { x: 300, y: 200 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + // First node should be used as leftNode + expect(result.nodes.length).toBe(4) + }) + + it('should handle negative positions', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: -100, y: -50 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + // startX = -100 - 400 = -500 + expect(result.nodes[0].position.x).toBe(-500) + expect(result.nodes[0].position.y).toBe(-50) + }) + + it('should handle undefined viewport gracefully', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes, undefined) + + expect(result.viewport).toBeDefined() + expect(result.viewport?.zoom).toBe(1) + }) + }) +}) + +describe('module exports', () => { + it('should export processNodesWithoutDataSource', () => { + expect(processNodesWithoutDataSource).toBeDefined() + expect(typeof processNodesWithoutDataSource).toBe('function') + }) +}) diff --git a/web/app/components/share/text-generation/info-modal.spec.tsx b/web/app/components/share/text-generation/info-modal.spec.tsx new file mode 100644 index 000000000..025c5edde --- /dev/null +++ b/web/app/components/share/text-generation/info-modal.spec.tsx @@ -0,0 +1,205 @@ +import type { SiteInfo } from '@/models/share' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import InfoModal from './info-modal' + +// Only mock react-i18next for translations +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +afterEach(() => { + cleanup() +}) + +describe('InfoModal', () => { + const mockOnClose = vi.fn() + + const baseSiteInfo: SiteInfo = { + title: 'Test App', + icon: '🚀', + icon_type: 'emoji', + icon_background: '#ffffff', + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should not render when isShow is false', () => { + render( + , + ) + + expect(screen.queryByText('Test App')).not.toBeInTheDocument() + }) + + it('should render when isShow is true', () => { + render( + , + ) + + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should render app title', () => { + render( + , + ) + + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should render copyright when provided', () => { + const siteInfoWithCopyright: SiteInfo = { + ...baseSiteInfo, + copyright: 'Dify Inc.', + } + + render( + , + ) + + expect(screen.getByText(/Dify Inc./)).toBeInTheDocument() + }) + + it('should render current year in copyright', () => { + const siteInfoWithCopyright: SiteInfo = { + ...baseSiteInfo, + copyright: 'Test Company', + } + + render( + , + ) + + const currentYear = new Date().getFullYear().toString() + expect(screen.getByText(new RegExp(currentYear))).toBeInTheDocument() + }) + + it('should render custom disclaimer when provided', () => { + const siteInfoWithDisclaimer: SiteInfo = { + ...baseSiteInfo, + custom_disclaimer: 'This is a custom disclaimer', + } + + render( + , + ) + + expect(screen.getByText('This is a custom disclaimer')).toBeInTheDocument() + }) + + it('should not render copyright section when not provided', () => { + render( + , + ) + + const year = new Date().getFullYear().toString() + expect(screen.queryByText(new RegExp(`©.*${year}`))).not.toBeInTheDocument() + }) + + it('should render with undefined data', () => { + render( + , + ) + + // Modal should still render but without content + expect(screen.queryByText('Test App')).not.toBeInTheDocument() + }) + + it('should render with image icon type', () => { + const siteInfoWithImage: SiteInfo = { + ...baseSiteInfo, + icon_type: 'image', + icon_url: 'https://example.com/icon.png', + } + + render( + , + ) + + expect(screen.getByText(siteInfoWithImage.title!)).toBeInTheDocument() + }) + }) + + describe('close functionality', () => { + it('should call onClose when close button is clicked', () => { + render( + , + ) + + // Find the close icon (RiCloseLine) which has text-text-tertiary class + const closeIcon = document.querySelector('[class*="text-text-tertiary"]') + expect(closeIcon).toBeInTheDocument() + if (closeIcon) { + fireEvent.click(closeIcon) + expect(mockOnClose).toHaveBeenCalled() + } + }) + }) + + describe('both copyright and disclaimer', () => { + it('should render both when both are provided', () => { + const siteInfoWithBoth: SiteInfo = { + ...baseSiteInfo, + copyright: 'My Company', + custom_disclaimer: 'Disclaimer text here', + } + + render( + , + ) + + expect(screen.getByText(/My Company/)).toBeInTheDocument() + expect(screen.getByText('Disclaimer text here')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/share/text-generation/menu-dropdown.spec.tsx b/web/app/components/share/text-generation/menu-dropdown.spec.tsx new file mode 100644 index 000000000..b54a2df63 --- /dev/null +++ b/web/app/components/share/text-generation/menu-dropdown.spec.tsx @@ -0,0 +1,261 @@ +import type { SiteInfo } from '@/models/share' +import { act, cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import MenuDropdown from './menu-dropdown' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock next/navigation +const mockReplace = vi.fn() +const mockPathname = '/test-path' +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), + usePathname: () => mockPathname, +})) + +// Mock web-app-context +const mockShareCode = 'test-share-code' +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: Record) => unknown) => { + const state = { + webAppAccessMode: 'code', + shareCode: mockShareCode, + } + return selector(state) + }, +})) + +// Mock webapp-auth service +const mockWebAppLogout = vi.fn().mockResolvedValue(undefined) +vi.mock('@/service/webapp-auth', () => ({ + webAppLogout: (...args: unknown[]) => mockWebAppLogout(...args), +})) + +afterEach(() => { + cleanup() +}) + +describe('MenuDropdown', () => { + const baseSiteInfo: SiteInfo = { + title: 'Test App', + icon: '🚀', + icon_type: 'emoji', + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the trigger button', () => { + render() + + // The trigger button contains a settings icon (RiEqualizer2Line) + const triggerButton = screen.getByRole('button') + expect(triggerButton).toBeInTheDocument() + }) + + it('should not show dropdown content initially', () => { + render() + + // Dropdown content should not be visible initially + expect(screen.queryByText('theme.theme')).not.toBeInTheDocument() + }) + + it('should show dropdown content when clicked', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('theme.theme')).toBeInTheDocument() + }) + }) + + it('should show About option in dropdown', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.about')).toBeInTheDocument() + }) + }) + }) + + describe('privacy policy link', () => { + it('should show privacy policy link when provided', async () => { + const siteInfoWithPrivacy: SiteInfo = { + ...baseSiteInfo, + privacy_policy: 'https://example.com/privacy', + } + + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('chat.privacyPolicyMiddle')).toBeInTheDocument() + }) + }) + + it('should not show privacy policy link when not provided', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.queryByText('chat.privacyPolicyMiddle')).not.toBeInTheDocument() + }) + }) + + it('should have correct href for privacy policy link', async () => { + const privacyUrl = 'https://example.com/privacy' + const siteInfoWithPrivacy: SiteInfo = { + ...baseSiteInfo, + privacy_policy: privacyUrl, + } + + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + const link = screen.getByText('chat.privacyPolicyMiddle').closest('a') + expect(link).toHaveAttribute('href', privacyUrl) + expect(link).toHaveAttribute('target', '_blank') + }) + }) + }) + + describe('logout functionality', () => { + it('should show logout option when hideLogout is false', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.logout')).toBeInTheDocument() + }) + }) + + it('should hide logout option when hideLogout is true', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.queryByText('userProfile.logout')).not.toBeInTheDocument() + }) + }) + + it('should call webAppLogout and redirect when logout is clicked', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.logout')).toBeInTheDocument() + }) + + const logoutButton = screen.getByText('userProfile.logout') + await act(async () => { + fireEvent.click(logoutButton) + }) + + await waitFor(() => { + expect(mockWebAppLogout).toHaveBeenCalledWith(mockShareCode) + expect(mockReplace).toHaveBeenCalledWith(`/webapp-signin?redirect_url=${mockPathname}`) + }) + }) + }) + + describe('about modal', () => { + it('should show InfoModal when About is clicked', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.about')).toBeInTheDocument() + }) + + const aboutButton = screen.getByText('userProfile.about') + fireEvent.click(aboutButton) + + await waitFor(() => { + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + }) + }) + + describe('forceClose prop', () => { + it('should close dropdown when forceClose changes to true', async () => { + const { rerender } = render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('theme.theme')).toBeInTheDocument() + }) + + rerender() + + await waitFor(() => { + expect(screen.queryByText('theme.theme')).not.toBeInTheDocument() + }) + }) + }) + + describe('placement prop', () => { + it('should accept custom placement', () => { + render() + + const triggerButton = screen.getByRole('button') + expect(triggerButton).toBeInTheDocument() + }) + }) + + describe('toggle behavior', () => { + it('should close dropdown when clicking trigger again', async () => { + render() + + const triggerButton = screen.getByRole('button') + + // Open + fireEvent.click(triggerButton) + await waitFor(() => { + expect(screen.getByText('theme.theme')).toBeInTheDocument() + }) + + // Close + fireEvent.click(triggerButton) + await waitFor(() => { + expect(screen.queryByText('theme.theme')).not.toBeInTheDocument() + }) + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((MenuDropdown as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/share/text-generation/result/content.spec.tsx b/web/app/components/share/text-generation/result/content.spec.tsx new file mode 100644 index 000000000..242ae7aa5 --- /dev/null +++ b/web/app/components/share/text-generation/result/content.spec.tsx @@ -0,0 +1,133 @@ +import type { FeedbackType } from '@/app/components/base/chat/chat/type' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import Result from './content' + +// Only mock react-i18next for translations +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock copy-to-clipboard for the Header component +vi.mock('copy-to-clipboard', () => ({ + default: vi.fn(() => true), +})) + +// Mock the format function from service/base +vi.mock('@/service/base', () => ({ + format: (content: string) => content.replace(/\n/g, '
    '), +})) + +afterEach(() => { + cleanup() +}) + +describe('Result (content)', () => { + const mockOnFeedback = vi.fn() + + const defaultProps = { + content: 'Test content here', + showFeedback: true, + feedback: { rating: null } as FeedbackType, + onFeedback: mockOnFeedback, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the Header component', () => { + render() + + // Header renders the result title + expect(screen.getByText('generation.resultTitle')).toBeInTheDocument() + }) + + it('should render content', () => { + render() + + expect(screen.getByText('Test content here')).toBeInTheDocument() + }) + + it('should render formatted content with line breaks', () => { + render( + , + ) + + // The format function converts \n to
    + const contentDiv = document.querySelector('[class*="overflow-scroll"]') + expect(contentDiv?.innerHTML).toContain('Line 1
    Line 2') + }) + + it('should have max height style', () => { + render() + + const contentDiv = document.querySelector('[class*="overflow-scroll"]') + expect(contentDiv).toHaveStyle({ maxHeight: '70vh' }) + }) + + it('should render with empty content', () => { + render( + , + ) + + expect(screen.getByText('generation.resultTitle')).toBeInTheDocument() + }) + + it('should render with HTML content safely', () => { + render( + , + ) + + // Content is rendered via dangerouslySetInnerHTML + const contentDiv = document.querySelector('[class*="overflow-scroll"]') + expect(contentDiv).toBeInTheDocument() + }) + }) + + describe('feedback props', () => { + it('should pass showFeedback to Header', () => { + render( + , + ) + + // Feedback buttons should not be visible + const feedbackArea = document.querySelector('[class*="space-x-1 rounded-lg border"]') + expect(feedbackArea).not.toBeInTheDocument() + }) + + it('should pass feedback to Header', () => { + render( + , + ) + + // Like button should be highlighted + const likeButton = document.querySelector('[class*="primary"]') + expect(likeButton).toBeInTheDocument() + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((Result as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/share/text-generation/result/header.spec.tsx b/web/app/components/share/text-generation/result/header.spec.tsx new file mode 100644 index 000000000..b2ef0fadc --- /dev/null +++ b/web/app/components/share/text-generation/result/header.spec.tsx @@ -0,0 +1,176 @@ +import type { FeedbackType } from '@/app/components/base/chat/chat/type' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import Header from './header' + +// Only mock react-i18next for translations +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock copy-to-clipboard +const mockCopy = vi.fn((_text: string) => true) +vi.mock('copy-to-clipboard', () => ({ + default: (text: string) => mockCopy(text), +})) + +afterEach(() => { + cleanup() +}) + +describe('Header', () => { + const mockOnFeedback = vi.fn() + + const defaultProps = { + result: 'Test result content', + showFeedback: true, + feedback: { rating: null } as FeedbackType, + onFeedback: mockOnFeedback, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the result title', () => { + render(
    ) + + expect(screen.getByText('generation.resultTitle')).toBeInTheDocument() + }) + + it('should render the copy button', () => { + render(
    ) + + expect(screen.getByText('generation.copy')).toBeInTheDocument() + }) + }) + + describe('copy functionality', () => { + it('should copy result when copy button is clicked', () => { + render(
    ) + + const copyButton = screen.getByText('generation.copy').closest('button') + fireEvent.click(copyButton!) + + expect(mockCopy).toHaveBeenCalledWith('Test result content') + }) + }) + + describe('feedback buttons when showFeedback is true', () => { + it('should show feedback buttons when no rating is given', () => { + render(
    ) + + // Should show both thumbs up and down buttons + const buttons = document.querySelectorAll('[class*="cursor-pointer"]') + expect(buttons.length).toBeGreaterThan(0) + }) + + it('should show like button highlighted when rating is like', () => { + render( +
    , + ) + + // Should show the undo button for like + const likeButton = document.querySelector('[class*="primary"]') + expect(likeButton).toBeInTheDocument() + }) + + it('should show dislike button highlighted when rating is dislike', () => { + render( +
    , + ) + + // Should show the undo button for dislike + const dislikeButton = document.querySelector('[class*="red"]') + expect(dislikeButton).toBeInTheDocument() + }) + + it('should call onFeedback with like when thumbs up is clicked', () => { + render(
    ) + + // Find the thumbs up button (first one in the feedback area) + const thumbButtons = document.querySelectorAll('[class*="cursor-pointer"]') + const thumbsUp = Array.from(thumbButtons).find(btn => + btn.className.includes('rounded-md') && !btn.className.includes('primary'), + ) + + if (thumbsUp) { + fireEvent.click(thumbsUp) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: 'like' }) + } + }) + + it('should call onFeedback with dislike when thumbs down is clicked', () => { + render(
    ) + + // Find the thumbs down button + const thumbButtons = document.querySelectorAll('[class*="cursor-pointer"]') + const thumbsDown = Array.from(thumbButtons).pop() + + if (thumbsDown) { + fireEvent.click(thumbsDown) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: 'dislike' }) + } + }) + + it('should call onFeedback with null when undo like is clicked', () => { + render( +
    , + ) + + // When liked, clicking the like button again should undo it (has bg-primary-100 class) + const likeButton = document.querySelector('[class*="bg-primary-100"]') + expect(likeButton).toBeInTheDocument() + fireEvent.click(likeButton!) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: null }) + }) + + it('should call onFeedback with null when undo dislike is clicked', () => { + render( +
    , + ) + + // When disliked, clicking the dislike button again should undo it (has bg-red-100 class) + const dislikeButton = document.querySelector('[class*="bg-red-100"]') + expect(dislikeButton).toBeInTheDocument() + fireEvent.click(dislikeButton!) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: null }) + }) + }) + + describe('feedback buttons when showFeedback is false', () => { + it('should not show feedback buttons', () => { + render( +
    , + ) + + // Should not show feedback area buttons (only copy button) + const feedbackArea = document.querySelector('[class*="space-x-1 rounded-lg border"]') + expect(feedbackArea).not.toBeInTheDocument() + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((Header as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/share/text-generation/run-once/index.spec.tsx b/web/app/components/share/text-generation/run-once/index.spec.tsx index ea5ce3c90..af3d723d2 100644 --- a/web/app/components/share/text-generation/run-once/index.spec.tsx +++ b/web/app/components/share/text-generation/run-once/index.spec.tsx @@ -1,6 +1,7 @@ +import type { InputValueTypes } from '../types' import type { PromptConfig, PromptVariable } from '@/models/debug' import type { SiteInfo } from '@/models/share' -import type { VisionSettings } from '@/types/app' +import type { VisionFile, VisionSettings } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { useEffect, useRef, useState } from 'react' @@ -27,7 +28,7 @@ vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', ( })) vi.mock('@/app/components/base/image-uploader/text-generation-image-uploader', () => { - function TextGenerationImageUploaderMock({ onFilesChange }: { onFilesChange: (files: any[]) => void }) { + function TextGenerationImageUploaderMock({ onFilesChange }: { onFilesChange: (files: VisionFile[]) => void }) { useEffect(() => { onFilesChange([]) }, [onFilesChange]) @@ -38,6 +39,20 @@ vi.mock('@/app/components/base/image-uploader/text-generation-image-uploader', ( } }) +// Mock FileUploaderInAttachmentWrapper as it requires context providers not available in tests +vi.mock('@/app/components/base/file-uploader', () => ({ + FileUploaderInAttachmentWrapper: ({ value, onChange }: { value: object[], onChange: (files: object[]) => void }) => ( +
    + + + {value?.length || 0} + {' '} + files + +
    + ), +})) + const createPromptVariable = (overrides: Partial): PromptVariable => ({ key: 'input', name: 'Input', @@ -95,11 +110,11 @@ const setup = (overrides: { const onInputsChange = vi.fn() const onSend = vi.fn() const onVisionFilesChange = vi.fn() - let inputsRefCapture: React.MutableRefObject> | null = null + let inputsRefCapture: React.MutableRefObject> | null = null const Wrapper = () => { - const [inputs, setInputs] = useState>({}) - const inputsRef = useRef>({}) + const [inputs, setInputs] = useState>({}) + const inputsRef = useRef>({}) inputsRefCapture = inputsRef return ( { expect(stopButton).toBeDisabled() }) + describe('select input type', () => { + it('should render select input and handle selection', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'selectInput', + name: 'Select Input', + type: 'select', + options: ['Option A', 'Option B', 'Option C'], + default: 'Option A', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + selectInput: 'Option A', + }) + }) + // The Select component should be rendered + expect(screen.getByText('Select Input')).toBeInTheDocument() + }) + }) + + describe('file input types', () => { + it('should render file uploader for single file input', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'fileInput', + name: 'File Input', + type: 'file', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + fileInput: undefined, + }) + }) + expect(screen.getByText('File Input')).toBeInTheDocument() + }) + + it('should render file uploader for file-list input', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'fileListInput', + name: 'File List Input', + type: 'file-list', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + fileListInput: [], + }) + }) + expect(screen.getByText('File List Input')).toBeInTheDocument() + }) + }) + + describe('json_object input type', () => { + it('should render code editor for json_object input', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'jsonInput', + name: 'JSON Input', + type: 'json_object' as PromptVariable['type'], + json_schema: '{"type": "object"}', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + jsonInput: undefined, + }) + }) + expect(screen.getByText('JSON Input')).toBeInTheDocument() + expect(screen.getByTestId('code-editor-mock')).toBeInTheDocument() + }) + + it('should update json_object input when code editor changes', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'jsonInput', + name: 'JSON Input', + type: 'json_object' as PromptVariable['type'], + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + onInputsChange.mockClear() + + const codeEditor = screen.getByTestId('code-editor-mock') + fireEvent.change(codeEditor, { target: { value: '{"key": "value"}' } }) + + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + jsonInput: '{"key": "value"}', + }) + }) + }) + }) + + describe('hidden and optional fields', () => { + it('should not render hidden variables', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'hiddenInput', + name: 'Hidden Input', + type: 'string', + hide: true, + }), + createPromptVariable({ + key: 'visibleInput', + name: 'Visible Input', + type: 'string', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + expect(screen.queryByText('Hidden Input')).not.toBeInTheDocument() + expect(screen.getByText('Visible Input')).toBeInTheDocument() + }) + + it('should show optional label for non-required fields', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'optionalInput', + name: 'Optional Input', + type: 'string', + required: false, + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + expect(screen.getByText('workflow.panel.optional')).toBeInTheDocument() + }) + }) + + describe('vision uploader', () => { + it('should not render vision uploader when disabled', async () => { + const { onInputsChange } = setup({ visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + expect(screen.queryByText('common.imageUploader.imageUpload')).not.toBeInTheDocument() + }) + }) + + describe('clear with different input types', () => { + it('should clear select input to undefined', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'selectInput', + name: 'Select Input', + type: 'select', + options: ['Option A', 'Option B'], + default: 'Option A', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + onInputsChange.mockClear() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.clear' })) + + expect(onInputsChange).toHaveBeenCalledWith({ + selectInput: undefined, + }) + }) + }) + describe('maxLength behavior', () => { it('should not have maxLength attribute when max_length is not set', async () => { const promptConfig: PromptConfig = { diff --git a/web/app/components/share/utils.spec.ts b/web/app/components/share/utils.spec.ts new file mode 100644 index 000000000..ee2aab58e --- /dev/null +++ b/web/app/components/share/utils.spec.ts @@ -0,0 +1,71 @@ +import { describe, expect, it } from 'vitest' +import { getInitialTokenV2, isTokenV1 } from './utils' + +describe('utils', () => { + describe('isTokenV1', () => { + it('should return true when token has no version property', () => { + const token = { someKey: 'value' } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is undefined', () => { + const token = { version: undefined } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is null', () => { + const token = { version: null } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is 0', () => { + const token = { version: 0 } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is empty string', () => { + const token = { version: '' } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return false when token has version 1', () => { + const token = { version: 1 } + expect(isTokenV1(token)).toBe(false) + }) + + it('should return false when token has version 2', () => { + const token = { version: 2 } + expect(isTokenV1(token)).toBe(false) + }) + + it('should return false when token has string version', () => { + const token = { version: '2' } + expect(isTokenV1(token)).toBe(false) + }) + + it('should handle empty object', () => { + const token = {} + expect(isTokenV1(token)).toBe(true) + }) + }) + + describe('getInitialTokenV2', () => { + it('should return object with version 2', () => { + const token = getInitialTokenV2() + expect(token.version).toBe(2) + }) + + it('should return a new object each time', () => { + const token1 = getInitialTokenV2() + const token2 = getInitialTokenV2() + expect(token1).not.toBe(token2) + }) + + it('should return an object that can be modified without affecting future calls', () => { + const token1 = getInitialTokenV2() + token1.customField = 'test' + const token2 = getInitialTokenV2() + expect(token2.customField).toBeUndefined() + }) + }) +}) diff --git a/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx b/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx index 848412f0a..204772a3e 100644 --- a/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx @@ -172,6 +172,9 @@ describe('EditCustomCollectionModal', () => { expect(parseParamsSchemaMock).toHaveBeenCalledWith('{}') }) + // Flush pending state updates from parseParamsSchema promise resolution + await act(async () => {}) + await act(async () => { fireEvent.click(screen.getByText('common.operation.save')) }) @@ -184,6 +187,10 @@ describe('EditCustomCollectionModal', () => { credentials: { auth_type: 'none', }, + icon: { + content: '🕵️', + background: '#FEF7C3', + }, labels: [], })) expect(toastNotifySpy).not.toHaveBeenCalled() diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx index 1314dc90d..525946bb1 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx @@ -551,8 +551,8 @@ describe('WorkflowOnboardingModal', () => { // Assert const escKey = screen.getByText('workflow.onboarding.escTip.key') - expect(escKey.closest('kbd')).toBeInTheDocument() - expect(escKey.closest('kbd')).toHaveClass('system-kbd') + // ShortcutsName renders a
    with class system-kbd, not a element + expect(escKey.closest('.system-kbd')).toBeInTheDocument() }) it('should have descriptive text for ESC functionality', () => { diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx index c483abfb0..16bae5124 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx @@ -7,6 +7,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import Modal from '@/app/components/base/modal' +import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { BlockEnum } from '@/app/components/workflow/types' import StartNodeSelectionPanel from './start-node-selection-panel' @@ -75,9 +76,7 @@ const WorkflowOnboardingModal: FC = ({ {isShow && (
    {t('onboarding.escTip.press', { ns: 'workflow' })} - - {t('onboarding.escTip.key', { ns: 'workflow' })} - + {t('onboarding.escTip.toDismiss', { ns: 'workflow' })}
    )} diff --git a/web/app/components/workflow-app/hooks/use-DSL.ts b/web/app/components/workflow-app/hooks/use-DSL.ts index 6c01509bc..939e43b55 100644 --- a/web/app/components/workflow-app/hooks/use-DSL.ts +++ b/web/app/components/workflow-app/hooks/use-DSL.ts @@ -11,6 +11,7 @@ import { import { useEventEmitterContextContext } from '@/context/event-emitter' import { exportAppConfig } from '@/service/apps' import { fetchWorkflowDraft } from '@/service/workflow' +import { downloadBlob } from '@/utils/download' import { useNodesSyncDraft } from './use-nodes-sync-draft' export const useDSL = () => { @@ -37,13 +38,8 @@ export const useDSL = () => { include, workflowID: workflowId, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${appDetail.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/workflow/block-selector/market-place-plugin/action.tsx b/web/app/components/workflow/block-selector/market-place-plugin/action.tsx index b8300d6f2..abdbae1b4 100644 --- a/web/app/components/workflow/block-selector/market-place-plugin/action.tsx +++ b/web/app/components/workflow/block-selector/market-place-plugin/action.tsx @@ -15,7 +15,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { useDownloadPlugin } from '@/service/use-plugins' import { cn } from '@/utils/classnames' -import { downloadFile } from '@/utils/format' +import { downloadBlob } from '@/utils/download' import { getMarketplaceUrl } from '@/utils/var' type Props = { @@ -67,7 +67,7 @@ const OperationDropdown: FC = ({ if (!needDownload || !blob) return const fileName = `${author}-${name}_${version}.zip` - downloadFile({ data: blob, fileName }) + downloadBlob({ data: blob, fileName }) setNeedDownload(false) queryClient.removeQueries({ queryKey: ['plugins', 'downloadPlugin', downloadInfo], diff --git a/web/app/components/workflow/header/run-mode.tsx b/web/app/components/workflow/header/run-mode.tsx index 1a101bc6d..74bc5bc80 100644 --- a/web/app/components/workflow/header/run-mode.tsx +++ b/web/app/components/workflow/header/run-mode.tsx @@ -7,9 +7,9 @@ import { trackEvent } from '@/app/components/base/amplitude' import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' import { useToastContext } from '@/app/components/base/toast' import { useWorkflowRun, useWorkflowRunValidation, useWorkflowStartRun } from '@/app/components/workflow/hooks' +import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { useStore } from '@/app/components/workflow/store' import { WorkflowRunningStatus } from '@/app/components/workflow/types' -import { getKeyboardKeyNameBySystem } from '@/app/components/workflow/utils' import { EVENT_WORKFLOW_STOP } from '@/app/components/workflow/variable-inspect/types' import { useEventEmitterContextContext } from '@/context/event-emitter' import { cn } from '@/utils/classnames' @@ -143,14 +143,7 @@ const RunMode = ({ > {text ?? t('common.run', { ns: 'workflow' })} -
    -
    - {getKeyboardKeyNameBySystem('alt')} -
    -
    - R -
    -
    +
    ) diff --git a/web/app/components/workflow/header/version-history-button.tsx b/web/app/components/workflow/header/version-history-button.tsx index 32e72dc18..b98dfeea7 100644 --- a/web/app/components/workflow/header/version-history-button.tsx +++ b/web/app/components/workflow/header/version-history-button.tsx @@ -8,7 +8,8 @@ import useTheme from '@/hooks/use-theme' import { cn } from '@/utils/classnames' import Button from '../../base/button' import Tooltip from '../../base/tooltip' -import { getKeyboardKeyCodeBySystem, getKeyboardKeyNameBySystem } from '../utils' +import ShortcutsName from '../shortcuts-name' +import { getKeyboardKeyCodeBySystem } from '../utils' type VersionHistoryButtonProps = { onClick: () => Promise | unknown @@ -23,16 +24,7 @@ const PopupContent = React.memo(() => {
    {t('common.versionHistory', { ns: 'workflow' })}
    -
    - {VERSION_HISTORY_SHORTCUT.map(key => ( - - {getKeyboardKeyNameBySystem(key)} - - ))} -
    +
    ) }) diff --git a/web/app/components/workflow/nodes/http/components/curl-panel.tsx b/web/app/components/workflow/nodes/http/components/curl-panel.tsx index aa67a2a0a..6c809c310 100644 --- a/web/app/components/workflow/nodes/http/components/curl-panel.tsx +++ b/web/app/components/workflow/nodes/http/components/curl-panel.tsx @@ -41,7 +41,7 @@ const parseCurl = (curlCommand: string): { node: HttpNodeType | null, error: str case '--request': if (i + 1 >= args.length) return { node: null, error: 'Missing HTTP method after -X or --request.' } - node.method = (args[++i].replace(/^['"]|['"]$/g, '') as Method) || Method.get + node.method = (args[++i].replace(/^['"]|['"]$/g, '').toLowerCase() as Method) || Method.get hasData = true break case '-H': diff --git a/web/app/components/workflow/nodes/knowledge-base/panel.tsx b/web/app/components/workflow/nodes/knowledge-base/panel.tsx index 0a275645a..2845d605b 100644 --- a/web/app/components/workflow/nodes/knowledge-base/panel.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/panel.tsx @@ -18,6 +18,7 @@ import { Group, } from '@/app/components/workflow/nodes/_base/components/layout' import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker' +import { IS_CE_EDITION } from '@/config' import Split from '../_base/components/split' import ChunkStructure from './components/chunk-structure' import EmbeddingModel from './components/embedding-model' @@ -172,7 +173,7 @@ const Panel: FC> = ({ { data.indexing_technique === IndexMethodEnum.QUALIFIED && [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure) - && ( + && IS_CE_EDITION && ( <> void } -const Key = (props: { keyName: string }) => { - const { keyName } = props - return ( - - {keyName} - - ) -} - const AdvancedActions: FC = ({ isConfirmDisabled, onCancel, @@ -48,10 +40,7 @@ const AdvancedActions: FC = ({ onClick={onConfirm} > {t('operation.confirm', { ns: 'common' })} -
    - - -
    +
    ) diff --git a/web/app/components/workflow/nodes/trigger-plugin/default.ts b/web/app/components/workflow/nodes/trigger-plugin/default.ts index 0cb2a72c9..605a1e3f1 100644 --- a/web/app/components/workflow/nodes/trigger-plugin/default.ts +++ b/web/app/components/workflow/nodes/trigger-plugin/default.ts @@ -221,7 +221,7 @@ const buildOutputVars = (schema: Record, schemaTypeDefinitions?: Sc const metaData = genNodeMetaData({ sort: 1, type: BlockEnum.TriggerPlugin, - helpLinkUri: 'plugin-trigger', + helpLinkUri: 'trigger/plugin-trigger', isStart: true, }) diff --git a/web/app/components/workflow/nodes/trigger-schedule/default.ts b/web/app/components/workflow/nodes/trigger-schedule/default.ts index 4f166675e..587a125c2 100644 --- a/web/app/components/workflow/nodes/trigger-schedule/default.ts +++ b/web/app/components/workflow/nodes/trigger-schedule/default.ts @@ -110,7 +110,7 @@ const validateVisualConfig = (payload: ScheduleTriggerNodeType, t: any): string const metaData = genNodeMetaData({ sort: 2, type: BlockEnum.TriggerSchedule, - helpLinkUri: 'schedule-trigger', + helpLinkUri: 'trigger/schedule-trigger', isStart: true, }) diff --git a/web/app/components/workflow/nodes/trigger-webhook/default.ts b/web/app/components/workflow/nodes/trigger-webhook/default.ts index ec0369d75..66fae30b0 100644 --- a/web/app/components/workflow/nodes/trigger-webhook/default.ts +++ b/web/app/components/workflow/nodes/trigger-webhook/default.ts @@ -8,7 +8,7 @@ import { createWebhookRawVariable } from './utils/raw-variable' const metaData = genNodeMetaData({ sort: 3, type: BlockEnum.TriggerWebhook, - helpLinkUri: 'webhook-trigger', + helpLinkUri: 'trigger/webhook-trigger', isStart: true, }) diff --git a/web/app/components/workflow/operator/more-actions.tsx b/web/app/components/workflow/operator/more-actions.tsx index e9fc1ea87..7e6617e84 100644 --- a/web/app/components/workflow/operator/more-actions.tsx +++ b/web/app/components/workflow/operator/more-actions.tsx @@ -19,6 +19,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { useStore } from '@/app/components/workflow/store' import { cn } from '@/utils/classnames' +import { downloadUrl } from '@/utils/download' import { useNodesReadOnly } from '../hooks' import TipPopup from './tip-popup' @@ -146,26 +147,14 @@ const MoreActions: FC = () => { } } + const fileName = `${filename}.${type}` + if (currentWorkflow) { setPreviewUrl(dataUrl) - setPreviewTitle(`${filename}.${type}`) + setPreviewTitle(fileName) + } - const link = document.createElement('a') - link.href = dataUrl - link.download = `${filename}.${type}` - document.body.appendChild(link) - link.click() - document.body.removeChild(link) - } - else { - // For current view, just download - const link = document.createElement('a') - link.href = dataUrl - link.download = `${filename}.${type}` - document.body.appendChild(link) - link.click() - document.body.removeChild(link) - } + downloadUrl({ url: dataUrl, fileName }) } catch (error) { console.error('Export image failed:', error) diff --git a/web/app/components/workflow/shortcuts-name.tsx b/web/app/components/workflow/shortcuts-name.tsx index d0ce007f6..3d21cff31 100644 --- a/web/app/components/workflow/shortcuts-name.tsx +++ b/web/app/components/workflow/shortcuts-name.tsx @@ -6,11 +6,13 @@ type ShortcutsNameProps = { keys: string[] className?: string textColor?: 'default' | 'secondary' + bgColor?: 'gray' | 'white' } const ShortcutsName = ({ keys, className, textColor = 'default', + bgColor = 'gray', }: ShortcutsNameProps) => { return (
    diff --git a/web/app/components/workflow/utils/gen-node-meta-data.ts b/web/app/components/workflow/utils/gen-node-meta-data.ts index f45bfcb01..e625e3a8a 100644 --- a/web/app/components/workflow/utils/gen-node-meta-data.ts +++ b/web/app/components/workflow/utils/gen-node-meta-data.ts @@ -1,4 +1,5 @@ import type { BlockEnum } from '@/app/components/workflow/types' +import type { UseDifyNodesPath } from '@/types/doc-paths' import { BlockClassificationEnum } from '@/app/components/workflow/block-selector/types' export type GenNodeMetaDataParams = { @@ -7,7 +8,7 @@ export type GenNodeMetaDataParams = { type: BlockEnum title?: string author?: string - helpLinkUri?: string + helpLinkUri?: UseDifyNodesPath isRequired?: boolean isUndeletable?: boolean isStart?: boolean diff --git a/web/app/styles/monaco-sticky-fix.css b/web/app/styles/monaco-sticky-fix.css index 66bb5921c..ac928cf24 100644 --- a/web/app/styles/monaco-sticky-fix.css +++ b/web/app/styles/monaco-sticky-fix.css @@ -9,8 +9,7 @@ html[data-theme="dark"] .monaco-editor .sticky-line-content:hover { background-color: var(--color-components-sticky-header-bg-hover) !important; } -/* Fallback: any app sticky header using input-bg variables should use the sticky header bg when sticky */ -html[data-theme="dark"] .sticky, html[data-theme="dark"] .is-sticky { +/* Monaco editor specific sticky scroll styles in dark mode */ +html[data-theme="dark"] .monaco-editor .sticky-line-root { background-color: var(--color-components-sticky-header-bg) !important; - border-bottom: 1px solid var(--color-components-sticky-header-border) !important; } \ No newline at end of file diff --git a/web/context/i18n.ts b/web/context/i18n.ts index 2766dfe5e..5f39d1afb 100644 --- a/web/context/i18n.ts +++ b/web/context/i18n.ts @@ -1,6 +1,7 @@ import type { Locale } from '@/i18n-config/language' import type { DocPathWithoutLang } from '@/types/doc-paths' import { useTranslation } from '#i18n' +import { useCallback } from 'react' import { getDocLanguage, getLanguage, getPricingPageLanguage } from '@/i18n-config/language' import { apiReferencePathTranslations } from '@/types/doc-paths' @@ -27,21 +28,24 @@ export const useDocLink = (baseUrl?: string): ((path?: DocPathWithoutLang, pathM let baseDocUrl = baseUrl || defaultDocBaseUrl baseDocUrl = (baseDocUrl.endsWith('/')) ? baseDocUrl.slice(0, -1) : baseDocUrl const locale = useLocale() - const docLanguage = getDocLanguage(locale) - return (path?: DocPathWithoutLang, pathMap?: DocPathMap): string => { - const pathUrl = path || '' - let targetPath = (pathMap) ? pathMap[locale] || pathUrl : pathUrl - let languagePrefix = `/${docLanguage}` + return useCallback( + (path?: DocPathWithoutLang, pathMap?: DocPathMap): string => { + const docLanguage = getDocLanguage(locale) + const pathUrl = path || '' + let targetPath = (pathMap) ? pathMap[locale] || pathUrl : pathUrl + let languagePrefix = `/${docLanguage}` - // Translate API reference paths for non-English locales - if (targetPath.startsWith('/api-reference/') && docLanguage !== 'en') { - const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage as 'zh' | 'ja'] - if (translatedPath) { - targetPath = translatedPath - languagePrefix = '' + // Translate API reference paths for non-English locales + if (targetPath.startsWith('/api-reference/') && docLanguage !== 'en') { + const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage as 'zh' | 'ja'] + if (translatedPath) { + targetPath = translatedPath + languagePrefix = '' + } } - } - return `${baseDocUrl}${languagePrefix}${targetPath}` - } + return `${baseDocUrl}${languagePrefix}${targetPath}` + }, + [baseDocUrl, locale], + ) } diff --git a/web/docs/lint.md b/web/docs/lint.md new file mode 100644 index 000000000..051f9e6ec --- /dev/null +++ b/web/docs/lint.md @@ -0,0 +1,51 @@ +# Lint Guide + +We use ESLint and Typescript to maintain code quality and consistency across the project. + +## ESLint + +### Common Flags + +**File/folder targeting**: Append paths to lint specific files or directories. + +```sh +pnpm eslint [options] file.js [file.js] [dir] +``` + +**`--cache`**: Caches lint results for faster subsequent runs. Keep this enabled by default; only disable when you encounter unexpected lint results. + +**`--concurrency`**: Enables multi-threaded linting. Use `--concurrency=auto` or experiment with specific numbers to find the optimal setting for your machine. Keep this enabled when linting multiple files. + +- [ESLint multi-thread linting blog post](https://eslint.org/blog/2025/08/multithread-linting/) + +**`--fix`**: Automatically fixes auto-fixable rule violations. Always review the diff before committing to ensure no unintended changes. + +**`--quiet`**: Suppresses warnings and only shows errors. Useful when you want to reduce noise from existing issues. + +**`--suppress-all`**: Temporarily suppresses error-level violations and records them, allowing CI to pass. Treat this as an escape hatch—fix these errors when time permits. + +**`--prune-suppressions`**: Removes outdated suppressions after you've fixed the underlying errors. + +- [ESLint bulk suppressions blog post](https://eslint.org/blog/2025/04/introducing-bulk-suppressions/) + +### Type-Aware Linting + +Some ESLint rules require type information, such as [no-leaked-conditional-rendering](https://www.eslint-react.xyz/docs/rules/no-leaked-conditional-rendering). However, [typed linting via typescript-eslint](https://typescript-eslint.io/getting-started/typed-linting) is too slow for practical use, so we use [TSSLint](https://github.com/johnsoncodehk/tsslint) instead. + +```sh +pnpm lint:tss +``` + +This command lints the entire project and is intended for final verification before committing or pushing changes. + +## Type Check + +You should be able to see suggestions from TypeScript in your editor for all open files. + +However, it can be useful to run the TypeScript 7 command-line (tsgo) to type check all files: + +```sh +pnpm type-check:tsgo +``` + +Prefer using `tsgo` for type checking as it is significantly faster than the standard TypeScript compiler. Only fall back to `pnpm type-check` (which uses `tsc`) if you encounter unexpected results. diff --git a/web/testing/testing.md b/web/docs/test.md similarity index 99% rename from web/testing/testing.md rename to web/docs/test.md index 47341e445..cac0e0e35 100644 --- a/web/testing/testing.md +++ b/web/docs/test.md @@ -360,11 +360,11 @@ describe('ComponentName', () => { let mockPortalOpenState = false vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open, ...props }: any) => { + PortalToFollowElem: ({ children, open, ...props }) => { mockPortalOpenState = open || false // Update shared state return
    {children}
    }, - PortalToFollowElemContent: ({ children }: any) => { + PortalToFollowElemContent: ({ children }) => { // ✅ Matches actual: returns null when open is false if (!mockPortalOpenState) return null diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 90632b9ff..e5ced085f 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -177,6 +177,11 @@ "count": 1 } }, + "app/components/app/annotation/add-annotation-modal/edit-item/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx": { "ts/no-explicit-any": { "count": 2 @@ -186,6 +191,9 @@ "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -193,6 +201,9 @@ "app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 + }, + "react-refresh/only-export-components": { + "count": 1 } }, "app/components/app/annotation/edit-annotation-modal/index.spec.tsx": { @@ -252,6 +263,11 @@ "count": 6 } }, + "app/components/app/configuration/base/var-highlight/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/app/configuration/config-prompt/advanced-prompt-input.tsx": { "ts/no-explicit-any": { "count": 2 @@ -422,6 +438,11 @@ "count": 6 } }, + "app/components/app/configuration/debug/debug-with-multiple-model/context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx": { "ts/no-explicit-any": { "count": 5 @@ -504,6 +525,11 @@ "count": 1 } }, + "app/components/app/create-app-dialog/app-list/sidebar.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/app/create-app-modal/index.spec.tsx": { "ts/no-explicit-any": { "count": 7 @@ -520,6 +546,14 @@ "app/components/app/create-from-dsl-modal/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 + }, + "react-refresh/only-export-components": { + "count": 1 + } + }, + "app/components/app/log/filter.tsx": { + "react-refresh/only-export-components": { + "count": 1 } }, "app/components/app/log/index.tsx": { @@ -588,6 +622,9 @@ "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 }, + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 4 } @@ -597,6 +634,11 @@ "count": 2 } }, + "app/components/app/workflow-log/filter.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/app/workflow-log/list.spec.tsx": { "ts/no-explicit-any": { "count": 1 @@ -648,6 +690,11 @@ "count": 1 } }, + "app/components/base/action-button/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/agent-log-modal/detail.tsx": { "ts/no-explicit-any": { "count": 1 @@ -676,6 +723,11 @@ "count": 2 } }, + "app/components/base/amplitude/AmplitudeProvider.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/amplitude/utils.ts": { "ts/no-explicit-any": { "count": 2 @@ -724,6 +776,9 @@ "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, + "react-refresh/only-export-components": { + "count": 1 + }, "react/no-nested-component-definitions": { "count": 1 } @@ -733,11 +788,21 @@ "count": 1 } }, + "app/components/base/button/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/button/sync-button.stories.tsx": { "no-console": { "count": 1 } }, + "app/components/base/carousel/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/chat/chat-with-history/chat-wrapper.tsx": { "ts/no-explicit-any": { "count": 6 @@ -817,6 +882,11 @@ "count": 1 } }, + "app/components/base/chat/chat/context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/chat/chat/hooks.ts": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 @@ -920,10 +990,18 @@ } }, "app/components/base/error-boundary/index.tsx": { + "react-refresh/only-export-components": { + "count": 3 + }, "ts/no-explicit-any": { "count": 2 } }, + "app/components/base/features/context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/features/new-feature-panel/annotation-reply/index.tsx": { "ts/no-explicit-any": { "count": 3 @@ -989,12 +1067,17 @@ "count": 3 } }, + "app/components/base/file-uploader/store.tsx": { + "react-refresh/only-export-components": { + "count": 4 + } + }, "app/components/base/file-uploader/utils.spec.ts": { "test/no-identical-title": { "count": 1 }, "ts/no-explicit-any": { - "count": 3 + "count": 2 } }, "app/components/base/file-uploader/utils.ts": { @@ -1085,6 +1168,11 @@ "count": 2 } }, + "app/components/base/ga/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/icons/utils.ts": { "ts/no-explicit-any": { "count": 3 @@ -1136,6 +1224,16 @@ "count": 1 } }, + "app/components/base/input/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "app/components/base/logo/dify-logo.tsx": { + "react-refresh/only-export-components": { + "count": 2 + } + }, "app/components/base/markdown-blocks/audio-block.tsx": { "ts/no-explicit-any": { "count": 5 @@ -1286,6 +1384,11 @@ "count": 1 } }, + "app/components/base/node-status/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/base/notion-connector/index.stories.tsx": { "no-console": { "count": 1 @@ -1317,6 +1420,9 @@ } }, "app/components/base/portal-to-follow-elem/index.tsx": { + "react-refresh/only-export-components": { + "count": 2 + }, "ts/no-explicit-any": { "count": 1 } @@ -1483,6 +1589,16 @@ "count": 1 } }, + "app/components/base/textarea/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "app/components/base/toast/index.tsx": { + "react-refresh/only-export-components": { + "count": 2 + } + }, "app/components/base/video-gallery/VideoPlayer.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -1526,6 +1642,16 @@ "count": 2 } }, + "app/components/billing/pricing/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "app/components/billing/pricing/plan-switcher/plan-range-switcher.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx": { "test/prefer-hooks-in-order": { "count": 1 @@ -1576,6 +1702,11 @@ "count": 3 } }, + "app/components/datasets/common/image-uploader/store.tsx": { + "react-refresh/only-export-components": { + "count": 4 + } + }, "app/components/datasets/common/image-uploader/utils.ts": { "ts/no-explicit-any": { "count": 2 @@ -1586,6 +1717,16 @@ "count": 1 } }, + "app/components/datasets/common/retrieval-method-info/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/datasets/create/file-preview/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -1616,6 +1757,11 @@ "count": 3 } }, + "app/components/datasets/create/step-two/preview-item/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/datasets/create/stop-embedding-modal/index.spec.tsx": { "test/prefer-hooks-in-order": { "count": 1 @@ -1661,7 +1807,7 @@ "count": 1 }, "ts/no-explicit-any": { - "count": 5 + "count": 4 } }, "app/components/datasets/create/website/watercrawl/options.tsx": { @@ -1673,6 +1819,9 @@ "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -1737,6 +1886,11 @@ "count": 2 } }, + "app/components/datasets/documents/create-from-pipeline/data-source/store/provider.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/datasets/documents/create-from-pipeline/data-source/store/slices/online-drive.ts": { "ts/no-explicit-any": { "count": 4 @@ -1782,6 +1936,11 @@ "count": 1 } }, + "app/components/datasets/documents/detail/completed/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/datasets/documents/detail/completed/new-child-segment.tsx": { "ts/no-explicit-any": { "count": 1 @@ -1805,6 +1964,11 @@ "count": 1 } }, + "app/components/datasets/documents/detail/segment-add/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx": { "ts/no-explicit-any": { "count": 6 @@ -1945,6 +2109,11 @@ "count": 1 } }, + "app/components/explore/try-app/tab.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/goto-anything/actions/commands/command-bus.ts": { "ts/no-explicit-any": { "count": 2 @@ -1956,6 +2125,9 @@ } }, "app/components/goto-anything/actions/commands/slash.tsx": { + "react-refresh/only-export-components": { + "count": 3 + }, "ts/no-explicit-any": { "count": 1 } @@ -1973,15 +2145,8 @@ "app/components/goto-anything/context.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 4 - } - }, - "app/components/goto-anything/index.spec.tsx": { - "ts/no-explicit-any": { - "count": 5 - } - }, - "app/components/goto-anything/index.tsx": { - "react-hooks-extra/no-direct-set-state-in-use-effect": { + }, + "react-refresh/only-export-components": { "count": 1 } }, @@ -2169,6 +2334,11 @@ "count": 4 } }, + "app/components/plugins/install-plugin/install-bundle/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/plugins/install-plugin/install-bundle/item/github-item.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -2245,6 +2415,11 @@ "count": 2 } }, + "app/components/plugins/plugin-auth/index.tsx": { + "react-refresh/only-export-components": { + "count": 3 + } + }, "app/components/plugins/plugin-auth/plugin-auth-in-agent.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2329,6 +2504,9 @@ } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -2369,6 +2547,9 @@ } }, "app/components/plugins/plugin-page/context.tsx": { + "react-refresh/only-export-components": { + "count": 2 + }, "ts/no-explicit-any": { "count": 1 } @@ -2584,11 +2765,6 @@ "count": 2 } }, - "app/components/share/text-generation/run-once/index.spec.tsx": { - "ts/no-explicit-any": { - "count": 4 - } - }, "app/components/share/text-generation/run-once/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -2738,6 +2914,11 @@ "count": 1 } }, + "app/components/workflow/block-selector/constants.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/block-selector/featured-tools.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 @@ -2759,6 +2940,11 @@ "count": 1 } }, + "app/components/workflow/block-selector/index-bar.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/block-selector/market-place-plugin/action.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -2797,11 +2983,26 @@ "count": 1 } }, + "app/components/workflow/block-selector/view-type-select.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/candidate-node-main.tsx": { "ts/no-explicit-any": { "count": 2 } }, + "app/components/workflow/context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "app/components/workflow/datasets-detail-store/provider.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/header/run-mode.tsx": { "no-console": { "count": 1 @@ -2810,11 +3011,21 @@ "count": 1 } }, + "app/components/workflow/header/test-run-menu.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/header/view-workflow-history.tsx": { "ts/no-explicit-any": { "count": 1 } }, + "app/components/workflow/hooks-store/provider.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/hooks-store/store.ts": { "ts/no-explicit-any": { "count": 6 @@ -2935,10 +3146,18 @@ } }, "app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx": { + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 6 } }, + "app/components/workflow/nodes/_base/components/entry-node-container.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/error-handle/default-value.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2964,6 +3183,16 @@ "count": 1 } }, + "app/components/workflow/nodes/_base/components/layout/index.tsx": { + "react-refresh/only-export-components": { + "count": 7 + } + }, + "app/components/workflow/nodes/_base/components/mcp-tool-availability.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/memory-config.tsx": { "unicorn/prefer-number-properties": { "count": 1 @@ -3047,6 +3276,9 @@ } }, "app/components/workflow/nodes/_base/components/workflow-panel/tab.tsx": { + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -3100,6 +3332,9 @@ } }, "app/components/workflow/nodes/agent/panel.tsx": { + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -3383,6 +3618,11 @@ "count": 2 } }, + "app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx": { + "react-refresh/only-export-components": { + "count": 3 + } + }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/auto-width-input.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -3680,6 +3920,11 @@ "count": 1 } }, + "app/components/workflow/note-node/note-editor/toolbar/color-picker.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "app/components/workflow/note-node/note-editor/utils.ts": { "regexp/no-useless-quantifier": { "count": 1 @@ -3716,6 +3961,9 @@ } }, "app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx": { + "react-refresh/only-export-components": { + "count": 1 + }, "ts/no-explicit-any": { "count": 5 }, @@ -3995,6 +4243,11 @@ "count": 8 } }, + "app/components/workflow/workflow-history-store.tsx": { + "react-refresh/only-export-components": { + "count": 2 + } + }, "app/components/workflow/workflow-preview/components/nodes/constants.ts": { "ts/no-explicit-any": { "count": 1 @@ -4056,30 +4309,79 @@ } }, "context/app-context.tsx": { + "react-refresh/only-export-components": { + "count": 2 + }, "ts/no-explicit-any": { "count": 1 } }, + "context/datasets-context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "context/event-emitter.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "context/external-api-panel-context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "context/external-knowledge-api-context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "context/global-public-context.tsx": { + "react-refresh/only-export-components": { + "count": 4 + } + }, "context/hooks/use-trigger-events-limit-modal.ts": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 } }, + "context/mitt-context.tsx": { + "react-refresh/only-export-components": { + "count": 3 + } + }, "context/modal-context.test.tsx": { "ts/no-explicit-any": { "count": 3 } }, "context/modal-context.tsx": { + "react-refresh/only-export-components": { + "count": 2 + }, "ts/no-explicit-any": { "count": 5 } }, "context/provider-context.tsx": { + "react-refresh/only-export-components": { + "count": 3 + }, "ts/no-explicit-any": { "count": 1 } }, + "context/web-app-context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, + "context/workspace-context.tsx": { + "react-refresh/only-export-components": { + "count": 1 + } + }, "hooks/use-async-window-open.spec.ts": { "ts/no-explicit-any": { "count": 6 @@ -4116,6 +4418,9 @@ "hooks/use-pay.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 4 + }, + "react-refresh/only-export-components": { + "count": 3 } }, "i18n-config/README.md": { @@ -4323,11 +4628,6 @@ "count": 10 } }, - "testing/testing.md": { - "ts/no-explicit-any": { - "count": 2 - } - }, "types/app.ts": { "ts/no-explicit-any": { "count": 1 @@ -4381,11 +4681,6 @@ "count": 1 } }, - "utils/format.spec.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, "utils/get-icon.spec.ts": { "ts/no-explicit-any": { "count": 2 diff --git a/web/eslint.config.mjs b/web/eslint.config.mjs index 9ef3f8d04..3f3bef8c0 100644 --- a/web/eslint.config.mjs +++ b/web/eslint.config.mjs @@ -4,7 +4,7 @@ import pluginQuery from '@tanstack/eslint-plugin-query' import sonar from 'eslint-plugin-sonarjs' import storybook from 'eslint-plugin-storybook' import tailwind from 'eslint-plugin-tailwindcss' -import difyI18n from './eslint-rules/index.js' +import dify from './eslint-rules/index.js' export default antfu( { @@ -104,44 +104,25 @@ export default antfu( 'tailwindcss/migration-from-tailwind-2': 'warn', }, }, - // dify i18n namespace migration - // { - // files: ['**/*.ts', '**/*.tsx'], - // ignores: ['eslint-rules/**', 'i18n/**', 'i18n-config/**'], - // plugins: { - // 'dify-i18n': difyI18n, - // }, - // rules: { - // // 'dify-i18n/no-as-any-in-t': ['error', { mode: 'all' }], - // 'dify-i18n/no-as-any-in-t': 'error', - // // 'dify-i18n/no-legacy-namespace-prefix': 'error', - // // 'dify-i18n/require-ns-option': 'error', - // }, - // }, - // i18n JSON validation rules + { + plugins: { dify }, + }, { files: ['i18n/**/*.json'], - plugins: { - 'dify-i18n': difyI18n, - }, rules: { 'sonarjs/max-lines': 'off', 'max-lines': 'off', 'jsonc/sort-keys': 'error', - 'dify-i18n/valid-i18n-keys': 'error', - 'dify-i18n/no-extra-keys': 'error', - 'dify-i18n/consistent-placeholders': 'error', + 'dify/valid-i18n-keys': 'error', + 'dify/no-extra-keys': 'error', + 'dify/consistent-placeholders': 'error', }, }, - // package.json version prefix validation { files: ['**/package.json'], - plugins: { - 'dify-i18n': difyI18n, - }, rules: { - 'dify-i18n/no-version-prefix': 'error', + 'dify/no-version-prefix': 'error', }, }, ) diff --git a/web/i18n/en-US/explore.json b/web/i18n/en-US/explore.json index 89bbea81e..68b8b30b0 100644 --- a/web/i18n/en-US/explore.json +++ b/web/i18n/en-US/explore.json @@ -36,5 +36,5 @@ "tryApp.requirements": "Requirements", "tryApp.tabHeader.detail": "Orchestration Details", "tryApp.tabHeader.try": "Try it", - "tryApp.tryInfo": "This is a sample app. You can try up to 5 messages. To keep using it, click \"Create form this sample app\" and set it up!" + "tryApp.tryInfo": "This is a sample app. You can try up to 5 messages. To keep using it, click \"Create from this sample app\" and set it up!" } diff --git a/web/next.config.ts b/web/next.config.ts index 3a672d74c..3999039fd 100644 --- a/web/next.config.ts +++ b/web/next.config.ts @@ -76,6 +76,9 @@ const nextConfig: NextConfig = { compiler: { removeConsole: isDev ? false : { exclude: ['warn', 'error'] }, }, + experimental: { + turbopackFileSystemCacheForDev: false, + }, } export default withBundleAnalyzer(withMDX(nextConfig)) diff --git a/web/package.json b/web/package.json index d17cc92fe..e63a27650 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.11.4", + "version": "1.12.1", "private": true, "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "imports": { @@ -22,6 +22,9 @@ "and_uc >= 15.5", "and_qq >= 14.9" ], + "engines": { + "node": ">=24" + }, "scripts": { "dev": "next dev", "dev:inspect": "next dev --inspect", @@ -116,7 +119,6 @@ "ky": "1.12.0", "lamejs": "1.2.1", "lexical": "0.38.2", - "line-clamp": "1.0.0", "mermaid": "11.11.0", "mime": "4.1.0", "mitt": "3.0.1", @@ -163,13 +165,13 @@ "zustand": "5.0.9" }, "devDependencies": { - "@antfu/eslint-config": "7.0.1", + "@antfu/eslint-config": "7.2.0", "@chromatic-com/storybook": "5.0.0", - "@eslint-react/eslint-plugin": "2.7.0", + "@eslint-react/eslint-plugin": "2.8.1", "@mdx-js/loader": "3.1.1", "@mdx-js/react": "3.1.1", "@next/bundle-analyzer": "16.1.5", - "@next/eslint-plugin-next": "16.1.5", + "@next/eslint-plugin-next": "16.1.6", "@next/mdx": "16.1.5", "@rgrove/parse-xml": "4.2.0", "@serwist/turbopack": "9.5.0", @@ -179,7 +181,7 @@ "@storybook/addon-themes": "10.2.0", "@storybook/nextjs-vite": "10.2.0", "@storybook/react": "10.2.0", - "@tanstack/eslint-plugin-query": "5.91.2", + "@tanstack/eslint-plugin-query": "5.91.3", "@tanstack/react-devtools": "0.9.2", "@tanstack/react-form-devtools": "0.2.12", "@tanstack/react-query-devtools": "5.90.2", @@ -187,9 +189,9 @@ "@testing-library/jest-dom": "6.9.1", "@testing-library/react": "16.3.0", "@testing-library/user-event": "14.6.1", - "@tsslint/cli": "3.0.1", - "@tsslint/compat-eslint": "3.0.1", - "@tsslint/config": "3.0.1", + "@tsslint/cli": "3.0.2", + "@tsslint/compat-eslint": "3.0.2", + "@tsslint/config": "3.0.2", "@types/js-cookie": "3.0.6", "@types/js-yaml": "4.0.9", "@types/negotiator": "0.6.4", @@ -204,7 +206,7 @@ "@types/semver": "7.7.1", "@types/sortablejs": "1.15.8", "@types/uuid": "10.0.0", - "@typescript-eslint/parser": "8.53.0", + "@typescript-eslint/parser": "8.54.0", "@typescript/native-preview": "7.0.0-dev.20251209.1", "@vitejs/plugin-react": "5.1.2", "@vitest/coverage-v8": "4.0.17", @@ -215,8 +217,8 @@ "eslint": "9.39.2", "eslint-plugin-react-hooks": "7.0.1", "eslint-plugin-react-refresh": "0.4.26", - "eslint-plugin-sonarjs": "3.0.5", - "eslint-plugin-storybook": "10.2.0", + "eslint-plugin-sonarjs": "3.0.6", + "eslint-plugin-storybook": "10.2.1", "eslint-plugin-tailwindcss": "3.18.2", "husky": "9.1.7", "jsdom": "27.3.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 884debb67..2b283b83e 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -240,9 +240,6 @@ importers: lexical: specifier: 0.38.2 version: 0.38.2 - line-clamp: - specifier: 1.0.0 - version: 1.0.0 mermaid: specifier: 11.11.0 version: 11.11.0 @@ -377,14 +374,14 @@ importers: version: 5.0.9(@types/react@19.2.9)(immer@11.1.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) devDependencies: '@antfu/eslint-config': - specifier: 7.0.1 - version: 7.0.1(@eslint-react/eslint-plugin@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.5)(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@9.39.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.26(eslint@9.39.2(jiti@1.21.7)))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.17) + specifier: 7.2.0 + version: 7.2.0(@eslint-react/eslint-plugin@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.6)(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@9.39.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.26(eslint@9.39.2(jiti@1.21.7)))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.17) '@chromatic-com/storybook': specifier: 5.0.0 version: 5.0.0(storybook@10.2.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@eslint-react/eslint-plugin': - specifier: 2.7.0 - version: 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + specifier: 2.8.1 + version: 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) '@mdx-js/loader': specifier: 3.1.1 version: 3.1.1(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) @@ -395,8 +392,8 @@ importers: specifier: 16.1.5 version: 16.1.5 '@next/eslint-plugin-next': - specifier: 16.1.5 - version: 16.1.5 + specifier: 16.1.6 + version: 16.1.6 '@next/mdx': specifier: 16.1.5 version: 16.1.5(@mdx-js/loader@3.1.1(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.9)(react@19.2.4)) @@ -425,8 +422,8 @@ importers: specifier: 10.2.0 version: 10.2.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.2.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) '@tanstack/eslint-plugin-query': - specifier: 5.91.2 - version: 5.91.2(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + specifier: 5.91.3 + version: 5.91.3(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) '@tanstack/react-devtools': specifier: 0.9.2 version: 0.9.2(@types/react-dom@19.2.3(@types/react@19.2.9))(@types/react@19.2.9)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(solid-js@1.9.11) @@ -449,14 +446,14 @@ importers: specifier: 14.6.1 version: 14.6.1(@testing-library/dom@10.4.1) '@tsslint/cli': - specifier: 3.0.1 - version: 3.0.1(@tsslint/compat-eslint@3.0.1(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) + specifier: 3.0.2 + version: 3.0.2(@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) '@tsslint/compat-eslint': - specifier: 3.0.1 - version: 3.0.1(jiti@1.21.7)(typescript@5.9.3) + specifier: 3.0.2 + version: 3.0.2(jiti@1.21.7)(typescript@5.9.3) '@tsslint/config': - specifier: 3.0.1 - version: 3.0.1(@tsslint/compat-eslint@3.0.1(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) + specifier: 3.0.2 + version: 3.0.2(@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) '@types/js-cookie': specifier: 3.0.6 version: 3.0.6 @@ -500,8 +497,8 @@ importers: specifier: 10.0.0 version: 10.0.0 '@typescript-eslint/parser': - specifier: 8.53.0 - version: 8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + specifier: 8.54.0 + version: 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) '@typescript/native-preview': specifier: 7.0.0-dev.20251209.1 version: 7.0.0-dev.20251209.1 @@ -533,11 +530,11 @@ importers: specifier: 0.4.26 version: 0.4.26(eslint@9.39.2(jiti@1.21.7)) eslint-plugin-sonarjs: - specifier: 3.0.5 - version: 3.0.5(eslint@9.39.2(jiti@1.21.7)) + specifier: 3.0.6 + version: 3.0.6(eslint@9.39.2(jiti@1.21.7)) eslint-plugin-storybook: - specifier: 10.2.0 - version: 10.2.0(eslint@9.39.2(jiti@1.21.7))(storybook@10.2.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + specifier: 10.2.1 + version: 10.2.1(eslint@9.39.2(jiti@1.21.7))(storybook@10.2.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) eslint-plugin-tailwindcss: specifier: 3.18.2 version: 3.18.2(tailwindcss@3.4.18(tsx@4.21.0)(yaml@2.8.2)) @@ -685,8 +682,8 @@ packages: '@amplitude/targeting@0.2.0': resolution: {integrity: sha512-/50ywTrC4hfcfJVBbh5DFbqMPPfaIOivZeb5Gb+OGM03QrA+lsUqdvtnKLNuWtceD4H6QQ2KFzPJ5aAJLyzVDA==} - '@antfu/eslint-config@7.0.1': - resolution: {integrity: sha512-QbCDrLPo2Bpn9/W5PnpGvUuD/EIKhiCmLBuIj9ylxeMvl47XSkXy3MZyinqUVsBJzk196B7BcJQByDZRr5TbZQ==} + '@antfu/eslint-config@7.2.0': + resolution: {integrity: sha512-I/GWDvkvUfp45VolhrMpOdkfBC69f6lstJi0BCSooylQZwH4OTJPkbXCkp4lKh9V4BeMrcO3G5iC+YIfY28/aA==} hasBin: true peerDependencies: '@eslint-react/eslint-plugin': ^2.0.1 @@ -1127,48 +1124,44 @@ packages: peerDependencies: eslint: ^6.0.0 || ^7.0.0 || >=8.0.0 - '@eslint-community/regexpp@4.12.1': - resolution: {integrity: sha512-CCZCDJuduB9OUkFkY2IgppNZMi2lBQgD2qzwXkEia16cge2pijY/aXi96CJMquDMn3nJdlPV1A5KrJEXwfLNzQ==} - engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} - '@eslint-community/regexpp@4.12.2': resolution: {integrity: sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==} engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} - '@eslint-react/ast@2.7.0': - resolution: {integrity: sha512-GGrvel9+kR++wK7orcS2kS1xtHpY0o0rh6hbHbiGVWsSiZmg0X8jZfK1nSf8a3FLJR2WLtQlUsrrtJ4hObaqeQ==} + '@eslint-react/ast@2.8.1': + resolution: {integrity: sha512-4D442lxeFvvd9PMvBbA621rfz/Ne8Kod8RW0/FLKO0vx+IOxm74pP6be1uU56rqL9TvoIHxjclBjfgXplEF+Yw==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@eslint-react/core@2.7.0': - resolution: {integrity: sha512-xeRSnzLI35Msr2lnGjH4vxgOwohODy2FaXRmXUS1IpmMRDp1Ct+7I3SDknfeW/YExjGZXvpxR0uD2P9dSjU6NA==} + '@eslint-react/core@2.8.1': + resolution: {integrity: sha512-zF73p8blyuX+zrfgyTtpKesichYzK+G54TEjFWtzagWIbnqQjtVscebL/eGep72oWzAOd5B04ACBvJ2hW4fp5g==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@eslint-react/eff@2.7.0': - resolution: {integrity: sha512-+uUI53LkS6EDU0ysVUeM2SdyZQwt/xEfh4OSJ0JMLT8fJbseZY8c0hyev7X5arifcLs0PVPHwUP1IPcNhSLOFw==} + '@eslint-react/eff@2.8.1': + resolution: {integrity: sha512-ZASOs8oTZJSiu1giue7V87GEKQvlKLfGfLppal6Rl+aKnfIEz+vartmjpH12pkFQZ9ESRyHzYbU533S6pEDoNg==} engines: {node: '>=20.19.0'} - '@eslint-react/eslint-plugin@2.7.0': - resolution: {integrity: sha512-Bog14dOrsG/jBA9B8URZPJMI6dZuEwqHdkPcTuIkJe92EjFj8NwyziNGFXKY3j7o9AU9ILCBbjfC4JFq56lwjQ==} + '@eslint-react/eslint-plugin@2.8.1': + resolution: {integrity: sha512-ob+SSDnTPnA5dhiWWJLfyHRLEzWnjilCsohgo5s9PPKF5b5bjxG+c/rwqhQwT3M9Ey83mGNdkrLzt00SOfr4pw==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@eslint-react/shared@2.7.0': - resolution: {integrity: sha512-/lF5uiGYd+XIfO5t2YMC5RdbQ9lxLkxfL4icZgrbiJIPndirAKjFNl1cdXd+C/qqRCYDACrTPqI8HEL1T4N1Iw==} + '@eslint-react/shared@2.8.1': + resolution: {integrity: sha512-NDmJBiMiPDXR6qeZzYOtiILHxWjYwBHxquQ/bMQkWcWK+1qF5LeD8UTRcWtBpZoMPi3sNBWwR3k2Sc5HWZpJ7g==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@eslint-react/var@2.7.0': - resolution: {integrity: sha512-EFztHstOAYYCrFFNUOPZ7+J3o/X/zawqPKgLL7b5/271rhL6/DMxUmTcKtJIHO7hCdFPMcGT+vPxe+omq62Ukg==} + '@eslint-react/var@2.8.1': + resolution: {integrity: sha512-iHIdEBz6kgW4dEFdhEjpy9SEQ6+d4RYg+WBzHg5J5ktT2xSQFi77Dq6Wtemik6QvvAPnYLRseQxgW+m+1rQlfA==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 @@ -1199,6 +1192,10 @@ packages: resolution: {integrity: sha512-gBrxN88gOIf3R7ja5K9slwNayVcZgK6SOUORm2uBzTeIEfeVaIhOpCtTox3P6R7o2jLFwLFTLnC7kU/RGcYEgw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/config-helpers@0.5.1': + resolution: {integrity: sha512-QN8067dXsXAl9HIvqws7STEviheRFojX3zek5OpC84oBxDGqizW9731ByF/ASxqQihbWrVDdZXS+Ihnsckm9dg==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} + '@eslint/core@0.14.0': resolution: {integrity: sha512-qIbV0/JZr7iSDjqAc60IqbLdsj9GDt16xQtWD+B78d/HAlvysGdZZ6rpJHGAc2T0FQx1X6thsSPdnoiGKdNtdg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} @@ -1765,8 +1762,8 @@ packages: '@next/env@16.1.5': resolution: {integrity: sha512-CRSCPJiSZoi4Pn69RYBDI9R7YK2g59vLexPQFXY0eyw+ILevIenCywzg+DqmlBik9zszEnw2HLFOUlLAcJbL7g==} - '@next/eslint-plugin-next@16.1.5': - resolution: {integrity: sha512-gUWcEsOl+1W7XakmouClcJ0TNFCkblvDUho31wulbDY9na0C6mGtBTSXGRU5GXJY65GjGj0zNaCD/GaBp888Mg==} + '@next/eslint-plugin-next@16.1.6': + resolution: {integrity: sha512-/Qq3PTagA6+nYVfryAtQ7/9FEr/6YVyvOtl6rZnGsbReGLf0jZU6gkpr1FuChAQpvV46a78p4cmHOVP8mbfSMQ==} '@next/mdx@16.1.5': resolution: {integrity: sha512-TYzfGfZiXtf6HXZpqJoKq+2DRB1FjY9BR1HWhfl7WoSW/BAEr6X+WmdrdrCtqNpkY8VSoWHVWP0KNbyTqY7ZTA==} @@ -2927,8 +2924,8 @@ packages: peerDependencies: solid-js: 1.9.11 - '@tanstack/eslint-plugin-query@5.91.2': - resolution: {integrity: sha512-UPeWKl/Acu1IuuHJlsN+eITUHqAaa9/04geHHPedY8siVarSaWprY0SVMKrkpKfk5ehRT7+/MZ5QwWuEtkWrFw==} + '@tanstack/eslint-plugin-query@5.91.3': + resolution: {integrity: sha512-5GMGZMYFK9dOvjpdedjJs4hU40EdPuO2AjzObQzP7eOSsikunCfrXaU3oNGXSsvoU9ve1Z1xQZZuDyPi0C1M7Q==} peerDependencies: eslint: ^8.57.0 || ^9.0.0 @@ -3034,18 +3031,18 @@ packages: peerDependencies: '@testing-library/dom': '>=7.21.4' - '@tsslint/cli@3.0.1': - resolution: {integrity: sha512-y5yzMFl6sKQNsomuGInmFzMiKW37xxDcJauHnPqYoCWL8LldNLnaUOBqx0illfNZ0FDAiSuV/oshC/NG8/F2Tw==} + '@tsslint/cli@3.0.2': + resolution: {integrity: sha512-8lyZcDEs86zitz0wZ5QRdswY6xGz8j+WL11baN4rlpwahtPgYatujpYV5gpoKeyMAyerlNTdQh6u2LUJLoLNyQ==} engines: {node: '>=22.6.0'} hasBin: true peerDependencies: typescript: '*' - '@tsslint/compat-eslint@3.0.1': - resolution: {integrity: sha512-cojBaB1C9RxWjDfCvLBhbffshyizb+Cf1Os9NXHuzyQOPvU1IwYPW5Sxo1RU19pCOE9/TvQcuxgnGfwbkk/Dig==} + '@tsslint/compat-eslint@3.0.2': + resolution: {integrity: sha512-2TzSJPybCEfU/kHNi9UybwI//A7Fe14CwqmNuJ4fR4WYGpfIclXqfDJwsn5U1NzrWbHjWzRSntJITQPNw1SCNA==} - '@tsslint/config@3.0.1': - resolution: {integrity: sha512-1S8YYLrZE22xfH3GtDXRO7YzkeQj9+FjoxaWhYQsjWDU82HHeSRWq5d2UzPSN/ac6WFmFq8yApXIGylfvrG6MA==} + '@tsslint/config@3.0.2': + resolution: {integrity: sha512-oHzteAwL6NHVrLzJnrpqMwewEFOydhDH228weO4wkHW8SwvE4oVV5qrKmjwL69ClYt5Le3y2aGDzGou+GuTbKg==} engines: {node: '>=22.6.0'} hasBin: true peerDependencies: @@ -3057,12 +3054,12 @@ packages: tsl: optional: true - '@tsslint/core@3.0.1': - resolution: {integrity: sha512-8FEczJ20hdpmEH5vm272hS3QAycsk5574yZT6VMS8TUK8kNY4qoRKY/gdOY0nYNYWZrRPs+6dr1TmEVPBZjlvw==} + '@tsslint/core@3.0.2': + resolution: {integrity: sha512-Cu50e9vBojEMQjbqMoshkgLSoBj1BKbbmhSvzgbo07TiQ1wrOblZjvhU8ygB1fAIIHgU4laExX3pLU5OOeeR9g==} engines: {node: '>=22.6.0'} - '@tsslint/types@3.0.1': - resolution: {integrity: sha512-JPK/+tSJ2hPTwgN173fkenPEnAI2CD0r0FDJ23PfftTc0NM449ZiAFHvs1KuPUOjAvBFIo5BsLr7Kxc1Ekdgtw==} + '@tsslint/types@3.0.2': + resolution: {integrity: sha512-RbF3TIxu/YQwRpYrH5j2EL3ff4+Lr2SSmwCJmPJfi832F0hpgJj6xB9xKEorrUj0ZaTHE1QOr5SOMe5B6Qv+2Q==} '@tybys/wasm-util@0.10.1': resolution: {integrity: sha512-9tTaPJLSiejZKx+Bmog4uSubteqTvFrVrURwkmHixBo0G4seD0zUxp98E1DzUBJxLQ3NPwXrGKDiVjwx/DpPsg==} @@ -3301,41 +3298,41 @@ packages: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/parser@8.53.0': - resolution: {integrity: sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg==} + '@typescript-eslint/parser@8.54.0': + resolution: {integrity: sha512-BtE0k6cjwjLZoZixN0t5AKP0kSzlGu7FctRXYuPAm//aaiZhmfq1JwdYpYr1brzEspYyFeF+8XF5j2VK6oalrA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/project-service@8.53.0': - resolution: {integrity: sha512-Bl6Gdr7NqkqIP5yP9z1JU///Nmes4Eose6L1HwpuVHwScgDPPuEWbUVhvlZmb8hy0vX9syLk5EGNL700WcBlbg==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/project-service@8.53.1': resolution: {integrity: sha512-WYC4FB5Ra0xidsmlPb+1SsnaSKPmS3gsjIARwbEkHkoWloQmuzcfypljaJcR78uyLA1h8sHdWWPHSLDI+MtNog==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/scope-manager@8.53.0': - resolution: {integrity: sha512-kWNj3l01eOGSdVBnfAF2K1BTh06WS0Yet6JUgb9Cmkqaz3Jlu0fdVUjj9UI8gPidBWSMqDIglmEXifSgDT/D0g==} + '@typescript-eslint/project-service@8.54.0': + resolution: {integrity: sha512-YPf+rvJ1s7MyiWM4uTRhE4DvBXrEV+d8oC3P9Y2eT7S+HBS0clybdMIPnhiATi9vZOYDc7OQ1L/i6ga6NFYK/g==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + typescript: '>=4.8.4 <6.0.0' '@typescript-eslint/scope-manager@8.53.1': resolution: {integrity: sha512-Lu23yw1uJMFY8cUeq7JlrizAgeQvWugNQzJp8C3x8Eo5Jw5Q2ykMdiiTB9vBVOOUBysMzmRRmUfwFrZuI2C4SQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/tsconfig-utils@8.53.0': - resolution: {integrity: sha512-K6Sc0R5GIG6dNoPdOooQ+KtvT5KCKAvTcY8h2rIuul19vxH5OTQk7ArKkd4yTzkw66WnNY0kPPzzcmWA+XRmiA==} + '@typescript-eslint/scope-manager@8.54.0': + resolution: {integrity: sha512-27rYVQku26j/PbHYcVfRPonmOlVI6gihHtXFbTdB5sb6qA0wdAQAbyXFVarQ5t4HRojIz64IV90YtsjQSSGlQg==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@typescript-eslint/tsconfig-utils@8.53.1': + resolution: {integrity: sha512-qfvLXS6F6b1y43pnf0pPbXJ+YoXIC7HKg0UGZ27uMIemKMKA6XH2DTxsEDdpdN29D+vHV07x/pnlPNVLhdhWiA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/tsconfig-utils@8.53.1': - resolution: {integrity: sha512-qfvLXS6F6b1y43pnf0pPbXJ+YoXIC7HKg0UGZ27uMIemKMKA6XH2DTxsEDdpdN29D+vHV07x/pnlPNVLhdhWiA==} + '@typescript-eslint/tsconfig-utils@8.54.0': + resolution: {integrity: sha512-dRgOyT2hPk/JwxNMZDsIXDgyl9axdJI3ogZ2XWhBPsnZUv+hPesa5iuhdYt2gzwA9t8RE5ytOJ6xB0moV0Ujvw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' @@ -3347,22 +3344,29 @@ packages: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/types@8.53.0': - resolution: {integrity: sha512-Bmh9KX31Vlxa13+PqPvt4RzKRN1XORYSLlAE+sO1i28NkisGbTtSLFVB3l7PWdHtR3E0mVMuC7JilWJ99m2HxQ==} + '@typescript-eslint/type-utils@8.54.0': + resolution: {integrity: sha512-hiLguxJWHjjwL6xMBwD903ciAwd7DmK30Y9Axs/etOkftC3ZNN9K44IuRD/EB08amu+Zw6W37x9RecLkOo3pMA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + eslint: ^8.57.0 || ^9.0.0 + typescript: '>=4.8.4 <6.0.0' '@typescript-eslint/types@8.53.1': resolution: {integrity: sha512-jr/swrr2aRmUAUjW5/zQHbMaui//vQlsZcJKijZf3M26bnmLj8LyZUpj8/Rd6uzaek06OWsqdofN/Thenm5O8A==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/typescript-estree@8.53.0': - resolution: {integrity: sha512-pw0c0Gdo7Z4xOG987u3nJ8akL9093yEEKv8QTJ+Bhkghj1xyj8cgPaavlr9rq8h7+s6plUJ4QJYw2gCZodqmGw==} + '@typescript-eslint/types@8.54.0': + resolution: {integrity: sha512-PDUI9R1BVjqu7AUDsRBbKMtwmjWcn4J3le+5LpcFgWULN3LvHC5rkc9gCVxbrsrGmO1jfPybN5s6h4Jy+OnkAA==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@typescript-eslint/typescript-estree@8.53.1': + resolution: {integrity: sha512-RGlVipGhQAG4GxV1s34O91cxQ/vWiHJTDHbXRr0li2q/BGg3RR/7NM8QDWgkEgrwQYCvmJV9ichIwyoKCQ+DTg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/typescript-estree@8.53.1': - resolution: {integrity: sha512-RGlVipGhQAG4GxV1s34O91cxQ/vWiHJTDHbXRr0li2q/BGg3RR/7NM8QDWgkEgrwQYCvmJV9ichIwyoKCQ+DTg==} + '@typescript-eslint/typescript-estree@8.54.0': + resolution: {integrity: sha512-BUwcskRaPvTk6fzVWgDPdUndLjB87KYDrN5EYGetnktoeAvPtO4ONHlAZDnj5VFnUANg0Sjm7j4usBlnoVMHwA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' @@ -3374,14 +3378,21 @@ packages: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/visitor-keys@8.53.0': - resolution: {integrity: sha512-LZ2NqIHFhvFwxG0qZeLL9DvdNAHPGCY5dIRwBhyYeU+LfLhcStE1ImjsuTG/WaVh3XysGaeLW8Rqq7cGkPCFvw==} + '@typescript-eslint/utils@8.54.0': + resolution: {integrity: sha512-9Cnda8GS57AQakvRyG0PTejJNlA2xhvyNtEVIMlDWOOeEyBkYWhGPnfrIAnqxLMTSTo6q8g12XVjjev5l1NvMA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + eslint: ^8.57.0 || ^9.0.0 + typescript: '>=4.8.4 <6.0.0' '@typescript-eslint/visitor-keys@8.53.1': resolution: {integrity: sha512-oy+wV7xDKFPRyNggmXuZQSBzvoLnpmJs+GhzRhPjrxl2b/jIlyjVokzm47CZCDUdXKr2zd7ZLodPfOBpOPyPlg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@typescript-eslint/visitor-keys@8.54.0': + resolution: {integrity: sha512-VFlhGSl4opC0bprJiItPQ1RfUhGDIBokcPwaFH4yiBCaNPeld/9VeXbiPO1cLyorQi1G1vL+ecBk1x8o1axORA==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@typescript/native-preview-darwin-arm64@7.0.0-dev.20251209.1': resolution: {integrity: sha512-F1cnYi+ZeinYQnaTQKKIsbuoq8vip5iepBkSZXlB8PjbG62LW1edUdktd/nVEc+Q+SEysSQ3jRdk9eU766s5iw==} cpu: [arm64] @@ -4312,8 +4323,14 @@ packages: resolution: {integrity: sha512-k1gCAXAsNgLwEL+Y8Wvl+M6oEFj5bgazfZULpS5CneoPPXRaCCW7dm+q21Ky2VEE5X+VeRDBVg1Pcvvsr4TtNQ==} engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0} +<<<<<<< HEAD dingtalk-jsapi@3.2.5: resolution: {integrity: sha512-GHtDTmilJQhr07GNarjlzhvgUkPWc0+52zbN2ToW+JzkydaOwmhiJCTO42+BI+onAlhdfLUbtUnGsjQNDTrM1w==} +======= + diff-sequences@29.6.3: + resolution: {integrity: sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==} + engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} +>>>>>>> cd03e0a9ef7f2383853ace444e3aefe4fac05cde dlv@1.1.3: resolution: {integrity: sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==} @@ -4461,8 +4478,8 @@ packages: peerDependencies: eslint: ^9.5.0 - eslint-flat-config-utils@2.1.4: - resolution: {integrity: sha512-bEnmU5gqzS+4O+id9vrbP43vByjF+8KOs+QuuV4OlqAuXmnRW2zfI/Rza1fQvdihQ5h4DUo0NqFAiViD4mSrzQ==} + eslint-flat-config-utils@3.0.0: + resolution: {integrity: sha512-bzTam/pSnPANR0GUz4g7lo4fyzlQZwuz/h8ytsSS4w59N/JlXH/l7jmyNVBLxPz3B9/9ntz5ZLevGpazyDXJQQ==} eslint-json-compat-utils@0.2.1: resolution: {integrity: sha512-YzEodbDyW8DX8bImKhAcCeu/L31Dd/70Bidx2Qex9OFUtgzXLqtfWL4Hr5fM/aCCB8QUZLuJur0S9k6UfgFkfg==} @@ -4535,15 +4552,15 @@ packages: peerDependencies: eslint: ^9.0.0 - eslint-plugin-react-dom@2.7.0: - resolution: {integrity: sha512-9dvpfaAG3dC14jkDx5c9yXK9mQkYvxAUphQYfzorCntumQi5iOPsWNhITO+M1P+uIEpoc4HwuWkX42E/395AGQ==} + eslint-plugin-react-dom@2.8.1: + resolution: {integrity: sha512-VAVs3cp/0XTxdjTeLePtZVadj+om+N1VNVy7hyzSPACfh5ncAicC0zOIc5MB15KUWCj8PoG/ZnVny0YqeubgRg==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - eslint-plugin-react-hooks-extra@2.7.0: - resolution: {integrity: sha512-pvjuFvUJkmmHLRjWgJcuRKI+UUq8DddyVU5PrMJY2G3LTYewr4kMHRGaFQ6qg+mbVZWovfxy+VjZjJ8PTfJTDg==} + eslint-plugin-react-hooks-extra@2.8.1: + resolution: {integrity: sha512-YeZLGzcib6UxlY7Gf+3zz8Mfl7u+OoVj3MukGaTuU6zkm1XQMI8/k4o16bKHuWtUauhn7Udl1bLAWfLgQM5UFw==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 @@ -4555,8 +4572,8 @@ packages: peerDependencies: eslint: ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0 || ^9.0.0 - eslint-plugin-react-naming-convention@2.7.0: - resolution: {integrity: sha512-BENL2tUVW/PSpFjLyfS0WloG5Buh76rvBM1hG/dCEyWDpHA6s4oJpF2Th9J92eKfim48/uprIPkKCB520Ev2nQ==} + eslint-plugin-react-naming-convention@2.8.1: + resolution: {integrity: sha512-fVj+hSzIe2I6HyPTf1nccMBXq72c4jbM3gk0T+szo/wewEF8/LgenjfquJoxHPpheb1fujFgdlo5HBhsilAX7Q==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 @@ -4567,15 +4584,15 @@ packages: peerDependencies: eslint: '>=8.40' - eslint-plugin-react-web-api@2.7.0: - resolution: {integrity: sha512-vIuYyHbn2H337YZR8tKqUbzSNAiH6+9jk3atQBEgISJT0NTuwd80nhEPm3oPHfbgB3Sc4+rEhchVTnG+4BsFfg==} + eslint-plugin-react-web-api@2.8.1: + resolution: {integrity: sha512-NYsZKW1aJZ2XZuYTPzbwYLShvGcuXKRV/5TW61VO56gik/btil4Snt5UtyxshHbvT/zXx/Z+QsHul51/XM4/Qw==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: '>=4.8.4 <6.0.0' - eslint-plugin-react-x@2.7.0: - resolution: {integrity: sha512-/za228LsbKt1OlZ2XxP3R4xouG0rXeeuLyEnpHfKsAcY0mKPklempmQ5s0E9+SqcpQ/Jd+O4Jg9/30RU+vCqfw==} + eslint-plugin-react-x@2.8.1: + resolution: {integrity: sha512-4IpCMrsb63AVEa9diOApIm+T3wUGIzK+EB5vyYocO31YYPJ16+R7Fh4lV3S3fOuX1+aQ+Ad4SE0cYuZ2pF2Tlg==} engines: {node: '>=20.19.0'} peerDependencies: eslint: ^8.57.0 || ^9.0.0 @@ -4587,16 +4604,16 @@ packages: peerDependencies: eslint: '>=8.44.0' - eslint-plugin-sonarjs@3.0.5: - resolution: {integrity: sha512-dI62Ff3zMezUToi161hs2i1HX1ie8Ia2hO0jtNBfdgRBicAG4ydy2WPt0rMTrAe3ZrlqhpAO3w1jcQEdneYoFA==} + eslint-plugin-sonarjs@3.0.6: + resolution: {integrity: sha512-3mVUqsAUSylGfkJMj2v0aC2Cu/eUunDLm+XMjLf0uLjAZao205NWF3g6EXxcCAFO+rCZiQ6Or1WQkUcU9/sKFQ==} peerDependencies: eslint: ^8.0.0 || ^9.0.0 - eslint-plugin-storybook@10.2.0: - resolution: {integrity: sha512-OtQJ153FOusr8bIMzccjkfMFJEex/3NFx0iXZ+UaeQ0WXearQ+37EGgBay3onkFElyu8AySggq/fdTknPAEvPA==} + eslint-plugin-storybook@10.2.1: + resolution: {integrity: sha512-5+V+dlzTuZfNKUD8hPbLvCVtggcWfI2lDGTpiq0AENrHeAgcztj17wwDva96lbg/sAG20uX71l8HQo3s/GmpHw==} peerDependencies: eslint: '>=8' - storybook: ^10.2.0 + storybook: ^10.2.1 eslint-plugin-tailwindcss@3.18.2: resolution: {integrity: sha512-QbkMLDC/OkkjFQ1iz/5jkMdHfiMu/uwujUHLAJK5iwNHD8RTxVTlsUezE0toTZ6VhybNBsk+gYGPDq2agfeRNA==} @@ -4639,11 +4656,11 @@ packages: '@typescript-eslint/parser': optional: true - eslint-plugin-yml@1.19.1: - resolution: {integrity: sha512-bYkOxyEiXh9WxUhVYPELdSHxGG5pOjCSeJOVkfdIyj6tuiHDxrES2WAW1dBxn3iaZQey57XflwLtCYRcNPOiOg==} - engines: {node: ^14.17.0 || >=16.0.0} + eslint-plugin-yml@3.0.0: + resolution: {integrity: sha512-kuAW6o3hlFHyF5p7TLon+AtvNWnsvRrb88pqywGMSCEqAP5d1gOMvNGgWLVlKHqmx5RbFhQLcxFDGmS4IU9DwA==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24.0.0} peerDependencies: - eslint: '>=6.0.0' + eslint: '>=9.38.0' eslint-processor-vue-blocks@2.0.0: resolution: {integrity: sha512-u4W0CJwGoWY3bjXAuFpc/b6eK3NQEI8MoeW7ritKj3G3z/WtHrKjkqf+wk8mPEy5rlMGS+k6AZYOw2XBoN/02Q==} @@ -5468,9 +5485,6 @@ packages: resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==} engines: {node: '>=14'} - line-clamp@1.0.0: - resolution: {integrity: sha512-dCDlvMj572RIRBQ3x9aIX0DTdt2St1bMdpi64jVTAi5vqBck7wf+J97//+J7+pS80rFJaYa8HiyXCTp0flpnBA==} - lines-and-columns@1.2.4: resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} @@ -6642,11 +6656,6 @@ packages: resolution: {integrity: sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==} hasBin: true - semver@7.7.2: - resolution: {integrity: sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==} - engines: {node: '>=10'} - hasBin: true - semver@7.7.3: resolution: {integrity: sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==} engines: {node: '>=10'} @@ -7504,10 +7513,6 @@ packages: yallist@3.1.1: resolution: {integrity: sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==} - yaml-eslint-parser@1.3.2: - resolution: {integrity: sha512-odxVsHAkZYYglR30aPYRY4nUGJnoJ2y1ww2HDvZALo0BDETv9kWbi16J52eHs+PWRNmF4ub6nZqfVOeesOvntg==} - engines: {node: ^14.17.0 || >=16.0.0} - yaml-eslint-parser@2.0.0: resolution: {integrity: sha512-h0uDm97wvT2bokfwwTmY6kJ1hp6YDFL0nRHwNKz8s/VD1FH/vvZjAKoMUE+un0eaYBSG7/c6h+lJTP+31tjgTw==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} @@ -7742,21 +7747,21 @@ snapshots: idb: 8.0.3 tslib: 2.8.1 - '@antfu/eslint-config@7.0.1(@eslint-react/eslint-plugin@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.5)(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@9.39.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.26(eslint@9.39.2(jiti@1.21.7)))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.17)': + '@antfu/eslint-config@7.2.0(@eslint-react/eslint-plugin@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.6)(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@9.39.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.26(eslint@9.39.2(jiti@1.21.7)))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.17)': dependencies: '@antfu/install-pkg': 1.1.0 '@clack/prompts': 0.11.0 '@eslint-community/eslint-plugin-eslint-comments': 4.6.0(eslint@9.39.2(jiti@1.21.7)) '@eslint/markdown': 7.5.1 '@stylistic/eslint-plugin': 5.7.1(eslint@9.39.2(jiti@1.21.7)) - '@typescript-eslint/eslint-plugin': 8.53.1(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/parser': 8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/eslint-plugin': 8.53.1(@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/parser': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) '@vitest/eslint-plugin': 1.6.6(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.17) ansis: 4.2.0 cac: 6.7.14 eslint: 9.39.2(jiti@1.21.7) eslint-config-flat-gitignore: 2.1.0(eslint@9.39.2(jiti@1.21.7)) - eslint-flat-config-utils: 2.1.4 + eslint-flat-config-utils: 3.0.0 eslint-merge-processors: 2.0.0(eslint@9.39.2(jiti@1.21.7)) eslint-plugin-antfu: 3.1.3(eslint@9.39.2(jiti@1.21.7)) eslint-plugin-command: 3.4.0(eslint@9.39.2(jiti@1.21.7)) @@ -7770,9 +7775,9 @@ snapshots: eslint-plugin-regexp: 2.10.0(eslint@9.39.2(jiti@1.21.7)) eslint-plugin-toml: 1.0.3(eslint@9.39.2(jiti@1.21.7)) eslint-plugin-unicorn: 62.0.0(eslint@9.39.2(jiti@1.21.7)) - eslint-plugin-unused-imports: 4.3.0(@typescript-eslint/eslint-plugin@8.53.1(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7)) - eslint-plugin-vue: 10.7.0(@stylistic/eslint-plugin@5.7.1(eslint@9.39.2(jiti@1.21.7)))(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(vue-eslint-parser@10.2.0(eslint@9.39.2(jiti@1.21.7))) - eslint-plugin-yml: 1.19.1(eslint@9.39.2(jiti@1.21.7)) + eslint-plugin-unused-imports: 4.3.0(@typescript-eslint/eslint-plugin@8.53.1(@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7)) + eslint-plugin-vue: 10.7.0(@stylistic/eslint-plugin@5.7.1(eslint@9.39.2(jiti@1.21.7)))(@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(vue-eslint-parser@10.2.0(eslint@9.39.2(jiti@1.21.7))) + eslint-plugin-yml: 3.0.0(eslint@9.39.2(jiti@1.21.7)) eslint-processor-vue-blocks: 2.0.0(@vue/compiler-sfc@3.5.27)(eslint@9.39.2(jiti@1.21.7)) globals: 17.1.0 jsonc-eslint-parser: 2.4.2 @@ -7780,10 +7785,10 @@ snapshots: parse-gitignore: 2.0.0 toml-eslint-parser: 1.0.3 vue-eslint-parser: 10.2.0(eslint@9.39.2(jiti@1.21.7)) - yaml-eslint-parser: 1.3.2 + yaml-eslint-parser: 2.0.0 optionalDependencies: - '@eslint-react/eslint-plugin': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@next/eslint-plugin-next': 16.1.5 + '@eslint-react/eslint-plugin': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@next/eslint-plugin-next': 16.1.6 eslint-plugin-react-hooks: 7.0.1(eslint@9.39.2(jiti@1.21.7)) eslint-plugin-react-refresh: 0.4.26(eslint@9.39.2(jiti@1.21.7)) transitivePeerDependencies: @@ -8183,63 +8188,60 @@ snapshots: eslint: 9.39.2(jiti@1.21.7) eslint-visitor-keys: 3.4.3 - '@eslint-community/regexpp@4.12.1': {} - '@eslint-community/regexpp@4.12.2': {} - '@eslint-react/ast@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/ast@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/eff': 2.7.0 - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/typescript-estree': 8.53.1(typescript@5.9.3) - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/typescript-estree': 8.54.0(typescript@5.9.3) + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) string-ts: 2.3.1 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@eslint-react/core@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/core@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/ast': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.7.0 - '@eslint-react/shared': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - birecord: 0.1.1 + '@eslint-react/ast': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@eslint-react/eff@2.7.0': {} + '@eslint-react/eff@2.8.1': {} - '@eslint-react/eslint-plugin@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/eslint-plugin@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/eff': 2.7.0 - '@eslint-react/shared': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/type-utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/type-utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) - eslint-plugin-react-dom: 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-hooks-extra: 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-naming-convention: 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-web-api: 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-x: 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-dom: 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-hooks-extra: 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-naming-convention: 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-web-api: 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-x: 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) ts-api-utils: 2.4.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@eslint-react/shared@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/shared@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/eff': 2.7.0 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 @@ -8247,13 +8249,14 @@ snapshots: transitivePeerDependencies: - supports-color - '@eslint-react/var@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/var@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/ast': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.7.0 - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 @@ -8288,6 +8291,10 @@ snapshots: dependencies: '@eslint/core': 0.17.0 + '@eslint/config-helpers@0.5.1': + dependencies: + '@eslint/core': 1.0.1 + '@eslint/core@0.14.0': dependencies: '@types/json-schema': 7.0.15 @@ -8937,7 +8944,7 @@ snapshots: '@next/env@16.1.5': {} - '@next/eslint-plugin-next@16.1.5': + '@next/eslint-plugin-next@16.1.6': dependencies: fast-glob: 3.3.1 @@ -9995,7 +10002,7 @@ snapshots: - csstype - utf-8-validate - '@tanstack/eslint-plugin-query@5.91.2(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@tanstack/eslint-plugin-query@5.91.3(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) @@ -10133,11 +10140,11 @@ snapshots: dependencies: '@testing-library/dom': 10.4.1 - '@tsslint/cli@3.0.1(@tsslint/compat-eslint@3.0.1(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3)': + '@tsslint/cli@3.0.2(@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3)': dependencies: '@clack/prompts': 0.8.2 - '@tsslint/config': 3.0.1(@tsslint/compat-eslint@3.0.1(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) - '@tsslint/core': 3.0.1 + '@tsslint/config': 3.0.2(@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) + '@tsslint/core': 3.0.2 '@volar/language-core': 2.4.27 '@volar/language-hub': 0.0.1 '@volar/typescript': 2.4.27 @@ -10147,32 +10154,32 @@ snapshots: - '@tsslint/compat-eslint' - tsl - '@tsslint/compat-eslint@3.0.1(jiti@1.21.7)(typescript@5.9.3)': + '@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3)': dependencies: - '@tsslint/types': 3.0.1 - '@typescript-eslint/parser': 8.53.0(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3) + '@tsslint/types': 3.0.2 + '@typescript-eslint/parser': 8.54.0(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3) eslint: 9.27.0(jiti@1.21.7) transitivePeerDependencies: - jiti - supports-color - typescript - '@tsslint/config@3.0.1(@tsslint/compat-eslint@3.0.1(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3)': + '@tsslint/config@3.0.2(@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3)': dependencies: - '@tsslint/types': 3.0.1 + '@tsslint/types': 3.0.2 minimatch: 10.1.1 ts-api-utils: 2.4.0(typescript@5.9.3) optionalDependencies: - '@tsslint/compat-eslint': 3.0.1(jiti@1.21.7)(typescript@5.9.3) + '@tsslint/compat-eslint': 3.0.2(jiti@1.21.7)(typescript@5.9.3) transitivePeerDependencies: - typescript - '@tsslint/core@3.0.1': + '@tsslint/core@3.0.2': dependencies: - '@tsslint/types': 3.0.1 + '@tsslint/types': 3.0.2 minimatch: 10.1.1 - '@tsslint/types@3.0.1': {} + '@tsslint/types@3.0.2': {} '@tybys/wasm-util@0.10.1': dependencies: @@ -10432,10 +10439,10 @@ snapshots: '@types/zen-observable@0.8.3': {} - '@typescript-eslint/eslint-plugin@8.53.1(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/eslint-plugin@8.53.1(@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: '@eslint-community/regexpp': 4.12.2 - '@typescript-eslint/parser': 8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/parser': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) '@typescript-eslint/scope-manager': 8.53.1 '@typescript-eslint/type-utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) @@ -10448,39 +10455,30 @@ snapshots: transitivePeerDependencies: - supports-color - '@typescript-eslint/parser@8.53.0(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/parser@8.54.0(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@typescript-eslint/scope-manager': 8.53.0 - '@typescript-eslint/types': 8.53.0 - '@typescript-eslint/typescript-estree': 8.53.0(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.53.0 + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/typescript-estree': 8.54.0(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.54.0 debug: 4.4.3 eslint: 9.27.0(jiti@1.21.7) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@typescript-eslint/scope-manager': 8.53.0 - '@typescript-eslint/types': 8.53.0 - '@typescript-eslint/typescript-estree': 8.53.0(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.53.0 + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/typescript-estree': 8.54.0(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.54.0 debug: 4.4.3 eslint: 9.39.2(jiti@1.21.7) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/project-service@8.53.0(typescript@5.9.3)': - dependencies: - '@typescript-eslint/tsconfig-utils': 8.53.1(typescript@5.9.3) - '@typescript-eslint/types': 8.53.1 - debug: 4.4.3 - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - '@typescript-eslint/project-service@8.53.1(typescript@5.9.3)': dependencies: '@typescript-eslint/tsconfig-utils': 8.53.1(typescript@5.9.3) @@ -10490,21 +10488,30 @@ snapshots: transitivePeerDependencies: - supports-color - '@typescript-eslint/scope-manager@8.53.0': + '@typescript-eslint/project-service@8.54.0(typescript@5.9.3)': dependencies: - '@typescript-eslint/types': 8.53.0 - '@typescript-eslint/visitor-keys': 8.53.0 + '@typescript-eslint/tsconfig-utils': 8.54.0(typescript@5.9.3) + '@typescript-eslint/types': 8.54.0 + debug: 4.4.3 + typescript: 5.9.3 + transitivePeerDependencies: + - supports-color '@typescript-eslint/scope-manager@8.53.1': dependencies: '@typescript-eslint/types': 8.53.1 '@typescript-eslint/visitor-keys': 8.53.1 - '@typescript-eslint/tsconfig-utils@8.53.0(typescript@5.9.3)': + '@typescript-eslint/scope-manager@8.54.0': + dependencies: + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/visitor-keys': 8.54.0 + + '@typescript-eslint/tsconfig-utils@8.53.1(typescript@5.9.3)': dependencies: typescript: 5.9.3 - '@typescript-eslint/tsconfig-utils@8.53.1(typescript@5.9.3)': + '@typescript-eslint/tsconfig-utils@8.54.0(typescript@5.9.3)': dependencies: typescript: 5.9.3 @@ -10520,16 +10527,28 @@ snapshots: transitivePeerDependencies: - supports-color - '@typescript-eslint/types@8.53.0': {} + '@typescript-eslint/type-utils@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': + dependencies: + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/typescript-estree': 8.54.0(typescript@5.9.3) + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + debug: 4.4.3 + eslint: 9.39.2(jiti@1.21.7) + ts-api-utils: 2.4.0(typescript@5.9.3) + typescript: 5.9.3 + transitivePeerDependencies: + - supports-color '@typescript-eslint/types@8.53.1': {} - '@typescript-eslint/typescript-estree@8.53.0(typescript@5.9.3)': + '@typescript-eslint/types@8.54.0': {} + + '@typescript-eslint/typescript-estree@8.53.1(typescript@5.9.3)': dependencies: - '@typescript-eslint/project-service': 8.53.0(typescript@5.9.3) - '@typescript-eslint/tsconfig-utils': 8.53.0(typescript@5.9.3) - '@typescript-eslint/types': 8.53.0 - '@typescript-eslint/visitor-keys': 8.53.0 + '@typescript-eslint/project-service': 8.53.1(typescript@5.9.3) + '@typescript-eslint/tsconfig-utils': 8.53.1(typescript@5.9.3) + '@typescript-eslint/types': 8.53.1 + '@typescript-eslint/visitor-keys': 8.53.1 debug: 4.4.3 minimatch: 9.0.5 semver: 7.7.3 @@ -10539,12 +10558,12 @@ snapshots: transitivePeerDependencies: - supports-color - '@typescript-eslint/typescript-estree@8.53.1(typescript@5.9.3)': + '@typescript-eslint/typescript-estree@8.54.0(typescript@5.9.3)': dependencies: - '@typescript-eslint/project-service': 8.53.1(typescript@5.9.3) - '@typescript-eslint/tsconfig-utils': 8.53.1(typescript@5.9.3) - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/visitor-keys': 8.53.1 + '@typescript-eslint/project-service': 8.54.0(typescript@5.9.3) + '@typescript-eslint/tsconfig-utils': 8.54.0(typescript@5.9.3) + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/visitor-keys': 8.54.0 debug: 4.4.3 minimatch: 9.0.5 semver: 7.7.3 @@ -10565,16 +10584,27 @@ snapshots: transitivePeerDependencies: - supports-color - '@typescript-eslint/visitor-keys@8.53.0': + '@typescript-eslint/utils@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@typescript-eslint/types': 8.53.0 - eslint-visitor-keys: 4.2.1 + '@eslint-community/eslint-utils': 4.9.1(eslint@9.39.2(jiti@1.21.7)) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/typescript-estree': 8.54.0(typescript@5.9.3) + eslint: 9.39.2(jiti@1.21.7) + typescript: 5.9.3 + transitivePeerDependencies: + - supports-color '@typescript-eslint/visitor-keys@8.53.1': dependencies: '@typescript-eslint/types': 8.53.1 eslint-visitor-keys: 4.2.1 + '@typescript-eslint/visitor-keys@8.54.0': + dependencies: + '@typescript-eslint/types': 8.54.0 + eslint-visitor-keys: 4.2.1 + '@typescript/native-preview-darwin-arm64@7.0.0-dev.20251209.1': optional: true @@ -11587,9 +11617,13 @@ snapshots: diff-sequences@27.5.1: {} +<<<<<<< HEAD dingtalk-jsapi@3.2.5: dependencies: promise-polyfill: 7.1.2 +======= + diff-sequences@29.6.3: {} +>>>>>>> cd03e0a9ef7f2383853ace444e3aefe4fac05cde dlv@1.1.3: {} @@ -11742,8 +11776,9 @@ snapshots: '@eslint/compat': 1.4.1(eslint@9.39.2(jiti@1.21.7)) eslint: 9.39.2(jiti@1.21.7) - eslint-flat-config-utils@2.1.4: + eslint-flat-config-utils@3.0.0: dependencies: + '@eslint/config-helpers': 0.5.1 pathe: 2.0.3 eslint-json-compat-utils@0.2.1(eslint@9.39.2(jiti@1.21.7))(jsonc-eslint-parser@2.4.2): @@ -11848,16 +11883,16 @@ snapshots: yaml: 2.8.2 yaml-eslint-parser: 2.0.0 - eslint-plugin-react-dom@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-dom@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.7.0 - '@eslint-react/shared': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) compare-versions: 6.1.1 eslint: 9.39.2(jiti@1.21.7) string-ts: 2.3.1 @@ -11866,17 +11901,17 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-react-hooks-extra@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-hooks-extra@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.7.0 - '@eslint-react/shared': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/type-utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/type-utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) string-ts: 2.3.1 ts-pattern: 5.9.0 @@ -11895,17 +11930,17 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-react-naming-convention@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-naming-convention@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.7.0 - '@eslint-react/shared': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/type-utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/type-utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) compare-versions: 6.1.1 eslint: 9.39.2(jiti@1.21.7) string-ts: 2.3.1 @@ -11918,16 +11953,17 @@ snapshots: dependencies: eslint: 9.39.2(jiti@1.21.7) - eslint-plugin-react-web-api@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-web-api@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.7.0 - '@eslint-react/shared': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + birecord: 0.1.1 eslint: 9.39.2(jiti@1.21.7) string-ts: 2.3.1 ts-pattern: 5.9.0 @@ -11935,17 +11971,17 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-react-x@2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-x@2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.7.0 - '@eslint-react/shared': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.7.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.53.1 - '@typescript-eslint/type-utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.53.1 - '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/eff': 2.8.1 + '@eslint-react/shared': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 2.8.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.54.0 + '@typescript-eslint/type-utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.54.0 + '@typescript-eslint/utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) compare-versions: 6.1.1 eslint: 9.39.2(jiti@1.21.7) is-immutable-type: 5.0.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) @@ -11967,21 +12003,21 @@ snapshots: regexp-ast-analysis: 0.7.1 scslre: 0.3.0 - eslint-plugin-sonarjs@3.0.5(eslint@9.39.2(jiti@1.21.7)): + eslint-plugin-sonarjs@3.0.6(eslint@9.39.2(jiti@1.21.7)): dependencies: - '@eslint-community/regexpp': 4.12.1 + '@eslint-community/regexpp': 4.12.2 builtin-modules: 3.3.0 bytes: 3.1.2 eslint: 9.39.2(jiti@1.21.7) functional-red-black-tree: 1.0.1 jsx-ast-utils-x: 0.1.0 lodash.merge: 4.6.2 - minimatch: 9.0.5 + minimatch: 10.1.1 scslre: 0.3.0 - semver: 7.7.2 + semver: 7.7.3 typescript: 5.9.3 - eslint-plugin-storybook@10.2.0(eslint@9.39.2(jiti@1.21.7))(storybook@10.2.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): + eslint-plugin-storybook@10.2.1(eslint@9.39.2(jiti@1.21.7))(storybook@10.2.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): dependencies: '@typescript-eslint/utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) @@ -12028,13 +12064,13 @@ snapshots: semver: 7.7.3 strip-indent: 4.1.1 - eslint-plugin-unused-imports@4.3.0(@typescript-eslint/eslint-plugin@8.53.1(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7)): + eslint-plugin-unused-imports@4.3.0(@typescript-eslint/eslint-plugin@8.53.1(@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7)): dependencies: eslint: 9.39.2(jiti@1.21.7) optionalDependencies: - '@typescript-eslint/eslint-plugin': 8.53.1(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/eslint-plugin': 8.53.1(@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-vue@10.7.0(@stylistic/eslint-plugin@5.7.1(eslint@9.39.2(jiti@1.21.7)))(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(vue-eslint-parser@10.2.0(eslint@9.39.2(jiti@1.21.7))): + eslint-plugin-vue@10.7.0(@stylistic/eslint-plugin@5.7.1(eslint@9.39.2(jiti@1.21.7)))(@typescript-eslint/parser@8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3))(eslint@9.39.2(jiti@1.21.7))(vue-eslint-parser@10.2.0(eslint@9.39.2(jiti@1.21.7))): dependencies: '@eslint-community/eslint-utils': 4.9.1(eslint@9.39.2(jiti@1.21.7)) eslint: 9.39.2(jiti@1.21.7) @@ -12046,17 +12082,18 @@ snapshots: xml-name-validator: 4.0.0 optionalDependencies: '@stylistic/eslint-plugin': 5.7.1(eslint@9.39.2(jiti@1.21.7)) - '@typescript-eslint/parser': 8.53.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/parser': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-yml@1.19.1(eslint@9.39.2(jiti@1.21.7)): + eslint-plugin-yml@3.0.0(eslint@9.39.2(jiti@1.21.7)): dependencies: + '@eslint/core': 1.0.1 + '@eslint/plugin-kit': 0.5.1 debug: 4.4.3 - diff-sequences: 27.5.1 - escape-string-regexp: 4.0.0 + diff-sequences: 29.6.3 + escape-string-regexp: 5.0.0 eslint: 9.39.2(jiti@1.21.7) - eslint-compat-utils: 0.6.5(eslint@9.39.2(jiti@1.21.7)) natural-compare: 1.4.0 - yaml-eslint-parser: 1.3.2 + yaml-eslint-parser: 2.0.0 transitivePeerDependencies: - supports-color @@ -12733,7 +12770,7 @@ snapshots: is-immutable-type@5.0.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@typescript-eslint/type-utils': 8.53.1(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/type-utils': 8.54.0(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) eslint: 9.39.2(jiti@1.21.7) ts-api-utils: 2.4.0(typescript@5.9.3) ts-declaration-location: 1.0.7(typescript@5.9.3) @@ -12966,8 +13003,6 @@ snapshots: lilconfig@3.1.3: {} - line-clamp@1.0.0: {} - lines-and-columns@1.2.4: {} lint-staged@15.5.2: @@ -14582,8 +14617,6 @@ snapshots: semver@6.3.1: {} - semver@7.7.2: {} - semver@7.7.3: {} serialize-javascript@6.0.2: @@ -15485,11 +15518,6 @@ snapshots: yallist@3.1.1: {} - yaml-eslint-parser@1.3.2: - dependencies: - eslint-visitor-keys: 3.4.3 - yaml: 2.8.2 - yaml-eslint-parser@2.0.0: dependencies: eslint-visitor-keys: 5.0.0 diff --git a/web/scripts/analyze-component.js b/web/scripts/analyze-component.js index b09301503..2fdff2f3d 100755 --- a/web/scripts/analyze-component.js +++ b/web/scripts/analyze-component.js @@ -337,7 +337,7 @@ Test file under review: ${testPath} Checklist (ensure every item is addressed in your review): -- Confirm the tests satisfy all requirements listed above and in web/testing/TESTING.md. +- Confirm the tests satisfy all requirements listed above and in web/docs/test.md. - Verify Arrange → Act → Assert structure, mocks, and cleanup follow project conventions. - Ensure all detected component features (state, effects, routing, API, events, etc.) are exercised, including edge cases and error paths. - Check coverage of prop variations, null/undefined inputs, and high-priority workflows implied by usage score. @@ -382,7 +382,7 @@ Examples: # Review existing test pnpm analyze-component app/components/base/button/index.tsx --review -For complete testing guidelines, see: web/testing/testing.md +For complete testing guidelines, see: web/docs/test.md `) } diff --git a/web/scripts/gen-doc-paths.ts b/web/scripts/gen-doc-paths.ts index f0393937c..03c3cdadd 100644 --- a/web/scripts/gen-doc-paths.ts +++ b/web/scripts/gen-doc-paths.ts @@ -282,6 +282,15 @@ function generateTypeDefinitions( } lines.push('') + + // Add UseDifyNodesPath helper type after UseDifyPath + if (section === 'use-dify') { + lines.push('// UseDify node paths (without prefix)') + // eslint-disable-next-line no-template-curly-in-string + lines.push('type ExtractNodesPath = T extends `/use-dify/nodes/${infer Path}` ? Path : never') + lines.push('export type UseDifyNodesPath = ExtractNodesPath') + lines.push('') + } } // Generate API reference type (English paths only) diff --git a/web/service/client.spec.ts b/web/service/client.spec.ts new file mode 100644 index 000000000..d8b46ad4b --- /dev/null +++ b/web/service/client.spec.ts @@ -0,0 +1,80 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +const loadGetBaseURL = async (isClientValue: boolean) => { + vi.resetModules() + vi.doMock('@/utils/client', () => ({ isClient: isClientValue })) + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + // eslint-disable-next-line next/no-assign-module-variable + const module = await import('./client') + warnSpy.mockClear() + return { getBaseURL: module.getBaseURL, warnSpy } +} + +// Scenario: base URL selection and warnings. +describe('getBaseURL', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + // Scenario: client environment uses window origin. + it('should use window origin when running on the client', async () => { + // Arrange + const { origin } = window.location + const { getBaseURL, warnSpy } = await loadGetBaseURL(true) + + // Act + const url = getBaseURL('/api') + + // Assert + expect(url.href).toBe(`${origin}/api`) + expect(warnSpy).not.toHaveBeenCalled() + }) + + // Scenario: server environment falls back to localhost with warning. + it('should fall back to localhost and warn on the server', async () => { + // Arrange + const { getBaseURL, warnSpy } = await loadGetBaseURL(false) + + // Act + const url = getBaseURL('/api') + + // Assert + expect(url.href).toBe('http://localhost/api') + expect(warnSpy).toHaveBeenCalledTimes(1) + expect(warnSpy).toHaveBeenCalledWith('Using localhost as base URL in server environment, please configure accordingly.') + }) + + // Scenario: non-http protocols surface warnings. + it('should warn when protocol is not http or https', async () => { + // Arrange + const { getBaseURL, warnSpy } = await loadGetBaseURL(true) + + // Act + const url = getBaseURL('localhost:5001/console/api') + + // Assert + expect(url.protocol).toBe('localhost:') + expect(url.href).toBe('localhost:5001/console/api') + expect(warnSpy).toHaveBeenCalledTimes(1) + expect(warnSpy).toHaveBeenCalledWith( + 'Unexpected protocol for API requests, expected http or https. Current protocol: localhost:. Please configure accordingly.', + ) + }) + + // Scenario: absolute http URLs are preserved. + it('should keep absolute http URLs intact', async () => { + // Arrange + const { getBaseURL, warnSpy } = await loadGetBaseURL(true) + + // Act + const url = getBaseURL('https://api.example.com/console/api') + + // Assert + expect(url.href).toBe('https://api.example.com/console/api') + expect(warnSpy).not.toHaveBeenCalled() + }) +}) diff --git a/web/service/client.ts b/web/service/client.ts index a2cf56453..eb6b1af81 100644 --- a/web/service/client.ts +++ b/web/service/client.ts @@ -13,6 +13,7 @@ import { consoleRouterContract, marketplaceRouterContract, } from '@/contract/router' +import { isClient } from '@/utils/client' import { request } from './base' // extend: CVE-2025-63387 跨域时 Cookie 可能为 None,用 Header 携带 JWT @@ -25,6 +26,31 @@ const getMarketplaceHeaders = () => new Headers({ 'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0', }) +function isURL(path: string) { + try { + // eslint-disable-next-line no-new + new URL(path) + return true + } + catch { + return false + } +} + +export function getBaseURL(path: string) { + const url = new URL(path, isURL(path) ? undefined : isClient ? window.location.origin : 'http://localhost') + + if (!isClient && !isURL(path)) { + console.warn('Using localhost as base URL in server environment, please configure accordingly.') + } + + if (url.protocol !== 'http:' && url.protocol !== 'https:') { + console.warn(`Unexpected protocol for API requests, expected http or https. Current protocol: ${url.protocol}. Please configure accordingly.`) + } + + return url +} + const getConsoleHeaders = () => { const h = new Headers() if (loginConfigToken) @@ -52,7 +78,7 @@ export const marketplaceClient: JsonifiedClient getConsoleHeaders(), fetch: (input, init) => { return request( diff --git a/web/types/doc-paths.ts b/web/types/doc-paths.ts index 7a74f0905..8f9524935 100644 --- a/web/types/doc-paths.ts +++ b/web/types/doc-paths.ts @@ -2,7 +2,7 @@ // DON NOT EDIT IT MANUALLY // // Generated from: https://raw.githubusercontent.com/langgenius/dify-docs/refs/heads/main/docs.json -// Generated at: 2026-01-21T07:24:02.413Z +// Generated at: 2026-01-30T09:14:29.304Z // Language prefixes export type DocLanguage = 'en' | 'zh' | 'ja' @@ -104,6 +104,10 @@ export type UseDifyPath = | '/use-dify/workspace/subscription-management' | '/use-dify/workspace/team-members-management' +// UseDify node paths (without prefix) +type ExtractNodesPath = T extends `/use-dify/nodes/${infer Path}` ? Path : never +export type UseDifyNodesPath = ExtractNodesPath + // SelfHost paths export type SelfHostPath = | '/self-host/advanced-deployments/local-source-code' diff --git a/web/utils/completion-params.spec.ts b/web/utils/completion-params.spec.ts index 0b691a0ba..e56957de8 100644 --- a/web/utils/completion-params.spec.ts +++ b/web/utils/completion-params.spec.ts @@ -21,7 +21,7 @@ describe('completion-params', () => { it('validates int type parameter within range', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 100 } const result = mergeValidCompletionParams(oldParams, rules) @@ -32,7 +32,7 @@ describe('completion-params', () => { it('removes int parameter below minimum', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 0 } const result = mergeValidCompletionParams(oldParams, rules) @@ -43,7 +43,7 @@ describe('completion-params', () => { it('removes int parameter above maximum', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 5000 } const result = mergeValidCompletionParams(oldParams, rules) @@ -54,7 +54,7 @@ describe('completion-params', () => { it('removes int parameter with invalid type', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 'not a number' as any } const result = mergeValidCompletionParams(oldParams, rules) @@ -184,7 +184,7 @@ describe('completion-params', () => { it('handles multiple parameters with mixed validity', () => { const rules: ModelParameterRule[] = [ { name: 'temperature', type: 'float', min: 0, max: 2, label: { en_US: 'Temperature', zh_Hans: '温度' }, required: false }, - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, { name: 'model', type: 'string', options: ['gpt-4'], label: { en_US: 'Model', zh_Hans: '模型' }, required: false }, ] const oldParams: FormValue = { diff --git a/web/utils/download.spec.ts b/web/utils/download.spec.ts new file mode 100644 index 000000000..ff41ddfff --- /dev/null +++ b/web/utils/download.spec.ts @@ -0,0 +1,75 @@ +import { downloadBlob, downloadUrl } from './download' + +describe('downloadUrl', () => { + let mockAnchor: HTMLAnchorElement + + beforeEach(() => { + mockAnchor = { + href: '', + download: '', + rel: '', + target: '', + style: { display: '' }, + click: vi.fn(), + remove: vi.fn(), + } as unknown as HTMLAnchorElement + + vi.spyOn(document, 'createElement').mockReturnValue(mockAnchor) + vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => node) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('should create a link and trigger a download correctly', () => { + downloadUrl({ url: 'https://example.com/file.txt', fileName: 'file.txt', target: '_blank' }) + + expect(mockAnchor.href).toBe('https://example.com/file.txt') + expect(mockAnchor.download).toBe('file.txt') + expect(mockAnchor.rel).toBe('noopener noreferrer') + expect(mockAnchor.target).toBe('_blank') + expect(mockAnchor.style.display).toBe('none') + expect(mockAnchor.click).toHaveBeenCalled() + expect(mockAnchor.remove).toHaveBeenCalled() + }) + + it('should skip when url is empty', () => { + downloadUrl({ url: '' }) + expect(document.createElement).not.toHaveBeenCalled() + }) +}) + +describe('downloadBlob', () => { + it('should create a blob url, trigger download, and revoke url', () => { + const blob = new Blob(['test'], { type: 'text/plain' }) + const mockUrl = 'blob:mock-url' + const createObjectURLMock = vi.spyOn(window.URL, 'createObjectURL').mockReturnValue(mockUrl) + const revokeObjectURLMock = vi.spyOn(window.URL, 'revokeObjectURL').mockImplementation(() => {}) + + const mockAnchor = { + href: '', + download: '', + rel: '', + target: '', + style: { display: '' }, + click: vi.fn(), + remove: vi.fn(), + } as unknown as HTMLAnchorElement + + vi.spyOn(document, 'createElement').mockReturnValue(mockAnchor) + vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => node) + + downloadBlob({ data: blob, fileName: 'file.txt' }) + + expect(createObjectURLMock).toHaveBeenCalledWith(blob) + expect(mockAnchor.href).toBe(mockUrl) + expect(mockAnchor.download).toBe('file.txt') + expect(mockAnchor.rel).toBe('noopener noreferrer') + expect(mockAnchor.click).toHaveBeenCalled() + expect(mockAnchor.remove).toHaveBeenCalled() + expect(revokeObjectURLMock).toHaveBeenCalledWith(mockUrl) + + vi.restoreAllMocks() + }) +}) diff --git a/web/utils/format.spec.ts b/web/utils/format.spec.ts index 3a1709dbd..2796854e3 100644 --- a/web/utils/format.spec.ts +++ b/web/utils/format.spec.ts @@ -1,4 +1,4 @@ -import { downloadFile, formatFileSize, formatNumber, formatNumberAbbreviated, formatTime } from './format' +import { formatFileSize, formatNumber, formatNumberAbbreviated, formatTime } from './format' describe('formatNumber', () => { it('should correctly format integers', () => { @@ -82,49 +82,6 @@ describe('formatTime', () => { expect(formatTime(7200)).toBe('2.00 h') }) }) -describe('downloadFile', () => { - it('should create a link and trigger a download correctly', () => { - // Mock data - const blob = new Blob(['test content'], { type: 'text/plain' }) - const fileName = 'test-file.txt' - const mockUrl = 'blob:mockUrl' - - // Mock URL.createObjectURL - const createObjectURLMock = vi.fn().mockReturnValue(mockUrl) - const revokeObjectURLMock = vi.fn() - Object.defineProperty(window.URL, 'createObjectURL', { value: createObjectURLMock }) - Object.defineProperty(window.URL, 'revokeObjectURL', { value: revokeObjectURLMock }) - - // Mock createElement and appendChild - const mockLink = { - href: '', - download: '', - click: vi.fn(), - remove: vi.fn(), - } - const createElementMock = vi.spyOn(document, 'createElement').mockReturnValue(mockLink as any) - const appendChildMock = vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => { - return node - }) - - // Call the function - downloadFile({ data: blob, fileName }) - - // Assertions - expect(createObjectURLMock).toHaveBeenCalledWith(blob) - expect(createElementMock).toHaveBeenCalledWith('a') - expect(mockLink.href).toBe(mockUrl) - expect(mockLink.download).toBe(fileName) - expect(appendChildMock).toHaveBeenCalledWith(mockLink) - expect(mockLink.click).toHaveBeenCalled() - expect(mockLink.remove).toHaveBeenCalled() - expect(revokeObjectURLMock).toHaveBeenCalledWith(mockUrl) - - // Clean up mocks - vi.restoreAllMocks() - }) -}) - describe('formatNumberAbbreviated', () => { it('should return number as string when less than 1000', () => { expect(formatNumberAbbreviated(0)).toBe('0') diff --git a/web/utils/format.ts b/web/utils/format.ts index ce813d399..d6968e0ef 100644 --- a/web/utils/format.ts +++ b/web/utils/format.ts @@ -100,17 +100,6 @@ export const formatTime = (seconds: number) => { return `${seconds.toFixed(2)} ${units[index]}` } -export const downloadFile = ({ data, fileName }: { data: Blob, fileName: string }) => { - const url = window.URL.createObjectURL(data) - const a = document.createElement('a') - a.href = url - a.download = fileName - document.body.appendChild(a) - a.click() - a.remove() - window.URL.revokeObjectURL(url) -} - /** * Formats a number into a readable string using "k", "M", or "B" suffix. * @example