diff --git a/.agent/skills/component-refactoring b/.agent/skills/component-refactoring
new file mode 120000
index 000000000..53ae67e2f
--- /dev/null
+++ b/.agent/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.agent/skills/frontend-code-review b/.agent/skills/frontend-code-review
new file mode 120000
index 000000000..55654ffbd
--- /dev/null
+++ b/.agent/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.agent/skills/frontend-testing b/.agent/skills/frontend-testing
new file mode 120000
index 000000000..092cec774
--- /dev/null
+++ b/.agent/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.agent/skills/orpc-contract-first b/.agent/skills/orpc-contract-first
new file mode 120000
index 000000000..da47b335c
--- /dev/null
+++ b/.agent/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.agent/skills/skill-creator b/.agent/skills/skill-creator
new file mode 120000
index 000000000..b87455490
--- /dev/null
+++ b/.agent/skills/skill-creator
@@ -0,0 +1 @@
+../../.agents/skills/skill-creator
\ No newline at end of file
diff --git a/.agent/skills/vercel-react-best-practices b/.agent/skills/vercel-react-best-practices
new file mode 120000
index 000000000..e567923b3
--- /dev/null
+++ b/.agent/skills/vercel-react-best-practices
@@ -0,0 +1 @@
+../../.agents/skills/vercel-react-best-practices
\ No newline at end of file
diff --git a/.agent/skills/web-design-guidelines b/.agent/skills/web-design-guidelines
new file mode 120000
index 000000000..886b26ded
--- /dev/null
+++ b/.agent/skills/web-design-guidelines
@@ -0,0 +1 @@
+../../.agents/skills/web-design-guidelines
\ No newline at end of file
diff --git a/.claude/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md
similarity index 100%
rename from .claude/skills/component-refactoring/SKILL.md
rename to .agents/skills/component-refactoring/SKILL.md
diff --git a/.claude/skills/component-refactoring/references/complexity-patterns.md b/.agents/skills/component-refactoring/references/complexity-patterns.md
similarity index 100%
rename from .claude/skills/component-refactoring/references/complexity-patterns.md
rename to .agents/skills/component-refactoring/references/complexity-patterns.md
diff --git a/.claude/skills/component-refactoring/references/component-splitting.md b/.agents/skills/component-refactoring/references/component-splitting.md
similarity index 100%
rename from .claude/skills/component-refactoring/references/component-splitting.md
rename to .agents/skills/component-refactoring/references/component-splitting.md
diff --git a/.claude/skills/component-refactoring/references/hook-extraction.md b/.agents/skills/component-refactoring/references/hook-extraction.md
similarity index 100%
rename from .claude/skills/component-refactoring/references/hook-extraction.md
rename to .agents/skills/component-refactoring/references/hook-extraction.md
diff --git a/.claude/skills/frontend-code-review/SKILL.md b/.agents/skills/frontend-code-review/SKILL.md
similarity index 100%
rename from .claude/skills/frontend-code-review/SKILL.md
rename to .agents/skills/frontend-code-review/SKILL.md
diff --git a/.claude/skills/frontend-code-review/references/business-logic.md b/.agents/skills/frontend-code-review/references/business-logic.md
similarity index 100%
rename from .claude/skills/frontend-code-review/references/business-logic.md
rename to .agents/skills/frontend-code-review/references/business-logic.md
diff --git a/.claude/skills/frontend-code-review/references/code-quality.md b/.agents/skills/frontend-code-review/references/code-quality.md
similarity index 100%
rename from .claude/skills/frontend-code-review/references/code-quality.md
rename to .agents/skills/frontend-code-review/references/code-quality.md
diff --git a/.claude/skills/frontend-code-review/references/performance.md b/.agents/skills/frontend-code-review/references/performance.md
similarity index 100%
rename from .claude/skills/frontend-code-review/references/performance.md
rename to .agents/skills/frontend-code-review/references/performance.md
diff --git a/.claude/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md
similarity index 98%
rename from .claude/skills/frontend-testing/SKILL.md
rename to .agents/skills/frontend-testing/SKILL.md
index dd9677a78..0716c81ef 100644
--- a/.claude/skills/frontend-testing/SKILL.md
+++ b/.agents/skills/frontend-testing/SKILL.md
@@ -83,6 +83,9 @@ vi.mock('next/navigation', () => ({
usePathname: () => '/test',
}))
+// ✅ Zustand stores: Use real stores (auto-mocked globally)
+// Set test state with: useAppStore.setState({ ... })
+
// Shared state for mocks (if needed)
let mockSharedState = false
@@ -296,7 +299,7 @@ For each test file generated, aim for:
For more detailed information, refer to:
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
-- `references/mocking.md` - Mock patterns and best practices
+- `references/mocking.md` - Mock patterns, Zustand store testing, and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
diff --git a/.claude/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx
similarity index 100%
rename from .claude/skills/frontend-testing/assets/component-test.template.tsx
rename to .agents/skills/frontend-testing/assets/component-test.template.tsx
diff --git a/.claude/skills/frontend-testing/assets/hook-test.template.ts b/.agents/skills/frontend-testing/assets/hook-test.template.ts
similarity index 100%
rename from .claude/skills/frontend-testing/assets/hook-test.template.ts
rename to .agents/skills/frontend-testing/assets/hook-test.template.ts
diff --git a/.claude/skills/frontend-testing/assets/utility-test.template.ts b/.agents/skills/frontend-testing/assets/utility-test.template.ts
similarity index 100%
rename from .claude/skills/frontend-testing/assets/utility-test.template.ts
rename to .agents/skills/frontend-testing/assets/utility-test.template.ts
diff --git a/.claude/skills/frontend-testing/references/async-testing.md b/.agents/skills/frontend-testing/references/async-testing.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/async-testing.md
rename to .agents/skills/frontend-testing/references/async-testing.md
diff --git a/.claude/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/checklist.md
rename to .agents/skills/frontend-testing/references/checklist.md
diff --git a/.claude/skills/frontend-testing/references/common-patterns.md b/.agents/skills/frontend-testing/references/common-patterns.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/common-patterns.md
rename to .agents/skills/frontend-testing/references/common-patterns.md
diff --git a/.claude/skills/frontend-testing/references/domain-components.md b/.agents/skills/frontend-testing/references/domain-components.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/domain-components.md
rename to .agents/skills/frontend-testing/references/domain-components.md
diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md
similarity index 64%
rename from .claude/skills/frontend-testing/references/mocking.md
rename to .agents/skills/frontend-testing/references/mocking.md
index c70bcf0ae..86bd37598 100644
--- a/.claude/skills/frontend-testing/references/mocking.md
+++ b/.agents/skills/frontend-testing/references/mocking.md
@@ -37,16 +37,36 @@ Only mock these categories:
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
+### Zustand Stores - DO NOT Mock Manually
+
+**Zustand is globally mocked** in `web/vitest.setup.ts`. Use real stores with `setState()`:
+
+```typescript
+// ✅ CORRECT: Use real store, set test state
+import { useAppStore } from '@/app/components/app/store'
+
+useAppStore.setState({ appDetail: { id: 'test', name: 'Test' } })
+render( )
+
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({ ... }))
+```
+
+See [Zustand Store Testing](#zustand-store-testing) section for full details.
+
## Mock Placement
| Location | Purpose |
|----------|---------|
-| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
+| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) |
+| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
+**Note**: Zustand is special - it's globally mocked but you should NOT mock store modules manually. See [Zustand Store Testing](#zustand-store-testing).
+
## Essential Mocks
### 1. i18n (Auto-loaded via Global Mock)
@@ -276,6 +296,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
+1. **Use real Zustand stores** - Set test state via `store.setState()`
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
@@ -285,6 +306,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
@@ -308,10 +330,151 @@ Need to use a component in test?
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
│
+├─ Is it a Zustand store?
+│ └─ YES → DO NOT mock the module!
+│ Use real store + setState() to set test state
+│ (Global mock handles auto-reset)
+│
└─ Is it i18n?
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
+## Zustand Store Testing
+
+### Global Zustand Mock (Auto-loaded)
+
+Zustand is globally mocked in `web/vitest.setup.ts` following the [official Zustand testing guide](https://zustand.docs.pmnd.rs/guides/testing). The mock in `web/__mocks__/zustand.ts` provides:
+
+- Real store behavior with `getState()`, `setState()`, `subscribe()` methods
+- Automatic store reset after each test via `afterEach`
+- Proper test isolation between tests
+
+### ✅ Recommended: Use Real Stores (Official Best Practice)
+
+**DO NOT mock store modules manually.** Import and use the real store, then use `setState()` to set test state:
+
+```typescript
+// ✅ CORRECT: Use real store with setState
+import { useAppStore } from '@/app/components/app/store'
+
+describe('MyComponent', () => {
+ it('should render app details', () => {
+ // Arrange: Set test state via setState
+ useAppStore.setState({
+ appDetail: {
+ id: 'test-app',
+ name: 'Test App',
+ mode: 'chat',
+ },
+ })
+
+ // Act
+ render( )
+
+ // Assert
+ expect(screen.getByText('Test App')).toBeInTheDocument()
+ // Can also verify store state directly
+ expect(useAppStore.getState().appDetail?.name).toBe('Test App')
+ })
+
+ // No cleanup needed - global mock auto-resets after each test
+})
+```
+
+### ❌ Avoid: Manual Store Module Mocking
+
+Manual mocking conflicts with the global Zustand mock and loses store functionality:
+
+```typescript
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({
+ useStore: (selector) => mockSelector(selector), // Missing getState, setState!
+}))
+
+// ❌ WRONG: This conflicts with global zustand mock
+vi.mock('@/app/components/workflow/store', () => ({
+ useWorkflowStore: vi.fn(() => mockState),
+}))
+```
+
+**Problems with manual mocking:**
+
+1. Loses `getState()`, `setState()`, `subscribe()` methods
+1. Conflicts with global Zustand mock behavior
+1. Requires manual maintenance of store API
+1. Tests don't reflect actual store behavior
+
+### When Manual Store Mocking is Necessary
+
+In rare cases where the store has complex initialization or side effects, you can mock it, but ensure you provide the full store API:
+
+```typescript
+// If you MUST mock (rare), include full store API
+const mockStore = {
+ appDetail: { id: 'test', name: 'Test' },
+ setAppDetail: vi.fn(),
+}
+
+vi.mock('@/app/components/app/store', () => ({
+ useStore: Object.assign(
+ (selector: (state: typeof mockStore) => unknown) => selector(mockStore),
+ {
+ getState: () => mockStore,
+ setState: vi.fn(),
+ subscribe: vi.fn(),
+ },
+ ),
+}))
+```
+
+### Store Testing Decision Tree
+
+```
+Need to test a component using Zustand store?
+│
+├─ Can you use the real store?
+│ └─ YES → Use real store + setState (RECOMMENDED)
+│ useAppStore.setState({ ... })
+│
+├─ Does the store have complex initialization/side effects?
+│ └─ YES → Consider mocking, but include full API
+│ (getState, setState, subscribe)
+│
+└─ Are you testing the store itself (not a component)?
+ └─ YES → Test store directly with getState/setState
+ const store = useMyStore
+ store.setState({ count: 0 })
+ store.getState().increment()
+ expect(store.getState().count).toBe(1)
+```
+
+### Example: Testing Store Actions
+
+```typescript
+import { useCounterStore } from '@/stores/counter'
+
+describe('Counter Store', () => {
+ it('should increment count', () => {
+ // Initial state (auto-reset by global mock)
+ expect(useCounterStore.getState().count).toBe(0)
+
+ // Call action
+ useCounterStore.getState().increment()
+
+ // Verify state change
+ expect(useCounterStore.getState().count).toBe(1)
+ })
+
+ it('should reset to initial state', () => {
+ // Set some state
+ useCounterStore.setState({ count: 100 })
+ expect(useCounterStore.getState().count).toBe(100)
+
+ // After this test, global mock will reset to initial state
+ })
+})
+```
+
## Factory Function Pattern
```typescript
diff --git a/.claude/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/workflow.md
rename to .agents/skills/frontend-testing/references/workflow.md
diff --git a/.agents/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md
new file mode 100644
index 000000000..4e3bfc7a3
--- /dev/null
+++ b/.agents/skills/orpc-contract-first/SKILL.md
@@ -0,0 +1,46 @@
+---
+name: orpc-contract-first
+description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
+---
+
+# oRPC Contract-First Development
+
+## Project Structure
+
+```
+web/contract/
+├── base.ts # Base contract (inputStructure: 'detailed')
+├── router.ts # Router composition & type exports
+├── marketplace.ts # Marketplace contracts
+└── console/ # Console contracts by domain
+ ├── system.ts
+ └── billing.ts
+```
+
+## Workflow
+
+1. **Create contract** in `web/contract/console/{domain}.ts`
+ - Import `base` from `../base` and `type` from `@orpc/contract`
+ - Define route with `path`, `method`, `input`, `output`
+
+2. **Register in router** at `web/contract/router.ts`
+ - Import directly from domain file (no barrel files)
+ - Nest by API prefix: `billing: { invoices, bindPartnerStack }`
+
+3. **Create hooks** in `web/service/use-{domain}.ts`
+ - Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
+ - Use `consoleClient.{group}.{contract}()` for API calls
+
+## Key Rules
+
+- **Input structure**: Always use `{ params, query?, body? }` format
+- **Path params**: Use `{paramName}` in path, match in `params` object
+- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`)
+- **No barrel files**: Import directly from specific files
+- **Types**: Import from `@/types/`, use `type()` helper
+
+## Type Export
+
+```typescript
+export type ConsoleInputs = InferContractRouterInputs
+```
diff --git a/.claude/skills/skill-creator/SKILL.md b/.agents/skills/skill-creator/SKILL.md
similarity index 100%
rename from .claude/skills/skill-creator/SKILL.md
rename to .agents/skills/skill-creator/SKILL.md
diff --git a/.claude/skills/skill-creator/references/output-patterns.md b/.agents/skills/skill-creator/references/output-patterns.md
similarity index 100%
rename from .claude/skills/skill-creator/references/output-patterns.md
rename to .agents/skills/skill-creator/references/output-patterns.md
diff --git a/.claude/skills/skill-creator/references/workflows.md b/.agents/skills/skill-creator/references/workflows.md
similarity index 100%
rename from .claude/skills/skill-creator/references/workflows.md
rename to .agents/skills/skill-creator/references/workflows.md
diff --git a/.claude/skills/skill-creator/scripts/init_skill.py b/.agents/skills/skill-creator/scripts/init_skill.py
similarity index 100%
rename from .claude/skills/skill-creator/scripts/init_skill.py
rename to .agents/skills/skill-creator/scripts/init_skill.py
diff --git a/.claude/skills/skill-creator/scripts/package_skill.py b/.agents/skills/skill-creator/scripts/package_skill.py
similarity index 100%
rename from .claude/skills/skill-creator/scripts/package_skill.py
rename to .agents/skills/skill-creator/scripts/package_skill.py
diff --git a/.claude/skills/skill-creator/scripts/quick_validate.py b/.agents/skills/skill-creator/scripts/quick_validate.py
similarity index 100%
rename from .claude/skills/skill-creator/scripts/quick_validate.py
rename to .agents/skills/skill-creator/scripts/quick_validate.py
diff --git a/.agents/skills/vercel-react-best-practices/AGENTS.md b/.agents/skills/vercel-react-best-practices/AGENTS.md
new file mode 100644
index 000000000..f9b9e99c4
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/AGENTS.md
@@ -0,0 +1,2410 @@
+# React Best Practices
+
+**Version 1.0.0**
+Vercel Engineering
+January 2026
+
+> **Note:**
+> This document is mainly for agents and LLMs to follow when maintaining,
+> generating, or refactoring React and Next.js codebases at Vercel. Humans
+> may also find it useful, but guidance here is optimized for automation
+> and consistency by AI-assisted workflows.
+
+---
+
+## Abstract
+
+Comprehensive performance optimization guide for React and Next.js applications, designed for AI agents and LLMs. Contains 40+ rules across 8 categories, prioritized by impact from critical (eliminating waterfalls, reducing bundle size) to incremental (advanced patterns). Each rule includes detailed explanations, real-world examples comparing incorrect vs. correct implementations, and specific impact metrics to guide automated refactoring and code generation.
+
+---
+
+## Table of Contents
+
+1. [Eliminating Waterfalls](#1-eliminating-waterfalls) — **CRITICAL**
+ - 1.1 [Defer Await Until Needed](#11-defer-await-until-needed)
+ - 1.2 [Dependency-Based Parallelization](#12-dependency-based-parallelization)
+ - 1.3 [Prevent Waterfall Chains in API Routes](#13-prevent-waterfall-chains-in-api-routes)
+ - 1.4 [Promise.all() for Independent Operations](#14-promiseall-for-independent-operations)
+ - 1.5 [Strategic Suspense Boundaries](#15-strategic-suspense-boundaries)
+2. [Bundle Size Optimization](#2-bundle-size-optimization) — **CRITICAL**
+ - 2.1 [Avoid Barrel File Imports](#21-avoid-barrel-file-imports)
+ - 2.2 [Conditional Module Loading](#22-conditional-module-loading)
+ - 2.3 [Defer Non-Critical Third-Party Libraries](#23-defer-non-critical-third-party-libraries)
+ - 2.4 [Dynamic Imports for Heavy Components](#24-dynamic-imports-for-heavy-components)
+ - 2.5 [Preload Based on User Intent](#25-preload-based-on-user-intent)
+3. [Server-Side Performance](#3-server-side-performance) — **HIGH**
+ - 3.1 [Cross-Request LRU Caching](#31-cross-request-lru-caching)
+ - 3.2 [Minimize Serialization at RSC Boundaries](#32-minimize-serialization-at-rsc-boundaries)
+ - 3.3 [Parallel Data Fetching with Component Composition](#33-parallel-data-fetching-with-component-composition)
+ - 3.4 [Per-Request Deduplication with React.cache()](#34-per-request-deduplication-with-reactcache)
+ - 3.5 [Use after() for Non-Blocking Operations](#35-use-after-for-non-blocking-operations)
+4. [Client-Side Data Fetching](#4-client-side-data-fetching) — **MEDIUM-HIGH**
+ - 4.1 [Deduplicate Global Event Listeners](#41-deduplicate-global-event-listeners)
+ - 4.2 [Use Passive Event Listeners for Scrolling Performance](#42-use-passive-event-listeners-for-scrolling-performance)
+ - 4.3 [Use SWR for Automatic Deduplication](#43-use-swr-for-automatic-deduplication)
+ - 4.4 [Version and Minimize localStorage Data](#44-version-and-minimize-localstorage-data)
+5. [Re-render Optimization](#5-re-render-optimization) — **MEDIUM**
+ - 5.1 [Defer State Reads to Usage Point](#51-defer-state-reads-to-usage-point)
+ - 5.2 [Extract to Memoized Components](#52-extract-to-memoized-components)
+ - 5.3 [Narrow Effect Dependencies](#53-narrow-effect-dependencies)
+ - 5.4 [Subscribe to Derived State](#54-subscribe-to-derived-state)
+ - 5.5 [Use Functional setState Updates](#55-use-functional-setstate-updates)
+ - 5.6 [Use Lazy State Initialization](#56-use-lazy-state-initialization)
+ - 5.7 [Use Transitions for Non-Urgent Updates](#57-use-transitions-for-non-urgent-updates)
+6. [Rendering Performance](#6-rendering-performance) — **MEDIUM**
+ - 6.1 [Animate SVG Wrapper Instead of SVG Element](#61-animate-svg-wrapper-instead-of-svg-element)
+ - 6.2 [CSS content-visibility for Long Lists](#62-css-content-visibility-for-long-lists)
+ - 6.3 [Hoist Static JSX Elements](#63-hoist-static-jsx-elements)
+ - 6.4 [Optimize SVG Precision](#64-optimize-svg-precision)
+ - 6.5 [Prevent Hydration Mismatch Without Flickering](#65-prevent-hydration-mismatch-without-flickering)
+ - 6.6 [Use Activity Component for Show/Hide](#66-use-activity-component-for-showhide)
+ - 6.7 [Use Explicit Conditional Rendering](#67-use-explicit-conditional-rendering)
+7. [JavaScript Performance](#7-javascript-performance) — **LOW-MEDIUM**
+ - 7.1 [Batch DOM CSS Changes](#71-batch-dom-css-changes)
+ - 7.2 [Build Index Maps for Repeated Lookups](#72-build-index-maps-for-repeated-lookups)
+ - 7.3 [Cache Property Access in Loops](#73-cache-property-access-in-loops)
+ - 7.4 [Cache Repeated Function Calls](#74-cache-repeated-function-calls)
+ - 7.5 [Cache Storage API Calls](#75-cache-storage-api-calls)
+ - 7.6 [Combine Multiple Array Iterations](#76-combine-multiple-array-iterations)
+ - 7.7 [Early Length Check for Array Comparisons](#77-early-length-check-for-array-comparisons)
+ - 7.8 [Early Return from Functions](#78-early-return-from-functions)
+ - 7.9 [Hoist RegExp Creation](#79-hoist-regexp-creation)
+ - 7.10 [Use Loop for Min/Max Instead of Sort](#710-use-loop-for-minmax-instead-of-sort)
+ - 7.11 [Use Set/Map for O(1) Lookups](#711-use-setmap-for-o1-lookups)
+ - 7.12 [Use toSorted() Instead of sort() for Immutability](#712-use-tosorted-instead-of-sort-for-immutability)
+8. [Advanced Patterns](#8-advanced-patterns) — **LOW**
+ - 8.1 [Store Event Handlers in Refs](#81-store-event-handlers-in-refs)
+ - 8.2 [useLatest for Stable Callback Refs](#82-uselatest-for-stable-callback-refs)
+
+---
+
+## 1. Eliminating Waterfalls
+
+**Impact: CRITICAL**
+
+Waterfalls are the #1 performance killer. Each sequential await adds full network latency. Eliminating them yields the largest gains.
+
+### 1.1 Defer Await Until Needed
+
+**Impact: HIGH (avoids blocking unused code paths)**
+
+Move `await` operations into the branches where they're actually used to avoid blocking code paths that don't need them.
+
+**Incorrect: blocks both branches**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ const userData = await fetchUserData(userId)
+
+ if (skipProcessing) {
+ // Returns immediately but still waited for userData
+ return { skipped: true }
+ }
+
+ // Only this branch uses userData
+ return processUserData(userData)
+}
+```
+
+**Correct: only blocks when needed**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ if (skipProcessing) {
+ // Returns immediately without waiting
+ return { skipped: true }
+ }
+
+ // Fetch only when needed
+ const userData = await fetchUserData(userId)
+ return processUserData(userData)
+}
+```
+
+**Another example: early return optimization**
+
+```typescript
+// Incorrect: always fetches permissions
+async function updateResource(resourceId: string, userId: string) {
+ const permissions = await fetchPermissions(userId)
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+
+// Correct: fetches only when needed
+async function updateResource(resourceId: string, userId: string) {
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ const permissions = await fetchPermissions(userId)
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+```
+
+This optimization is especially valuable when the skipped branch is frequently taken, or when the deferred operation is expensive.
+
+### 1.2 Dependency-Based Parallelization
+
+**Impact: CRITICAL (2-10× improvement)**
+
+For operations with partial dependencies, use `better-all` to maximize parallelism. It automatically starts each task at the earliest possible moment.
+
+**Incorrect: profile waits for config unnecessarily**
+
+```typescript
+const [user, config] = await Promise.all([
+ fetchUser(),
+ fetchConfig()
+])
+const profile = await fetchProfile(user.id)
+```
+
+**Correct: config and profile run in parallel**
+
+```typescript
+import { all } from 'better-all'
+
+const { user, config, profile } = await all({
+ async user() { return fetchUser() },
+ async config() { return fetchConfig() },
+ async profile() {
+ return fetchProfile((await this.$.user).id)
+ }
+})
+```
+
+Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
+
+### 1.3 Prevent Waterfall Chains in API Routes
+
+**Impact: CRITICAL (2-10× improvement)**
+
+In API routes and Server Actions, start independent operations immediately, even if you don't await them yet.
+
+**Incorrect: config waits for auth, data waits for both**
+
+```typescript
+export async function GET(request: Request) {
+ const session = await auth()
+ const config = await fetchConfig()
+ const data = await fetchData(session.user.id)
+ return Response.json({ data, config })
+}
+```
+
+**Correct: auth and config start immediately**
+
+```typescript
+export async function GET(request: Request) {
+ const sessionPromise = auth()
+ const configPromise = fetchConfig()
+ const session = await sessionPromise
+ const [config, data] = await Promise.all([
+ configPromise,
+ fetchData(session.user.id)
+ ])
+ return Response.json({ data, config })
+}
+```
+
+For operations with more complex dependency chains, use `better-all` to automatically maximize parallelism (see Dependency-Based Parallelization).
+
+### 1.4 Promise.all() for Independent Operations
+
+**Impact: CRITICAL (2-10× improvement)**
+
+When async operations have no interdependencies, execute them concurrently using `Promise.all()`.
+
+**Incorrect: sequential execution, 3 round trips**
+
+```typescript
+const user = await fetchUser()
+const posts = await fetchPosts()
+const comments = await fetchComments()
+```
+
+**Correct: parallel execution, 1 round trip**
+
+```typescript
+const [user, posts, comments] = await Promise.all([
+ fetchUser(),
+ fetchPosts(),
+ fetchComments()
+])
+```
+
+### 1.5 Strategic Suspense Boundaries
+
+**Impact: HIGH (faster initial paint)**
+
+Instead of awaiting data in async components before returning JSX, use Suspense boundaries to show the wrapper UI faster while data loads.
+
+**Incorrect: wrapper blocked by data fetching**
+
+```tsx
+async function Page() {
+ const data = await fetchData() // Blocks entire page
+
+ return (
+
+
Sidebar
+
Header
+
+
+
+
Footer
+
+ )
+}
+```
+
+The entire layout waits for data even though only the middle section needs it.
+
+**Correct: wrapper shows immediately, data streams in**
+
+```tsx
+function Page() {
+ return (
+
+
Sidebar
+
Header
+
+ }>
+
+
+
+
Footer
+
+ )
+}
+
+async function DataDisplay() {
+ const data = await fetchData() // Only blocks this component
+ return {data.content}
+}
+```
+
+Sidebar, Header, and Footer render immediately. Only DataDisplay waits for data.
+
+**Alternative: share promise across components**
+
+```tsx
+function Page() {
+ // Start fetch immediately, but don't await
+ const dataPromise = fetchData()
+
+ return (
+
+
Sidebar
+
Header
+
}>
+
+
+
+
Footer
+
+ )
+}
+
+function DataDisplay({ dataPromise }: { dataPromise: Promise }) {
+ const data = use(dataPromise) // Unwraps the promise
+ return {data.content}
+}
+
+function DataSummary({ dataPromise }: { dataPromise: Promise }) {
+ const data = use(dataPromise) // Reuses the same promise
+ return {data.summary}
+}
+```
+
+Both components share the same promise, so only one fetch occurs. Layout renders immediately while both components wait together.
+
+**When NOT to use this pattern:**
+
+- Critical data needed for layout decisions (affects positioning)
+
+- SEO-critical content above the fold
+
+- Small, fast queries where suspense overhead isn't worth it
+
+- When you want to avoid layout shift (loading → content jump)
+
+**Trade-off:** Faster initial paint vs potential layout shift. Choose based on your UX priorities.
+
+---
+
+## 2. Bundle Size Optimization
+
+**Impact: CRITICAL**
+
+Reducing initial bundle size improves Time to Interactive and Largest Contentful Paint.
+
+### 2.1 Avoid Barrel File Imports
+
+**Impact: CRITICAL (200-800ms import cost, slow builds)**
+
+Import directly from source files instead of barrel files to avoid loading thousands of unused modules. **Barrel files** are entry points that re-export multiple modules (e.g., `index.js` that does `export * from './module'`).
+
+Popular icon and component libraries can have **up to 10,000 re-exports** in their entry file. For many React packages, **it takes 200-800ms just to import them**, affecting both development speed and production cold starts.
+
+**Why tree-shaking doesn't help:** When a library is marked as external (not bundled), the bundler can't optimize it. If you bundle it to enable tree-shaking, builds become substantially slower analyzing the entire module graph.
+
+**Incorrect: imports entire library**
+
+```tsx
+import { Check, X, Menu } from 'lucide-react'
+// Loads 1,583 modules, takes ~2.8s extra in dev
+// Runtime cost: 200-800ms on every cold start
+
+import { Button, TextField } from '@mui/material'
+// Loads 2,225 modules, takes ~4.2s extra in dev
+```
+
+**Correct: imports only what you need**
+
+```tsx
+import Check from 'lucide-react/dist/esm/icons/check'
+import X from 'lucide-react/dist/esm/icons/x'
+import Menu from 'lucide-react/dist/esm/icons/menu'
+// Loads only 3 modules (~2KB vs ~1MB)
+
+import Button from '@mui/material/Button'
+import TextField from '@mui/material/TextField'
+// Loads only what you use
+```
+
+**Alternative: Next.js 13.5+**
+
+```js
+// next.config.js - use optimizePackageImports
+module.exports = {
+ experimental: {
+ optimizePackageImports: ['lucide-react', '@mui/material']
+ }
+}
+
+// Then you can keep the ergonomic barrel imports:
+import { Check, X, Menu } from 'lucide-react'
+// Automatically transformed to direct imports at build time
+```
+
+Direct imports provide 15-70% faster dev boot, 28% faster builds, 40% faster cold starts, and significantly faster HMR.
+
+Libraries commonly affected: `lucide-react`, `@mui/material`, `@mui/icons-material`, `@tabler/icons-react`, `react-icons`, `@headlessui/react`, `@radix-ui/react-*`, `lodash`, `ramda`, `date-fns`, `rxjs`, `react-use`.
+
+Reference: [https://vercel.com/blog/how-we-optimized-package-imports-in-next-js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
+
+### 2.2 Conditional Module Loading
+
+**Impact: HIGH (loads large data only when needed)**
+
+Load large data or modules only when a feature is activated.
+
+**Example: lazy-load animation frames**
+
+```tsx
+function AnimationPlayer({ enabled, setEnabled }: { enabled: boolean; setEnabled: React.Dispatch> }) {
+ const [frames, setFrames] = useState (null)
+
+ useEffect(() => {
+ if (enabled && !frames && typeof window !== 'undefined') {
+ import('./animation-frames.js')
+ .then(mod => setFrames(mod.frames))
+ .catch(() => setEnabled(false))
+ }
+ }, [enabled, frames, setEnabled])
+
+ if (!frames) return
+ return
+}
+```
+
+The `typeof window !== 'undefined'` check prevents bundling this module for SSR, optimizing server bundle size and build speed.
+
+### 2.3 Defer Non-Critical Third-Party Libraries
+
+**Impact: MEDIUM (loads after hydration)**
+
+Analytics, logging, and error tracking don't block user interaction. Load them after hydration.
+
+**Incorrect: blocks initial bundle**
+
+```tsx
+import { Analytics } from '@vercel/analytics/react'
+
+export default function RootLayout({ children }) {
+ return (
+
+
+ {children}
+
+
+
+ )
+}
+```
+
+**Correct: loads after hydration**
+
+```tsx
+import dynamic from 'next/dynamic'
+
+const Analytics = dynamic(
+ () => import('@vercel/analytics/react').then(m => m.Analytics),
+ { ssr: false }
+)
+
+export default function RootLayout({ children }) {
+ return (
+
+
+ {children}
+
+
+
+ )
+}
+```
+
+### 2.4 Dynamic Imports for Heavy Components
+
+**Impact: CRITICAL (directly affects TTI and LCP)**
+
+Use `next/dynamic` to lazy-load large components not needed on initial render.
+
+**Incorrect: Monaco bundles with main chunk ~300KB**
+
+```tsx
+import { MonacoEditor } from './monaco-editor'
+
+function CodePanel({ code }: { code: string }) {
+ return
+}
+```
+
+**Correct: Monaco loads on demand**
+
+```tsx
+import dynamic from 'next/dynamic'
+
+const MonacoEditor = dynamic(
+ () => import('./monaco-editor').then(m => m.MonacoEditor),
+ { ssr: false }
+)
+
+function CodePanel({ code }: { code: string }) {
+ return
+}
+```
+
+### 2.5 Preload Based on User Intent
+
+**Impact: MEDIUM (reduces perceived latency)**
+
+Preload heavy bundles before they're needed to reduce perceived latency.
+
+**Example: preload on hover/focus**
+
+```tsx
+function EditorButton({ onClick }: { onClick: () => void }) {
+ const preload = () => {
+ if (typeof window !== 'undefined') {
+ void import('./monaco-editor')
+ }
+ }
+
+ return (
+
+ Open Editor
+
+ )
+}
+```
+
+**Example: preload when feature flag is enabled**
+
+```tsx
+function FlagsProvider({ children, flags }: Props) {
+ useEffect(() => {
+ if (flags.editorEnabled && typeof window !== 'undefined') {
+ void import('./monaco-editor').then(mod => mod.init())
+ }
+ }, [flags.editorEnabled])
+
+ return
+ {children}
+
+}
+```
+
+The `typeof window !== 'undefined'` check prevents bundling preloaded modules for SSR, optimizing server bundle size and build speed.
+
+---
+
+## 3. Server-Side Performance
+
+**Impact: HIGH**
+
+Optimizing server-side rendering and data fetching eliminates server-side waterfalls and reduces response times.
+
+### 3.1 Cross-Request LRU Caching
+
+**Impact: HIGH (caches across requests)**
+
+`React.cache()` only works within one request. For data shared across sequential requests (user clicks button A then button B), use an LRU cache.
+
+**Implementation:**
+
+```typescript
+import { LRUCache } from 'lru-cache'
+
+const cache = new LRUCache({
+ max: 1000,
+ ttl: 5 * 60 * 1000 // 5 minutes
+})
+
+export async function getUser(id: string) {
+ const cached = cache.get(id)
+ if (cached) return cached
+
+ const user = await db.user.findUnique({ where: { id } })
+ cache.set(id, user)
+ return user
+}
+
+// Request 1: DB query, result cached
+// Request 2: cache hit, no DB query
+```
+
+Use when sequential user actions hit multiple endpoints needing the same data within seconds.
+
+**With Vercel's [Fluid Compute](https://vercel.com/docs/fluid-compute):** LRU caching is especially effective because multiple concurrent requests can share the same function instance and cache. This means the cache persists across requests without needing external storage like Redis.
+
+**In traditional serverless:** Each invocation runs in isolation, so consider Redis for cross-process caching.
+
+Reference: [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
+
+### 3.2 Minimize Serialization at RSC Boundaries
+
+**Impact: HIGH (reduces data transfer size)**
+
+The React Server/Client boundary serializes all object properties into strings and embeds them in the HTML response and subsequent RSC requests. This serialized data directly impacts page weight and load time, so **size matters a lot**. Only pass fields that the client actually uses.
+
+**Incorrect: serializes all 50 fields**
+
+```tsx
+async function Page() {
+ const user = await fetchUser() // 50 fields
+ return
+}
+
+'use client'
+function Profile({ user }: { user: User }) {
+ return {user.name}
// uses 1 field
+}
+```
+
+**Correct: serializes only 1 field**
+
+```tsx
+async function Page() {
+ const user = await fetchUser()
+ return
+}
+
+'use client'
+function Profile({ name }: { name: string }) {
+ return {name}
+}
+```
+
+### 3.3 Parallel Data Fetching with Component Composition
+
+**Impact: CRITICAL (eliminates server-side waterfalls)**
+
+React Server Components execute sequentially within a tree. Restructure with composition to parallelize data fetching.
+
+**Incorrect: Sidebar waits for Page's fetch to complete**
+
+```tsx
+export default async function Page() {
+ const header = await fetchHeader()
+ return (
+
+ )
+}
+
+async function Sidebar() {
+ const items = await fetchSidebarItems()
+ return {items.map(renderItem)}
+}
+```
+
+**Correct: both fetch simultaneously**
+
+```tsx
+async function Header() {
+ const data = await fetchHeader()
+ return {data}
+}
+
+async function Sidebar() {
+ const items = await fetchSidebarItems()
+ return {items.map(renderItem)}
+}
+
+export default function Page() {
+ return (
+
+
+
+
+ )
+}
+```
+
+**Alternative with children prop:**
+
+```tsx
+async function Header() {
+ const data = await fetchHeader()
+ return {data}
+}
+
+async function Sidebar() {
+ const items = await fetchSidebarItems()
+ return {items.map(renderItem)}
+}
+
+function Layout({ children }: { children: ReactNode }) {
+ return (
+
+
+ {children}
+
+ )
+}
+
+export default function Page() {
+ return (
+
+
+
+ )
+}
+```
+
+### 3.4 Per-Request Deduplication with React.cache()
+
+**Impact: MEDIUM (deduplicates within request)**
+
+Use `React.cache()` for server-side request deduplication. Authentication and database queries benefit most.
+
+**Usage:**
+
+```typescript
+import { cache } from 'react'
+
+export const getCurrentUser = cache(async () => {
+ const session = await auth()
+ if (!session?.user?.id) return null
+ return await db.user.findUnique({
+ where: { id: session.user.id }
+ })
+})
+```
+
+Within a single request, multiple calls to `getCurrentUser()` execute the query only once.
+
+**Avoid inline objects as arguments:**
+
+`React.cache()` uses shallow equality (`Object.is`) to determine cache hits. Inline objects create new references each call, preventing cache hits.
+
+**Incorrect: always cache miss**
+
+```typescript
+const getUser = cache(async (params: { uid: number }) => {
+ return await db.user.findUnique({ where: { id: params.uid } })
+})
+
+// Each call creates new object, never hits cache
+getUser({ uid: 1 })
+getUser({ uid: 1 }) // Cache miss, runs query again
+```
+
+**Correct: cache hit**
+
+```typescript
+const params = { uid: 1 }
+getUser(params) // Query runs
+getUser(params) // Cache hit (same reference)
+```
+
+If you must pass objects, pass the same reference:
+
+**Next.js-Specific Note:**
+
+In Next.js, the `fetch` API is automatically extended with request memoization. Requests with the same URL and options are automatically deduplicated within a single request, so you don't need `React.cache()` for `fetch` calls. However, `React.cache()` is still essential for other async tasks:
+
+- Database queries (Prisma, Drizzle, etc.)
+
+- Heavy computations
+
+- Authentication checks
+
+- File system operations
+
+- Any non-fetch async work
+
+Use `React.cache()` to deduplicate these operations across your component tree.
+
+Reference: [https://react.dev/reference/react/cache](https://react.dev/reference/react/cache)
+
+### 3.5 Use after() for Non-Blocking Operations
+
+**Impact: MEDIUM (faster response times)**
+
+Use Next.js's `after()` to schedule work that should execute after a response is sent. This prevents logging, analytics, and other side effects from blocking the response.
+
+**Incorrect: blocks response**
+
+```tsx
+import { logUserAction } from '@/app/utils'
+
+export async function POST(request: Request) {
+ // Perform mutation
+ await updateDatabase(request)
+
+ // Logging blocks the response
+ const userAgent = request.headers.get('user-agent') || 'unknown'
+ await logUserAction({ userAgent })
+
+ return new Response(JSON.stringify({ status: 'success' }), {
+ status: 200,
+ headers: { 'Content-Type': 'application/json' }
+ })
+}
+```
+
+**Correct: non-blocking**
+
+```tsx
+import { after } from 'next/server'
+import { headers, cookies } from 'next/headers'
+import { logUserAction } from '@/app/utils'
+
+export async function POST(request: Request) {
+ // Perform mutation
+ await updateDatabase(request)
+
+ // Log after response is sent
+ after(async () => {
+ const userAgent = (await headers()).get('user-agent') || 'unknown'
+ const sessionCookie = (await cookies()).get('session-id')?.value || 'anonymous'
+
+ logUserAction({ sessionCookie, userAgent })
+ })
+
+ return new Response(JSON.stringify({ status: 'success' }), {
+ status: 200,
+ headers: { 'Content-Type': 'application/json' }
+ })
+}
+```
+
+The response is sent immediately while logging happens in the background.
+
+**Common use cases:**
+
+- Analytics tracking
+
+- Audit logging
+
+- Sending notifications
+
+- Cache invalidation
+
+- Cleanup tasks
+
+**Important notes:**
+
+- `after()` runs even if the response fails or redirects
+
+- Works in Server Actions, Route Handlers, and Server Components
+
+Reference: [https://nextjs.org/docs/app/api-reference/functions/after](https://nextjs.org/docs/app/api-reference/functions/after)
+
+---
+
+## 4. Client-Side Data Fetching
+
+**Impact: MEDIUM-HIGH**
+
+Automatic deduplication and efficient data fetching patterns reduce redundant network requests.
+
+### 4.1 Deduplicate Global Event Listeners
+
+**Impact: LOW (single listener for N components)**
+
+Use `useSWRSubscription()` to share global event listeners across component instances.
+
+**Incorrect: N instances = N listeners**
+
+```tsx
+function useKeyboardShortcut(key: string, callback: () => void) {
+ useEffect(() => {
+ const handler = (e: KeyboardEvent) => {
+ if (e.metaKey && e.key === key) {
+ callback()
+ }
+ }
+ window.addEventListener('keydown', handler)
+ return () => window.removeEventListener('keydown', handler)
+ }, [key, callback])
+}
+```
+
+When using the `useKeyboardShortcut` hook multiple times, each instance will register a new listener.
+
+**Correct: N instances = 1 listener**
+
+```tsx
+import useSWRSubscription from 'swr/subscription'
+
+// Module-level Map to track callbacks per key
+const keyCallbacks = new Map void>>()
+
+function useKeyboardShortcut(key: string, callback: () => void) {
+ // Register this callback in the Map
+ useEffect(() => {
+ if (!keyCallbacks.has(key)) {
+ keyCallbacks.set(key, new Set())
+ }
+ keyCallbacks.get(key)!.add(callback)
+
+ return () => {
+ const set = keyCallbacks.get(key)
+ if (set) {
+ set.delete(callback)
+ if (set.size === 0) {
+ keyCallbacks.delete(key)
+ }
+ }
+ }
+ }, [key, callback])
+
+ useSWRSubscription('global-keydown', () => {
+ const handler = (e: KeyboardEvent) => {
+ if (e.metaKey && keyCallbacks.has(e.key)) {
+ keyCallbacks.get(e.key)!.forEach(cb => cb())
+ }
+ }
+ window.addEventListener('keydown', handler)
+ return () => window.removeEventListener('keydown', handler)
+ })
+}
+
+function Profile() {
+ // Multiple shortcuts will share the same listener
+ useKeyboardShortcut('p', () => { /* ... */ })
+ useKeyboardShortcut('k', () => { /* ... */ })
+ // ...
+}
+```
+
+### 4.2 Use Passive Event Listeners for Scrolling Performance
+
+**Impact: MEDIUM (eliminates scroll delay caused by event listeners)**
+
+Add `{ passive: true }` to touch and wheel event listeners to enable immediate scrolling. Browsers normally wait for listeners to finish to check if `preventDefault()` is called, causing scroll delay.
+
+**Incorrect:**
+
+```typescript
+useEffect(() => {
+ const handleTouch = (e: TouchEvent) => console.log(e.touches[0].clientX)
+ const handleWheel = (e: WheelEvent) => console.log(e.deltaY)
+
+ document.addEventListener('touchstart', handleTouch)
+ document.addEventListener('wheel', handleWheel)
+
+ return () => {
+ document.removeEventListener('touchstart', handleTouch)
+ document.removeEventListener('wheel', handleWheel)
+ }
+}, [])
+```
+
+**Correct:**
+
+```typescript
+useEffect(() => {
+ const handleTouch = (e: TouchEvent) => console.log(e.touches[0].clientX)
+ const handleWheel = (e: WheelEvent) => console.log(e.deltaY)
+
+ document.addEventListener('touchstart', handleTouch, { passive: true })
+ document.addEventListener('wheel', handleWheel, { passive: true })
+
+ return () => {
+ document.removeEventListener('touchstart', handleTouch)
+ document.removeEventListener('wheel', handleWheel)
+ }
+}, [])
+```
+
+**Use passive when:** tracking/analytics, logging, any listener that doesn't call `preventDefault()`.
+
+**Don't use passive when:** implementing custom swipe gestures, custom zoom controls, or any listener that needs `preventDefault()`.
+
+### 4.3 Use SWR for Automatic Deduplication
+
+**Impact: MEDIUM-HIGH (automatic deduplication)**
+
+SWR enables request deduplication, caching, and revalidation across component instances.
+
+**Incorrect: no deduplication, each instance fetches**
+
+```tsx
+function UserList() {
+ const [users, setUsers] = useState([])
+ useEffect(() => {
+ fetch('/api/users')
+ .then(r => r.json())
+ .then(setUsers)
+ }, [])
+}
+```
+
+**Correct: multiple instances share one request**
+
+```tsx
+import useSWR from 'swr'
+
+function UserList() {
+ const { data: users } = useSWR('/api/users', fetcher)
+}
+```
+
+**For immutable data:**
+
+```tsx
+import { useImmutableSWR } from '@/lib/swr'
+
+function StaticContent() {
+ const { data } = useImmutableSWR('/api/config', fetcher)
+}
+```
+
+**For mutations:**
+
+```tsx
+import { useSWRMutation } from 'swr/mutation'
+
+function UpdateButton() {
+ const { trigger } = useSWRMutation('/api/user', updateUser)
+ return trigger()}>Update
+}
+```
+
+Reference: [https://swr.vercel.app](https://swr.vercel.app)
+
+### 4.4 Version and Minimize localStorage Data
+
+**Impact: MEDIUM (prevents schema conflicts, reduces storage size)**
+
+Add version prefix to keys and store only needed fields. Prevents schema conflicts and accidental storage of sensitive data.
+
+**Incorrect:**
+
+```typescript
+// No version, stores everything, no error handling
+localStorage.setItem('userConfig', JSON.stringify(fullUserObject))
+const data = localStorage.getItem('userConfig')
+```
+
+**Correct:**
+
+```typescript
+const VERSION = 'v2'
+
+function saveConfig(config: { theme: string; language: string }) {
+ try {
+ localStorage.setItem(`userConfig:${VERSION}`, JSON.stringify(config))
+ } catch {
+ // Throws in incognito/private browsing, quota exceeded, or disabled
+ }
+}
+
+function loadConfig() {
+ try {
+ const data = localStorage.getItem(`userConfig:${VERSION}`)
+ return data ? JSON.parse(data) : null
+ } catch {
+ return null
+ }
+}
+
+// Migration from v1 to v2
+function migrate() {
+ try {
+ const v1 = localStorage.getItem('userConfig:v1')
+ if (v1) {
+ const old = JSON.parse(v1)
+ saveConfig({ theme: old.darkMode ? 'dark' : 'light', language: old.lang })
+ localStorage.removeItem('userConfig:v1')
+ }
+ } catch {}
+}
+```
+
+**Store minimal fields from server responses:**
+
+```typescript
+// User object has 20+ fields, only store what UI needs
+function cachePrefs(user: FullUser) {
+ try {
+ localStorage.setItem('prefs:v1', JSON.stringify({
+ theme: user.preferences.theme,
+ notifications: user.preferences.notifications
+ }))
+ } catch {}
+}
+```
+
+**Always wrap in try-catch:** `getItem()` and `setItem()` throw in incognito/private browsing (Safari, Firefox), when quota exceeded, or when disabled.
+
+**Benefits:** Schema evolution via versioning, reduced storage size, prevents storing tokens/PII/internal flags.
+
+---
+
+## 5. Re-render Optimization
+
+**Impact: MEDIUM**
+
+Reducing unnecessary re-renders minimizes wasted computation and improves UI responsiveness.
+
+### 5.1 Defer State Reads to Usage Point
+
+**Impact: MEDIUM (avoids unnecessary subscriptions)**
+
+Don't subscribe to dynamic state (searchParams, localStorage) if you only read it inside callbacks.
+
+**Incorrect: subscribes to all searchParams changes**
+
+```tsx
+function ShareButton({ chatId }: { chatId: string }) {
+ const searchParams = useSearchParams()
+
+ const handleShare = () => {
+ const ref = searchParams.get('ref')
+ shareChat(chatId, { ref })
+ }
+
+ return Share
+}
+```
+
+**Correct: reads on demand, no subscription**
+
+```tsx
+function ShareButton({ chatId }: { chatId: string }) {
+ const handleShare = () => {
+ const params = new URLSearchParams(window.location.search)
+ const ref = params.get('ref')
+ shareChat(chatId, { ref })
+ }
+
+ return Share
+}
+```
+
+### 5.2 Extract to Memoized Components
+
+**Impact: MEDIUM (enables early returns)**
+
+Extract expensive work into memoized components to enable early returns before computation.
+
+**Incorrect: computes avatar even when loading**
+
+```tsx
+function Profile({ user, loading }: Props) {
+ const avatar = useMemo(() => {
+ const id = computeAvatarId(user)
+ return
+ }, [user])
+
+ if (loading) return
+ return {avatar}
+}
+```
+
+**Correct: skips computation when loading**
+
+```tsx
+const UserAvatar = memo(function UserAvatar({ user }: { user: User }) {
+ const id = useMemo(() => computeAvatarId(user), [user])
+ return
+})
+
+function Profile({ user, loading }: Props) {
+ if (loading) return
+ return (
+
+
+
+ )
+}
+```
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, manual memoization with `memo()` and `useMemo()` is not necessary. The compiler automatically optimizes re-renders.
+
+### 5.3 Narrow Effect Dependencies
+
+**Impact: LOW (minimizes effect re-runs)**
+
+Specify primitive dependencies instead of objects to minimize effect re-runs.
+
+**Incorrect: re-runs on any user field change**
+
+```tsx
+useEffect(() => {
+ console.log(user.id)
+}, [user])
+```
+
+**Correct: re-runs only when id changes**
+
+```tsx
+useEffect(() => {
+ console.log(user.id)
+}, [user.id])
+```
+
+**For derived state, compute outside effect:**
+
+```tsx
+// Incorrect: runs on width=767, 766, 765...
+useEffect(() => {
+ if (width < 768) {
+ enableMobileMode()
+ }
+}, [width])
+
+// Correct: runs only on boolean transition
+const isMobile = width < 768
+useEffect(() => {
+ if (isMobile) {
+ enableMobileMode()
+ }
+}, [isMobile])
+```
+
+### 5.4 Subscribe to Derived State
+
+**Impact: MEDIUM (reduces re-render frequency)**
+
+Subscribe to derived boolean state instead of continuous values to reduce re-render frequency.
+
+**Incorrect: re-renders on every pixel change**
+
+```tsx
+function Sidebar() {
+ const width = useWindowWidth() // updates continuously
+ const isMobile = width < 768
+ return
+}
+```
+
+**Correct: re-renders only when boolean changes**
+
+```tsx
+function Sidebar() {
+ const isMobile = useMediaQuery('(max-width: 767px)')
+ return
+}
+```
+
+### 5.5 Use Functional setState Updates
+
+**Impact: MEDIUM (prevents stale closures and unnecessary callback recreations)**
+
+When updating state based on the current state value, use the functional update form of setState instead of directly referencing the state variable. This prevents stale closures, eliminates unnecessary dependencies, and creates stable callback references.
+
+**Incorrect: requires state as dependency**
+
+```tsx
+function TodoList() {
+ const [items, setItems] = useState(initialItems)
+
+ // Callback must depend on items, recreated on every items change
+ const addItems = useCallback((newItems: Item[]) => {
+ setItems([...items, ...newItems])
+ }, [items]) // ❌ items dependency causes recreations
+
+ // Risk of stale closure if dependency is forgotten
+ const removeItem = useCallback((id: string) => {
+ setItems(items.filter(item => item.id !== id))
+ }, []) // ❌ Missing items dependency - will use stale items!
+
+ return
+}
+```
+
+The first callback is recreated every time `items` changes, which can cause child components to re-render unnecessarily. The second callback has a stale closure bug—it will always reference the initial `items` value.
+
+**Correct: stable callbacks, no stale closures**
+
+```tsx
+function TodoList() {
+ const [items, setItems] = useState(initialItems)
+
+ // Stable callback, never recreated
+ const addItems = useCallback((newItems: Item[]) => {
+ setItems(curr => [...curr, ...newItems])
+ }, []) // ✅ No dependencies needed
+
+ // Always uses latest state, no stale closure risk
+ const removeItem = useCallback((id: string) => {
+ setItems(curr => curr.filter(item => item.id !== id))
+ }, []) // ✅ Safe and stable
+
+ return
+}
+```
+
+**Benefits:**
+
+1. **Stable callback references** - Callbacks don't need to be recreated when state changes
+
+2. **No stale closures** - Always operates on the latest state value
+
+3. **Fewer dependencies** - Simplifies dependency arrays and reduces memory leaks
+
+4. **Prevents bugs** - Eliminates the most common source of React closure bugs
+
+**When to use functional updates:**
+
+- Any setState that depends on the current state value
+
+- Inside useCallback/useMemo when state is needed
+
+- Event handlers that reference state
+
+- Async operations that update state
+
+**When direct updates are fine:**
+
+- Setting state to a static value: `setCount(0)`
+
+- Setting state from props/arguments only: `setName(newName)`
+
+- State doesn't depend on previous value
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler can automatically optimize some cases, but functional updates are still recommended for correctness and to prevent stale closure bugs.
+
+### 5.6 Use Lazy State Initialization
+
+**Impact: MEDIUM (wasted computation on every render)**
+
+Pass a function to `useState` for expensive initial values. Without the function form, the initializer runs on every render even though the value is only used once.
+
+**Incorrect: runs on every render**
+
+```tsx
+function FilteredList({ items }: { items: Item[] }) {
+ // buildSearchIndex() runs on EVERY render, even after initialization
+ const [searchIndex, setSearchIndex] = useState(buildSearchIndex(items))
+ const [query, setQuery] = useState('')
+
+ // When query changes, buildSearchIndex runs again unnecessarily
+ return
+}
+
+function UserProfile() {
+ // JSON.parse runs on every render
+ const [settings, setSettings] = useState(
+ JSON.parse(localStorage.getItem('settings') || '{}')
+ )
+
+ return
+}
+```
+
+**Correct: runs only once**
+
+```tsx
+function FilteredList({ items }: { items: Item[] }) {
+ // buildSearchIndex() runs ONLY on initial render
+ const [searchIndex, setSearchIndex] = useState(() => buildSearchIndex(items))
+ const [query, setQuery] = useState('')
+
+ return
+}
+
+function UserProfile() {
+ // JSON.parse runs only on initial render
+ const [settings, setSettings] = useState(() => {
+ const stored = localStorage.getItem('settings')
+ return stored ? JSON.parse(stored) : {}
+ })
+
+ return
+}
+```
+
+Use lazy initialization when computing initial values from localStorage/sessionStorage, building data structures (indexes, maps), reading from the DOM, or performing heavy transformations.
+
+For simple primitives (`useState(0)`), direct references (`useState(props.value)`), or cheap literals (`useState({})`), the function form is unnecessary.
+
+### 5.7 Use Transitions for Non-Urgent Updates
+
+**Impact: MEDIUM (maintains UI responsiveness)**
+
+Mark frequent, non-urgent state updates as transitions to maintain UI responsiveness.
+
+**Incorrect: blocks UI on every scroll**
+
+```tsx
+function ScrollTracker() {
+ const [scrollY, setScrollY] = useState(0)
+ useEffect(() => {
+ const handler = () => setScrollY(window.scrollY)
+ window.addEventListener('scroll', handler, { passive: true })
+ return () => window.removeEventListener('scroll', handler)
+ }, [])
+}
+```
+
+**Correct: non-blocking updates**
+
+```tsx
+import { startTransition } from 'react'
+
+function ScrollTracker() {
+ const [scrollY, setScrollY] = useState(0)
+ useEffect(() => {
+ const handler = () => {
+ startTransition(() => setScrollY(window.scrollY))
+ }
+ window.addEventListener('scroll', handler, { passive: true })
+ return () => window.removeEventListener('scroll', handler)
+ }, [])
+}
+```
+
+---
+
+## 6. Rendering Performance
+
+**Impact: MEDIUM**
+
+Optimizing the rendering process reduces the work the browser needs to do.
+
+### 6.1 Animate SVG Wrapper Instead of SVG Element
+
+**Impact: LOW (enables hardware acceleration)**
+
+Many browsers don't have hardware acceleration for CSS3 animations on SVG elements. Wrap SVG in a `` and animate the wrapper instead.
+
+**Incorrect: animating SVG directly - no hardware acceleration**
+
+```tsx
+function LoadingSpinner() {
+ return (
+
+
+
+ )
+}
+```
+
+**Correct: animating wrapper div - hardware accelerated**
+
+```tsx
+function LoadingSpinner() {
+ return (
+
+
+
+
+
+ )
+}
+```
+
+This applies to all CSS transforms and transitions (`transform`, `opacity`, `translate`, `scale`, `rotate`). The wrapper div allows browsers to use GPU acceleration for smoother animations.
+
+### 6.2 CSS content-visibility for Long Lists
+
+**Impact: HIGH (faster initial render)**
+
+Apply `content-visibility: auto` to defer off-screen rendering.
+
+**CSS:**
+
+```css
+.message-item {
+ content-visibility: auto;
+ contain-intrinsic-size: 0 80px;
+}
+```
+
+**Example:**
+
+```tsx
+function MessageList({ messages }: { messages: Message[] }) {
+ return (
+
+ {messages.map(msg => (
+
+ ))}
+
+ )
+}
+```
+
+For 1000 messages, browser skips layout/paint for ~990 off-screen items (10× faster initial render).
+
+### 6.3 Hoist Static JSX Elements
+
+**Impact: LOW (avoids re-creation)**
+
+Extract static JSX outside components to avoid re-creation.
+
+**Incorrect: recreates element every render**
+
+```tsx
+function LoadingSkeleton() {
+ return
+}
+
+function Container() {
+ return (
+
+ {loading && }
+
+ )
+}
+```
+
+**Correct: reuses same element**
+
+```tsx
+const loadingSkeleton = (
+
+)
+
+function Container() {
+ return (
+
+ {loading && loadingSkeleton}
+
+ )
+}
+```
+
+This is especially helpful for large and static SVG nodes, which can be expensive to recreate on every render.
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler automatically hoists static JSX elements and optimizes component re-renders, making manual hoisting unnecessary.
+
+### 6.4 Optimize SVG Precision
+
+**Impact: LOW (reduces file size)**
+
+Reduce SVG coordinate precision to decrease file size. The optimal precision depends on the viewBox size, but in general reducing precision should be considered.
+
+**Incorrect: excessive precision**
+
+```svg
+
+```
+
+**Correct: 1 decimal place**
+
+```svg
+
+```
+
+**Automate with SVGO:**
+
+```bash
+npx svgo --precision=1 --multipass icon.svg
+```
+
+### 6.5 Prevent Hydration Mismatch Without Flickering
+
+**Impact: MEDIUM (avoids visual flicker and hydration errors)**
+
+When rendering content that depends on client-side storage (localStorage, cookies), avoid both SSR breakage and post-hydration flickering by injecting a synchronous script that updates the DOM before React hydrates.
+
+**Incorrect: breaks SSR**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ // localStorage is not available on server - throws error
+ const theme = localStorage.getItem('theme') || 'light'
+
+ return (
+
+ {children}
+
+ )
+}
+```
+
+Server-side rendering will fail because `localStorage` is undefined.
+
+**Incorrect: visual flickering**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ const [theme, setTheme] = useState('light')
+
+ useEffect(() => {
+ // Runs after hydration - causes visible flash
+ const stored = localStorage.getItem('theme')
+ if (stored) {
+ setTheme(stored)
+ }
+ }, [])
+
+ return (
+
+ {children}
+
+ )
+}
+```
+
+Component first renders with default value (`light`), then updates after hydration, causing a visible flash of incorrect content.
+
+**Correct: no flicker, no hydration mismatch**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ return (
+ <>
+
+ {children}
+
+
+ >
+ )
+}
+```
+
+The inline script executes synchronously before showing the element, ensuring the DOM already has the correct value. No flickering, no hydration mismatch.
+
+This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
+
+### 6.6 Use Activity Component for Show/Hide
+
+**Impact: MEDIUM (preserves state/DOM)**
+
+Use React's `
` to preserve state/DOM for expensive components that frequently toggle visibility.
+
+**Usage:**
+
+```tsx
+import { Activity } from 'react'
+
+function Dropdown({ isOpen }: Props) {
+ return (
+
+
+
+ )
+}
+```
+
+Avoids expensive re-renders and state loss.
+
+### 6.7 Use Explicit Conditional Rendering
+
+**Impact: LOW (prevents rendering 0 or NaN)**
+
+Use explicit ternary operators (`? :`) instead of `&&` for conditional rendering when the condition can be `0`, `NaN`, or other falsy values that render.
+
+**Incorrect: renders "0" when count is 0**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count && {count} }
+
+ )
+}
+
+// When count = 0, renders: 0
+// When count = 5, renders: 5
+```
+
+**Correct: renders nothing when count is 0**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count > 0 ? {count} : null}
+
+ )
+}
+
+// When count = 0, renders:
+// When count = 5, renders: 5
+```
+
+---
+
+## 7. JavaScript Performance
+
+**Impact: LOW-MEDIUM**
+
+Micro-optimizations for hot paths can add up to meaningful improvements.
+
+### 7.1 Batch DOM CSS Changes
+
+**Impact: MEDIUM (reduces reflows/repaints)**
+
+Avoid changing styles one property at a time. Group multiple CSS changes together via classes or `cssText` to minimize browser reflows.
+
+**Incorrect: multiple reflows**
+
+```typescript
+function updateElementStyles(element: HTMLElement) {
+ // Each line triggers a reflow
+ element.style.width = '100px'
+ element.style.height = '200px'
+ element.style.backgroundColor = 'blue'
+ element.style.border = '1px solid black'
+}
+```
+
+**Correct: add class - single reflow**
+
+```typescript
+// CSS file
+.highlighted-box {
+ width: 100px;
+ height: 200px;
+ background-color: blue;
+ border: 1px solid black;
+}
+
+// JavaScript
+function updateElementStyles(element: HTMLElement) {
+ element.classList.add('highlighted-box')
+}
+```
+
+**Correct: change cssText - single reflow**
+
+```typescript
+function updateElementStyles(element: HTMLElement) {
+ element.style.cssText = `
+ width: 100px;
+ height: 200px;
+ background-color: blue;
+ border: 1px solid black;
+ `
+}
+```
+
+**React example:**
+
+```tsx
+// Incorrect: changing styles one by one
+function Box({ isHighlighted }: { isHighlighted: boolean }) {
+ const ref = useRef(null)
+
+ useEffect(() => {
+ if (ref.current && isHighlighted) {
+ ref.current.style.width = '100px'
+ ref.current.style.height = '200px'
+ ref.current.style.backgroundColor = 'blue'
+ }
+ }, [isHighlighted])
+
+ return Content
+}
+
+// Correct: toggle class
+function Box({ isHighlighted }: { isHighlighted: boolean }) {
+ return (
+
+ Content
+
+ )
+}
+```
+
+Prefer CSS classes over inline styles when possible. Classes are cached by the browser and provide better separation of concerns.
+
+### 7.2 Build Index Maps for Repeated Lookups
+
+**Impact: LOW-MEDIUM (1M ops to 2K ops)**
+
+Multiple `.find()` calls by the same key should use a Map.
+
+**Incorrect (O(n) per lookup):**
+
+```typescript
+function processOrders(orders: Order[], users: User[]) {
+ return orders.map(order => ({
+ ...order,
+ user: users.find(u => u.id === order.userId)
+ }))
+}
+```
+
+**Correct (O(1) per lookup):**
+
+```typescript
+function processOrders(orders: Order[], users: User[]) {
+ const userById = new Map(users.map(u => [u.id, u]))
+
+ return orders.map(order => ({
+ ...order,
+ user: userById.get(order.userId)
+ }))
+}
+```
+
+Build map once (O(n)), then all lookups are O(1).
+
+For 1000 orders × 1000 users: 1M ops → 2K ops.
+
+### 7.3 Cache Property Access in Loops
+
+**Impact: LOW-MEDIUM (reduces lookups)**
+
+Cache object property lookups in hot paths.
+
+**Incorrect: 3 lookups × N iterations**
+
+```typescript
+for (let i = 0; i < arr.length; i++) {
+ process(obj.config.settings.value)
+}
+```
+
+**Correct: 1 lookup total**
+
+```typescript
+const value = obj.config.settings.value
+const len = arr.length
+for (let i = 0; i < len; i++) {
+ process(value)
+}
+```
+
+### 7.4 Cache Repeated Function Calls
+
+**Impact: MEDIUM (avoid redundant computation)**
+
+Use a module-level Map to cache function results when the same function is called repeatedly with the same inputs during render.
+
+**Incorrect: redundant computation**
+
+```typescript
+function ProjectList({ projects }: { projects: Project[] }) {
+ return (
+
+ {projects.map(project => {
+ // slugify() called 100+ times for same project names
+ const slug = slugify(project.name)
+
+ return
+ })}
+
+ )
+}
+```
+
+**Correct: cached results**
+
+```typescript
+// Module-level cache
+const slugifyCache = new Map()
+
+function cachedSlugify(text: string): string {
+ if (slugifyCache.has(text)) {
+ return slugifyCache.get(text)!
+ }
+ const result = slugify(text)
+ slugifyCache.set(text, result)
+ return result
+}
+
+function ProjectList({ projects }: { projects: Project[] }) {
+ return (
+
+ {projects.map(project => {
+ // Computed only once per unique project name
+ const slug = cachedSlugify(project.name)
+
+ return
+ })}
+
+ )
+}
+```
+
+**Simpler pattern for single-value functions:**
+
+```typescript
+let isLoggedInCache: boolean | null = null
+
+function isLoggedIn(): boolean {
+ if (isLoggedInCache !== null) {
+ return isLoggedInCache
+ }
+
+ isLoggedInCache = document.cookie.includes('auth=')
+ return isLoggedInCache
+}
+
+// Clear cache when auth changes
+function onAuthChange() {
+ isLoggedInCache = null
+}
+```
+
+Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
+
+Reference: [https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast](https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast)
+
+### 7.5 Cache Storage API Calls
+
+**Impact: LOW-MEDIUM (reduces expensive I/O)**
+
+`localStorage`, `sessionStorage`, and `document.cookie` are synchronous and expensive. Cache reads in memory.
+
+**Incorrect: reads storage on every call**
+
+```typescript
+function getTheme() {
+ return localStorage.getItem('theme') ?? 'light'
+}
+// Called 10 times = 10 storage reads
+```
+
+**Correct: Map cache**
+
+```typescript
+const storageCache = new Map()
+
+function getLocalStorage(key: string) {
+ if (!storageCache.has(key)) {
+ storageCache.set(key, localStorage.getItem(key))
+ }
+ return storageCache.get(key)
+}
+
+function setLocalStorage(key: string, value: string) {
+ localStorage.setItem(key, value)
+ storageCache.set(key, value) // keep cache in sync
+}
+```
+
+Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
+
+**Cookie caching:**
+
+```typescript
+let cookieCache: Record | null = null
+
+function getCookie(name: string) {
+ if (!cookieCache) {
+ cookieCache = Object.fromEntries(
+ document.cookie.split('; ').map(c => c.split('='))
+ )
+ }
+ return cookieCache[name]
+}
+```
+
+**Important: invalidate on external changes**
+
+```typescript
+window.addEventListener('storage', (e) => {
+ if (e.key) storageCache.delete(e.key)
+})
+
+document.addEventListener('visibilitychange', () => {
+ if (document.visibilityState === 'visible') {
+ storageCache.clear()
+ }
+})
+```
+
+If storage can change externally (another tab, server-set cookies), invalidate cache:
+
+### 7.6 Combine Multiple Array Iterations
+
+**Impact: LOW-MEDIUM (reduces iterations)**
+
+Multiple `.filter()` or `.map()` calls iterate the array multiple times. Combine into one loop.
+
+**Incorrect: 3 iterations**
+
+```typescript
+const admins = users.filter(u => u.isAdmin)
+const testers = users.filter(u => u.isTester)
+const inactive = users.filter(u => !u.isActive)
+```
+
+**Correct: 1 iteration**
+
+```typescript
+const admins: User[] = []
+const testers: User[] = []
+const inactive: User[] = []
+
+for (const user of users) {
+ if (user.isAdmin) admins.push(user)
+ if (user.isTester) testers.push(user)
+ if (!user.isActive) inactive.push(user)
+}
+```
+
+### 7.7 Early Length Check for Array Comparisons
+
+**Impact: MEDIUM-HIGH (avoids expensive operations when lengths differ)**
+
+When comparing arrays with expensive operations (sorting, deep equality, serialization), check lengths first. If lengths differ, the arrays cannot be equal.
+
+In real-world applications, this optimization is especially valuable when the comparison runs in hot paths (event handlers, render loops).
+
+**Incorrect: always runs expensive comparison**
+
+```typescript
+function hasChanges(current: string[], original: string[]) {
+ // Always sorts and joins, even when lengths differ
+ return current.sort().join() !== original.sort().join()
+}
+```
+
+Two O(n log n) sorts run even when `current.length` is 5 and `original.length` is 100. There is also overhead of joining the arrays and comparing the strings.
+
+**Correct (O(1) length check first):**
+
+```typescript
+function hasChanges(current: string[], original: string[]) {
+ // Early return if lengths differ
+ if (current.length !== original.length) {
+ return true
+ }
+ // Only sort/join when lengths match
+ const currentSorted = current.toSorted()
+ const originalSorted = original.toSorted()
+ for (let i = 0; i < currentSorted.length; i++) {
+ if (currentSorted[i] !== originalSorted[i]) {
+ return true
+ }
+ }
+ return false
+}
+```
+
+This new approach is more efficient because:
+
+- It avoids the overhead of sorting and joining the arrays when lengths differ
+
+- It avoids consuming memory for the joined strings (especially important for large arrays)
+
+- It avoids mutating the original arrays
+
+- It returns early when a difference is found
+
+### 7.8 Early Return from Functions
+
+**Impact: LOW-MEDIUM (avoids unnecessary computation)**
+
+Return early when result is determined to skip unnecessary processing.
+
+**Incorrect: processes all items even after finding answer**
+
+```typescript
+function validateUsers(users: User[]) {
+ let hasError = false
+ let errorMessage = ''
+
+ for (const user of users) {
+ if (!user.email) {
+ hasError = true
+ errorMessage = 'Email required'
+ }
+ if (!user.name) {
+ hasError = true
+ errorMessage = 'Name required'
+ }
+ // Continues checking all users even after error found
+ }
+
+ return hasError ? { valid: false, error: errorMessage } : { valid: true }
+}
+```
+
+**Correct: returns immediately on first error**
+
+```typescript
+function validateUsers(users: User[]) {
+ for (const user of users) {
+ if (!user.email) {
+ return { valid: false, error: 'Email required' }
+ }
+ if (!user.name) {
+ return { valid: false, error: 'Name required' }
+ }
+ }
+
+ return { valid: true }
+}
+```
+
+### 7.9 Hoist RegExp Creation
+
+**Impact: LOW-MEDIUM (avoids recreation)**
+
+Don't create RegExp inside render. Hoist to module scope or memoize with `useMemo()`.
+
+**Incorrect: new RegExp every render**
+
+```tsx
+function Highlighter({ text, query }: Props) {
+ const regex = new RegExp(`(${query})`, 'gi')
+ const parts = text.split(regex)
+ return <>{parts.map((part, i) => ...)}>
+}
+```
+
+**Correct: memoize or hoist**
+
+```tsx
+const EMAIL_REGEX = /^[^\s@]+@[^\s@]+\.[^\s@]+$/
+
+function Highlighter({ text, query }: Props) {
+ const regex = useMemo(
+ () => new RegExp(`(${escapeRegex(query)})`, 'gi'),
+ [query]
+ )
+ const parts = text.split(regex)
+ return <>{parts.map((part, i) => ...)}>
+}
+```
+
+**Warning: global regex has mutable state**
+
+```typescript
+const regex = /foo/g
+regex.test('foo') // true, lastIndex = 3
+regex.test('foo') // false, lastIndex = 0
+```
+
+Global regex (`/g`) has mutable `lastIndex` state:
+
+### 7.10 Use Loop for Min/Max Instead of Sort
+
+**Impact: LOW (O(n) instead of O(n log n))**
+
+Finding the smallest or largest element only requires a single pass through the array. Sorting is wasteful and slower.
+
+**Incorrect (O(n log n) - sort to find latest):**
+
+```typescript
+interface Project {
+ id: string
+ name: string
+ updatedAt: number
+}
+
+function getLatestProject(projects: Project[]) {
+ const sorted = [...projects].sort((a, b) => b.updatedAt - a.updatedAt)
+ return sorted[0]
+}
+```
+
+Sorts the entire array just to find the maximum value.
+
+**Incorrect (O(n log n) - sort for oldest and newest):**
+
+```typescript
+function getOldestAndNewest(projects: Project[]) {
+ const sorted = [...projects].sort((a, b) => a.updatedAt - b.updatedAt)
+ return { oldest: sorted[0], newest: sorted[sorted.length - 1] }
+}
+```
+
+Still sorts unnecessarily when only min/max are needed.
+
+**Correct (O(n) - single loop):**
+
+```typescript
+function getLatestProject(projects: Project[]) {
+ if (projects.length === 0) return null
+
+ let latest = projects[0]
+
+ for (let i = 1; i < projects.length; i++) {
+ if (projects[i].updatedAt > latest.updatedAt) {
+ latest = projects[i]
+ }
+ }
+
+ return latest
+}
+
+function getOldestAndNewest(projects: Project[]) {
+ if (projects.length === 0) return { oldest: null, newest: null }
+
+ let oldest = projects[0]
+ let newest = projects[0]
+
+ for (let i = 1; i < projects.length; i++) {
+ if (projects[i].updatedAt < oldest.updatedAt) oldest = projects[i]
+ if (projects[i].updatedAt > newest.updatedAt) newest = projects[i]
+ }
+
+ return { oldest, newest }
+}
+```
+
+Single pass through the array, no copying, no sorting.
+
+**Alternative: Math.min/Math.max for small arrays**
+
+```typescript
+const numbers = [5, 2, 8, 1, 9]
+const min = Math.min(...numbers)
+const max = Math.max(...numbers)
+```
+
+This works for small arrays but can be slower for very large arrays due to spread operator limitations. Use the loop approach for reliability.
+
+### 7.11 Use Set/Map for O(1) Lookups
+
+**Impact: LOW-MEDIUM (O(n) to O(1))**
+
+Convert arrays to Set/Map for repeated membership checks.
+
+**Incorrect (O(n) per check):**
+
+```typescript
+const allowedIds = ['a', 'b', 'c', ...]
+items.filter(item => allowedIds.includes(item.id))
+```
+
+**Correct (O(1) per check):**
+
+```typescript
+const allowedIds = new Set(['a', 'b', 'c', ...])
+items.filter(item => allowedIds.has(item.id))
+```
+
+### 7.12 Use toSorted() Instead of sort() for Immutability
+
+**Impact: MEDIUM-HIGH (prevents mutation bugs in React state)**
+
+`.sort()` mutates the array in place, which can cause bugs with React state and props. Use `.toSorted()` to create a new sorted array without mutation.
+
+**Incorrect: mutates original array**
+
+```typescript
+function UserList({ users }: { users: User[] }) {
+ // Mutates the users prop array!
+ const sorted = useMemo(
+ () => users.sort((a, b) => a.name.localeCompare(b.name)),
+ [users]
+ )
+ return {sorted.map(renderUser)}
+}
+```
+
+**Correct: creates new array**
+
+```typescript
+function UserList({ users }: { users: User[] }) {
+ // Creates new sorted array, original unchanged
+ const sorted = useMemo(
+ () => users.toSorted((a, b) => a.name.localeCompare(b.name)),
+ [users]
+ )
+ return {sorted.map(renderUser)}
+}
+```
+
+**Why this matters in React:**
+
+1. Props/state mutations break React's immutability model - React expects props and state to be treated as read-only
+
+2. Causes stale closure bugs - Mutating arrays inside closures (callbacks, effects) can lead to unexpected behavior
+
+**Browser support: fallback for older browsers**
+
+```typescript
+// Fallback for older browsers
+const sorted = [...items].sort((a, b) => a.value - b.value)
+```
+
+`.toSorted()` is available in all modern browsers (Chrome 110+, Safari 16+, Firefox 115+, Node.js 20+). For older environments, use spread operator:
+
+**Other immutable array methods:**
+
+- `.toSorted()` - immutable sort
+
+- `.toReversed()` - immutable reverse
+
+- `.toSpliced()` - immutable splice
+
+- `.with()` - immutable element replacement
+
+---
+
+## 8. Advanced Patterns
+
+**Impact: LOW**
+
+Advanced patterns for specific cases that require careful implementation.
+
+### 8.1 Store Event Handlers in Refs
+
+**Impact: LOW (stable subscriptions)**
+
+Store callbacks in refs when used in effects that shouldn't re-subscribe on callback changes.
+
+**Incorrect: re-subscribes on every render**
+
+```tsx
+function useWindowEvent(event: string, handler: () => void) {
+ useEffect(() => {
+ window.addEventListener(event, handler)
+ return () => window.removeEventListener(event, handler)
+ }, [event, handler])
+}
+```
+
+**Correct: stable subscription**
+
+```tsx
+import { useEffectEvent } from 'react'
+
+function useWindowEvent(event: string, handler: () => void) {
+ const onEvent = useEffectEvent(handler)
+
+ useEffect(() => {
+ window.addEventListener(event, onEvent)
+ return () => window.removeEventListener(event, onEvent)
+ }, [event])
+}
+```
+
+**Alternative: use `useEffectEvent` if you're on latest React:**
+
+`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
+
+### 8.2 useLatest for Stable Callback Refs
+
+**Impact: LOW (prevents effect re-runs)**
+
+Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
+
+**Implementation:**
+
+```typescript
+function useLatest(value: T) {
+ const ref = useRef(value)
+ useEffect(() => {
+ ref.current = value
+ }, [value])
+ return ref
+}
+```
+
+**Incorrect: effect re-runs on every callback change**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearch(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query, onSearch])
+}
+```
+
+**Correct: stable effect, fresh callback**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+ const onSearchRef = useLatest(onSearch)
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearchRef.current(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query])
+}
+```
+
+---
+
+## References
+
+1. [https://react.dev](https://react.dev)
+2. [https://nextjs.org](https://nextjs.org)
+3. [https://swr.vercel.app](https://swr.vercel.app)
+4. [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
+5. [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
+6. [https://vercel.com/blog/how-we-optimized-package-imports-in-next-js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
+7. [https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast](https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast)
diff --git a/.agents/skills/vercel-react-best-practices/SKILL.md b/.agents/skills/vercel-react-best-practices/SKILL.md
new file mode 100644
index 000000000..b064716f6
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/SKILL.md
@@ -0,0 +1,125 @@
+---
+name: vercel-react-best-practices
+description: React and Next.js performance optimization guidelines from Vercel Engineering. This skill should be used when writing, reviewing, or refactoring React/Next.js code to ensure optimal performance patterns. Triggers on tasks involving React components, Next.js pages, data fetching, bundle optimization, or performance improvements.
+license: MIT
+metadata:
+ author: vercel
+ version: "1.0.0"
+---
+
+# Vercel React Best Practices
+
+Comprehensive performance optimization guide for React and Next.js applications, maintained by Vercel. Contains 45 rules across 8 categories, prioritized by impact to guide automated refactoring and code generation.
+
+## When to Apply
+
+Reference these guidelines when:
+- Writing new React components or Next.js pages
+- Implementing data fetching (client or server-side)
+- Reviewing code for performance issues
+- Refactoring existing React/Next.js code
+- Optimizing bundle size or load times
+
+## Rule Categories by Priority
+
+| Priority | Category | Impact | Prefix |
+|----------|----------|--------|--------|
+| 1 | Eliminating Waterfalls | CRITICAL | `async-` |
+| 2 | Bundle Size Optimization | CRITICAL | `bundle-` |
+| 3 | Server-Side Performance | HIGH | `server-` |
+| 4 | Client-Side Data Fetching | MEDIUM-HIGH | `client-` |
+| 5 | Re-render Optimization | MEDIUM | `rerender-` |
+| 6 | Rendering Performance | MEDIUM | `rendering-` |
+| 7 | JavaScript Performance | LOW-MEDIUM | `js-` |
+| 8 | Advanced Patterns | LOW | `advanced-` |
+
+## Quick Reference
+
+### 1. Eliminating Waterfalls (CRITICAL)
+
+- `async-defer-await` - Move await into branches where actually used
+- `async-parallel` - Use Promise.all() for independent operations
+- `async-dependencies` - Use better-all for partial dependencies
+- `async-api-routes` - Start promises early, await late in API routes
+- `async-suspense-boundaries` - Use Suspense to stream content
+
+### 2. Bundle Size Optimization (CRITICAL)
+
+- `bundle-barrel-imports` - Import directly, avoid barrel files
+- `bundle-dynamic-imports` - Use next/dynamic for heavy components
+- `bundle-defer-third-party` - Load analytics/logging after hydration
+- `bundle-conditional` - Load modules only when feature is activated
+- `bundle-preload` - Preload on hover/focus for perceived speed
+
+### 3. Server-Side Performance (HIGH)
+
+- `server-cache-react` - Use React.cache() for per-request deduplication
+- `server-cache-lru` - Use LRU cache for cross-request caching
+- `server-serialization` - Minimize data passed to client components
+- `server-parallel-fetching` - Restructure components to parallelize fetches
+- `server-after-nonblocking` - Use after() for non-blocking operations
+
+### 4. Client-Side Data Fetching (MEDIUM-HIGH)
+
+- `client-swr-dedup` - Use SWR for automatic request deduplication
+- `client-event-listeners` - Deduplicate global event listeners
+
+### 5. Re-render Optimization (MEDIUM)
+
+- `rerender-defer-reads` - Don't subscribe to state only used in callbacks
+- `rerender-memo` - Extract expensive work into memoized components
+- `rerender-dependencies` - Use primitive dependencies in effects
+- `rerender-derived-state` - Subscribe to derived booleans, not raw values
+- `rerender-functional-setstate` - Use functional setState for stable callbacks
+- `rerender-lazy-state-init` - Pass function to useState for expensive values
+- `rerender-transitions` - Use startTransition for non-urgent updates
+
+### 6. Rendering Performance (MEDIUM)
+
+- `rendering-animate-svg-wrapper` - Animate div wrapper, not SVG element
+- `rendering-content-visibility` - Use content-visibility for long lists
+- `rendering-hoist-jsx` - Extract static JSX outside components
+- `rendering-svg-precision` - Reduce SVG coordinate precision
+- `rendering-hydration-no-flicker` - Use inline script for client-only data
+- `rendering-activity` - Use Activity component for show/hide
+- `rendering-conditional-render` - Use ternary, not && for conditionals
+
+### 7. JavaScript Performance (LOW-MEDIUM)
+
+- `js-batch-dom-css` - Group CSS changes via classes or cssText
+- `js-index-maps` - Build Map for repeated lookups
+- `js-cache-property-access` - Cache object properties in loops
+- `js-cache-function-results` - Cache function results in module-level Map
+- `js-cache-storage` - Cache localStorage/sessionStorage reads
+- `js-combine-iterations` - Combine multiple filter/map into one loop
+- `js-length-check-first` - Check array length before expensive comparison
+- `js-early-exit` - Return early from functions
+- `js-hoist-regexp` - Hoist RegExp creation outside loops
+- `js-min-max-loop` - Use loop for min/max instead of sort
+- `js-set-map-lookups` - Use Set/Map for O(1) lookups
+- `js-tosorted-immutable` - Use toSorted() for immutability
+
+### 8. Advanced Patterns (LOW)
+
+- `advanced-event-handler-refs` - Store event handlers in refs
+- `advanced-use-latest` - useLatest for stable callback refs
+
+## How to Use
+
+Read individual rule files for detailed explanations and code examples:
+
+```
+rules/async-parallel.md
+rules/bundle-barrel-imports.md
+rules/_sections.md
+```
+
+Each rule file contains:
+- Brief explanation of why it matters
+- Incorrect code example with explanation
+- Correct code example with explanation
+- Additional context and references
+
+## Full Compiled Document
+
+For the complete guide with all rules expanded: `AGENTS.md`
diff --git a/.agents/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md b/.agents/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md
new file mode 100644
index 000000000..97e7ade24
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md
@@ -0,0 +1,55 @@
+---
+title: Store Event Handlers in Refs
+impact: LOW
+impactDescription: stable subscriptions
+tags: advanced, hooks, refs, event-handlers, optimization
+---
+
+## Store Event Handlers in Refs
+
+Store callbacks in refs when used in effects that shouldn't re-subscribe on callback changes.
+
+**Incorrect (re-subscribes on every render):**
+
+```tsx
+function useWindowEvent(event: string, handler: (e) => void) {
+ useEffect(() => {
+ window.addEventListener(event, handler)
+ return () => window.removeEventListener(event, handler)
+ }, [event, handler])
+}
+```
+
+**Correct (stable subscription):**
+
+```tsx
+function useWindowEvent(event: string, handler: (e) => void) {
+ const handlerRef = useRef(handler)
+ useEffect(() => {
+ handlerRef.current = handler
+ }, [handler])
+
+ useEffect(() => {
+ const listener = (e) => handlerRef.current(e)
+ window.addEventListener(event, listener)
+ return () => window.removeEventListener(event, listener)
+ }, [event])
+}
+```
+
+**Alternative: use `useEffectEvent` if you're on latest React:**
+
+```tsx
+import { useEffectEvent } from 'react'
+
+function useWindowEvent(event: string, handler: (e) => void) {
+ const onEvent = useEffectEvent(handler)
+
+ useEffect(() => {
+ window.addEventListener(event, onEvent)
+ return () => window.removeEventListener(event, onEvent)
+ }, [event])
+}
+```
+
+`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
diff --git a/.agents/skills/vercel-react-best-practices/rules/advanced-use-latest.md b/.agents/skills/vercel-react-best-practices/rules/advanced-use-latest.md
new file mode 100644
index 000000000..483c2ef7d
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/advanced-use-latest.md
@@ -0,0 +1,49 @@
+---
+title: useLatest for Stable Callback Refs
+impact: LOW
+impactDescription: prevents effect re-runs
+tags: advanced, hooks, useLatest, refs, optimization
+---
+
+## useLatest for Stable Callback Refs
+
+Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
+
+**Implementation:**
+
+```typescript
+function useLatest(value: T) {
+ const ref = useRef(value)
+ useLayoutEffect(() => {
+ ref.current = value
+ }, [value])
+ return ref
+}
+```
+
+**Incorrect (effect re-runs on every callback change):**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearch(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query, onSearch])
+}
+```
+
+**Correct (stable effect, fresh callback):**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+ const onSearchRef = useLatest(onSearch)
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearchRef.current(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query])
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/async-api-routes.md b/.agents/skills/vercel-react-best-practices/rules/async-api-routes.md
new file mode 100644
index 000000000..6feda1ef0
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/async-api-routes.md
@@ -0,0 +1,38 @@
+---
+title: Prevent Waterfall Chains in API Routes
+impact: CRITICAL
+impactDescription: 2-10× improvement
+tags: api-routes, server-actions, waterfalls, parallelization
+---
+
+## Prevent Waterfall Chains in API Routes
+
+In API routes and Server Actions, start independent operations immediately, even if you don't await them yet.
+
+**Incorrect (config waits for auth, data waits for both):**
+
+```typescript
+export async function GET(request: Request) {
+ const session = await auth()
+ const config = await fetchConfig()
+ const data = await fetchData(session.user.id)
+ return Response.json({ data, config })
+}
+```
+
+**Correct (auth and config start immediately):**
+
+```typescript
+export async function GET(request: Request) {
+ const sessionPromise = auth()
+ const configPromise = fetchConfig()
+ const session = await sessionPromise
+ const [config, data] = await Promise.all([
+ configPromise,
+ fetchData(session.user.id)
+ ])
+ return Response.json({ data, config })
+}
+```
+
+For operations with more complex dependency chains, use `better-all` to automatically maximize parallelism (see Dependency-Based Parallelization).
diff --git a/.agents/skills/vercel-react-best-practices/rules/async-defer-await.md b/.agents/skills/vercel-react-best-practices/rules/async-defer-await.md
new file mode 100644
index 000000000..ea7082a36
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/async-defer-await.md
@@ -0,0 +1,80 @@
+---
+title: Defer Await Until Needed
+impact: HIGH
+impactDescription: avoids blocking unused code paths
+tags: async, await, conditional, optimization
+---
+
+## Defer Await Until Needed
+
+Move `await` operations into the branches where they're actually used to avoid blocking code paths that don't need them.
+
+**Incorrect (blocks both branches):**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ const userData = await fetchUserData(userId)
+
+ if (skipProcessing) {
+ // Returns immediately but still waited for userData
+ return { skipped: true }
+ }
+
+ // Only this branch uses userData
+ return processUserData(userData)
+}
+```
+
+**Correct (only blocks when needed):**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ if (skipProcessing) {
+ // Returns immediately without waiting
+ return { skipped: true }
+ }
+
+ // Fetch only when needed
+ const userData = await fetchUserData(userId)
+ return processUserData(userData)
+}
+```
+
+**Another example (early return optimization):**
+
+```typescript
+// Incorrect: always fetches permissions
+async function updateResource(resourceId: string, userId: string) {
+ const permissions = await fetchPermissions(userId)
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+
+// Correct: fetches only when needed
+async function updateResource(resourceId: string, userId: string) {
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ const permissions = await fetchPermissions(userId)
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+```
+
+This optimization is especially valuable when the skipped branch is frequently taken, or when the deferred operation is expensive.
diff --git a/.agents/skills/vercel-react-best-practices/rules/async-dependencies.md b/.agents/skills/vercel-react-best-practices/rules/async-dependencies.md
new file mode 100644
index 000000000..fb90d861a
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/async-dependencies.md
@@ -0,0 +1,36 @@
+---
+title: Dependency-Based Parallelization
+impact: CRITICAL
+impactDescription: 2-10× improvement
+tags: async, parallelization, dependencies, better-all
+---
+
+## Dependency-Based Parallelization
+
+For operations with partial dependencies, use `better-all` to maximize parallelism. It automatically starts each task at the earliest possible moment.
+
+**Incorrect (profile waits for config unnecessarily):**
+
+```typescript
+const [user, config] = await Promise.all([
+ fetchUser(),
+ fetchConfig()
+])
+const profile = await fetchProfile(user.id)
+```
+
+**Correct (config and profile run in parallel):**
+
+```typescript
+import { all } from 'better-all'
+
+const { user, config, profile } = await all({
+ async user() { return fetchUser() },
+ async config() { return fetchConfig() },
+ async profile() {
+ return fetchProfile((await this.$.user).id)
+ }
+})
+```
+
+Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
diff --git a/.agents/skills/vercel-react-best-practices/rules/async-parallel.md b/.agents/skills/vercel-react-best-practices/rules/async-parallel.md
new file mode 100644
index 000000000..64133f6c3
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/async-parallel.md
@@ -0,0 +1,28 @@
+---
+title: Promise.all() for Independent Operations
+impact: CRITICAL
+impactDescription: 2-10× improvement
+tags: async, parallelization, promises, waterfalls
+---
+
+## Promise.all() for Independent Operations
+
+When async operations have no interdependencies, execute them concurrently using `Promise.all()`.
+
+**Incorrect (sequential execution, 3 round trips):**
+
+```typescript
+const user = await fetchUser()
+const posts = await fetchPosts()
+const comments = await fetchComments()
+```
+
+**Correct (parallel execution, 1 round trip):**
+
+```typescript
+const [user, posts, comments] = await Promise.all([
+ fetchUser(),
+ fetchPosts(),
+ fetchComments()
+])
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md b/.agents/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md
new file mode 100644
index 000000000..1fbc05b04
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md
@@ -0,0 +1,99 @@
+---
+title: Strategic Suspense Boundaries
+impact: HIGH
+impactDescription: faster initial paint
+tags: async, suspense, streaming, layout-shift
+---
+
+## Strategic Suspense Boundaries
+
+Instead of awaiting data in async components before returning JSX, use Suspense boundaries to show the wrapper UI faster while data loads.
+
+**Incorrect (wrapper blocked by data fetching):**
+
+```tsx
+async function Page() {
+ const data = await fetchData() // Blocks entire page
+
+ return (
+
+
Sidebar
+
Header
+
+
+
+
Footer
+
+ )
+}
+```
+
+The entire layout waits for data even though only the middle section needs it.
+
+**Correct (wrapper shows immediately, data streams in):**
+
+```tsx
+function Page() {
+ return (
+
+
Sidebar
+
Header
+
+ }>
+
+
+
+
Footer
+
+ )
+}
+
+async function DataDisplay() {
+ const data = await fetchData() // Only blocks this component
+ return {data.content}
+}
+```
+
+Sidebar, Header, and Footer render immediately. Only DataDisplay waits for data.
+
+**Alternative (share promise across components):**
+
+```tsx
+function Page() {
+ // Start fetch immediately, but don't await
+ const dataPromise = fetchData()
+
+ return (
+
+
Sidebar
+
Header
+
}>
+
+
+
+
Footer
+
+ )
+}
+
+function DataDisplay({ dataPromise }: { dataPromise: Promise }) {
+ const data = use(dataPromise) // Unwraps the promise
+ return {data.content}
+}
+
+function DataSummary({ dataPromise }: { dataPromise: Promise }) {
+ const data = use(dataPromise) // Reuses the same promise
+ return {data.summary}
+}
+```
+
+Both components share the same promise, so only one fetch occurs. Layout renders immediately while both components wait together.
+
+**When NOT to use this pattern:**
+
+- Critical data needed for layout decisions (affects positioning)
+- SEO-critical content above the fold
+- Small, fast queries where suspense overhead isn't worth it
+- When you want to avoid layout shift (loading → content jump)
+
+**Trade-off:** Faster initial paint vs potential layout shift. Choose based on your UX priorities.
diff --git a/.agents/skills/vercel-react-best-practices/rules/bundle-barrel-imports.md b/.agents/skills/vercel-react-best-practices/rules/bundle-barrel-imports.md
new file mode 100644
index 000000000..ee48f3273
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/bundle-barrel-imports.md
@@ -0,0 +1,59 @@
+---
+title: Avoid Barrel File Imports
+impact: CRITICAL
+impactDescription: 200-800ms import cost, slow builds
+tags: bundle, imports, tree-shaking, barrel-files, performance
+---
+
+## Avoid Barrel File Imports
+
+Import directly from source files instead of barrel files to avoid loading thousands of unused modules. **Barrel files** are entry points that re-export multiple modules (e.g., `index.js` that does `export * from './module'`).
+
+Popular icon and component libraries can have **up to 10,000 re-exports** in their entry file. For many React packages, **it takes 200-800ms just to import them**, affecting both development speed and production cold starts.
+
+**Why tree-shaking doesn't help:** When a library is marked as external (not bundled), the bundler can't optimize it. If you bundle it to enable tree-shaking, builds become substantially slower analyzing the entire module graph.
+
+**Incorrect (imports entire library):**
+
+```tsx
+import { Check, X, Menu } from 'lucide-react'
+// Loads 1,583 modules, takes ~2.8s extra in dev
+// Runtime cost: 200-800ms on every cold start
+
+import { Button, TextField } from '@mui/material'
+// Loads 2,225 modules, takes ~4.2s extra in dev
+```
+
+**Correct (imports only what you need):**
+
+```tsx
+import Check from 'lucide-react/dist/esm/icons/check'
+import X from 'lucide-react/dist/esm/icons/x'
+import Menu from 'lucide-react/dist/esm/icons/menu'
+// Loads only 3 modules (~2KB vs ~1MB)
+
+import Button from '@mui/material/Button'
+import TextField from '@mui/material/TextField'
+// Loads only what you use
+```
+
+**Alternative (Next.js 13.5+):**
+
+```js
+// next.config.js - use optimizePackageImports
+module.exports = {
+ experimental: {
+ optimizePackageImports: ['lucide-react', '@mui/material']
+ }
+}
+
+// Then you can keep the ergonomic barrel imports:
+import { Check, X, Menu } from 'lucide-react'
+// Automatically transformed to direct imports at build time
+```
+
+Direct imports provide 15-70% faster dev boot, 28% faster builds, 40% faster cold starts, and significantly faster HMR.
+
+Libraries commonly affected: `lucide-react`, `@mui/material`, `@mui/icons-material`, `@tabler/icons-react`, `react-icons`, `@headlessui/react`, `@radix-ui/react-*`, `lodash`, `ramda`, `date-fns`, `rxjs`, `react-use`.
+
+Reference: [How we optimized package imports in Next.js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
diff --git a/.agents/skills/vercel-react-best-practices/rules/bundle-conditional.md b/.agents/skills/vercel-react-best-practices/rules/bundle-conditional.md
new file mode 100644
index 000000000..99d6fc90e
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/bundle-conditional.md
@@ -0,0 +1,31 @@
+---
+title: Conditional Module Loading
+impact: HIGH
+impactDescription: loads large data only when needed
+tags: bundle, conditional-loading, lazy-loading
+---
+
+## Conditional Module Loading
+
+Load large data or modules only when a feature is activated.
+
+**Example (lazy-load animation frames):**
+
+```tsx
+function AnimationPlayer({ enabled, setEnabled }: { enabled: boolean; setEnabled: React.Dispatch> }) {
+ const [frames, setFrames] = useState (null)
+
+ useEffect(() => {
+ if (enabled && !frames && typeof window !== 'undefined') {
+ import('./animation-frames.js')
+ .then(mod => setFrames(mod.frames))
+ .catch(() => setEnabled(false))
+ }
+ }, [enabled, frames, setEnabled])
+
+ if (!frames) return
+ return
+}
+```
+
+The `typeof window !== 'undefined'` check prevents bundling this module for SSR, optimizing server bundle size and build speed.
diff --git a/.agents/skills/vercel-react-best-practices/rules/bundle-defer-third-party.md b/.agents/skills/vercel-react-best-practices/rules/bundle-defer-third-party.md
new file mode 100644
index 000000000..db041d151
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/bundle-defer-third-party.md
@@ -0,0 +1,49 @@
+---
+title: Defer Non-Critical Third-Party Libraries
+impact: MEDIUM
+impactDescription: loads after hydration
+tags: bundle, third-party, analytics, defer
+---
+
+## Defer Non-Critical Third-Party Libraries
+
+Analytics, logging, and error tracking don't block user interaction. Load them after hydration.
+
+**Incorrect (blocks initial bundle):**
+
+```tsx
+import { Analytics } from '@vercel/analytics/react'
+
+export default function RootLayout({ children }) {
+ return (
+
+
+ {children}
+
+
+
+ )
+}
+```
+
+**Correct (loads after hydration):**
+
+```tsx
+import dynamic from 'next/dynamic'
+
+const Analytics = dynamic(
+ () => import('@vercel/analytics/react').then(m => m.Analytics),
+ { ssr: false }
+)
+
+export default function RootLayout({ children }) {
+ return (
+
+
+ {children}
+
+
+
+ )
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/bundle-dynamic-imports.md b/.agents/skills/vercel-react-best-practices/rules/bundle-dynamic-imports.md
new file mode 100644
index 000000000..60b62695e
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/bundle-dynamic-imports.md
@@ -0,0 +1,35 @@
+---
+title: Dynamic Imports for Heavy Components
+impact: CRITICAL
+impactDescription: directly affects TTI and LCP
+tags: bundle, dynamic-import, code-splitting, next-dynamic
+---
+
+## Dynamic Imports for Heavy Components
+
+Use `next/dynamic` to lazy-load large components not needed on initial render.
+
+**Incorrect (Monaco bundles with main chunk ~300KB):**
+
+```tsx
+import { MonacoEditor } from './monaco-editor'
+
+function CodePanel({ code }: { code: string }) {
+ return
+}
+```
+
+**Correct (Monaco loads on demand):**
+
+```tsx
+import dynamic from 'next/dynamic'
+
+const MonacoEditor = dynamic(
+ () => import('./monaco-editor').then(m => m.MonacoEditor),
+ { ssr: false }
+)
+
+function CodePanel({ code }: { code: string }) {
+ return
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/bundle-preload.md b/.agents/skills/vercel-react-best-practices/rules/bundle-preload.md
new file mode 100644
index 000000000..700050406
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/bundle-preload.md
@@ -0,0 +1,50 @@
+---
+title: Preload Based on User Intent
+impact: MEDIUM
+impactDescription: reduces perceived latency
+tags: bundle, preload, user-intent, hover
+---
+
+## Preload Based on User Intent
+
+Preload heavy bundles before they're needed to reduce perceived latency.
+
+**Example (preload on hover/focus):**
+
+```tsx
+function EditorButton({ onClick }: { onClick: () => void }) {
+ const preload = () => {
+ if (typeof window !== 'undefined') {
+ void import('./monaco-editor')
+ }
+ }
+
+ return (
+
+ Open Editor
+
+ )
+}
+```
+
+**Example (preload when feature flag is enabled):**
+
+```tsx
+function FlagsProvider({ children, flags }: Props) {
+ useEffect(() => {
+ if (flags.editorEnabled && typeof window !== 'undefined') {
+ void import('./monaco-editor').then(mod => mod.init())
+ }
+ }, [flags.editorEnabled])
+
+ return
+ {children}
+
+}
+```
+
+The `typeof window !== 'undefined'` check prevents bundling preloaded modules for SSR, optimizing server bundle size and build speed.
diff --git a/.agents/skills/vercel-react-best-practices/rules/client-event-listeners.md b/.agents/skills/vercel-react-best-practices/rules/client-event-listeners.md
new file mode 100644
index 000000000..aad4ae916
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/client-event-listeners.md
@@ -0,0 +1,74 @@
+---
+title: Deduplicate Global Event Listeners
+impact: LOW
+impactDescription: single listener for N components
+tags: client, swr, event-listeners, subscription
+---
+
+## Deduplicate Global Event Listeners
+
+Use `useSWRSubscription()` to share global event listeners across component instances.
+
+**Incorrect (N instances = N listeners):**
+
+```tsx
+function useKeyboardShortcut(key: string, callback: () => void) {
+ useEffect(() => {
+ const handler = (e: KeyboardEvent) => {
+ if (e.metaKey && e.key === key) {
+ callback()
+ }
+ }
+ window.addEventListener('keydown', handler)
+ return () => window.removeEventListener('keydown', handler)
+ }, [key, callback])
+}
+```
+
+When using the `useKeyboardShortcut` hook multiple times, each instance will register a new listener.
+
+**Correct (N instances = 1 listener):**
+
+```tsx
+import useSWRSubscription from 'swr/subscription'
+
+// Module-level Map to track callbacks per key
+const keyCallbacks = new Map void>>()
+
+function useKeyboardShortcut(key: string, callback: () => void) {
+ // Register this callback in the Map
+ useEffect(() => {
+ if (!keyCallbacks.has(key)) {
+ keyCallbacks.set(key, new Set())
+ }
+ keyCallbacks.get(key)!.add(callback)
+
+ return () => {
+ const set = keyCallbacks.get(key)
+ if (set) {
+ set.delete(callback)
+ if (set.size === 0) {
+ keyCallbacks.delete(key)
+ }
+ }
+ }
+ }, [key, callback])
+
+ useSWRSubscription('global-keydown', () => {
+ const handler = (e: KeyboardEvent) => {
+ if (e.metaKey && keyCallbacks.has(e.key)) {
+ keyCallbacks.get(e.key)!.forEach(cb => cb())
+ }
+ }
+ window.addEventListener('keydown', handler)
+ return () => window.removeEventListener('keydown', handler)
+ })
+}
+
+function Profile() {
+ // Multiple shortcuts will share the same listener
+ useKeyboardShortcut('p', () => { /* ... */ })
+ useKeyboardShortcut('k', () => { /* ... */ })
+ // ...
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/client-localstorage-schema.md b/.agents/skills/vercel-react-best-practices/rules/client-localstorage-schema.md
new file mode 100644
index 000000000..d30a1a7d4
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/client-localstorage-schema.md
@@ -0,0 +1,71 @@
+---
+title: Version and Minimize localStorage Data
+impact: MEDIUM
+impactDescription: prevents schema conflicts, reduces storage size
+tags: client, localStorage, storage, versioning, data-minimization
+---
+
+## Version and Minimize localStorage Data
+
+Add version prefix to keys and store only needed fields. Prevents schema conflicts and accidental storage of sensitive data.
+
+**Incorrect:**
+
+```typescript
+// No version, stores everything, no error handling
+localStorage.setItem('userConfig', JSON.stringify(fullUserObject))
+const data = localStorage.getItem('userConfig')
+```
+
+**Correct:**
+
+```typescript
+const VERSION = 'v2'
+
+function saveConfig(config: { theme: string; language: string }) {
+ try {
+ localStorage.setItem(`userConfig:${VERSION}`, JSON.stringify(config))
+ } catch {
+ // Throws in incognito/private browsing, quota exceeded, or disabled
+ }
+}
+
+function loadConfig() {
+ try {
+ const data = localStorage.getItem(`userConfig:${VERSION}`)
+ return data ? JSON.parse(data) : null
+ } catch {
+ return null
+ }
+}
+
+// Migration from v1 to v2
+function migrate() {
+ try {
+ const v1 = localStorage.getItem('userConfig:v1')
+ if (v1) {
+ const old = JSON.parse(v1)
+ saveConfig({ theme: old.darkMode ? 'dark' : 'light', language: old.lang })
+ localStorage.removeItem('userConfig:v1')
+ }
+ } catch {}
+}
+```
+
+**Store minimal fields from server responses:**
+
+```typescript
+// User object has 20+ fields, only store what UI needs
+function cachePrefs(user: FullUser) {
+ try {
+ localStorage.setItem('prefs:v1', JSON.stringify({
+ theme: user.preferences.theme,
+ notifications: user.preferences.notifications
+ }))
+ } catch {}
+}
+```
+
+**Always wrap in try-catch:** `getItem()` and `setItem()` throw in incognito/private browsing (Safari, Firefox), when quota exceeded, or when disabled.
+
+**Benefits:** Schema evolution via versioning, reduced storage size, prevents storing tokens/PII/internal flags.
diff --git a/.agents/skills/vercel-react-best-practices/rules/client-passive-event-listeners.md b/.agents/skills/vercel-react-best-practices/rules/client-passive-event-listeners.md
new file mode 100644
index 000000000..ce39a889e
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/client-passive-event-listeners.md
@@ -0,0 +1,48 @@
+---
+title: Use Passive Event Listeners for Scrolling Performance
+impact: MEDIUM
+impactDescription: eliminates scroll delay caused by event listeners
+tags: client, event-listeners, scrolling, performance, touch, wheel
+---
+
+## Use Passive Event Listeners for Scrolling Performance
+
+Add `{ passive: true }` to touch and wheel event listeners to enable immediate scrolling. Browsers normally wait for listeners to finish to check if `preventDefault()` is called, causing scroll delay.
+
+**Incorrect:**
+
+```typescript
+useEffect(() => {
+ const handleTouch = (e: TouchEvent) => console.log(e.touches[0].clientX)
+ const handleWheel = (e: WheelEvent) => console.log(e.deltaY)
+
+ document.addEventListener('touchstart', handleTouch)
+ document.addEventListener('wheel', handleWheel)
+
+ return () => {
+ document.removeEventListener('touchstart', handleTouch)
+ document.removeEventListener('wheel', handleWheel)
+ }
+}, [])
+```
+
+**Correct:**
+
+```typescript
+useEffect(() => {
+ const handleTouch = (e: TouchEvent) => console.log(e.touches[0].clientX)
+ const handleWheel = (e: WheelEvent) => console.log(e.deltaY)
+
+ document.addEventListener('touchstart', handleTouch, { passive: true })
+ document.addEventListener('wheel', handleWheel, { passive: true })
+
+ return () => {
+ document.removeEventListener('touchstart', handleTouch)
+ document.removeEventListener('wheel', handleWheel)
+ }
+}, [])
+```
+
+**Use passive when:** tracking/analytics, logging, any listener that doesn't call `preventDefault()`.
+
+**Don't use passive when:** implementing custom swipe gestures, custom zoom controls, or any listener that needs `preventDefault()`.
diff --git a/.agents/skills/vercel-react-best-practices/rules/client-swr-dedup.md b/.agents/skills/vercel-react-best-practices/rules/client-swr-dedup.md
new file mode 100644
index 000000000..2a430f27f
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/client-swr-dedup.md
@@ -0,0 +1,56 @@
+---
+title: Use SWR for Automatic Deduplication
+impact: MEDIUM-HIGH
+impactDescription: automatic deduplication
+tags: client, swr, deduplication, data-fetching
+---
+
+## Use SWR for Automatic Deduplication
+
+SWR enables request deduplication, caching, and revalidation across component instances.
+
+**Incorrect (no deduplication, each instance fetches):**
+
+```tsx
+function UserList() {
+ const [users, setUsers] = useState([])
+ useEffect(() => {
+ fetch('/api/users')
+ .then(r => r.json())
+ .then(setUsers)
+ }, [])
+}
+```
+
+**Correct (multiple instances share one request):**
+
+```tsx
+import useSWR from 'swr'
+
+function UserList() {
+ const { data: users } = useSWR('/api/users', fetcher)
+}
+```
+
+**For immutable data:**
+
+```tsx
+import { useImmutableSWR } from '@/lib/swr'
+
+function StaticContent() {
+ const { data } = useImmutableSWR('/api/config', fetcher)
+}
+```
+
+**For mutations:**
+
+```tsx
+import { useSWRMutation } from 'swr/mutation'
+
+function UpdateButton() {
+ const { trigger } = useSWRMutation('/api/user', updateUser)
+ return trigger()}>Update
+}
+```
+
+Reference: [https://swr.vercel.app](https://swr.vercel.app)
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-batch-dom-css.md b/.agents/skills/vercel-react-best-practices/rules/js-batch-dom-css.md
new file mode 100644
index 000000000..84b655292
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-batch-dom-css.md
@@ -0,0 +1,57 @@
+---
+title: Batch DOM CSS Changes
+impact: MEDIUM
+impactDescription: reduces reflows/repaints
+tags: javascript, dom, css, performance, reflow
+---
+
+## Batch DOM CSS Changes
+
+Avoid interleaving style writes with layout reads. When you read a layout property (like `offsetWidth`, `getBoundingClientRect()`, or `getComputedStyle()`) between style changes, the browser is forced to trigger a synchronous reflow.
+
+**Incorrect (interleaved reads and writes force reflows):**
+
+```typescript
+function updateElementStyles(element: HTMLElement) {
+ element.style.width = '100px'
+ const width = element.offsetWidth // Forces reflow
+ element.style.height = '200px'
+ const height = element.offsetHeight // Forces another reflow
+}
+```
+
+**Correct (batch writes, then read once):**
+
+```typescript
+function updateElementStyles(element: HTMLElement) {
+ // Batch all writes together
+ element.style.width = '100px'
+ element.style.height = '200px'
+ element.style.backgroundColor = 'blue'
+ element.style.border = '1px solid black'
+
+ // Read after all writes are done (single reflow)
+ const { width, height } = element.getBoundingClientRect()
+}
+```
+
+**Better: use CSS classes**
+
+```css
+.highlighted-box {
+ width: 100px;
+ height: 200px;
+ background-color: blue;
+ border: 1px solid black;
+}
+```
+
+```typescript
+function updateElementStyles(element: HTMLElement) {
+ element.classList.add('highlighted-box')
+
+ const { width, height } = element.getBoundingClientRect()
+}
+```
+
+Prefer CSS classes over inline styles when possible. CSS files are cached by the browser, and classes provide better separation of concerns and are easier to maintain.
\ No newline at end of file
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-cache-function-results.md b/.agents/skills/vercel-react-best-practices/rules/js-cache-function-results.md
new file mode 100644
index 000000000..180f8ac8f
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-cache-function-results.md
@@ -0,0 +1,80 @@
+---
+title: Cache Repeated Function Calls
+impact: MEDIUM
+impactDescription: avoid redundant computation
+tags: javascript, cache, memoization, performance
+---
+
+## Cache Repeated Function Calls
+
+Use a module-level Map to cache function results when the same function is called repeatedly with the same inputs during render.
+
+**Incorrect (redundant computation):**
+
+```typescript
+function ProjectList({ projects }: { projects: Project[] }) {
+ return (
+
+ {projects.map(project => {
+ // slugify() called 100+ times for same project names
+ const slug = slugify(project.name)
+
+ return
+ })}
+
+ )
+}
+```
+
+**Correct (cached results):**
+
+```typescript
+// Module-level cache
+const slugifyCache = new Map()
+
+function cachedSlugify(text: string): string {
+ if (slugifyCache.has(text)) {
+ return slugifyCache.get(text)!
+ }
+ const result = slugify(text)
+ slugifyCache.set(text, result)
+ return result
+}
+
+function ProjectList({ projects }: { projects: Project[] }) {
+ return (
+
+ {projects.map(project => {
+ // Computed only once per unique project name
+ const slug = cachedSlugify(project.name)
+
+ return
+ })}
+
+ )
+}
+```
+
+**Simpler pattern for single-value functions:**
+
+```typescript
+let isLoggedInCache: boolean | null = null
+
+function isLoggedIn(): boolean {
+ if (isLoggedInCache !== null) {
+ return isLoggedInCache
+ }
+
+ isLoggedInCache = document.cookie.includes('auth=')
+ return isLoggedInCache
+}
+
+// Clear cache when auth changes
+function onAuthChange() {
+ isLoggedInCache = null
+}
+```
+
+Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
+
+Reference: [How we made the Vercel Dashboard twice as fast](https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast)
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-cache-property-access.md b/.agents/skills/vercel-react-best-practices/rules/js-cache-property-access.md
new file mode 100644
index 000000000..39eec9061
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-cache-property-access.md
@@ -0,0 +1,28 @@
+---
+title: Cache Property Access in Loops
+impact: LOW-MEDIUM
+impactDescription: reduces lookups
+tags: javascript, loops, optimization, caching
+---
+
+## Cache Property Access in Loops
+
+Cache object property lookups in hot paths.
+
+**Incorrect (3 lookups × N iterations):**
+
+```typescript
+for (let i = 0; i < arr.length; i++) {
+ process(obj.config.settings.value)
+}
+```
+
+**Correct (1 lookup total):**
+
+```typescript
+const value = obj.config.settings.value
+const len = arr.length
+for (let i = 0; i < len; i++) {
+ process(value)
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-cache-storage.md b/.agents/skills/vercel-react-best-practices/rules/js-cache-storage.md
new file mode 100644
index 000000000..aa4a30c08
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-cache-storage.md
@@ -0,0 +1,70 @@
+---
+title: Cache Storage API Calls
+impact: LOW-MEDIUM
+impactDescription: reduces expensive I/O
+tags: javascript, localStorage, storage, caching, performance
+---
+
+## Cache Storage API Calls
+
+`localStorage`, `sessionStorage`, and `document.cookie` are synchronous and expensive. Cache reads in memory.
+
+**Incorrect (reads storage on every call):**
+
+```typescript
+function getTheme() {
+ return localStorage.getItem('theme') ?? 'light'
+}
+// Called 10 times = 10 storage reads
+```
+
+**Correct (Map cache):**
+
+```typescript
+const storageCache = new Map()
+
+function getLocalStorage(key: string) {
+ if (!storageCache.has(key)) {
+ storageCache.set(key, localStorage.getItem(key))
+ }
+ return storageCache.get(key)
+}
+
+function setLocalStorage(key: string, value: string) {
+ localStorage.setItem(key, value)
+ storageCache.set(key, value) // keep cache in sync
+}
+```
+
+Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
+
+**Cookie caching:**
+
+```typescript
+let cookieCache: Record | null = null
+
+function getCookie(name: string) {
+ if (!cookieCache) {
+ cookieCache = Object.fromEntries(
+ document.cookie.split('; ').map(c => c.split('='))
+ )
+ }
+ return cookieCache[name]
+}
+```
+
+**Important (invalidate on external changes):**
+
+If storage can change externally (another tab, server-set cookies), invalidate cache:
+
+```typescript
+window.addEventListener('storage', (e) => {
+ if (e.key) storageCache.delete(e.key)
+})
+
+document.addEventListener('visibilitychange', () => {
+ if (document.visibilityState === 'visible') {
+ storageCache.clear()
+ }
+})
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-combine-iterations.md b/.agents/skills/vercel-react-best-practices/rules/js-combine-iterations.md
new file mode 100644
index 000000000..044d017ec
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-combine-iterations.md
@@ -0,0 +1,32 @@
+---
+title: Combine Multiple Array Iterations
+impact: LOW-MEDIUM
+impactDescription: reduces iterations
+tags: javascript, arrays, loops, performance
+---
+
+## Combine Multiple Array Iterations
+
+Multiple `.filter()` or `.map()` calls iterate the array multiple times. Combine into one loop.
+
+**Incorrect (3 iterations):**
+
+```typescript
+const admins = users.filter(u => u.isAdmin)
+const testers = users.filter(u => u.isTester)
+const inactive = users.filter(u => !u.isActive)
+```
+
+**Correct (1 iteration):**
+
+```typescript
+const admins: User[] = []
+const testers: User[] = []
+const inactive: User[] = []
+
+for (const user of users) {
+ if (user.isAdmin) admins.push(user)
+ if (user.isTester) testers.push(user)
+ if (!user.isActive) inactive.push(user)
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-early-exit.md b/.agents/skills/vercel-react-best-practices/rules/js-early-exit.md
new file mode 100644
index 000000000..f46cb89c6
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-early-exit.md
@@ -0,0 +1,50 @@
+---
+title: Early Return from Functions
+impact: LOW-MEDIUM
+impactDescription: avoids unnecessary computation
+tags: javascript, functions, optimization, early-return
+---
+
+## Early Return from Functions
+
+Return early when result is determined to skip unnecessary processing.
+
+**Incorrect (processes all items even after finding answer):**
+
+```typescript
+function validateUsers(users: User[]) {
+ let hasError = false
+ let errorMessage = ''
+
+ for (const user of users) {
+ if (!user.email) {
+ hasError = true
+ errorMessage = 'Email required'
+ }
+ if (!user.name) {
+ hasError = true
+ errorMessage = 'Name required'
+ }
+ // Continues checking all users even after error found
+ }
+
+ return hasError ? { valid: false, error: errorMessage } : { valid: true }
+}
+```
+
+**Correct (returns immediately on first error):**
+
+```typescript
+function validateUsers(users: User[]) {
+ for (const user of users) {
+ if (!user.email) {
+ return { valid: false, error: 'Email required' }
+ }
+ if (!user.name) {
+ return { valid: false, error: 'Name required' }
+ }
+ }
+
+ return { valid: true }
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-hoist-regexp.md b/.agents/skills/vercel-react-best-practices/rules/js-hoist-regexp.md
new file mode 100644
index 000000000..dae3fefdc
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-hoist-regexp.md
@@ -0,0 +1,45 @@
+---
+title: Hoist RegExp Creation
+impact: LOW-MEDIUM
+impactDescription: avoids recreation
+tags: javascript, regexp, optimization, memoization
+---
+
+## Hoist RegExp Creation
+
+Don't create RegExp inside render. Hoist to module scope or memoize with `useMemo()`.
+
+**Incorrect (new RegExp every render):**
+
+```tsx
+function Highlighter({ text, query }: Props) {
+ const regex = new RegExp(`(${query})`, 'gi')
+ const parts = text.split(regex)
+ return <>{parts.map((part, i) => ...)}>
+}
+```
+
+**Correct (memoize or hoist):**
+
+```tsx
+const EMAIL_REGEX = /^[^\s@]+@[^\s@]+\.[^\s@]+$/
+
+function Highlighter({ text, query }: Props) {
+ const regex = useMemo(
+ () => new RegExp(`(${escapeRegex(query)})`, 'gi'),
+ [query]
+ )
+ const parts = text.split(regex)
+ return <>{parts.map((part, i) => ...)}>
+}
+```
+
+**Warning (global regex has mutable state):**
+
+Global regex (`/g`) has mutable `lastIndex` state:
+
+```typescript
+const regex = /foo/g
+regex.test('foo') // true, lastIndex = 3
+regex.test('foo') // false, lastIndex = 0
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-index-maps.md b/.agents/skills/vercel-react-best-practices/rules/js-index-maps.md
new file mode 100644
index 000000000..9d357a00b
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-index-maps.md
@@ -0,0 +1,37 @@
+---
+title: Build Index Maps for Repeated Lookups
+impact: LOW-MEDIUM
+impactDescription: 1M ops to 2K ops
+tags: javascript, map, indexing, optimization, performance
+---
+
+## Build Index Maps for Repeated Lookups
+
+Multiple `.find()` calls by the same key should use a Map.
+
+**Incorrect (O(n) per lookup):**
+
+```typescript
+function processOrders(orders: Order[], users: User[]) {
+ return orders.map(order => ({
+ ...order,
+ user: users.find(u => u.id === order.userId)
+ }))
+}
+```
+
+**Correct (O(1) per lookup):**
+
+```typescript
+function processOrders(orders: Order[], users: User[]) {
+ const userById = new Map(users.map(u => [u.id, u]))
+
+ return orders.map(order => ({
+ ...order,
+ user: userById.get(order.userId)
+ }))
+}
+```
+
+Build map once (O(n)), then all lookups are O(1).
+For 1000 orders × 1000 users: 1M ops → 2K ops.
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-length-check-first.md b/.agents/skills/vercel-react-best-practices/rules/js-length-check-first.md
new file mode 100644
index 000000000..8b8957363
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-length-check-first.md
@@ -0,0 +1,49 @@
+---
+title: Early Length Check for Array Comparisons
+impact: MEDIUM-HIGH
+impactDescription: avoids expensive operations when lengths differ
+tags: javascript, arrays, performance, optimization, comparison
+---
+
+## Early Length Check for Array Comparisons
+
+When comparing arrays with expensive operations (sorting, deep equality, serialization), check lengths first. If lengths differ, the arrays cannot be equal.
+
+In real-world applications, this optimization is especially valuable when the comparison runs in hot paths (event handlers, render loops).
+
+**Incorrect (always runs expensive comparison):**
+
+```typescript
+function hasChanges(current: string[], original: string[]) {
+ // Always sorts and joins, even when lengths differ
+ return current.sort().join() !== original.sort().join()
+}
+```
+
+Two O(n log n) sorts run even when `current.length` is 5 and `original.length` is 100. There is also overhead of joining the arrays and comparing the strings.
+
+**Correct (O(1) length check first):**
+
+```typescript
+function hasChanges(current: string[], original: string[]) {
+ // Early return if lengths differ
+ if (current.length !== original.length) {
+ return true
+ }
+ // Only sort when lengths match
+ const currentSorted = current.toSorted()
+ const originalSorted = original.toSorted()
+ for (let i = 0; i < currentSorted.length; i++) {
+ if (currentSorted[i] !== originalSorted[i]) {
+ return true
+ }
+ }
+ return false
+}
+```
+
+This new approach is more efficient because:
+- It avoids the overhead of sorting and joining the arrays when lengths differ
+- It avoids consuming memory for the joined strings (especially important for large arrays)
+- It avoids mutating the original arrays
+- It returns early when a difference is found
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-min-max-loop.md b/.agents/skills/vercel-react-best-practices/rules/js-min-max-loop.md
new file mode 100644
index 000000000..4b6656e96
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-min-max-loop.md
@@ -0,0 +1,82 @@
+---
+title: Use Loop for Min/Max Instead of Sort
+impact: LOW
+impactDescription: O(n) instead of O(n log n)
+tags: javascript, arrays, performance, sorting, algorithms
+---
+
+## Use Loop for Min/Max Instead of Sort
+
+Finding the smallest or largest element only requires a single pass through the array. Sorting is wasteful and slower.
+
+**Incorrect (O(n log n) - sort to find latest):**
+
+```typescript
+interface Project {
+ id: string
+ name: string
+ updatedAt: number
+}
+
+function getLatestProject(projects: Project[]) {
+ const sorted = [...projects].sort((a, b) => b.updatedAt - a.updatedAt)
+ return sorted[0]
+}
+```
+
+Sorts the entire array just to find the maximum value.
+
+**Incorrect (O(n log n) - sort for oldest and newest):**
+
+```typescript
+function getOldestAndNewest(projects: Project[]) {
+ const sorted = [...projects].sort((a, b) => a.updatedAt - b.updatedAt)
+ return { oldest: sorted[0], newest: sorted[sorted.length - 1] }
+}
+```
+
+Still sorts unnecessarily when only min/max are needed.
+
+**Correct (O(n) - single loop):**
+
+```typescript
+function getLatestProject(projects: Project[]) {
+ if (projects.length === 0) return null
+
+ let latest = projects[0]
+
+ for (let i = 1; i < projects.length; i++) {
+ if (projects[i].updatedAt > latest.updatedAt) {
+ latest = projects[i]
+ }
+ }
+
+ return latest
+}
+
+function getOldestAndNewest(projects: Project[]) {
+ if (projects.length === 0) return { oldest: null, newest: null }
+
+ let oldest = projects[0]
+ let newest = projects[0]
+
+ for (let i = 1; i < projects.length; i++) {
+ if (projects[i].updatedAt < oldest.updatedAt) oldest = projects[i]
+ if (projects[i].updatedAt > newest.updatedAt) newest = projects[i]
+ }
+
+ return { oldest, newest }
+}
+```
+
+Single pass through the array, no copying, no sorting.
+
+**Alternative (Math.min/Math.max for small arrays):**
+
+```typescript
+const numbers = [5, 2, 8, 1, 9]
+const min = Math.min(...numbers)
+const max = Math.max(...numbers)
+```
+
+This works for small arrays, but can be slower or just throw an error for very large arrays due to spread operator limitations. Maximal array length is approximately 124000 in Chrome 143 and 638000 in Safari 18; exact numbers may vary - see [the fiddle](https://jsfiddle.net/qw1jabsx/4/). Use the loop approach for reliability.
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-set-map-lookups.md b/.agents/skills/vercel-react-best-practices/rules/js-set-map-lookups.md
new file mode 100644
index 000000000..680a4892e
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-set-map-lookups.md
@@ -0,0 +1,24 @@
+---
+title: Use Set/Map for O(1) Lookups
+impact: LOW-MEDIUM
+impactDescription: O(n) to O(1)
+tags: javascript, set, map, data-structures, performance
+---
+
+## Use Set/Map for O(1) Lookups
+
+Convert arrays to Set/Map for repeated membership checks.
+
+**Incorrect (O(n) per check):**
+
+```typescript
+const allowedIds = ['a', 'b', 'c', ...]
+items.filter(item => allowedIds.includes(item.id))
+```
+
+**Correct (O(1) per check):**
+
+```typescript
+const allowedIds = new Set(['a', 'b', 'c', ...])
+items.filter(item => allowedIds.has(item.id))
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/js-tosorted-immutable.md b/.agents/skills/vercel-react-best-practices/rules/js-tosorted-immutable.md
new file mode 100644
index 000000000..eae8b3f8a
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/js-tosorted-immutable.md
@@ -0,0 +1,57 @@
+---
+title: Use toSorted() Instead of sort() for Immutability
+impact: MEDIUM-HIGH
+impactDescription: prevents mutation bugs in React state
+tags: javascript, arrays, immutability, react, state, mutation
+---
+
+## Use toSorted() Instead of sort() for Immutability
+
+`.sort()` mutates the array in place, which can cause bugs with React state and props. Use `.toSorted()` to create a new sorted array without mutation.
+
+**Incorrect (mutates original array):**
+
+```typescript
+function UserList({ users }: { users: User[] }) {
+ // Mutates the users prop array!
+ const sorted = useMemo(
+ () => users.sort((a, b) => a.name.localeCompare(b.name)),
+ [users]
+ )
+ return {sorted.map(renderUser)}
+}
+```
+
+**Correct (creates new array):**
+
+```typescript
+function UserList({ users }: { users: User[] }) {
+ // Creates new sorted array, original unchanged
+ const sorted = useMemo(
+ () => users.toSorted((a, b) => a.name.localeCompare(b.name)),
+ [users]
+ )
+ return {sorted.map(renderUser)}
+}
+```
+
+**Why this matters in React:**
+
+1. Props/state mutations break React's immutability model - React expects props and state to be treated as read-only
+2. Causes stale closure bugs - Mutating arrays inside closures (callbacks, effects) can lead to unexpected behavior
+
+**Browser support (fallback for older browsers):**
+
+`.toSorted()` is available in all modern browsers (Chrome 110+, Safari 16+, Firefox 115+, Node.js 20+). For older environments, use spread operator:
+
+```typescript
+// Fallback for older browsers
+const sorted = [...items].sort((a, b) => a.value - b.value)
+```
+
+**Other immutable array methods:**
+
+- `.toSorted()` - immutable sort
+- `.toReversed()` - immutable reverse
+- `.toSpliced()` - immutable splice
+- `.with()` - immutable element replacement
diff --git a/.agents/skills/vercel-react-best-practices/rules/rendering-activity.md b/.agents/skills/vercel-react-best-practices/rules/rendering-activity.md
new file mode 100644
index 000000000..c957a490b
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rendering-activity.md
@@ -0,0 +1,26 @@
+---
+title: Use Activity Component for Show/Hide
+impact: MEDIUM
+impactDescription: preserves state/DOM
+tags: rendering, activity, visibility, state-preservation
+---
+
+## Use Activity Component for Show/Hide
+
+Use React's `` to preserve state/DOM for expensive components that frequently toggle visibility.
+
+**Usage:**
+
+```tsx
+import { Activity } from 'react'
+
+function Dropdown({ isOpen }: Props) {
+ return (
+
+
+
+ )
+}
+```
+
+Avoids expensive re-renders and state loss.
diff --git a/.agents/skills/vercel-react-best-practices/rules/rendering-animate-svg-wrapper.md b/.agents/skills/vercel-react-best-practices/rules/rendering-animate-svg-wrapper.md
new file mode 100644
index 000000000..646744cbe
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rendering-animate-svg-wrapper.md
@@ -0,0 +1,47 @@
+---
+title: Animate SVG Wrapper Instead of SVG Element
+impact: LOW
+impactDescription: enables hardware acceleration
+tags: rendering, svg, css, animation, performance
+---
+
+## Animate SVG Wrapper Instead of SVG Element
+
+Many browsers don't have hardware acceleration for CSS3 animations on SVG elements. Wrap SVG in a `` and animate the wrapper instead.
+
+**Incorrect (animating SVG directly - no hardware acceleration):**
+
+```tsx
+function LoadingSpinner() {
+ return (
+
+
+
+ )
+}
+```
+
+**Correct (animating wrapper div - hardware accelerated):**
+
+```tsx
+function LoadingSpinner() {
+ return (
+
+
+
+
+
+ )
+}
+```
+
+This applies to all CSS transforms and transitions (`transform`, `opacity`, `translate`, `scale`, `rotate`). The wrapper div allows browsers to use GPU acceleration for smoother animations.
diff --git a/.agents/skills/vercel-react-best-practices/rules/rendering-conditional-render.md b/.agents/skills/vercel-react-best-practices/rules/rendering-conditional-render.md
new file mode 100644
index 000000000..7e866f585
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rendering-conditional-render.md
@@ -0,0 +1,40 @@
+---
+title: Use Explicit Conditional Rendering
+impact: LOW
+impactDescription: prevents rendering 0 or NaN
+tags: rendering, conditional, jsx, falsy-values
+---
+
+## Use Explicit Conditional Rendering
+
+Use explicit ternary operators (`? :`) instead of `&&` for conditional rendering when the condition can be `0`, `NaN`, or other falsy values that render.
+
+**Incorrect (renders "0" when count is 0):**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count && {count} }
+
+ )
+}
+
+// When count = 0, renders:
0
+// When count = 5, renders:
5
+```
+
+**Correct (renders nothing when count is 0):**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count > 0 ? {count} : null}
+
+ )
+}
+
+// When count = 0, renders:
+// When count = 5, renders:
5
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/rendering-content-visibility.md b/.agents/skills/vercel-react-best-practices/rules/rendering-content-visibility.md
new file mode 100644
index 000000000..aa6656362
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rendering-content-visibility.md
@@ -0,0 +1,38 @@
+---
+title: CSS content-visibility for Long Lists
+impact: HIGH
+impactDescription: faster initial render
+tags: rendering, css, content-visibility, long-lists
+---
+
+## CSS content-visibility for Long Lists
+
+Apply `content-visibility: auto` to defer off-screen rendering.
+
+**CSS:**
+
+```css
+.message-item {
+ content-visibility: auto;
+ contain-intrinsic-size: 0 80px;
+}
+```
+
+**Example:**
+
+```tsx
+function MessageList({ messages }: { messages: Message[] }) {
+ return (
+
+ {messages.map(msg => (
+
+ ))}
+
+ )
+}
+```
+
+For 1000 messages, browser skips layout/paint for ~990 off-screen items (10× faster initial render).
diff --git a/.agents/skills/vercel-react-best-practices/rules/rendering-hoist-jsx.md b/.agents/skills/vercel-react-best-practices/rules/rendering-hoist-jsx.md
new file mode 100644
index 000000000..32d2f3fce
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rendering-hoist-jsx.md
@@ -0,0 +1,46 @@
+---
+title: Hoist Static JSX Elements
+impact: LOW
+impactDescription: avoids re-creation
+tags: rendering, jsx, static, optimization
+---
+
+## Hoist Static JSX Elements
+
+Extract static JSX outside components to avoid re-creation.
+
+**Incorrect (recreates element every render):**
+
+```tsx
+function LoadingSkeleton() {
+ return
+}
+
+function Container() {
+ return (
+
+ {loading && }
+
+ )
+}
+```
+
+**Correct (reuses same element):**
+
+```tsx
+const loadingSkeleton = (
+
+)
+
+function Container() {
+ return (
+
+ {loading && loadingSkeleton}
+
+ )
+}
+```
+
+This is especially helpful for large and static SVG nodes, which can be expensive to recreate on every render.
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler automatically hoists static JSX elements and optimizes component re-renders, making manual hoisting unnecessary.
diff --git a/.agents/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md b/.agents/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md
new file mode 100644
index 000000000..5cf0e79b6
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md
@@ -0,0 +1,82 @@
+---
+title: Prevent Hydration Mismatch Without Flickering
+impact: MEDIUM
+impactDescription: avoids visual flicker and hydration errors
+tags: rendering, ssr, hydration, localStorage, flicker
+---
+
+## Prevent Hydration Mismatch Without Flickering
+
+When rendering content that depends on client-side storage (localStorage, cookies), avoid both SSR breakage and post-hydration flickering by injecting a synchronous script that updates the DOM before React hydrates.
+
+**Incorrect (breaks SSR):**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ // localStorage is not available on server - throws error
+ const theme = localStorage.getItem('theme') || 'light'
+
+ return (
+
+ {children}
+
+ )
+}
+```
+
+Server-side rendering will fail because `localStorage` is undefined.
+
+**Incorrect (visual flickering):**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ const [theme, setTheme] = useState('light')
+
+ useEffect(() => {
+ // Runs after hydration - causes visible flash
+ const stored = localStorage.getItem('theme')
+ if (stored) {
+ setTheme(stored)
+ }
+ }, [])
+
+ return (
+
+ {children}
+
+ )
+}
+```
+
+Component first renders with default value (`light`), then updates after hydration, causing a visible flash of incorrect content.
+
+**Correct (no flicker, no hydration mismatch):**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ return (
+ <>
+
+ {children}
+
+
+ >
+ )
+}
+```
+
+The inline script executes synchronously before showing the element, ensuring the DOM already has the correct value. No flickering, no hydration mismatch.
+
+This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
diff --git a/.agents/skills/vercel-react-best-practices/rules/rendering-svg-precision.md b/.agents/skills/vercel-react-best-practices/rules/rendering-svg-precision.md
new file mode 100644
index 000000000..6d7712860
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rendering-svg-precision.md
@@ -0,0 +1,28 @@
+---
+title: Optimize SVG Precision
+impact: LOW
+impactDescription: reduces file size
+tags: rendering, svg, optimization, svgo
+---
+
+## Optimize SVG Precision
+
+Reduce SVG coordinate precision to decrease file size. The optimal precision depends on the viewBox size, but in general reducing precision should be considered.
+
+**Incorrect (excessive precision):**
+
+```svg
+
+```
+
+**Correct (1 decimal place):**
+
+```svg
+
+```
+
+**Automate with SVGO:**
+
+```bash
+npx svgo --precision=1 --multipass icon.svg
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/rerender-defer-reads.md b/.agents/skills/vercel-react-best-practices/rules/rerender-defer-reads.md
new file mode 100644
index 000000000..e867c95f0
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rerender-defer-reads.md
@@ -0,0 +1,39 @@
+---
+title: Defer State Reads to Usage Point
+impact: MEDIUM
+impactDescription: avoids unnecessary subscriptions
+tags: rerender, searchParams, localStorage, optimization
+---
+
+## Defer State Reads to Usage Point
+
+Don't subscribe to dynamic state (searchParams, localStorage) if you only read it inside callbacks.
+
+**Incorrect (subscribes to all searchParams changes):**
+
+```tsx
+function ShareButton({ chatId }: { chatId: string }) {
+ const searchParams = useSearchParams()
+
+ const handleShare = () => {
+ const ref = searchParams.get('ref')
+ shareChat(chatId, { ref })
+ }
+
+ return
Share
+}
+```
+
+**Correct (reads on demand, no subscription):**
+
+```tsx
+function ShareButton({ chatId }: { chatId: string }) {
+ const handleShare = () => {
+ const params = new URLSearchParams(window.location.search)
+ const ref = params.get('ref')
+ shareChat(chatId, { ref })
+ }
+
+ return
Share
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/rerender-dependencies.md b/.agents/skills/vercel-react-best-practices/rules/rerender-dependencies.md
new file mode 100644
index 000000000..47a4d9268
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rerender-dependencies.md
@@ -0,0 +1,45 @@
+---
+title: Narrow Effect Dependencies
+impact: LOW
+impactDescription: minimizes effect re-runs
+tags: rerender, useEffect, dependencies, optimization
+---
+
+## Narrow Effect Dependencies
+
+Specify primitive dependencies instead of objects to minimize effect re-runs.
+
+**Incorrect (re-runs on any user field change):**
+
+```tsx
+useEffect(() => {
+ console.log(user.id)
+}, [user])
+```
+
+**Correct (re-runs only when id changes):**
+
+```tsx
+useEffect(() => {
+ console.log(user.id)
+}, [user.id])
+```
+
+**For derived state, compute outside effect:**
+
+```tsx
+// Incorrect: runs on width=767, 766, 765...
+useEffect(() => {
+ if (width < 768) {
+ enableMobileMode()
+ }
+}, [width])
+
+// Correct: runs only on boolean transition
+const isMobile = width < 768
+useEffect(() => {
+ if (isMobile) {
+ enableMobileMode()
+ }
+}, [isMobile])
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/rerender-derived-state.md b/.agents/skills/vercel-react-best-practices/rules/rerender-derived-state.md
new file mode 100644
index 000000000..e5c899f6c
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rerender-derived-state.md
@@ -0,0 +1,29 @@
+---
+title: Subscribe to Derived State
+impact: MEDIUM
+impactDescription: reduces re-render frequency
+tags: rerender, derived-state, media-query, optimization
+---
+
+## Subscribe to Derived State
+
+Subscribe to derived boolean state instead of continuous values to reduce re-render frequency.
+
+**Incorrect (re-renders on every pixel change):**
+
+```tsx
+function Sidebar() {
+ const width = useWindowWidth() // updates continuously
+ const isMobile = width < 768
+ return
+}
+```
+
+**Correct (re-renders only when boolean changes):**
+
+```tsx
+function Sidebar() {
+ const isMobile = useMediaQuery('(max-width: 767px)')
+ return
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md b/.agents/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md
new file mode 100644
index 000000000..b004ef45e
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md
@@ -0,0 +1,74 @@
+---
+title: Use Functional setState Updates
+impact: MEDIUM
+impactDescription: prevents stale closures and unnecessary callback recreations
+tags: react, hooks, useState, useCallback, callbacks, closures
+---
+
+## Use Functional setState Updates
+
+When updating state based on the current state value, use the functional update form of setState instead of directly referencing the state variable. This prevents stale closures, eliminates unnecessary dependencies, and creates stable callback references.
+
+**Incorrect (requires state as dependency):**
+
+```tsx
+function TodoList() {
+ const [items, setItems] = useState(initialItems)
+
+ // Callback must depend on items, recreated on every items change
+ const addItems = useCallback((newItems: Item[]) => {
+ setItems([...items, ...newItems])
+ }, [items]) // ❌ items dependency causes recreations
+
+ // Risk of stale closure if dependency is forgotten
+ const removeItem = useCallback((id: string) => {
+ setItems(items.filter(item => item.id !== id))
+ }, []) // ❌ Missing items dependency - will use stale items!
+
+ return
+}
+```
+
+The first callback is recreated every time `items` changes, which can cause child components to re-render unnecessarily. The second callback has a stale closure bug—it will always reference the initial `items` value.
+
+**Correct (stable callbacks, no stale closures):**
+
+```tsx
+function TodoList() {
+ const [items, setItems] = useState(initialItems)
+
+ // Stable callback, never recreated
+ const addItems = useCallback((newItems: Item[]) => {
+ setItems(curr => [...curr, ...newItems])
+ }, []) // ✅ No dependencies needed
+
+ // Always uses latest state, no stale closure risk
+ const removeItem = useCallback((id: string) => {
+ setItems(curr => curr.filter(item => item.id !== id))
+ }, []) // ✅ Safe and stable
+
+ return
+}
+```
+
+**Benefits:**
+
+1. **Stable callback references** - Callbacks don't need to be recreated when state changes
+2. **No stale closures** - Always operates on the latest state value
+3. **Fewer dependencies** - Simplifies dependency arrays and reduces memory leaks
+4. **Prevents bugs** - Eliminates the most common source of React closure bugs
+
+**When to use functional updates:**
+
+- Any setState that depends on the current state value
+- Inside useCallback/useMemo when state is needed
+- Event handlers that reference state
+- Async operations that update state
+
+**When direct updates are fine:**
+
+- Setting state to a static value: `setCount(0)`
+- Setting state from props/arguments only: `setName(newName)`
+- State doesn't depend on previous value
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler can automatically optimize some cases, but functional updates are still recommended for correctness and to prevent stale closure bugs.
diff --git a/.agents/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md b/.agents/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md
new file mode 100644
index 000000000..4ecb350fb
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md
@@ -0,0 +1,58 @@
+---
+title: Use Lazy State Initialization
+impact: MEDIUM
+impactDescription: wasted computation on every render
+tags: react, hooks, useState, performance, initialization
+---
+
+## Use Lazy State Initialization
+
+Pass a function to `useState` for expensive initial values. Without the function form, the initializer runs on every render even though the value is only used once.
+
+**Incorrect (runs on every render):**
+
+```tsx
+function FilteredList({ items }: { items: Item[] }) {
+ // buildSearchIndex() runs on EVERY render, even after initialization
+ const [searchIndex, setSearchIndex] = useState(buildSearchIndex(items))
+ const [query, setQuery] = useState('')
+
+ // When query changes, buildSearchIndex runs again unnecessarily
+ return
+}
+
+function UserProfile() {
+ // JSON.parse runs on every render
+ const [settings, setSettings] = useState(
+ JSON.parse(localStorage.getItem('settings') || '{}')
+ )
+
+ return
+}
+```
+
+**Correct (runs only once):**
+
+```tsx
+function FilteredList({ items }: { items: Item[] }) {
+ // buildSearchIndex() runs ONLY on initial render
+ const [searchIndex, setSearchIndex] = useState(() => buildSearchIndex(items))
+ const [query, setQuery] = useState('')
+
+ return
+}
+
+function UserProfile() {
+ // JSON.parse runs only on initial render
+ const [settings, setSettings] = useState(() => {
+ const stored = localStorage.getItem('settings')
+ return stored ? JSON.parse(stored) : {}
+ })
+
+ return
+}
+```
+
+Use lazy initialization when computing initial values from localStorage/sessionStorage, building data structures (indexes, maps), reading from the DOM, or performing heavy transformations.
+
+For simple primitives (`useState(0)`), direct references (`useState(props.value)`), or cheap literals (`useState({})`), the function form is unnecessary.
diff --git a/.agents/skills/vercel-react-best-practices/rules/rerender-memo.md b/.agents/skills/vercel-react-best-practices/rules/rerender-memo.md
new file mode 100644
index 000000000..f8982ab61
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rerender-memo.md
@@ -0,0 +1,44 @@
+---
+title: Extract to Memoized Components
+impact: MEDIUM
+impactDescription: enables early returns
+tags: rerender, memo, useMemo, optimization
+---
+
+## Extract to Memoized Components
+
+Extract expensive work into memoized components to enable early returns before computation.
+
+**Incorrect (computes avatar even when loading):**
+
+```tsx
+function Profile({ user, loading }: Props) {
+ const avatar = useMemo(() => {
+ const id = computeAvatarId(user)
+ return
+ }, [user])
+
+ if (loading) return
+ return
{avatar}
+}
+```
+
+**Correct (skips computation when loading):**
+
+```tsx
+const UserAvatar = memo(function UserAvatar({ user }: { user: User }) {
+ const id = useMemo(() => computeAvatarId(user), [user])
+ return
+})
+
+function Profile({ user, loading }: Props) {
+ if (loading) return
+ return (
+
+
+
+ )
+}
+```
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, manual memoization with `memo()` and `useMemo()` is not necessary. The compiler automatically optimizes re-renders.
diff --git a/.agents/skills/vercel-react-best-practices/rules/rerender-transitions.md b/.agents/skills/vercel-react-best-practices/rules/rerender-transitions.md
new file mode 100644
index 000000000..d99f43f76
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/rerender-transitions.md
@@ -0,0 +1,40 @@
+---
+title: Use Transitions for Non-Urgent Updates
+impact: MEDIUM
+impactDescription: maintains UI responsiveness
+tags: rerender, transitions, startTransition, performance
+---
+
+## Use Transitions for Non-Urgent Updates
+
+Mark frequent, non-urgent state updates as transitions to maintain UI responsiveness.
+
+**Incorrect (blocks UI on every scroll):**
+
+```tsx
+function ScrollTracker() {
+ const [scrollY, setScrollY] = useState(0)
+ useEffect(() => {
+ const handler = () => setScrollY(window.scrollY)
+ window.addEventListener('scroll', handler, { passive: true })
+ return () => window.removeEventListener('scroll', handler)
+ }, [])
+}
+```
+
+**Correct (non-blocking updates):**
+
+```tsx
+import { startTransition } from 'react'
+
+function ScrollTracker() {
+ const [scrollY, setScrollY] = useState(0)
+ useEffect(() => {
+ const handler = () => {
+ startTransition(() => setScrollY(window.scrollY))
+ }
+ window.addEventListener('scroll', handler, { passive: true })
+ return () => window.removeEventListener('scroll', handler)
+ }, [])
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/server-after-nonblocking.md b/.agents/skills/vercel-react-best-practices/rules/server-after-nonblocking.md
new file mode 100644
index 000000000..e8f5b260f
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/server-after-nonblocking.md
@@ -0,0 +1,73 @@
+---
+title: Use after() for Non-Blocking Operations
+impact: MEDIUM
+impactDescription: faster response times
+tags: server, async, logging, analytics, side-effects
+---
+
+## Use after() for Non-Blocking Operations
+
+Use Next.js's `after()` to schedule work that should execute after a response is sent. This prevents logging, analytics, and other side effects from blocking the response.
+
+**Incorrect (blocks response):**
+
+```tsx
+import { logUserAction } from '@/app/utils'
+
+export async function POST(request: Request) {
+ // Perform mutation
+ await updateDatabase(request)
+
+ // Logging blocks the response
+ const userAgent = request.headers.get('user-agent') || 'unknown'
+ await logUserAction({ userAgent })
+
+ return new Response(JSON.stringify({ status: 'success' }), {
+ status: 200,
+ headers: { 'Content-Type': 'application/json' }
+ })
+}
+```
+
+**Correct (non-blocking):**
+
+```tsx
+import { after } from 'next/server'
+import { headers, cookies } from 'next/headers'
+import { logUserAction } from '@/app/utils'
+
+export async function POST(request: Request) {
+ // Perform mutation
+ await updateDatabase(request)
+
+ // Log after response is sent
+ after(async () => {
+ const userAgent = (await headers()).get('user-agent') || 'unknown'
+ const sessionCookie = (await cookies()).get('session-id')?.value || 'anonymous'
+
+ logUserAction({ sessionCookie, userAgent })
+ })
+
+ return new Response(JSON.stringify({ status: 'success' }), {
+ status: 200,
+ headers: { 'Content-Type': 'application/json' }
+ })
+}
+```
+
+The response is sent immediately while logging happens in the background.
+
+**Common use cases:**
+
+- Analytics tracking
+- Audit logging
+- Sending notifications
+- Cache invalidation
+- Cleanup tasks
+
+**Important notes:**
+
+- `after()` runs even if the response fails or redirects
+- Works in Server Actions, Route Handlers, and Server Components
+
+Reference: [https://nextjs.org/docs/app/api-reference/functions/after](https://nextjs.org/docs/app/api-reference/functions/after)
diff --git a/.agents/skills/vercel-react-best-practices/rules/server-cache-lru.md b/.agents/skills/vercel-react-best-practices/rules/server-cache-lru.md
new file mode 100644
index 000000000..ef6938aa5
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/server-cache-lru.md
@@ -0,0 +1,41 @@
+---
+title: Cross-Request LRU Caching
+impact: HIGH
+impactDescription: caches across requests
+tags: server, cache, lru, cross-request
+---
+
+## Cross-Request LRU Caching
+
+`React.cache()` only works within one request. For data shared across sequential requests (user clicks button A then button B), use an LRU cache.
+
+**Implementation:**
+
+```typescript
+import { LRUCache } from 'lru-cache'
+
+const cache = new LRUCache
({
+ max: 1000,
+ ttl: 5 * 60 * 1000 // 5 minutes
+})
+
+export async function getUser(id: string) {
+ const cached = cache.get(id)
+ if (cached) return cached
+
+ const user = await db.user.findUnique({ where: { id } })
+ cache.set(id, user)
+ return user
+}
+
+// Request 1: DB query, result cached
+// Request 2: cache hit, no DB query
+```
+
+Use when sequential user actions hit multiple endpoints needing the same data within seconds.
+
+**With Vercel's [Fluid Compute](https://vercel.com/docs/fluid-compute):** LRU caching is especially effective because multiple concurrent requests can share the same function instance and cache. This means the cache persists across requests without needing external storage like Redis.
+
+**In traditional serverless:** Each invocation runs in isolation, so consider Redis for cross-process caching.
+
+Reference: [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
diff --git a/.agents/skills/vercel-react-best-practices/rules/server-cache-react.md b/.agents/skills/vercel-react-best-practices/rules/server-cache-react.md
new file mode 100644
index 000000000..87c9ca331
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/server-cache-react.md
@@ -0,0 +1,76 @@
+---
+title: Per-Request Deduplication with React.cache()
+impact: MEDIUM
+impactDescription: deduplicates within request
+tags: server, cache, react-cache, deduplication
+---
+
+## Per-Request Deduplication with React.cache()
+
+Use `React.cache()` for server-side request deduplication. Authentication and database queries benefit most.
+
+**Usage:**
+
+```typescript
+import { cache } from 'react'
+
+export const getCurrentUser = cache(async () => {
+ const session = await auth()
+ if (!session?.user?.id) return null
+ return await db.user.findUnique({
+ where: { id: session.user.id }
+ })
+})
+```
+
+Within a single request, multiple calls to `getCurrentUser()` execute the query only once.
+
+**Avoid inline objects as arguments:**
+
+`React.cache()` uses shallow equality (`Object.is`) to determine cache hits. Inline objects create new references each call, preventing cache hits.
+
+**Incorrect (always cache miss):**
+
+```typescript
+const getUser = cache(async (params: { uid: number }) => {
+ return await db.user.findUnique({ where: { id: params.uid } })
+})
+
+// Each call creates new object, never hits cache
+getUser({ uid: 1 })
+getUser({ uid: 1 }) // Cache miss, runs query again
+```
+
+**Correct (cache hit):**
+
+```typescript
+const getUser = cache(async (uid: number) => {
+ return await db.user.findUnique({ where: { id: uid } })
+})
+
+// Primitive args use value equality
+getUser(1)
+getUser(1) // Cache hit, returns cached result
+```
+
+If you must pass objects, pass the same reference:
+
+```typescript
+const params = { uid: 1 }
+getUser(params) // Query runs
+getUser(params) // Cache hit (same reference)
+```
+
+**Next.js-Specific Note:**
+
+In Next.js, the `fetch` API is automatically extended with request memoization. Requests with the same URL and options are automatically deduplicated within a single request, so you don't need `React.cache()` for `fetch` calls. However, `React.cache()` is still essential for other async tasks:
+
+- Database queries (Prisma, Drizzle, etc.)
+- Heavy computations
+- Authentication checks
+- File system operations
+- Any non-fetch async work
+
+Use `React.cache()` to deduplicate these operations across your component tree.
+
+Reference: [React.cache documentation](https://react.dev/reference/react/cache)
diff --git a/.agents/skills/vercel-react-best-practices/rules/server-parallel-fetching.md b/.agents/skills/vercel-react-best-practices/rules/server-parallel-fetching.md
new file mode 100644
index 000000000..1affc835a
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/server-parallel-fetching.md
@@ -0,0 +1,83 @@
+---
+title: Parallel Data Fetching with Component Composition
+impact: CRITICAL
+impactDescription: eliminates server-side waterfalls
+tags: server, rsc, parallel-fetching, composition
+---
+
+## Parallel Data Fetching with Component Composition
+
+React Server Components execute sequentially within a tree. Restructure with composition to parallelize data fetching.
+
+**Incorrect (Sidebar waits for Page's fetch to complete):**
+
+```tsx
+export default async function Page() {
+ const header = await fetchHeader()
+ return (
+
+ )
+}
+
+async function Sidebar() {
+ const items = await fetchSidebarItems()
+ return {items.map(renderItem)}
+}
+```
+
+**Correct (both fetch simultaneously):**
+
+```tsx
+async function Header() {
+ const data = await fetchHeader()
+ return {data}
+}
+
+async function Sidebar() {
+ const items = await fetchSidebarItems()
+ return {items.map(renderItem)}
+}
+
+export default function Page() {
+ return (
+
+
+
+
+ )
+}
+```
+
+**Alternative with children prop:**
+
+```tsx
+async function Header() {
+ const data = await fetchHeader()
+ return {data}
+}
+
+async function Sidebar() {
+ const items = await fetchSidebarItems()
+ return {items.map(renderItem)}
+}
+
+function Layout({ children }: { children: ReactNode }) {
+ return (
+
+
+ {children}
+
+ )
+}
+
+export default function Page() {
+ return (
+
+
+
+ )
+}
+```
diff --git a/.agents/skills/vercel-react-best-practices/rules/server-serialization.md b/.agents/skills/vercel-react-best-practices/rules/server-serialization.md
new file mode 100644
index 000000000..39c5c4164
--- /dev/null
+++ b/.agents/skills/vercel-react-best-practices/rules/server-serialization.md
@@ -0,0 +1,38 @@
+---
+title: Minimize Serialization at RSC Boundaries
+impact: HIGH
+impactDescription: reduces data transfer size
+tags: server, rsc, serialization, props
+---
+
+## Minimize Serialization at RSC Boundaries
+
+The React Server/Client boundary serializes all object properties into strings and embeds them in the HTML response and subsequent RSC requests. This serialized data directly impacts page weight and load time, so **size matters a lot**. Only pass fields that the client actually uses.
+
+**Incorrect (serializes all 50 fields):**
+
+```tsx
+async function Page() {
+ const user = await fetchUser() // 50 fields
+ return
+}
+
+'use client'
+function Profile({ user }: { user: User }) {
+ return {user.name}
// uses 1 field
+}
+```
+
+**Correct (serializes only 1 field):**
+
+```tsx
+async function Page() {
+ const user = await fetchUser()
+ return
+}
+
+'use client'
+function Profile({ name }: { name: string }) {
+ return {name}
+}
+```
diff --git a/.agents/skills/web-design-guidelines/SKILL.md b/.agents/skills/web-design-guidelines/SKILL.md
new file mode 100644
index 000000000..ceae92ab3
--- /dev/null
+++ b/.agents/skills/web-design-guidelines/SKILL.md
@@ -0,0 +1,39 @@
+---
+name: web-design-guidelines
+description: Review UI code for Web Interface Guidelines compliance. Use when asked to "review my UI", "check accessibility", "audit design", "review UX", or "check my site against best practices".
+metadata:
+ author: vercel
+ version: "1.0.0"
+ argument-hint:
+---
+
+# Web Interface Guidelines
+
+Review files for compliance with Web Interface Guidelines.
+
+## How It Works
+
+1. Fetch the latest guidelines from the source URL below
+2. Read the specified files (or prompt user for files/pattern)
+3. Check against all rules in the fetched guidelines
+4. Output findings in the terse `file:line` format
+
+## Guidelines Source
+
+Fetch fresh guidelines before each review:
+
+```
+https://raw.githubusercontent.com/vercel-labs/web-interface-guidelines/main/command.md
+```
+
+Use WebFetch to retrieve the latest rules. The fetched content contains all the rules and output format instructions.
+
+## Usage
+
+When a user provides a file or pattern argument:
+1. Fetch guidelines from the source URL above
+2. Read the specified files
+3. Apply all rules from the fetched guidelines
+4. Output findings using the format specified in the guidelines
+
+If no files specified, ask the user which files to review.
diff --git a/.claude/settings.json b/.claude/settings.json
index 509dbe844..f9e1016d0 100644
--- a/.claude/settings.json
+++ b/.claude/settings.json
@@ -1,9 +1,20 @@
{
+ "hooks": {
+ "PreToolUse": [
+ {
+ "matcher": "Bash",
+ "hooks": [
+ {
+ "type": "command",
+ "command": "npx -y block-no-verify@1.1.1"
+ }
+ ]
+ }
+ ]
+ },
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
- "typescript-lsp@claude-plugins-official": true,
- "pyright-lsp@claude-plugins-official": true,
"ralph-loop@claude-plugins-official": true
}
}
diff --git a/.claude/skills/component-refactoring b/.claude/skills/component-refactoring
new file mode 120000
index 000000000..53ae67e2f
--- /dev/null
+++ b/.claude/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.claude/skills/frontend-code-review b/.claude/skills/frontend-code-review
new file mode 120000
index 000000000..55654ffbd
--- /dev/null
+++ b/.claude/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.claude/skills/frontend-testing b/.claude/skills/frontend-testing
new file mode 120000
index 000000000..092cec774
--- /dev/null
+++ b/.claude/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.claude/skills/orpc-contract-first b/.claude/skills/orpc-contract-first
new file mode 120000
index 000000000..da47b335c
--- /dev/null
+++ b/.claude/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.claude/skills/skill-creator b/.claude/skills/skill-creator
new file mode 120000
index 000000000..b87455490
--- /dev/null
+++ b/.claude/skills/skill-creator
@@ -0,0 +1 @@
+../../.agents/skills/skill-creator
\ No newline at end of file
diff --git a/.claude/skills/vercel-react-best-practices b/.claude/skills/vercel-react-best-practices
new file mode 120000
index 000000000..e567923b3
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices
@@ -0,0 +1 @@
+../../.agents/skills/vercel-react-best-practices
\ No newline at end of file
diff --git a/.claude/skills/web-design-guidelines b/.claude/skills/web-design-guidelines
new file mode 120000
index 000000000..886b26ded
--- /dev/null
+++ b/.claude/skills/web-design-guidelines
@@ -0,0 +1 @@
+../../.agents/skills/web-design-guidelines
\ No newline at end of file
diff --git a/.codex/skills b/.codex/skills
deleted file mode 120000
index 454b8427c..000000000
--- a/.codex/skills
+++ /dev/null
@@ -1 +0,0 @@
-../.claude/skills
\ No newline at end of file
diff --git a/.codex/skills/component-refactoring b/.codex/skills/component-refactoring
new file mode 120000
index 000000000..53ae67e2f
--- /dev/null
+++ b/.codex/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.codex/skills/frontend-code-review b/.codex/skills/frontend-code-review
new file mode 120000
index 000000000..55654ffbd
--- /dev/null
+++ b/.codex/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.codex/skills/frontend-testing b/.codex/skills/frontend-testing
new file mode 120000
index 000000000..092cec774
--- /dev/null
+++ b/.codex/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.codex/skills/orpc-contract-first b/.codex/skills/orpc-contract-first
new file mode 120000
index 000000000..da47b335c
--- /dev/null
+++ b/.codex/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.codex/skills/skill-creator b/.codex/skills/skill-creator
new file mode 120000
index 000000000..b87455490
--- /dev/null
+++ b/.codex/skills/skill-creator
@@ -0,0 +1 @@
+../../.agents/skills/skill-creator
\ No newline at end of file
diff --git a/.codex/skills/vercel-react-best-practices b/.codex/skills/vercel-react-best-practices
new file mode 120000
index 000000000..e567923b3
--- /dev/null
+++ b/.codex/skills/vercel-react-best-practices
@@ -0,0 +1 @@
+../../.agents/skills/vercel-react-best-practices
\ No newline at end of file
diff --git a/.codex/skills/web-design-guidelines b/.codex/skills/web-design-guidelines
new file mode 120000
index 000000000..886b26ded
--- /dev/null
+++ b/.codex/skills/web-design-guidelines
@@ -0,0 +1 @@
+../../.agents/skills/web-design-guidelines
\ No newline at end of file
diff --git a/.cursor/skills/component-refactoring b/.cursor/skills/component-refactoring
new file mode 120000
index 000000000..53ae67e2f
--- /dev/null
+++ b/.cursor/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.cursor/skills/frontend-code-review b/.cursor/skills/frontend-code-review
new file mode 120000
index 000000000..55654ffbd
--- /dev/null
+++ b/.cursor/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.cursor/skills/frontend-testing b/.cursor/skills/frontend-testing
new file mode 120000
index 000000000..092cec774
--- /dev/null
+++ b/.cursor/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.cursor/skills/orpc-contract-first b/.cursor/skills/orpc-contract-first
new file mode 120000
index 000000000..da47b335c
--- /dev/null
+++ b/.cursor/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.cursor/skills/skill-creator b/.cursor/skills/skill-creator
new file mode 120000
index 000000000..b87455490
--- /dev/null
+++ b/.cursor/skills/skill-creator
@@ -0,0 +1 @@
+../../.agents/skills/skill-creator
\ No newline at end of file
diff --git a/.cursor/skills/vercel-react-best-practices b/.cursor/skills/vercel-react-best-practices
new file mode 120000
index 000000000..e567923b3
--- /dev/null
+++ b/.cursor/skills/vercel-react-best-practices
@@ -0,0 +1 @@
+../../.agents/skills/vercel-react-best-practices
\ No newline at end of file
diff --git a/.cursor/skills/web-design-guidelines b/.cursor/skills/web-design-guidelines
new file mode 120000
index 000000000..886b26ded
--- /dev/null
+++ b/.cursor/skills/web-design-guidelines
@@ -0,0 +1 @@
+../../.agents/skills/web-design-guidelines
\ No newline at end of file
diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh
index 220f77e5c..637593b9d 100755
--- a/.devcontainer/post_create_command.sh
+++ b/.devcontainer/post_create_command.sh
@@ -8,7 +8,7 @@ pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
-echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
+echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
diff --git a/.gemini/skills/component-refactoring b/.gemini/skills/component-refactoring
new file mode 120000
index 000000000..53ae67e2f
--- /dev/null
+++ b/.gemini/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.gemini/skills/frontend-code-review b/.gemini/skills/frontend-code-review
new file mode 120000
index 000000000..55654ffbd
--- /dev/null
+++ b/.gemini/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.gemini/skills/frontend-testing b/.gemini/skills/frontend-testing
new file mode 120000
index 000000000..092cec774
--- /dev/null
+++ b/.gemini/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.gemini/skills/orpc-contract-first b/.gemini/skills/orpc-contract-first
new file mode 120000
index 000000000..da47b335c
--- /dev/null
+++ b/.gemini/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.gemini/skills/skill-creator b/.gemini/skills/skill-creator
new file mode 120000
index 000000000..b87455490
--- /dev/null
+++ b/.gemini/skills/skill-creator
@@ -0,0 +1 @@
+../../.agents/skills/skill-creator
\ No newline at end of file
diff --git a/.gemini/skills/vercel-react-best-practices b/.gemini/skills/vercel-react-best-practices
new file mode 120000
index 000000000..e567923b3
--- /dev/null
+++ b/.gemini/skills/vercel-react-best-practices
@@ -0,0 +1 @@
+../../.agents/skills/vercel-react-best-practices
\ No newline at end of file
diff --git a/.gemini/skills/web-design-guidelines b/.gemini/skills/web-design-guidelines
new file mode 120000
index 000000000..886b26ded
--- /dev/null
+++ b/.gemini/skills/web-design-guidelines
@@ -0,0 +1 @@
+../../.agents/skills/web-design-guidelines
\ No newline at end of file
diff --git a/.github/labeler.yml b/.github/labeler.yml
new file mode 100644
index 000000000..d1d324d38
--- /dev/null
+++ b/.github/labeler.yml
@@ -0,0 +1,3 @@
+web:
+ - changed-files:
+ - any-glob-to-any-file: 'web/**'
diff --git a/.github/skills/component-refactoring b/.github/skills/component-refactoring
new file mode 120000
index 000000000..53ae67e2f
--- /dev/null
+++ b/.github/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.github/skills/frontend-code-review b/.github/skills/frontend-code-review
new file mode 120000
index 000000000..55654ffbd
--- /dev/null
+++ b/.github/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.github/skills/frontend-testing b/.github/skills/frontend-testing
new file mode 120000
index 000000000..092cec774
--- /dev/null
+++ b/.github/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.github/skills/orpc-contract-first b/.github/skills/orpc-contract-first
new file mode 120000
index 000000000..da47b335c
--- /dev/null
+++ b/.github/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.github/skills/skill-creator b/.github/skills/skill-creator
new file mode 120000
index 000000000..b87455490
--- /dev/null
+++ b/.github/skills/skill-creator
@@ -0,0 +1 @@
+../../.agents/skills/skill-creator
\ No newline at end of file
diff --git a/.github/skills/vercel-react-best-practices b/.github/skills/vercel-react-best-practices
new file mode 120000
index 000000000..e567923b3
--- /dev/null
+++ b/.github/skills/vercel-react-best-practices
@@ -0,0 +1 @@
+../../.agents/skills/vercel-react-best-practices
\ No newline at end of file
diff --git a/.github/skills/web-design-guidelines b/.github/skills/web-design-guidelines
new file mode 120000
index 000000000..886b26ded
--- /dev/null
+++ b/.github/skills/web-design-guidelines
@@ -0,0 +1 @@
+../../.agents/skills/web-design-guidelines
\ No newline at end of file
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index 8ec7dbf4d..7c21bec7f 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -39,12 +39,6 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- - name: Run pyrefly check
- run: |
- cd api
- uv add --dev pyrefly
- uv run pyrefly check || true
-
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index 5413f83c2..4a8c61e7d 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -16,14 +16,14 @@ jobs:
- name: Check Docker Compose inputs
id: docker-compose-changes
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- - uses: actions/setup-python@v5
+ - uses: actions/setup-python@v6
with:
python-version: "3.11"
@@ -79,9 +79,32 @@ jobs:
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Setup Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Install web dependencies
+ run: |
+ cd web
+ pnpm install --frozen-lockfile
+
+ - name: ESLint autofix
+ run: |
+ cd web
+ pnpm lint:fix || true
+
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
- uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
+ uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index bbf89236d..704d89619 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -112,7 +112,7 @@ jobs:
context: "web"
steps:
- name: Download digests
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v7
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*
diff --git a/.github/workflows/deploy-trigger-dev.yml b/.github/workflows/deploy-agent-dev.yml
similarity index 69%
rename from .github/workflows/deploy-trigger-dev.yml
rename to .github/workflows/deploy-agent-dev.yml
index 2d9a904fc..dd759f7ba 100644
--- a/.github/workflows/deploy-trigger-dev.yml
+++ b/.github/workflows/deploy-agent-dev.yml
@@ -1,4 +1,4 @@
-name: Deploy Trigger Dev
+name: Deploy Agent Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- - "deploy/trigger-dev"
+ - "deploy/agent-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
- github.event.workflow_run.head_branch == 'deploy/trigger-dev'
+ github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
+ uses: appleboy/ssh-action@v1
with:
- host: ${{ secrets.TRIGGER_SSH_HOST }}
+ host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml
index cd1c86e66..38fa0b9a7 100644
--- a/.github/workflows/deploy-dev.yml
+++ b/.github/workflows/deploy-dev.yml
@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
+ uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml
new file mode 100644
index 000000000..7d5f0a22e
--- /dev/null
+++ b/.github/workflows/deploy-hitl.yml
@@ -0,0 +1,29 @@
+name: Deploy HITL
+
+on:
+ workflow_run:
+ workflows: ["Build and Push API & Web"]
+ branches:
+ - "feat/hitl-frontend"
+ - "feat/hitl-backend"
+ types:
+ - completed
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ if: |
+ github.event.workflow_run.conclusion == 'success' &&
+ (
+ github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
+ github.event.workflow_run.head_branch == 'feat/hitl-backend'
+ )
+ steps:
+ - name: Deploy to server
+ uses: appleboy/ssh-action@v1
+ with:
+ host: ${{ secrets.HITL_SSH_HOST }}
+ username: ${{ secrets.SSH_USER }}
+ key: ${{ secrets.SSH_PRIVATE_KEY }}
+ script: |
+ ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
new file mode 100644
index 000000000..06782b53c
--- /dev/null
+++ b/.github/workflows/labeler.yml
@@ -0,0 +1,14 @@
+name: "Pull Request Labeler"
+on:
+ pull_request_target:
+
+jobs:
+ labeler:
+ permissions:
+ contents: read
+ pull-requests: write
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/labeler@v6
+ with:
+ sync-labels: true
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 1870b1f67..b6df1d7e9 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- - uses: actions/stale@v5
+ - uses: actions/stale@v10
with:
days-before-issue-stale: 15
days-before-issue-close: 3
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index 6c5d6f413..fdc05d1d6 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -65,6 +65,9 @@ jobs:
defaults:
run:
working-directory: ./web
+ permissions:
+ checks: write
+ pull-requests: read
steps:
- name: Checkout code
@@ -103,23 +106,32 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
- pnpm run lint
+ pnpm run lint:ci
+ # pnpm run lint:report
+ # continue-on-error: true
+
+ # - name: Annotate Code
+ # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
+ # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
+ # with:
+ # eslint-report: web/eslint_report.json
+ # github-token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Web tsslint
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm run lint:tss
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
- run: pnpm run type-check:tsgo
+ run: pnpm run type-check
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run knip
- - name: Web build check
- if: steps.changed-files.outputs.any_changed == 'true'
- working-directory: ./web
- run: pnpm run build
-
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index dcbc675cf..ec392cb3b 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -21,17 +21,12 @@ jobs:
working-directory: sdks/nodejs-client
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
with:
persist-credentials: false
-<<<<<<< HEAD
- - name: Use Node.js ${{ matrix.node-version }}
- uses: actions/setup-node@v4
-=======
- name: Use Node.js
uses: actions/setup-node@v6
->>>>>>> 328897f81c (build: require node 24.13.0 (#30945))
with:
node-version: 24
cache: ''
diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml
index 8344af989..5d9440ff3 100644
--- a/.github/workflows/translate-i18n-claude.yml
+++ b/.github/workflows/translate-i18n-claude.yml
@@ -134,6 +134,9 @@ jobs:
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
+ # Allow github-actions bot to trigger this workflow via repository_dispatch
+ # See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
+ allowed_bots: 'github-actions[bot]'
prompt: |
You are a professional i18n synchronization engineer for the Dify project.
Your task is to keep all language translations in sync with the English source (en-US).
@@ -285,6 +288,22 @@ jobs:
- `${variable}` - Template literal
- `content ` - HTML tags
- `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
+
+ **CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them**
+
+ ✅ CORRECT examples:
+ - English: "{{count}} items" → Japanese: "{{count}} 個のアイテム"
+ - English: "{{name}} updated" → Korean: "{{name}} 업데이트됨"
+ - English: "{{email}} " → Chinese: "{{email}} "
+ - English: "Marketplace " → Japanese: "マーケットプレイス "
+
+ ❌ WRONG examples (NEVER do this - will break the application):
+ - "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese)
+ - "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean)
+ - "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese)
+ - "" → "<メール>" ❌ (tag name translated)
+ - "" → "<自定义链接>" ❌ (component name translated)
+
- Use appropriate language register (formal/informal) based on existing translations
- Match existing translation style in each language
- Technical terms: check existing conventions per language
diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml
new file mode 100644
index 000000000..66a29453b
--- /dev/null
+++ b/.github/workflows/trigger-i18n-sync.yml
@@ -0,0 +1,66 @@
+name: Trigger i18n Sync on Push
+
+# This workflow bridges the push event to repository_dispatch
+# because claude-code-action doesn't support push events directly.
+# See: https://github.com/langgenius/dify/issues/30743
+
+on:
+ push:
+ branches: [main]
+ paths:
+ - 'web/i18n/en-US/*.json'
+
+permissions:
+ contents: write
+
+jobs:
+ trigger:
+ if: github.repository == 'langgenius/dify'
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+
+ - name: Detect changed files and generate diff
+ id: detect
+ run: |
+ BEFORE_SHA="${{ github.event.before }}"
+ # Handle edge case: force push may have null/zero SHA
+ if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
+ BEFORE_SHA="HEAD~1"
+ fi
+
+ # Detect changed i18n files
+ changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
+ echo "changed_files=$changed" >> $GITHUB_OUTPUT
+
+ # Generate diff for context
+ git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
+
+ # Truncate if too large (keep first 50KB to match receiving workflow)
+ head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
+ mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
+
+ # Base64 encode the diff for safe JSON transport (portable, single-line)
+ diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
+ echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
+
+ if [ -n "$changed" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ echo "Detected changed files: $changed"
+ else
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ echo "No i18n changes detected"
+ fi
+
+ - name: Trigger i18n sync workflow
+ if: steps.detect.outputs.has_changes == 'true'
+ uses: peter-evans/repository-dispatch@v3
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ event-type: i18n-sync
+ client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index 65c958a45..191ce56aa 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -366,3 +366,48 @@ jobs:
path: web/coverage
retention-days: 30
if-no-files-found: error
+
+ web-build:
+ name: Web Build
+ runs-on: ubuntu-latest
+ defaults:
+ run:
+ working-directory: ./web
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v6
+ with:
+ persist-credentials: false
+
+ - name: Check changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v47
+ with:
+ files: |
+ web/**
+ .github/workflows/web-tests.yml
+
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Setup NodeJS
+ uses: actions/setup-node@v6
+ if: steps.changed-files.outputs.any_changed == 'true'
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Web dependencies
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm install --frozen-lockfile
+
+ - name: Web build check
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm run build
diff --git a/.nvmrc b/.nvmrc
deleted file mode 100644
index 7af24b7dd..000000000
--- a/.nvmrc
+++ /dev/null
@@ -1 +0,0 @@
-22.11.0
diff --git a/AGENTS.md~upstream_main b/AGENTS.md~upstream_main
index 782861ad3..deab7c862 100644
--- a/AGENTS.md~upstream_main
+++ b/AGENTS.md~upstream_main
@@ -12,12 +12,8 @@ The codebase is split into:
## Backend Workflow
+- Read `api/AGENTS.md` for details
- Run backend CLI commands through `uv run --project api `.
-
-- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-
-- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
-
- Integration tests are CI-only and are not expected to run in the local environment.
## Frontend Workflow
diff --git a/Makefile b/Makefile
index 60c32948b..e92a7b131 100644
--- a/Makefile
+++ b/Makefile
@@ -61,7 +61,8 @@ check:
lint:
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
- @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
+ @uv run --project api --dev ruff format ./api
+ @uv run --project api --dev ruff check --fix ./api
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
@@ -73,7 +74,12 @@ type-check:
test:
@echo "🧪 Running backend unit tests..."
- @uv run --project api --dev dev/pytest/pytest_unit_tests.sh
+ @if [ -n "$(TARGET_TESTS)" ]; then \
+ echo "Target: $(TARGET_TESTS)"; \
+ uv run --project api --dev pytest $(TARGET_TESTS); \
+ else \
+ uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
+ fi
@echo "✅ Tests complete"
# Build Docker images
@@ -125,7 +131,7 @@ help:
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
- @echo " make test - Run backend unit tests"
+ @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
diff --git a/agent-notes/.gitkeep b/agent-notes/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/api/.env.example b/api/.env.example
index 53e24571d..d9f03ec09 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -417,6 +417,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
+# Optional: override the local hostname used for SMTP HELO/EHLO
+SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
@@ -589,6 +591,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
+ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
@@ -712,6 +715,8 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
+SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
+
# OA Oauth2(二开新增配置)
diff --git a/api/.importlinter b/api/.importlinter
index 2dec95878..b676e9759 100644
--- a/api/.importlinter
+++ b/api/.importlinter
@@ -27,7 +27,9 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
- core.workflow.nodes.node_factory -> core.workflow.graph
+ core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
+ core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
+
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
@@ -57,6 +59,252 @@ ignore_imports =
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
+[importlinter:contract:workflow-external-imports]
+name = Workflow External Imports
+type = forbidden
+source_modules =
+ core.workflow
+forbidden_modules =
+ configs
+ controllers
+ extensions
+ models
+ services
+ tasks
+ core.agent
+ core.app
+ core.base
+ core.callback_handler
+ core.datasource
+ core.db
+ core.entities
+ core.errors
+ core.extension
+ core.external_data_tool
+ core.file
+ core.helper
+ core.hosting_configuration
+ core.indexing_runner
+ core.llm_generator
+ core.logging
+ core.mcp
+ core.memory
+ core.model_manager
+ core.moderation
+ core.ops
+ core.plugin
+ core.prompt
+ core.provider_manager
+ core.rag
+ core.repositories
+ core.schemas
+ core.tools
+ core.trigger
+ core.variables
+ignore_imports =
+ core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
+ core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
+ core.workflow.graph_engine.layers.observability -> configs
+ core.workflow.graph_engine.layers.observability -> extensions.otel.runtime
+ core.workflow.graph_engine.layers.persistence -> core.ops.ops_trace_manager
+ core.workflow.graph_engine.worker_management.worker_pool -> configs
+ core.workflow.nodes.agent.agent_node -> core.model_manager
+ core.workflow.nodes.agent.agent_node -> core.provider_manager
+ core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
+ core.workflow.nodes.datasource.datasource_node -> models.model
+ core.workflow.nodes.datasource.datasource_node -> models.tools
+ core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
+ core.workflow.nodes.document_extractor.node -> configs
+ core.workflow.nodes.document_extractor.node -> core.file.file_manager
+ core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
+ core.workflow.nodes.http_request.entities -> configs
+ core.workflow.nodes.http_request.executor -> configs
+ core.workflow.nodes.http_request.executor -> core.file.file_manager
+ core.workflow.nodes.http_request.node -> configs
+ core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
+ core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.llm.llm_utils -> configs
+ core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.llm.llm_utils -> core.file.models
+ core.workflow.nodes.llm.llm_utils -> core.model_manager
+ core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.llm.llm_utils -> models.model
+ core.workflow.nodes.llm.llm_utils -> models.provider
+ core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
+ core.workflow.nodes.llm.node -> core.tools.signature
+ core.workflow.nodes.template_transform.template_transform_node -> configs
+ core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
+ core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
+ core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
+ core.workflow.workflow_entry -> configs
+ core.workflow.workflow_entry -> models.workflow
+ core.workflow.nodes.agent.agent_node -> core.agent.entities
+ core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
+ core.workflow.graph_engine.layers.persistence -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.start.entities -> core.app.app_config.entities
+ core.workflow.nodes.start.start_node -> core.app.app_config.entities
+ core.workflow.workflow_entry -> core.app.apps.exc
+ core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
+ core.workflow.workflow_entry -> core.app.workflow.node_factory
+ core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
+ core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
+ core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
+ core.workflow.node_events.node -> core.file
+ core.workflow.nodes.agent.agent_node -> core.file
+ core.workflow.nodes.datasource.datasource_node -> core.file
+ core.workflow.nodes.datasource.datasource_node -> core.file.enums
+ core.workflow.nodes.document_extractor.node -> core.file
+ core.workflow.nodes.http_request.executor -> core.file.enums
+ core.workflow.nodes.http_request.node -> core.file
+ core.workflow.nodes.http_request.node -> core.file.file_manager
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
+ core.workflow.nodes.list_operator.node -> core.file
+ core.workflow.nodes.llm.file_saver -> core.file
+ core.workflow.nodes.llm.llm_utils -> core.variables.segments
+ core.workflow.nodes.llm.node -> core.file
+ core.workflow.nodes.llm.node -> core.file.file_manager
+ core.workflow.nodes.llm.node -> core.file.models
+ core.workflow.nodes.loop.entities -> core.variables.types
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
+ core.workflow.nodes.protocols -> core.file
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
+ core.workflow.nodes.tool.tool_node -> core.file
+ core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
+ core.workflow.nodes.tool.tool_node -> models
+ core.workflow.nodes.trigger_webhook.node -> core.file
+ core.workflow.runtime.variable_pool -> core.file
+ core.workflow.runtime.variable_pool -> core.file.file_manager
+ core.workflow.system_variable -> core.file.models
+ core.workflow.utils.condition.processor -> core.file
+ core.workflow.utils.condition.processor -> core.file.file_manager
+ core.workflow.workflow_entry -> core.file.models
+ core.workflow.workflow_type_encoder -> core.file.models
+ core.workflow.nodes.agent.agent_node -> models.model
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
+ core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
+ core.workflow.nodes.datasource.datasource_node -> core.variables.variables
+ core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
+ core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
+ core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
+ core.workflow.nodes.llm.node -> core.helper.code_executor
+ core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
+ core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
+ core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
+ core.workflow.nodes.llm.node -> core.model_manager
+ core.workflow.graph_engine.layers.persistence -> core.ops.entities.trace_entity
+ core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.llm.node -> models.dataset
+ core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
+ core.workflow.nodes.llm.file_saver -> core.tools.signature
+ core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
+ core.workflow.nodes.tool.tool_node -> core.tools.errors
+ core.workflow.conversation_variable_updater -> core.variables
+ core.workflow.graph_engine.entities.commands -> core.variables.variables
+ core.workflow.nodes.agent.agent_node -> core.variables.segments
+ core.workflow.nodes.answer.answer_node -> core.variables
+ core.workflow.nodes.code.code_node -> core.variables.segments
+ core.workflow.nodes.code.code_node -> core.variables.types
+ core.workflow.nodes.code.entities -> core.variables.types
+ core.workflow.nodes.datasource.datasource_node -> core.variables.segments
+ core.workflow.nodes.document_extractor.node -> core.variables
+ core.workflow.nodes.document_extractor.node -> core.variables.segments
+ core.workflow.nodes.http_request.executor -> core.variables.segments
+ core.workflow.nodes.http_request.node -> core.variables.segments
+ core.workflow.nodes.iteration.iteration_node -> core.variables
+ core.workflow.nodes.iteration.iteration_node -> core.variables.segments
+ core.workflow.nodes.iteration.iteration_node -> core.variables.variables
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
+ core.workflow.nodes.list_operator.node -> core.variables
+ core.workflow.nodes.list_operator.node -> core.variables.segments
+ core.workflow.nodes.llm.node -> core.variables
+ core.workflow.nodes.loop.loop_node -> core.variables
+ core.workflow.nodes.parameter_extractor.entities -> core.variables.types
+ core.workflow.nodes.parameter_extractor.exc -> core.variables.types
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
+ core.workflow.nodes.tool.tool_node -> core.variables.segments
+ core.workflow.nodes.tool.tool_node -> core.variables.variables
+ core.workflow.nodes.trigger_webhook.node -> core.variables.types
+ core.workflow.nodes.trigger_webhook.node -> core.variables.variables
+ core.workflow.nodes.variable_aggregator.entities -> core.variables.types
+ core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
+ core.workflow.nodes.variable_assigner.v1.node -> core.variables
+ core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
+ core.workflow.nodes.variable_assigner.v2.node -> core.variables
+ core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
+ core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
+ core.workflow.runtime.read_only_wrappers -> core.variables.segments
+ core.workflow.runtime.variable_pool -> core.variables
+ core.workflow.runtime.variable_pool -> core.variables.consts
+ core.workflow.runtime.variable_pool -> core.variables.segments
+ core.workflow.runtime.variable_pool -> core.variables.variables
+ core.workflow.utils.condition.processor -> core.variables
+ core.workflow.utils.condition.processor -> core.variables.segments
+ core.workflow.variable_loader -> core.variables
+ core.workflow.variable_loader -> core.variables.consts
+ core.workflow.workflow_type_encoder -> core.variables
+ core.workflow.graph_engine.manager -> extensions.ext_redis
+ core.workflow.nodes.agent.agent_node -> extensions.ext_database
+ core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
+ core.workflow.nodes.llm.file_saver -> extensions.ext_database
+ core.workflow.nodes.llm.llm_utils -> extensions.ext_database
+ core.workflow.nodes.llm.node -> extensions.ext_database
+ core.workflow.nodes.tool.tool_node -> extensions.ext_database
+ core.workflow.workflow_entry -> extensions.otel.runtime
+ core.workflow.nodes.agent.agent_node -> models
+ core.workflow.nodes.base.node -> models.enums
+ core.workflow.nodes.llm.llm_utils -> models.provider_ids
+ core.workflow.nodes.llm.node -> models.model
+ core.workflow.workflow_entry -> models.enums
+ core.workflow.nodes.agent.agent_node -> services
+ core.workflow.nodes.tool.tool_node -> services
+
[importlinter:contract:rsc]
name = RSC
type = layers
diff --git a/api/AGENTS.md b/api/AGENTS.md
index 17398ec4b..6ce419828 100644
--- a/api/AGENTS.md
+++ b/api/AGENTS.md
@@ -1,62 +1,236 @@
-# Agent Skill Index
+# API Agent Guide
+
+## Agent Notes (must-check)
+
+Before you start work on any backend file under `api/`, you MUST check whether a related note exists under:
+
+- `agent-notes/.md`
+
+Rules:
+
+- **Path mapping**: for a target file `/.py`, the note must be `agent-notes//.py.md` (same folder structure, same filename, plus `.md`).
+- **Before working**:
+ - If the note exists, read it first and follow any constraints/decisions recorded there.
+ - If the note conflicts with the current code, or references an "origin" file/path that has been deleted, renamed, or migrated, treat the **code as the single source of truth** and update the note to match reality.
+ - If the note does not exist, create it with a short architecture/intent summary and any relevant invariants/edge cases.
+- **During working**:
+ - Keep the note in sync as you discover constraints, make decisions, or change approach.
+ - If you move/rename a file, migrate its note to the new mapped path (and fix any outdated references inside the note).
+ - Record non-obvious edge cases, trade-offs, and the test/verification plan as you go (not just at the end).
+ - Keep notes **coherent**: integrate new findings into the relevant sections and rewrite for clarity; avoid append-only “recent fix” / changelog-style additions unless the note is explicitly intended to be a changelog.
+- **When finishing work**:
+ - Update the related note(s) to reflect what changed, why, and any new edge cases/tests.
+ - If a file is deleted, remove or clearly deprecate the corresponding note so it cannot be mistaken as current guidance.
+ - Keep notes concise and accurate; they are meant to prevent repeated rediscovery.
+
+## Skill Index
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
-______________________________________________________________________
+### Platform Foundations
-## Platform Foundations
-
-- **[Infrastructure Overview](agent_skills/infra.md)**\
- When to read this:
+#### [Infrastructure Overview](agent_skills/infra.md)
+- **When to read this**
- You need to understand where a feature belongs in the architecture.
- You’re wiring storage, Redis, vector stores, or OTEL.
- - You’re about to add CLI commands or async jobs.\
- What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
+ - You’re about to add CLI commands or async jobs.
+- **What it covers**
+ - Configuration stack (`configs/app_config.py`, remote settings)
+ - Storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`)
+ - Redis conventions (`extensions/ext_redis.py`)
+ - Plugin runtime topology
+ - Vector-store factory (`core/rag/datasource/vdb/*`)
+ - Observability hooks
+ - SSRF proxy usage
+ - Core CLI commands
-- **[Coding Style](agent_skills/coding_style.md)**\
- When to read this:
+### Plugin & Extension Development
- - You’re writing or reviewing backend code and need the authoritative checklist.
- - You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- - You want the exact lint/type/test commands used in PRs.\
- Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
-
-______________________________________________________________________
-
-## Plugin & Extension Development
-
-- **[Plugin Systems](agent_skills/plugin.md)**\
- When to read this:
+#### [Plugin Systems](agent_skills/plugin.md)
+- **When to read this**
- You’re building or debugging a marketplace plugin.
- - You need to know how manifests, providers, daemons, and migrations fit together.\
- What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
+ - You need to know how manifests, providers, daemons, and migrations fit together.
+- **What it covers**
+ - Plugin manifests (`core/plugin/entities/plugin.py`)
+ - Installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands)
+ - Runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent)
+ - Daemon coordination (`core/plugin/entities/plugin_daemon.py`)
+ - How provider registries surface capabilities to the rest of the platform
-- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
- When to read this:
+#### [Plugin OAuth](agent_skills/plugin_oauth.md)
+- **When to read this**
- You must integrate OAuth for a plugin or datasource.
- - You’re handling credential encryption or refresh flows.\
- Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
+ - You’re handling credential encryption or refresh flows.
+- **Topics**
+ - Credential storage
+ - Encryption helpers (`core/helper/provider_encryption.py`)
+ - OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`)
+ - How console/API layers expose the flows
-______________________________________________________________________
+### Workflow Entry & Execution
-## Workflow Entry & Execution
+#### [Trigger Concepts](agent_skills/trigger.md)
-- **[Trigger Concepts](agent_skills/trigger.md)**\
- When to read this:
+- **When to read this**
- You’re debugging why a workflow didn’t start.
- You’re adding a new trigger type or hook.
- - You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
- Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
+ - You need to trace async execution, draft debugging, or webhook/schedule pipelines.
+- **Details**
+ - Start-node taxonomy
+ - Webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`)
+ - Async orchestration (`services/async_workflow_service.py`, Celery queues)
+ - Debug event bus
+ - Storage/logging interactions
-______________________________________________________________________
+## General Reminders
-## Additional Notes for Agents
-
-- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
+- All skill docs assume you follow the coding style rules below—run the lint/type/test commands before submitting changes.
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
+
+## Coding Style
+
+This is the default standard for backend code in this repo. Follow it for new code and use it as the checklist when reviewing changes.
+
+### Linting & Formatting
+
+- Use Ruff for formatting and linting (follow `.ruff.toml`).
+- Keep each line under 120 characters (including spaces).
+
+### Naming Conventions
+
+- Use `snake_case` for variables and functions.
+- Use `PascalCase` for classes.
+- Use `UPPER_CASE` for constants.
+
+### Typing & Class Layout
+
+- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
+- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
+- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
+
+```python
+from datetime import datetime
+
+
+class Example:
+ user_id: str
+ created_at: datetime
+
+ def __init__(self, user_id: str, created_at: datetime) -> None:
+ self.user_id = user_id
+ self.created_at = created_at
+```
+
+### General Rules
+
+- Use Pydantic v2 conventions.
+- Use `uv` for Python package management in this repo (usually with `--project api`).
+- Prefer simple functions over small “utility classes” for lightweight helpers.
+- Avoid implementing dunder methods unless it’s clearly needed and matches existing patterns.
+- Never start long-running services as part of agent work (`uv run app.py`, `flask run`, etc.); running tests is allowed.
+- Keep files below ~800 lines; split when necessary.
+- Keep code readable and explicit—avoid clever hacks.
+
+### Architecture & Boundaries
+
+- Mirror the layered architecture: controller → service → core/domain.
+- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
+- Optimise for observability: deterministic control flow, clear logging, actionable errors.
+
+### Logging & Errors
+
+- Never use `print`; use a module-level logger:
+ - `logger = logging.getLogger(__name__)`
+- Include tenant/app/workflow identifiers in log context when relevant.
+- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate them into HTTP responses in controllers.
+- Log retryable events at `warning`, terminal failures at `error`.
+
+### SQLAlchemy Patterns
+
+- Models inherit from `models.base.TypeBase`; do not create ad-hoc metadata or engines.
+- Open sessions with context managers:
+
+```python
+from sqlalchemy.orm import Session
+
+with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(Workflow).where(
+ Workflow.id == workflow_id,
+ Workflow.tenant_id == tenant_id,
+ )
+ workflow = session.execute(stmt).scalar_one_or_none()
+```
+
+- Prefer SQLAlchemy expressions; avoid raw SQL unless necessary.
+- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
+- Introduce repository abstractions only for very large tables (e.g., workflow executions) or when alternative storage strategies are required.
+
+### Storage & External I/O
+
+- Access storage via `extensions.ext_storage.storage`.
+- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
+- Background tasks that touch storage must be idempotent, and should log relevant object identifiers.
+
+### Pydantic Usage
+
+- Define DTOs with Pydantic v2 models and forbid extras by default.
+- Use `@field_validator` / `@model_validator` for domain rules.
+
+Example:
+
+```python
+from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
+
+
+class TriggerConfig(BaseModel):
+ endpoint: HttpUrl
+ secret: str
+
+ model_config = ConfigDict(extra="forbid")
+
+ @field_validator("secret")
+ def ensure_secret_prefix(cls, value: str) -> str:
+ if not value.startswith("dify_"):
+ raise ValueError("secret must start with dify_")
+ return value
+```
+
+### Generics & Protocols
+
+- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
+- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
+- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
+
+### Tooling & Checks
+
+Quick checks while iterating:
+
+- Format: `make format`
+- Lint (includes auto-fix): `make lint`
+- Type check: `make type-check`
+- Targeted tests: `make test TARGET_TESTS=./api/tests/`
+
+Before opening a PR / submitting:
+
+- `make lint`
+- `make type-check`
+- `make test`
+
+### Controllers & Services
+
+- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
+- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
+- Document non-obvious behaviour with concise comments.
+
+### Miscellaneous
+
+- Use `configs.dify_config` for configuration—never read environment variables directly.
+- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
+- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
+- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/agent_skills/coding_style.md b/api/agent_skills/coding_style.md
deleted file mode 100644
index a2b66f0bd..000000000
--- a/api/agent_skills/coding_style.md
+++ /dev/null
@@ -1,115 +0,0 @@
-## Linter
-
-- Always follow `.ruff.toml`.
-- Run `uv run ruff check --fix --unsafe-fixes`.
-- Keep each line under 100 characters (including spaces).
-
-## Code Style
-
-- `snake_case` for variables and functions.
-- `PascalCase` for classes.
-- `UPPER_CASE` for constants.
-
-## Rules
-
-- Use Pydantic v2 standard.
-- Use `uv` for package management.
-- Do not override dunder methods like `__init__`, `__iadd__`, etc.
-- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
-- Prefer simple functions over classes for lightweight helpers.
-- Keep files below 800 lines; split when necessary.
-- Keep code readable—no clever hacks.
-- Never use `print`; log with `logger = logging.getLogger(__name__)`.
-
-## Guiding Principles
-
-- Mirror the project’s layered architecture: controller → service → core/domain.
-- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
-- Optimise for observability: deterministic control flow, clear logging, actionable errors.
-
-## SQLAlchemy Patterns
-
-- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
-
-- Open sessions with context managers:
-
- ```python
- from sqlalchemy.orm import Session
-
- with Session(db.engine, expire_on_commit=False) as session:
- stmt = select(Workflow).where(
- Workflow.id == workflow_id,
- Workflow.tenant_id == tenant_id,
- )
- workflow = session.execute(stmt).scalar_one_or_none()
- ```
-
-- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
-
-- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
-
-- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
-
-## Storage & External IO
-
-- Access storage via `extensions.ext_storage.storage`.
-- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
-- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
-
-## Pydantic Usage
-
-- Define DTOs with Pydantic v2 models and forbid extras by default.
-
-- Use `@field_validator` / `@model_validator` for domain rules.
-
-- Example:
-
- ```python
- from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
-
- class TriggerConfig(BaseModel):
- endpoint: HttpUrl
- secret: str
-
- model_config = ConfigDict(extra="forbid")
-
- @field_validator("secret")
- def ensure_secret_prefix(cls, value: str) -> str:
- if not value.startswith("dify_"):
- raise ValueError("secret must start with dify_")
- return value
- ```
-
-## Generics & Protocols
-
-- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
-- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
-- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
-
-## Error Handling & Logging
-
-- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
-- Declare `logger = logging.getLogger(__name__)` at module top.
-- Include tenant/app/workflow identifiers in log context.
-- Log retryable events at `warning`, terminal failures at `error`.
-
-## Tooling & Checks
-
-- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
-- Type checks: `uv run --directory api --dev basedpyright`.
-- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-- Run all of the above before submitting your work.
-
-## Controllers & Services
-
-- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
-- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
-- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
-- Document non-obvious behaviour with concise comments.
-
-## Miscellaneous
-
-- Use `configs.dify_config` for configuration—never read environment variables directly.
-- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
-- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
-- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/agent_skills/infra.md b/api/agent_skills/infra.md
deleted file mode 100644
index bc36c7bf6..000000000
--- a/api/agent_skills/infra.md
+++ /dev/null
@@ -1,96 +0,0 @@
-## Configuration
-
-- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
-- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
-- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
-- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
-
-## Dependencies
-
-- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
-- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
-- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
-
-## Storage & Files
-
-- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
-- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
-- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
-- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
-
-## Redis & Shared State
-
-- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
-- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
-
-## Models
-
-- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
-- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
-- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
-- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
-
-## Vector Stores
-
-- Vector client implementations live in `core/rag/datasource/vdb/`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
-- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
-- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
-- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
-
-## Observability & OTEL
-
-- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
-- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
-- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
-- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
-
-## Ops Integrations
-
-- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
-- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
-- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
-
-## Controllers, Services, Core
-
-- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
-- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
-- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
-
-## Plugins, Tools, Providers
-
-- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
-- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
-- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
-- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
-- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
-- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
-
-## Async Workloads
-
-see `agent_skills/trigger.md` for more detailed documentation.
-
-- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
-- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
-- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
-
-## Database & Migrations
-
-- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
-- Generate migrations with `uv run --project api flask db revision --autogenerate -m ""`, then review the diff; never hand-edit the database outside Alembic.
-- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
-- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
-
-## CLI Commands
-
-- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask `.
-- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
-- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
-- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
-- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
-
-## When You Add Features
-
-- Check for an existing helper or service before writing a new util.
-- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
-- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
-- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
diff --git a/api/agent_skills/plugin.md b/api/agent_skills/plugin.md
deleted file mode 100644
index 954ddd236..000000000
--- a/api/agent_skills/plugin.md
+++ /dev/null
@@ -1 +0,0 @@
-// TBD
diff --git a/api/agent_skills/plugin_oauth.md b/api/agent_skills/plugin_oauth.md
deleted file mode 100644
index 954ddd236..000000000
--- a/api/agent_skills/plugin_oauth.md
+++ /dev/null
@@ -1 +0,0 @@
-// TBD
diff --git a/api/agent_skills/trigger.md b/api/agent_skills/trigger.md
deleted file mode 100644
index f4b076332..000000000
--- a/api/agent_skills/trigger.md
+++ /dev/null
@@ -1,53 +0,0 @@
-## Overview
-
-Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
-
-## Trigger nodes
-
-- `UserInput`
-- `Trigger Webhook`
-- `Trigger Schedule`
-- `Trigger Plugin`
-
-### UserInput
-
-Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
-
-1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
-1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
-1. For its detailed implementation, please refer to `core/workflow/nodes/start`
-
-### Trigger Webhook
-
-Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
-
-Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
-
-### Trigger Schedule
-
-`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
-
-To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
-
-### Trigger Plugin
-
-`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
-
-1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
-1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
-
-A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
-
-## Worker Pool / Async Task
-
-All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
-
-The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
-
-## Debug Strategy
-
-Dify divided users into 2 groups: builders / end users.
-
-Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
-
-A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
diff --git a/api/app_factory.py b/api/app_factory.py
index f827842d6..07859a375 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -71,6 +71,8 @@ def create_app() -> DifyApp:
def initialize_extensions(app: DifyApp):
+ # Initialize Flask context capture for workflow execution
+ from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
@@ -79,6 +81,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
+ ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
@@ -100,6 +103,8 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
)
+ init_flask_context()
+
extensions = [
ext_timezone,
ext_logging,
@@ -124,6 +129,7 @@ def initialize_extensions(app: DifyApp):
ext_proxy_fix,
ext_blueprints,
ext_commands,
+ ext_fastopenapi,
ext_otel,
ext_request_logging,
ext_session_factory,
diff --git a/api/commands.py b/api/commands.py
index 511b77bfc..a1aa8ac45 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -1,7 +1,9 @@
import base64
+import datetime
import json
import logging
import secrets
+import time
from typing import Any
import click
@@ -34,7 +36,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
-from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
+from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@@ -45,6 +47,9 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -62,8 +67,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
+ normalized_email = email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@@ -84,7 +91,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -100,20 +107,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
+ normalized_new_email = new_email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
- email_validate(new_email)
+ email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
- account.email = new_email
+ account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@@ -658,7 +667,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
- email = email.strip()
+ email = email.strip().lower()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@@ -852,6 +861,435 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
+@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
+@click.option(
+ "--before-days",
+ "--days",
+ default=30,
+ show_default=True,
+ type=click.IntRange(min=0),
+ help="Delete workflow runs created before N days ago.",
+)
+@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
+@click.option(
+ "--from-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
+)
+@click.option(
+ "--to-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
+)
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option(
+ "--dry-run",
+ is_flag=True,
+ help="Preview cleanup results without deleting any workflow run data.",
+)
+def clean_workflow_runs(
+ before_days: int,
+ batch_size: int,
+ from_days_ago: int | None,
+ to_days_ago: int | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ dry_run: bool,
+):
+ """
+ Clean workflow runs and related workflow data for free tenants.
+ """
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+
+ if (from_days_ago is None) ^ (to_days_ago is None):
+ raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
+
+ if from_days_ago is not None and to_days_ago is not None:
+ if start_from or end_before:
+ raise click.UsageError("Choose either day offsets or explicit dates, not both.")
+ if from_days_ago <= to_days_ago:
+ raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
+ now = datetime.datetime.now()
+ start_from = now - datetime.timedelta(days=from_days_ago)
+ end_before = now - datetime.timedelta(days=to_days_ago)
+ before_days = 0
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
+
+ WorkflowRunCleanup(
+ days=before_days,
+ batch_size=batch_size,
+ start_from=start_from,
+ end_before=end_before,
+ dry_run=dry_run,
+ ).run()
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Workflow run cleanup completed. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+
+
+@click.command(
+ "archive-workflow-runs",
+ help="Archive workflow runs for paid plan tenants to S3-compatible storage.",
+)
+@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.")
+@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.")
+@click.option(
+ "--from-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
+)
+@click.option(
+ "--to-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
+)
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Archive runs created at or after this timestamp (UTC if no timezone).",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Archive runs created before this timestamp (UTC if no timezone).",
+)
+@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.")
+@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.")
+@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.")
+@click.option("--dry-run", is_flag=True, help="Preview without archiving.")
+@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.")
+def archive_workflow_runs(
+ tenant_ids: str | None,
+ before_days: int,
+ from_days_ago: int | None,
+ to_days_ago: int | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ batch_size: int,
+ workers: int,
+ limit: int | None,
+ dry_run: bool,
+ delete_after_archive: bool,
+):
+ """
+ Archive workflow runs for paid plan tenants older than the specified days.
+
+ This command archives the following tables to storage:
+ - workflow_node_executions
+ - workflow_node_execution_offload
+ - workflow_pauses
+ - workflow_pause_reasons
+ - workflow_trigger_logs
+
+ The workflow_runs and workflow_app_logs tables are preserved for UI listing.
+ """
+ from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
+
+ run_started_at = datetime.datetime.now(datetime.UTC)
+ click.echo(
+ click.style(
+ f"Starting workflow run archiving at {run_started_at.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ if (start_from is None) ^ (end_before is None):
+ click.echo(click.style("start-from and end-before must be provided together.", fg="red"))
+ return
+
+ if (from_days_ago is None) ^ (to_days_ago is None):
+ click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red"))
+ return
+
+ if from_days_ago is not None and to_days_ago is not None:
+ if start_from or end_before:
+ click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red"))
+ return
+ if from_days_ago <= to_days_ago:
+ click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red"))
+ return
+ now = datetime.datetime.now()
+ start_from = now - datetime.timedelta(days=from_days_ago)
+ end_before = now - datetime.timedelta(days=to_days_ago)
+ before_days = 0
+
+ if start_from and end_before and start_from >= end_before:
+ click.echo(click.style("start-from must be earlier than end-before.", fg="red"))
+ return
+ if workers < 1:
+ click.echo(click.style("workers must be at least 1.", fg="red"))
+ return
+
+ archiver = WorkflowRunArchiver(
+ days=before_days,
+ batch_size=batch_size,
+ start_from=start_from,
+ end_before=end_before,
+ workers=workers,
+ tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None,
+ limit=limit,
+ dry_run=dry_run,
+ delete_after_archive=delete_after_archive,
+ )
+ summary = archiver.run()
+ click.echo(
+ click.style(
+ f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
+ f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
+ f"time={summary.total_elapsed_time:.2f}s",
+ fg="cyan",
+ )
+ )
+
+ run_finished_at = datetime.datetime.now(datetime.UTC)
+ elapsed = run_finished_at - run_started_at
+ click.echo(
+ click.style(
+ f"Workflow run archiving completed. start={run_started_at.isoformat()} "
+ f"end={run_finished_at.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+
+
+@click.command(
+ "restore-workflow-runs",
+ help="Restore archived workflow runs from S3-compatible storage.",
+)
+@click.option(
+ "--tenant-ids",
+ required=False,
+ help="Tenant IDs (comma-separated).",
+)
+@click.option("--run-id", required=False, help="Workflow run ID to restore.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.")
+@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.")
+@click.option("--dry-run", is_flag=True, help="Preview without restoring.")
+def restore_workflow_runs(
+ tenant_ids: str | None,
+ run_id: str | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ workers: int,
+ limit: int,
+ dry_run: bool,
+):
+ """
+ Restore an archived workflow run from storage to the database.
+
+ This restores the following tables:
+ - workflow_node_executions
+ - workflow_node_execution_offload
+ - workflow_pauses
+ - workflow_pause_reasons
+ - workflow_trigger_logs
+ """
+ from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
+
+ parsed_tenant_ids = None
+ if tenant_ids:
+ parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
+ if not parsed_tenant_ids:
+ raise click.BadParameter("tenant-ids must not be empty")
+
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+ if run_id is None and (start_from is None or end_before is None):
+ raise click.UsageError("--start-from and --end-before are required for batch restore.")
+ if workers < 1:
+ raise click.BadParameter("workers must be at least 1")
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ click.echo(
+ click.style(
+ f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
+ if run_id:
+ results = [restorer.restore_by_run_id(run_id)]
+ else:
+ assert start_from is not None
+ assert end_before is not None
+ results = restorer.restore_batch(
+ parsed_tenant_ids,
+ start_date=start_from,
+ end_date=end_before,
+ limit=limit,
+ )
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+
+ successes = sum(1 for result in results if result.success)
+ failures = len(results) - successes
+
+ if failures == 0:
+ click.echo(
+ click.style(
+ f"Restore completed successfully. success={successes} duration={elapsed}",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
+ fg="red",
+ )
+ )
+
+
+@click.command(
+ "delete-archived-workflow-runs",
+ help="Delete archived workflow runs from the database.",
+)
+@click.option(
+ "--tenant-ids",
+ required=False,
+ help="Tenant IDs (comma-separated).",
+)
+@click.option("--run-id", required=False, help="Workflow run ID to delete.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.")
+@click.option("--dry-run", is_flag=True, help="Preview without deleting.")
+def delete_archived_workflow_runs(
+ tenant_ids: str | None,
+ run_id: str | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ limit: int,
+ dry_run: bool,
+):
+ """
+ Delete archived workflow runs from the database.
+ """
+ from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
+
+ parsed_tenant_ids = None
+ if tenant_ids:
+ parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
+ if not parsed_tenant_ids:
+ raise click.BadParameter("tenant-ids must not be empty")
+
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+ if run_id is None and (start_from is None or end_before is None):
+ raise click.UsageError("--start-from and --end-before are required for batch delete.")
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ target_desc = f"workflow run {run_id}" if run_id else "workflow runs"
+ click.echo(
+ click.style(
+ f"Starting delete of {target_desc} at {start_time.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run)
+ if run_id:
+ results = [deleter.delete_by_run_id(run_id)]
+ else:
+ assert start_from is not None
+ assert end_before is not None
+ results = deleter.delete_batch(
+ parsed_tenant_ids,
+ start_date=start_from,
+ end_date=end_before,
+ limit=limit,
+ )
+
+ for result in results:
+ if result.success:
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
+ f"workflow run {result.run_id} (tenant={result.tenant_id})",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Failed to delete workflow run {result.run_id}: {result.error}",
+ fg="red",
+ )
+ )
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+
+ successes = sum(1 for result in results if result.success)
+ failures = len(results) - successes
+
+ if failures == 0:
+ click.echo(
+ click.style(
+ f"Delete completed successfully. success={successes} duration={elapsed}",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
+ fg="red",
+ )
+ )
+
+
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):
@@ -2113,6 +2551,82 @@ def migrate_oss(
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
+@click.command("clean-expired-messages", help="Clean expired messages.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Lower bound (inclusive) for created_at.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Upper bound (exclusive) for created_at.",
+)
+@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
+@click.option(
+ "--graceful-period",
+ default=21,
+ show_default=True,
+ help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
+)
+@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
+def clean_expired_messages(
+ batch_size: int,
+ graceful_period: int,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ dry_run: bool,
+):
+ """
+ Clean expired messages and related data for tenants based on clean policy.
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ # NOTE: graceful_period will be ignored when billing is disabled.
+ policy = create_message_clean_policy(graceful_period_days=graceful_period)
+
+ # Create and run the cleanup service
+ service = MessagesCleanService.from_time_range(
+ policy=policy,
+ start_from=start_from,
+ end_before=end_before,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
+
+ click.echo(click.style("messages cleanup completed.", fg="green"))
+
+
# extend: start 管理二开db扩展
@click.group("extend_db", help="管理二开扩展表的数据库迁移")
def extend_db():
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 6a04171d2..786094f29 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -949,6 +949,12 @@ class MailConfig(BaseSettings):
default=False,
)
+ SMTP_LOCAL_HOSTNAME: str | None = Field(
+ description="Override the local hostname used in SMTP HELO/EHLO. "
+ "Useful behind NAT or when the default hostname causes rejections.",
+ default=None,
+ )
+
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
@@ -959,6 +965,16 @@ class MailConfig(BaseSettings):
default=None,
)
+ ENABLE_TRIAL_APP: bool = Field(
+ description="Enable trial app",
+ default=False,
+ )
+
+ ENABLE_EXPLORE_BANNER: bool = Field(
+ description="Enable explore banner",
+ default=False,
+ )
+
class RagEtlConfig(BaseSettings):
"""
@@ -1101,6 +1117,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
+ ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
+ description="Enable scheduled workflow run cleanup task",
+ default=False,
+ )
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
@@ -1288,6 +1308,10 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
+ SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
+ description="Lock TTL for sandbox expired records clean task in seconds",
+ default=90000,
+ )
class FeatureConfig(
diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py
index be01f2dc3..2a3530040 100644
--- a/api/configs/middleware/storage/volcengine_tos_storage_config.py
+++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py
@@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
- Configuration settings for Volcengine Tinder Object Storage (TOS)
+ Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(
diff --git a/api/context/__init__.py b/api/context/__init__.py
new file mode 100644
index 000000000..aebf9750c
--- /dev/null
+++ b/api/context/__init__.py
@@ -0,0 +1,74 @@
+"""
+Core Context - Framework-agnostic context management.
+
+This module provides context management that is independent of any specific
+web framework. Framework-specific implementations register their context
+capture functions at application initialization time.
+
+This ensures the workflow layer remains completely decoupled from Flask
+or any other web framework.
+"""
+
+import contextvars
+from collections.abc import Callable
+
+from core.workflow.context.execution_context import (
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+)
+
+# Global capturer function - set by framework-specific modules
+_capturer: Callable[[], IExecutionContext] | None = None
+
+
+def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
+ """
+ Register a context capture function.
+
+ This should be called by framework-specific modules (e.g., Flask)
+ during application initialization.
+
+ Args:
+ capturer: Function that captures current context and returns IExecutionContext
+ """
+ global _capturer
+ _capturer = capturer
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context.
+
+ This function uses the registered context capturer. If no capturer
+ is registered, it returns a minimal context with only contextvars
+ (suitable for non-framework environments like tests or standalone scripts).
+
+ Returns:
+ IExecutionContext with captured context
+ """
+ if _capturer is None:
+ # No framework registered - return minimal context
+ return ExecutionContext(
+ app_context=NullAppContext(),
+ context_vars=contextvars.copy_context(),
+ )
+
+ return _capturer()
+
+
+def reset_context_provider() -> None:
+ """
+ Reset the context capturer.
+
+ This is primarily useful for testing to ensure a clean state.
+ """
+ global _capturer
+ _capturer = None
+
+
+__all__ = [
+ "capture_current_context",
+ "register_context_capturer",
+ "reset_context_provider",
+]
diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py
new file mode 100644
index 000000000..2d465c8cf
--- /dev/null
+++ b/api/context/flask_app_context.py
@@ -0,0 +1,192 @@
+"""
+Flask App Context - Flask implementation of AppContext interface.
+"""
+
+import contextvars
+import threading
+from collections.abc import Generator
+from contextlib import contextmanager
+from typing import Any, final
+
+from flask import Flask, current_app, g
+
+from core.workflow.context import register_context_capturer
+from core.workflow.context.execution_context import (
+ AppContext,
+ IExecutionContext,
+)
+
+
+@final
+class FlaskAppContext(AppContext):
+ """
+ Flask implementation of AppContext.
+
+ This adapts Flask's app context to the AppContext interface.
+ """
+
+ def __init__(self, flask_app: Flask) -> None:
+ """
+ Initialize Flask app context.
+
+ Args:
+ flask_app: The Flask application instance
+ """
+ self._flask_app = flask_app
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value from Flask app config."""
+ return self._flask_app.config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name."""
+ return self._flask_app.extensions.get(name)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask app context."""
+ with self._flask_app.app_context():
+ yield
+
+ @property
+ def flask_app(self) -> Flask:
+ """Get the underlying Flask app instance."""
+ return self._flask_app
+
+
+def capture_flask_context(user: Any = None) -> IExecutionContext:
+ """
+ Capture current Flask execution context.
+
+ This function captures the Flask app context and contextvars from the
+ current environment. It should be called from within a Flask request or
+ app context.
+
+ Args:
+ user: Optional user object to include in context
+
+ Returns:
+ IExecutionContext with captured Flask context
+
+ Raises:
+ RuntimeError: If called outside Flask context
+ """
+ # Get Flask app instance
+ flask_app = current_app._get_current_object() # type: ignore
+
+ # Save current user if available
+ saved_user = user
+ if saved_user is None:
+ # Check for user in g (flask-login)
+ if hasattr(g, "_login_user"):
+ saved_user = g._login_user
+
+ # Capture contextvars
+ context_vars = contextvars.copy_context()
+
+ return FlaskExecutionContext(
+ flask_app=flask_app,
+ context_vars=context_vars,
+ user=saved_user,
+ )
+
+
+@final
+class FlaskExecutionContext:
+ """
+ Flask-specific execution context.
+
+ This is a specialized version of ExecutionContext that includes Flask app
+ context. It provides the same interface as ExecutionContext but with
+ Flask-specific implementation.
+ """
+
+ def __init__(
+ self,
+ flask_app: Flask,
+ context_vars: contextvars.Context,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize Flask execution context.
+
+ Args:
+ flask_app: Flask application instance
+ context_vars: Python contextvars
+ user: Optional user object
+ """
+ self._app_context = FlaskAppContext(flask_app)
+ self._context_vars = context_vars
+ self._user = user
+ self._flask_app = flask_app
+ self._local = threading.local()
+
+ @property
+ def app_context(self) -> FlaskAppContext:
+ """Get Flask app context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ def __enter__(self) -> "FlaskExecutionContext":
+ """Enter the Flask execution context."""
+ # Restore non-Flask context variables to avoid leaking Flask tokens across threads
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter Flask app context
+ cm = self._app_context.enter()
+ self._local.cm = cm
+ cm.__enter__()
+
+ # Restore user in new app context
+ if self._user is not None:
+ g._login_user = self._user
+
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the Flask execution context."""
+ cm = getattr(self._local, "cm", None)
+ if cm is not None:
+ cm.__exit__(*args)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask execution context as context manager."""
+ # Restore non-Flask context variables to avoid leaking Flask tokens across threads
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter Flask app context
+ with self._flask_app.app_context():
+ # Restore user in new app context
+ if self._user is not None:
+ g._login_user = self._user
+ yield
+
+
+def init_flask_context() -> None:
+ """
+ Initialize Flask context capture by registering the capturer.
+
+ This function should be called during Flask application initialization
+ to register the Flask-specific context capturer with the core context module.
+
+ Example:
+ app = Flask(__name__)
+ init_flask_context() # Register Flask context capturer
+
+ Note:
+ This function does not need the app instance as it uses Flask's
+ `current_app` to get the app when capturing context.
+ """
+ register_context_capturer(capture_flask_context)
diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py
index 0d7bd22e7..3c14d56f6 100644
--- a/api/controllers/console/__init__.py
+++ b/api/controllers/console/__init__.py
@@ -112,10 +112,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
+ banner,
installed_app,
parameter,
recommended_app,
saved_message,
+ trial,
)
# Import tag controllers
@@ -152,6 +154,7 @@ __all__ = [
"apikey",
"app",
"audio",
+ "banner",
"billing",
"bp",
"completion",
@@ -205,6 +208,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
+ "trial",
"trigger_providers",
"version",
"website",
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index a25ca5ef5..e1ee2c24b 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
-from models.model import App, InstalledApp, RecommendedApp
+from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
P = ParamSpec("P")
R = TypeVar("R")
@@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
+ can_trial: bool = Field(default=False)
+ trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
@@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
return supported_language(value)
+class InsertExploreBannerPayload(BaseModel):
+ category: str = Field(...)
+ title: str = Field(...)
+ description: str = Field(...)
+ img_src: str = Field(..., alias="img-src")
+ language: str = Field(default="en-US")
+ link: str = Field(...)
+ sort: int = Field(...)
+
+ @field_validator("language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+ model_config = {"populate_by_name": True}
+
+
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
+console_ns.schema_model(
+ InsertExploreBannerPayload.__name__,
+ InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
def admin_required(view: Callable[P, R]):
@wraps(view)
@@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
+ if payload.can_trial:
+ trial_app = db.session.execute(
+ select(TrialApp).where(TrialApp.app_id == payload.app_id)
+ ).scalar_one_or_none()
+ if not trial_app:
+ db.session.add(
+ TrialApp(
+ app_id=payload.app_id,
+ tenant_id=app.tenant_id,
+ trial_limit=payload.trial_limit,
+ )
+ )
+ else:
+ trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = payload.category
recommended_app.position = payload.position
+ if payload.can_trial:
+ trial_app = db.session.execute(
+ select(TrialApp).where(TrialApp.app_id == payload.app_id)
+ ).scalar_one_or_none()
+ if not trial_app:
+ db.session.add(
+ TrialApp(
+ app_id=payload.app_id,
+ tenant_id=app.tenant_id,
+ trial_limit=payload.trial_limit,
+ )
+ )
+ else:
+ trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
+ trial_app = session.execute(
+ select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
+ ).scalar_one_or_none()
+ if trial_app:
+ session.delete(trial_app)
+
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
+
+
+@console_ns.route("/admin/insert-explore-banner")
+class InsertExploreBannerApi(Resource):
+ @console_ns.doc("insert_explore_banner")
+ @console_ns.doc(description="Insert an explore banner")
+ @console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
+ @console_ns.response(201, "Banner inserted successfully")
+ @only_edition_cloud
+ @admin_required
+ def post(self):
+ payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
+
+ content = {
+ "category": payload.category,
+ "title": payload.title,
+ "description": payload.description,
+ "img-src": payload.img_src,
+ }
+
+ banner = ExporleBanner(
+ content=content,
+ link=payload.link,
+ sort=payload.sort,
+ language=payload.language,
+ )
+ db.session.add(banner)
+ db.session.commit()
+
+ return {"result": "success"}, 201
+
+
+@console_ns.route("/admin/delete-explore-banner/")
+class DeleteExploreBannerApi(Resource):
+ @console_ns.doc("delete_explore_banner")
+ @console_ns.doc(description="Delete an explore banner")
+ @console_ns.doc(params={"banner_id": "Banner ID to delete"})
+ @console_ns.response(204, "Banner deleted successfully")
+ @only_edition_cloud
+ @admin_required
+ def delete(self, banner_id):
+ banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
+ if not banner:
+ raise NotFound(f"Banner '{banner_id}' is not found")
+
+ db.session.delete(banner)
+ db.session.commit()
+
+ return {"result": "success"}, 204
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 449dbb92f..3f018fa2b 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -1,4 +1,3 @@
-import re
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
@@ -87,48 +86,6 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
-# XSS prevention: patterns that could lead to XSS attacks
-# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
-_XSS_PATTERNS = [
- r"", # Script tags
- r")", # Iframe tags (including self-closing)
- r"javascript:", # JavaScript protocol
- r"]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
- r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
- r"]*(?:\s*/>|>.*? )", # Object tags (opening tag)
- r"]*>", # Embed tags (self-closing)
- r" ]*>", # Link tags with javascript
-]
-
-
-def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
- """
- Validate that a string value doesn't contain potential XSS payloads.
-
- Args:
- value: The string value to validate
- field_name: Name of the field for error messages
-
- Returns:
- The original value if safe
-
- Raises:
- ValueError: If the value contains XSS patterns
- """
- if value is None:
- return None
-
- value_lower = value.lower()
- for pattern in _XSS_PATTERNS:
- if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
- raise ValueError(
- f"{field_name} contains invalid characters or patterns. "
- "HTML tags, JavaScript, and other potentially dangerous content are not allowed."
- )
-
- return value
-
-
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@@ -137,11 +94,6 @@ class CreateAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
- @field_validator("name", "description", mode="before")
- @classmethod
- def validate_xss_safe(cls, value: str | None, info) -> str | None:
- return _validate_xss_safe(value, info.field_name)
-
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
@@ -152,11 +104,6 @@ class UpdateAppPayload(BaseModel):
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
- @field_validator("name", "description", mode="before")
- @classmethod
- def validate_xss_safe(cls, value: str | None, info) -> str | None:
- return _validate_xss_safe(value, info.field_name)
-
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
@@ -165,11 +112,6 @@ class CopyAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
- @field_validator("name", "description", mode="before")
- @classmethod
- def validate_xss_safe(cls, value: str | None, info) -> str | None:
- return _validate_xss_safe(value, info.field_name)
-
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 56816dd46..55fdcb51e 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -592,9 +592,12 @@ def _get_conversation(app_model, conversation_id):
if not conversation:
raise NotFound("Conversation Not Exists.")
- if not conversation.read_at:
- conversation.read_at = naive_utc_now()
- conversation.read_account_id = current_user.id
- db.session.commit()
+ db.session.execute(
+ sa.update(Conversation)
+ .where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
+ .values(read_at=naive_utc_now(), read_account_id=current_user.id)
+ )
+ db.session.commit()
+ db.session.refresh(conversation)
return conversation
diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py
index fbd790164..3fa15d6d6 100644
--- a/api/controllers/console/app/error.py
+++ b/api/controllers/console/app/error.py
@@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."
- code = 400
+ code = 404
class DraftWorkflowNotSync(BaseHTTPException):
error_code = "draft_workflow_not_sync"
description = "Workflow graph might have been modified, please refresh and resubmit."
- code = 400
+ code = 409
class TracingConfigNotExist(BaseHTTPException):
@@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
+
+
+class NeedAddIdsError(BaseHTTPException):
+ error_code = "need_add_ids"
+ description = "Need to add ids."
+ code = 400
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index fa67fb815..6736f24a2 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
-from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
+from fields.workflow_app_log_fields import (
+ build_workflow_app_log_pagination_model,
+ build_workflow_archived_log_pagination_model,
+)
from libs.login import login_required
from models import App
from models.model import AppMode
@@ -61,6 +64,7 @@ console_ns.schema_model(
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
+workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@console_ns.route("/apps//workflow-app-logs")
@@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource):
)
return workflow_app_log_pagination
+
+
+@console_ns.route("/apps//workflow-archived-logs")
+class WorkflowArchivedLogApi(Resource):
+ @console_ns.doc("get_workflow_archived_logs")
+ @console_ns.doc(description="Get workflow archived execution logs")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
+ @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model(mode=[AppMode.WORKFLOW])
+ @marshal_with(workflow_archived_log_pagination_model)
+ def get(self, app_model: App):
+ """
+ Get workflow archived logs
+ """
+ args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+
+ workflow_app_service = WorkflowAppService()
+ with Session(db.engine) as session:
+ workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
+ session=session,
+ app_model=app_model,
+ page=args.page,
+ limit=args.limit,
+ )
+
+ return workflow_app_log_pagination
diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py
index 8f1871f1e..fa74f8aea 100644
--- a/api/controllers/console/app/workflow_run.py
+++ b/api/controllers/console/app/workflow_run.py
@@ -1,12 +1,15 @@
+from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
+from sqlalchemy import select
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
+from extensions.ext_database import db
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import (
@@ -19,14 +22,17 @@ from fields.workflow_run_fields import (
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
+from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
-from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
+from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
+EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
+workflow_run_export_fields = console_ns.model(
+ "WorkflowRunExport",
+ {
+ "status": fields.String(description="Export status: success/failed"),
+ "presigned_url": fields.String(description="Pre-signed URL for download", required=False),
+ "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
+ },
+)
+
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
return result
+@console_ns.route("/apps//workflow-runs//export")
+class WorkflowRunExportApi(Resource):
+ @console_ns.doc("get_workflow_run_export_url")
+ @console_ns.doc(description="Generate a download URL for an archived workflow run.")
+ @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
+ @console_ns.response(200, "Export URL generated", workflow_run_export_fields)
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model()
+ def get(self, app_model: App, run_id: str):
+ tenant_id = str(app_model.tenant_id)
+ app_id = str(app_model.id)
+ run_id_str = str(run_id)
+
+ run_created_at = db.session.scalar(
+ select(WorkflowArchiveLog.run_created_at)
+ .where(
+ WorkflowArchiveLog.tenant_id == tenant_id,
+ WorkflowArchiveLog.app_id == app_id,
+ WorkflowArchiveLog.workflow_run_id == run_id_str,
+ )
+ .limit(1)
+ )
+ if not run_created_at:
+ return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404
+
+ prefix = (
+ f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/"
+ f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}"
+ )
+ archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+
+ try:
+ archive_storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ return {"code": "archive_storage_not_configured", "message": str(e)}, 500
+
+ presigned_url = archive_storage.generate_presigned_url(
+ archive_key,
+ expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
+ )
+ expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
+ return {
+ "status": "success",
+ "presigned_url": presigned_url,
+ "presigned_url_expires_at": expires_at.isoformat(),
+ }, 200
+
+
@console_ns.route("/apps//advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs_count")
diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py
index 9bb2718f8..e687d980f 100644
--- a/api/controllers/console/app/wraps.py
+++ b/api/controllers/console/app/wraps.py
@@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
+def _load_app_model_with_trial(app_id: str) -> App | None:
+ app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
+ return app_model
+
+
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
+
+
+def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
+ def decorator(view_func: Callable[P, R]):
+ @wraps(view_func)
+ def decorated_view(*args: P.args, **kwargs: P.kwargs):
+ if not kwargs.get("app_id"):
+ raise ValueError("missing app_id in path parameters")
+
+ app_id = kwargs.get("app_id")
+ app_id = str(app_id)
+
+ del kwargs["app_id"]
+
+ app_model = _load_app_model_with_trial(app_id)
+
+ if not app_model:
+ raise AppNotFoundError()
+
+ app_mode = AppMode.value_of(app_model.mode)
+
+ if mode is not None:
+ if isinstance(mode, list):
+ modes = mode
+ else:
+ modes = [mode]
+
+ if app_mode not in modes:
+ mode_values = {m.value for m in modes}
+ raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
+
+ kwargs["app_model"] = app_model
+
+ return view_func(*args, **kwargs)
+
+ return decorated_view
+
+ if view is None:
+ return decorator
+ else:
+ return decorator(view)
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index fe70d930f..f741107b8 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -63,13 +63,19 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
- reg_email = args.email
token = args.token
- invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
+ invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
+
+ # Check workspace permission
+ if tenant:
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
@@ -100,11 +106,12 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
- invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
+ normalized_request_email = args.email.lower() if args.email else None
+ invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
- RegisterService.revoke_token(args.workspace_id, args.email, args.token)
+ RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"]
account.name = args.name
diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py
index fa082c735..c2a95ddad 100644
--- a/api/controllers/console/auth/email_register.py
+++ b/api/controllers/console/auth/email_register.py
@@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
- token = None
- token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
+ token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
+ is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_email_register_error_rate_limit(args.email)
+ AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
- AccountService.reset_email_register_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_email_register_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register")
@@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
+ normalized_email = email.lower()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
- account = self._create_new_account(email, args.password_confirm)
+ account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
- def _create_new_account(self, email, password) -> Account | None:
+ def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
try:
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index 661f59118..394f205d9 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
- email=args.email,
+ email=normalized_email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(args.email)
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=args.code, additional_data={"phase": "reset"}
+ token_email, code=args.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets")
@@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
-
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index 4a52bf8ab..400df138b 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -90,32 +90,38 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
+ request_email = args.email
+ normalized_email = request_email.lower()
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
- is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
+ is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
+ invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
- if args.invite_token:
- invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
+ if invite_token:
+ invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
+ if invitation_data is None:
+ invite_token = None
try:
if invitation_data:
data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
- if invitee_email != args.email:
+ invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
+ if invitee_email_normalized != normalized_email:
raise InvalidEmailError()
- account = AccountService.authenticate(args.email, args.password, args.invite_token)
- else:
- account = AccountService.authenticate(args.email, args.password)
+ account = _authenticate_account_with_case_fallback(
+ request_email, normalized_email, args.password, invite_token
+ )
except services.errors.account.AccountLoginError:
raise AccountBannedError()
- except services.errors.account.AccountPasswordError:
- AccountService.add_login_error_rate_limit(args.email)
- raise AuthenticationFailedError()
+ except services.errors.account.AccountPasswordError as exc:
+ AccountService.add_login_error_rate_limit(normalized_email)
+ raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@@ -130,7 +136,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
- email=args.email,
+ email=normalized_email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
- token = AccountService.send_email_code_login_email(email=args.email, language=language)
+ token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else:
raise AccountNotFound()
else:
@@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
- user_email = args.email
+ original_email = args.email
+ user_email = original_email.lower()
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args.email:
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
- account = AccountService.get_user_through_email(user_email)
+ account = _get_account_with_case_fallback(original_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
+
+
+def _get_account_with_case_fallback(email: str):
+ account = AccountService.get_user_through_email(email)
+ if account or email == email.lower():
+ return account
+
+ return AccountService.get_user_through_email(email.lower())
+
+
+def _authenticate_account_with_case_fallback(
+ original_email: str, normalized_email: str, password: str, invite_token: str | None
+):
+ try:
+ return AccountService.authenticate(original_email, password, invite_token)
+ except services.errors.account.AccountPasswordError:
+ if original_email == normalized_email:
+ raise
+ return AccountService.authenticate(normalized_email, password, invite_token)
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index e5a05a219..c98891b90 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -4,7 +4,6 @@ from typing import Optional # Extend: OAuto third-party login
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
-from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -155,7 +154,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
- if invitation_email != user_info.email:
+ invitation_email_normalized = (
+ invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
+ )
+ if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@@ -213,7 +215,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
@@ -235,9 +237,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
+ normalized_email = user_info.email.lower()
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@@ -248,7 +251,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
- email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
+ email=normalized_email,
+ name=account_name,
+ password=None,
+ open_id=user_info.id,
+ provider=provider,
)
# Set interface language
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index ac78d3854..2599e6293 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -2,12 +2,14 @@ import json
import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
-from typing import Literal, cast
+from contextlib import ExitStack
+from typing import Any, Literal, cast
+from uuid import UUID
import sqlalchemy as sa
-from flask import request
+from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -42,6 +44,7 @@ from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
+from services.file_service import FileService
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@@ -65,6 +68,9 @@ from ..wraps import (
logger = logging.getLogger(__name__)
+# NOTE: Keep constants near the top of the module for discoverability.
+DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
+
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
@@ -104,6 +110,21 @@ class DocumentRenamePayload(BaseModel):
name: str
+class DocumentBatchDownloadZipPayload(BaseModel):
+ """Request payload for bulk downloading documents as a zip archive."""
+
+ document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
+
+
+class DocumentDatasetListParam(BaseModel):
+ page: int = Field(1, title="Page", description="Page number.")
+ limit: int = Field(20, title="Limit", description="Page size.")
+ search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
+ sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
+ status: str | None = Field(None, title="Status", description="Document status.")
+ fetch_val: str = Field("false", alias="fetch")
+
+
register_schema_models(
console_ns,
KnowledgeConfig,
@@ -111,6 +132,7 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
+ DocumentBatchDownloadZipPayload,
)
@@ -225,14 +247,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
- sort = request.args.get("sort", default="-created_at", type=str)
- status = request.args.get("status", default=None, type=str)
+ raw_args = request.args.to_dict()
+ param = DocumentDatasetListParam.model_validate(raw_args)
+ page = param.page
+ limit = param.limit
+ search = param.search
+ sort = param.sort_by
+ status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
- fetch_val = request.args.get("fetch", default="false")
+ fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
@@ -842,6 +866,62 @@ class DocumentApi(DocumentResource):
return {"result": "success"}, 204
+@console_ns.route("/datasets//documents//download")
+class DocumentDownloadApi(DocumentResource):
+ """Return a signed download URL for a dataset document's original uploaded file."""
+
+ @console_ns.doc("get_dataset_document_download_url")
+ @console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
+ # Reuse the shared permission/tenant checks implemented in DocumentResource.
+ document = self.get_document(str(dataset_id), str(document_id))
+ return {"url": DocumentService.get_document_download_url(document)}
+
+
+@console_ns.route("/datasets//documents/download-zip")
+class DocumentBatchDownloadZipApi(DocumentResource):
+ """Download multiple uploaded-file documents as a single ZIP (avoids browser multi-download limits)."""
+
+ @console_ns.doc("download_dataset_documents_as_zip")
+ @console_ns.doc(description="Download selected dataset documents as a single ZIP archive (upload-file only)")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
+ def post(self, dataset_id: str):
+ """Stream a ZIP archive containing the requested uploaded documents."""
+ # Parse and validate request payload.
+ payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
+
+ current_user, current_tenant_id = current_account_with_tenant()
+ dataset_id = str(dataset_id)
+ document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
+ upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
+ dataset_id=dataset_id,
+ document_ids=document_ids,
+ tenant_id=current_tenant_id,
+ current_user=current_user,
+ )
+
+ # Delegate ZIP packing to FileService, but keep Flask response+cleanup in the route.
+ with ExitStack() as stack:
+ zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
+ response = send_file(
+ zip_path,
+ mimetype="application/zip",
+ as_attachment=True,
+ download_name=download_name,
+ )
+ cleanup = stack.pop_all()
+ response.call_on_close(cleanup.close)
+ return response
+
+
@console_ns.route("/datasets//documents//processing/")
class DocumentProcessingApi(DocumentResource):
@console_ns.doc("update_document_processing")
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index 89c9fcad3..a70a7ce48 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -81,7 +81,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
- name: str = Field(..., min_length=1, max_length=40)
+ name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py
new file mode 100644
index 000000000..da306fbc9
--- /dev/null
+++ b/api/controllers/console/explore/banner.py
@@ -0,0 +1,43 @@
+from flask import request
+from flask_restx import Resource
+
+from controllers.console import api
+from controllers.console.explore.wraps import explore_banner_enabled
+from extensions.ext_database import db
+from models.model import ExporleBanner
+
+
+class BannerApi(Resource):
+ """Resource for banner list."""
+
+ @explore_banner_enabled
+ def get(self):
+ """Get banner list."""
+ language = request.args.get("language", "en-US")
+
+ # Build base query for enabled banners
+ base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
+
+ # Try to get banners in the requested language
+ banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
+
+ # Fallback to en-US if no banners found and language is not en-US
+ if not banners and language != "en-US":
+ banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
+ # Convert banners to serializable format
+ result = []
+ for banner in banners:
+ banner_data = {
+ "id": banner.id,
+ "content": banner.content, # Already parsed as JSON by SQLAlchemy
+ "link": banner.link,
+ "sort": banner.sort,
+ "status": banner.status,
+ "created_at": banner.created_at.isoformat() if banner.created_at else None,
+ }
+ result.append(banner_data)
+
+ return result
+
+
+api.add_resource(BannerApi, "/explore/banners")
diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py
index 1e05ff420..e96fa64f8 100644
--- a/api/controllers/console/explore/error.py
+++ b/api/controllers/console/explore/error.py
@@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
+
+
+class TrialAppNotAllowed(BaseHTTPException):
+ """*403* `Trial App Not Allowed`
+
+ Raise if the user has reached the trial app limit.
+ """
+
+ error_code = "trial_app_not_allowed"
+ code = 403
+ description = "the app is not allowed to be trial."
+
+
+class TrialAppLimitExceeded(BaseHTTPException):
+ """*403* `Trial App Limit Exceeded`
+
+ Raise if the user has exceeded the trial app limit.
+ """
+
+ error_code = "trial_app_limit_exceeded"
+ code = 403
+ description = "The user has exceeded the trial app limit."
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index 2b2f80769..362513ec1 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -29,6 +29,7 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
+ "can_trial": fields.Boolean,
}
recommended_app_list_fields = {
diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py
new file mode 100644
index 000000000..97d856beb
--- /dev/null
+++ b/api/controllers/console/explore/trial.py
@@ -0,0 +1,512 @@
+import logging
+from typing import Any, cast
+
+from flask import request
+from flask_restx import Resource, marshal, marshal_with, reqparse
+from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
+
+import services
+from controllers.common.fields import Parameters as ParametersResponse
+from controllers.common.fields import Site as SiteResponse
+from controllers.console import api
+from controllers.console.app.error import (
+ AppUnavailableError,
+ AudioTooLargeError,
+ CompletionRequestError,
+ ConversationCompletedError,
+ NeedAddIdsError,
+ NoAudioUploadedError,
+ ProviderModelCurrentlyNotSupportError,
+ ProviderNotInitializeError,
+ ProviderNotSupportSpeechToTextError,
+ ProviderQuotaExceededError,
+ UnsupportedAudioTypeError,
+)
+from controllers.console.app.wraps import get_app_model_with_trial
+from controllers.console.explore.error import (
+ AppSuggestedQuestionsAfterAnswerDisabledError,
+ NotChatAppError,
+ NotCompletionAppError,
+ NotWorkflowAppError,
+)
+from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
+from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
+from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
+from core.app.apps.base_app_queue_manager import AppQueueManager
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.errors.error import (
+ ModelCurrentlyNotSupportError,
+ ProviderTokenNotInitError,
+ QuotaExceededError,
+)
+from core.model_runtime.errors.invoke import InvokeError
+from core.workflow.graph_engine.manager import GraphEngineManager
+from extensions.ext_database import db
+from fields.app_fields import app_detail_fields_with_site
+from fields.dataset_fields import dataset_fields
+from fields.workflow_fields import workflow_fields
+from libs import helper
+from libs.helper import uuid_value
+from libs.login import current_user
+from models import Account
+from models.account import TenantStatus
+from models.model import AppMode, Site
+from models.workflow import Workflow
+from services.app_generate_service import AppGenerateService
+from services.app_service import AppService
+from services.audio_service import AudioService
+from services.dataset_service import DatasetService
+from services.errors.audio import (
+ AudioTooLargeServiceError,
+ NoAudioUploadedServiceError,
+ ProviderNotSupportSpeechToTextServiceError,
+ UnsupportedAudioTypeServiceError,
+)
+from services.errors.conversation import ConversationNotExistsError
+from services.errors.llm import InvokeRateLimitError
+from services.errors.message import (
+ MessageNotExistsError,
+ SuggestedQuestionsAfterAnswerDisabledError,
+)
+from services.message_service import MessageService
+from services.recommended_app_service import RecommendedAppService
+
+logger = logging.getLogger(__name__)
+
+
+class TrialAppWorkflowRunApi(TrialAppResource):
+ def post(self, trial_app):
+ """
+ Run workflow
+ """
+ app_model = trial_app
+ if not app_model:
+ raise NotWorkflowAppError()
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode != AppMode.WORKFLOW:
+ raise NotWorkflowAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("files", type=list, required=False, location="json")
+ args = parser.parse_args()
+ assert current_user is not None
+ try:
+ app_id = app_model.id
+ user_id = current_user.id
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
+ )
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except InvokeRateLimitError as ex:
+ raise InvokeRateLimitHttpError(ex.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialAppWorkflowTaskStopApi(TrialAppResource):
+ def post(self, trial_app, task_id: str):
+ """
+ Stop workflow task
+ """
+ app_model = trial_app
+ if not app_model:
+ raise NotWorkflowAppError()
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode != AppMode.WORKFLOW:
+ raise NotWorkflowAppError()
+ assert current_user is not None
+
+ # Stop using both mechanisms for backward compatibility
+ # Legacy stop flag mechanism (without user check)
+ AppQueueManager.set_stop_flag_no_user_check(task_id)
+
+ # New graph engine command channel mechanism
+ GraphEngineManager.send_stop_command(task_id)
+
+ return {"result": "success"}
+
+
+class TrialChatApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
+ raise NotChatAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, location="json")
+ parser.add_argument("query", type=str, required=True, location="json")
+ parser.add_argument("files", type=list, required=False, location="json")
+ parser.add_argument("conversation_id", type=uuid_value, location="json")
+ parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
+ parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
+ args = parser.parse_args()
+
+ args["auto_generate_name"] = False
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
+ )
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except InvokeRateLimitError as ex:
+ raise InvokeRateLimitHttpError(ex.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialMessageSuggestedQuestionApi(TrialAppResource):
+ @trial_feature_enable
+ def get(self, trial_app, message_id):
+ app_model = trial_app
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
+ raise NotChatAppError()
+
+ message_id = str(message_id)
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+ questions = MessageService.get_suggested_questions_after_answer(
+ app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
+ )
+ except MessageNotExistsError:
+ raise NotFound("Message not found")
+ except ConversationNotExistsError:
+ raise NotFound("Conversation not found")
+ except SuggestedQuestionsAfterAnswerDisabledError:
+ raise AppSuggestedQuestionsAfterAnswerDisabledError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+ return {"data": questions}
+
+
+class TrialChatAudioApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+
+ file = request.files["file"]
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return response
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except NoAudioUploadedServiceError:
+ raise NoAudioUploadedError()
+ except AudioTooLargeServiceError as e:
+ raise AudioTooLargeError(str(e))
+ except UnsupportedAudioTypeServiceError:
+ raise UnsupportedAudioTypeError()
+ except ProviderNotSupportSpeechToTextServiceError:
+ raise ProviderNotSupportSpeechToTextError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialChatTextApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ try:
+ parser = reqparse.RequestParser()
+ parser.add_argument("message_id", type=str, required=False, location="json")
+ parser.add_argument("voice", type=str, location="json")
+ parser.add_argument("text", type=str, location="json")
+ parser.add_argument("streaming", type=bool, location="json")
+ args = parser.parse_args()
+
+ message_id = args.get("message_id", None)
+ text = args.get("text", None)
+ voice = args.get("voice", None)
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return response
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except NoAudioUploadedServiceError:
+ raise NoAudioUploadedError()
+ except AudioTooLargeServiceError as e:
+ raise AudioTooLargeError(str(e))
+ except UnsupportedAudioTypeServiceError:
+ raise UnsupportedAudioTypeError()
+ except ProviderNotSupportSpeechToTextServiceError:
+ raise ProviderNotSupportSpeechToTextError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialCompletionApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ if app_model.mode != "completion":
+ raise NotCompletionAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, location="json")
+ parser.add_argument("query", type=str, location="json", default="")
+ parser.add_argument("files", type=list, required=False, location="json")
+ parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+ parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
+ args = parser.parse_args()
+
+ streaming = args["response_mode"] == "streaming"
+ args["auto_generate_name"] = False
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
+ )
+
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialSitApi(Resource):
+ """Resource for trial app sites."""
+
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ """Retrieve app site info.
+
+ Returns the site configuration for the application including theme, icons, and text.
+ """
+ site = db.session.query(Site).where(Site.app_id == app_model.id).first()
+
+ if not site:
+ raise Forbidden()
+
+ assert app_model.tenant
+ if app_model.tenant.status == TenantStatus.ARCHIVE:
+ raise Forbidden()
+
+ return SiteResponse.model_validate(site).model_dump(mode="json")
+
+
+class TrialAppParameterApi(Resource):
+ """Resource for app variables."""
+
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ """Retrieve app parameters."""
+
+ if app_model is None:
+ raise AppUnavailableError()
+
+ if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+ workflow = app_model.workflow
+ if workflow is None:
+ raise AppUnavailableError()
+
+ features_dict = workflow.features_dict
+ user_input_form = workflow.user_input_form(to_old_structure=True)
+ else:
+ app_model_config = app_model.app_model_config
+ if app_model_config is None:
+ raise AppUnavailableError()
+
+ features_dict = app_model_config.to_dict()
+
+ user_input_form = features_dict.get("user_input_form", [])
+
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return ParametersResponse.model_validate(parameters).model_dump(mode="json")
+
+
+class AppApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ @marshal_with(app_detail_fields_with_site)
+ def get(self, app_model):
+ """Get app detail"""
+
+ app_service = AppService()
+ app_model = app_service.get_app(app_model)
+
+ return app_model
+
+
+class AppWorkflowApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ @marshal_with(workflow_fields)
+ def get(self, app_model):
+ """Get workflow detail"""
+ if not app_model.workflow_id:
+ raise AppUnavailableError()
+
+ workflow = (
+ db.session.query(Workflow)
+ .where(
+ Workflow.id == app_model.workflow_id,
+ )
+ .first()
+ )
+ return workflow
+
+
+class DatasetListApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ page = request.args.get("page", default=1, type=int)
+ limit = request.args.get("limit", default=20, type=int)
+ ids = request.args.getlist("ids")
+
+ tenant_id = app_model.tenant_id
+ if ids:
+ datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
+ else:
+ raise NeedAddIdsError()
+
+ data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
+
+ response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
+ return response
+
+
+api.add_resource(TrialChatApi, "/trial-apps//chat-messages", endpoint="trial_app_chat_completion")
+
+api.add_resource(
+ TrialMessageSuggestedQuestionApi,
+ "/trial-apps//messages//suggested-questions",
+ endpoint="trial_app_suggested_question",
+)
+
+api.add_resource(TrialChatAudioApi, "/trial-apps//audio-to-text", endpoint="trial_app_audio")
+api.add_resource(TrialChatTextApi, "/trial-apps//text-to-audio", endpoint="trial_app_text")
+
+api.add_resource(TrialCompletionApi, "/trial-apps//completion-messages", endpoint="trial_app_completion")
+
+api.add_resource(TrialSitApi, "/trial-apps//site")
+
+api.add_resource(TrialAppParameterApi, "/trial-apps//parameters", endpoint="trial_app_parameters")
+
+api.add_resource(AppApi, "/trial-apps/", endpoint="trial_app")
+
+api.add_resource(TrialAppWorkflowRunApi, "/trial-apps//workflows/run", endpoint="trial_app_workflow_run")
+api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps//workflows/tasks//stop")
+
+api.add_resource(AppWorkflowApi, "/trial-apps//workflows", endpoint="trial_app_workflow")
+api.add_resource(DatasetListApi, "/trial-apps//datasets", endpoint="trial_app_datasets")
diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py
index 2aee795a2..4e99604d0 100644
--- a/api/controllers/console/explore/wraps.py
+++ b/api/controllers/console/explore/wraps.py
@@ -2,13 +2,16 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
+from flask import abort
from flask_restx import Resource
from werkzeug.exceptions import NotFound
+from controllers.console.explore.error import TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
-from models import InstalledApp
+from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
+from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
@@ -57,6 +60,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
+def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
+ def decorator(view: Callable[Concatenate[App, P], R]):
+ @wraps(view)
+ def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
+ current_user, _ = current_account_with_tenant()
+
+ trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
+
+ if trial_app is None:
+ raise TrialAppNotAllowed()
+ app = trial_app.app
+
+ if app is None:
+ raise TrialAppNotAllowed()
+
+ account_trial_app_record = (
+ db.session.query(AccountTrialAppRecord)
+ .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
+ .first()
+ )
+ if account_trial_app_record:
+ if account_trial_app_record.count >= trial_app.trial_limit:
+ raise TrialAppLimitExceeded()
+
+ return view(app, *args, **kwargs)
+
+ return decorated
+
+ if view:
+ return decorator(view)
+ return decorator
+
+
+def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if not features.enable_trial_app:
+ abort(403, "Trial app feature is not enabled.")
+ return view(*args, **kwargs)
+
+ return decorated
+
+
+def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if not features.enable_explore_banner:
+ abort(403, "Explore banner feature is not enabled.")
+ return view(*args, **kwargs)
+
+ return decorated
+
+
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@@ -66,3 +124,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
+
+
+class TrialAppResource(Resource):
+ # must be reversed if there are multiple decorators
+
+ method_decorators = [
+ trial_app_required,
+ account_initialization_required,
+ login_required,
+ ]
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index 6951c906e..d3811e2d1 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -1,6 +1,7 @@
from flask_restx import Resource, fields
+from werkzeug.exceptions import Unauthorized
-from libs.login import current_account_with_tenant, login_required
+from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureService
from . import console_ns
@@ -39,5 +40,21 @@ class SystemFeatureApi(Resource):
),
)
def get(self):
- """Get system-wide feature configuration"""
- return FeatureService.get_system_features().model_dump()
+ """Get system-wide feature configuration
+
+ NOTE: This endpoint is unauthenticated by design, as it provides system features
+ data required for dashboard initialization.
+
+ Authentication would create circular dependency (can't login without dashboard loading).
+
+ Only non-sensitive configuration data should be returned by this endpoint.
+ """
+ # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
+ # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
+ # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
+ # raise `Unauthorized` exception if authentication token is not provided.
+ try:
+ is_authenticated = current_user.is_authenticated
+ except Unauthorized:
+ is_authenticated = False
+ return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py
index 25a3d8052..d480af312 100644
--- a/api/controllers/console/ping.py
+++ b/api/controllers/console/ping.py
@@ -1,17 +1,17 @@
-from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
-from . import console_ns
+from controllers.fastopenapi import console_router
-@console_ns.route("/ping")
-class PingApi(Resource):
- @console_ns.doc("health_check")
- @console_ns.doc(description="Health check endpoint for connection testing")
- @console_ns.response(
- 200,
- "Success",
- console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
- )
- def get(self):
- """Health check endpoint for connection testing"""
- return {"result": "pong"}
+class PingResponse(BaseModel):
+ result: str = Field(description="Health check result", examples=["pong"])
+
+
+@console_router.get(
+ "/ping",
+ response_model=PingResponse,
+ tags=["console"],
+)
+def ping() -> PingResponse:
+ """Health check endpoint for connection testing."""
+ return PingResponse(result="pong")
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index 7fa02ae28..ed22ef045 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -84,10 +84,11 @@ class SetupApi(Resource):
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
# setup
RegisterService.setup(
- email=args.email,
+ email=normalized_email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index ea3e326b4..d77675149 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -31,6 +31,11 @@ class TagBindingRemovePayload(BaseModel):
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
+class TagListQueryParam(BaseModel):
+ type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
+ keyword: str | None = Field(None, description="Search keyword")
+
+
register_schema_models(
console_ns,
TagBasePayload,
@@ -44,12 +49,15 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
+ @console_ns.doc(
+ params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
+ )
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
- tag_type = request.args.get("type", type=str, default="")
- keyword = request.args.get("keyword", default=None, type=str)
- tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
+ raw_args = request.args.to_dict()
+ param = TagListQueryParam.model_validate(raw_args)
+ tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
return tags, 200
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 03ad0f423..527aabbc3 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -41,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
-from models import Account, AccountIntegrate, InvitationCode
+from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
- user_email = args.email
+ user_email = None
+ email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
- if user_email != current_user.email:
+ if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
+
+ user_email = current_user.email
else:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
+ email_for_sending = account.email
+ user_email = account.email
token = AccountService.send_change_email_email(
- account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
+ account=account,
+ email=email_for_sending,
+ old_email=user_email,
+ language=language,
+ phase=args.phase,
)
return {"result": "success", "data": token}
@@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_change_email_error_rate_limit(args.email)
+ AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
- AccountService.reset_change_email_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_change_email_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset")
@@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
+ normalized_new_email = args.new_email.lower()
- if AccountService.is_account_in_freeze(args.new_email):
+ if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.new_email):
+ if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
- if current_user.email != old_email:
+ if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
- updated_account = AccountService.update_account_email(current_user, email=args.new_email)
+ updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
- email=args.new_email,
+ email=normalized_new_email,
)
return updated_account
@@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
- if AccountService.is_account_in_freeze(args.email):
+ normalized_email = args.email.lower()
+ if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.email):
+ if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index 0142e14fb..01cca2a8a 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -107,6 +107,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
+
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(inviter.current_tenant.id)
+
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
@@ -116,26 +122,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
+ normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
- inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
+ tenant=inviter.current_tenant,
+ email=invitee_email,
+ language=interface_language,
+ role=invitee_role,
+ inviter=inviter,
)
- encoded_invitee_email = parse.quote(invitee_email)
+ encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
- "email": invitee_email,
+ "email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
- {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
+ {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
- invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
+ invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index 1fde362f8..163d35cd9 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
+ only_edition_enterprise,
setup_required,
)
from enums.cloud_plan import CloudPlan
@@ -28,6 +29,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
+from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@@ -292,3 +294,31 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
+
+
+@console_ns.route("/workspaces/current/permission")
+class WorkspacePermissionApi(Resource):
+ """Get workspace permissions for the current workspace."""
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @only_edition_enterprise
+ def get(self):
+ """
+ Get workspace permission settings.
+ Returns permission flags that control workspace features like member invitations and owner transfer.
+ """
+ _, current_tenant_id = current_account_with_tenant()
+
+ if not current_tenant_id:
+ raise ValueError("No current tenant")
+
+ # Get workspace permissions from enterprise service
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
+
+ return {
+ "workspace_id": permission.workspace_id,
+ "allow_member_invite": permission.allow_member_invite,
+ "allow_owner_transfer": permission.allow_owner_transfer,
+ }, 200
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index 95fc006a1..fd928b077 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -286,13 +286,12 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
- _, current_tenant_id = current_account_with_tenant()
- features = FeatureService.get_features(current_tenant_id)
- if features.is_allow_transfer_workspace:
- return view(*args, **kwargs)
+ from libs.workspace_permission import check_workspace_owner_transfer_permission
- # otherwise, return 403
- abort(403)
+ _, current_tenant_id = current_account_with_tenant()
+ # Check both billing/plan level and workspace policy level permissions
+ check_workspace_owner_transfer_permission(current_tenant_id)
+ return view(*args, **kwargs)
return decorated
diff --git a/api/controllers/fastopenapi.py b/api/controllers/fastopenapi.py
new file mode 100644
index 000000000..c13f22338
--- /dev/null
+++ b/api/controllers/fastopenapi.py
@@ -0,0 +1,3 @@
+from fastopenapi.routers import FlaskRouter
+
+console_router = FlaskRouter()
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index c800c0e4e..49ff4f57d 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -261,17 +261,6 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
- args = {}
- if "data" in request.form:
- args = json.loads(request.form["data"])
- if "doc_form" not in args:
- args["doc_form"] = "text_model"
- if "doc_language" not in args:
- args["doc_language"] = "English"
-
- # get dataset info
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -280,6 +269,18 @@ class DocumentAddByFileApi(DatasetApiResource):
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
+ args = {}
+ if "data" in request.form:
+ args = json.loads(request.form["data"])
+ if "doc_form" not in args:
+ args["doc_form"] = dataset.chunk_structure or "text_model"
+ if "doc_language" not in args:
+ args["doc_language"] = "English"
+
+ # get dataset info
+ dataset_id = str(dataset_id)
+ tenant_id = str(tenant_id)
+
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.")
@@ -370,17 +371,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
- args = {}
- if "data" in request.form:
- args = json.loads(request.form["data"])
- if "doc_form" not in args:
- args["doc_form"] = "text_model"
- if "doc_language" not in args:
- args["doc_language"] = "English"
-
- # get dataset info
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -389,6 +379,18 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
+ args = {}
+ if "data" in request.form:
+ args = json.loads(request.form["data"])
+ if "doc_form" not in args:
+ args["doc_form"] = dataset.chunk_structure or "text_model"
+ if "doc_language" not in args:
+ args["doc_language"] = "English"
+
+ # get dataset info
+ dataset_id = str(dataset_id)
+ tenant_id = str(tenant_id)
+
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py
index cce3dae95..2540bf02f 100644
--- a/api/controllers/web/feature.py
+++ b/api/controllers/web/feature.py
@@ -17,5 +17,15 @@ class SystemFeatureApi(Resource):
Returns:
dict: System feature configuration object
+
+ This endpoint is akin to the `SystemFeatureApi` endpoint in api/controllers/console/feature.py,
+ except it is intended for use by the web app, instead of the console dashboard.
+
+ NOTE: This endpoint is unauthenticated by design, as it provides system features
+ data required for webapp initialization.
+
+ Authentication would create circular dependency (can't authenticate without webapp loading).
+
+ Only non-sensitive configuration data should be returned by this endpoint.
"""
return FeatureService.get_system_features().model_dump()
diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py
index 690b76655..91d206f72 100644
--- a/api/controllers/web/forgot_password.py
+++ b/api/controllers/web/forgot_password.py
@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@@ -22,7 +21,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
+from models.account import Account
from services.account_service import AccountService
@@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
+ request_email = payload.email
+ normalized_email = request_email.lower()
+
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
raise AuthenticationFailedError()
else:
- token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
+ token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token}
@@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
- user_email = payload.email
+ user_email = payload.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if payload.code != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(payload.email)
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=payload.code, additional_data={"phase": "reset"}
+ token_email, code=payload.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(payload.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets")
@@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py
index bf404dc9d..398aa5107 100644
--- a/api/controllers/web/login.py
+++ b/api/controllers/web/login.py
@@ -1,9 +1,11 @@
from flask import make_response, request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
from jwt import InvalidTokenError
+from pydantic import BaseModel, Field, field_validator
import services
from configs import dify_config
+from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
@@ -18,7 +20,7 @@ from controllers.console.wraps import (
)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
-from libs.helper import email
+from libs.helper import EmailStr
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@@ -31,10 +33,35 @@ from services.app_service import AppService
from services.webapp_auth_service import WebAppAuthService
+class LoginPayload(BaseModel):
+ email: EmailStr
+ password: str
+
+ @field_validator("password")
+ @classmethod
+ def validate_password(cls, value: str) -> str:
+ return valid_password(value)
+
+
+class EmailCodeLoginSendPayload(BaseModel):
+ email: EmailStr
+ language: str | None = None
+
+
+class EmailCodeLoginVerifyPayload(BaseModel):
+ email: EmailStr
+ code: str
+ token: str = Field(min_length=1)
+
+
+register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
+
+
@web_ns.route("/login")
class LoginApi(Resource):
"""Resource for web app email/password login."""
+ @web_ns.expect(web_ns.models[LoginPayload.__name__])
@setup_required
@only_edition_enterprise
@web_ns.doc("web_app_login")
@@ -51,15 +78,10 @@ class LoginApi(Resource):
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("password", type=valid_password, required=True, location="json")
- )
- args = parser.parse_args()
+ payload = LoginPayload.model_validate(web_ns.payload or {})
try:
- account = WebAppAuthService.authenticate(args["email"], args["password"])
+ account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
@@ -159,6 +181,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@only_edition_enterprise
@web_ns.doc("send_email_code_login")
@web_ns.doc(description="Send email verification code for login")
+ @web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
@web_ns.doc(
responses={
200: "Email code sent successfully",
@@ -167,19 +190,14 @@ class EmailCodeLoginSendEmailApi(Resource):
}
)
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if payload.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
- account = WebAppAuthService.get_user_through_email(args["email"])
+ account = WebAppAuthService.get_user_through_email(payload.email)
if account is None:
raise AuthenticationFailedError()
else:
@@ -193,6 +211,7 @@ class EmailCodeLoginApi(Resource):
@only_edition_enterprise
@web_ns.doc("verify_email_code_login")
@web_ns.doc(description="Verify email code and complete login")
+ @web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
@web_ns.doc(
responses={
200: "Email code verified and login successful",
@@ -203,33 +222,31 @@ class EmailCodeLoginApi(Resource):
)
@decrypt_code_field
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, location="json")
- )
- args = parser.parse_args()
+ payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
- user_email = args["email"]
+ user_email = payload.email.lower()
- token_data = WebAppAuthService.get_email_code_login_data(args["token"])
+ token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args["email"]:
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+ if normalized_token_email != user_email:
raise InvalidEmailError()
- if token_data["code"] != args["code"]:
+ if token_data["code"] != payload.code:
raise EmailCodeError()
- WebAppAuthService.revoke_email_code_login_token(args["token"])
- account = WebAppAuthService.get_user_through_email(user_email)
+ WebAppAuthService.revoke_email_code_login_token(payload.token)
+ account = WebAppAuthService.get_user_through_email(token_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
- AccountService.reset_login_error_rate_limit(args["email"])
+ AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py
index bb69423d7..af96b057a 100644
--- a/api/controllers/web/workflow.py
+++ b/api/controllers/web/workflow.py
@@ -1,8 +1,10 @@
import logging
+from typing import Any
-from flask_restx import reqparse
+from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
CompletionRequestError,
@@ -27,6 +29,12 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
+
+class WorkflowRunPayload(BaseModel):
+ inputs: dict[str, Any] = Field(description="Input variables for the workflow")
+ files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
+
+
logger = logging.getLogger(__name__)
# extend: start 您必须登录才能访问您的帐户扩展功能
@@ -39,17 +47,14 @@ from services.app_generate_service_extend import AppGenerateServiceExtend
# extend: stop 您必须登录才能访问您的帐户扩展功能
+register_schema_models(web_ns, WorkflowRunPayload)
+
@web_ns.route("/workflows/run")
class WorkflowRunApi(WebApiResource):
@web_ns.doc("Run Workflow")
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
- @web_ns.doc(
- params={
- "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
- "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
- }
- )
+ @web_ns.expect(web_ns.models[WorkflowRunPayload.__name__])
@web_ns.doc(
responses={
200: "Success",
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index c196dbbdf..3c6d36afe 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -1,6 +1,7 @@
import json
import logging
import uuid
+from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
@@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
+from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
+ tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
- message_unit_price=0,
- message_price_unit=0,
+ message_unit_price=Decimal(0),
+ message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
- answer_unit_price=0,
- answer_price_unit=0,
+ answer_unit_price=Decimal(0),
+ answer_price_unit=Decimal("0.001"),
tokens=0,
- total_price=0,
+ total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
- created_by_role="account",
+ created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
- agent_thought.thought += thought
+ existing_thought = agent_thought.thought or ""
+ agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
- tools = agent_thought.tool
- if tools:
- tools = tools.split(";")
+ tool_names_raw = agent_thought.tool
+ if tool_names_raw:
+ tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
- try:
- tool_inputs = json.loads(agent_thought.tool_input)
- except Exception:
- tool_inputs = {tool: {} for tool in tools}
- try:
- tool_responses = json.loads(agent_thought.observation)
- except Exception:
- tool_responses = dict.fromkeys(tools, agent_thought.observation)
+ tool_input_payload = agent_thought.tool_input
+ if tool_input_payload:
+ try:
+ tool_inputs = json.loads(tool_input_payload)
+ except Exception:
+ tool_inputs = {tool: {} for tool in tool_names}
+ else:
+ tool_inputs = {tool: {} for tool in tool_names}
- for tool in tools:
+ observation_payload = agent_thought.observation
+ if observation_payload:
+ try:
+ tool_responses = json.loads(observation_payload)
+ except Exception:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+ else:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+
+ for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
- if not tools:
+ if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index 68d14ad02..7c5c9136a 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
- assistant_message = AssistantPromptMessage(content="", tool_calls=[])
+ assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
- else:
- assistant_message.content = response
self._current_thoughts.append(assistant_message)
diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py
index 307af3747..13c51529c 100644
--- a/api/core/app/app_config/entities.py
+++ b/api/core/app/app_config/entities.py
@@ -1,4 +1,3 @@
-import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
- json_schema: str | None = Field(default=None)
+ json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
- def validate_json_schema(cls, schema: str | None) -> str | None:
+ def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
-
try:
- json_schema = json.loads(schema)
- except json.JSONDecodeError:
- raise ValueError(f"invalid json_schema value {schema}")
-
- try:
- Draft7Validator.check_schema(json_schema)
+ Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py
index e4b308a6f..c21c494ef 100644
--- a/api/core/app/apps/advanced_chat/app_config_manager.py
+++ b/api/core/app/apps/advanced_chat/app_config_manager.py
@@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
-
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index 303c2d9f3..999925a0d 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@@ -319,7 +319,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
- def _initialize_conversation_variables(self) -> list[VariableUnion]:
+ def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@@ -344,7 +344,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
- return cast(list[VariableUnion], conversation_variables)
+ return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 21efd026c..d88dd71fe 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -374,25 +374,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
- # For ANSWER nodes, check if we need to send a message_replace event
- # Only send if the final output differs from the accumulated task_state.answer
- # This happens when variables were updated by variable_assigner during workflow execution
- if event.node_type == NodeType.ANSWER and event.outputs:
- final_answer = event.outputs.get("answer")
- if final_answer is not None and final_answer != self._task_state.answer:
- logger.info(
- "ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
- final_answer,
- self._task_state.answer,
- )
- # Update the task state answer
- self._task_state.answer = str(final_answer)
- # Send message_replace event to update the UI
- yield self._message_cycle_manager.message_replace_to_stream_response(
- answer=str(final_answer),
- reason="variable_update",
- )
-
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py
index a6aace168..07bae6686 100644
--- a/api/core/app/apps/base_app_generator.py
+++ b/api/core/app/apps/base_app_generator.py
@@ -1,4 +1,3 @@
-import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -76,12 +75,24 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
- if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
- raise ValueError("Invalid input type")
- if any(
- filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
- ):
- raise ValueError("Invalid input type")
+ invalid_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, dict)
+ and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
+ ]
+ if invalid_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_dict_keys}")
+
+ invalid_list_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, list)
+ and any(isinstance(item, dict) for item in v)
+ and entity_dictionary[k].type != VariableEntityType.FILE_LIST
+ ]
+ if invalid_list_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
return user_inputs
@@ -178,12 +189,8 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
- if not isinstance(value, str):
- raise ValueError(f"{variable_entity.variable} in input form must be a string")
- try:
- json.loads(value)
- except json.JSONDecodeError:
- raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
+ if value and not isinstance(value, dict):
+ raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")
diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py
index ecea4716b..2e27b89f9 100644
--- a/api/core/app/apps/pipeline/pipeline_runner.py
+++ b/api/core/app/apps/pipeline/pipeline_runner.py
@@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
+from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
-from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 933a134d0..f16342dbd 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -10,7 +10,7 @@ from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
-from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
+from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@@ -492,7 +493,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
- with Session(db.engine, expire_on_commit=False) as session:
+ with session_factory.create_session() as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py
index 7adf3504a..0e8f8b8db 100644
--- a/api/core/app/apps/workflow_app_runner.py
+++ b/api/core/app/apps/workflow_app_runner.py
@@ -25,6 +25,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
+from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -53,7 +54,6 @@ from core.workflow.graph_events import (
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
-from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -166,18 +166,22 @@ class WorkflowBasedAppRunner:
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
- graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
+ graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
+ node_type_filter_key="iteration_id",
+ node_type_label="iteration",
)
elif single_loop_run:
- graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
+ graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
+ node_type_filter_key="loop_id",
+ node_type_label="loop",
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
@@ -314,44 +318,6 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
- def _get_graph_and_variable_pool_of_single_iteration(
- self,
- workflow: Workflow,
- node_id: str,
- user_inputs: dict[str, Any],
- graph_runtime_state: GraphRuntimeState,
- ) -> tuple[Graph, VariablePool]:
- """
- Get variable pool of single iteration
- """
- return self._get_graph_and_variable_pool_for_single_node_run(
- workflow=workflow,
- node_id=node_id,
- user_inputs=user_inputs,
- graph_runtime_state=graph_runtime_state,
- node_type_filter_key="iteration_id",
- node_type_label="iteration",
- )
-
- def _get_graph_and_variable_pool_of_single_loop(
- self,
- workflow: Workflow,
- node_id: str,
- user_inputs: dict[str, Any],
- graph_runtime_state: GraphRuntimeState,
- ) -> tuple[Graph, VariablePool]:
- """
- Get variable pool of single loop
- """
- return self._get_graph_and_variable_pool_for_single_node_run(
- workflow=workflow,
- node_id=node_id,
- user_inputs=user_inputs,
- graph_runtime_state=graph_runtime_state,
- node_type_filter_key="loop_id",
- node_type_label="loop",
- )
-
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py
index 77cc00bdc..c070845b7 100644
--- a/api/core/app/layers/conversation_variable_persist_layer.py
+++ b/api/core/app/layers/conversation_variable_persist_layer.py
@@ -1,6 +1,6 @@
import logging
-from core.variables import Variable
+from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py
index 225b758fc..a7ea9ef44 100644
--- a/api/core/app/layers/trigger_post_layer.py
+++ b/api/core/app/layers/trigger_post_layer.py
@@ -3,8 +3,8 @@ from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
-from sqlalchemy.orm import Session, sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
@@ -31,13 +31,11 @@ class TriggerPostLayer(GraphEngineLayer):
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
- session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
- self.session_maker = session_maker
def on_graph_start(self):
pass
@@ -47,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer):
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
- with self.session_maker() as session:
+ with session_factory.create_session() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:
diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py
new file mode 100644
index 000000000..172ee5d70
--- /dev/null
+++ b/api/core/app/workflow/__init__.py
@@ -0,0 +1,3 @@
+from .node_factory import DifyNodeFactory
+
+__all__ = ["DifyNodeFactory"]
diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/app/workflow/node_factory.py
similarity index 79%
rename from api/core/workflow/nodes/node_factory.py
rename to api/core/app/workflow/node_factory.py
index 557d3a330..e0a0059a3 100644
--- a/api/core/workflow/nodes/node_factory.py
+++ b/api/core/app/workflow/node_factory.py
@@ -1,16 +1,22 @@
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
+from core.file import file_manager
+from core.helper import ssrf_proxy
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
+from core.tools.tool_file_manager import ToolFileManager
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
+from core.workflow.nodes.http_request.node import HttpRequestNode
+from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
@@ -18,8 +24,6 @@ from core.workflow.nodes.template_transform.template_renderer import (
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
-from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
-
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
@@ -43,6 +47,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
+ http_request_http_client: HttpClientProtocol = ssrf_proxy,
+ http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ http_request_file_manager: FileManagerProtocol = file_manager,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@@ -61,6 +68,9 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
+ self._http_request_http_client = http_request_http_client
+ self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
+ self._http_request_file_manager = http_request_file_manager
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@@ -113,6 +123,7 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
+
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
@@ -122,6 +133,17 @@ class DifyNodeFactory(NodeFactory):
template_renderer=self._template_renderer,
)
+ if node_type == NodeType.HTTP_REQUEST:
+ return HttpRequestNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ http_client=self._http_request_http_client,
+ tool_file_manager_factory=self._http_request_tool_file_manager_factory,
+ file_manager=self._http_request_file_manager,
+ )
+
return node_class(
id=node_id,
config=node_config,
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index 1785cbde4..128c64ff2 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
+request_error = httpx.RequestError
+max_retries_exceeded_error = MaxRetriesExceededError
+
+
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index b4c3ec1ca..be1e306d4 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -71,8 +71,8 @@ class LLMGenerator:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
- answer = cast(str, response.message.content)
- if answer is None:
+ answer = response.message.get_text_content()
+ if answer == "":
return ""
try:
result_dict = json.loads(answer)
@@ -184,7 +184,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- rule_config["prompt"] = cast(str, response.message.content)
+ rule_config["prompt"] = response.message.get_text_content()
except InvokeError as e:
error = str(e)
@@ -237,13 +237,11 @@ class LLMGenerator:
return rule_config
- rule_config["prompt"] = cast(str, prompt_content.message.content)
+ rule_config["prompt"] = prompt_content.message.get_text_content()
- if not isinstance(prompt_content.message.content, str):
- raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
- "INPUT_TEXT": prompt_content.message.content,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -253,7 +251,7 @@ class LLMGenerator:
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
- "INPUT_TEXT": prompt_content.message.content,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -263,7 +261,7 @@ class LLMGenerator:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
- rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
+ rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
@@ -272,7 +270,7 @@ class LLMGenerator:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
- rule_config["opening_statement"] = cast(str, statement_content.message.content)
+ rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -315,7 +313,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- generated_code = cast(str, response.message.content)
+ generated_code = response.message.get_text_content()
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -351,7 +349,7 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
- answer = cast(str, response.message.content)
+ answer = response.message.get_text_content()
return answer.strip()
@classmethod
@@ -375,10 +373,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- raw_content = response.message.content
-
- if not isinstance(raw_content, str):
- raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
+ raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 3ac83b4c9..9e46d7289 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
- if not super().is_empty() and not self.tool_calls:
- return False
-
- return True
+ return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py
index d6bd4d201..22ad756c9 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/core/ops/aliyun_trace/aliyun_trace.py
@@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
+from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
@@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)
diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
index d3324f8f8..762458636 100644
--- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py
+++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
- kind=trace_api.SpanKind.INTERNAL,
+ kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
index 20ff2d087..907803149 100644
--- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
+++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
-from opentelemetry.trace import Status, StatusCode
+from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
+ span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index f45f15a6d..84f5bf551 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -35,7 +35,6 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
-from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@@ -473,6 +472,9 @@ class TraceTask:
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
+ # Lazy import to avoid circular import during module initialization
+ from repositories.factory import DifyAPIRepositoryFactory
+
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py
index 0e49824ad..7a6a598a2 100644
--- a/api/core/plugin/impl/base.py
+++ b/api/core/plugin/impl/base.py
@@ -320,18 +320,17 @@ class BasePluginClient:
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
- args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
- raise InvokeRateLimitError(description=args.get("description"))
+ raise InvokeRateLimitError(description=error_object.get("message"))
case InvokeAuthorizationError.__name__:
- raise InvokeAuthorizationError(description=args.get("description"))
+ raise InvokeAuthorizationError(description=error_object.get("message"))
case InvokeBadRequestError.__name__:
- raise InvokeBadRequestError(description=args.get("description"))
+ raise InvokeBadRequestError(description=error_object.get("message"))
case InvokeConnectionError.__name__:
- raise InvokeConnectionError(description=args.get("description"))
+ raise InvokeConnectionError(description=error_object.get("message"))
case InvokeServerUnavailableError.__name__:
- raise InvokeServerUnavailableError(description=args.get("description"))
+ raise InvokeServerUnavailableError(description=error_object.get("message"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
@@ -339,11 +338,11 @@ class BasePluginClient:
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
- raise TriggerPluginInvokeError(description=error_object.get("description"))
+ raise TriggerPluginInvokeError(description=error_object.get("message"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
- raise EventIgnoreError(description=error_object.get("description"))
+ raise EventIgnoreError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:
diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py
index 5b88742be..2db5185a2 100644
--- a/api/core/plugin/impl/endpoint.py
+++ b/api/core/plugin/impl/endpoint.py
@@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
+from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
+
+ This operation is idempotent: if the endpoint is already deleted (record not found),
+ it will return True instead of raising an error.
"""
- return self._request_with_plugin_daemon_response(
- "POST",
- f"plugin/{tenant_id}/endpoint/remove",
- bool,
- data={
- "endpoint_id": endpoint_id,
- },
- headers={
- "Content-Type": "application/json",
- },
- )
+ try:
+ return self._request_with_plugin_daemon_response(
+ "POST",
+ f"plugin/{tenant_id}/endpoint/remove",
+ bool,
+ data={
+ "endpoint_id": endpoint_id,
+ },
+ headers={
+ "Content-Type": "application/json",
+ },
+ )
+ except PluginDaemonInternalServerError as e:
+ # Make delete idempotent: if record is not found, consider it a success
+ if "record not found" in str(e.description).lower():
+ return True
+ raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py
index 5bdb0af0b..50bb2429e 100644
--- a/api/core/rag/datasource/vdb/iris/iris_vector.py
+++ b/api/core/rag/datasource/vdb/iris/iris_vector.py
@@ -154,7 +154,7 @@ class IrisConnectionPool:
# Add to cache to skip future checks
self._schemas_initialized.add(schema)
- except Exception as e:
+ except Exception:
conn.rollback()
logger.exception("Failed to ensure schema %s exists", schema)
raise
@@ -177,6 +177,9 @@ class IrisConnectionPool:
class IrisVector(BaseVector):
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
+ # Fallback score for full-text search when Rank function unavailable or TEXT_INDEX disabled
+ _FULL_TEXT_FALLBACK_SCORE = 0.5
+
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
super().__init__(collection_name)
self.config = config
@@ -272,41 +275,131 @@ class IrisVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- """Search documents by full-text using iFind index or fallback to LIKE search."""
+ """Search documents by full-text using iFind index with BM25 relevance scoring.
+
+ When IRIS_TEXT_INDEX is enabled, this method uses the auto-generated Rank
+ function from %iFind.Index.Basic to calculate BM25 relevance scores. The Rank
+ function is automatically created with naming: {schema}.{table_name}_{index}Rank
+
+ Args:
+ query: Search query string
+ **kwargs: Optional parameters including top_k, document_ids_filter
+
+ Returns:
+ List of Document objects with relevance scores in metadata["score"]
+ """
top_k = kwargs.get("top_k", 5)
+ document_ids_filter = kwargs.get("document_ids_filter")
with self._get_cursor() as cursor:
if self.config.IRIS_TEXT_INDEX:
- # Use iFind full-text search with index
+ # Use iFind full-text search with auto-generated Rank function
text_index_name = f"idx_{self.table_name}_text"
+ # IRIS removes underscores from function names
+ table_no_underscore = self.table_name.replace("_", "")
+ index_no_underscore = text_index_name.replace("_", "")
+ rank_function = f"{self.schema}.{table_no_underscore}_{index_no_underscore}Rank"
+
+ # Build WHERE clause with document ID filter if provided
+ where_clause = f"WHERE %ID %FIND search_index({text_index_name}, ?)"
+ # First param for Rank function, second for FIND
+ params = [query, query]
+
+ if document_ids_filter:
+ # Add document ID filter
+ placeholders = ",".join("?" * len(document_ids_filter))
+ where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
+ params.extend(document_ids_filter)
+
sql = f"""
- SELECT TOP {top_k} id, text, meta
+ SELECT TOP {top_k}
+ id,
+ text,
+ meta,
+ {rank_function}(%ID, ?) AS score
FROM {self.schema}.{self.table_name}
- WHERE %ID %FIND search_index({text_index_name}, ?)
+ {where_clause}
+ ORDER BY score DESC
"""
- cursor.execute(sql, (query,))
+
+ logger.debug(
+ "iFind search: query='%s', index='%s', rank='%s'",
+ query,
+ text_index_name,
+ rank_function,
+ )
+
+ try:
+ cursor.execute(sql, params)
+ except Exception: # pylint: disable=broad-exception-caught
+ # Fallback to query without Rank function if it fails
+ logger.warning(
+ "Rank function '%s' failed, using fixed score",
+ rank_function,
+ exc_info=True,
+ )
+ sql_fallback = f"""
+ SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
+ FROM {self.schema}.{self.table_name}
+ {where_clause}
+ """
+ # Skip first param (for Rank function)
+ cursor.execute(sql_fallback, params[1:])
else:
- # Fallback to LIKE search (inefficient for large datasets)
- # Escape special characters for LIKE clause to prevent SQL injection
- from libs.helper import escape_like_pattern
+ # Fallback to LIKE search (IRIS_TEXT_INDEX disabled)
+ from libs.helper import ( # pylint: disable=import-outside-toplevel
+ escape_like_pattern,
+ )
escaped_query = escape_like_pattern(query)
query_pattern = f"%{escaped_query}%"
+
+ # Build WHERE clause with document ID filter if provided
+ where_clause = "WHERE text LIKE ? ESCAPE '\\\\'"
+ params = [query_pattern]
+
+ if document_ids_filter:
+ placeholders = ",".join("?" * len(document_ids_filter))
+ where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
+ params.extend(document_ids_filter)
+
sql = f"""
- SELECT TOP {top_k} id, text, meta
+ SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
FROM {self.schema}.{self.table_name}
- WHERE text LIKE ? ESCAPE '\\'
+ {where_clause}
+ ORDER BY LENGTH(text) ASC
"""
- cursor.execute(sql, (query_pattern,))
+
+ logger.debug(
+ "LIKE fallback (TEXT_INDEX disabled): query='%s'",
+ query_pattern,
+ )
+ cursor.execute(sql, params)
docs = []
for row in cursor.fetchall():
- if len(row) >= 3:
- metadata = json.loads(row[2]) if row[2] else {}
- docs.append(Document(page_content=row[1], metadata=metadata))
+ # Expecting 4 columns: id, text, meta, score
+ if len(row) >= 4:
+ text_content = row[1]
+ meta_str = row[2]
+ score_value = row[3]
+
+ metadata = json.loads(meta_str) if meta_str else {}
+ # Add score to metadata for hybrid search compatibility
+ score = float(score_value) if score_value is not None else 0.0
+ metadata["score"] = score
+
+ docs.append(Document(page_content=text_content, metadata=metadata))
+
+ logger.info(
+ "Full-text search completed: query='%s', results=%d/%d",
+ query,
+ len(docs),
+ top_k,
+ )
if not docs:
- logger.info("Full-text search for '%s' returned no results", query)
+ logger.warning("Full-text search for '%s' returned no results", query)
return docs
@@ -370,7 +463,11 @@ class IrisVector(BaseVector):
AS %iFind.Index.Basic
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
"""
- logger.info("Creating text index: %s with language: %s", text_index_name, language)
+ logger.info(
+ "Creating text index: %s with language: %s",
+ text_index_name,
+ language,
+ )
logger.info("SQL for text index: %s", sql_text_index)
cursor.execute(sql_text_index)
logger.info("Text index created successfully: %s", text_index_name)
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index b5c7a6310..96268d029 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -130,7 +130,7 @@ class ToolInvokeMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
- json_object: dict
+ json_object: dict | list
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
@@ -144,7 +144,14 @@ class ToolInvokeMessage(BaseModel):
end: bool = Field(..., description="Whether the chunk is the last chunk")
class FileMessage(BaseModel):
- pass
+ file_marker: str = Field(default="file_marker")
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_file_message(cls, values):
+ if isinstance(values, dict) and "file_marker" not in values:
+ raise ValueError("Invalid FileMessage: missing file_marker")
+ return values
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
@@ -234,10 +241,22 @@ class ToolInvokeMessage(BaseModel):
@field_validator("message", mode="before")
@classmethod
- def decode_blob_message(cls, v):
+ def decode_blob_message(cls, v, info: ValidationInfo):
+ # 处理 blob 解码
if isinstance(v, dict) and "blob" in v:
with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"])
+
+ # Force correct message type based on type field
+ # Only wrap dict types to avoid wrapping already parsed Pydantic model objects
+ if info.data and isinstance(info.data, dict) and isinstance(v, dict):
+ msg_type = info.data.get("type")
+ if msg_type == cls.MessageType.JSON:
+ if "json_object" not in v:
+ v = {"json_object": v}
+ elif msg_type == cls.MessageType.FILE:
+ v = {"file_marker": "file_marker"}
+
return v
@field_serializer("message")
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index 13fd579e2..3f57a346c 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -1,5 +1,6 @@
import contextlib
import json
+import logging
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
@@ -36,6 +37,8 @@ from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import Message, MessageFile
+logger = logging.getLogger(__name__)
+
class ToolEngine:
"""
@@ -123,25 +126,31 @@ class ToolEngine:
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
+ logger.error(e, exc_info=True)
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
+ logger.error(e, exc_info=True)
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 81a1d5419..9c1ceff14 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -5,10 +5,9 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
-from flask import has_request_context
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@@ -20,9 +19,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
-from extensions.ext_database import db
from factories.file_factory import build_from_mapping
-from libs.login import current_user
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
@@ -210,50 +207,38 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
- if has_request_context():
- return self._resolve_user_from_request()
- else:
- return self._resolve_user_from_database(user_id=user_id)
-
- def _resolve_user_from_request(self) -> Account | EndUser | None:
- """
- Resolve user from Flask request context.
- """
- try:
- # Note: `current_user` is a LocalProxy. Never compare it with None directly.
- return getattr(current_user, "_get_current_object", lambda: current_user)()
- except Exception as e:
- logger.warning("Failed to resolve user from request context: %s", e)
- return None
+ return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user from database (worker/Celery context).
"""
+ with session_factory.create_session() as session:
+ tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
+ tenant = session.scalar(tenant_stmt)
+ if not tenant:
+ return None
+
+ user_stmt = select(Account).where(Account.id == user_id)
+ user = session.scalar(user_stmt)
+ if user:
+ user.current_tenant = tenant
+ session.expunge(user)
+ return user
+
+ end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
+ end_user = session.scalar(end_user_stmt)
+ if end_user:
+ session.expunge(end_user)
+ return end_user
- tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
- tenant = db.session.scalar(tenant_stmt)
- if not tenant:
return None
- user_stmt = select(Account).where(Account.id == user_id)
- user = db.session.scalar(user_stmt)
- if user:
- user.current_tenant = tenant
- return user
-
- end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
- end_user = db.session.scalar(end_user_stmt)
- if end_user:
- return end_user
-
- return None
-
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@@ -265,22 +250,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
- if not workflow:
- raise ValueError("workflow not found or not published")
+ if not workflow:
+ raise ValueError("workflow not found or not published")
- return workflow
+ session.expunge(workflow)
+ return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
- if not app:
- raise ValueError("app not found")
+ if not app:
+ raise ValueError("app not found")
- return app
+ session.expunge(app)
+ return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py
index 7a1cbf994..749822492 100644
--- a/api/core/variables/__init__.py
+++ b/api/core/variables/__init__.py
@@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
+ VariableBase,
)
__all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
+ "VariableBase",
]
diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py
index 406b4e6f9..8330f1fe1 100644
--- a/api/core/variables/segments.py
+++ b/api/core/variables/segments.py
@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
-# - `Variable` and its subclasses, which are handled by `VariableUnion`.
+# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index 9fd0bbc5b..a19c53918 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
-class Variable(Segment):
+class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
-class StringVariable(StringSegment, Variable):
+class StringVariable(StringSegment, VariableBase):
pass
-class FloatVariable(FloatSegment, Variable):
+class FloatVariable(FloatSegment, VariableBase):
pass
-class IntegerVariable(IntegerSegment, Variable):
+class IntegerVariable(IntegerSegment, VariableBase):
pass
-class ObjectVariable(ObjectSegment, Variable):
+class ObjectVariable(ObjectSegment, VariableBase):
pass
-class ArrayVariable(ArraySegment, Variable):
+class ArrayVariable(ArraySegment, VariableBase):
pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
-class NoneVariable(NoneSegment, Variable):
+class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
-class FileVariable(FileSegment, Variable):
+class FileVariable(FileSegment, VariableBase):
pass
-class BooleanVariable(BooleanSegment, Variable):
+class BooleanVariable(BooleanSegment, VariableBase):
pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
-# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
-# Use `Variable` for type hinting when serialization is not required.
+# The `Variable` type is used to enable serialization and deserialization with Pydantic.
+# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
-# - All variants in `VariableUnion` must inherit from the `Variable` class.
-# - The union must include all non-abstract subclasses of `Segment`, except:
-VariableUnion: TypeAlias = Annotated[
+# - All variants in `Variable` must inherit from the `VariableBase` class.
+# - The union must include all non-abstract subclasses of `VariableBase`.
+Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py
new file mode 100644
index 000000000..1237d6a01
--- /dev/null
+++ b/api/core/workflow/context/__init__.py
@@ -0,0 +1,34 @@
+"""
+Execution Context - Context management for workflow execution.
+
+This package provides Flask-independent context management for workflow
+execution in multi-threaded environments.
+"""
+
+from core.workflow.context.execution_context import (
+ AppContext,
+ ContextProviderNotFoundError,
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+ capture_current_context,
+ read_context,
+ register_context,
+ register_context_capturer,
+ reset_context_provider,
+)
+from core.workflow.context.models import SandboxContext
+
+__all__ = [
+ "AppContext",
+ "ContextProviderNotFoundError",
+ "ExecutionContext",
+ "IExecutionContext",
+ "NullAppContext",
+ "SandboxContext",
+ "capture_current_context",
+ "read_context",
+ "register_context",
+ "register_context_capturer",
+ "reset_context_provider",
+]
diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py
new file mode 100644
index 000000000..e3007530f
--- /dev/null
+++ b/api/core/workflow/context/execution_context.py
@@ -0,0 +1,284 @@
+"""
+Execution Context - Abstracted context management for workflow execution.
+"""
+
+import contextvars
+import threading
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Generator
+from contextlib import AbstractContextManager, contextmanager
+from typing import Any, Protocol, TypeVar, final, runtime_checkable
+
+from pydantic import BaseModel
+
+
+class AppContext(ABC):
+ """
+ Abstract application context interface.
+
+ This abstraction allows workflow execution to work with or without Flask
+ by providing a common interface for application context management.
+ """
+
+ @abstractmethod
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ pass
+
+ @abstractmethod
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name (e.g., 'db', 'cache')."""
+ pass
+
+ @abstractmethod
+ def enter(self) -> AbstractContextManager[None]:
+ """Enter the application context."""
+ pass
+
+
+@runtime_checkable
+class IExecutionContext(Protocol):
+ """
+ Protocol for execution context.
+
+ This protocol defines the interface that all execution contexts must implement,
+ allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
+ """
+
+ def __enter__(self) -> "IExecutionContext":
+ """Enter the execution context."""
+ ...
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ ...
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ ...
+
+
+@final
+class ExecutionContext:
+ """
+ Execution context for workflow execution in worker threads.
+
+ This class encapsulates all context needed for workflow execution:
+ - Application context (Flask app or standalone)
+ - Context variables for Python contextvars
+ - User information (optional)
+
+ It is designed to be serializable and passable to worker threads.
+ """
+
+ def __init__(
+ self,
+ app_context: AppContext | None = None,
+ context_vars: contextvars.Context | None = None,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize execution context.
+
+ Args:
+ app_context: Application context (Flask or standalone)
+ context_vars: Python contextvars to preserve
+ user: User object (optional)
+ """
+ self._app_context = app_context
+ self._context_vars = context_vars
+ self._user = user
+ self._local = threading.local()
+
+ @property
+ def app_context(self) -> AppContext | None:
+ """Get application context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context | None:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """
+ Enter this execution context.
+
+ This is a convenience method that creates a context manager.
+ """
+ # Restore context variables if provided
+ if self._context_vars:
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter app context if available
+ if self._app_context is not None:
+ with self._app_context.enter():
+ yield
+ else:
+ yield
+
+ def __enter__(self) -> "ExecutionContext":
+ """Enter the execution context."""
+ cm = self.enter()
+ self._local.cm = cm
+ cm.__enter__()
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ cm = getattr(self._local, "cm", None)
+ if cm is not None:
+ cm.__exit__(*args)
+
+
+class NullAppContext(AppContext):
+ """
+ Null implementation of AppContext for non-Flask environments.
+
+ This is used when running without Flask (e.g., in tests or standalone mode).
+ """
+
+ def __init__(self, config: dict[str, Any] | None = None) -> None:
+ """
+ Initialize null app context.
+
+ Args:
+ config: Optional configuration dictionary
+ """
+ self._config = config or {}
+ self._extensions: dict[str, Any] = {}
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ return self._config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get extension by name."""
+ return self._extensions.get(name)
+
+ def set_extension(self, name: str, extension: Any) -> None:
+ """Set extension by name."""
+ self._extensions[name] = extension
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter null context (no-op)."""
+ yield
+
+
+class ExecutionContextBuilder:
+ """
+ Builder for creating ExecutionContext instances.
+
+ This provides a fluent API for building execution contexts.
+ """
+
+ def __init__(self) -> None:
+ self._app_context: AppContext | None = None
+ self._context_vars: contextvars.Context | None = None
+ self._user: Any = None
+
+ def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
+ """Set application context."""
+ self._app_context = app_context
+ return self
+
+ def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
+ """Set context variables."""
+ self._context_vars = context_vars
+ return self
+
+ def with_user(self, user: Any) -> "ExecutionContextBuilder":
+ """Set user."""
+ self._user = user
+ return self
+
+ def build(self) -> ExecutionContext:
+ """Build the execution context."""
+ return ExecutionContext(
+ app_context=self._app_context,
+ context_vars=self._context_vars,
+ user=self._user,
+ )
+
+
+_capturer: Callable[[], IExecutionContext] | None = None
+
+# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
+# Key mapping:
+# (name, tenant_id) -> provider
+# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
+# - tenant_id: tenant identifier string
+# Value:
+# provider: Callable[[], BaseModel] returning the typed context value
+# Type-safety note:
+# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
+# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
+# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
+# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
+_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
+
+T = TypeVar("T", bound=BaseModel)
+
+
+class ContextProviderNotFoundError(KeyError):
+ """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
+
+ pass
+
+
+def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
+ """Register a single enterable execution context capturer (e.g., Flask)."""
+ global _capturer
+ _capturer = capturer
+
+
+def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
+ """Register a tenant-specific provider for a named context.
+
+ Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
+ Consider adding a typed wrapper for this registration in your feature module.
+ """
+ _tenant_context_providers[(name, tenant_id)] = provider
+
+
+def read_context(name: str, *, tenant_id: str) -> BaseModel:
+ """
+ Read a context value for a specific tenant.
+
+ Raises KeyError if the provider for (name, tenant_id) is not registered.
+ """
+ prov = _tenant_context_providers.get((name, tenant_id))
+ if prov is None:
+ raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
+ return prov()
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context from the calling environment.
+
+ If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
+ context with NullAppContext + copy of current contextvars.
+ """
+ if _capturer is None:
+ return ExecutionContext(
+ app_context=NullAppContext(),
+ context_vars=contextvars.copy_context(),
+ )
+ return _capturer()
+
+
+def reset_context_provider() -> None:
+ """Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
+ global _capturer
+ _capturer = None
+ _tenant_context_providers.clear()
diff --git a/api/core/workflow/context/models.py b/api/core/workflow/context/models.py
new file mode 100644
index 000000000..af5a4b261
--- /dev/null
+++ b/api/core/workflow/context/models.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+from pydantic import AnyHttpUrl, BaseModel
+
+
+class SandboxContext(BaseModel):
+ """Typed context for sandbox integration. All fields optional by design."""
+
+ sandbox_url: AnyHttpUrl | None = None
+ sandbox_token: str | None = None # optional, if later needed for auth
+
+
+__all__ = ["SandboxContext"]
diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py
index fd78248c1..75f47691d 100644
--- a/api/core/workflow/conversation_variable_updater.py
+++ b/api/core/workflow/conversation_variable_updater.py
@@ -1,7 +1,7 @@
import abc
from typing import Protocol
-from core.variables import Variable
+from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
- def update(self, conversation_id: str, variable: "Variable"):
+ def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
- :param variable: The `Variable` instance containing the updated value.
+ :param variable: The `VariableBase` instance containing the updated value.
"""
pass
diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py
index c08b62a25..bb3b13e8c 100644
--- a/api/core/workflow/enums.py
+++ b/api/core/workflow/enums.py
@@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
+ @classmethod
+ def ended_values(cls) -> list[str]:
+ return [status.value for status in _END_STATE]
+
_END_STATE = frozenset(
[
diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py
index 6dce03c94..41276eb44 100644
--- a/api/core/workflow/graph_engine/entities/commands.py
+++ b/api/core/workflow/graph_engine/entities/commands.py
@@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
class CommandType(StrEnum):
@@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
- value: VariableUnion = Field(description="New variable value")
+ value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index 9a870d7bf..dbb2727c9 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -7,15 +7,13 @@ Domain-Driven Design principles for improved maintainability and testability.
from __future__ import annotations
-import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
-from flask import Flask, current_app
-
+from core.workflow.context import capture_current_context
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@@ -159,17 +157,8 @@ class GraphEngine:
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
- # Capture Flask app context for worker threads
- flask_app: Flask | None = None
- try:
- app = current_app._get_current_object() # type: ignore
- if isinstance(app, Flask):
- flask_app = app
- except RuntimeError:
- pass
-
- # Capture context variables for worker threads
- context_vars = contextvars.copy_context()
+ # Capture execution context for worker threads
+ execution_context = capture_current_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
@@ -177,8 +166,7 @@ class GraphEngine:
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
- flask_app=flask_app,
- context_vars=context_vars,
+ execution_context=execution_context,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py
index 83419830b..6c69ea5df 100644
--- a/api/core/workflow/graph_engine/worker.py
+++ b/api/core/workflow/graph_engine/worker.py
@@ -5,26 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
-import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
-from typing import final
-from uuid import uuid4
+from typing import TYPE_CHECKING, final
-from flask import Flask
from typing_extensions import override
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
-from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
+if TYPE_CHECKING:
+ pass
+
@final
class Worker(threading.Thread):
@@ -44,8 +44,7 @@ class Worker(threading.Thread):
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
- flask_app: Flask | None = None,
- context_vars: contextvars.Context | None = None,
+ execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize worker thread.
@@ -56,19 +55,17 @@ class Worker(threading.Thread):
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
- flask_app: Optional Flask application for context preservation
- context_vars: Optional context variables to preserve in worker thread
+ execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
- self._flask_app = flask_app
- self._context_vars = context_vars
- self._last_task_time = time.time()
+ self._execution_context = execution_context
self._stop_event = stop_event
self._layers = layers if layers is not None else []
+ self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
@@ -115,7 +112,7 @@ class Worker(threading.Thread):
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
- id=str(uuid4()),
+ id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
@@ -135,11 +132,9 @@ class Worker(threading.Thread):
error: Exception | None = None
- if self._flask_app and self._context_vars:
- with preserve_flask_contexts(
- flask_app=self._flask_app,
- context_vars=self._context_vars,
- ):
+ # Execute the node with preserved context if execution context is provided
+ if self._execution_context is not None:
+ with self._execution_context:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py
index df76ebe88..9ce7d16e9 100644
--- a/api/core/workflow/graph_engine/worker_management/worker_pool.py
+++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py
@@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
-from typing import TYPE_CHECKING, final
+from typing import final
from configs import dify_config
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
@@ -20,11 +21,6 @@ from ..worker import Worker
logger = logging.getLogger(__name__)
-if TYPE_CHECKING:
- from contextvars import Context
-
- from flask import Flask
-
@final
class WorkerPool:
@@ -42,8 +38,7 @@ class WorkerPool:
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
- flask_app: "Flask | None" = None,
- context_vars: "Context | None" = None,
+ execution_context: IExecutionContext | None = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
@@ -57,8 +52,7 @@ class WorkerPool:
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
- flask_app: Optional Flask app for context preservation
- context_vars: Optional context variables
+ execution_context: Optional execution context for context preservation
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
@@ -67,8 +61,7 @@ class WorkerPool:
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
- self._flask_app = flask_app
- self._context_vars = context_vars
+ self._execution_context = execution_context
self._layers = layers
# Scaling parameters with defaults
@@ -152,8 +145,7 @@ class WorkerPool:
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
- flask_app=self._flask_app,
- context_vars=self._context_vars,
+ execution_context=self._execution_context,
stop_event=self._stop_event,
)
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 234651ce9..5a365f769 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -235,7 +235,18 @@ class AgentNode(Node[AgentNodeData]):
0,
):
value_param = param.get("value", {})
- params[key] = value_param.get("value", "") if value_param is not None else None
+ if value_param and value_param.get("type", "") == "variable":
+ variable_selector = value_param.get("value")
+ if not variable_selector:
+ raise ValueError("Variable selector is missing for a variable-type parameter.")
+
+ variable = variable_pool.get(variable_selector)
+ if variable is None:
+ raise AgentVariableNotFoundError(str(variable_selector))
+
+ params[key] = variable.value
+ else:
+ params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
@@ -483,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
text = ""
files: list[File] = []
- json_list: list[dict] = []
+ json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
@@ -557,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
- msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
- llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
- agent_execution_metadata = {
- WorkflowNodeExecutionMetadataKey(key): value
- for key, value in msg_metadata.items()
- if key in WorkflowNodeExecutionMetadataKey.__members__.values()
- }
+ if isinstance(message.message.json_object, dict):
+ msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
+ llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
+ agent_execution_metadata = {
+ WorkflowNodeExecutionMetadataKey(key): value
+ for key, value in msg_metadata.items()
+ if key in WorkflowNodeExecutionMetadataKey.__members__.values()
+ }
+ else:
+ msg_metadata = {}
+ llm_usage = LLMUsage.empty_usage()
+ agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
@@ -672,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
- json_output: list[dict[str, Any]] = []
+ json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py
index 55c8db40e..63e026034 100644
--- a/api/core/workflow/nodes/base/node.py
+++ b/api/core/workflow/nodes/base/node.py
@@ -469,12 +469,8 @@ class Node(Generic[NodeDataT]):
import core.workflow.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
- # Avoid importing modules that depend on the registry to prevent circular imports
- # e.g. node_factory imports node_mapping which builds the mapping here.
- if _modname in {
- "core.workflow.nodes.node_factory",
- "core.workflow.nodes.node_mapping",
- }:
+ # Avoid importing modules that depend on the registry to prevent circular imports.
+ if _modname == "core.workflow.nodes.node_mapping":
continue
importlib.import_module(_modname)
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index bb2140f42..925561cf7 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
text = ""
files: list[File] = []
- json: list[dict] = []
+ json: list[dict | list] = []
variables: dict[str, Any] = {}
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index 931c6113a..429f8411a 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
+from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
+ http_client: HttpClientProtocol = ssrf_proxy,
+ file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@@ -104,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
+ self._http_client = http_client
+ self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@@ -200,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
- self.content = file_manager.download(file)
+ self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@@ -239,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
- file_manager.download(file),
+ self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@@ -332,19 +337,18 @@ class Executor:
do http request depending on api bundle
"""
_METHOD_MAP = {
- "get": ssrf_proxy.get,
- "head": ssrf_proxy.head,
- "post": ssrf_proxy.post,
- "put": ssrf_proxy.put,
- "delete": ssrf_proxy.delete,
- "patch": ssrf_proxy.patch,
+ "get": self._http_client.get,
+ "head": self._http_client.head,
+ "post": self._http_client.post,
+ "put": self._http_client.put,
+ "delete": self._http_client.delete,
+ "patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
- "url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
@@ -357,8 +361,12 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
- response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
- except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
+ response: httpx.Response = _METHOD_MAP[method_lc](
+ url=self.url,
+ **request_args,
+ max_retries=self.max_retries,
+ )
+ except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index 9bd1cb976..964e53e03 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -1,10 +1,11 @@
import logging
import mimetypes
-from collections.abc import Mapping, Sequence
-from typing import Any
+from collections.abc import Callable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any
from configs import dify_config
-from core.file import File, FileTransferMethod
+from core.file import File, FileTransferMethod, file_manager
+from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ http_client: HttpClientProtocol = ssrf_proxy,
+ tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ file_manager: FileManagerProtocol = file_manager,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._http_client = http_client
+ self._tool_file_manager_factory = tool_file_manager_factory
+ self._file_manager = file_manager
+
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
+ http_client=self._http_client,
+ file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
- tool_file_manager = ToolFileManager()
+ tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index e5d86414c..ced996e7e 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -1,17 +1,15 @@
-import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
-from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
-from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -51,6 +48,7 @@ from .exc import (
)
if TYPE_CHECKING:
+ from core.workflow.context import IExecutionContext
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@@ -240,7 +238,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
- dict[str, VariableUnion],
+ dict[str, Variable],
LLMUsage,
]
],
@@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._execute_single_iteration_parallel,
index=index,
item=item,
- flask_app=current_app._get_current_object(), # type: ignore
- context_vars=contextvars.copy_context(),
+ execution_context=self._capture_execution_context(),
)
future_to_index[future] = index
@@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self,
index: int,
item: object,
- flask_app: Flask,
- context_vars: contextvars.Context,
- ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
+ execution_context: "IExecutionContext",
+ ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
- with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
+ with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
@@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph_engine.graph_runtime_state.llm_usage,
)
+ def _capture_execution_context(self) -> "IExecutionContext":
+ """Capture current execution context for parallel iterations."""
+ from core.workflow.context import capture_current_context
+
+ return capture_current_context()
+
def _handle_iteration_success(
self,
started_at: datetime,
@@ -515,11 +517,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
- def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
+ def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
- def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
+ def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
@@ -586,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
+ from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
- from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
index 1f9fc8a11..07d05966c 100644
--- a/api/core/workflow/nodes/loop/loop_node.py
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
+ from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
- from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py
new file mode 100644
index 000000000..e7dcf62fc
--- /dev/null
+++ b/api/core/workflow/nodes/protocols.py
@@ -0,0 +1,29 @@
+from typing import Protocol
+
+import httpx
+
+from core.file import File
+
+
+class HttpClientProtocol(Protocol):
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]: ...
+
+ @property
+ def request_error(self) -> type[Exception]: ...
+
+ def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+
+class FileManagerProtocol(Protocol):
+ def download(self, f: File, /) -> bytes: ...
diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py
index 36fc5078c..53c1b4ee6 100644
--- a/api/core/workflow/nodes/start/start_node.py
+++ b/api/core/workflow/nodes/start/start_node.py
@@ -1,4 +1,3 @@
-import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
+ # If no value provided, skip further processing for this key
+ if not value:
+ continue
+
+ if not isinstance(value, dict):
+ raise ValueError(f"JSON object for '{key}' must be an object")
+
+ # Overwrite with normalized dict to ensure downstream consistency
+ node_inputs[key] = value
+
+ # If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
- if not value:
- continue
-
try:
- json_schema = json.loads(schema)
- except json.JSONDecodeError as e:
- raise ValueError(f"{schema} must be a valid JSON object")
-
- try:
- json_value = json.loads(value)
- except json.JSONDecodeError as e:
- raise ValueError(f"{value} must be a valid JSON object")
-
- try:
- Draft7Validator(json_schema).validate(json_value)
+ Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
- node_inputs[key] = json_value
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 2e7ec757b..68ac60e4f 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
text = ""
files: list[File] = []
- json: list[dict] = []
+ json: list[dict | list] = []
variables: dict[str, Any] = {}
@@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
message.message.metadata = dict_metadata
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
- json_output: list[dict[str, Any]] = []
+ json_output: list[dict[str, Any] | list[Any]] = []
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py
index d2ea7d94e..9f5818f4b 100644
--- a/api/core/workflow/nodes/variable_assigner/v1/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v1/node.py
@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_runtime_state=graph_runtime_state,
)
+ def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
+ """
+ Check if this Variable Assigner node blocks the output of specific variables.
+
+ Returns True if this node updates any of the requested conversation variables.
+ """
+ assigned_selector = tuple(self.node_data.assigned_variable_selector)
+ return assigned_selector in variable_selectors
+
@classmethod
def version(cls) -> str:
return "1"
@@ -64,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
- if not isinstance(original_variable, Variable):
+ if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index 486e6bb6a..5857702e7 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
@@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
- variable: Variable,
+ variable: VariableBase,
operation: Operation,
value: Any,
):
diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py
index 85ceb9d59..d205c6ac8 100644
--- a/api/core/workflow/runtime/variable_pool.py
+++ b/api/core/workflow/runtime/variable_pool.py
@@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
-from core.variables import Segment, SegmentGroup, Variable
+from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
-from core.variables.variables import RAGPipelineVariableInput, VariableUnion
+from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
- variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
+ variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables",
default_factory=SystemVariable.empty,
)
- environment_variables: Sequence[VariableUnion] = Field(
+ environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
- conversation_variables: Sequence[VariableUnion] = Field(
+ conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py
index ea0bdc353..7992785fe 100644
--- a/api/core/workflow/variable_loader.py
+++ b/api/core/workflow/variable_loader.py
@@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
-from core.variables import Variable
+from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
- :return: a list of Variable objects that match the provided selectors.
+ :return: a list of VariableBase objects that match the provided selectors.
"""
pass
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index ddf545bb3..c7bcc66c8 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -7,6 +7,7 @@ from typing import Any
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.workflow.node_factory import DifyNodeFactory
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
@@ -136,13 +137,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
- node_config = workflow.get_node_config_by_id(node_id)
+ node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
- # Get node class
+ # Get node type
node_type = NodeType(node_config_data.get("type"))
- node_version = node_config_data.get("version", "1")
- node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -158,12 +157,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
- node = node_cls(
- id=str(uuid.uuid4()),
- config=node_config,
+ node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
+ node = node_factory.create_node(node_config)
+ node_cls = type(node)
try:
# variable selector to variable mapping
@@ -190,8 +189,7 @@ class WorkflowEntry:
)
try:
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
@@ -324,8 +322,7 @@ class WorkflowEntry:
tenant_id=tenant_id,
)
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
return node, generator
except Exception as e:
@@ -431,3 +428,26 @@ class WorkflowEntry:
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)
+
+ @staticmethod
+ def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]:
+ """
+ Wraps a node's run method with OpenTelemetry tracing and returns a generator.
+ """
+ # Wrap node.run() with ObservabilityLayer hooks to produce node-level spans
+ layer = ObservabilityLayer()
+ layer.on_graph_start()
+ node.ensure_execution_id()
+
+ def _gen():
+ error: Exception | None = None
+ layer.on_node_run_start(node)
+ try:
+ yield from node.run()
+ except Exception as exc:
+ error = exc
+ raise
+ finally:
+ layer.on_node_run_end(node, error)
+
+ return _gen()
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 305fb4c6d..216712359 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -164,6 +164,13 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
+ if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
+ # for saas only
+ imports.append("schedule.clean_workflow_runs_task")
+ beat_schedule["clean_workflow_runs_task"] = {
+ "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
+ "schedule": crontab(minute="0", hour="0"),
+ }
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {
diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py
index ac8967cd7..569568722 100644
--- a/api/extensions/ext_commands.py
+++ b/api/extensions/ext_commands.py
@@ -4,11 +4,15 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
+ archive_workflow_runs,
+ clean_expired_messages,
+ clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
convert_to_agent_apps,
create_tenant,
+ delete_archived_workflow_runs,
extend_db,
extract_plugins,
extract_unique_plugins,
@@ -23,6 +27,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
+ restore_workflow_runs,
setup_datasource_oauth_client,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
@@ -57,6 +62,11 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
+ archive_workflow_runs,
+ delete_archived_workflow_runs,
+ restore_workflow_runs,
+ clean_workflow_runs,
+ clean_expired_messages,
extend_db,
]
for cmd in cmds_to_register:
diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py
new file mode 100644
index 000000000..0ef1513e1
--- /dev/null
+++ b/api/extensions/ext_fastopenapi.py
@@ -0,0 +1,43 @@
+from fastopenapi.routers import FlaskRouter
+from flask_cors import CORS
+
+from configs import dify_config
+from controllers.fastopenapi import console_router
+from dify_app import DifyApp
+from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
+
+DOCS_PREFIX = "/fastopenapi"
+
+
+def init_app(app: DifyApp) -> None:
+ docs_enabled = dify_config.SWAGGER_UI_ENABLED
+ docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
+ redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None
+ openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None
+
+ router = FlaskRouter(
+ app=app,
+ docs_url=docs_url,
+ redoc_url=redoc_url,
+ openapi_url=openapi_url,
+ openapi_version="3.0.0",
+ title="Dify API (FastOpenAPI PoC)",
+ version="1.0",
+ description="FastOpenAPI proof of concept for Dify API",
+ )
+
+ # Ensure route decorators are evaluated.
+ import controllers.console.ping as ping_module
+
+ _ = ping_module
+
+ router.include_router(console_router, prefix="/console/api")
+ CORS(
+ app,
+ resources={r"/console/api/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
+ supports_credentials=True,
+ allow_headers=list(AUTHENTICATED_HEADERS),
+ methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ expose_headers=list(EXPOSED_HEADERS),
+ )
+ app.extensions["fastopenapi"] = router
diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py
index 502f0bb46..cda2d1ad1 100644
--- a/api/extensions/ext_logstore.py
+++ b/api/extensions/ext_logstore.py
@@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
+from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
+ Logstore is considered enabled when:
+ 1. All required Aliyun SLS environment variables are set
+ 2. At least one repository configuration points to a logstore implementation
+
Returns:
- True if all required Aliyun SLS environment variables are set, False otherwise
+ True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
+ # Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
- all_set = all(os.environ.get(var) for var in required_vars)
+ sls_vars_set = all(os.environ.get(var) for var in required_vars)
- if not all_set:
- logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
+ if not sls_vars_set:
+ return False
- return all_set
+ # Check if any repository configuration points to logstore implementation
+ repository_configs = [
+ dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
+ dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_RUN_REPOSITORY,
+ ]
+
+ uses_logstore = any("logstore" in config.lower() for config in repository_configs)
+
+ if not uses_logstore:
+ return False
+
+ logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
+ return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
-
- This function:
- 1. Creates Aliyun SLS project if it doesn't exist
- 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
- 3. Creates indexes with field configurations based on PostgreSQL table structures
-
- This operation is idempotent and only executes once during application startup.
+ If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
- logger.info("Initializing logstore...")
+ logger.info("Initializing Aliyun SLS Logstore...")
- # Create logstore client and initialize project/logstores/indexes
+ # Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
- # Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
+
except Exception:
- logger.exception("Failed to initialize logstore")
- # Don't raise - allow application to continue even if logstore init fails
- # This ensures that the application can still run if logstore is misconfigured
+ logger.exception(
+ "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
+ "Application will continue but logstore features will NOT work.",
+ os.environ.get("ALIYUN_SLS_ENDPOINT"),
+ os.environ.get("ALIYUN_SLS_REGION"),
+ os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
+ os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
+ )
+ # Don't raise - allow application to continue even if logstore setup fails
diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py
index 8c64a25be..f6a4765f1 100644
--- a/api/extensions/logstore/aliyun_logstore.py
+++ b/api/extensions/logstore/aliyun_logstore.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import os
+import socket
import threading
import time
from collections.abc import Sequence
@@ -179,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
- self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
+ self.log_enabled: bool = (
+ os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
+ or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
+ )
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
+ # Get timeout configuration
+ check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
+
+ # Pre-check endpoint connectivity to prevent indefinite hangs
+ self._check_endpoint_connectivity(self.endpoint, check_timeout)
+
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@@ -199,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
+ @staticmethod
+ def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
+ """
+ Check if the SLS endpoint is reachable before creating LogClient.
+ Prevents indefinite hangs when the endpoint is unreachable.
+
+ Args:
+ endpoint: SLS endpoint URL
+ timeout: Connection timeout in seconds
+
+ Raises:
+ ConnectionError: If endpoint is not reachable
+ """
+ # Parse endpoint URL to extract hostname and port
+ from urllib.parse import urlparse
+
+ parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
+ hostname = parsed_url.hostname
+ port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
+
+ if not hostname:
+ raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
+
+ sock = None
+ try:
+ # Create socket and set timeout
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(timeout)
+ sock.connect((hostname, port))
+ except Exception as e:
+ # Catch all exceptions and provide clear error message
+ error_type = type(e).__name__
+ raise ConnectionError(
+ f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
+ ) from e
+ finally:
+ # Ensure socket is properly closed
+ if sock:
+ try:
+ sock.close()
+ except Exception: # noqa: S110
+ pass # Ignore errors during cleanup
+
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@@ -220,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
- logger.info("Successfully connected to project %s using PG protocol", self.project_name)
+ logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
- logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
+ logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
- logger.warning(
- "Failed to establish PG connection for project %s: %s. Will use SDK mode.",
- self.project_name,
- str(e),
- )
+ logger.info("Using SDK mode for project %s", self.project_name)
+ logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@@ -246,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
- logger.info(
- "Attempting delayed PG connection for newly created project %s ...",
- self.project_name,
- )
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@@ -284,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
- logger.info(
- "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
- self.project_name,
- self.__class__._pg_connection_delay,
- )
+ logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@@ -299,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
- logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@@ -318,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
- "Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
- "PG protocol requires scan_index to be enabled.",
+ "Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
+ self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@@ -748,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
- # Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@@ -770,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
- # Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@@ -845,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
- # Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@@ -853,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
- query,
- sql,
+ full_query,
)
try:
@@ -865,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
- # Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py
index 35aa51ce5..874c20d14 100644
--- a/api/extensions/logstore/aliyun_logstore_pg.py
+++ b/api/extensions/logstore/aliyun_logstore_pg.py
@@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
-import psycopg2.pool
-from psycopg2 import InterfaceError, OperationalError
+from sqlalchemy import create_engine
from configs import dify_config
@@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
- """
- PostgreSQL protocol support for Aliyun SLS LogStore.
-
- Handles PG connection pooling and operations for regions that support PG protocol.
- """
+ """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
- self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
+ self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
- """
- Check if a TCP port is reachable using socket connection.
-
- This provides a fast check before attempting full database connection,
- preventing long waits when connecting to unsupported regions.
-
- Args:
- host: Hostname or IP address
- port: Port number
- timeout: Connection timeout in seconds (default: 2.0)
-
- Returns:
- True if port is reachable, False otherwise
- """
+ """Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
- """
- Initialize PostgreSQL connection pool for SLS PG protocol support.
-
- Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
- _use_pg_protocol to True and creates a connection pool. If connection fails
- (region doesn't support PG protocol or other errors), returns False.
-
- Returns:
- True if PG protocol is supported and initialized, False otherwise
- """
+ """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
- # Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
- # Get pool configuration
- pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
+ # Pool configuration
+ pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
+ max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
+ pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
+ pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
- logger.debug(
- "Check PG protocol connection to SLS: host=%s, project=%s",
- pg_host,
- self.project_name,
- )
+ logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
- # Fast port connectivity check before attempting full connection
- # This prevents long waits when connecting to unsupported regions
+ # Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
- logger.info(
- "USE SDK mode for read/write operations, host=%s",
- pg_host,
- )
+ logger.debug("Using SDK mode for host=%s", pg_host)
return False
- # Create connection pool
- self._pg_pool = psycopg2.pool.SimpleConnectionPool(
- minconn=1,
- maxconn=pg_max_connections,
- host=pg_host,
- port=5432,
- database=self.project_name,
- user=self._access_key_id,
- password=self._access_key_secret,
- sslmode="require",
- connect_timeout=5,
- application_name=f"Dify-{dify_config.project.version}",
+ # Build connection URL
+ from urllib.parse import quote_plus
+
+ username = quote_plus(self._access_key_id)
+ password = quote_plus(self._access_key_secret)
+ database_url = (
+ f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
- # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
- # Connection pool creation success already indicates connectivity
+ # Create SQLAlchemy engine with connection pool
+ self._engine = create_engine(
+ database_url,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ pool_pre_ping=pool_pre_ping,
+ pool_timeout=30,
+ connect_args={
+ "connect_timeout": 5,
+ "application_name": f"Dify-{dify_config.project.version}-fixautocommit",
+ "keepalives": 1,
+ "keepalives_idle": 60,
+ "keepalives_interval": 10,
+ "keepalives_count": 5,
+ },
+ )
self._use_pg_protocol = True
logger.info(
- "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
+ "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
+ pool_size,
+ pool_recycle,
)
return True
except Exception as e:
- # PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
- if self._pg_pool:
+ if self._engine:
try:
- self._pg_pool.closeall()
+ self._engine.dispose()
except Exception:
- logger.debug("Failed to close PG connection pool during cleanup, ignoring")
- self._pg_pool = None
+ logger.debug("Failed to dispose engine during cleanup, ignoring")
+ self._engine = None
- logger.info(
- "PG protocol connection failed (region may not support PG protocol): %s. "
- "Falling back to SDK mode for read/write operations.",
- str(e),
- )
- return False
-
- def _is_connection_valid(self, conn: Any) -> bool:
- """
- Check if a connection is still valid.
-
- Args:
- conn: psycopg2 connection object
-
- Returns:
- True if connection is valid, False otherwise
- """
- try:
- # Check if connection is closed
- if conn.closed:
- return False
-
- # Quick ping test - execute a lightweight query
- # For SLS PG protocol, we can't use SELECT 1 without FROM,
- # so we just check the connection status
- with conn.cursor() as cursor:
- cursor.execute("SELECT 1")
- cursor.fetchone()
- return True
- except Exception:
+ logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
- """
- Context manager to get a PostgreSQL connection from the pool.
+ """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
+ if not self._engine:
+ raise RuntimeError("SQLAlchemy engine is not initialized")
- Automatically validates and refreshes stale connections.
-
- Note: Aliyun SLS PG protocol does not support transactions, so we always
- use autocommit mode.
-
- Yields:
- psycopg2 connection object
-
- Raises:
- RuntimeError: If PG pool is not initialized
- """
- if not self._pg_pool:
- raise RuntimeError("PG connection pool is not initialized")
-
- conn = self._pg_pool.getconn()
+ connection = self._engine.raw_connection()
try:
- # Validate connection and get a fresh one if needed
- if not self._is_connection_valid(conn):
- logger.debug("Connection is stale, marking as bad and getting a new one")
- # Mark connection as bad and get a new one
- self._pg_pool.putconn(conn, close=True)
- conn = self._pg_pool.getconn()
-
- # Aliyun SLS PG protocol does not support transactions, always use autocommit
- conn.autocommit = True
- yield conn
+ connection.autocommit = True # SLS PG protocol does not support transactions
+ yield connection
+ except Exception:
+ raise
finally:
- # Return connection to pool (or close if it's bad)
- if self._is_connection_valid(conn):
- self._pg_pool.putconn(conn)
- else:
- self._pg_pool.putconn(conn, close=True)
+ connection.close()
def close(self) -> None:
- """Close the PostgreSQL connection pool."""
- if self._pg_pool:
+ """Dispose SQLAlchemy engine and close all connections."""
+ if self._engine:
try:
- self._pg_pool.closeall()
- logger.info("PG connection pool closed")
+ self._engine.dispose()
+ logger.info("SQLAlchemy engine disposed")
except Exception:
- logger.exception("Failed to close PG connection pool")
+ logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
- """
- Check if an error is retriable (connection-related issues).
-
- Args:
- error: Exception to check
-
- Returns:
- True if the error is retriable, False otherwise
- """
- # Retry on connection-related errors
- if isinstance(error, (OperationalError, InterfaceError)):
+ """Check if error is retriable (connection-related issues)."""
+ # Check for psycopg2 connection errors directly
+ if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
- # Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
+ "operational error",
+ "interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
- """
- Write log to SLS using PostgreSQL protocol with automatic retry.
-
- Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
- writes with log_version field for versioning, same as SDK implementation.
-
- Args:
- logstore: Name of the logstore table
- contents: List of (field_name, value) tuples
- log_enabled: Whether to enable logging
-
- Raises:
- psycopg2.Error: If database operation fails after all retries
- """
+ """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
- # Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
-
- # Build INSERT statement with literal values
- # Note: Aliyun SLS PG protocol doesn't support parameterized queries,
- # so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
- # Retry configuration
max_retries = 3
- retry_delay = 0.1 # Start with 100ms
+ retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
- # Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
- # Success - exit retry loop
return
except psycopg2.Error as e:
- # Check if error is retriable
if not self._is_retriable_error(e):
- # Not a retriable error (e.g., data validation error), fail immediately
- logger.exception(
- "Failed to put logs to logstore %s via PG protocol (non-retriable error)",
- logstore,
- )
+ logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
- # Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
- "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
+ "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
- retry_delay *= 2 # Exponential backoff
+ retry_delay *= 2
else:
- # Last attempt failed
- logger.exception(
- "Failed to put logs to logstore %s via PG protocol after %d attempts",
- logstore,
- max_retries,
- )
+ logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
- """
- Execute SQL query using PostgreSQL protocol with automatic retry.
-
- Args:
- sql: SQL query string
- logstore: Name of the logstore (for logging purposes)
- log_enabled: Whether to enable logging
-
- Returns:
- List of result rows as dictionaries
-
- Raises:
- psycopg2.Error: If database operation fails after all retries
- """
+ """Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
- # Retry configuration
max_retries = 3
- retry_delay = 0.1 # Start with 100ms
+ retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
-
- # Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
- # Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
- # Check if error is retriable
if not self._is_retriable_error(e):
- # Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
- "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
+ "Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
- # Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
- "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
+ "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
- retry_delay *= 2 # Exponential backoff
+ retry_delay *= 2
else:
- # Last attempt failed
logger.exception(
- "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
+ "Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
- # This line should never be reached due to raise above, but makes type checker happy
return []
diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py
index e69de29bb..b5a4fcf84 100644
--- a/api/extensions/logstore/repositories/__init__.py
+++ b/api/extensions/logstore/repositories/__init__.py
@@ -0,0 +1,29 @@
+"""
+LogStore repository utilities.
+"""
+
+from typing import Any
+
+
+def safe_float(value: Any, default: float = 0.0) -> float:
+ """
+ Safely convert a value to float, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ return default
+
+
+def safe_int(value: Any, default: int = 0) -> int:
+ """
+ Safely convert a value to int, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return int(float(value))
+ except (ValueError, TypeError):
+ return default
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
index 8c804d6bb..f67723630 100644
--- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
@@ -14,6 +14,8 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
- # Numeric fields with defaults
- model.index = int(data.get("index", 0))
- model.elapsed_time = float(data.get("elapsed_time", 0))
+ model.index = safe_int(data.get("index", 0))
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_id = escape_identifier(workflow_id)
+ escaped_node_id = escape_identifier(node_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE tenant_id = '{tenant_id}'
- AND app_id = '{app_id}'
- AND workflow_id = '{workflow_id}'
- AND node_id = '{node_id}'
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_id = '{escaped_workflow_id}'
+ AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
- f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE tenant_id = '{tenant_id}'
- AND app_id = '{app_id}'
- AND workflow_run_id = '{workflow_run_id}'
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
- query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
+ query = (
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_run_id: {escaped_workflow_run_id}"
+ )
from_time = 0
to_time = int(time.time()) # now
@@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_execution_id = escape_identifier(execution_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
- tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
+ if tenant_id:
+ escaped_tenant_id = escape_identifier(tenant_id)
+ tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
+ else:
+ tenant_filter = ""
+
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
+ WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
- query = f"id: {execution_id} and tenant_id: {tenant_id}"
+ query = (
+ f"id:{escape_logstore_query_value(execution_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
+ )
else:
- query = f"id: {execution_id}"
+ query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
index 252cdcc4d..14382ed87 100644
--- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
@@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
+- SQL injection prevention via parameter escaping
"""
import logging
@@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
- # Numeric fields with defaults
- model.total_tokens = int(data.get("total_tokens", 0))
- model.total_steps = int(data.get("total_steps", 0))
- model.exceptions_count = int(data.get("exceptions_count", 0))
+ model.total_tokens = safe_int(data.get("total_tokens", 0))
+ model.total_steps = safe_int(data.get("total_steps", 0))
+ model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
- model.elapsed_time = float(data.get("elapsed_time", 0))
+ # Use safe conversion to handle 'null' strings and None values
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
- if isinstance(triggered_from, WorkflowRunTriggeredFrom):
+ if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
- # Build triggered_from filter
- triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
- # Build status filter
- status_filter = f"AND status='{status}'" if status else ""
+ # Build triggered_from filter with escaped values
+ # Support both enum and string values for triggered_from
+ triggered_from_filter = " OR ".join(
+ [
+ f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
+ for tf in triggered_from_list
+ ]
+ )
+
+ # Build status filter with escaped value
+ status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
- WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
+ WHERE id = '{escaped_run_id}'
+ AND tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
- query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
+ query = (
+ f"id:{escape_logstore_query_value(run_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
+ f"and app_id:{escape_logstore_query_value(app_id)}"
+ )
from_time = 0
to_time = int(time.time()) # now
@@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
+ # Escape parameter to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
- WHERE id = '{run_id}' AND __time__ > 0
+ WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
- query = f"id: {run_id}"
+ # Note: Values must be quoted in LogStore query syntax
+ query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
# Build time range filter
time_filter = ""
if time_range:
@@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
+ escaped_status = escape_sql_string(status)
+
if status == "running":
# Running status requires window function
sql = f"""
@@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
- AND status='{status}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
+ # Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by
diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
index 1119534d5..9928879a7 100644
--- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
@@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
-def to_serializable(obj):
- """
- Convert non-JSON-serializable objects into JSON-compatible formats.
-
- - Uses `to_dict()` if it's a callable method.
- - Falls back to string representation.
- """
- if hasattr(obj, "to_dict") and callable(obj.to_dict):
- return obj.to_dict()
- return str(obj)
-
-
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
- self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
@@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
+ # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
+ json_converter = WorkflowRuntimeTypeConverter()
+
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
("version", domain_model.workflow_version),
(
"graph",
- json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
+ json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
- json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
+ json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"outputs",
- json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
+ json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),
diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
index 400a08951..4897171b1 100644
--- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
@@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
- index=int(data.get("index", 0)),
+ index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
- elapsed_time=float(data.get("elapsed_time", 0.0)),
+ elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
- self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
- For LogStore implementation, this is similar to save() since we always write
- complete records. We append a new record with updated data fields.
+ For LogStore implementation, this is a no-op for the LogStore write because save()
+ already writes all fields including inputs, process_data, and outputs. The caller
+ typically calls save() first to persist status/metadata, then calls save_execution_data()
+ to persist data fields. Since LogStore writes complete records atomically, we don't
+ need a separate write here to avoid duplicate records.
+
+ However, if dual-write is enabled, we still need to call the SQL repository's
+ save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
- logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
- # In LogStore, we simply write a new complete record with the data
- # The log_version timestamp will ensure this is treated as the latest version
- self.save(execution)
+ logger.debug(
+ "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
+ execution.id,
+ execution.node_execution_id,
+ )
+ # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
+ # Calling save() again would create a duplicate record in the append-only LogStore
+
+ # Dual-write to SQL database if enabled (for safe migration)
+ if self._enable_dual_write:
+ try:
+ self.sql_repository.save_execution_data(execution)
+ logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
+ except Exception:
+ logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
+ # Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
- Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
- This ensures we only get the final version of each node execution.
+ Uses LogStore SQL query with window function to get the latest version of each node execution.
+ This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
- This method filters by finished_at IS NOT NULL to avoid duplicates from
- version updates. For complete history including intermediate states,
- a different query strategy would be needed.
+ This method uses ROW_NUMBER() window function partitioned by node_execution_id
+ to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
- # Build SQL query with deduplication using finished_at IS NOT NULL
- # This optimization avoids window functions for common case where we only
- # want the final state of each node execution
+ # Build SQL query with deduplication using window function
+ # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
+ # ensures we get the latest version of each node execution
- # Build ORDER BY clause
+ # Escape parameters to prevent SQL injection
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+ escaped_tenant_id = escape_identifier(self._tenant_id)
+
+ # Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
- sql = f"""
- SELECT *
- FROM {AliyunLogStore.workflow_node_execution_logstore}
- WHERE workflow_run_id='{workflow_run_id}'
- AND tenant_id='{self._tenant_id}'
- AND finished_at IS NOT NULL
- """
-
+ # Build app_id filter for subquery
+ app_id_filter = ""
if self._app_id:
- sql += f" AND app_id='{self._app_id}'"
+ escaped_app_id = escape_identifier(self._app_id)
+ app_id_filter = f" AND app_id='{escaped_app_id}'"
+
+ # Use window function to get latest version of each node execution
+ sql = f"""
+ SELECT * FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
+ FROM {AliyunLogStore.workflow_node_execution_logstore}
+ WHERE workflow_run_id='{escaped_workflow_run_id}'
+ AND tenant_id='{escaped_tenant_id}'
+ {app_id_filter}
+ ) t
+ WHERE rn = 1
+ """
if order_clause:
sql += f" {order_clause}"
diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py
new file mode 100644
index 000000000..d88d6bd95
--- /dev/null
+++ b/api/extensions/logstore/sql_escape.py
@@ -0,0 +1,134 @@
+"""
+SQL Escape Utility for LogStore Queries
+
+This module provides escaping utilities to prevent injection attacks in LogStore queries.
+
+LogStore supports two query modes:
+1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
+2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
+
+Key Security Concerns:
+- Prevent tenant A from accessing tenant B's data via injection
+- SLS queries are read-only, so we focus on data access control
+- Different escaping strategies for SQL vs LogStore query syntax
+"""
+
+
+def escape_sql_string(value: str) -> str:
+ """
+ Escape a string value for safe use in SQL queries.
+
+ This function escapes single quotes by doubling them, which is the standard
+ SQL escaping method. This prevents SQL injection by ensuring that user input
+ cannot break out of string literals.
+
+ Args:
+ value: The string value to escape
+
+ Returns:
+ Escaped string safe for use in SQL queries
+
+ Examples:
+ >>> escape_sql_string("normal_value")
+ "normal_value"
+ >>> escape_sql_string("value' OR '1'='1")
+ "value'' OR ''1''=''1"
+ >>> escape_sql_string("tenant's_id")
+ "tenant''s_id"
+
+ Security:
+ - Prevents breaking out of string literals
+ - Stops injection attacks like: ' OR '1'='1
+ - Protects against cross-tenant data access
+ """
+ if not value:
+ return value
+
+ # Escape single quotes by doubling them (standard SQL escaping)
+ # This prevents breaking out of string literals in SQL queries
+ return value.replace("'", "''")
+
+
+def escape_identifier(value: str) -> str:
+ """
+ Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
+
+ This function is for PG protocol mode (SQL syntax).
+ For SDK mode, use escape_logstore_query_value() instead.
+
+ Args:
+ value: The identifier value to escape
+
+ Returns:
+ Escaped identifier safe for use in SQL queries
+
+ Examples:
+ >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
+ "550e8400-e29b-41d4-a716-446655440000"
+ >>> escape_identifier("tenant_id' OR '1'='1")
+ "tenant_id'' OR ''1''=''1"
+
+ Security:
+ - Prevents SQL injection via identifiers
+ - Stops cross-tenant access attempts
+ - Works for UUIDs, alphanumeric IDs, and similar identifiers
+ """
+ # For identifiers, use the same escaping as strings
+ # This is simple and effective for preventing injection
+ return escape_sql_string(value)
+
+
+def escape_logstore_query_value(value: str) -> str:
+ """
+ Escape value for LogStore query syntax (SDK mode).
+
+ LogStore query syntax rules:
+ 1. Keywords (and/or/not) are case-insensitive
+ 2. Single quotes are ordinary characters (no special meaning)
+ 3. Double quotes wrap values: key:"value"
+ 4. Backslash is the escape character:
+ - \" for double quote inside value
+ - \\ for backslash itself
+ 5. Parentheses can change query structure
+
+ To prevent injection:
+ - Wrap value in double quotes to treat special chars as literals
+ - Escape backslashes and double quotes using backslash
+
+ Args:
+ value: The value to escape for LogStore query syntax
+
+ Returns:
+ Quoted and escaped value safe for LogStore query syntax (includes the quotes)
+
+ Examples:
+ >>> escape_logstore_query_value("normal_value")
+ '"normal_value"'
+ >>> escape_logstore_query_value("value or field:evil")
+ '"value or field:evil"' # 'or' and ':' are now literals
+ >>> escape_logstore_query_value('value"test')
+ '"value\\"test"' # Internal double quote escaped
+ >>> escape_logstore_query_value('value\\test')
+ '"value\\\\test"' # Backslash escaped
+
+ Security:
+ - Prevents injection via and/or/not keywords
+ - Prevents injection via colons (:)
+ - Prevents injection via parentheses
+ - Protects against cross-tenant data access
+
+ Note:
+ Escape order is critical: backslash first, then double quotes.
+ Otherwise, we'd double-escape the escape character itself.
+ """
+ if not value:
+ return '""'
+
+ # IMPORTANT: Escape backslashes FIRST, then double quotes
+ # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
+ escaped = value.replace("\\", "\\\\") # \ -> \\
+ escaped = escaped.replace('"', '\\"') # " -> \"
+
+ # Wrap in double quotes to treat as literal string
+ # This prevents and/or/not/:/() from being interpreted as operators
+ return f'"{escaped}"'
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index bd71f18af..0be836c8f 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -115,7 +115,18 @@ def build_from_mappings(
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
- valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
+ def is_valid_mapping(m: Mapping[str, Any]) -> bool:
+ if not m or not m.get("transfer_method"):
+ return False
+ # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
+ transfer_method = m.get("transfer_method")
+ if transfer_method == FileTransferMethod.REMOTE_URL:
+ url = m.get("url") or m.get("remote_url")
+ if not url:
+ return False
+ return True
+
+ valid_mappings = [m for m in mappings if is_valid_mapping(m)]
files = [
build_from_mapping(
mapping=mapping,
diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py
index 494194369..3f030ae12 100644
--- a/api/factories/variable_factory.py
+++ b/api/factories/variable_factory.py
@@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
- Variable,
+ VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
-def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
-def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
-def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
-def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
+def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
- result: Variable
+ result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
- return cast(Variable, result)
+ return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
-) -> Variable:
- if isinstance(segment, Variable):
+) -> VariableBase:
+ if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
- Variable,
+ VariableBase,
variable_class(
id=id,
name=name,
diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py
index 2bba198fa..c81e482f7 100644
--- a/api/fields/message_fields.py
+++ b/api/fields/message_fields.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import datetime
from typing import TypeAlias
+from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -20,8 +21,8 @@ class SimpleFeedback(ResponseModel):
class RetrieverResource(ResponseModel):
- id: str
- message_id: str
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ message_id: str = Field(default_factory=lambda: str(uuid4()))
position: int
dataset_id: str | None = None
dataset_name: str | None = None
diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py
index 0ebc03a98..ae7035632 100644
--- a/api/fields/workflow_app_log_fields.py
+++ b/api/fields/workflow_app_log_fields.py
@@ -2,7 +2,12 @@ from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
-from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields
+from fields.workflow_run_fields import (
+ build_workflow_run_for_archived_log_model,
+ build_workflow_run_for_log_model,
+ workflow_run_for_archived_log_fields,
+ workflow_run_for_log_fields,
+)
from libs.helper import TimestampField
workflow_app_log_partial_fields = {
@@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Namespace):
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
+workflow_archived_log_partial_fields = {
+ "id": fields.String,
+ "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True),
+ "trigger_metadata": fields.Raw,
+ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
+ "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
+ "created_at": TimestampField,
+}
+
+
+def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
+ """Build the workflow archived log partial model for the API or Namespace."""
+ workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
+ simple_account_model = build_simple_account_model(api_or_ns)
+ simple_end_user_model = build_simple_end_user_model(api_or_ns)
+
+ copied_fields = workflow_archived_log_partial_fields.copy()
+ copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
+ copied_fields["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+ )
+ copied_fields["created_by_end_user"] = fields.Nested(
+ simple_end_user_model, attribute="created_by_end_user", allow_null=True
+ )
+ return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)
+
+
workflow_app_log_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer,
@@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
copied_fields = workflow_app_log_pagination_fields.copy()
copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model))
return api_or_ns.model("WorkflowAppLogPagination", copied_fields)
+
+
+workflow_archived_log_pagination_fields = {
+ "page": fields.Integer,
+ "limit": fields.Integer,
+ "total": fields.Integer,
+ "has_more": fields.Boolean,
+ "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)),
+}
+
+
+def build_workflow_archived_log_pagination_model(api_or_ns: Namespace):
+ """Build the workflow archived log pagination model for the API or Namespace."""
+ workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns)
+
+ copied_fields = workflow_archived_log_pagination_fields.copy()
+ copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model))
+ return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields)
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index d037b0c44..2755f77f6 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
-from core.variables import SecretVariable, SegmentType, Variable
+from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,
diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py
index 476025064..35bb442c5 100644
--- a/api/fields/workflow_run_fields.py
+++ b/api/fields/workflow_run_fields.py
@@ -23,6 +23,19 @@ def build_workflow_run_for_log_model(api_or_ns: Namespace):
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
+workflow_run_for_archived_log_fields = {
+ "id": fields.String,
+ "status": fields.String,
+ "triggered_from": fields.String,
+ "elapsed_time": fields.Float,
+ "total_tokens": fields.Integer,
+}
+
+
+def build_workflow_run_for_archived_log_model(api_or_ns: Namespace):
+ return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields)
+
+
workflow_run_for_list_fields = {
"id": fields.String,
"version": fields.String,
diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py
index f84d22644..66b57ac66 100644
--- a/api/libs/archive_storage.py
+++ b/api/libs/archive_storage.py
@@ -7,7 +7,6 @@ to S3-compatible object storage.
import base64
import datetime
-import gzip
import hashlib
import logging
from collections.abc import Generator
@@ -39,7 +38,7 @@ class ArchiveStorage:
"""
S3-compatible storage client for archiving or exporting.
- This client provides methods for storing and retrieving archived data in JSONL+gzip format.
+ This client provides methods for storing and retrieving archived data in JSONL format.
"""
def __init__(self, bucket: str):
@@ -69,7 +68,10 @@ class ArchiveStorage:
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
region_name=dify_config.ARCHIVE_STORAGE_REGION,
- config=Config(s3={"addressing_style": "path"}),
+ config=Config(
+ s3={"addressing_style": "path"},
+ max_pool_connections=64,
+ ),
)
# Verify bucket accessibility
@@ -100,12 +102,18 @@ class ArchiveStorage:
"""
checksum = hashlib.md5(data).hexdigest()
try:
- self.client.put_object(
+ response = self.client.put_object(
Bucket=self.bucket,
Key=key,
Body=data,
ContentMD5=self._content_md5(data),
)
+ etag = response.get("ETag")
+ if not etag:
+ raise ArchiveStorageError(f"Missing ETag for '{key}'")
+ normalized_etag = etag.strip('"')
+ if normalized_etag != checksum:
+ raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}")
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
return checksum
except ClientError as e:
@@ -240,19 +248,18 @@ class ArchiveStorage:
return base64.b64encode(hashlib.md5(data).digest()).decode()
@staticmethod
- def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
+ def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes:
"""
- Serialize records to gzipped JSONL format.
+ Serialize records to JSONL format.
Args:
records: List of dictionaries to serialize
Returns:
- Gzipped JSONL bytes
+ JSONL bytes
"""
lines = []
for record in records:
- # Convert datetime objects to ISO format strings
serialized = ArchiveStorage._serialize_record(record)
lines.append(orjson.dumps(serialized))
@@ -260,23 +267,22 @@ class ArchiveStorage:
if jsonl_content:
jsonl_content += b"\n"
- return gzip.compress(jsonl_content)
+ return jsonl_content
@staticmethod
- def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
+ def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]:
"""
- Deserialize gzipped JSONL data to records.
+ Deserialize JSONL data to records.
Args:
- data: Gzipped JSONL bytes
+ data: JSONL bytes
Returns:
List of dictionaries
"""
- jsonl_content = gzip.decompress(data)
records = []
- for line in jsonl_content.splitlines():
+ for line in data.splitlines():
if line:
records.append(orjson.loads(line))
diff --git a/api/libs/smtp.py b/api/libs/smtp.py
index 4044c6f7e..6f82f1440 100644
--- a/api/libs/smtp.py
+++ b/api/libs/smtp.py
@@ -3,6 +3,8 @@ import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
+from configs import dify_config
+
logger = logging.getLogger(__name__)
@@ -19,20 +21,21 @@ class SMTPClient:
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
- smtp = None
+ smtp: smtplib.SMTP | None = None
+ local_host = dify_config.SMTP_LOCAL_HOSTNAME
try:
- if self.use_tls:
- if self.opportunistic_tls:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
- # Send EHLO command with the HELO domain name as the server address
- smtp.ehlo(self.server)
- smtp.starttls()
- # Resend EHLO command to identify the TLS session
- smtp.ehlo(self.server)
- else:
- smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
+ if self.use_tls and not self.opportunistic_tls:
+ # SMTP with SSL (implicit TLS)
+ smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10, local_hostname=local_host)
else:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
+ # Plain SMTP or SMTP with STARTTLS (explicit TLS)
+ smtp = smtplib.SMTP(self.server, self.port, timeout=10, local_hostname=local_host)
+
+ assert smtp is not None
+ if self.use_tls and self.opportunistic_tls:
+ smtp.ehlo(self.server)
+ smtp.starttls()
+ smtp.ehlo(self.server)
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():
diff --git a/api/libs/workspace_permission.py b/api/libs/workspace_permission.py
new file mode 100644
index 000000000..dd42a7fac
--- /dev/null
+++ b/api/libs/workspace_permission.py
@@ -0,0 +1,74 @@
+"""
+Workspace permission helper functions.
+
+These helpers check both billing/plan level and workspace-specific policy level permissions.
+Checks are performed at two levels:
+1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
+2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
+"""
+
+import logging
+
+from werkzeug.exceptions import Forbidden
+
+from configs import dify_config
+from services.enterprise.enterprise_service import EnterpriseService
+from services.feature_service import FeatureService
+
+logger = logging.getLogger(__name__)
+
+
+def check_workspace_member_invite_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows member invitations at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - For future expansion (currently no plan-level restriction)
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits member invitations
+ """
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_member_invite:
+ raise Forbidden("Workspace policy prohibits member invitations")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace invite permission for %s", workspace_id)
+
+
+def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows owner transfer at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - SANDBOX plan blocks owner transfer
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits ownership transfer
+ """
+ features = FeatureService.get_features(workspace_id)
+ if not features.is_allow_transfer_workspace:
+ raise Forbidden("Your current plan does not allow workspace ownership transfer")
+
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_owner_transfer:
+ raise Forbidden("Workspace policy prohibits ownership transfer")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace transfer permission for %s", workspace_id)
diff --git a/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py
new file mode 100644
index 000000000..624be1d07
--- /dev/null
+++ b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py
@@ -0,0 +1,60 @@
+"""make message annotation question not nullable
+
+Revision ID: 9e6fa5cbcd80
+Revises: 03f8dcbc611e
+Create Date: 2025-11-06 16:03:54.549378
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '9e6fa5cbcd80'
+down_revision = '288345cd01d1'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ bind = op.get_bind()
+ message_annotations = sa.table(
+ "message_annotations",
+ sa.column("id", sa.String),
+ sa.column("message_id", sa.String),
+ sa.column("question", sa.Text),
+ )
+ messages = sa.table(
+ "messages",
+ sa.column("id", sa.String),
+ sa.column("query", sa.Text),
+ )
+ update_question_from_message = (
+ sa.update(message_annotations)
+ .where(
+ sa.and_(
+ message_annotations.c.question.is_(None),
+ message_annotations.c.message_id.isnot(None),
+ )
+ )
+ .values(
+ question=sa.select(sa.func.coalesce(messages.c.query, ""))
+ .where(messages.c.id == message_annotations.c.message_id)
+ .scalar_subquery()
+ )
+ )
+ bind.execute(update_question_from_message)
+
+ fill_remaining_questions = (
+ sa.update(message_annotations)
+ .where(message_annotations.c.question.is_(None))
+ .values(question="")
+ )
+ bind.execute(fill_remaining_questions)
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False)
+
+
+def downgrade():
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True)
diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
new file mode 100644
index 000000000..7e0cc8ec9
--- /dev/null
+++ b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
@@ -0,0 +1,30 @@
+"""add workflow_run_created_at_id_idx
+
+Revision ID: 905527cc8fd3
+Revises: 7df29de0f6be
+Create Date: 2025-01-09 16:30:02.462084
+
+"""
+from alembic import op
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '905527cc8fd3'
+down_revision = '7df29de0f6be'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_run_created_at_id_idx')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
new file mode 100644
index 000000000..758369ba9
--- /dev/null
+++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
@@ -0,0 +1,33 @@
+"""feat: add created_at id index to messages
+
+Revision ID: 3334862ee907
+Revises: 905527cc8fd3
+Create Date: 2026-01-12 17:29:44.846544
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '3334862ee907'
+down_revision = '905527cc8fd3'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.drop_index('message_created_at_id_idx')
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
new file mode 100644
index 000000000..2e1af0c83
--- /dev/null
+++ b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
@@ -0,0 +1,35 @@
+"""change workflow node execution workflow_run index
+
+Revision ID: 288345cd01d1
+Revises: 3334862ee907
+Create Date: 2026-01-16 17:15:00.000000
+
+"""
+from alembic import op
+
+
+# revision identifiers, used by Alembic.
+revision = "288345cd01d1"
+down_revision = "3334862ee907"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_id_idx",
+ ["workflow_run_id"],
+ unique=False,
+ )
+
+
+def downgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_id_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_idx",
+ ["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"],
+ unique=False,
+ )
diff --git a/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py
new file mode 100644
index 000000000..b99ca04e3
--- /dev/null
+++ b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py
@@ -0,0 +1,73 @@
+"""add table explore banner and trial
+
+Revision ID: f9f6d18a37f9
+Revises: 9e6fa5cbcd80
+Create Date: 2026-01-017 11:10:18.079355
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'f9f6d18a37f9'
+down_revision = '9e6fa5cbcd80'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('account_trial_app_records',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('account_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('count', sa.Integer(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
+ sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
+ )
+ with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
+ batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
+ batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
+
+ op.create_table('exporle_banners',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('content', sa.JSON(), nullable=False),
+ sa.Column('link', sa.String(length=255), nullable=False),
+ sa.Column('sort', sa.Integer(), nullable=False),
+ sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
+ )
+ op.create_table('trial_apps',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('trial_limit', sa.Integer(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
+ sa.UniqueConstraint('app_id', name='unique_trail_app_id')
+ )
+ with op.batch_alter_table('trial_apps', schema=None) as batch_op:
+ batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
+ batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('trial_apps', schema=None) as batch_op:
+ batch_op.drop_index('trial_app_tenant_id_idx')
+ batch_op.drop_index('trial_app_app_id_idx')
+
+ op.drop_table('trial_apps')
+ op.drop_table('exporle_banners')
+ with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
+ batch_op.drop_index('account_trial_app_record_app_id_idx')
+ batch_op.drop_index('account_trial_app_record_account_id_idx')
+
+ op.drop_table('account_trial_app_records')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py
new file mode 100644
index 000000000..5e7298af5
--- /dev/null
+++ b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py
@@ -0,0 +1,95 @@
+"""create workflow_archive_logs
+
+Revision ID: 9d77545f524e
+Revises: f9f6d18a37f9
+Create Date: 2026-01-06 17:18:56.292479
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+def _is_pg(conn):
+ return conn.dialect.name == "postgresql"
+
+# revision identifiers, used by Alembic.
+revision = '9d77545f524e'
+down_revision = 'f9f6d18a37f9'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ conn = op.get_bind()
+ if _is_pg(conn):
+ op.create_table('workflow_archive_logs',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
+ sa.Column('log_id', models.types.StringUUID(), nullable=True),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_by_role', sa.String(length=255), nullable=False),
+ sa.Column('created_by', models.types.StringUUID(), nullable=False),
+ sa.Column('log_created_at', sa.DateTime(), nullable=True),
+ sa.Column('log_created_from', sa.String(length=255), nullable=True),
+ sa.Column('run_version', sa.String(length=255), nullable=False),
+ sa.Column('run_status', sa.String(length=255), nullable=False),
+ sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
+ sa.Column('run_error', models.types.LongText(), nullable=True),
+ sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('run_created_at', sa.DateTime(), nullable=False),
+ sa.Column('run_finished_at', sa.DateTime(), nullable=True),
+ sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
+ sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
+ )
+ else:
+ op.create_table('workflow_archive_logs',
+ sa.Column('id', models.types.StringUUID(), nullable=False),
+ sa.Column('log_id', models.types.StringUUID(), nullable=True),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_by_role', sa.String(length=255), nullable=False),
+ sa.Column('created_by', models.types.StringUUID(), nullable=False),
+ sa.Column('log_created_at', sa.DateTime(), nullable=True),
+ sa.Column('log_created_from', sa.String(length=255), nullable=True),
+ sa.Column('run_version', sa.String(length=255), nullable=False),
+ sa.Column('run_status', sa.String(length=255), nullable=False),
+ sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
+ sa.Column('run_error', models.types.LongText(), nullable=True),
+ sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('run_created_at', sa.DateTime(), nullable=False),
+ sa.Column('run_finished_at', sa.DateTime(), nullable=True),
+ sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
+ sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
+ )
+ with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
+ batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False)
+ batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False)
+ batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False)
+
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_archive_log_workflow_run_id_idx')
+ batch_op.drop_index('workflow_archive_log_run_created_at_idx')
+ batch_op.drop_index('workflow_archive_log_app_idx')
+
+ op.drop_table('workflow_archive_logs')
+ # ### end Alembic commands ###
diff --git a/api/models/__init__.py b/api/models/__init__.py
index 0ed7a4d79..c8f143ccf 100644
--- a/api/models/__init__.py
+++ b/api/models/__init__.py
@@ -38,6 +38,7 @@ from .enums import (
WorkflowTriggerStatus,
)
from .model import (
+ AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@@ -50,6 +51,7 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
+ ExporleBanner,
IconType,
InstalledApp,
Message,
@@ -65,6 +67,7 @@ from .model import (
TagBinding,
TenantCreditPool,
TraceAppConfig,
+ TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@@ -104,6 +107,7 @@ from .workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
+ WorkflowArchiveLog,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowNodeExecutionTriggeredFrom,
@@ -118,6 +122,7 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
+ "AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@@ -154,6 +159,7 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
+ "ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@@ -193,6 +199,7 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
+ "TrialApp",
"TriggerOAuthSystemClient",
"TriggerOAuthTenantClient",
"TriggerSubscription",
@@ -202,6 +209,7 @@ __all__ = [
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
+ "WorkflowArchiveLog",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 445ac6086..62f11b8c7 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -1149,7 +1149,7 @@ class DatasetCollectionBinding(TypeBase):
)
-class TidbAuthBinding(Base):
+class TidbAuthBinding(TypeBase):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@@ -1158,7 +1158,13 @@ class TidbAuthBinding(Base):
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1166,7 +1172,9 @@ class TidbAuthBinding(Base):
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
- created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
class Whitelist(TypeBase):
diff --git a/api/models/model.py b/api/models/model.py
index e09960acc..0d796bfbf 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -327,40 +327,48 @@ class AppStatisticsExtend(db.Model):
number = db.Column(db.Integer, nullable=False, default=0)
-class AppModelConfig(Base):
+class AppModelConfig(TypeBase):
__tablename__ = "app_model_configs"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- app_id = mapped_column(StringUUID, nullable=False)
- provider = mapped_column(String(255), nullable=True)
- model_id = mapped_column(String(255), nullable=True)
- configs = mapped_column(sa.JSON, nullable=True)
- created_by = mapped_column(StringUUID, nullable=True)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = mapped_column(StringUUID, nullable=True)
- updated_at = mapped_column(
- sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ configs: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True, default=None)
+ created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
- opening_statement = mapped_column(LongText)
- suggested_questions = mapped_column(LongText)
- suggested_questions_after_answer = mapped_column(LongText)
- speech_to_text = mapped_column(LongText)
- text_to_speech = mapped_column(LongText)
- more_like_this = mapped_column(LongText)
- model = mapped_column(LongText)
- user_input_form = mapped_column(LongText)
- dataset_query_variable = mapped_column(String(255))
- pre_prompt = mapped_column(LongText)
- agent_mode = mapped_column(LongText)
- sensitive_word_avoidance = mapped_column(LongText)
- retriever_resource = mapped_column(LongText)
- prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'"))
- chat_prompt_config = mapped_column(LongText)
- completion_prompt_config = mapped_column(LongText)
- dataset_configs = mapped_column(LongText)
- external_data_tools = mapped_column(LongText)
- file_upload = mapped_column(LongText)
+ updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ updated_at: Mapped[datetime] = mapped_column(
+ sa.DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
+ )
+ opening_statement: Mapped[str | None] = mapped_column(LongText, default=None)
+ suggested_questions: Mapped[str | None] = mapped_column(LongText, default=None)
+ suggested_questions_after_answer: Mapped[str | None] = mapped_column(LongText, default=None)
+ speech_to_text: Mapped[str | None] = mapped_column(LongText, default=None)
+ text_to_speech: Mapped[str | None] = mapped_column(LongText, default=None)
+ more_like_this: Mapped[str | None] = mapped_column(LongText, default=None)
+ model: Mapped[str | None] = mapped_column(LongText, default=None)
+ user_input_form: Mapped[str | None] = mapped_column(LongText, default=None)
+ dataset_query_variable: Mapped[str | None] = mapped_column(String(255), default=None)
+ pre_prompt: Mapped[str | None] = mapped_column(LongText, default=None)
+ agent_mode: Mapped[str | None] = mapped_column(LongText, default=None)
+ sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None)
+ retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None)
+ prompt_type: Mapped[str] = mapped_column(
+ String(255), nullable=False, server_default=sa.text("'simple'"), default="simple"
+ )
+ chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
+ completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
+ dataset_configs: Mapped[str | None] = mapped_column(LongText, default=None)
+ external_data_tools: Mapped[str | None] = mapped_column(LongText, default=None)
+ file_upload: Mapped[str | None] = mapped_column(LongText, default=None)
@property
def app(self) -> App | None:
@@ -641,6 +649,64 @@ class InstalledApp(TypeBase):
return tenant
+class TrialApp(Base):
+ __tablename__ = "trial_apps"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
+ sa.Index("trial_app_app_id_idx", "app_id"),
+ sa.Index("trial_app_tenant_id_idx", "tenant_id"),
+ sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
+ )
+
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+ trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
+
+ @property
+ def app(self) -> App | None:
+ app = db.session.query(App).where(App.id == self.app_id).first()
+ return app
+
+
+class AccountTrialAppRecord(Base):
+ __tablename__ = "account_trial_app_records"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
+ sa.Index("account_trial_app_record_account_id_idx", "account_id"),
+ sa.Index("account_trial_app_record_app_id_idx", "app_id"),
+ sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
+ )
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ account_id = mapped_column(StringUUID, nullable=False)
+ app_id = mapped_column(StringUUID, nullable=False)
+ count = mapped_column(sa.Integer, nullable=False, default=0)
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+
+ @property
+ def app(self) -> App | None:
+ app = db.session.query(App).where(App.id == self.app_id).first()
+ return app
+
+ @property
+ def user(self) -> Account | None:
+ user = db.session.query(Account).where(Account.id == self.account_id).first()
+ return user
+
+
+class ExporleBanner(Base):
+ __tablename__ = "exporle_banners"
+ __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ content = mapped_column(sa.JSON, nullable=False)
+ link = mapped_column(String(255), nullable=False)
+ sort = mapped_column(sa.Integer, nullable=False)
+ status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+ language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
+
+
class OAuthProviderApp(TypeBase):
"""
Globally shared OAuth provider app information.
@@ -790,8 +856,8 @@ class Conversation(Base):
override_model_configs = json.loads(self.override_model_configs)
if "model" in override_model_configs:
- app_model_config = AppModelConfig()
- app_model_config = app_model_config.from_model_config_dict(override_model_configs)
+ # where is app_id?
+ app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
@@ -1026,6 +1092,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
+ Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -1480,7 +1547,7 @@ class MessageAnnotation(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[str | None] = mapped_column(StringUUID)
- question: Mapped[str | None] = mapped_column(LongText, nullable=True)
+ question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1505,7 +1572,7 @@ class MessageAnnotation(Base):
return account
-class AppAnnotationHitHistory(Base):
+class AppAnnotationHitHistory(TypeBase):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1515,17 +1582,19 @@ class AppAnnotationHitHistory(Base):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- app_id = mapped_column(StringUUID, nullable=False)
+ id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- source = mapped_column(LongText, nullable=False)
- question = mapped_column(LongText, nullable=False)
- account_id = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
- message_id = mapped_column(StringUUID, nullable=False)
- annotation_question = mapped_column(LongText, nullable=False)
- annotation_content = mapped_column(LongText, nullable=False)
+ source: Mapped[str] = mapped_column(LongText, nullable=False)
+ question: Mapped[str] = mapped_column(LongText, nullable=False)
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+ score: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ annotation_question: Mapped[str] = mapped_column(LongText, nullable=False)
+ annotation_content: Mapped[str] = mapped_column(LongText, nullable=False)
@property
def account(self):
@@ -1901,7 +1970,7 @@ class MessageChain(TypeBase):
)
-class MessageAgentThought(Base):
+class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@@ -1909,34 +1978,42 @@ class MessageAgentThought(Base):
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- message_id = mapped_column(StringUUID, nullable=False)
- message_chain_id = mapped_column(StringUUID, nullable=True)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
- thought = mapped_column(LongText, nullable=True)
- tool = mapped_column(LongText, nullable=True)
- tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_input = mapped_column(LongText, nullable=True)
- observation = mapped_column(LongText, nullable=True)
+ created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
- tool_process_data = mapped_column(LongText, nullable=True)
- message = mapped_column(LongText, nullable=True)
- message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- message_unit_price = mapped_column(sa.Numeric, nullable=True)
- message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- message_files = mapped_column(LongText, nullable=True)
- answer = mapped_column(LongText, nullable=True)
- answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- answer_unit_price = mapped_column(sa.Numeric, nullable=True)
- answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- total_price = mapped_column(sa.Numeric, nullable=True)
- currency = mapped_column(String(255), nullable=True)
- latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
- created_by_role = mapped_column(String(255), nullable=False)
- created_by = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
+ tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ message_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ answer_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
+ )
@property
def files(self) -> list[Any]:
@@ -2133,7 +2210,7 @@ class TraceAppConfig(TypeBase):
}
-class TenantCreditPool(Base):
+class TenantCreditPool(TypeBase):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
@@ -2141,14 +2218,20 @@ class TenantCreditPool(Base):
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
- id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
- tenant_id = mapped_column(StringUUID, nullable=False)
- pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
- quota_limit = mapped_column(BigInteger, nullable=False, default=0)
- quota_used = mapped_column(BigInteger, nullable=False, default=0)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
- updated_at = mapped_column(
- sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
+ quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), init=False
+ )
+ updated_at: Mapped[datetime] = mapped_column(
+ sa.DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
)
@property
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 8db14c079..330fc9c0f 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -1,12 +1,10 @@
-from __future__ import annotations
-
import json
import logging
import uuid
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
-from typing import TYPE_CHECKING, Any, Union, cast
+from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -47,7 +45,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
-from core.variables import SecretVariable, Segment, SegmentType, Variable
+from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@@ -70,7 +68,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
- def value_of(cls, value: str) -> WorkflowType:
+ def value_of(cls, value: str) -> "WorkflowType":
"""
Get value of given mode.
@@ -83,7 +81,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
- def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
+ def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
"""
Get workflow type from app mode.
@@ -179,12 +177,12 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
- ) -> Workflow:
+ ) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@@ -229,8 +227,7 @@ class Workflow(Base): # bug
#
# Currently, the following functions / methods would mutate the returned dict:
#
- # - `_get_graph_and_variable_pool_of_single_iteration`.
- # - `_get_graph_and_variable_pool_of_single_loop`.
+ # - `_get_graph_and_variable_pool_for_single_node_run`.
return json.loads(self.graph) if self.graph else {}
def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
@@ -448,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
- var: Variable,
+ var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@@ -464,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
- def environment_variables(self, value: Sequence[Variable]):
+ def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@@ -488,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
- def encrypt_func(var: Variable) -> Variable:
+ def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -518,7 +515,7 @@ class Workflow(Base): # bug
return result
@property
- def conversation_variables(self) -> Sequence[Variable]:
+ def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@@ -528,7 +525,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
- def conversation_variables(self, value: Sequence[Variable]):
+ def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@@ -598,6 +595,7 @@ class WorkflowRun(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
+ sa.Index("workflow_run_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -622,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
- pause: Mapped[WorkflowPause | None] = orm.relationship(
+ pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@@ -692,7 +690,7 @@ class WorkflowRun(Base):
}
@classmethod
- def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
+ def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -783,11 +781,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
- "workflow_node_execution_workflow_run_idx",
- "tenant_id",
- "app_id",
- "workflow_id",
- "triggered_from",
+ "workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
@@ -844,7 +838,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
- offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
+ offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@@ -854,13 +848,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
- query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
+ query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
- query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
+ query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@@ -935,7 +929,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
- def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
+ def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@@ -1049,7 +1043,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
- file: Mapped[UploadFile | None] = orm.relationship(
+ file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1067,7 +1061,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
- def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
+ def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
"""
Get value of given mode.
@@ -1216,6 +1210,69 @@ class WorkflowAppLog(TypeBase):
}
+class WorkflowArchiveLog(TypeBase):
+ """
+ Workflow archive log.
+
+ Stores essential workflow run snapshot data for archived app logs.
+
+ Field sources:
+ - Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency.
+ - log_* fields: from WorkflowAppLog when present; null if the run has no app log.
+ - run_* fields: workflow run snapshot fields from WorkflowRun.
+ - trigger_metadata: snapshot from WorkflowTriggerLog when present.
+ """
+
+ __tablename__ = "workflow_archive_logs"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"),
+ sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"),
+ sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"),
+ sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"),
+ )
+
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
+
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+
+ log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
+ log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+ log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
+
+ run_version: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_status: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_error: Mapped[str | None] = mapped_column(LongText, nullable=True)
+ run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
+ run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
+ run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
+ run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
+ run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+ run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
+
+ trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True)
+ archived_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+
+ @property
+ def workflow_run_summary(self) -> dict[str, Any]:
+ return {
+ "id": self.workflow_run_id,
+ "status": self.run_status,
+ "triggered_from": self.run_triggered_from,
+ "elapsed_time": self.run_elapsed_time,
+ "total_tokens": self.run_total_tokens,
+ }
+
+
class ConversationVariable(TypeBase):
__tablename__ = "workflow_conversation_variables"
@@ -1231,7 +1288,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
- def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
+ def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@@ -1240,7 +1297,7 @@ class ConversationVariable(TypeBase):
)
return obj
- def to_variable(self) -> Variable:
+ def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@@ -1384,7 +1441,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
- variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
+ variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1554,7 +1611,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@@ -1577,7 +1634,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@@ -1598,7 +1655,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@@ -1621,7 +1678,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=node_id,
@@ -1717,7 +1774,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
- upload_file: Mapped[UploadFile] = orm.relationship(
+ upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@@ -1784,7 +1841,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
- workflow_run: Mapped[WorkflowRun] = orm.relationship(
+ workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@@ -1840,7 +1897,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
- def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
+ def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 07f132999..9e641503d 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "dify-api"
-version = "1.11.2"
+version = "1.11.4"
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -31,7 +31,7 @@ dependencies = [
"gunicorn~=23.0.0",
"httpx[socks]~=0.27.0",
"jieba==0.42.1",
- "json-repair>=0.41.1",
+ "json-repair>=0.55.1",
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
@@ -93,6 +93,7 @@ dependencies = [
"weaviate-client==4.17.0",
"apscheduler>=3.11.0",
"weave>=0.52.16",
+ "fastopenapi[flask]>=0.7.0",
##### start extend ######
"tokenizers~=0.22.0",
"validators>=0.34.0",
@@ -197,7 +198,7 @@ storage = [
"opendal~=0.46.0",
"oss2==2.18.5",
"supabase~=2.18.1",
- "tos~=2.7.1",
+ "tos~=2.9.0",
]
############################################################
diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json
index f295d2fb1..658296b2b 100644
--- a/api/pyrightconfig.json
+++ b/api/pyrightconfig.json
@@ -8,6 +8,7 @@
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [
+ "fastopenapi",
"flask_restx",
"flask_login",
"opentelemetry.instrumentation.celery",
diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py
index fa2c94b62..5b3f63530 100644
--- a/api/repositories/api_workflow_node_execution_repository.py
+++ b/api/repositories/api_workflow_node_execution_repository.py
@@ -13,8 +13,10 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
-from models.workflow import WorkflowNodeExecutionModel
+from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
@@ -130,6 +132,18 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
"""
...
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions and offloads for the given workflow run ids.
+ """
+ ...
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions and offloads for the given workflow run ids.
+ """
+ ...
+
def delete_executions_by_app(
self,
tenant_id: str,
@@ -195,3 +209,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
The number of executions deleted
"""
...
+
+ def get_offloads_by_execution_ids(
+ self,
+ session: Session,
+ node_execution_ids: Sequence[str],
+ ) -> Sequence[WorkflowNodeExecutionOffload]:
+ """
+ Get offload records by node execution IDs.
+
+ This method retrieves workflow node execution offload records
+ that belong to the given node execution IDs.
+
+ Args:
+ session: The database session to use
+ node_execution_ids: List of node execution IDs to filter by
+
+ Returns:
+ A sequence of WorkflowNodeExecutionOffload instances
+ """
+ ...
diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py
index fd547c78b..1d3954571 100644
--- a/api/repositories/api_workflow_run_repository.py
+++ b/api/repositories/api_workflow_run_repository.py
@@ -34,15 +34,18 @@ Example:
```
"""
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.entities.pause_reason import PauseReason
+from core.workflow.enums import WorkflowType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
-from models.workflow import WorkflowRun
+from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
AverageInteractionStats,
@@ -253,6 +256,151 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+ """
+ ...
+
+ def get_archived_run_ids(
+ self,
+ session: Session,
+ run_ids: Sequence[str],
+ ) -> set[str]:
+ """
+ Fetch workflow run IDs that already have archive log records.
+ """
+ ...
+
+ def get_archived_logs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowArchiveLog]:
+ """
+ Fetch archived workflow logs by time range for restore.
+ """
+ ...
+
+ def get_archived_log_by_run_id(
+ self,
+ run_id: str,
+ ) -> WorkflowArchiveLog | None:
+ """
+ Fetch a workflow archive log by workflow run ID.
+ """
+ ...
+
+ def delete_archive_log_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> int:
+ """
+ Delete archive log by workflow run ID.
+
+ Used after restoring a workflow run to remove the archive log record,
+ allowing the run to be archived again if needed.
+
+ Args:
+ session: Database session
+ run_id: Workflow run ID
+
+ Returns:
+ Number of records deleted (0 or 1)
+ """
+ ...
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Delete workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons).
+ """
+ ...
+
+ def get_pause_records_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowPause]:
+ """
+ Fetch workflow pause records by workflow run ID.
+ """
+ ...
+
+ def get_pause_reason_records_by_run_id(
+ self,
+ session: Session,
+ pause_ids: Sequence[str],
+ ) -> Sequence[WorkflowPauseReason]:
+ """
+ Fetch workflow pause reason records by pause IDs.
+ """
+ ...
+
+ def get_app_logs_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowAppLog]:
+ """
+ Fetch workflow app logs by workflow run ID.
+ """
+ ...
+
+ def create_archive_logs(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ app_logs: Sequence[WorkflowAppLog],
+ trigger_metadata: str | None,
+ ) -> int:
+ """
+ Create archive log records for a workflow run.
+ """
+ ...
+
+ def get_archived_runs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Return workflow runs that already have archive logs, for cleanup of `workflow_runs`.
+ """
+ ...
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Count workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons) without deleting data.
+ """
+ ...
+
def create_workflow_pause(
self,
workflow_run_id: str,
diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
index 7e2173acd..b19cc73bd 100644
--- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
@@ -9,11 +9,14 @@ from collections.abc import Sequence
from datetime import datetime
from typing import cast
-from sqlalchemy import asc, delete, desc, select
+from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
-from models.workflow import WorkflowNodeExecutionModel
+from models.workflow import (
+ WorkflowNodeExecutionModel,
+ WorkflowNodeExecutionOffload,
+)
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -290,3 +293,85 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
result = cast(CursorResult, session.execute(stmt))
session.commit()
return result.rowcount
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+
+ offloads_deleted = (
+ cast(
+ CursorResult,
+ session.execute(
+ delete(WorkflowNodeExecutionOffload).where(
+ WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
+ )
+ ),
+ ).rowcount
+ or 0
+ )
+
+ node_executions_deleted = (
+ cast(
+ CursorResult,
+ session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)),
+ ).rowcount
+ or 0
+ )
+
+ return node_executions_deleted, offloads_deleted
+
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+
+ node_executions_count = (
+ session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0
+ )
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+ offloads_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowNodeExecutionOffload)
+ .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
+ )
+ or 0
+ )
+
+ return int(node_executions_count), int(offloads_count)
+
+ @staticmethod
+ def get_by_run(
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Fetch node executions for a run using workflow_run_id.
+ """
+ stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def get_offloads_by_execution_ids(
+ self,
+ session: Session,
+ node_execution_ids: Sequence[str],
+ ) -> Sequence[WorkflowNodeExecutionOffload]:
+ if not node_execution_ids:
+ return []
+
+ stmt = select(WorkflowNodeExecutionOffload).where(
+ WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
+ )
+ return list(session.scalars(stmt))
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
index b172c6a3a..d5214be04 100644
--- a/api/repositories/sqlalchemy_api_workflow_run_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -21,7 +21,7 @@ Implementation Notes:
import logging
import uuid
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from decimal import Decimal
from typing import Any, cast
@@ -32,7 +32,7 @@ from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
-from core.workflow.enums import WorkflowExecutionStatus
+from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import convert_datetime_to_date
@@ -40,8 +40,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
-from models.workflow import WorkflowPause as WorkflowPauseModel
-from models.workflow import WorkflowPauseReason, WorkflowRun
+from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
@@ -314,6 +313,335 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
return total_deleted
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+
+ Query scope:
+ - created_at in [start_from, end_before)
+ - type in run_types (when provided)
+ - status is an ended state
+ - optional tenant_id filter and cursor (last_seen) for pagination
+ """
+ with self._session_maker() as session:
+ stmt = (
+ select(WorkflowRun)
+ .where(
+ WorkflowRun.created_at < end_before,
+ WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
+ )
+ .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
+ .limit(batch_size)
+ )
+ if run_types is not None:
+ if not run_types:
+ return []
+ stmt = stmt.where(WorkflowRun.type.in_(run_types))
+
+ if start_from:
+ stmt = stmt.where(WorkflowRun.created_at >= start_from)
+
+ if tenant_ids:
+ stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids))
+
+ if last_seen:
+ stmt = stmt.where(
+ or_(
+ WorkflowRun.created_at > last_seen[0],
+ and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
+ )
+ )
+
+ return session.scalars(stmt).all()
+
+ def get_archived_run_ids(
+ self,
+ session: Session,
+ run_ids: Sequence[str],
+ ) -> set[str]:
+ if not run_ids:
+ return set()
+
+ stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids))
+ return set(session.scalars(stmt).all())
+
+ def get_archived_log_by_run_id(
+ self,
+ run_id: str,
+ ) -> WorkflowArchiveLog | None:
+ with self._session_maker() as session:
+ stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1)
+ return session.scalar(stmt)
+
+ def delete_archive_log_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> int:
+ stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id)
+ result = session.execute(stmt)
+ return cast(CursorResult, result).rowcount or 0
+
+ def get_pause_records_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowPause]:
+ stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def get_pause_reason_records_by_run_id(
+ self,
+ session: Session,
+ pause_ids: Sequence[str],
+ ) -> Sequence[WorkflowPauseReason]:
+ if not pause_ids:
+ return []
+
+ stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ return list(session.scalars(stmt))
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if delete_node_executions:
+ node_executions_deleted, offloads_deleted = delete_node_executions(session, runs)
+ else:
+ node_executions_deleted, offloads_deleted = 0, 0
+
+ app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))
+ app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0
+
+ pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
+ pause_ids = session.scalars(pause_stmt).all()
+ pause_reasons_deleted = 0
+ pauses_deleted = 0
+
+ if pause_ids:
+ pause_reasons_result = session.execute(
+ delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0
+ pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids)))
+ pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
+
+ trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
+
+ runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
+ runs_deleted = cast(CursorResult, runs_result).rowcount or 0
+
+ session.commit()
+
+ return {
+ "runs": runs_deleted,
+ "node_executions": node_executions_deleted,
+ "offloads": offloads_deleted,
+ "app_logs": app_logs_deleted,
+ "trigger_logs": trigger_logs_deleted,
+ "pauses": pauses_deleted,
+ "pause_reasons": pause_reasons_deleted,
+ }
+
+ def get_app_logs_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowAppLog]:
+ stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def create_archive_logs(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ app_logs: Sequence[WorkflowAppLog],
+ trigger_metadata: str | None,
+ ) -> int:
+ if not app_logs:
+ archive_log = WorkflowArchiveLog(
+ log_id=None,
+ log_created_at=None,
+ log_created_from=None,
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_id=run.workflow_id,
+ workflow_run_id=run.id,
+ created_by_role=run.created_by_role,
+ created_by=run.created_by,
+ run_version=run.version,
+ run_status=run.status,
+ run_triggered_from=run.triggered_from,
+ run_error=run.error,
+ run_elapsed_time=run.elapsed_time,
+ run_total_tokens=run.total_tokens,
+ run_total_steps=run.total_steps,
+ run_created_at=run.created_at,
+ run_finished_at=run.finished_at,
+ run_exceptions_count=run.exceptions_count,
+ trigger_metadata=trigger_metadata,
+ )
+ session.add(archive_log)
+ return 1
+
+ archive_logs = [
+ WorkflowArchiveLog(
+ log_id=app_log.id,
+ log_created_at=app_log.created_at,
+ log_created_from=app_log.created_from,
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_id=run.workflow_id,
+ workflow_run_id=run.id,
+ created_by_role=run.created_by_role,
+ created_by=run.created_by,
+ run_version=run.version,
+ run_status=run.status,
+ run_triggered_from=run.triggered_from,
+ run_error=run.error,
+ run_elapsed_time=run.elapsed_time,
+ run_total_tokens=run.total_tokens,
+ run_total_steps=run.total_steps,
+ run_created_at=run.created_at,
+ run_finished_at=run.finished_at,
+ run_exceptions_count=run.exceptions_count,
+ trigger_metadata=trigger_metadata,
+ )
+ for app_log in app_logs
+ ]
+ session.add_all(archive_logs)
+ return len(archive_logs)
+
+ def get_archived_runs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Retrieves WorkflowRun records by joining workflow_archive_logs.
+
+ Used to identify runs that are already archived and ready for deletion.
+ """
+ stmt = (
+ select(WorkflowRun)
+ .join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id)
+ .where(
+ WorkflowArchiveLog.run_created_at >= start_date,
+ WorkflowArchiveLog.run_created_at < end_date,
+ )
+ .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
+ .limit(limit)
+ )
+ if tenant_ids:
+ stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
+ return list(session.scalars(stmt))
+
+ def get_archived_logs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowArchiveLog]:
+ # Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted.
+ stmt = (
+ select(WorkflowArchiveLog)
+ .where(
+ WorkflowArchiveLog.run_created_at >= start_date,
+ WorkflowArchiveLog.run_created_at < end_date,
+ )
+ .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
+ .limit(limit)
+ )
+ if tenant_ids:
+ stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
+ return list(session.scalars(stmt))
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if count_node_executions:
+ node_executions_count, offloads_count = count_node_executions(session, runs)
+ else:
+ node_executions_count, offloads_count = 0, 0
+
+ app_logs_count = (
+ session.scalar(
+ select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))
+ )
+ or 0
+ )
+
+ pause_ids = session.scalars(
+ select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
+ ).all()
+ pauses_count = len(pause_ids)
+ pause_reasons_count = 0
+ if pause_ids:
+ pause_reasons_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowPauseReason)
+ .where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ or 0
+ )
+
+ trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0
+
+ return {
+ "runs": len(runs),
+ "node_executions": node_executions_count,
+ "offloads": offloads_count,
+ "app_logs": int(app_logs_count),
+ "trigger_logs": trigger_logs_count,
+ "pauses": pauses_count,
+ "pause_reasons": int(pause_reasons_count),
+ }
+
def create_workflow_pause(
self,
workflow_run_id: str,
@@ -340,9 +668,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
RuntimeError: If workflow is already paused or in invalid state
"""
- previous_pause_model_query = select(WorkflowPauseModel).where(
- WorkflowPauseModel.workflow_run_id == workflow_run_id
- )
+ previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id)
with self._session_maker() as session, session.begin():
# Get the workflow run
workflow_run = session.get(WorkflowRun, workflow_run_id)
@@ -367,7 +693,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Upload the state file
# Create the pause record
- pause_model = WorkflowPauseModel()
+ pause_model = WorkflowPause()
pause_model.id = str(uuidv7())
pause_model.workflow_id = workflow_run.workflow_id
pause_model.workflow_run_id = workflow_run.id
@@ -539,13 +865,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"""
with self._session_maker() as session, session.begin():
# Get the pause model by ID
- pause_model = session.get(WorkflowPauseModel, pause_entity.id)
+ pause_model = session.get(WorkflowPause, pause_entity.id)
if pause_model is None:
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
self._delete_pause_model(session, pause_model)
@staticmethod
- def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
+ def _delete_pause_model(session: Session, pause_model: WorkflowPause):
storage.delete(pause_model.state_object_key)
# Delete the pause record
@@ -580,15 +906,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
_limit: int = limit or 1000
pruned_record_ids: list[str] = []
cond = or_(
- WorkflowPauseModel.created_at < expiration,
+ WorkflowPause.created_at < expiration,
and_(
- WorkflowPauseModel.resumed_at.is_not(null()),
- WorkflowPauseModel.resumed_at < resumption_expiration,
+ WorkflowPause.resumed_at.is_not(null()),
+ WorkflowPause.resumed_at < resumption_expiration,
),
)
# First, collect pause records to delete with their state files
# Expired pauses (created before expiration time)
- stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
+ stmt = select(WorkflowPause).where(cond).limit(_limit)
with self._session_maker(expire_on_commit=False) as session:
# Old resumed pauses (resumed more than resumption_duration ago)
@@ -599,7 +925,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Delete state files from storage
for pause in pauses_to_delete:
with self._session_maker(expire_on_commit=False) as session, session.begin():
- # todo: this issues a separate query for each WorkflowPauseModel record.
+ # todo: this issues a separate query for each WorkflowPause record.
# consider batching this lookup.
try:
storage.delete(pause.state_object_key)
@@ -851,7 +1177,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
def __init__(
self,
*,
- pause_model: WorkflowPauseModel,
+ pause_model: WorkflowPause,
reason_models: Sequence[WorkflowPauseReason],
human_input_form: Sequence = (),
) -> None:
diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
index 0d67e286b..f3dc4cd60 100644
--- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
+++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
@@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta
+from typing import cast
-from sqlalchemy import and_, select
+from sqlalchemy import and_, delete, func, select
+from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from models.enums import WorkflowTriggerStatus
@@ -44,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
return self.session.scalar(query)
+ def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]:
+ """List trigger logs for a workflow run."""
+ query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id)
+ return list(self.session.scalars(query).all())
+
def get_failed_for_retry(
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> Sequence[WorkflowTriggerLog]:
@@ -84,3 +91,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
)
return list(self.session.scalars(query).all())
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows deleted.
+ """
+ if not run_ids:
+ return 0
+
+ result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)))
+ return cast(CursorResult, result).rowcount or 0
+
+ def count_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Count trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows matched.
+ """
+ if not run_ids:
+ return 0
+
+ count = self.session.scalar(
+ select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
+ )
+ return int(count or 0)
diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py
index 138b8779a..b0009e398 100644
--- a/api/repositories/workflow_trigger_log_repository.py
+++ b/api/repositories/workflow_trigger_log_repository.py
@@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol):
A sequence of recent WorkflowTriggerLog instances
"""
...
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs for workflow run IDs.
+
+ Args:
+ run_ids: Workflow run IDs to delete
+
+ Returns:
+ Number of rows deleted
+ """
+ ...
diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py
index 352a84b59..be5f483b9 100644
--- a/api/schedule/clean_messages.py
+++ b/api/schedule/clean_messages.py
@@ -1,90 +1,78 @@
-import datetime
import logging
import time
import click
-from sqlalchemy.exc import SQLAlchemyError
+from redis.exceptions import LockError
import app
from configs import dify_config
-from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
-from models.model import (
- App,
- Message,
- MessageAgentThought,
- MessageAnnotation,
- MessageChain,
- MessageFeedback,
- MessageFile,
-)
-from models.web import SavedMessage
-from services.feature_service import FeatureService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
logger = logging.getLogger(__name__)
-@app.celery.task(queue="dataset")
+@app.celery.task(queue="retention")
def clean_messages():
- click.echo(click.style("Start clean messages.", fg="green"))
- start_at = time.perf_counter()
- plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
- days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
- )
- while True:
- try:
- # Main query with join and filter
- messages = (
- db.session.query(Message)
- .where(Message.created_at < plan_sandbox_clean_message_day)
- .order_by(Message.created_at.desc())
- .limit(100)
- .all()
- )
+ """
+ Clean expired messages based on clean policy.
- except SQLAlchemyError:
- raise
- if not messages:
- break
- for message in messages:
- app = db.session.query(App).filter_by(id=message.app_id).first()
- if not app:
- logger.warning(
- "Expected App record to exist, but none was found, app_id=%s, message_id=%s",
- message.app_id,
- message.id,
- )
- continue
- features_cache_key = f"features:{app.tenant_id}"
- plan_cache = redis_client.get(features_cache_key)
- if plan_cache is None:
- features = FeatureService.get_features(app.tenant_id)
- redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
- plan = features.billing.subscription.plan
- else:
- plan = plan_cache.decode()
- if plan == CloudPlan.SANDBOX:
- # clean related message
- db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(Message).where(Message.id == message.id).delete()
- db.session.commit()
- end_at = time.perf_counter()
- click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
+ This task uses MessagesCleanService to efficiently clean messages in batches.
+ The behavior depends on BILLING_ENABLED configuration:
+ - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period)
+ - BILLING_ENABLED=False: delete all messages within the time range
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ policy = create_message_clean_policy(
+ graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
+ )
+
+ # Create and run the cleanup service
+ # lock the task to avoid concurrent execution in case of the future data volume growth
+ with redis_client.lock(
+ "retention:clean_messages", timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, blocking=False
+ ):
+ service = MessagesCleanService.from_days(
+ policy=policy,
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except LockError:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages: acquire task lock failed, skip current execution")
+ click.echo(
+ click.style(
+ f"clean_messages: skipped (lock already held) - latency: {end_at - start_at:.2f}s",
+ fg="yellow",
+ )
+ )
+ raise
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py
new file mode 100644
index 000000000..ff45a3ddf
--- /dev/null
+++ b/api/schedule/clean_workflow_runs_task.py
@@ -0,0 +1,79 @@
+import logging
+from datetime import UTC, datetime
+
+import click
+from redis.exceptions import LockError
+
+import app
+from configs import dify_config
+from extensions.ext_redis import redis_client
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
+
+logger = logging.getLogger(__name__)
+
+
+@app.celery.task(queue="retention")
+def clean_workflow_runs_task() -> None:
+ """
+ Scheduled cleanup for workflow runs and related records (sandbox tenants only).
+ """
+ click.echo(
+ click.style(
+ (
+ "Scheduled workflow run cleanup starting: "
+ f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, "
+ f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}"
+ ),
+ fg="green",
+ )
+ )
+
+ start_time = datetime.now(UTC)
+
+ try:
+ # lock the task to avoid concurrent execution in case of the future data volume growth
+ with redis_client.lock(
+ "retention:clean_workflow_runs_task",
+ timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL,
+ blocking=False,
+ ):
+ WorkflowRunCleanup(
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ start_from=None,
+ end_before=None,
+ ).run()
+
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+ except LockError:
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ logger.exception("clean_workflow_runs_task: acquire task lock failed, skip current execution")
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup skipped (lock already held). "
+ f"start={start_time.isoformat()} end={end_time.isoformat()} duration={elapsed}",
+ fg="yellow",
+ )
+ )
+ raise
+ except Exception as e:
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ logger.exception("clean_workflow_runs_task failed")
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup failed. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed} - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py
index c343063fa..ed46c1c70 100644
--- a/api/schedule/create_tidb_serverless_task.py
+++ b/api/schedule/create_tidb_serverless_task.py
@@ -50,10 +50,13 @@ def create_clusters(batch_size):
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
+ tenant_id=None,
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
+ active=False,
+ status="CREATING",
)
db.session.add(tidb_auth_binding)
db.session.commit()
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 556154bf0..d14963e35 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -749,6 +749,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
+ @staticmethod
+ def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
+ """
+ Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
+
+ This keeps backward compatibility for older records that stored uppercase emails while the
+ rest of the system gradually normalizes new inputs.
+ """
+ query_session = session or db.session
+ account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ if account or email == email.lower():
+ return account
+
+ return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
+
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@@ -1382,16 +1397,27 @@ class RegisterService:
if not inviter:
raise ValueError("Inviter is required")
+ normalized_email = email.lower()
+
"""Invite new member"""
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
with Session(db.engine) as session:
- account = session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
- name = email.split("@")[0]
+ name = normalized_email.split("@")[0]
account = cls.register(
- email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
+ email=normalized_email,
+ name=name,
+ language=language,
+ status=AccountStatus.PENDING,
+ is_setup=True,
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
@@ -1413,7 +1439,7 @@ class RegisterService:
# send email
send_invite_member_mail_task.delay(
language=language,
- to=email,
+ to=account.email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
@@ -1512,6 +1538,16 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
+ @classmethod
+ def get_invitation_with_case_fallback(
+ cls, workspace_id: str | None, email: str | None, token: str
+ ) -> dict[str, Any] | None:
+ invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
+ if invitation or not email or email == email.lower():
+ return invitation
+ normalized_email = email.lower()
+ return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
+
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index b73302508..56e9cc6a0 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -209,8 +209,12 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
+ question = args.get("question")
+ if question is None:
+ raise ValueError("'question' is required")
+
annotation = MessageAnnotation(
- app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
+ app_id=app.id, content=args["answer"], question=question, account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
@@ -219,7 +223,7 @@ class AppAnnotationService:
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
- args["question"],
+ question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
@@ -244,8 +248,12 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
+ question = args.get("question")
+ if question is None:
+ raise ValueError("'question' is required")
+
annotation.content = args["answer"]
- annotation.question = args["question"]
+ annotation.question = question
db.session.commit()
# if annotation reply is enabled , add annotation to index
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index acd2a25a8..da22464d3 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -521,12 +521,10 @@ class AppDslService:
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
- app_model_config = AppModelConfig().from_model_config_dict(model_config)
+ app_model_config = AppModelConfig(
+ app_id=app.id, created_by=account.id, updated_by=account.id
+ ).from_model_config_dict(model_config)
app_model_config.id = str(uuid4())
- app_model_config.app_id = app.id
- app_model_config.created_by = account.id
- app_model_config.updated_by = account.id
-
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
diff --git a/api/services/app_service.py b/api/services/app_service.py
index 31f6a4960..7ab57d5a4 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -169,10 +169,9 @@ class AppService:
db.session.flush()
if default_model_config:
- app_model_config = AppModelConfig(**default_model_config)
- app_model_config.app_id = app.id
- app_model_config.created_by = account.id
- app_model_config.updated_by = account.id
+ app_model_config = AppModelConfig(
+ **default_model_config, app_id=app.id, created_by=account.id, updated_by=account.id
+ )
db.session.add(app_model_config)
db.session.flush()
diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py
index acc0ec2b2..92008d5ff 100644
--- a/api/services/conversation_variable_updater.py
+++ b/api/services/conversation_variable_updater.py
@@ -1,7 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
-from core.variables.variables import Variable
+from core.variables.variables import VariableBase
from models import ConversationVariable
@@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
- def update(self, conversation_id: str, variable: Variable) -> None:
+ def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 18e561343..be9a0e927 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -13,10 +13,11 @@ import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
-from werkzeug.exceptions import NotFound
+from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
+from core.file import helpers as file_helpers
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
@@ -73,6 +74,7 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
+from services.file_service import FileService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tag_service import TagService
from services.vector_service import VectorService
@@ -1162,6 +1164,7 @@ class DocumentService:
Document.archived.is_(True),
),
}
+ DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION = ".zip"
@classmethod
def normalize_display_status(cls, status: str | None) -> str | None:
@@ -1288,6 +1291,143 @@ class DocumentService:
else:
return None
+ @staticmethod
+ def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]:
+ """Fetch documents for a dataset in a single batch query."""
+ if not document_ids:
+ return []
+ document_id_list: list[str] = [str(document_id) for document_id in document_ids]
+ # Fetch all requested documents in one query to avoid N+1 lookups.
+ documents: Sequence[Document] = db.session.scalars(
+ select(Document).where(
+ Document.dataset_id == dataset_id,
+ Document.id.in_(document_id_list),
+ )
+ ).all()
+ return documents
+
+ @staticmethod
+ def get_document_download_url(document: Document) -> str:
+ """
+ Return a signed download URL for an upload-file document.
+ """
+ upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
+ return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
+
+ @staticmethod
+ def prepare_document_batch_download_zip(
+ *,
+ dataset_id: str,
+ document_ids: Sequence[str],
+ tenant_id: str,
+ current_user: Account,
+ ) -> tuple[list[UploadFile], str]:
+ """
+ Resolve upload files for batch ZIP downloads and generate a client-visible filename.
+ """
+ dataset = DatasetService.get_dataset(dataset_id)
+ if not dataset:
+ raise NotFound("Dataset not found.")
+ try:
+ DatasetService.check_dataset_permission(dataset, current_user)
+ except NoPermissionError as e:
+ raise Forbidden(str(e))
+
+ upload_files_by_document_id = DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset_id,
+ document_ids=document_ids,
+ tenant_id=tenant_id,
+ )
+ upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids]
+ download_name = DocumentService._generate_document_batch_download_zip_filename()
+ return upload_files, download_name
+
+ @staticmethod
+ def _generate_document_batch_download_zip_filename() -> str:
+ """
+ Generate a random attachment filename for the batch download ZIP.
+ """
+ return f"{uuid.uuid4().hex}{DocumentService.DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION}"
+
+ @staticmethod
+ def _get_upload_file_id_for_upload_file_document(
+ document: Document,
+ *,
+ invalid_source_message: str,
+ missing_file_message: str,
+ ) -> str:
+ """
+ Normalize and validate `Document -> UploadFile` linkage for download flows.
+ """
+ if document.data_source_type != "upload_file":
+ raise NotFound(invalid_source_message)
+
+ data_source_info: dict[str, Any] = document.data_source_info_dict or {}
+ upload_file_id: str | None = data_source_info.get("upload_file_id")
+ if not upload_file_id:
+ raise NotFound(missing_file_message)
+
+ return str(upload_file_id)
+
+ @staticmethod
+ def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile:
+ """
+ Load the `UploadFile` row for an upload-file document.
+ """
+ upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="Document does not have an uploaded file to download.",
+ missing_file_message="Uploaded file not found.",
+ )
+ upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
+ upload_file = upload_files_by_id.get(upload_file_id)
+ if not upload_file:
+ raise NotFound("Uploaded file not found.")
+ return upload_file
+
+ @staticmethod
+ def _get_upload_files_by_document_id_for_zip_download(
+ *,
+ dataset_id: str,
+ document_ids: Sequence[str],
+ tenant_id: str,
+ ) -> dict[str, UploadFile]:
+ """
+ Batch load upload files keyed by document id for ZIP downloads.
+ """
+ document_id_list: list[str] = [str(document_id) for document_id in document_ids]
+
+ documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list)
+ documents_by_id: dict[str, Document] = {str(document.id): document for document in documents}
+
+ missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys())
+ if missing_document_ids:
+ raise NotFound("Document not found.")
+
+ upload_file_ids: list[str] = []
+ upload_file_ids_by_document_id: dict[str, str] = {}
+ for document_id, document in documents_by_id.items():
+ if document.tenant_id != tenant_id:
+ raise Forbidden("No permission.")
+
+ upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="Only uploaded-file documents can be downloaded as ZIP.",
+ missing_file_message="Only uploaded-file documents can be downloaded as ZIP.",
+ )
+ upload_file_ids.append(upload_file_id)
+ upload_file_ids_by_document_id[document_id] = upload_file_id
+
+ upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
+ missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
+ if missing_upload_file_ids:
+ raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")
+
+ return {
+ document_id: upload_files_by_id[upload_file_id]
+ for document_id, upload_file_id in upload_file_ids_by_document_id.items()
+ }
+
@staticmethod
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py
index bdc960aa2..e3832475a 100644
--- a/api/services/enterprise/base.py
+++ b/api/services/enterprise/base.py
@@ -1,9 +1,14 @@
+import logging
import os
from collections.abc import Mapping
from typing import Any
import httpx
+from core.helper.trace_id_helper import generate_traceparent_header
+
+logger = logging.getLogger(__name__)
+
class BaseRequest:
proxies: Mapping[str, str] | None = {
@@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
+
+ try:
+ # ensure traceparent even when OTEL is disabled
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+ except Exception:
+ logger.debug("Failed to generate traceparent header", exc_info=True)
+
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()
diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py
index c0cc0e523..a5133dfcb 100644
--- a/api/services/enterprise/enterprise_service.py
+++ b/api/services/enterprise/enterprise_service.py
@@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
)
+class WorkspacePermission(BaseModel):
+ workspace_id: str = Field(
+ description="The ID of the workspace.",
+ alias="workspaceId",
+ )
+ allow_member_invite: bool = Field(
+ description="Whether to allow members to invite new members to the workspace.",
+ default=False,
+ alias="allowMemberInvite",
+ )
+ allow_owner_transfer: bool = Field(
+ description="Whether to allow owners to transfer ownership of the workspace.",
+ default=False,
+ alias="allowOwnerTransfer",
+ )
+
+
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -44,6 +61,16 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
+ class WorkspacePermissionService:
+ @classmethod
+ def get_permission(cls, workspace_id: str):
+ if not workspace_id:
+ raise ValueError("workspace_id must be provided.")
+ data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
+ if not data or "permission" not in data:
+ raise ValueError("No data found.")
+ return WorkspacePermission.model_validate(data["permission"])
+
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index 5de9aa459..5f3cdc042 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -180,6 +180,8 @@ class SystemFeatureModel(BaseModel):
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
+ enable_trial_app: bool = False
+ enable_explore_banner: bool = False
is_custom_auth2: str = "" # extend: Customizing AUTH2
is_custom_auth2_logout: str = "" # extend: Customizing AUTH2
ding_talk_client_id: str = "" # extend: DingTalk third-party login
@@ -215,7 +217,7 @@ class FeatureService:
return knowledge_rate_limit
@classmethod
- def get_system_features(cls) -> SystemFeatureModel:
+ def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
system_features = SystemFeatureModel()
# extend start: oauth2
# 检查是否有请求上下文(在 Celery worker 中可能没有)
@@ -237,7 +239,7 @@ class FeatureService:
system_features.webapp_auth.enabled = True
system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
- cls._fulfill_params_from_enterprise(system_features)
+ cls._fulfill_params_from_enterprise(system_features, is_authenticated)
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
@@ -252,6 +254,8 @@ class FeatureService:
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
+ system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
+ system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
# extend start: DingTalk third-party login
# 检查是否有应用上下文(访问 db.session 需要应用上下文)
if has_app_context():
@@ -350,7 +354,7 @@ class FeatureService:
features.next_credit_reset_date = billing_info["next_credit_reset_date"]
@classmethod
- def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
+ def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False):
enterprise_info = EnterpriseService.get_info()
if "SSOEnforcedForSignin" in enterprise_info:
@@ -387,19 +391,14 @@ class FeatureService:
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
- if "License" in enterprise_info:
- license_info = enterprise_info["License"]
+ if is_authenticated and (license_info := enterprise_info.get("License")):
+ features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
+ features.license.expired_at = license_info.get("expiredAt", "")
- if "status" in license_info:
- features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
-
- if "expiredAt" in license_info:
- features.license.expired_at = license_info["expiredAt"]
-
- if "workspaces" in license_info:
- features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
- features.license.workspaces.limit = license_info["workspaces"]["limit"]
- features.license.workspaces.size = license_info["workspaces"]["used"]
+ if workspaces_info := license_info.get("workspaces"):
+ features.license.workspaces.enabled = workspaces_info.get("enabled", False)
+ features.license.workspaces.limit = workspaces_info.get("limit", 0)
+ features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
diff --git a/api/services/file_service.py b/api/services/file_service.py
index 0911cf38c..a0a99f3f8 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -2,7 +2,11 @@ import base64
import hashlib
import os
import uuid
+from collections.abc import Iterator, Sequence
+from contextlib import contextmanager, suppress
+from tempfile import NamedTemporaryFile
from typing import Literal, Union
+from zipfile import ZIP_DEFLATED, ZipFile
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
@@ -17,6 +21,7 @@ from constants import (
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
+from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
@@ -167,6 +172,9 @@ class FileService:
return upload_file
def get_file_preview(self, file_id: str):
+ """
+ Return a short text preview extracted from a document file.
+ """
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
@@ -253,3 +261,101 @@ class FileService:
return
storage.delete(upload_file.key)
session.delete(upload_file)
+
+ @staticmethod
+ def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
+ """
+ Fetch `UploadFile` rows for a tenant in a single batch query.
+
+ This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`.
+ """
+ if not upload_file_ids:
+ return {}
+
+ # Normalize and deduplicate ids before using them in the IN clause.
+ upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids]
+ unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
+
+ # Fetch upload files in one query for efficient batch access.
+ upload_files: Sequence[UploadFile] = db.session.scalars(
+ select(UploadFile).where(
+ UploadFile.tenant_id == tenant_id,
+ UploadFile.id.in_(unique_upload_file_ids),
+ )
+ ).all()
+ return {str(upload_file.id): upload_file for upload_file in upload_files}
+
+ @staticmethod
+ def _sanitize_zip_entry_name(name: str) -> str:
+ """
+ Sanitize a ZIP entry name to avoid path traversal and weird separators.
+
+ We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data)
+ could still contain unsafe names.
+ """
+ # Drop any directory components and prevent empty names.
+ base = os.path.basename(name).strip() or "file"
+
+ # ZIP uses forward slashes as separators; remove any residual separator characters.
+ return base.replace("/", "_").replace("\\", "_")
+
+ @staticmethod
+ def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str:
+ """
+ Return a unique ZIP entry name, inserting suffixes before the extension.
+ """
+ # Keep the original name when it's not already used.
+ if original_name not in used_names:
+ return original_name
+
+ # Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt").
+ stem, extension = os.path.splitext(original_name)
+ suffix = 1
+ while True:
+ candidate = f"{stem} ({suffix}){extension}"
+ if candidate not in used_names:
+ return candidate
+ suffix += 1
+
+ @staticmethod
+ @contextmanager
+ def build_upload_files_zip_tempfile(
+ *,
+ upload_files: Sequence[UploadFile],
+ ) -> Iterator[str]:
+ """
+ Build a ZIP from `UploadFile`s and yield a tempfile path.
+
+ We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug
+ streams responses. The caller is expected to keep this context open until the response is fully sent, then
+ close it (e.g., via `response.call_on_close(...)`) to delete the tempfile.
+ """
+ used_names: set[str] = set()
+
+ # Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it.
+ tmp_path: str | None = None
+ try:
+ with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp:
+ tmp_path = tmp.name
+ with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf:
+ for upload_file in upload_files:
+ # Ensure the entry name is safe and unique.
+ safe_name = FileService._sanitize_zip_entry_name(upload_file.name)
+ arcname = FileService._dedupe_zip_entry_name(safe_name, used_names)
+ used_names.add(arcname)
+
+ # Stream file bytes from storage into the ZIP entry.
+ with zf.open(arcname, "w") as entry:
+ for chunk in storage.load(upload_file.key, stream=True):
+ entry.write(chunk)
+
+ # Flush so `send_file(path, ...)` can re-open it safely on all platforms.
+ tmp.flush()
+
+ assert tmp_path is not None
+ yield tmp_path
+ finally:
+ # Remove the temp file when the context is closed (typically after the response finishes streaming).
+ if tmp_path is not None:
+ with suppress(FileNotFoundError):
+ os.remove(tmp_path)
diff --git a/api/services/message_service.py b/api/services/message_service.py
index e1a256e64..a53ca8b22 100644
--- a/api/services/message_service.py
+++ b/api/services/message_service.py
@@ -261,10 +261,9 @@ class MessageService:
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
- id=conversation.app_model_config_id,
app_id=app_model.id,
)
-
+ app_model_config.id = conversation.app_model_config_id
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if not app_model_config:
raise ValueError("did not find app model config")
diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py
index b8303eb72..411c335c1 100644
--- a/api/services/plugin/plugin_service.py
+++ b/api/services/plugin/plugin_service.py
@@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
+from sqlalchemy import select
from yarl import URL
from configs import dify_config
@@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
+from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
+
+ # Get plugin info before uninstalling to delete associated credentials
+ try:
+ plugins = manager.list_plugins(tenant_id)
+ plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
+
+ if plugin:
+ plugin_id = plugin.plugin_id
+ logger.info("Deleting credentials for plugin: %s", plugin_id)
+
+ # Delete provider credentials that match this plugin
+ credentials = db.session.scalars(
+ select(ProviderCredential).where(
+ ProviderCredential.tenant_id == tenant_id,
+ ProviderCredential.provider_name.like(f"{plugin_id}/%"),
+ )
+ ).all()
+
+ for cred in credentials:
+ db.session.delete(cred)
+
+ db.session.commit()
+ logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
+ except Exception as e:
+ logger.warning("Failed to delete credentials: %s", e)
+ # Continue with uninstall even if credential deletion fails
+
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 1ba64813b..2d8418900 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
-from core.variables.variables import Variable
+from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""
diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py
index 544383a10..6b211a563 100644
--- a/api/services/recommended_app_service.py
+++ b/api/services/recommended_app_service.py
@@ -1,4 +1,7 @@
from configs import dify_config
+from extensions.ext_database import db
+from models.model import AccountTrialAppRecord, TrialApp
+from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@@ -20,6 +23,15 @@ class RecommendedAppService:
)
)
+ if FeatureService.get_system_features().enable_trial_app:
+ apps = result["recommended_apps"]
+ for app in apps:
+ app_id = app["app_id"]
+ trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
+ if trial_app_model:
+ app["can_trial"] = True
+ else:
+ app["can_trial"] = False
return result
@classmethod
@@ -32,4 +44,30 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
+ if FeatureService.get_system_features().enable_trial_app:
+ app_id = result["id"]
+ trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
+ if trial_app_model:
+ result["can_trial"] = True
+ else:
+ result["can_trial"] = False
return result
+
+ @classmethod
+ def add_trial_app_record(cls, app_id: str, account_id: str):
+ """
+ Add trial app record.
+ :param app_id: app id
+ :return:
+ """
+ account_trial_app_record = (
+ db.session.query(AccountTrialAppRecord)
+ .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
+ .first()
+ )
+ if account_trial_app_record:
+ account_trial_app_record.count += 1
+ db.session.commit()
+ else:
+ db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
+ db.session.commit()
diff --git a/api/services/retention/__init__.py b/api/services/retention/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py
new file mode 100644
index 000000000..6e647b983
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_policy.py
@@ -0,0 +1,216 @@
+import datetime
+import logging
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SimpleMessage:
+ id: str
+ app_id: str
+ created_at: datetime.datetime
+
+
+class MessagesCleanPolicy(ABC):
+ """
+ Abstract base class for message cleanup policies.
+
+ A policy determines which messages from a batch should be deleted.
+ """
+
+ @abstractmethod
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages and return IDs of messages that should be deleted.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ ...
+
+
+class BillingDisabledPolicy(MessagesCleanPolicy):
+ """
+ Policy for community or enterpriseedition (billing disabled).
+
+ No special filter logic, just return all message ids.
+ """
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ return [msg.id for msg in messages]
+
+
+class BillingSandboxPolicy(MessagesCleanPolicy):
+ """
+ Policy for sandbox plan tenants in cloud edition (billing enabled).
+
+ Filters messages based on sandbox plan expiration rules:
+ - Skip tenants in the whitelist
+ - Only delete messages from sandbox plan tenants
+ - Respect grace period after subscription expiration
+ - Safe default: if tenant mapping or plan is missing, do NOT delete
+ """
+
+ def __init__(
+ self,
+ plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
+ graceful_period_days: int = 21,
+ tenant_whitelist: Sequence[str] | None = None,
+ current_timestamp: int | None = None,
+ ) -> None:
+ self._graceful_period_days = graceful_period_days
+ self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
+ self._plan_provider = plan_provider
+ self._current_timestamp = current_timestamp
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages based on sandbox plan expiration rules.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ if not messages or not app_to_tenant:
+ return []
+
+ # Get unique tenant_ids and fetch subscription plans
+ tenant_ids = list(set(app_to_tenant.values()))
+ tenant_plans = self._plan_provider(tenant_ids)
+
+ if not tenant_plans:
+ return []
+
+ # Apply sandbox deletion rules
+ return self._filter_expired_sandbox_messages(
+ messages=messages,
+ app_to_tenant=app_to_tenant,
+ tenant_plans=tenant_plans,
+ )
+
+ def _filter_expired_sandbox_messages(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ tenant_plans: dict[str, SubscriptionPlan],
+ ) -> list[str]:
+ """
+ Filter messages that should be deleted based on sandbox plan expiration.
+
+ A message should be deleted if:
+ 1. It belongs to a sandbox tenant AND
+ 2. Either:
+ a) The tenant has no previous subscription (expiration_date == -1), OR
+ b) The subscription expired more than graceful_period_days ago
+
+ Args:
+ messages: List of message objects with id and app_id attributes
+ app_to_tenant: Mapping from app_id to tenant_id
+ tenant_plans: Mapping from tenant_id to subscription plan info
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ current_timestamp = self._current_timestamp
+ if current_timestamp is None:
+ current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
+
+ sandbox_message_ids: list[str] = []
+ graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
+
+ for msg in messages:
+ # Get tenant_id for this message's app
+ tenant_id = app_to_tenant.get(msg.app_id)
+ if not tenant_id:
+ continue
+
+ # Skip tenant messages in whitelist
+ if tenant_id in self._tenant_whitelist:
+ continue
+
+ # Get subscription plan for this tenant
+ tenant_plan = tenant_plans.get(tenant_id)
+ if not tenant_plan:
+ continue
+
+ plan = str(tenant_plan["plan"])
+ expiration_date = int(tenant_plan["expiration_date"])
+
+ # Only process sandbox plans
+ if plan != CloudPlan.SANDBOX:
+ continue
+
+ # Case 1: No previous subscription (-1 means never had a paid subscription)
+ if expiration_date == -1:
+ sandbox_message_ids.append(msg.id)
+ continue
+
+ # Case 2: Subscription expired beyond grace period
+ if current_timestamp - expiration_date > graceful_period_seconds:
+ sandbox_message_ids.append(msg.id)
+
+ return sandbox_message_ids
+
+
+def create_message_clean_policy(
+ graceful_period_days: int = 21,
+ current_timestamp: int | None = None,
+) -> MessagesCleanPolicy:
+ """
+ Factory function to create the appropriate message clean policy.
+
+ Determines which policy to use based on BILLING_ENABLED configuration:
+ - If BILLING_ENABLED is True: returns BillingSandboxPolicy
+ - If BILLING_ENABLED is False: returns BillingDisabledPolicy
+
+ Args:
+ graceful_period_days: Grace period in days after subscription expiration (default: 21)
+ current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
+ """
+ if not dify_config.BILLING_ENABLED:
+ logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
+ return BillingDisabledPolicy()
+
+ # Billing enabled - fetch whitelist from BillingService
+ tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
+ plan_provider = BillingService.get_plan_bulk_with_cache
+
+ logger.info(
+ "create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
+ "(graceful_period_days=%s, whitelist=%s)",
+ graceful_period_days,
+ tenant_whitelist,
+ )
+
+ return BillingSandboxPolicy(
+ plan_provider=plan_provider,
+ graceful_period_days=graceful_period_days,
+ tenant_whitelist=tenant_whitelist,
+ current_timestamp=current_timestamp,
+ )
diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py
new file mode 100644
index 000000000..3ca5d8286
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_service.py
@@ -0,0 +1,334 @@
+import datetime
+import logging
+import random
+from collections.abc import Sequence
+from typing import cast
+
+from sqlalchemy import delete, select
+from sqlalchemy.engine import CursorResult
+from sqlalchemy.orm import Session
+
+from extensions.ext_database import db
+from models.model import (
+ App,
+ AppAnnotationHitHistory,
+ DatasetRetrieverResource,
+ Message,
+ MessageAgentThought,
+ MessageAnnotation,
+ MessageChain,
+ MessageFeedback,
+ MessageFile,
+)
+from models.web import SavedMessage
+from services.retention.conversation.messages_clean_policy import (
+ MessagesCleanPolicy,
+ SimpleMessage,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class MessagesCleanService:
+ """
+ Service for cleaning expired messages based on retention policies.
+
+ Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
+ If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
+ """
+
+ def __init__(
+ self,
+ policy: MessagesCleanPolicy,
+ end_before: datetime.datetime,
+ start_from: datetime.datetime | None = None,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> None:
+ """
+ Initialize the service with cleanup parameters.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ end_before: End time (exclusive) of the range
+ start_from: Optional start time (inclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+ """
+ self._policy = policy
+ self._end_before = end_before
+ self._start_from = start_from
+ self._batch_size = batch_size
+ self._dry_run = dry_run
+
+ @classmethod
+ def from_time_range(
+ cls,
+ policy: MessagesCleanPolicy,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages within a specific time range.
+
+ Time range is [start_from, end_before).
+
+ Args:
+ policy: The policy that determines which messages to delete
+ start_from: Start time (inclusive) of the range
+ end_before: End time (exclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If start_from >= end_before or invalid parameters
+ """
+ if start_from >= end_before:
+ raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ logger.info(
+ "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
+ start_from,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(
+ policy=policy,
+ end_before=end_before,
+ start_from=start_from,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+
+ @classmethod
+ def from_days(
+ cls,
+ policy: MessagesCleanPolicy,
+ days: int = 30,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages older than specified days.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ days: Number of days to look back from now
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If invalid parameters
+ """
+ if days < 0:
+ raise ValueError(f"days ({days}) must be greater than or equal to 0")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ end_before = datetime.datetime.now() - datetime.timedelta(days=days)
+
+ logger.info(
+ "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
+ days,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
+
+ def run(self) -> dict[str, int]:
+ """
+ Execute the message cleanup operation.
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ return self._clean_messages_by_time_range()
+
+ def _clean_messages_by_time_range(self) -> dict[str, int]:
+ """
+ Clean messages within a time range using cursor-based pagination.
+
+ Time range is [start_from, end_before)
+
+ Steps:
+ 1. Iterate messages using cursor pagination (by created_at, id)
+ 2. Query app_id -> tenant_id mapping
+ 3. Delegate to policy to determine which messages to delete
+ 4. Batch delete messages and their relations
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ stats = {
+ "batches": 0,
+ "total_messages": 0,
+ "filtered_messages": 0,
+ "total_deleted": 0,
+ }
+
+ # Cursor-based pagination using (created_at, id) to avoid infinite loops
+ # and ensure proper ordering with time-based filtering
+ _cursor: tuple[datetime.datetime, str] | None = None
+
+ logger.info(
+ "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
+ self._dry_run,
+ self._start_from,
+ self._end_before,
+ )
+
+ while True:
+ stats["batches"] += 1
+
+ # Step 1: Fetch a batch of messages using cursor
+ with Session(db.engine, expire_on_commit=False) as session:
+ msg_stmt = (
+ select(Message.id, Message.app_id, Message.created_at)
+ .where(Message.created_at < self._end_before)
+ .order_by(Message.created_at, Message.id)
+ .limit(self._batch_size)
+ )
+
+ if self._start_from:
+ msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
+
+ # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
+ # This translates to:
+ # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
+ if _cursor:
+ # Continuing from previous batch
+ msg_stmt = msg_stmt.where(
+ (Message.created_at > _cursor[0])
+ | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
+ )
+
+ raw_messages = list(session.execute(msg_stmt).all())
+ messages = [
+ SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
+ for msg_id, app_id, msg_created_at in raw_messages
+ ]
+
+ # Track total messages fetched across all batches
+ stats["total_messages"] += len(messages)
+
+ if not messages:
+ logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
+ break
+
+ # Update cursor to the last message's (created_at, id)
+ _cursor = (messages[-1].created_at, messages[-1].id)
+
+ # Step 2: Extract app_ids and query tenant_ids
+ app_ids = list({msg.app_id for msg in messages})
+
+ if not app_ids:
+ logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
+ continue
+
+ app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
+ apps = list(session.execute(app_stmt).all())
+
+ if not apps:
+ logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
+ continue
+
+ # Build app_id -> tenant_id mapping
+ app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
+
+ # Step 3: Delegate to policy to determine which messages to delete
+ message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
+
+ if not message_ids_to_delete:
+ logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
+ continue
+
+ stats["filtered_messages"] += len(message_ids_to_delete)
+
+ # Step 4: Batch delete messages and their relations
+ if not self._dry_run:
+ with Session(db.engine, expire_on_commit=False) as session:
+ # Delete related records first
+ self._batch_delete_message_relations(session, message_ids_to_delete)
+
+ # Delete messages
+ delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
+ delete_result = cast(CursorResult, session.execute(delete_stmt))
+ messages_deleted = delete_result.rowcount
+ session.commit()
+
+ stats["total_deleted"] += messages_deleted
+
+ logger.info(
+ "clean_messages (batch %s): processed %s messages, deleted %s messages",
+ stats["batches"],
+ len(messages),
+ messages_deleted,
+ )
+ else:
+ # Log random sample of message IDs that would be deleted (up to 10)
+ sample_size = min(10, len(message_ids_to_delete))
+ sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
+
+ logger.info(
+ "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
+ stats["batches"],
+ len(message_ids_to_delete),
+ sample_size,
+ )
+ for msg_id in sampled_ids:
+ logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
+
+ logger.info(
+ "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
+ stats["batches"],
+ stats["total_messages"],
+ stats["filtered_messages"],
+ stats["total_deleted"],
+ )
+
+ return stats
+
+ @staticmethod
+ def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
+ """
+ Batch delete all related records for given message IDs.
+
+ Args:
+ session: Database session
+ message_ids: List of message IDs to delete relations for
+ """
+ if not message_ids:
+ return
+
+ # Delete all related records in batch
+ session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
+
+ session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
+
+ session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
+
+ session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))
diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py
new file mode 100644
index 000000000..18dd42c91
--- /dev/null
+++ b/api/services/retention/workflow_run/__init__.py
@@ -0,0 +1 @@
+"""Workflow run retention services."""
diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py
new file mode 100644
index 000000000..ea5cbb774
--- /dev/null
+++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py
@@ -0,0 +1,531 @@
+"""
+Archive Paid Plan Workflow Run Logs Service.
+
+This service archives workflow run logs for paid plan users older than the configured
+retention period (default: 90 days) to S3-compatible storage.
+
+Archived tables:
+- workflow_runs
+- workflow_app_logs
+- workflow_node_executions
+- workflow_node_execution_offload
+- workflow_pauses
+- workflow_pause_reasons
+- workflow_trigger_logs
+
+"""
+
+import datetime
+import io
+import json
+import logging
+import time
+import zipfile
+from collections.abc import Sequence
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass, field
+from typing import Any
+
+import click
+from sqlalchemy import inspect
+from sqlalchemy.orm import Session, sessionmaker
+
+from configs import dify_config
+from core.workflow.enums import WorkflowType
+from enums.cloud_plan import CloudPlan
+from extensions.ext_database import db
+from libs.archive_storage import (
+ ArchiveStorage,
+ ArchiveStorageNotConfiguredError,
+ get_archive_storage,
+)
+from models.workflow import WorkflowAppLog, WorkflowRun
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+from services.billing_service import BillingService
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class TableStats:
+ """Statistics for a single archived table."""
+
+ table_name: str
+ row_count: int
+ checksum: str
+ size_bytes: int
+
+
+@dataclass
+class ArchiveResult:
+ """Result of archiving a single workflow run."""
+
+ run_id: str
+ tenant_id: str
+ success: bool
+ tables: list[TableStats] = field(default_factory=list)
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+@dataclass
+class ArchiveSummary:
+ """Summary of the entire archive operation."""
+
+ total_runs_processed: int = 0
+ runs_archived: int = 0
+ runs_skipped: int = 0
+ runs_failed: int = 0
+ total_elapsed_time: float = 0.0
+
+
+class WorkflowRunArchiver:
+ """
+ Archive workflow run logs for paid plan users.
+
+ Storage Layout:
+ {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/
+ └── archive.v1.0.zip
+ ├── manifest.json
+ ├── workflow_runs.jsonl
+ ├── workflow_app_logs.jsonl
+ ├── workflow_node_executions.jsonl
+ ├── workflow_node_execution_offload.jsonl
+ ├── workflow_pauses.jsonl
+ ├── workflow_pause_reasons.jsonl
+ └── workflow_trigger_logs.jsonl
+ """
+
+ ARCHIVED_TYPE = [
+ WorkflowType.WORKFLOW,
+ WorkflowType.RAG_PIPELINE,
+ ]
+ ARCHIVED_TABLES = [
+ "workflow_runs",
+ "workflow_app_logs",
+ "workflow_node_executions",
+ "workflow_node_execution_offload",
+ "workflow_pauses",
+ "workflow_pause_reasons",
+ "workflow_trigger_logs",
+ ]
+
+ start_from: datetime.datetime | None
+ end_before: datetime.datetime
+
+ def __init__(
+ self,
+ days: int = 90,
+ batch_size: int = 100,
+ start_from: datetime.datetime | None = None,
+ end_before: datetime.datetime | None = None,
+ workers: int = 1,
+ tenant_ids: Sequence[str] | None = None,
+ limit: int | None = None,
+ dry_run: bool = False,
+ delete_after_archive: bool = False,
+ workflow_run_repo: APIWorkflowRunRepository | None = None,
+ ):
+ """
+ Initialize the archiver.
+
+ Args:
+ days: Archive runs older than this many days
+ batch_size: Number of runs to process per batch
+ start_from: Optional start time (inclusive) for archiving
+ end_before: Optional end time (exclusive) for archiving
+ workers: Number of concurrent workflow runs to archive
+ tenant_ids: Optional tenant IDs for grayscale rollout
+ limit: Maximum number of runs to archive (None for unlimited)
+ dry_run: If True, only preview without making changes
+ delete_after_archive: If True, delete runs and related data after archiving
+ """
+ self.days = days
+ self.batch_size = batch_size
+ if start_from or end_before:
+ if start_from is None or end_before is None:
+ raise ValueError("start_from and end_before must be provided together")
+ if start_from >= end_before:
+ raise ValueError("start_from must be earlier than end_before")
+ self.start_from = start_from.replace(tzinfo=datetime.UTC)
+ self.end_before = end_before.replace(tzinfo=datetime.UTC)
+ else:
+ self.start_from = None
+ self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
+ if workers < 1:
+ raise ValueError("workers must be at least 1")
+ self.workers = workers
+ self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else []
+ self.limit = limit
+ self.dry_run = dry_run
+ self.delete_after_archive = delete_after_archive
+ self.workflow_run_repo = workflow_run_repo
+
+ def run(self) -> ArchiveSummary:
+ """
+ Main archiving loop.
+
+ Returns:
+ ArchiveSummary with statistics about the operation
+ """
+ summary = ArchiveSummary()
+ start_time = time.time()
+
+ click.echo(
+ click.style(
+ self._build_start_message(),
+ fg="white",
+ )
+ )
+
+ # Initialize archive storage (will raise if not configured)
+ try:
+ if not self.dry_run:
+ storage = get_archive_storage()
+ else:
+ storage = None
+ except ArchiveStorageNotConfiguredError as e:
+ click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
+ return summary
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = self._get_workflow_run_repo()
+
+ def _archive_with_session(run: WorkflowRun) -> ArchiveResult:
+ with session_maker() as session:
+ return self._archive_run(session, storage, run)
+
+ last_seen: tuple[datetime.datetime, str] | None = None
+ archived_count = 0
+
+ with ThreadPoolExecutor(max_workers=self.workers) as executor:
+ while True:
+ # Check limit
+ if self.limit and archived_count >= self.limit:
+ click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow"))
+ break
+
+ # Fetch batch of runs
+ runs = self._get_runs_batch(last_seen)
+
+ if not runs:
+ break
+
+ run_ids = [run.id for run in runs]
+ with session_maker() as session:
+ archived_run_ids = repo.get_archived_run_ids(session, run_ids)
+
+ last_seen = (runs[-1].created_at, runs[-1].id)
+
+ # Filter to paid tenants only
+ tenant_ids = {run.tenant_id for run in runs}
+ paid_tenants = self._filter_paid_tenants(tenant_ids)
+
+ runs_to_process: list[WorkflowRun] = []
+ for run in runs:
+ summary.total_runs_processed += 1
+
+ # Skip non-paid tenants
+ if run.tenant_id not in paid_tenants:
+ summary.runs_skipped += 1
+ continue
+
+ # Skip already archived runs
+ if run.id in archived_run_ids:
+ summary.runs_skipped += 1
+ continue
+
+ # Check limit
+ if self.limit and archived_count + len(runs_to_process) >= self.limit:
+ break
+
+ runs_to_process.append(run)
+
+ if not runs_to_process:
+ continue
+
+ results = list(executor.map(_archive_with_session, runs_to_process))
+
+ for run, result in zip(runs_to_process, results):
+ if result.success:
+ summary.runs_archived += 1
+ archived_count += 1
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} "
+ f"run {run.id} (tenant={run.tenant_id}, "
+ f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)",
+ fg="green",
+ )
+ )
+ else:
+ summary.runs_failed += 1
+ click.echo(
+ click.style(
+ f"Failed to archive run {run.id}: {result.error}",
+ fg="red",
+ )
+ )
+
+ summary.total_elapsed_time = time.time() - start_time
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: "
+ f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
+ f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
+ f"time={summary.total_elapsed_time:.2f}s",
+ fg="white",
+ )
+ )
+
+ return summary
+
+ def _get_runs_batch(
+ self,
+ last_seen: tuple[datetime.datetime, str] | None,
+ ) -> Sequence[WorkflowRun]:
+ """Fetch a batch of workflow runs to archive."""
+ repo = self._get_workflow_run_repo()
+ return repo.get_runs_batch_by_time_range(
+ start_from=self.start_from,
+ end_before=self.end_before,
+ last_seen=last_seen,
+ batch_size=self.batch_size,
+ run_types=self.ARCHIVED_TYPE,
+ tenant_ids=self.tenant_ids or None,
+ )
+
+ def _build_start_message(self) -> str:
+ range_desc = f"before {self.end_before.isoformat()}"
+ if self.start_from:
+ range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}"
+ return (
+ f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving "
+ f"for runs {range_desc} "
+ f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})"
+ )
+
+ def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]:
+ """Filter tenant IDs to only include paid tenants."""
+ if not dify_config.BILLING_ENABLED:
+ # If billing is not enabled, treat all tenants as paid
+ return tenant_ids
+
+ if not tenant_ids:
+ return set()
+
+ try:
+ bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids))
+ except Exception:
+ logger.exception("Failed to fetch billing plans for tenants")
+ # On error, skip all tenants in this batch
+ return set()
+
+ # Filter to paid tenants (any plan except SANDBOX)
+ paid = set()
+ for tid, info in bulk_info.items():
+ if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM):
+ paid.add(tid)
+
+ return paid
+
+ def _archive_run(
+ self,
+ session: Session,
+ storage: ArchiveStorage | None,
+ run: WorkflowRun,
+ ) -> ArchiveResult:
+ """Archive a single workflow run."""
+ start_time = time.time()
+ result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
+
+ try:
+ # Extract data from all tables
+ table_data, app_logs, trigger_metadata = self._extract_data(session, run)
+
+ if self.dry_run:
+ # In dry run, just report what would be archived
+ for table_name in self.ARCHIVED_TABLES:
+ records = table_data.get(table_name, [])
+ result.tables.append(
+ TableStats(
+ table_name=table_name,
+ row_count=len(records),
+ checksum="",
+ size_bytes=0,
+ )
+ )
+ result.success = True
+ else:
+ if storage is None:
+ raise ArchiveStorageNotConfiguredError("Archive storage not configured")
+ archive_key = self._get_archive_key(run)
+
+ # Serialize tables for the archive bundle
+ table_stats: list[TableStats] = []
+ table_payloads: dict[str, bytes] = {}
+ for table_name in self.ARCHIVED_TABLES:
+ records = table_data.get(table_name, [])
+ data = ArchiveStorage.serialize_to_jsonl(records)
+ table_payloads[table_name] = data
+ checksum = ArchiveStorage.compute_checksum(data)
+
+ table_stats.append(
+ TableStats(
+ table_name=table_name,
+ row_count=len(records),
+ checksum=checksum,
+ size_bytes=len(data),
+ )
+ )
+
+ # Generate and upload archive bundle
+ manifest = self._generate_manifest(run, table_stats)
+ manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8")
+ archive_data = self._build_archive_bundle(manifest_data, table_payloads)
+ storage.put_object(archive_key, archive_data)
+
+ repo = self._get_workflow_run_repo()
+ archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata)
+ session.commit()
+
+ deleted_counts = None
+ if self.delete_after_archive:
+ deleted_counts = repo.delete_runs_with_related(
+ [run],
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+
+ logger.info(
+ "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s",
+ run.id,
+ {s.table_name: s.row_count for s in table_stats},
+ archived_log_count,
+ deleted_counts,
+ )
+
+ result.tables = table_stats
+ result.success = True
+
+ except Exception as e:
+ logger.exception("Failed to archive workflow run %s", run.id)
+ result.error = str(e)
+ session.rollback()
+
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def _extract_data(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]:
+ table_data: dict[str, list[dict[str, Any]]] = {}
+ table_data["workflow_runs"] = [self._row_to_dict(run)]
+ repo = self._get_workflow_run_repo()
+ app_logs = repo.get_app_logs_by_run_id(session, run.id)
+ table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs]
+ node_exec_repo = self._get_workflow_node_execution_repo(session)
+ node_exec_records = node_exec_repo.get_executions_by_workflow_run(
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_run_id=run.id,
+ )
+ node_exec_ids = [record.id for record in node_exec_records]
+ offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids)
+ table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records]
+ table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records]
+ repo = self._get_workflow_run_repo()
+ pause_records = repo.get_pause_records_by_run_id(session, run.id)
+ pause_ids = [pause.id for pause in pause_records]
+ pause_reason_records = repo.get_pause_reason_records_by_run_id(
+ session,
+ pause_ids,
+ )
+ table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records]
+ table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records]
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ trigger_records = trigger_repo.list_by_run_id(run.id)
+ table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records]
+ trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None
+ return table_data, app_logs, trigger_metadata
+
+ @staticmethod
+ def _row_to_dict(row: Any) -> dict[str, Any]:
+ mapper = inspect(row).mapper
+ return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns}
+
+ def _get_archive_key(self, run: WorkflowRun) -> str:
+ """Get the storage key for the archive bundle."""
+ created_at = run.created_at
+ prefix = (
+ f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
+ f"month={created_at.strftime('%m')}/workflow_run_id={run.id}"
+ )
+ return f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+
+ def _generate_manifest(
+ self,
+ run: WorkflowRun,
+ table_stats: list[TableStats],
+ ) -> dict[str, Any]:
+ """Generate a manifest for the archived workflow run."""
+ return {
+ "schema_version": ARCHIVE_SCHEMA_VERSION,
+ "workflow_run_id": run.id,
+ "tenant_id": run.tenant_id,
+ "app_id": run.app_id,
+ "workflow_id": run.workflow_id,
+ "created_at": run.created_at.isoformat(),
+ "archived_at": datetime.datetime.now(datetime.UTC).isoformat(),
+ "tables": {
+ stat.table_name: {
+ "row_count": stat.row_count,
+ "checksum": stat.checksum,
+ "size_bytes": stat.size_bytes,
+ }
+ for stat in table_stats
+ },
+ }
+
+ def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes:
+ buffer = io.BytesIO()
+ with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
+ archive.writestr("manifest.json", manifest_data)
+ for table_name in self.ARCHIVED_TABLES:
+ data = table_payloads.get(table_name)
+ if data is None:
+ raise ValueError(f"Missing archive payload for {table_name}")
+ archive.writestr(f"{table_name}.jsonl", data)
+ return buffer.getvalue()
+
+ def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids)
+
+ def _get_workflow_node_execution_repo(
+ self,
+ session: Session,
+ ) -> DifyAPIWorkflowNodeExecutionRepository:
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+ return self.workflow_run_repo
diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
new file mode 100644
index 000000000..c3e0dce39
--- /dev/null
+++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
@@ -0,0 +1,293 @@
+import datetime
+import logging
+from collections.abc import Iterable, Sequence
+
+import click
+from sqlalchemy.orm import Session, sessionmaker
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from extensions.ext_database import db
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.factory import DifyAPIRepositoryFactory
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+class WorkflowRunCleanup:
+ def __init__(
+ self,
+ days: int,
+ batch_size: int,
+ start_from: datetime.datetime | None = None,
+ end_before: datetime.datetime | None = None,
+ workflow_run_repo: APIWorkflowRunRepository | None = None,
+ dry_run: bool = False,
+ ):
+ if (start_from is None) ^ (end_before is None):
+ raise ValueError("start_from and end_before must be both set or both omitted.")
+
+ computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days)
+ self.window_start = start_from
+ self.window_end = end_before or computed_cutoff
+
+ if self.window_start and self.window_end <= self.window_start:
+ raise ValueError("end_before must be greater than start_from.")
+
+ if batch_size <= 0:
+ raise ValueError("batch_size must be greater than 0.")
+
+ self.batch_size = batch_size
+ self._cleanup_whitelist: set[str] | None = None
+ self.dry_run = dry_run
+ self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
+ self.workflow_run_repo: APIWorkflowRunRepository
+ if workflow_run_repo:
+ self.workflow_run_repo = workflow_run_repo
+ else:
+ # Lazy import to avoid circular dependencies during module import
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ def run(self) -> None:
+ click.echo(
+ click.style(
+ f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs "
+ f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}"
+ f"{self.window_end.isoformat()} (batch={self.batch_size})",
+ fg="white",
+ )
+ )
+ if self.dry_run:
+ click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow"))
+
+ total_runs_deleted = 0
+ total_runs_targeted = 0
+ related_totals = self._empty_related_counts() if self.dry_run else None
+ batch_index = 0
+ last_seen: tuple[datetime.datetime, str] | None = None
+
+ while True:
+ run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
+ start_from=self.window_start,
+ end_before=self.window_end,
+ last_seen=last_seen,
+ batch_size=self.batch_size,
+ )
+ if not run_rows:
+ break
+
+ batch_index += 1
+ last_seen = (run_rows[-1].created_at, run_rows[-1].id)
+ tenant_ids = {row.tenant_id for row in run_rows}
+ free_tenants = self._filter_free_tenants(tenant_ids)
+ free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
+ paid_or_skipped = len(run_rows) - len(free_runs)
+
+ if not free_runs:
+ skipped_message = (
+ f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
+ )
+ click.echo(
+ click.style(
+ skipped_message,
+ fg="yellow",
+ )
+ )
+ continue
+
+ total_runs_targeted += len(free_runs)
+
+ if self.dry_run:
+ batch_counts = self.workflow_run_repo.count_runs_with_related(
+ free_runs,
+ count_node_executions=self._count_node_executions,
+ count_trigger_logs=self._count_trigger_logs,
+ )
+ if related_totals is not None:
+ for key in related_totals:
+ related_totals[key] += batch_counts.get(key, 0)
+ sample_ids = ", ".join(run.id for run in free_runs[:5])
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] would delete {len(free_runs)} runs "
+ f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
+ fg="yellow",
+ )
+ )
+ continue
+
+ try:
+ counts = self.workflow_run_repo.delete_runs_with_related(
+ free_runs,
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+ except Exception:
+ logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
+ raise
+
+ total_runs_deleted += counts["runs"]
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] deleted runs: {counts['runs']} "
+ f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
+ f"skipped {paid_or_skipped} paid/unknown",
+ fg="green",
+ )
+ )
+
+ if self.dry_run:
+ if self.window_start:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"before {self.window_end.isoformat()}"
+ )
+ if related_totals is not None:
+ summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
+ summary_color = "yellow"
+ else:
+ if self.window_start:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
+ )
+ summary_color = "white"
+
+ click.echo(click.style(summary_message, fg=summary_color))
+
+ def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
+ tenant_id_list = list(tenant_ids)
+
+ if not dify_config.BILLING_ENABLED:
+ return set(tenant_id_list)
+
+ if not tenant_id_list:
+ return set()
+
+ cleanup_whitelist = self._get_cleanup_whitelist()
+
+ try:
+ bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list)
+ except Exception:
+ bulk_info = {}
+ logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list)
+
+ eligible_free_tenants: set[str] = set()
+ for tenant_id in tenant_id_list:
+ if tenant_id in cleanup_whitelist:
+ continue
+
+ info = bulk_info.get(tenant_id)
+ if info is None:
+ logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id)
+ continue
+
+ if info.get("plan") != CloudPlan.SANDBOX:
+ continue
+
+ if self._is_within_grace_period(tenant_id, info):
+ continue
+
+ eligible_free_tenants.add(tenant_id)
+
+ return eligible_free_tenants
+
+ def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None:
+ if expiration_value < 0:
+ return None
+
+ try:
+ return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC)
+ except (OverflowError, OSError, ValueError):
+ logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id)
+ return None
+
+ def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool:
+ if self.free_plan_grace_period_days <= 0:
+ return False
+
+ expiration_value = info.get("expiration_date", -1)
+ expiration_at = self._expiration_datetime(tenant_id, expiration_value)
+ if expiration_at is None:
+ return False
+
+ grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days)
+ return datetime.datetime.now(datetime.UTC) < grace_deadline
+
+ def _get_cleanup_whitelist(self) -> set[str]:
+ if self._cleanup_whitelist is not None:
+ return self._cleanup_whitelist
+
+ if not dify_config.BILLING_ENABLED:
+ self._cleanup_whitelist = set()
+ return self._cleanup_whitelist
+
+ try:
+ whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist()
+ except Exception:
+ logger.exception("Failed to fetch cleanup whitelist from billing service")
+ whitelist_ids = []
+
+ self._cleanup_whitelist = set(whitelist_ids)
+ return self._cleanup_whitelist
+
+ def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.count_by_run_ids(run_ids)
+
+ @staticmethod
+ def _empty_related_counts() -> dict[str, int]:
+ return {
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ @staticmethod
+ def _format_related_counts(counts: dict[str, int]) -> str:
+ return (
+ f"node_executions {counts['node_executions']}, "
+ f"offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, "
+ f"trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, "
+ f"pause_reasons {counts['pause_reasons']}"
+ )
+
+ def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.count_by_runs(session, run_ids)
+
+ def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.delete_by_runs(session, run_ids)
diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py
new file mode 100644
index 000000000..162bb4947
--- /dev/null
+++ b/api/services/retention/workflow_run/constants.py
@@ -0,0 +1,2 @@
+ARCHIVE_SCHEMA_VERSION = "1.0"
+ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip"
diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py
new file mode 100644
index 000000000..11873bf1b
--- /dev/null
+++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py
@@ -0,0 +1,134 @@
+"""
+Delete Archived Workflow Run Service.
+
+This service deletes archived workflow run data from the database while keeping
+archive logs intact.
+"""
+
+import time
+from collections.abc import Sequence
+from dataclasses import dataclass, field
+from datetime import datetime
+
+from sqlalchemy.orm import Session, sessionmaker
+
+from extensions.ext_database import db
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+
+
+@dataclass
+class DeleteResult:
+ run_id: str
+ tenant_id: str
+ success: bool
+ deleted_counts: dict[str, int] = field(default_factory=dict)
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+class ArchivedWorkflowRunDeletion:
+ def __init__(self, dry_run: bool = False):
+ self.dry_run = dry_run
+ self.workflow_run_repo: APIWorkflowRunRepository | None = None
+
+ def delete_by_run_id(self, run_id: str) -> DeleteResult:
+ start_time = time.time()
+ result = DeleteResult(run_id=run_id, tenant_id="", success=False)
+
+ repo = self._get_workflow_run_repo()
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ with session_maker() as session:
+ run = session.get(WorkflowRun, run_id)
+ if not run:
+ result.error = f"Workflow run {run_id} not found"
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ result.tenant_id = run.tenant_id
+ if not repo.get_archived_run_ids(session, [run.id]):
+ result.error = f"Workflow run {run_id} is not archived"
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ result = self._delete_run(run)
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def delete_batch(
+ self,
+ tenant_ids: list[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int = 100,
+ ) -> list[DeleteResult]:
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ results: list[DeleteResult] = []
+
+ repo = self._get_workflow_run_repo()
+ with session_maker() as session:
+ runs = list(
+ repo.get_archived_runs_by_time_range(
+ session=session,
+ tenant_ids=tenant_ids,
+ start_date=start_date,
+ end_date=end_date,
+ limit=limit,
+ )
+ )
+ for run in runs:
+ results.append(self._delete_run(run))
+
+ return results
+
+ def _delete_run(self, run: WorkflowRun) -> DeleteResult:
+ start_time = time.time()
+ result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
+ if self.dry_run:
+ result.success = True
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ repo = self._get_workflow_run_repo()
+ try:
+ deleted_counts = repo.delete_runs_with_related(
+ [run],
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+ result.deleted_counts = deleted_counts
+ result.success = True
+ except Exception as e:
+ result.error = str(e)
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ @staticmethod
+ def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ @staticmethod
+ def _delete_node_executions(
+ session: Session,
+ runs: Sequence[WorkflowRun],
+ ) -> tuple[int, int]:
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.delete_by_runs(session, run_ids)
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
+ sessionmaker(bind=db.engine, expire_on_commit=False)
+ )
+ return self.workflow_run_repo
diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py
new file mode 100644
index 000000000..d4a6e8758
--- /dev/null
+++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py
@@ -0,0 +1,481 @@
+"""
+Restore Archived Workflow Run Service.
+
+This service restores archived workflow run data from S3-compatible storage
+back to the database.
+"""
+
+import io
+import json
+import logging
+import time
+import zipfile
+from collections.abc import Callable
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Any, cast
+
+import click
+from sqlalchemy.dialects.postgresql import insert as pg_insert
+from sqlalchemy.engine import CursorResult
+from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
+
+from extensions.ext_database import db
+from libs.archive_storage import (
+ ArchiveStorage,
+ ArchiveStorageNotConfiguredError,
+ get_archive_storage,
+)
+from models.trigger import WorkflowTriggerLog
+from models.workflow import (
+ WorkflowAppLog,
+ WorkflowArchiveLog,
+ WorkflowNodeExecutionModel,
+ WorkflowNodeExecutionOffload,
+ WorkflowPause,
+ WorkflowPauseReason,
+ WorkflowRun,
+)
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.factory import DifyAPIRepositoryFactory
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
+
+logger = logging.getLogger(__name__)
+
+
+# Mapping of table names to SQLAlchemy models
+TABLE_MODELS = {
+ "workflow_runs": WorkflowRun,
+ "workflow_app_logs": WorkflowAppLog,
+ "workflow_node_executions": WorkflowNodeExecutionModel,
+ "workflow_node_execution_offload": WorkflowNodeExecutionOffload,
+ "workflow_pauses": WorkflowPause,
+ "workflow_pause_reasons": WorkflowPauseReason,
+ "workflow_trigger_logs": WorkflowTriggerLog,
+}
+
+SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]]
+
+SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = {
+ "1.0": {},
+}
+
+
+@dataclass
+class RestoreResult:
+ """Result of restoring a single workflow run."""
+
+ run_id: str
+ tenant_id: str
+ success: bool
+ restored_counts: dict[str, int]
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+class WorkflowRunRestore:
+ """
+ Restore archived workflow run data from storage to database.
+
+ This service reads archived data from storage and restores it to the
+ database tables. It handles idempotency by skipping records that already
+ exist in the database.
+ """
+
+ def __init__(self, dry_run: bool = False, workers: int = 1):
+ """
+ Initialize the restore service.
+
+ Args:
+ dry_run: If True, only preview without making changes
+ workers: Number of concurrent workflow runs to restore
+ """
+ self.dry_run = dry_run
+ if workers < 1:
+ raise ValueError("workers must be at least 1")
+ self.workers = workers
+ self.workflow_run_repo: APIWorkflowRunRepository | None = None
+
+ def _restore_from_run(
+ self,
+ run: WorkflowRun | WorkflowArchiveLog,
+ *,
+ session_maker: sessionmaker,
+ ) -> RestoreResult:
+ start_time = time.time()
+ run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id
+ created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at
+ result = RestoreResult(
+ run_id=run_id,
+ tenant_id=run.tenant_id,
+ success=False,
+ restored_counts={},
+ )
+
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})",
+ fg="white",
+ )
+ )
+
+ try:
+ storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ result.error = str(e)
+ click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ prefix = (
+ f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
+ f"month={created_at.strftime('%m')}/workflow_run_id={run_id}"
+ )
+ archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+ try:
+ archive_data = storage.get_object(archive_key)
+ except FileNotFoundError:
+ result.error = f"Archive bundle not found: {archive_key}"
+ click.echo(click.style(result.error, fg="red"))
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ with session_maker() as session:
+ try:
+ with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive:
+ try:
+ manifest = self._load_manifest_from_zip(archive)
+ except ValueError as e:
+ result.error = f"Archive bundle invalid: {e}"
+ click.echo(click.style(result.error, fg="red"))
+ return result
+
+ tables = manifest.get("tables", {})
+ schema_version = self._get_schema_version(manifest)
+ for table_name, info in tables.items():
+ row_count = info.get("row_count", 0)
+ if row_count == 0:
+ result.restored_counts[table_name] = 0
+ continue
+
+ if self.dry_run:
+ result.restored_counts[table_name] = row_count
+ continue
+
+ member_path = f"{table_name}.jsonl"
+ try:
+ data = archive.read(member_path)
+ except KeyError:
+ click.echo(
+ click.style(
+ f" Warning: Table data not found in archive: {member_path}",
+ fg="yellow",
+ )
+ )
+ result.restored_counts[table_name] = 0
+ continue
+
+ records = ArchiveStorage.deserialize_from_jsonl(data)
+ restored = self._restore_table_records(
+ session,
+ table_name,
+ records,
+ schema_version=schema_version,
+ )
+ result.restored_counts[table_name] = restored
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f" Restored {restored}/{len(records)} records to {table_name}",
+ fg="white",
+ )
+ )
+
+ # Verify row counts match manifest
+ manifest_total = sum(info.get("row_count", 0) for info in tables.values())
+ restored_total = sum(result.restored_counts.values())
+
+ if not self.dry_run:
+ # Note: restored count might be less than manifest count if records already exist
+ logger.info(
+ "Restore verification: manifest_total=%d, restored_total=%d",
+ manifest_total,
+ restored_total,
+ )
+
+ # Delete the archive log record after successful restore
+ repo = self._get_workflow_run_repo()
+ repo.delete_archive_log_by_run_id(session, run_id)
+
+ session.commit()
+
+ result.success = True
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f"Completed restore for workflow run {run_id}: restored={result.restored_counts}",
+ fg="green",
+ )
+ )
+
+ except Exception as e:
+ logger.exception("Failed to restore workflow run %s", run_id)
+ result.error = str(e)
+ session.rollback()
+ click.echo(click.style(f"Restore failed: {e}", fg="red"))
+
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
+ sessionmaker(bind=db.engine, expire_on_commit=False)
+ )
+ return self.workflow_run_repo
+
+ @staticmethod
+ def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
+ try:
+ data = archive.read("manifest.json")
+ except KeyError as e:
+ raise ValueError("manifest.json missing from archive bundle") from e
+ return json.loads(data.decode("utf-8"))
+
+ def _restore_table_records(
+ self,
+ session: Session,
+ table_name: str,
+ records: list[dict[str, Any]],
+ *,
+ schema_version: str,
+ ) -> int:
+ """
+ Restore records to a table.
+
+ Uses INSERT ... ON CONFLICT DO NOTHING for idempotency.
+
+ Args:
+ session: Database session
+ table_name: Name of the table
+ records: List of record dictionaries
+ schema_version: Archived schema version from manifest
+
+ Returns:
+ Number of records actually inserted
+ """
+ if not records:
+ return 0
+
+ model = TABLE_MODELS.get(table_name)
+ if not model:
+ logger.warning("Unknown table: %s", table_name)
+ return 0
+
+ column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model)
+ unknown_fields: set[str] = set()
+
+ # Apply schema mapping, filter to current columns, then convert datetimes
+ converted_records = []
+ for record in records:
+ mapped = self._apply_schema_mapping(table_name, schema_version, record)
+ unknown_fields.update(set(mapped.keys()) - column_names)
+ filtered = {key: value for key, value in mapped.items() if key in column_names}
+ for key in non_nullable_with_default:
+ if key in filtered and filtered[key] is None:
+ filtered.pop(key)
+ missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None]
+ if missing_required:
+ missing_cols = ", ".join(sorted(missing_required))
+ raise ValueError(
+ f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}"
+ )
+ converted = self._convert_datetime_fields(filtered, model)
+ converted_records.append(converted)
+ if unknown_fields:
+ logger.warning(
+ "Dropped unknown columns for %s (schema_version=%s): %s",
+ table_name,
+ schema_version,
+ ", ".join(sorted(unknown_fields)),
+ )
+
+ # Use INSERT ... ON CONFLICT DO NOTHING for idempotency
+ stmt = pg_insert(model).values(converted_records)
+ stmt = stmt.on_conflict_do_nothing(index_elements=["id"])
+
+ result = session.execute(stmt)
+ return cast(CursorResult, result).rowcount or 0
+
+ def _convert_datetime_fields(
+ self,
+ record: dict[str, Any],
+ model: type[DeclarativeBase] | Any,
+ ) -> dict[str, Any]:
+ """Convert ISO datetime strings to datetime objects."""
+ from sqlalchemy import DateTime
+
+ result = dict(record)
+
+ for column in model.__table__.columns:
+ if isinstance(column.type, DateTime):
+ value = result.get(column.key)
+ if isinstance(value, str):
+ try:
+ result[column.key] = datetime.fromisoformat(value)
+ except ValueError:
+ pass
+
+ return result
+
+ def _get_schema_version(self, manifest: dict[str, Any]) -> str:
+ schema_version = manifest.get("schema_version")
+ if not schema_version:
+ logger.warning("Manifest missing schema_version; defaulting to 1.0")
+ schema_version = "1.0"
+ schema_version = str(schema_version)
+ if schema_version not in SCHEMA_MAPPERS:
+ raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.")
+ return schema_version
+
+ def _apply_schema_mapping(
+ self,
+ table_name: str,
+ schema_version: str,
+ record: dict[str, Any],
+ ) -> dict[str, Any]:
+ # Keep hook for forward/backward compatibility when schema evolves.
+ mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name)
+ if mapper is None:
+ return dict(record)
+ return mapper(record)
+
+ def _get_model_column_info(
+ self,
+ model: type[DeclarativeBase] | Any,
+ ) -> tuple[set[str], set[str], set[str]]:
+ columns = list(model.__table__.columns)
+ column_names = {column.key for column in columns}
+ required_columns = {
+ column.key
+ for column in columns
+ if not column.nullable
+ and column.default is None
+ and column.server_default is None
+ and not column.autoincrement
+ }
+ non_nullable_with_default = {
+ column.key
+ for column in columns
+ if not column.nullable
+ and (column.default is not None or column.server_default is not None or column.autoincrement)
+ }
+ return column_names, required_columns, non_nullable_with_default
+
+ def restore_batch(
+ self,
+ tenant_ids: list[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int = 100,
+ ) -> list[RestoreResult]:
+ """
+ Restore multiple workflow runs by time range.
+
+ Args:
+ tenant_ids: Optional tenant IDs
+ start_date: Start date filter
+ end_date: End date filter
+ limit: Maximum number of runs to restore (default: 100)
+
+ Returns:
+ List of RestoreResult objects
+ """
+ results: list[RestoreResult] = []
+ if tenant_ids is not None and not tenant_ids:
+ return results
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = self._get_workflow_run_repo()
+
+ with session_maker() as session:
+ archive_logs = repo.get_archived_logs_by_time_range(
+ session=session,
+ tenant_ids=tenant_ids,
+ start_date=start_date,
+ end_date=end_date,
+ limit=limit,
+ )
+
+ click.echo(
+ click.style(
+ f"Found {len(archive_logs)} archived workflow runs to restore",
+ fg="white",
+ )
+ )
+
+ def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult:
+ return self._restore_from_run(
+ archive_log,
+ session_maker=session_maker,
+ )
+
+ with ThreadPoolExecutor(max_workers=self.workers) as executor:
+ results = list(executor.map(_restore_with_session, archive_logs))
+
+ total_counts: dict[str, int] = {}
+ for result in results:
+ for table_name, count in result.restored_counts.items():
+ total_counts[table_name] = total_counts.get(table_name, 0) + count
+ success_count = sum(1 for result in results if result.success)
+
+ if self.dry_run:
+ click.echo(
+ click.style(
+ f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}",
+ fg="yellow",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}",
+ fg="green",
+ )
+ )
+
+ return results
+
+ def restore_by_run_id(
+ self,
+ run_id: str,
+ ) -> RestoreResult:
+ """
+ Restore a single workflow run by run ID.
+ """
+ repo = self._get_workflow_run_repo()
+ archive_log = repo.get_archived_log_by_run_id(run_id)
+
+ if not archive_log:
+ click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red"))
+ return RestoreResult(
+ run_id=run_id,
+ tenant_id="",
+ success=False,
+ restored_counts={},
+ error=f"Workflow run archive {run_id} not found",
+ )
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ result = self._restore_from_run(archive_log, session_maker=session_maker)
+ if self.dry_run and result.success:
+ click.echo(
+ click.style(
+ f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}",
+ fg="yellow",
+ )
+ )
+ return result
diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py
index 9bd797a45..5ca0b6300 100644
--- a/api/services/webapp_auth_service.py
+++ b/api/services/webapp_auth_service.py
@@ -12,6 +12,7 @@ from libs.passport import PassportService
from libs.password import compare_password
from models import Account, AccountStatus
from models.model import App, EndUser, Site
+from services.account_service import AccountService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
@@ -32,7 +33,7 @@ class WebAppAuthService:
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
- account = db.session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
raise AccountNotFoundError()
@@ -52,7 +53,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
- account = db.session.query(Account).where(Account.email == email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
return None
diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py
index 8574d3025..efc76c33b 100644
--- a/api/services/workflow_app_service.py
+++ b/api/services/workflow_app_service.py
@@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowExecutionStatus
-from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
+from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole
from models.trigger import WorkflowTriggerLog
from services.plugin.plugin_service import PluginService
@@ -173,7 +173,80 @@ class WorkflowAppService:
"data": items,
}
- def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]:
+ def get_paginate_workflow_archive_logs(
+ self,
+ *,
+ session: Session,
+ app_model: App,
+ page: int = 1,
+ limit: int = 20,
+ ):
+ """
+ Get paginate workflow archive logs using SQLAlchemy 2.0 style.
+ """
+ stmt = select(WorkflowArchiveLog).where(
+ WorkflowArchiveLog.tenant_id == app_model.tenant_id,
+ WorkflowArchiveLog.app_id == app_model.id,
+ WorkflowArchiveLog.log_id.isnot(None),
+ )
+
+ stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc())
+
+ count_stmt = select(func.count()).select_from(stmt.subquery())
+ total = session.scalar(count_stmt) or 0
+
+ offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
+
+ logs = list(session.scalars(offset_stmt).all())
+ account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT}
+ end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER}
+
+ accounts_by_id = {}
+ if account_ids:
+ accounts_by_id = {
+ account.id: account
+ for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all()
+ }
+
+ end_users_by_id = {}
+ if end_user_ids:
+ end_users_by_id = {
+ end_user.id: end_user
+ for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all()
+ }
+
+ items = []
+ for log in logs:
+ if log.created_by_role == CreatorUserRole.ACCOUNT:
+ created_by_account = accounts_by_id.get(log.created_by)
+ created_by_end_user = None
+ elif log.created_by_role == CreatorUserRole.END_USER:
+ created_by_account = None
+ created_by_end_user = end_users_by_id.get(log.created_by)
+ else:
+ created_by_account = None
+ created_by_end_user = None
+
+ items.append(
+ {
+ "id": log.id,
+ "workflow_run": log.workflow_run_summary,
+ "trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata),
+ "created_by_account": created_by_account,
+ "created_by_end_user": created_by_end_user,
+ "created_at": log.log_created_at,
+ }
+ )
+
+ return {
+ "page": page,
+ "limit": limit,
+ "total": total,
+ "has_more": total > page * limit,
+ "data": items,
+ }
+
+ def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:
return {}
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index 9407a2b3f..70b019023 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
-from core.variables import Segment, StringSegment, Variable
+from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
- _fallback_variables: Sequence[Variable]
+ _fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
- fallback_variables: Sequence[Variable] | None = None,
+ fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
- # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
- variable_by_selector: dict[tuple[str, str], Variable] = {}
+ # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
+ variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
- def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
+ def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index b45a167b7..d8c315917 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
-from core.variables import Variable
-from core.variables.variables import VariableUnion
+from core.variables import VariableBase
+from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -198,8 +198,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
- conversation_variables: list[Variable],
+ conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- conversation_variables=cast(list[VariableUnion], conversation_variables), #
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool
diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py
index e7dead8a5..62e6497e9 100644
--- a/api/tasks/add_document_to_index_task.py
+++ b/api/tasks/add_document_to_index_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
@@ -28,106 +28,106 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
start_at = time.perf_counter()
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
+ if not dataset_document:
+ logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
+ return
- if dataset_document.indexing_status != "completed":
- db.session.close()
- return
+ if dataset_document.indexing_status != "completed":
+ return
- indexing_cache_key = f"document_{dataset_document.id}_indexing"
+ indexing_cache_key = f"document_{dataset_document.id}_indexing"
- try:
- dataset = dataset_document.dataset
- if not dataset:
- raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
+ try:
+ dataset = dataset_document.dataset
+ if not dataset:
+ raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
- segments = (
- db.session.query(DocumentSegment)
- .where(
- DocumentSegment.document_id == dataset_document.id,
- DocumentSegment.status == "completed",
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.status == "completed",
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
+
+ # delete auto disable log
+ session.query(DatasetAutoDisableLog).where(
+ DatasetAutoDisableLog.document_id == dataset_document.id
+ ).delete()
+
+ # update segment to enable
+ session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
+ {
+ DocumentSegment.enabled: True,
+ DocumentSegment.disabled_at: None,
+ DocumentSegment.disabled_by: None,
+ DocumentSegment.updated_at: naive_utc_now(),
+ }
)
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
+ session.commit()
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
-
- # delete auto disable log
- db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
-
- # update segment to enable
- db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
- {
- DocumentSegment.enabled: True,
- DocumentSegment.disabled_at: None,
- DocumentSegment.disabled_by: None,
- DocumentSegment.updated_at: naive_utc_now(),
- }
- )
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
- )
- except Exception as e:
- logger.exception("add document to index failed")
- dataset_document.enabled = False
- dataset_document.disabled_at = naive_utc_now()
- dataset_document.indexing_status = "error"
- dataset_document.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
+ )
+ except Exception as e:
+ logger.exception("add document to index failed")
+ dataset_document.enabled = False
+ dataset_document.disabled_at = naive_utc_now()
+ dataset_document.indexing_status = "error"
+ dataset_document.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py
index 775814318..fc6bf0345 100644
--- a/api/tasks/annotation/batch_import_annotations_task.py
+++ b/api/tasks/annotation/batch_import_annotations_task.py
@@ -5,9 +5,9 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@@ -32,74 +32,72 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
active_jobs_key = f"annotation_import_active:{tenant_id}"
- # get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ with session_factory.create_session() as session:
+ # get app info
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- if app:
- try:
- documents = []
- for content in content_list:
- annotation = MessageAnnotation(
- app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
+ if app:
+ try:
+ documents = []
+ for content in content_list:
+ annotation = MessageAnnotation(
+ app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
+ )
+ session.add(annotation)
+ session.flush()
+
+ document = Document(
+ page_content=content["question"],
+ metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ )
+ documents.append(document)
+ # if annotation reply is enabled , batch add annotations' index
+ app_annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
- db.session.add(annotation)
- db.session.flush()
- document = Document(
- page_content=content["question"],
- metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
- )
- documents.append(document)
- # if annotation reply is enabled , batch add annotations' index
- app_annotation_setting = (
- db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
- )
+ if app_annotation_setting:
+ dataset_collection_binding = (
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ app_annotation_setting.collection_binding_id, "annotation"
+ )
+ )
+ if not dataset_collection_binding:
+ raise NotFound("App annotation setting not found")
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=dataset_collection_binding.provider_name,
+ embedding_model=dataset_collection_binding.model_name,
+ collection_binding_id=dataset_collection_binding.id,
+ )
- if app_annotation_setting:
- dataset_collection_binding = (
- DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
- app_annotation_setting.collection_binding_id, "annotation"
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ vector.create(documents, duplicate_check=True)
+
+ session.commit()
+ redis_client.setex(indexing_cache_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Build index successful for batch import annotation: {} latency: {}".format(
+ job_id, end_at - start_at
+ ),
+ fg="green",
)
)
- if not dataset_collection_binding:
- raise NotFound("App annotation setting not found")
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=dataset_collection_binding.provider_name,
- embedding_model=dataset_collection_binding.model_name,
- collection_binding_id=dataset_collection_binding.id,
- )
-
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- vector.create(documents, duplicate_check=True)
-
- db.session.commit()
- redis_client.setex(indexing_cache_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Build index successful for batch import annotation: {} latency: {}".format(
- job_id, end_at - start_at
- ),
- fg="green",
- )
- )
- except Exception as e:
- db.session.rollback()
- redis_client.setex(indexing_cache_key, 600, "error")
- indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
- redis_client.setex(indexing_error_msg_key, 600, str(e))
- logger.exception("Build index for batch import annotations failed")
- finally:
- # Clean up active job tracking to release concurrency slot
- try:
- redis_client.zrem(active_jobs_key, job_id)
- logger.debug("Released concurrency slot for job: %s", job_id)
- except Exception as cleanup_error:
- # Log but don't fail if cleanup fails - the job will be auto-expired
- logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
-
- # Close database session
- db.session.close()
+ except Exception as e:
+ session.rollback()
+ redis_client.setex(indexing_cache_key, 600, "error")
+ indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
+ redis_client.setex(indexing_error_msg_key, 600, str(e))
+ logger.exception("Build index for batch import annotations failed")
+ finally:
+ # Clean up active job tracking to release concurrency slot
+ try:
+ redis_client.zrem(active_jobs_key, job_id)
+ logger.debug("Released concurrency slot for job: %s", job_id)
+ except Exception as cleanup_error:
+ # Log but don't fail if cleanup fails - the job will be auto-expired
+ logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py
index c0020b29e..7b5cd46b0 100644
--- a/api/tasks/annotation/disable_annotation_reply_task.py
+++ b/api/tasks/annotation/disable_annotation_reply_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import exists, select
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
- if not app:
- logger.info(click.style(f"App not found: {app_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
+ if not app:
+ logger.info(click.style(f"App not found: {app_id}", fg="red"))
+ return
- app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
-
- if not app_annotation_setting:
- logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
- db.session.close()
- return
-
- disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
- disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
-
- try:
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- collection_binding_id=app_annotation_setting.collection_binding_id,
+ app_annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
+ if not app_annotation_setting:
+ logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
+ return
+
+ disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
+ disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
+
try:
- if annotations_exists:
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- vector.delete()
- except Exception:
- logger.exception("Delete annotation index failed when annotation deleted.")
- redis_client.setex(disable_app_annotation_job_key, 600, "completed")
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ collection_binding_id=app_annotation_setting.collection_binding_id,
+ )
- # delete annotation setting
- db.session.delete(app_annotation_setting)
- db.session.commit()
+ try:
+ if annotations_exists:
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ vector.delete()
+ except Exception:
+ logger.exception("Delete annotation index failed when annotation deleted.")
+ redis_client.setex(disable_app_annotation_job_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("Annotation batch deleted index failed")
- redis_client.setex(disable_app_annotation_job_key, 600, "error")
- disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
- redis_client.setex(disable_app_annotation_error_key, 600, str(e))
- finally:
- redis_client.delete(disable_app_annotation_key)
- db.session.close()
+ # delete annotation setting
+ session.delete(app_annotation_setting)
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"App annotations index deleted : {app_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ logger.exception("Annotation batch deleted index failed")
+ redis_client.setex(disable_app_annotation_job_key, 600, "error")
+ disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
+ redis_client.setex(disable_app_annotation_error_key, 600, str(e))
+ finally:
+ redis_client.delete(disable_app_annotation_key)
diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py
index be1de3cdd..4f8e2fec7 100644
--- a/api/tasks/annotation/enable_annotation_reply_task.py
+++ b/api/tasks/annotation/enable_annotation_reply_task.py
@@ -5,9 +5,9 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
@@ -33,92 +33,98 @@ def enable_annotation_reply_task(
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ with session_factory.create_session() as session:
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- if not app:
- logger.info(click.style(f"App not found: {app_id}", fg="red"))
- db.session.close()
- return
+ if not app:
+ logger.info(click.style(f"App not found: {app_id}", fg="red"))
+ return
- annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
- enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
- enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
+ annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
+ enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
+ enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
- try:
- documents = []
- dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
- embedding_provider_name, embedding_model_name, "annotation"
- )
- annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
- if annotation_setting:
- if dataset_collection_binding.id != annotation_setting.collection_binding_id:
- old_dataset_collection_binding = (
- DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
- annotation_setting.collection_binding_id, "annotation"
- )
- )
- if old_dataset_collection_binding and annotations:
- old_dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=old_dataset_collection_binding.provider_name,
- embedding_model=old_dataset_collection_binding.model_name,
- collection_binding_id=old_dataset_collection_binding.id,
- )
-
- old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
- try:
- old_vector.delete()
- except Exception as e:
- logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
- annotation_setting.score_threshold = score_threshold
- annotation_setting.collection_binding_id = dataset_collection_binding.id
- annotation_setting.updated_user_id = user_id
- annotation_setting.updated_at = naive_utc_now()
- db.session.add(annotation_setting)
- else:
- new_app_annotation_setting = AppAnnotationSetting(
- app_id=app_id,
- score_threshold=score_threshold,
- collection_binding_id=dataset_collection_binding.id,
- created_user_id=user_id,
- updated_user_id=user_id,
+ try:
+ documents = []
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+ embedding_provider_name, embedding_model_name, "annotation"
)
- db.session.add(new_app_annotation_setting)
+ annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
+ )
+ if annotation_setting:
+ if dataset_collection_binding.id != annotation_setting.collection_binding_id:
+ old_dataset_collection_binding = (
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ annotation_setting.collection_binding_id, "annotation"
+ )
+ )
+ if old_dataset_collection_binding and annotations:
+ old_dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=old_dataset_collection_binding.provider_name,
+ embedding_model=old_dataset_collection_binding.model_name,
+ collection_binding_id=old_dataset_collection_binding.id,
+ )
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=embedding_provider_name,
- embedding_model=embedding_model_name,
- collection_binding_id=dataset_collection_binding.id,
- )
- if annotations:
- for annotation in annotations:
- document = Document(
- page_content=annotation.question_text,
- metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ try:
+ old_vector.delete()
+ except Exception as e:
+ logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
+ annotation_setting.score_threshold = score_threshold
+ annotation_setting.collection_binding_id = dataset_collection_binding.id
+ annotation_setting.updated_user_id = user_id
+ annotation_setting.updated_at = naive_utc_now()
+ session.add(annotation_setting)
+ else:
+ new_app_annotation_setting = AppAnnotationSetting(
+ app_id=app_id,
+ score_threshold=score_threshold,
+ collection_binding_id=dataset_collection_binding.id,
+ created_user_id=user_id,
+ updated_user_id=user_id,
)
- documents.append(document)
+ session.add(new_app_annotation_setting)
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- try:
- vector.delete_by_metadata_field("app_id", app_id)
- except Exception as e:
- logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
- vector.create(documents)
- db.session.commit()
- redis_client.setex(enable_app_annotation_job_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("Annotation batch created index failed")
- redis_client.setex(enable_app_annotation_job_key, 600, "error")
- enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
- redis_client.setex(enable_app_annotation_error_key, 600, str(e))
- db.session.rollback()
- finally:
- redis_client.delete(enable_app_annotation_key)
- db.session.close()
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=embedding_provider_name,
+ embedding_model=embedding_model_name,
+ collection_binding_id=dataset_collection_binding.id,
+ )
+ if annotations:
+ for annotation in annotations:
+ document = Document(
+ page_content=annotation.question_text,
+ metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ )
+ documents.append(document)
+
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ try:
+ vector.delete_by_metadata_field("app_id", app_id)
+ except Exception as e:
+ logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
+ vector.create(documents)
+ session.commit()
+ redis_client.setex(enable_app_annotation_job_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"App annotations added to index: {app_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ logger.exception("Annotation batch created index failed")
+ redis_client.setex(enable_app_annotation_job_key, 600, "error")
+ enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
+ redis_client.setex(enable_app_annotation_error_key, 600, str(e))
+ session.rollback()
+ finally:
+ redis_client.delete(enable_app_annotation_key)
diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py
index f8aac5b46..b51884148 100644
--- a/api/tasks/async_workflow_tasks.py
+++ b/api/tasks/async_workflow_tasks.py
@@ -10,13 +10,13 @@ from typing import Any
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.trigger_post_layer import TriggerPostLayer
-from extensions.ext_database import db
+from core.db.session_factory import session_factory
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
@@ -98,10 +98,7 @@ def _execute_workflow_common(
):
"""Execute workflow with common logic and trigger log updates."""
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
# Get trigger log
@@ -157,7 +154,7 @@ def _execute_workflow_common(
root_node_id=trigger_data.root_node_id,
graph_engine_layers=[
# TODO: Re-enable TimeSliceLayer after the HITL release.
- TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
+ TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
],
)
diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py
index 3e1bd16cc..74b939e84 100644
--- a/api/tasks/batch_clean_document_task.py
+++ b/api/tasks/batch_clean_document_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
@@ -28,65 +28,64 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
"""
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
+ if not doc_form:
+ raise ValueError("doc_form is required")
- try:
- if not doc_form:
- raise ValueError("doc_form is required")
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- db.session.query(DatasetMetadataBinding).where(
- DatasetMetadataBinding.dataset_id == dataset_id,
- DatasetMetadataBinding.document_id.in_(document_ids),
- ).delete(synchronize_session=False)
+ session.query(DatasetMetadataBinding).where(
+ DatasetMetadataBinding.dataset_id == dataset_id,
+ DatasetMetadataBinding.document_id.in_(document_ids),
+ ).delete(synchronize_session=False)
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
- ).all()
- # check segment is exist
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
+ ).all()
+ # check segment is exist
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
+ for image_file in image_files:
+ try:
+ if image_file and image_file.key:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+ stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(stmt)
+ session.delete(segment)
+ if file_ids:
+ files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
+ for file in files:
try:
- if image_file and image_file.key:
- storage.delete(image_file.key)
+ storage.delete(file.key)
except Exception:
- logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
- )
- db.session.delete(image_file)
- db.session.delete(segment)
+ logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
+ stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
+ session.execute(stmt)
- db.session.commit()
- if file_ids:
- files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
- for file in files:
- try:
- storage.delete(file.key)
- except Exception:
- logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
- db.session.delete(file)
+ session.commit()
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Cleaned documents when documents deleted latency: {end_at - start_at}",
- fg="green",
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned documents when documents deleted latency: {end_at - start_at}",
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned documents when documents deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned documents when documents deleted failed")
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index bd95af261..8ee09d573 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -9,9 +9,9 @@ import pandas as pd
from celery import shared_task
from sqlalchemy import func
+from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
@@ -48,104 +48,107 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
- try:
- dataset = db.session.get(Dataset, dataset_id)
- if not dataset:
- raise ValueError("Dataset not exist.")
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.get(Dataset, dataset_id)
+ if not dataset:
+ raise ValueError("Dataset not exist.")
- dataset_document = db.session.get(Document, document_id)
- if not dataset_document:
- raise ValueError("Document not exist.")
+ dataset_document = session.get(Document, document_id)
+ if not dataset_document:
+ raise ValueError("Document not exist.")
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- raise ValueError("Document is not available.")
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ raise ValueError("Document is not available.")
- upload_file = db.session.get(UploadFile, upload_file_id)
- if not upload_file:
- raise ValueError("UploadFile not found.")
+ upload_file = session.get(UploadFile, upload_file_id)
+ if not upload_file:
+ raise ValueError("UploadFile not found.")
- with tempfile.TemporaryDirectory() as temp_dir:
- suffix = Path(upload_file.key).suffix
- file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
- storage.download(upload_file.key, file_path)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ suffix = Path(upload_file.key).suffix
+ file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
+ storage.download(upload_file.key, file_path)
- df = pd.read_csv(file_path)
- content = []
- for _, row in df.iterrows():
+ df = pd.read_csv(file_path)
+ content = []
+ for _, row in df.iterrows():
+ if dataset_document.doc_form == "qa_model":
+ data = {"content": row.iloc[0], "answer": row.iloc[1]}
+ else:
+ data = {"content": row.iloc[0]}
+ content.append(data)
+ if len(content) == 0:
+ raise ValueError("The CSV file is empty.")
+
+ document_segments = []
+ embedding_model = None
+ if dataset.indexing_technique == "high_quality":
+ model_manager = ModelManager()
+ embedding_model = model_manager.get_model_instance(
+ tenant_id=dataset.tenant_id,
+ provider=dataset.embedding_model_provider,
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=dataset.embedding_model,
+ )
+
+ word_count_change = 0
+ if embedding_model:
+ tokens_list = embedding_model.get_text_embedding_num_tokens(
+ texts=[segment["content"] for segment in content]
+ )
+ else:
+ tokens_list = [0] * len(content)
+
+ for segment, tokens in zip(content, tokens_list):
+ content = segment["content"]
+ doc_id = str(uuid.uuid4())
+ segment_hash = helper.generate_text_hash(content)
+ max_position = (
+ session.query(func.max(DocumentSegment.position))
+ .where(DocumentSegment.document_id == dataset_document.id)
+ .scalar()
+ )
+ segment_document = DocumentSegment(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ document_id=document_id,
+ index_node_id=doc_id,
+ index_node_hash=segment_hash,
+ position=max_position + 1 if max_position else 1,
+ content=content,
+ word_count=len(content),
+ tokens=tokens,
+ created_by=user_id,
+ indexing_at=naive_utc_now(),
+ status="completed",
+ completed_at=naive_utc_now(),
+ )
if dataset_document.doc_form == "qa_model":
- data = {"content": row.iloc[0], "answer": row.iloc[1]}
- else:
- data = {"content": row.iloc[0]}
- content.append(data)
- if len(content) == 0:
- raise ValueError("The CSV file is empty.")
+ segment_document.answer = segment["answer"]
+ segment_document.word_count += len(segment["answer"])
+ word_count_change += segment_document.word_count
+ session.add(segment_document)
+ document_segments.append(segment_document)
- document_segments = []
- embedding_model = None
- if dataset.indexing_technique == "high_quality":
- model_manager = ModelManager()
- embedding_model = model_manager.get_model_instance(
- tenant_id=dataset.tenant_id,
- provider=dataset.embedding_model_provider,
- model_type=ModelType.TEXT_EMBEDDING,
- model=dataset.embedding_model,
- )
+ assert dataset_document.word_count is not None
+ dataset_document.word_count += word_count_change
+ session.add(dataset_document)
- word_count_change = 0
- if embedding_model:
- tokens_list = embedding_model.get_text_embedding_num_tokens(
- texts=[segment["content"] for segment in content]
+ VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
+ session.commit()
+ redis_client.setex(indexing_cache_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Segment batch created job: {job_id} latency: {end_at - start_at}",
+ fg="green",
+ )
)
- else:
- tokens_list = [0] * len(content)
-
- for segment, tokens in zip(content, tokens_list):
- content = segment["content"]
- doc_id = str(uuid.uuid4())
- segment_hash = helper.generate_text_hash(content)
- max_position = (
- db.session.query(func.max(DocumentSegment.position))
- .where(DocumentSegment.document_id == dataset_document.id)
- .scalar()
- )
- segment_document = DocumentSegment(
- tenant_id=tenant_id,
- dataset_id=dataset_id,
- document_id=document_id,
- index_node_id=doc_id,
- index_node_hash=segment_hash,
- position=max_position + 1 if max_position else 1,
- content=content,
- word_count=len(content),
- tokens=tokens,
- created_by=user_id,
- indexing_at=naive_utc_now(),
- status="completed",
- completed_at=naive_utc_now(),
- )
- if dataset_document.doc_form == "qa_model":
- segment_document.answer = segment["answer"]
- segment_document.word_count += len(segment["answer"])
- word_count_change += segment_document.word_count
- db.session.add(segment_document)
- document_segments.append(segment_document)
-
- assert dataset_document.word_count is not None
- dataset_document.word_count += word_count_change
- db.session.add(dataset_document)
-
- VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
- db.session.commit()
- redis_client.setex(indexing_cache_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Segment batch created job: {job_id} latency: {end_at - start_at}",
- fg="green",
- )
- )
- except Exception:
- logger.exception("Segments batch created index failed")
- redis_client.setex(indexing_cache_key, 600, "error")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Segments batch created index failed")
+ redis_client.setex(indexing_cache_key, 600, "error")
diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py
index b4d82a150..0d51a743a 100644
--- a/api/tasks/clean_dataset_task.py
+++ b/api/tasks/clean_dataset_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models import WorkflowType
from models.dataset import (
@@ -53,135 +53,155 @@ def clean_dataset_task(
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = Dataset(
- id=dataset_id,
- tenant_id=tenant_id,
- indexing_technique=indexing_technique,
- index_struct=index_struct,
- collection_binding_id=collection_binding_id,
- )
- documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
- # Use JOIN to fetch attachments with bindings in a single query
- attachments_with_bindings = db.session.execute(
- select(SegmentAttachmentBinding, UploadFile)
- .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
- .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
- ).all()
-
- # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
- # This ensures all invalid doc_form values are properly handled
- if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
- # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
- from core.rag.index_processor.constant.index_type import IndexStructureType
-
- doc_form = IndexStructureType.PARAGRAPH_INDEX
- logger.info(
- click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
- )
-
- # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
- # This ensures Document/Segment deletion can continue even if vector database cleanup fails
+ with session_factory.create_session() as session:
try:
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
- logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
- except Exception:
- logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
- # Continue with document and segment deletion even if vector cleanup fails
- logger.info(
- click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
+ dataset = Dataset(
+ id=dataset_id,
+ tenant_id=tenant_id,
+ indexing_technique=indexing_technique,
+ index_struct=index_struct,
+ collection_binding_id=collection_binding_id,
)
+ documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
+ # Use JOIN to fetch attachments with bindings in a single query
+ attachments_with_bindings = session.execute(
+ select(SegmentAttachmentBinding, UploadFile)
+ .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
+ .where(
+ SegmentAttachmentBinding.tenant_id == tenant_id,
+ SegmentAttachmentBinding.dataset_id == dataset_id,
+ )
+ ).all()
- if documents is None or len(documents) == 0:
- logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
- else:
- logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
+ # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
+ # This ensures all invalid doc_form values are properly handled
+ if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
+ # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
+ from core.rag.index_processor.constant.index_type import IndexStructureType
- for document in documents:
- db.session.delete(document)
- # delete document file
+ doc_form = IndexStructureType.PARAGRAPH_INDEX
+ logger.info(
+ click.style(
+ f"Invalid doc_form detected, using default index type for cleanup: {doc_form}",
+ fg="yellow",
+ )
+ )
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
- if image_file is None:
- continue
+ # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
+ # This ensures Document/Segment deletion can continue even if vector database cleanup fails
+ try:
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
+ logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
+ except Exception:
+ logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
+ # Continue with document and segment deletion even if vector cleanup fails
+ logger.info(
+ click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
+ )
+
+ if documents is None or len(documents) == 0:
+ logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
+ else:
+ logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
+
+ for document in documents:
+ session.delete(document)
+
+ segment_ids = [segment.id for segment in segments]
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
+ for image_file in image_files:
+ if image_file is None:
+ continue
+ try:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+ stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(stmt)
+
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ # delete segment attachments
+ if attachments_with_bindings:
+ attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
+ binding_ids = [binding.id for binding, _ in attachments_with_bindings]
+ for binding, attachment_file in attachments_with_bindings:
try:
- storage.delete(image_file.key)
+ storage.delete(attachment_file.key)
except Exception:
logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
+ "Delete attachment_file failed when storage deleted, \
+ attachment_file_id: %s",
+ binding.attachment_id,
)
- db.session.delete(image_file)
- db.session.delete(segment)
- # delete segment attachments
- if attachments_with_bindings:
- for binding, attachment_file in attachments_with_bindings:
- try:
- storage.delete(attachment_file.key)
- except Exception:
- logger.exception(
- "Delete attachment_file failed when storage deleted, \
- attachment_file_id: %s",
- binding.attachment_id,
- )
- db.session.delete(attachment_file)
- db.session.delete(binding)
+ attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
+ session.execute(attachment_file_delete_stmt)
- db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
- db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
- db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
- # delete dataset metadata
- db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
- db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
- # delete pipeline and workflow
- if pipeline_id:
- db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
- db.session.query(Workflow).where(
- Workflow.tenant_id == tenant_id,
- Workflow.app_id == pipeline_id,
- Workflow.type == WorkflowType.RAG_PIPELINE,
- ).delete()
- # delete files
- if documents:
- for document in documents:
- try:
+ binding_delete_stmt = delete(SegmentAttachmentBinding).where(
+ SegmentAttachmentBinding.id.in_(binding_ids)
+ )
+ session.execute(binding_delete_stmt)
+
+ session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
+ session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
+ session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
+ # delete dataset metadata
+ session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
+ session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
+ # delete pipeline and workflow
+ if pipeline_id:
+ session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
+ session.query(Workflow).where(
+ Workflow.tenant_id == tenant_id,
+ Workflow.app_id == pipeline_id,
+ Workflow.type == WorkflowType.RAG_PIPELINE,
+ ).delete()
+ # delete files
+ if documents:
+ file_ids = []
+ for document in documents:
if document.data_source_type == "upload_file":
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
- file = (
- db.session.query(UploadFile)
- .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
- .first()
- )
- if not file:
- continue
- storage.delete(file.key)
- db.session.delete(file)
- except Exception:
- continue
+ file_ids.append(file_id)
+ files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
+ for file in files:
+ storage.delete(file.key)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
- )
- except Exception:
- # Add rollback to prevent dirty session state in case of exceptions
- # This ensures the database session is properly cleaned up
- try:
- db.session.rollback()
- logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
+ file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
+ session.execute(file_delete_stmt)
+
+ session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
except Exception:
- logger.exception("Failed to rollback database session")
+ # Add rollback to prevent dirty session state in case of exceptions
+ # This ensures the database session is properly cleaned up
+ try:
+ session.rollback()
+ logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
+ except Exception:
+ logger.exception("Failed to rollback database session")
- logger.exception("Cleaned dataset when dataset deleted failed")
- finally:
- db.session.close()
+ logger.exception("Cleaned dataset when dataset deleted failed")
+ finally:
+ # Explicitly close the session for test expectations and safety
+ try:
+ session.close()
+ except Exception:
+ logger.exception("Failed to close database session")
diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py
index 6d2feb1da..86e7cc716 100644
--- a/api/tasks/clean_document_task.py
+++ b/api/tasks/clean_document_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
@@ -29,85 +29,94 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- # Use JOIN to fetch attachments with bindings in a single query
- attachments_with_bindings = db.session.execute(
- select(SegmentAttachmentBinding, UploadFile)
- .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
- .where(
- SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
- SegmentAttachmentBinding.dataset_id == dataset_id,
- SegmentAttachmentBinding.document_id == document_id,
- )
- ).all()
- # check segment is exist
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ # Use JOIN to fetch attachments with bindings in a single query
+ attachments_with_bindings = session.execute(
+ select(SegmentAttachmentBinding, UploadFile)
+ .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
+ .where(
+ SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
+ SegmentAttachmentBinding.dataset_id == dataset_id,
+ SegmentAttachmentBinding.document_id == document_id,
+ )
+ ).all()
+ # check segment is exist
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
- if image_file is None:
- continue
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.scalars(
+ select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ ).all()
+ for image_file in image_files:
+ if image_file is None:
+ continue
+ try:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+
+ image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(image_file_delete_stmt)
+ session.delete(segment)
+
+ session.commit()
+ if file_id:
+ file = session.query(UploadFile).where(UploadFile.id == file_id).first()
+ if file:
try:
- storage.delete(image_file.key)
+ storage.delete(file.key)
+ except Exception:
+ logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
+ session.delete(file)
+ # delete segment attachments
+ if attachments_with_bindings:
+ attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
+ binding_ids = [binding.id for binding, _ in attachments_with_bindings]
+ for binding, attachment_file in attachments_with_bindings:
+ try:
+ storage.delete(attachment_file.key)
except Exception:
logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
+ "Delete attachment_file failed when storage deleted, \
+ attachment_file_id: %s",
+ binding.attachment_id,
)
- db.session.delete(image_file)
- db.session.delete(segment)
+ attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
+ session.execute(attachment_file_delete_stmt)
- db.session.commit()
- if file_id:
- file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
- if file:
- try:
- storage.delete(file.key)
- except Exception:
- logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
- db.session.delete(file)
- db.session.commit()
- # delete segment attachments
- if attachments_with_bindings:
- for binding, attachment_file in attachments_with_bindings:
- try:
- storage.delete(attachment_file.key)
- except Exception:
- logger.exception(
- "Delete attachment_file failed when storage deleted, \
- attachment_file_id: %s",
- binding.attachment_id,
- )
- db.session.delete(attachment_file)
- db.session.delete(binding)
+ binding_delete_stmt = delete(SegmentAttachmentBinding).where(
+ SegmentAttachmentBinding.id.in_(binding_ids)
+ )
+ session.execute(binding_delete_stmt)
- # delete dataset metadata binding
- db.session.query(DatasetMetadataBinding).where(
- DatasetMetadataBinding.dataset_id == dataset_id,
- DatasetMetadataBinding.document_id == document_id,
- ).delete()
- db.session.commit()
+ # delete dataset metadata binding
+ session.query(DatasetMetadataBinding).where(
+ DatasetMetadataBinding.dataset_id == dataset_id,
+ DatasetMetadataBinding.document_id == document_id,
+ ).delete()
+ session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
- fg="green",
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when document deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned document when document deleted failed")
diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py
index 771b43f9b..bcca1bf49 100644
--- a/api/tasks/clean_notion_document_task.py
+++ b/api/tasks/clean_notion_document_task.py
@@ -3,10 +3,10 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@@ -24,37 +24,37 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- for document_id in document_ids:
- document = db.session.query(Document).where(Document.id == document_id).first()
- db.session.delete(document)
+ if not dataset:
+ raise Exception("Document has no dataset")
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- index_node_ids = [segment.index_node_id for segment in segments]
+ document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
+ session.execute(document_delete_stmt)
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ for document_id in document_ids:
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Clean document when import form notion document deleted end :: {} latency: {}".format(
- dataset_id, end_at - start_at
- ),
- fg="green",
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Clean document when import form notion document deleted end :: {} latency: {}".format(
+ dataset_id, end_at - start_at
+ ),
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when import form notion document deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned document when import form notion document deleted failed")
diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py
index 6b2907cff..b5e472d71 100644
--- a/api/tasks/create_segment_to_index_task.py
+++ b/api/tasks/create_segment_to_index_task.py
@@ -4,9 +4,9 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "waiting":
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- # update segment status to indexing
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(
- {
- DocumentSegment.status: "indexing",
- DocumentSegment.indexing_at: naive_utc_now(),
- }
- )
- db.session.commit()
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
-
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "waiting":
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.load(dataset, [document])
+ try:
+ # update segment status to indexing
+ session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "indexing",
+ DocumentSegment.indexing_at: naive_utc_now(),
+ }
+ )
+ session.commit()
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- # update segment to completed
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(
- {
- DocumentSegment.status: "completed",
- DocumentSegment.completed_at: naive_utc_now(),
- }
- )
- db.session.commit()
+ dataset = segment.dataset
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("create segment to index failed")
- segment.enabled = False
- segment.disabled_at = naive_utc_now()
- segment.status = "error"
- segment.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.load(dataset, [document])
+
+ # update segment to completed
+ session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "completed",
+ DocumentSegment.completed_at: naive_utc_now(),
+ }
+ )
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("create segment to index failed")
+ segment.enabled = False
+ segment.disabled_at = naive_utc_now()
+ segment.status = "error"
+ segment.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py
index 3d13afdec..fa844a864 100644
--- a/api/tasks/deal_dataset_index_update_task.py
+++ b/api/tasks/deal_dataset_index_update_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task # type: ignore
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).filter_by(id=dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- if action == "upgrade":
- dataset_documents = (
- db.session.query(DatasetDocument)
- .where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ if action == "upgrade":
+ dataset_documents = (
+ session.query(DatasetDocument)
+ .where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ .all()
)
- .all()
- )
- if dataset_documents:
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
+ if dataset_documents:
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
- for dataset_document in dataset_documents:
- try:
- # add from vector index
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ for dataset_document in dataset_documents:
+ try:
+ # add from vector index
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
)
-
- documents.append(document)
- # save vector index
- # clean keywords
- index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
- index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- elif action == "update":
- dataset_documents = (
- db.session.query(DatasetDocument)
- .where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- .all()
- )
- # add new index
- if dataset_documents:
- # update document status
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
-
- # clean index
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
-
- for dataset_document in dataset_documents:
- # update from vector index
- try:
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(
- dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- else:
- # clean collection
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+ if segments:
+ documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- end_at = time.perf_counter()
- logging.info(
- click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
- )
- except Exception:
- logging.exception("Deal dataset vector index failed")
- finally:
- db.session.close()
+ documents.append(document)
+ # save vector index
+ # clean keywords
+ index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
+ index_processor.load(dataset, documents, with_keywords=False)
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ elif action == "update":
+ dataset_documents = (
+ session.query(DatasetDocument)
+ .where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ .all()
+ )
+ # add new index
+ if dataset_documents:
+ # update document status
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
+
+ # clean index
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ for dataset_document in dataset_documents:
+ # update from vector index
+ try:
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
+ )
+ if segments:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+ # save vector index
+ index_processor.load(
+ dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ )
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ else:
+ # clean collection
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("Deal dataset vector index failed")
diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py
index 1c7de3b1c..0047e04a1 100644
--- a/api/tasks/deal_dataset_vector_index_task.py
+++ b/api/tasks/deal_dataset_vector_index_task.py
@@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).filter_by(id=dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- if action == "remove":
- index_processor.clean(dataset, None, with_keywords=False)
- elif action == "add":
- dataset_documents = db.session.scalars(
- select(DatasetDocument).where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- ).all()
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ if action == "remove":
+ index_processor.clean(dataset, None, with_keywords=False)
+ elif action == "add":
+ dataset_documents = session.scalars(
+ select(DatasetDocument).where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ ).all()
- if dataset_documents:
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
+ if dataset_documents:
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
- for dataset_document in dataset_documents:
- try:
- # add from vector index
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ for dataset_document in dataset_documents:
+ try:
+ # add from vector index
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
)
-
- documents.append(document)
- # save vector index
- index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- elif action == "update":
- dataset_documents = db.session.scalars(
- select(DatasetDocument).where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- ).all()
- # add new index
- if dataset_documents:
- # update document status
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
-
- # clean index
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
-
- for dataset_document in dataset_documents:
- # update from vector index
- try:
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(
- dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- else:
- # clean collection
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+ if segments:
+ documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- end_at = time.perf_counter()
- logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("Deal dataset vector index failed")
- finally:
- db.session.close()
+ documents.append(document)
+ # save vector index
+ index_processor.load(dataset, documents, with_keywords=False)
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ elif action == "update":
+ dataset_documents = session.scalars(
+ select(DatasetDocument).where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ ).all()
+ # add new index
+ if dataset_documents:
+ # update document status
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
+
+ # clean index
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ for dataset_document in dataset_documents:
+ # update from vector index
+ try:
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
+ )
+ if segments:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+ # save vector index
+ index_processor.load(
+ dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ )
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ else:
+ # clean collection
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("Deal dataset vector index failed")
diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py
index cb703cc26..ecf6f9cb3 100644
--- a/api/tasks/delete_account_task.py
+++ b/api/tasks/delete_account_task.py
@@ -3,7 +3,7 @@ import logging
from celery import shared_task
from configs import dify_config
-from extensions.ext_database import db
+from core.db.session_factory import session_factory
from models import Account
from services.billing_service import BillingService
from tasks.mail_account_deletion_task import send_deletion_success_task
@@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
- account = db.session.query(Account).where(Account.id == account_id).first()
- try:
- if dify_config.BILLING_ENABLED:
- BillingService.delete_account(account_id)
- except Exception:
- logger.exception("Failed to delete account %s from billing service.", account_id)
- raise
+ with session_factory.create_session() as session:
+ account = session.query(Account).where(Account.id == account_id).first()
+ try:
+ if dify_config.BILLING_ENABLED:
+ BillingService.delete_account(account_id)
+ except Exception:
+ logger.exception("Failed to delete account %s from billing service.", account_id)
+ raise
- if not account:
- logger.error("Account %s not found.", account_id)
- return
- # send success email
- send_deletion_success_task.delay(account.email)
+ if not account:
+ logger.error("Account %s not found.", account_id)
+ return
+ # send success email
+ send_deletion_success_task.delay(account.email)
diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py
index 756b67c93..9664b8ac7 100644
--- a/api/tasks/delete_conversation_task.py
+++ b/api/tasks/delete_conversation_task.py
@@ -4,7 +4,7 @@ import time
import click
from celery import shared_task
-from extensions.ext_database import db
+from core.db.session_factory import session_factory
from models import ConversationVariable
from models.model import Message, MessageAnnotation, MessageFeedback
from models.tools import ToolConversationVariables, ToolFile
@@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str):
)
start_at = time.perf_counter()
- try:
- db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(ToolConversationVariables).where(
- ToolConversationVariables.conversation_id == conversation_id
- ).delete(synchronize_session=False)
-
- db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
-
- db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
-
- db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
- fg="green",
+ with session_factory.create_session() as session:
+ try:
+ session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
+ synchronize_session=False
)
- )
- except Exception as e:
- logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
- db.session.rollback()
- raise e
- finally:
- db.session.close()
+ session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.query(ToolConversationVariables).where(
+ ToolConversationVariables.conversation_id == conversation_id
+ ).delete(synchronize_session=False)
+
+ session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ (
+ f"Succeeded cleaning data from db for conversation_id {conversation_id} "
+ f"latency: {end_at - start_at}"
+ ),
+ fg="green",
+ )
+ )
+
+ except Exception:
+ logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
+ session.rollback()
+ raise
diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py
index fd8a4e7c9..9f2ee8abd 100644
--- a/api/tasks/delete_segment_from_index_task.py
+++ b/api/tasks/delete_segment_from_index_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
@@ -28,49 +28,52 @@ def delete_segment_from_index_task(
"""
logger.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
- return
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
+ return
- dataset_document = db.session.query(Document).where(Document.id == document_id).first()
- if not dataset_document:
- return
+ dataset_document = session.query(Document).where(Document.id == document_id).first()
+ if not dataset_document:
+ return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logging.info("Document not in valid state for index operations, skipping")
- return
- doc_form = dataset_document.doc_form
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logging.info("Document not in valid state for index operations, skipping")
+ return
+ doc_form = dataset_document.doc_form
- # Proceed with index cleanup using the index_node_ids directly
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(
- dataset,
- index_node_ids,
- with_keywords=True,
- delete_child_chunks=True,
- precomputed_child_node_ids=child_node_ids,
- )
- if dataset.is_multimodal:
- # delete segment attachment binding
- segment_attachment_bindings = (
- db.session.query(SegmentAttachmentBinding)
- .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
- .all()
+ # Proceed with index cleanup using the index_node_ids directly
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(
+ dataset,
+ index_node_ids,
+ with_keywords=True,
+ delete_child_chunks=True,
+ precomputed_child_node_ids=child_node_ids,
)
- if segment_attachment_bindings:
- attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
- index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
- for binding in segment_attachment_bindings:
- db.session.delete(binding)
- # delete upload file
- db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
- db.session.commit()
+ if dataset.is_multimodal:
+ # delete segment attachment binding
+ segment_attachment_bindings = (
+ session.query(SegmentAttachmentBinding)
+ .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
+ .all()
+ )
+ if segment_attachment_bindings:
+ attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
+ index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
+ for binding in segment_attachment_bindings:
+ session.delete(binding)
+ # delete upload file
+ session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
+ session.commit()
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("delete segment from index failed")
- finally:
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
+ except Exception:
+ logger.exception("delete segment from index failed")
diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py
index 6b5f01b41..0ce6429a9 100644
--- a/api/tasks/disable_segment_from_index_task.py
+++ b/api/tasks/disable_segment_from_index_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@@ -23,46 +23,53 @@ def disable_segment_from_index_task(segment_id: str):
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "completed":
- logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "completed":
+ logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_type = dataset_document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.clean(dataset, [segment.index_node_id])
+ try:
+ dataset = segment.dataset
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("remove segment from index failed")
- segment.enabled = True
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_type = dataset_document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.clean(dataset, [segment.index_node_id])
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Segment removed from index: {segment.id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("remove segment from index failed")
+ segment.enabled = True
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py
index c2a3de29f..03635902d 100644
--- a/api/tasks/disable_segments_from_index_task.py
+++ b/api/tasks/disable_segments_from_index_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
@@ -26,69 +26,65 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
+ return
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
- db.session.close()
- return
- # sync index processor
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if not dataset_document:
+ logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
+ return
+ if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+ logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
+ return
+ # sync index processor
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- )
- ).all()
-
- if not segments:
- db.session.close()
- return
-
- try:
- index_node_ids = [segment.index_node_id for segment in segments]
- if dataset.is_multimodal:
- segment_ids = [segment.id for segment in segments]
- segment_attachment_bindings = (
- db.session.query(SegmentAttachmentBinding)
- .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
- .all()
+ segments = session.scalars(
+ select(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
)
- if segment_attachment_bindings:
- attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
- index_node_ids.extend(attachment_ids)
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+ ).all()
- end_at = time.perf_counter()
- logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
- except Exception:
- # update segment error msg
- db.session.query(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- ).update(
- {
- "disabled_at": None,
- "disabled_by": None,
- "enabled": True,
- }
- )
- db.session.commit()
- finally:
- for segment in segments:
- indexing_cache_key = f"segment_{segment.id}_indexing"
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not segments:
+ return
+
+ try:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ if dataset.is_multimodal:
+ segment_ids = [segment.id for segment in segments]
+ segment_attachment_bindings = (
+ session.query(SegmentAttachmentBinding)
+ .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
+ .all()
+ )
+ if segment_attachment_bindings:
+ attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
+ index_node_ids.extend(attachment_ids)
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
+ except Exception:
+ # update segment error msg
+ session.query(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
+ ).update(
+ {
+ "disabled_at": None,
+ "disabled_by": None,
+ "enabled": True,
+ }
+ )
+ session.commit()
+ finally:
+ for segment in segments:
+ indexing_cache_key = f"segment_{segment.id}_indexing"
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py
index 5fc2597c9..149185f6e 100644
--- a/api/tasks/document_indexing_sync_task.py
+++ b/api/tasks/document_indexing_sync_task.py
@@ -3,12 +3,12 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.datasource_provider_service import DatasourceProviderService
@@ -28,105 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
-
- data_source_info = document.data_source_info_dict
- if document.data_source_type == "notion_import":
- if (
- not data_source_info
- or "notion_page_id" not in data_source_info
- or "notion_workspace_id" not in data_source_info
- ):
- raise ValueError("no notion page found")
- workspace_id = data_source_info["notion_workspace_id"]
- page_id = data_source_info["notion_page_id"]
- page_type = data_source_info["type"]
- page_edited_time = data_source_info["last_edited_time"]
- credential_id = data_source_info.get("credential_id")
-
- # Get credentials from datasource provider
- datasource_provider_service = DatasourceProviderService()
- credential = datasource_provider_service.get_datasource_credentials(
- tenant_id=document.tenant_id,
- credential_id=credential_id,
- provider="notion_datasource",
- plugin_id="langgenius/notion_datasource",
- )
-
- if not credential:
- logger.error(
- "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
- document_id,
- document.tenant_id,
- credential_id,
- )
- document.indexing_status = "error"
- document.error = "Datasource credential not found. Please reconnect your Notion workspace."
- document.stopped_at = naive_utc_now()
- db.session.commit()
- db.session.close()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
- loader = NotionExtractor(
- notion_workspace_id=workspace_id,
- notion_obj_id=page_id,
- notion_page_type=page_type,
- notion_access_token=credential.get("integration_secret"),
- tenant_id=document.tenant_id,
- )
+ data_source_info = document.data_source_info_dict
+ if document.data_source_type == "notion_import":
+ if (
+ not data_source_info
+ or "notion_page_id" not in data_source_info
+ or "notion_workspace_id" not in data_source_info
+ ):
+ raise ValueError("no notion page found")
+ workspace_id = data_source_info["notion_workspace_id"]
+ page_id = data_source_info["notion_page_id"]
+ page_type = data_source_info["type"]
+ page_edited_time = data_source_info["last_edited_time"]
+ credential_id = data_source_info.get("credential_id")
- last_edited_time = loader.get_notion_last_edited_time()
+ # Get credentials from datasource provider
+ datasource_provider_service = DatasourceProviderService()
+ credential = datasource_provider_service.get_datasource_credentials(
+ tenant_id=document.tenant_id,
+ credential_id=credential_id,
+ provider="notion_datasource",
+ plugin_id="langgenius/notion_datasource",
+ )
- # check the page is updated
- if last_edited_time != page_edited_time:
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.commit()
-
- # delete all document segment and index
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
-
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- index_node_ids = [segment.index_node_id for segment in segments]
-
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Cleaned document when document update data source or process rule: {} latency: {}".format(
- document_id, end_at - start_at
- ),
- fg="green",
- )
+ if not credential:
+ logger.error(
+ "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
+ document_id,
+ document.tenant_id,
+ credential_id,
)
- except Exception:
- logger.exception("Cleaned document when document update data source or process rule failed")
+ document.indexing_status = "error"
+ document.error = "Datasource credential not found. Please reconnect your Notion workspace."
+ document.stopped_at = naive_utc_now()
+ session.commit()
+ return
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- end_at = time.perf_counter()
- logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ loader = NotionExtractor(
+ notion_workspace_id=workspace_id,
+ notion_obj_id=page_id,
+ notion_page_type=page_type,
+ notion_access_token=credential.get("integration_secret"),
+ tenant_id=document.tenant_id,
+ )
+
+ last_edited_time = loader.get_notion_last_edited_time()
+
+ # check the page is updated
+ if last_edited_time != page_edited_time:
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.commit()
+
+ # delete all document segment and index
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
+
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Cleaned document when document update data source or process rule: {} latency: {}".format(
+ document_id, end_at - start_at
+ ),
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("Cleaned document when document update data source or process rule failed")
+
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ end_at = time.perf_counter()
+ logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py
index acbdab631..3bdff6019 100644
--- a/api/tasks/document_indexing_task.py
+++ b/api/tasks/document_indexing_task.py
@@ -6,11 +6,11 @@ import click
from celery import shared_task
from configs import dify_config
+from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@@ -46,66 +46,63 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
documents = []
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
- db.session.close()
- return
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- count = len(document_ids)
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
- if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
- raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
- if count > batch_upload_limit:
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
+ return
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ count = len(document_ids)
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+ if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ for document_id in document_ids:
+ document = (
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
- except Exception as e:
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ return
+
for document_id in document_ids:
+ logger.info(click.style(f"Start process document: {document_id}", fg="green"))
+
document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
+
if document:
- document.indexing_status = "error"
- document.error = str(e)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- db.session.close()
- return
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ documents.append(document)
+ session.add(document)
+ session.commit()
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
-
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- )
-
- if document:
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
-
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run(documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run(documents)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(
diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py
index 161502a22..67a23be95 100644
--- a/api/tasks/document_indexing_update_task.py
+++ b/api/tasks/document_indexing_update_task.py
@@ -3,8 +3,9 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.commit()
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.commit()
- # delete all document segment and index
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
+ # delete all document segment and index
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise Exception("Dataset not found")
- index_type = document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_type = document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Cleaned document when document update data source or process rule: {} latency: {}".format(
- document_id, end_at - start_at
- ),
- fg="green",
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ db.session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Cleaned document when document update data source or process rule: {} latency: {}".format(
+ document_id, end_at - start_at
+ ),
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when document update data source or process rule failed")
+ except Exception:
+ logger.exception("Cleaned document when document update data source or process rule failed")
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- end_at = time.perf_counter()
- logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ end_at = time.perf_counter()
+ logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py
index 4078c8910..00a963255 100644
--- a/api/tasks/duplicate_document_indexing_task.py
+++ b/api/tasks/duplicate_document_indexing_task.py
@@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
from configs import dify_config
+from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue(
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
- documents = []
+ documents: list[Document] = []
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if dataset is None:
- logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
- db.session.close()
- return
-
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
+ with session_factory.create_session() as session:
try:
- if features.billing.enabled:
- vector_space = features.vector_space
- count = len(document_ids)
- if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
- raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
- if count > batch_upload_limit:
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
- current = int(getattr(vector_space, "size", 0) or 0)
- limit = int(getattr(vector_space, "limit", 0) or 0)
- if limit > 0 and (current + count) > limit:
- raise ValueError(
- "Your total number of documents plus the number of uploads have exceeded the limit of "
- "your subscription."
- )
- except Exception as e:
- for document_id in document_ids:
- document = (
- db.session.query(Document)
- .where(Document.id == document_id, Document.dataset_id == dataset_id)
- .first()
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if dataset is None:
+ logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ count = len(document_ids)
+ if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+ current = int(getattr(vector_space, "size", 0) or 0)
+ limit = int(getattr(vector_space, "limit", 0) or 0)
+ if limit > 0 and (current + count) > limit:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have exceeded the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ documents = list(
+ session.scalars(
+ select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
+ ).all()
)
- if document:
- document.indexing_status = "error"
- document.error = str(e)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- return
+ for document in documents:
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ return
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
-
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ documents = list(
+ session.scalars(
+ select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
+ ).all()
)
- if document:
+ for document in documents:
+ logger.info(click.style(f"Start process document: {document.id}", fg="green"))
+
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
+ session.add(document)
+ session.commit()
- indexing_runner = IndexingRunner()
- indexing_runner.run(documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
- finally:
- db.session.close()
+ indexing_runner = IndexingRunner()
+ indexing_runner.run(list(documents))
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
@shared_task(queue="dataset")
diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py
index 7615469ed..1f9f21aa7 100644
--- a/api/tasks/enable_segment_to_index_task.py
+++ b/api/tasks/enable_segment_to_index_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -27,91 +27,93 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "completed":
- logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
-
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "completed":
+ logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ try:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+
+ dataset = segment.dataset
+
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ multimodel_documents = []
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodel_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
)
- child_documents.append(child_document)
- document.children = child_documents
- multimodel_documents = []
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodel_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- # save vector index
- index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
+ # save vector index
+ index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("enable segment to index failed")
- segment.enabled = False
- segment.disabled_at = naive_utc_now()
- segment.status = "error"
- segment.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("enable segment to index failed")
+ segment.enabled = False
+ segment.disabled_at = naive_utc_now()
+ segment.status = "error"
+ segment.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py
index 9f17d09e1..48d3c8e17 100644
--- a/api/tasks/enable_segments_to_index_task.py
+++ b/api/tasks/enable_segments_to_index_task.py
@@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
@@ -29,105 +29,102 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
- return
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
+ return
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
- db.session.close()
- return
- # sync index processor
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if not dataset_document:
+ logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
+ return
+ if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+ logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
+ return
+ # sync index processor
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- )
- ).all()
- if not segments:
- logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
- db.session.close()
- return
-
- try:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": document_id,
- "dataset_id": dataset_id,
- },
+ segments = session.scalars(
+ select(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
)
+ ).all()
+ if not segments:
+ logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
+ return
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": document_id,
- "dataset_id": dataset_id,
- },
+ try:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": document_id,
+ "dataset_id": dataset_id,
+ },
+ )
+
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": document_id,
+ "dataset_id": dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
)
- child_documents.append(child_document)
- document.children = child_documents
+ documents.append(document)
+ # save vector index
+ index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
-
- end_at = time.perf_counter()
- logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("enable segments to index failed")
- # update segment error msg
- db.session.query(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- ).update(
- {
- "error": str(e),
- "status": "error",
- "disabled_at": naive_utc_now(),
- "enabled": False,
- }
- )
- db.session.commit()
- finally:
- for segment in segments:
- indexing_cache_key = f"segment_{segment.id}_indexing"
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("enable segments to index failed")
+ # update segment error msg
+ session.query(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
+ ).update(
+ {
+ "error": str(e),
+ "status": "error",
+ "disabled_at": naive_utc_now(),
+ "enabled": False,
+ }
+ )
+ session.commit()
+ finally:
+ for segment in segments:
+ indexing_cache_key = f"segment_{segment.id}_indexing"
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py
index 1b2a653c0..af72023da 100644
--- a/api/tasks/recover_document_indexing_task.py
+++ b/api/tasks/recover_document_indexing_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
-from extensions.ext_database import db
from models.dataset import Document
logger = logging.getLogger(__name__)
@@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- try:
- indexing_runner = IndexingRunner()
- if document.indexing_status in {"waiting", "parsing", "cleaning"}:
- indexing_runner.run([document])
- elif document.indexing_status == "splitting":
- indexing_runner.run_in_splitting_status(document)
- elif document.indexing_status == "indexing":
- indexing_runner.run_in_indexing_status(document)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ if document.indexing_status in {"waiting", "parsing", "cleaning"}:
+ indexing_runner.run([document])
+ elif document.indexing_status == "splitting":
+ indexing_runner.run_in_splitting_status(document)
+ elif document.indexing_status == "indexing":
+ indexing_runner.run_in_indexing_status(document)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index 3227f6da9..817249845 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -1,15 +1,20 @@
import logging
import time
from collections.abc import Callable
+from typing import Any, cast
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
+from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
+from configs import dify_config
+from core.db.session_factory import session_factory
from extensions.ext_database import db
+from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from models import (
ApiToken,
AppAnnotationHitHistory,
@@ -40,6 +45,7 @@ from models.workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
+ WorkflowArchiveLog,
)
from repositories.factory import DifyAPIRepositoryFactory
@@ -64,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_app_workflow_runs(tenant_id, app_id)
_delete_app_workflow_node_executions(tenant_id, app_id)
_delete_app_workflow_app_logs(tenant_id, app_id)
+ if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED:
+ _delete_app_workflow_archive_logs(tenant_id, app_id)
+ _delete_archived_workflow_run_files(tenant_id, app_id)
_delete_app_conversations(tenant_id, app_id)
_delete_app_messages(tenant_id, app_id)
_delete_workflow_tool_providers(tenant_id, app_id)
@@ -77,7 +86,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_workflow_webhook_triggers(tenant_id, app_id)
_delete_workflow_schedule_plans(tenant_id, app_id)
_delete_workflow_trigger_logs(tenant_id, app_id)
-
end_at = time.perf_counter()
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
except SQLAlchemyError as e:
@@ -89,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
- def del_model_config(model_config_id: str):
- db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
+ def del_model_config(session, model_config_id: str):
+ session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@@ -101,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
- def del_site(site_id: str):
- db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
+ def del_site(session, site_id: str):
+ session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
@@ -113,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
- def del_mcp_server(mcp_server_id: str):
- db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
+ def del_mcp_server(session, mcp_server_id: str):
+ session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@@ -125,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
- def del_api_token(api_token_id: str):
- db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
+ def del_api_token(session, api_token_id: str):
+ session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
@@ -137,8 +145,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
- def del_installed_app(installed_app_id: str):
- db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
+ def del_installed_app(session, installed_app_id: str):
+ session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -149,10 +157,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
- def del_recommended_app(recommended_app_id: str):
- db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
- synchronize_session=False
- )
+ def del_recommended_app(session, recommended_app_id: str):
+ session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
@@ -163,8 +169,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
- def del_annotation_hit_history(annotation_hit_history_id: str):
- db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
+ def del_annotation_hit_history(session, annotation_hit_history_id: str):
+ session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
@@ -175,8 +181,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
"annotation hit history",
)
- def del_annotation_setting(annotation_setting_id: str):
- db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
+ def del_annotation_setting(session, annotation_setting_id: str):
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@@ -189,8 +195,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
- def del_dataset_join(dataset_join_id: str):
- db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
+ def del_dataset_join(session, dataset_join_id: str):
+ session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@@ -201,8 +207,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
- def del_workflow(workflow_id: str):
- db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
+ def del_workflow(session, workflow_id: str):
+ session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -241,10 +247,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
- def del_workflow_app_log(workflow_app_log_id: str):
- db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
- synchronize_session=False
- )
+ def del_workflow_app_log(session, workflow_app_log_id: str):
+ session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -254,12 +258,51 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
)
-def _delete_app_conversations(tenant_id: str, app_id: str):
- def del_conversation(conversation_id: str):
- db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
+ def del_workflow_archive_log(workflow_archive_log_id: str):
+ db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
- db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
+
+ _delete_records(
+ """select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
+ {"tenant_id": tenant_id, "app_id": app_id},
+ del_workflow_archive_log,
+ "workflow archive log",
+ )
+
+
+def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
+ prefix = f"{tenant_id}/app_id={app_id}/"
+ try:
+ archive_storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ logger.info("Archive storage not configured, skipping archive file cleanup: %s", e)
+ return
+
+ try:
+ keys = archive_storage.list_objects(prefix)
+ except Exception:
+ logger.exception("Failed to list archive files for app %s", app_id)
+ return
+
+ deleted = 0
+ for key in keys:
+ try:
+ archive_storage.delete_object(key)
+ deleted += 1
+ except Exception:
+ logger.exception("Failed to delete archive object %s", key)
+
+ logger.info("Deleted %s archive objects for app %s", deleted, app_id)
+
+
+def _delete_app_conversations(tenant_id: str, app_id: str):
+ def del_conversation(session, conversation_id: str):
+ session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+ session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@@ -270,28 +313,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
def _delete_conversation_variables(*, app_id: str):
- stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
- with db.engine.connect() as conn:
- conn.execute(stmt)
- conn.commit()
+ with session_factory.create_session() as session:
+ stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
+ session.execute(stmt)
+ session.commit()
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
- def del_message(message_id: str):
- db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
+ def del_message(session, message_id: str):
+ session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
+ session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
+ session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
+ session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
- db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
- synchronize_session=False
- )
- db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
- db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
- db.session.query(Message).where(Message.id == message_id).delete()
+ session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
+ session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
+ session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
@@ -302,8 +343,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
- def del_tool_provider(tool_provider_id: str):
- db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
+ def del_tool_provider(session, tool_provider_id: str):
+ session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@@ -316,8 +357,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
- def del_tag_binding(tag_binding_id: str):
- db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
+ def del_tag_binding(session, tag_binding_id: str):
+ session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@@ -328,8 +369,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
- def del_end_user(end_user_id: str):
- db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
+ def del_end_user(session, end_user_id: str):
+ session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -340,10 +381,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
- def del_trace_app_config(trace_app_config_id: str):
- db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
- synchronize_session=False
- )
+ def del_trace_app_config(session, trace_app_config_id: str):
+ session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
@@ -381,14 +420,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
- with db.engine.begin() as conn:
+ with session_factory.create_session() as session:
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
- result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
+ result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
rows = list(result)
if not rows:
@@ -399,7 +438,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
# Clean up associated Offload data first
if file_ids:
- files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
+ files_deleted = _delete_draft_variable_offload_data(session, file_ids)
total_files_deleted += files_deleted
# Delete the draft variables
@@ -407,8 +446,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
- deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
- batch_deleted = deleted_result.rowcount
+ deleted_result = cast(
+ CursorResult[Any],
+ session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}),
+ )
+ batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0)
total_deleted += batch_deleted
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
@@ -423,7 +465,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
return total_deleted
-def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
+def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
"""
Delete Offload data associated with WorkflowDraftVariable file_ids.
@@ -434,7 +476,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
4. Deletes WorkflowDraftVariableFile records
Args:
- conn: Database connection
+ session: Database connection
file_ids: List of WorkflowDraftVariableFile IDs
Returns:
@@ -450,12 +492,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
try:
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
query_sql = """
- SELECT wdvf.id, uf.key, uf.id as upload_file_id
- FROM workflow_draft_variable_files wdvf
- JOIN upload_files uf ON wdvf.upload_file_id = uf.id
- WHERE wdvf.id IN :file_ids
- """
- result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
+ SELECT wdvf.id, uf.key, uf.id as upload_file_id
+ FROM workflow_draft_variable_files wdvf
+ JOIN upload_files uf ON wdvf.upload_file_id = uf.id
+ WHERE wdvf.id IN :file_ids \
+ """
+ result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
file_records = list(result)
# Delete from object storage and collect upload file IDs
@@ -473,17 +515,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
# Delete UploadFile records
if upload_file_ids:
delete_upload_files_sql = """
- DELETE FROM upload_files
- WHERE id IN :upload_file_ids
- """
- conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
+ DELETE \
+ FROM upload_files
+ WHERE id IN :upload_file_ids \
+ """
+ session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
# Delete WorkflowDraftVariableFile records
delete_variable_files_sql = """
- DELETE FROM workflow_draft_variable_files
- WHERE id IN :file_ids
- """
- conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
+ DELETE \
+ FROM workflow_draft_variable_files
+ WHERE id IN :file_ids \
+ """
+ session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
except Exception:
logging.exception("Error deleting draft variable offload data:")
@@ -493,8 +537,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
def _delete_app_triggers(tenant_id: str, app_id: str):
- def del_app_trigger(trigger_id: str):
- db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
+ def del_app_trigger(session, trigger_id: str):
+ session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -505,8 +549,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
- def del_plugin_trigger(trigger_id: str):
- db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
+ def del_plugin_trigger(session, trigger_id: str):
+ session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
)
@@ -519,8 +563,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
- def del_webhook_trigger(trigger_id: str):
- db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
+ def del_webhook_trigger(session, trigger_id: str):
+ session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
)
@@ -533,10 +577,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
- def del_schedule_plan(plan_id: str):
- db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
- synchronize_session=False
- )
+ def del_schedule_plan(session, plan_id: str):
+ session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -547,8 +589,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
- def del_trigger_log(log_id: str):
- db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
+ def del_trigger_log(session, log_id: str):
+ session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -560,18 +602,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
- with db.engine.begin() as conn:
- rs = conn.execute(sa.text(query_sql), params)
- if rs.rowcount == 0:
+ with session_factory.create_session() as session:
+ rs = session.execute(sa.text(query_sql), params)
+ rows = rs.fetchall()
+ if not rows:
break
- for i in rs:
+ for i in rows:
record_id = str(i.id)
try:
- delete_func(record_id)
- db.session.commit()
+ delete_func(session, record_id)
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
except Exception:
logger.exception("Error occurred while deleting %s %s", name, record_id)
- continue
+ # continue with next record even if one deletion fails
+ session.rollback()
+ break
+ session.commit()
+
rs.close()
diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py
index c0ab2d0b4..c3c255fb1 100644
--- a/api/tasks/remove_document_from_index_task.py
+++ b/api/tasks/remove_document_from_index_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
@@ -25,52 +25,55 @@ def remove_document_from_index_task(document_id: str):
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id).first()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- if document.indexing_status != "completed":
- logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
- db.session.close()
- return
+ if document.indexing_status != "completed":
+ logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
+ return
- indexing_cache_key = f"document_{document.id}_indexing"
+ indexing_cache_key = f"document_{document.id}_indexing"
- try:
- dataset = document.dataset
+ try:
+ dataset = document.dataset
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
- index_node_ids = [segment.index_node_id for segment in segments]
- if index_node_ids:
- try:
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
- except Exception:
- logger.exception("clean dataset %s from index failed", dataset.id)
- # update segment to disable
- db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
- {
- DocumentSegment.enabled: False,
- DocumentSegment.disabled_at: naive_utc_now(),
- DocumentSegment.disabled_by: document.disabled_by,
- DocumentSegment.updated_at: naive_utc_now(),
- }
- )
- db.session.commit()
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
+ if index_node_ids:
+ try:
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+ except Exception:
+ logger.exception("clean dataset %s from index failed", dataset.id)
+ # update segment to disable
+ session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
+ {
+ DocumentSegment.enabled: False,
+ DocumentSegment.disabled_at: naive_utc_now(),
+ DocumentSegment.disabled_by: document.disabled_by,
+ DocumentSegment.updated_at: naive_utc_now(),
+ }
+ )
+ session.commit()
- end_at = time.perf_counter()
- logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("remove document from index failed")
- if not document.archived:
- document.enabled = True
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Document removed from index: {document.id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("remove document from index failed")
+ if not document.archived:
+ document.enabled = True
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py
index 9d208647e..f20b15ac8 100644
--- a/api/tasks/retry_document_indexing_task.py
+++ b/api/tasks/retry_document_indexing_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
@@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
- return
- user = db.session.query(Account).where(Account.id == user_id).first()
- if not user:
- logger.info(click.style(f"User not found: {user_id}", fg="red"))
- return
- tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
- if not tenant:
- raise ValueError("Tenant not found")
- user.current_tenant = tenant
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+ user = session.query(Account).where(Account.id == user_id).first()
+ if not user:
+ logger.info(click.style(f"User not found: {user_id}", fg="red"))
+ return
+ tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
+ if not tenant:
+ raise ValueError("Tenant not found")
+ user.current_tenant = tenant
- for document_id in document_ids:
- retry_indexing_cache_key = f"document_{document_id}_is_retried"
- # check document limit
- features = FeatureService.get_features(tenant.id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
- )
- except Exception as e:
+ for document_id in document_ids:
+ retry_indexing_cache_key = f"document_{document_id}_is_retried"
+ # check document limit
+ features = FeatureService.get_features(tenant.id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ document = (
+ session.query(Document)
+ .where(Document.id == document_id, Document.dataset_id == dataset_id)
+ .first()
+ )
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ redis_client.delete(retry_indexing_cache_key)
+ return
+
+ logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
- db.session.query(Document)
- .where(Document.id == document_id, Document.dataset_id == dataset_id)
- .first()
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
- if document:
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
+ return
+ try:
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+
+ if dataset.runtime_mode == "rag_pipeline":
+ rag_pipeline_service = RagPipelineService()
+ rag_pipeline_service.retry_error_document(dataset, document, user)
+ else:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(retry_indexing_cache_key)
+ except Exception as ex:
document.indexing_status = "error"
- document.error = str(e)
+ document.error = str(ex)
document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- redis_client.delete(retry_indexing_cache_key)
- return
-
- logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ session.add(document)
+ session.commit()
+ logger.info(click.style(str(ex), fg="yellow"))
+ redis_client.delete(retry_indexing_cache_key)
+ logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception(
+ "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
- return
- try:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
-
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
-
- if dataset.runtime_mode == "rag_pipeline":
- rag_pipeline_service = RagPipelineService()
- rag_pipeline_service.retry_error_document(dataset, document, user)
- else:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(retry_indexing_cache_key)
- except Exception as ex:
- document.indexing_status = "error"
- document.error = str(ex)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- logger.info(click.style(str(ex), fg="yellow"))
- redis_client.delete(retry_indexing_cache_key)
- logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
- end_at = time.perf_counter()
- logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception(
- "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
- )
- raise e
- finally:
- db.session.close()
+ raise e
diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py
index 0dc1d841f..f1c8c5699 100644
--- a/api/tasks/sync_website_document_indexing_task.py
+++ b/api/tasks/sync_website_document_indexing_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if dataset is None:
- raise ValueError("Dataset not found")
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if dataset is None:
+ raise ValueError("Dataset not found")
- sync_indexing_cache_key = f"document_{document_id}_is_sync"
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
- )
- except Exception as e:
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- )
- if document:
+ sync_indexing_cache_key = f"document_{document_id}_is_sync"
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ document = (
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ )
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ redis_client.delete(sync_indexing_cache_key)
+ return
+
+ logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
+ return
+ try:
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(sync_indexing_cache_key)
+ except Exception as ex:
document.indexing_status = "error"
- document.error = str(e)
+ document.error = str(ex)
document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- redis_client.delete(sync_indexing_cache_key)
- return
-
- logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
- return
- try:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
-
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
-
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(sync_indexing_cache_key)
- except Exception as ex:
- document.indexing_status = "error"
- document.error = str(ex)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- logger.info(click.style(str(ex), fg="yellow"))
- redis_client.delete(sync_indexing_cache_key)
- logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
- end_at = time.perf_counter()
- logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
+ session.add(document)
+ session.commit()
+ logger.info(click.style(str(ex), fg="yellow"))
+ redis_client.delete(sync_indexing_cache_key)
+ logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py
index ee1d31aa9..d18ea2c23 100644
--- a/api/tasks/trigger_processing_tasks.py
+++ b/api/tasks/trigger_processing_tasks.py
@@ -16,6 +16,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginInvokeError
@@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
-from extensions.ext_database import db
from models.enums import (
AppTriggerType,
CreatorUserRole,
@@ -257,7 +257,7 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py
index ed92f3f3c..7698a1a6b 100644
--- a/api/tasks/trigger_subscription_refresh_tasks.py
+++ b/api/tasks/trigger_subscription_refresh_tasks.py
@@ -7,9 +7,9 @@ from celery import shared_task
from sqlalchemy.orm import Session
from configs import dify_config
+from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.utils.locks import build_trigger_refresh_lock_key
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
@@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
try:
now: int = _now_ts()
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
if not subscription:
diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py
index 7d145fb50..3b3c6e531 100644
--- a/api/tasks/workflow_execution_tasks.py
+++ b/api/tasks/workflow_execution_tasks.py
@@ -10,11 +10,10 @@ import logging
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
-from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
@@ -46,10 +45,7 @@ def save_workflow_execution_task(
True if successful, False otherwise
"""
try:
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py
index 8f5127670..b30a4ff15 100644
--- a/api/tasks/workflow_node_execution_tasks.py
+++ b/api/tasks/workflow_node_execution_tasks.py
@@ -10,13 +10,12 @@ import logging
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
-from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@@ -48,10 +47,7 @@ def save_workflow_node_execution_task(
True if successful, False otherwise
"""
try:
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py
index f54e02a21..8c64d3ab2 100644
--- a/api/tasks/workflow_schedule_tasks.py
+++ b/api/tasks/workflow_schedule_tasks.py
@@ -1,15 +1,14 @@
import logging
from celery import shared_task
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.nodes.trigger_schedule.exc import (
ScheduleExecutionError,
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
-from extensions.ext_database import db
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
@@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
-
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")
diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html
index a07c5f4b1..7b296519f 100644
--- a/api/templates/invite_member_mail_template_en-US.html
+++ b/api/templates/invite_member_mail_template_en-US.html
@@ -83,7 +83,30 @@
Dear {{ to }},
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
Click the button below to log in to Dify and join the workspace.
- Login Here
+
Best regards,
Dify Team
diff --git a/api/templates/invite_member_mail_template_zh-CN.html b/api/templates/invite_member_mail_template_zh-CN.html
index 27709a3c6..c05b3ddb6 100644
--- a/api/templates/invite_member_mail_template_zh-CN.html
+++ b/api/templates/invite_member_mail_template_zh-CN.html
@@ -83,7 +83,30 @@
尊敬的 {{ to }},
{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。
点击下方按钮即可登录 Dify 并且加入空间。
- 在此登录
+
此致,
Dify 团队
diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html
index ac5042c27..e2bb99c98 100644
--- a/api/templates/register_email_when_account_exist_template_en-US.html
+++ b/api/templates/register_email_when_account_exist_template_en-US.html
@@ -115,7 +115,30 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here:
- Log In
+
If you forgot your password, you can reset it here: Reset Password
diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html
index 326b58343..6a5bbd135 100644
--- a/api/templates/register_email_when_account_exist_template_zh-CN.html
+++ b/api/templates/register_email_when_account_exist_template_zh-CN.html
@@ -115,7 +115,30 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录:
- 登录
+
如果您忘记了密码,可以在此重置: 重置密码
diff --git a/api/templates/without-brand/invite_member_mail_template_en-US.html b/api/templates/without-brand/invite_member_mail_template_en-US.html
index f9157284f..687ece617 100644
--- a/api/templates/without-brand/invite_member_mail_template_en-US.html
+++ b/api/templates/without-brand/invite_member_mail_template_en-US.html
@@ -92,12 +92,34 @@
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
create, and collaborate to build and operate AI applications.
Click the button below to log in to {{application_title}} and join the workspace.
- Login Here
+
Best regards,
{{application_title}} Team