diff --git a/admin/server/api/v1/gaia/enter.go b/admin/server/api/v1/gaia/enter.go index 47fbe72b1..0b8f269eb 100644 --- a/admin/server/api/v1/gaia/enter.go +++ b/admin/server/api/v1/gaia/enter.go @@ -9,6 +9,7 @@ type ApiGroup struct { SystemApi TestApi SystemOAuth2Api + BatchWorkflowApi } var ( diff --git a/admin/server/api/v1/gaia/workflow.go b/admin/server/api/v1/gaia/workflow.go new file mode 100644 index 000000000..4b3a27e38 --- /dev/null +++ b/admin/server/api/v1/gaia/workflow.go @@ -0,0 +1,657 @@ +package gaia + +import ( + "bytes" + "encoding/csv" + "encoding/json" + "fmt" + "github.com/flipped-aurora/gin-vue-admin/server/global" + "github.com/flipped-aurora/gin-vue-admin/server/model/gaia" + "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request" + "github.com/flipped-aurora/gin-vue-admin/server/utils" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" + "io" + "net/http" + "strconv" + "strings" + + "github.com/flipped-aurora/gin-vue-admin/server/model/common/response" + "github.com/flipped-aurora/gin-vue-admin/server/service" + gaiaService "github.com/flipped-aurora/gin-vue-admin/server/service/gaia" + "github.com/gin-gonic/gin" +) + +type BatchWorkflowApi struct{} + +var batchWorkflowService = service.ServiceGroupApp.GaiaServiceGroup.BatchWorkflowService + +// CreateBatchWorkflow 创建批量处理工作流 +// @Tags BatchWorkflow +// @Summary 创建批量处理工作流 +// @Description 上传CSV文件并创建批量处理工作流 +// @Accept multipart/form-data +// @Produce application/json +// @Param file formData file true "CSV文件" +// @Param installed_id formData string true "安装的应用ID" +// @Param app_id formData string true "应用ID" +// @Param tenant_id formData string true "租户ID" +// @Success 200 {object} response.Response{data=gaia.BatchWorkflow} "成功" +// @Router /gaia/workflow/batch/processing [post] +func (api *BatchWorkflowApi) CreateBatchWorkflow(c *gin.Context) { + // 获取表单参数 + userID := utils.GetUserID(c) + installedID := c.PostForm("installed_id") + keyNameMappingStr := c.PostForm("key_name_mapping") + + if installedID == "" { + response.FailWithMessage("缺少必要参数", c) + return + } + + // 解析key-name映射 + var keyNameMapping map[string]string + if keyNameMappingStr != "" { + if err := json.Unmarshal([]byte(keyNameMappingStr), &keyNameMapping); err != nil { + response.FailWithMessage("解析key_name_mapping失败: "+err.Error(), c) + return + } + } + + // 获取上传的文件 + file, err := c.FormFile("file") + if err != nil { + response.FailWithMessage("获取文件失败: "+err.Error(), c) + return + } + + // 打开上传的文件 + src, err := file.Open() + if err != nil { + response.FailWithMessage("打开文件失败: "+err.Error(), c) + return + } + defer src.Close() + + // 读取文件内容并检测编码 + content, err := io.ReadAll(src) + if err != nil { + response.FailWithMessage("读取文件内容失败: "+err.Error(), c) + return + } + + // 尝试不同编码解析CSV + var data [][]string + var parseErr error + + // 1. 先尝试UTF-8读取,使用宽松的CSV解析器配置 + reader := bytes.NewReader(content) + csvReader := csv.NewReader(reader) + csvReader.LazyQuotes = true // 允许懒惰引号 + csvReader.TrimLeadingSpace = true // 去除前导空格 + data, parseErr = csvReader.ReadAll() + + // 2. 如果UTF-8失败或包含乱码,尝试GBK编码 + if parseErr != nil || containsGarbledText(data) { + decoder := simplifiedchinese.GBK.NewDecoder() + gbkReader := transform.NewReader(bytes.NewReader(content), decoder) + + csvReader = csv.NewReader(gbkReader) + csvReader.LazyQuotes = true // 允许懒惰引号 + csvReader.TrimLeadingSpace = true // 去除前导空格 + data, parseErr = csvReader.ReadAll() + + // 3. 如果GBK也失败,尝试GB18030编码 + if parseErr != nil || containsGarbledText(data) { + gb18030Decoder := simplifiedchinese.GB18030.NewDecoder() + gb18030Reader := transform.NewReader(bytes.NewReader(content), gb18030Decoder) + + csvReader = csv.NewReader(gb18030Reader) + csvReader.LazyQuotes = true // 允许懒惰引号 + csvReader.TrimLeadingSpace = true // 去除前导空格 + data, parseErr = csvReader.ReadAll() + } + } + + // 4. 如果以上方法都失败,尝试最后的兜底解析方法 + if parseErr != nil { + data, parseErr = parseCSVWithFallback(content) + } + + if parseErr != nil { + response.FailWithMessage("解析CSV文件失败,请检查文件格式。错误详情: "+parseErr.Error(), c) + return + } + + // 创建批量处理工作流 + batchWorkflow, err := batchWorkflowService.CreateBatchWorkflow( + userID, installedID, file.Filename, data, keyNameMapping) + if err != nil { + // 特别处理数据库连接问题 + if strings.Contains(err.Error(), "数据库连接未初始化") { + response.FailWithMessage("系统初始化中,请稍后重试", c) + } else { + response.FailWithMessage("创建批量处理失败: "+err.Error(), c) + } + return + } + + response.OkWithData(batchWorkflow, c) +} + +// GetBatchWorkflow 获取批量处理信息 +// @Tags BatchWorkflow +// @Summary 获取批量处理信息 +// @Description 根据ID获取批量处理信息 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response{data=gaia.BatchWorkflow} "成功" +// @Router /gaia/workflow/batch/{id} [get] +func (api *BatchWorkflowApi) GetBatchWorkflow(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + batchWorkflow, err := batchWorkflowService.GetBatchWorkflow(id) + if err != nil { + response.FailWithMessage("获取批量处理信息失败: "+err.Error(), c) + return + } + + response.OkWithData(batchWorkflow, c) +} + +// GetBatchWorkflowTasks 获取批量处理任务列表 +// @Tags BatchWorkflow +// @Summary 获取批量处理任务列表 +// @Description 根据批量处理ID获取任务列表 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response{data=[]gaia.BatchWorkflowTask} "成功" +// @Router /gaia/workflow/batch/{id}/tasks [get] +func (api *BatchWorkflowApi) GetBatchWorkflowTasks(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + tasks, err := batchWorkflowService.GetBatchWorkflowTasks(id) + if err != nil { + response.FailWithMessage("获取任务列表失败: "+err.Error(), c) + return + } + + response.OkWithData(tasks, c) +} + +// GetBatchWorkflowProgress 获取批量处理进度 +// @Tags BatchWorkflow +// @Summary 获取批量处理进度 +// @Description 根据ID获取批量处理进度信息 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response{data=map[string]interface{}} "成功" +// @Router /gaia/workflow/batch/{id}/progress [get] +func (api *BatchWorkflowApi) GetBatchWorkflowProgress(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + progress, err := batchWorkflowService.GetBatchWorkflowProgress(id) + if err != nil { + response.FailWithMessage("获取进度信息失败: "+err.Error(), c) + return + } + + response.OkWithData(progress, c) +} + +// StopBatchWorkflow 停止批量处理 +// @Tags BatchWorkflow +// @Summary 停止批量处理 +// @Description 根据ID停止批量处理 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/batch/{id}/stop [post] +func (api *BatchWorkflowApi) StopBatchWorkflow(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + err := batchWorkflowService.StopBatchWorkflow(id) + if err != nil { + response.FailWithMessage("停止批量处理失败: "+err.Error(), c) + return + } + + response.OkWithMessage("停止成功", c) +} + +// RetryBatchWorkflow 重试批量处理(重新开始所有任务) +// @Tags BatchWorkflow +// @Summary 重试批量处理 +// @Description 根据ID重试批量处理,重置所有任务从头开始 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/batch/{id}/retry [post] +func (api *BatchWorkflowApi) RetryBatchWorkflow(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + err := batchWorkflowService.RetryBatchWorkflow(id) + if err != nil { + response.FailWithMessage("重试批量处理失败: "+err.Error(), c) + return + } + + response.OkWithMessage("重试成功,所有任务已重置", c) +} + +// RetryFailedTasks 仅重试失败的任务 +// @Tags BatchWorkflow +// @Summary 仅重试失败的任务 +// @Description 根据ID仅重试失败的任务,保留已完成的任务 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/batch/{id}/retry-failed [post] +func (api *BatchWorkflowApi) RetryFailedTasks(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + err := batchWorkflowService.RetryFailedTasks(id) + if err != nil { + response.FailWithMessage("重试失败任务失败: "+err.Error(), c) + return + } + + response.OkWithMessage("失败任务重试成功", c) +} + +// ResumeBatchWorkflow 恢复批量处理 +// @Tags BatchWorkflow +// @Summary 恢复批量处理 +// @Description 根据ID恢复批量处理 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/batch/{id}/resume [post] +func (api *BatchWorkflowApi) ResumeBatchWorkflow(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + err := batchWorkflowService.ResumeBatchWorkflow(id) + if err != nil { + response.FailWithMessage("恢复批量处理失败: "+err.Error(), c) + return + } + + response.OkWithMessage("恢复成功", c) +} + +// ResetBatchWorkflowErrorCount 重置批量工作流错误计数 +// @Tags BatchWorkflow +// @Summary 重置批量工作流错误计数 +// @Description 重置指定批量工作流的错误计数,恢复用户并发位 +// @Produce application/json +// @Param id path string true "批量处理ID" +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/batch/{id}/reset-error-count [post] +func (api *BatchWorkflowApi) ResetBatchWorkflowErrorCount(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + // 调用worker_pool中的重置函数 + err := gaiaService.ResetBatchWorkflowErrorCount(id) + if err != nil { + response.FailWithMessage("重置错误计数失败: "+err.Error(), c) + return + } + + response.OkWithMessage("错误计数已重置,用户并发位将恢复", c) +} + +// ResetUserErrorCount 重置用户所有批量工作流错误计数 +// @Tags BatchWorkflow +// @Summary 重置用户所有批量工作流错误计数 +// @Description 重置指定用户所有批量工作流的错误计数,恢复用户并发位 +// @Produce application/json +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/batch/reset-user-error-count [post] +func (api *BatchWorkflowApi) ResetUserErrorCount(c *gin.Context) { + userID := utils.GetUserID(c) + + // 调用worker_pool中的重置函数 + err := gaiaService.ResetUserErrorCount(userID) + if err != nil { + response.FailWithMessage("重置用户错误计数失败: "+err.Error(), c) + return + } + + response.OkWithMessage("用户所有批量工作流错误计数已重置,并发位将恢复", c) +} + +// DownloadBatchWorkflowResults 下载批量处理结果 +// @Tags BatchWorkflow +// @Summary 下载批量处理结果 +// @Description 根据ID下载批量处理结果 +// @Produce text/csv +// @Param id path string true "批量处理ID" +// @Success 200 {file} file "CSV文件" +// @Router /gaia/workflow/batch/{id}/download [get] +func (api *BatchWorkflowApi) DownloadBatchWorkflowResults(c *gin.Context) { + id := c.Param("id") + if id == "" { + response.FailWithMessage("缺少批量处理ID", c) + return + } + + // 获取批量处理信息 + flow, err := batchWorkflowService.GetBatchWorkflow(id) + if err != nil { + response.FailWithMessage("获取批量处理信息失败: "+err.Error(), c) + return + } + + // 获取任务列表 + tasks, err := batchWorkflowService.GetBatchWorkflowTasks(id) + if err != nil { + response.FailWithMessage("获取任务列表失败: "+err.Error(), c) + return + } + + // 生成CSV内容 + csvContent := generateCSVFromTasks(flow, tasks) + csvBytes := []byte(csvContent) + + // 添加 UTF-8 BOM 以确保在 Excel 中正确显示中文 + bom := []byte{0xEF, 0xBB, 0xBF} + fullContent := append(bom, csvBytes...) + + // 设置响应头 + filename := fmt.Sprintf("batch_results_%s.csv", id) + c.Header("Content-Type", "text/csv; charset=utf-8") + c.Header("Content-Disposition", fmt.Sprintf("attachment; filename*=UTF-8''%s", filename)) + c.Header("Content-Length", fmt.Sprintf("%d", len(fullContent))) + c.Header("Cache-Control", "no-cache, no-store, must-revalidate") + c.Header("Pragma", "no-cache") + c.Header("Expires", "0") + + c.Data(http.StatusOK, "text/csv; charset=utf-8", fullContent) +} + +// parseCSVWithFallback 兜底的CSV解析方法,用于处理格式不规范的CSV文件 +func parseCSVWithFallback(content []byte) ([][]string, error) { + // 将内容转换为字符串并按行分割 + contentStr := string(content) + lines := strings.Split(contentStr, "\n") + + var result [][]string + for i, line := range lines { + // 跳过空行 + if strings.TrimSpace(line) == "" { + continue + } + + // 尝试简单的逗号分割 + fields := strings.Split(line, ",") + + // 清理字段:去除多余的引号和空格 + for j, field := range fields { + field = strings.TrimSpace(field) + // 如果字段被引号包围,去除引号 + if len(field) >= 2 && field[0] == '"' && field[len(field)-1] == '"' { + field = field[1 : len(field)-1] + // 处理转义的引号 + field = strings.ReplaceAll(field, `""`, `"`) + } + fields[j] = field + } + + result = append(result, fields) + + // 如果解析失败超过100行,停止解析 + if i > 100 && len(result) == 0 { + return nil, fmt.Errorf("无法解析CSV文件:格式不正确") + } + } + + if len(result) == 0 { + return nil, fmt.Errorf("CSV文件为空或格式无法识别") + } + + return result, nil +} + +// containsGarbledText 检测是否包含乱码文本 +func containsGarbledText(data [][]string) bool { + // 检查前几行是否包含类似乱码的字符 + checkRows := 3 + if len(data) < checkRows { + checkRows = len(data) + } + + for i := 0; i < checkRows; i++ { + for _, cell := range data[i] { + // 检查是否包含典型的编码错误字符 + for _, char := range cell { + // 检查是否为替换字符(U+FFFD)或其他异常字符 + if char == '�' { + return true + } + } + // 检查特定的GBK乱码模式 + if strings.Contains(cell, "��") || strings.Contains(cell, "Ŀ") { + return true + } + } + } + return false +} + +// generateCSVFromTasks 从任务生成CSV内容 +func generateCSVFromTasks(flow *gaia.BatchWorkflow, tasks []gaia.BatchWorkflowTask) string { + if len(tasks) == 0 { + return "" + } + + // 解析第一个任务的输入参数来获取列名 + var firstTaskInputs map[string]string + if err := json.Unmarshal([]byte(tasks[0].Inputs), &firstTaskInputs); err != nil { + return "" + } + + buf := &bytes.Buffer{} + w := csv.NewWriter(buf) + + // 标题:输入列 + 处理结果 + 状态 + var nameList []string + var keyMap map[string]string + _ = json.Unmarshal([]byte(flow.KeyName), &keyMap) + headers := make([]string, 0, len(keyMap)) + for key, value := range keyMap { + headers = append(headers, key) + nameList = append(nameList, value) + } + headers = append(headers, "生成结果") + _ = w.Write(headers) + + // 行数据 + for _, task := range tasks { + var inputs map[string]string + if err := json.Unmarshal([]byte(task.Inputs), &inputs); err != nil { + continue + } + var text string + row := make([]string, 0, len(headers)) + var result request.WorkflowBatchProcessing + for _, value := range nameList { + row = append(row, inputs[value]) + } + if err := json.Unmarshal([]byte(task.Result), &result); err == nil { + for key, v := range result.Outputs { + if key == "task_id" { + continue + } + text += fmt.Sprintf("%s\r", v) + } + } + row = append(row, text) + _ = w.Write(row) + } + + w.Flush() + return buf.String() +} + +// GetWorkerPoolStatus 获取工作池状态 +// @Tags BatchWorkflow +// @Summary 获取工作池状态 +// @Description 获取当前工作池的运行状态和统计信息 +// @Accept application/json +// @Produce application/json +// @Success 200 {object} response.Response{data=map[string]interface{}} "成功" +// @Router /gaia/workflow/worker-pool/status [get] +func (api *BatchWorkflowApi) GetWorkerPoolStatus(c *gin.Context) { + pool := batchWorkflowService.GetWorkerPool() + if pool == nil { + response.FailWithMessage("工作池未初始化", c) + return + } + + status := pool.GetStatus() + response.OkWithData(status, c) +} + +// RestartWorkerPool 重启工作池 +// @Tags BatchWorkflow +// @Summary 重启工作池 +// @Description 停止当前工作池并重新启动 +// @Accept application/json +// @Produce application/json +// @Param workers query int false "worker数量" default(5) +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/worker-pool/restart [post] +func (api *BatchWorkflowApi) RestartWorkerPool(c *gin.Context) { + workers := global.GVA_CONFIG.System.WorkFlowNumber + if workersParam := c.Query("workers"); workersParam != "" { + if w, err := strconv.Atoi(workersParam); err == nil && w > 0 && w <= 20 { + workers = w + } + } + + // 停止当前工作池 + batchWorkflowService.StopWorkerPool() + + // 启动新的工作池 + batchWorkflowService.InitWorkerPool(workers) + + response.OkWithMessage("工作池重启成功", c) +} + +// StopWorkerPool 停止工作池 +// @Tags BatchWorkflow +// @Summary 停止工作池 +// @Description 停止当前工作池 +// @Accept application/json +// @Produce application/json +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/worker-pool/stop [post] +func (api *BatchWorkflowApi) StopWorkerPool(c *gin.Context) { + batchWorkflowService.StopWorkerPool() + response.OkWithMessage("工作池已停止", c) +} + +// StartWorkerPool 启动工作池 +// @Tags BatchWorkflow +// @Summary 启动工作池 +// @Description 启动工作池 +// @Accept application/json +// @Produce application/json +// @Param workers query int false "worker数量" default(5) +// @Success 200 {object} response.Response "成功" +// @Router /gaia/workflow/worker-pool/start [post] +func (api *BatchWorkflowApi) StartWorkerPool(c *gin.Context) { + workers := global.GVA_CONFIG.System.WorkFlowNumber + if workersParam := c.Query("workers"); workersParam != "" { + if w, err := strconv.Atoi(workersParam); err == nil && w > 0 && w <= 20 { + workers = w + } + } + + batchWorkflowService.InitWorkerPool(workers) + response.OkWithMessage("工作池启动成功", c) +} + +// GetBatchWorkflowList 获取最近30天的批量工作流列表 +// @Tags BatchWorkflow +// @Summary 获取最近30天的批量工作流列表 +// @Description 获取指定用户最近30天的批量工作流列表,支持分页和按应用过滤 +// @Accept application/json +// @Produce application/json +// @Param installed_id query string false "安装的应用ID" +// @Param page query int false "页码" default(1) +// @Param limit query int false "每页数量" default(10) +// @Success 200 {object} response.Response{data=map[string]interface{}} "成功" +// @Router /gaia/workflow/batch/list [get] +func (api *BatchWorkflowApi) GetBatchWorkflowList(c *gin.Context) { + userID := utils.GetUserID(c) + installedID := c.Query("installed_id") + + // 解析分页参数 + page := 1 + limit := 10 + + if pageParam := c.Query("page"); pageParam != "" { + if p, err := strconv.Atoi(pageParam); err == nil && p > 0 { + page = p + } + } + + if limitParam := c.Query("limit"); limitParam != "" { + if l, err := strconv.Atoi(limitParam); err == nil && l > 0 && l <= 100 { + limit = l + } + } + + // 调用服务层方法 + batchWorkflows, total, err := batchWorkflowService.GetBatchWorkflowList(userID, installedID, page, limit) + if err != nil { + response.FailWithMessage("获取批量工作流列表失败: "+err.Error(), c) + return + } + + // 计算分页信息 + totalPages := (total + int64(limit) - 1) / int64(limit) + hasMore := int64(page) < totalPages + + response.OkWithData(map[string]interface{}{ + "items": batchWorkflows, + "total": total, + "page": page, + "limit": limit, + "total_pages": totalPages, + "has_more": hasMore, + }, c) +} diff --git a/admin/server/config.docker.yaml b/admin/server/config.docker.yaml index ea7a99496..d514f4e08 100644 --- a/admin/server/config.docker.yaml +++ b/admin/server/config.docker.yaml @@ -160,6 +160,7 @@ system: use-strict-auth: false user_default-group-id: "888" docker-run: true + work_flow_number: 100 tencent-cos: bucket: xxxxx-10005608 region: ap-shanghai diff --git a/admin/server/config.yaml b/admin/server/config.yaml index 8b0bf5699..38396a5d5 100644 --- a/admin/server/config.yaml +++ b/admin/server/config.yaml @@ -10,7 +10,7 @@ captcha: open-captcha: 0 open-captcha-timeout: 3600 jwt: - signing-key: + signing-key: sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U expires-time: 1d buffer-time: 1d issuer: CLOUD @@ -61,6 +61,7 @@ system: use-mongo: false use-strict-auth: false user_default-group-id: "888" + work_flow_number: 100 zap: level: info prefix: '[gaia/server]' diff --git a/admin/server/config/system.go b/admin/server/config/system.go index e30e12839..84b7448a6 100644 --- a/admin/server/config/system.go +++ b/admin/server/config/system.go @@ -12,6 +12,7 @@ type System struct { UseMongo bool `mapstructure:"use-mongo" json:"use-mongo" yaml:"use-mongo"` // 使用mongo UseStrictAuth bool `mapstructure:"use-strict-auth" json:"use-strict-auth" yaml:"use-strict-auth"` // 使用树形角色分配模式 // Extend: Start Custom Configuration + WorkFlowNumber int `mapstructure:"work_flow_number" default:"200" json:"work_flow_number" yaml:"work_flow_number"` UserDefaultGroupID string `mapstructure:"user_default-group-id" default:"888" json:"user_default-group-id" yaml:"user_default-group-id"` // 用户默认群组id DockerRun bool `mapstructure:"docker-run" default:false json:"docker-run" yaml:"docker-run"` // 是否在docker中运行,如果是的话,无需自动生成jwtkey,直接填sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U与dify保持一致 // Extend: Stop Custom Configuration diff --git a/admin/server/core/viper.go b/admin/server/core/viper.go index cc119a3c9..4314784f1 100644 --- a/admin/server/core/viper.go +++ b/admin/server/core/viper.go @@ -14,6 +14,19 @@ import ( "github.com/flipped-aurora/gin-vue-admin/server/global" ) +// Extend: Override JWT signing key from environment variable +// This ensures admin-server uses the same JWT signing key as the API server +func overrideJWTSigningKeyFromEnv() { + // Check JWT_SIGNING_KEY first, then fall back to SECRET_KEY + if jwtKey := os.Getenv("JWT_SIGNING_KEY"); jwtKey != "" { + global.GVA_CONFIG.JWT.SigningKey = jwtKey + fmt.Printf("JWT signing key overridden from JWT_SIGNING_KEY environment variable\n") + } else if secretKey := os.Getenv("SECRET_KEY"); secretKey != "" { + global.GVA_CONFIG.JWT.SigningKey = secretKey + fmt.Printf("JWT signing key overridden from SECRET_KEY environment variable\n") + } +} + // Viper // // 优先级: 命令行 > 环境变量 > 默认值 // Author [SliverHorn](https://github.com/SliverHorn) @@ -60,11 +73,16 @@ func Viper(path ...string) *viper.Viper { if err = v.Unmarshal(&global.GVA_CONFIG); err != nil { fmt.Println(err) } + // Extend: Override JWT signing key from environment variable + overrideJWTSigningKeyFromEnv() }) if err = v.Unmarshal(&global.GVA_CONFIG); err != nil { panic(err) } + // Extend: Override JWT signing key from environment variable after initial load + overrideJWTSigningKeyFromEnv() + // root 适配性 根据root位置去找到对应迁移位置,保证root路径有效 global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("..") diff --git a/admin/server/initialize/ensure_tables.go b/admin/server/initialize/ensure_tables.go index f70df220e..d350eb07f 100644 --- a/admin/server/initialize/ensure_tables.go +++ b/admin/server/initialize/ensure_tables.go @@ -66,6 +66,9 @@ func (e *ensureTables) MigrateTable(ctx context.Context) (context.Context, error gaia.AppRequestTestBatch{}, gaia.AppRequestTest{}, gaia.SystemIntegration{}, // Extend System Integration + gaia.ForwardingExtend{}, // Extend Forwarding Extend + gaia.BatchWorkflow{}, // Extend Batch Workflow + gaia.BatchWorkflowTask{}, // Extend Batch Workflow Task sysModel.SysUserGlobalCode{}, // Extend Global Code // Extend gaia model } @@ -112,6 +115,9 @@ func (e *ensureTables) TableCreated(ctx context.Context) bool { gaia.AppRequestTestBatch{}, gaia.AppRequestTest{}, gaia.SystemIntegration{}, // Extend System Integration + gaia.ForwardingExtend{}, // Extend Forwarding Extend + gaia.BatchWorkflow{}, // Extend Batch Workflow + gaia.BatchWorkflowTask{}, // Extend Batch Workflow Task sysModel.SysUserGlobalCode{}, // Extend Global Code // Extend gaia model } diff --git a/admin/server/initialize/gorm.go b/admin/server/initialize/gorm.go index bb138f203..84353367e 100644 --- a/admin/server/initialize/gorm.go +++ b/admin/server/initialize/gorm.go @@ -38,7 +38,6 @@ func Gorm() *gorm.DB { func RegisterTables() { db := global.GVA_DB err := db.AutoMigrate( - system.SysApi{}, system.SysIgnoreApi{}, system.SysUser{}, @@ -68,6 +67,9 @@ func RegisterTables() { gaia.AppRequestTestBatch{}, gaia.AppRequestTest{}, gaia.SystemIntegration{}, // Extend System Integration + gaia.ForwardingExtend{}, // Extend Forwarding Extend + gaia.BatchWorkflow{}, // Extend Batch Workflow + gaia.BatchWorkflowTask{}, // Extend Batch Workflow Task system.SysUserGlobalCode{}, // Extend Global Code // Extend gaia model ) diff --git a/admin/server/initialize/router_biz.go b/admin/server/initialize/router_biz.go index 9b08faeaf..521da06f1 100644 --- a/admin/server/initialize/router_biz.go +++ b/admin/server/initialize/router_biz.go @@ -20,5 +20,6 @@ func initBizRouter(routers ...*gin.RouterGroup) { gaiaRouter.InitTenantsRouter(privateGroup, publicGroup) gaiaRouter.InitTestRouter(privateGroup, publicGroup) gaiaRouter.InitSystemRouter(privateGroup) + gaiaRouter.InitWorkflowRouter(privateGroup) } } diff --git a/admin/server/initialize/worker_pool.go b/admin/server/initialize/worker_pool.go new file mode 100644 index 000000000..d0f3849a6 --- /dev/null +++ b/admin/server/initialize/worker_pool.go @@ -0,0 +1,26 @@ +package initialize + +import ( + "fmt" + "github.com/flipped-aurora/gin-vue-admin/server/global" + "github.com/flipped-aurora/gin-vue-admin/server/service/gaia" +) + +// InitWorkerPool 初始化工作池 +func InitWorkerPool() { + // 从配置中获取worker数量,默认为5 + workerCount := 5 + if global.GVA_CONFIG.System.WorkFlowNumber > 0 { + workerCount = global.GVA_CONFIG.System.WorkFlowNumber + } + global.GVA_LOG.Info(fmt.Sprintf("正在启动批量任务工作池,工作器数量: %d", workerCount)) + gaia.InitWorkerPool(workerCount) + global.GVA_LOG.Info("批量任务工作池启动完成") +} + +// StopWorkerPool 停止工作池(优雅关闭时调用) +func StopWorkerPool() { + global.GVA_LOG.Info("正在停止批量任务工作池...") + gaia.StopWorkerPool() + global.GVA_LOG.Info("批量任务工作池已停止") +} diff --git a/admin/server/main.go b/admin/server/main.go index 6f4444892..1a53456b7 100644 --- a/admin/server/main.go +++ b/admin/server/main.go @@ -30,6 +30,7 @@ func main() { initialize.DBList() if global.GVA_DB != nil { initialize.RegisterTables() // 初始化表 + initialize.InitWorkerPool() // 初始化工作池 // 程序结束前关闭数据库链接 db, _ := global.GVA_DB.DB() defer db.Close() diff --git a/admin/server/model/gaia/batch_workflow.go b/admin/server/model/gaia/batch_workflow.go new file mode 100644 index 000000000..bae18093e --- /dev/null +++ b/admin/server/model/gaia/batch_workflow.go @@ -0,0 +1,71 @@ +package gaia + +import "time" + +// 批量工作流状态常量 +const ( + BatchWorkflowStatusPending = "pending" // 待处理 + BatchWorkflowStatusProcessing = "processing" // 处理中 + BatchWorkflowStatusCompleted = "completed" // 已完成 + BatchWorkflowStatusFailed = "failed" // 失败 + BatchWorkflowStatusStopped = "stopped" // 已停止 +) + +// 批量工作流任务状态常量 +const ( + BatchTaskStatusPending = "pending" // 待处理 + BatchTaskStatusQueued = "queued" // 队列中 + BatchTaskStatusRunning = "running" // 运行中 + BatchTaskStatusCompleted = "completed" // 已完成 + BatchTaskStatusFailed = "failed" // 失败 + BatchTaskStatusCancelled = "cancelled" // 已取消 +) + +// 批量工作流错误消息常量 +const ( + ErrorInsufficientBalance = "余额不足,调用失败!" + ErrorMaxRetryExceeded = "重试超过3次" + ErrorWorkflowFailed = "工作流执行失败" + ErrorCallAPIFailed = "调用Dify API失败" + ErrorParseResultFailed = "解析API返回结果失败" +) + +// 批量工作流配置常量 +const ( + MaxTaskRetryCount = 3 // 最大任务重试次数 + ErrorPenaltyThreshold = 50 // 错误惩罚阈值(每50个错误减少1个并发位) +) + +// BatchWorkflow 批量工作流处理 +type BatchWorkflow struct { + ID string `json:"id" gorm:"primaryKey;comment:批量处理ID"` + UserID uint `json:"user_id" gorm:"index;comment:用户id"` + InstalledID string `json:"installed_id" gorm:"not null;comment:安装的应用ID"` + FileName string `json:"file_name" gorm:"not null;comment:上传的文件名"` + TotalRows int `json:"total_rows" gorm:"not null;default:0;comment:总行数"` + ProcessedRows int `json:"processed_rows" gorm:"not null;default:0;comment:已处理行数"` + Status string `json:"status" gorm:"not null;default:'pending';comment:状态: pending, processing, completed, failed, stopped"` + Results string `json:"results" gorm:"type:text;comment:处理结果"` + KeyName string `json:"key_name" gorm:"type:text;comment:键名"` + Error string `json:"error" gorm:"comment:错误信息"` + ErrorCount int `json:"error_count" gorm:"not null;default:0;comment:累计错误次数"` + CreatedAt time.Time `json:"created_at" gorm:"not null;default:CURRENT_TIMESTAMP(0);comment:创建时间"` + UpdatedAt time.Time `json:"updated_at" gorm:"not null;default:CURRENT_TIMESTAMP(0);comment:更新时间"` +} + +// BatchWorkflowTask 批量工作流任务 +type BatchWorkflowTask struct { + ID string `json:"id" gorm:"primaryKey;comment:任务ID"` + BatchWorkflowID string `json:"batch_workflow_id" gorm:"not null;comment:批量处理ID"` + RowIndex int `json:"row_index" gorm:"not null;comment:行索引"` + Inputs string `json:"inputs" gorm:"type:text;comment:输入参数"` + Status string `json:"status" gorm:"not null;default:'pending';comment:状态: pending, running, completed, failed, cancelled"` + Result string `json:"result" gorm:"type:text;comment:处理结果"` + Error string `json:"error" gorm:"comment:错误信息"` + ErrorCount int `json:"error_count" gorm:"not null;default:0;comment:错误次数"` + CreatedAt time.Time `json:"created_at" gorm:"not null;default:CURRENT_TIMESTAMP(0);comment:创建时间"` + UpdatedAt time.Time `json:"updated_at" gorm:"not null;default:CURRENT_TIMESTAMP(0);comment:更新时间"` +} + +func (BatchWorkflow) TableName() string { return "batch_workflows_extend" } +func (BatchWorkflowTask) TableName() string { return "batch_workflow_tasks_extend" } diff --git a/admin/server/model/gaia/request/workflow.go b/admin/server/model/gaia/request/workflow.go new file mode 100644 index 000000000..2250f423e --- /dev/null +++ b/admin/server/model/gaia/request/workflow.go @@ -0,0 +1,44 @@ +package request + +type WorkflowBatchProcessing struct { + Outputs map[string]string `json:"outputs" gorm:"comment:从任务生成CSV内容"` // 从任务生成CSV内容 +} + +// SSEEvent 表示一个SSE事件 +type SSEEvent struct { + Event string `json:"event"` + Data map[string]interface{} `json:"data"` +} + +// NodeExecution 表示节点执行信息 +type NodeExecution struct { + ID string `json:"id"` + NodeID string `json:"node_id"` + NodeType string `json:"node_type"` + Title string `json:"title"` + Index int `json:"index"` + Status string `json:"status"` + Error string `json:"error,omitempty"` + ElapsedTime float64 `json:"elapsed_time"` + Inputs map[string]interface{} `json:"inputs,omitempty"` + Outputs map[string]interface{} `json:"outputs,omitempty"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at,omitempty"` +} + +// WorkflowResult 表示工作流执行结果 +type WorkflowResult struct { + WorkflowRunID string `json:"workflow_run_id"` + WorkflowID string `json:"workflow_id"` + SequenceNumber int `json:"sequence_number"` + Status string `json:"status"` + Outputs map[string]interface{} `json:"outputs"` + Error string `json:"error,omitempty"` + ElapsedTime float64 `json:"elapsed_time"` + TotalTokens int `json:"total_tokens"` + TotalSteps int `json:"total_steps"` + ExceptionsCount int `json:"exceptions_count"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at,omitempty"` + Nodes []NodeExecution `json:"nodes"` +} diff --git a/admin/server/router/gaia/enter.go b/admin/server/router/gaia/enter.go index 39fa09088..ab5375e5f 100644 --- a/admin/server/router/gaia/enter.go +++ b/admin/server/router/gaia/enter.go @@ -8,6 +8,7 @@ type RouterGroup struct { TenantsRouter SystemRouter TestRouter + WorkflowRouter } var ( @@ -18,3 +19,4 @@ var systemOAuth2Api = api.ApiGroupApp.GaiaApiGroup.SystemOAuth2Api var systemApi = api.ApiGroupApp.GaiaApiGroup.SystemApi var quotaApi = api.ApiGroupApp.GaiaApiGroup.QuotaApi var testApi = api.ApiGroupApp.GaiaApiGroup.TestApi +var batchWorkflowApi = api.ApiGroupApp.GaiaApiGroup.BatchWorkflowApi diff --git a/admin/server/router/gaia/workflow.go b/admin/server/router/gaia/workflow.go new file mode 100644 index 000000000..d51a54e3a --- /dev/null +++ b/admin/server/router/gaia/workflow.go @@ -0,0 +1,35 @@ +package gaia + +import ( + "github.com/gin-gonic/gin" +) + +type WorkflowRouter struct{} + +// InitWorkflowRouter 初始化批量处理工作流路由 +func (w *WorkflowRouter) InitWorkflowRouter(Router *gin.RouterGroup) { + workflowRouter := Router.Group("gaia/workflow") + { + // 批量处理工作流相关路由 + workflowRouter.POST("batch/processing", batchWorkflowApi.CreateBatchWorkflow) // 创建批量处理 + workflowRouter.GET("batch/list", batchWorkflowApi.GetBatchWorkflowList) // 获取最近30天的批量工作流列表 + workflowRouter.GET("batch/:id", batchWorkflowApi.GetBatchWorkflow) // 获取批量处理信息 + workflowRouter.GET("batch/:id/tasks", batchWorkflowApi.GetBatchWorkflowTasks) // 获取任务列表 + workflowRouter.GET("batch/:id/progress", batchWorkflowApi.GetBatchWorkflowProgress) // 获取进度信息 + workflowRouter.POST("batch/:id/stop", batchWorkflowApi.StopBatchWorkflow) // 停止批量处理 + workflowRouter.POST("batch/:id/retry", batchWorkflowApi.RetryBatchWorkflow) // 重试批量处理(重新开始所有任务) + workflowRouter.POST("batch/:id/retry-failed", batchWorkflowApi.RetryFailedTasks) // 仅重试失败的任务 + workflowRouter.POST("batch/:id/resume", batchWorkflowApi.ResumeBatchWorkflow) // 恢复批量处理 + workflowRouter.GET("batch/:id/download", batchWorkflowApi.DownloadBatchWorkflowResults) // 下载结果 + + // 工作池管理相关路由 + workflowRouter.GET("worker-pool/status", batchWorkflowApi.GetWorkerPoolStatus) // 获取工作池状态 + workflowRouter.POST("worker-pool/restart", batchWorkflowApi.RestartWorkerPool) // 重启工作池 + workflowRouter.POST("worker-pool/stop", batchWorkflowApi.StopWorkerPool) // 停止工作池 + workflowRouter.POST("worker-pool/start", batchWorkflowApi.StartWorkerPool) // 启动工作池 + + // 错误计数重置相关路由 + workflowRouter.POST("batch/:id/reset-error-count", batchWorkflowApi.ResetBatchWorkflowErrorCount) // 重置批量工作流错误计数 + workflowRouter.POST("batch/reset-user-error-count", batchWorkflowApi.ResetUserErrorCount) // 重置用户所有批量工作流错误计数 + } +} diff --git a/admin/server/service/gaia/account.go b/admin/server/service/gaia/account.go index ed59302f6..bc17644cf 100644 --- a/admin/server/service/gaia/account.go +++ b/admin/server/service/gaia/account.go @@ -128,8 +128,12 @@ func RegisterUser(u system.SysUser, token string) (err error) { global.GVA_LOG.Debug("注册用户信息:", zap.Any("1", 1)) var acc gaia.Account if err = global.GVA_DB.Where("email=?", u.Email).First(&acc).Error; err == nil { - // 用户已存在 - global.GVA_LOG.Info(fmt.Sprintf("account %s", acc.Name)) + // 用户已存在,更新密码 + global.GVA_LOG.Info(fmt.Sprintf("account %s already exists, updating password", acc.Name)) + global.GVA_DB.Model(&acc).Updates(&map[string]interface{}{ + "password": passwordHashed, + "password_salt": salt, + }) return nil } // 默认以root执行 @@ -178,7 +182,7 @@ func RegisterUser(u system.SysUser, token string) (err error) { } // result - if result, ok := bodyMap["result"]; !ok && result != "success" { + if result, ok := bodyMap["result"]; !ok || result != "success" { return errors.New(fmt.Sprintf("failed to create user: %s", bodyMap["error"])) } // 修改密码 diff --git a/admin/server/service/gaia/batch_workflow.go b/admin/server/service/gaia/batch_workflow.go new file mode 100644 index 000000000..341e5f05d --- /dev/null +++ b/admin/server/service/gaia/batch_workflow.go @@ -0,0 +1,637 @@ +package gaia + +import ( + "database/sql" + "encoding/json" + "fmt" + "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request" + "github.com/pkg/errors" + "io" + "net/http" + "strings" + "time" + + "github.com/flipped-aurora/gin-vue-admin/server/global" + "github.com/flipped-aurora/gin-vue-admin/server/model/gaia" + "github.com/google/uuid" +) + +type BatchWorkflowService struct{} + +// CreateBatchWorkflow 创建批量处理工作流 + +func (s *BatchWorkflowService) CreateBatchWorkflow( + userId uint, installedID, fileName string, fileContent [][]string, keyNameMapping map[string]string) ( + *gaia.BatchWorkflow, error) { + // 检查数据库连接 + if global.GVA_DB == nil { + return nil, fmt.Errorf("数据库连接未初始化") + } + + // 创建批量处理记录 + keyByte, _ := json.Marshal(keyNameMapping) + batchWorkflow := &gaia.BatchWorkflow{ + ProcessedRows: 0, + UserID: userId, + FileName: fileName, + Status: "pending", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + InstalledID: installedID, + KeyName: string(keyByte), + ID: uuid.New().String(), + TotalRows: 0, // 先设为0,后面会更新为实际有效行数 + } + + // 保存到数据库 + if err := global.GVA_DB.Create(batchWorkflow).Error; err != nil { + return nil, fmt.Errorf("保存批量处理记录失败: %v", err) + } + + // 创建任务记录 + headers := fileContent[0] + if len(headers) > 0 { + // 去除UTF-8 BOM + headers[0] = strings.TrimPrefix(headers[0], "\uFEFF") + } + dataRows := fileContent[1:] + + validRowCount := 0 // 记录有效行数 + for i, row := range dataRows { + // 构建输入参数 + inputs := make(map[string]string) + hasNonEmptyValue := false // 检查是否有非空值 + + for j, value := range row { + if j < len(headers) { + headerName := headers[j] + // 去除首尾空格 + value = strings.TrimSpace(value) + + // 如果有key-name映射,使用映射后的key,否则使用原始header + if keyNameMapping != nil { + if key, exists := keyNameMapping[headerName]; exists { + inputs[key] = value + } else { + inputs[headerName] = value + } + } else { + inputs[headerName] = value + } + + // 检查是否有非空值 + if value != "" { + hasNonEmptyValue = true + } + } + } + + // 如果所有字段都为空,跳过这一行 + if !hasNonEmptyValue { + global.GVA_LOG.Info(fmt.Sprintf("跳过空值行,行索引: %d", i+1)) + continue + } + + validRowCount++ + inputsJSON, _ := json.Marshal(inputs) + + task := &gaia.BatchWorkflowTask{ + ID: uuid.New().String(), + BatchWorkflowID: batchWorkflow.ID, + RowIndex: i + 1, + Inputs: string(inputsJSON), + Status: "pending", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := global.GVA_DB.Create(task).Error; err != nil { + return nil, fmt.Errorf("创建任务记录失败: %v", err) + } + } + + // 更新批量处理记录的总行数为实际有效行数 + if err := global.GVA_DB.Model(batchWorkflow).Update("total_rows", validRowCount).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新总行数失败: %v", err)) + } + + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 创建完成,原始行数: %d,有效行数: %d", + batchWorkflow.ID, len(fileContent)-1, validRowCount)) + + // 任务已创建,工作池会自动处理 + // 确保工作池在运行 + if pool := GetWorkerPool(); pool == nil || !pool.IsRunning() { + global.GVA_LOG.Warn("工作池未运行,尝试重新启动") + InitWorkerPool(global.GVA_CONFIG.System.WorkFlowNumber) // 默认5个worker + } + + // 更新批处理工作流状态为处理中 + if err := global.GVA_DB.Model(batchWorkflow).Update("status", "processing").Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新批处理工作流状态失败: %v", err)) + } + + return batchWorkflow, nil +} + +// parseSSEStream 解析SSE流并返回最终结果 +func (s *BatchWorkflowService) parseSSEStream(body []byte) (*request.WorkflowResult, error) { + lines := strings.Split(string(body), "\n") + result := &request.WorkflowResult{ + Nodes: make([]request.NodeExecution, 0), + } + nodeMap := make(map[string]*request.NodeExecution) // 用于跟踪节点状态 + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || !strings.HasPrefix(line, "data: ") { + continue + } + + // 移除 "data: " 前缀 + jsonStr := strings.TrimPrefix(line, "data: ") + + // 解析JSON + var event map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &event); err != nil { + continue // 跳过无法解析的行 + } + + eventType, ok := event["event"].(string) + if !ok { + continue + } + + // 兼容新旧格式:如果有data字段则使用data,否则使用顶层数据 + var data map[string]interface{} + if dataField, hasData := event["data"].(map[string]interface{}); hasData { + // 旧格式:事件数据在data字段中 + data = dataField + } else { + // 新格式:事件数据在顶层 + data = event + } + + switch eventType { + case "workflow_started": + if workflowRunID, ok := data["id"].(string); ok { + result.WorkflowRunID = workflowRunID + } + if workflowID, ok := data["workflow_id"].(string); ok { + result.WorkflowID = workflowID + } + if sequenceNumber, ok := data["sequence_number"].(float64); ok { + result.SequenceNumber = int(sequenceNumber) + } + if createdAt, ok := data["created_at"].(float64); ok { + result.CreatedAt = int64(createdAt) + } + + case "node_started": + nodeExecution := &request.NodeExecution{} + if id, ok := data["id"].(string); ok { + nodeExecution.ID = id + } + if nodeID, ok := data["node_id"].(string); ok { + nodeExecution.NodeID = nodeID + } + if nodeType, ok := data["node_type"].(string); ok { + nodeExecution.NodeType = nodeType + } + if title, ok := data["title"].(string); ok { + nodeExecution.Title = title + } + if index, ok := data["index"].(float64); ok { + nodeExecution.Index = int(index) + } + if inputs, ok := data["inputs"].(map[string]interface{}); ok { + nodeExecution.Inputs = inputs + } + if createdAt, ok := data["created_at"].(float64); ok { + nodeExecution.CreatedAt = int64(createdAt) + } + + nodeMap[nodeExecution.ID] = nodeExecution + + case "node_finished": + nodeID, ok := data["id"].(string) + if !ok { + continue + } + + node, exists := nodeMap[nodeID] + if !exists { + // 如果没有找到对应的开始节点,创建一个新的 + node = &request.NodeExecution{} + if id, ok := data["id"].(string); ok { + node.ID = id + } + if nodeIDStr, ok := data["node_id"].(string); ok { + node.NodeID = nodeIDStr + } + if nodeType, ok := data["node_type"].(string); ok { + node.NodeType = nodeType + } + if title, ok := data["title"].(string); ok { + node.Title = title + } + if index, ok := data["index"].(float64); ok { + node.Index = int(index) + } + nodeMap[nodeID] = node + } + + // 更新节点完成信息 + if status, ok := data["status"].(string); ok { + node.Status = status + } + if errorMsg, ok := data["error"].(string); ok && errorMsg != "" { + node.Error = errorMsg + } + if elapsedTime, ok := data["elapsed_time"].(float64); ok { + node.ElapsedTime = elapsedTime + } + if outputs, ok := data["outputs"].(map[string]interface{}); ok { + node.Outputs = outputs + } + if finishedAt, ok := data["finished_at"].(float64); ok { + node.FinishedAt = int64(finishedAt) + } + + case "workflow_finished": + if status, ok := data["status"].(string); ok { + result.Status = status + } + if outputs, ok := data["outputs"].(map[string]interface{}); ok { + result.Outputs = outputs + } + if errorMsg, ok := data["error"].(string); ok { + result.Error = errorMsg + } + if elapsedTime, ok := data["elapsed_time"].(float64); ok { + result.ElapsedTime = elapsedTime + } + if totalTokens, ok := data["total_tokens"].(float64); ok { + result.TotalTokens = int(totalTokens) + } + if totalSteps, ok := data["total_steps"].(float64); ok { + result.TotalSteps = int(totalSteps) + } + if exceptionsCount, ok := data["exceptions_count"].(float64); ok { + result.ExceptionsCount = int(exceptionsCount) + } + if finishedAt, ok := data["finished_at"].(float64); ok { + result.FinishedAt = int64(finishedAt) + } + + case "message": + // 处理新的message事件格式,将answer字段填充到outputs.text中 + if answer, ok := data["answer"].(string); ok && answer != "" { + // 如果result.Outputs为空,初始化它 + if result.Outputs == nil { + result.Outputs = make(map[string]interface{}) + } + if value, okText := result.Outputs["text"]; okText { + result.Outputs["text"] = value.(string) + answer + } else { + result.Outputs["text"] = answer + } + } + // 同时设置其他相关字段 + if messageID, ok := data["message_id"].(string); ok { + result.WorkflowRunID = messageID + } + if createdAt, ok := data["created_at"].(float64); ok { + result.CreatedAt = int64(createdAt) + } + } + } + + // 将节点按照index排序并添加到结果中 + for _, node := range nodeMap { + result.Nodes = append(result.Nodes, *node) + } + + // 按index排序 + for i := 0; i < len(result.Nodes)-1; i++ { + for j := i + 1; j < len(result.Nodes); j++ { + if result.Nodes[i].Index > result.Nodes[j].Index { + result.Nodes[i], result.Nodes[j] = result.Nodes[j], result.Nodes[i] + } + } + } + + return result, nil +} + +// callDifyAPI 调用Dify API +func (s *BatchWorkflowService) callDifyAPI( + installedID, userToken string, inputs map[string]string) (string, error) { + + var err error + var requestBodyJSON []byte + if requestBodyJSON, err = json.Marshal(&map[string]interface{}{ + "inputs": inputs, + "response_mode": "streaming", + }); err != nil { + return "", err + } + var url string + var mode sql.NullString + if err = global.GVA_DB.Raw("SELECT b.mode FROM installed_apps as a, apps as b WHERE a.app_id=b.id AND a.id = ?", installedID).Scan(&mode).Error; err != nil { + return "", err + } + // 区分model + if mode.String == "workflow" { + url = "%s/console/api/installed-apps/%s/workflows/run" + } else if mode.String == "completion" { + url = "%s/console/api/installed-apps/%s/completion-messages" + } else { + return "", errors.New(fmt.Sprintf("Unsupported dify API call: %s", mode.String)) + } + // 创建HTTP请求 + req, err := http.NewRequest("POST", fmt.Sprintf( + url, global.GVA_CONFIG.Gaia.Url, installedID), strings.NewReader(string(requestBodyJSON))) + if err != nil { + return "", err + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+userToken) + req.Header.Set("Accept", "text/event-stream") // 接受SSE流 + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API请求失败,状态码: %d, 响应: %s", resp.StatusCode, string(body)) + } + + // 解析SSE流 + result, err := s.parseSSEStream(body) + if err != nil { + return "", fmt.Errorf("解析SSE流失败: %v", err) + } + + // 将结果转换为JSON字符串返回 + resultJSON, err := json.Marshal(result) + if err != nil { + return "", fmt.Errorf("序列化结果失败: %v", err) + } + + return string(resultJSON), nil +} + +// GetBatchWorkflow 获取批量处理信息 +func (s *BatchWorkflowService) GetBatchWorkflow(id string) (*gaia.BatchWorkflow, error) { + if global.GVA_DB == nil { + return nil, fmt.Errorf("数据库连接未初始化") + } + + var batchWorkflow gaia.BatchWorkflow + if err := global.GVA_DB.Where("id = ?", id).First(&batchWorkflow).Error; err != nil { + return nil, err + } + return &batchWorkflow, nil +} + +// GetBatchWorkflowTasks 获取批量处理的任务列表 +func (s *BatchWorkflowService) GetBatchWorkflowTasks(batchWorkflowID string) ([]gaia.BatchWorkflowTask, error) { + if global.GVA_DB == nil { + return nil, fmt.Errorf("数据库连接未初始化") + } + + var tasks []gaia.BatchWorkflowTask + if err := global.GVA_DB.Where("batch_workflow_id = ?", batchWorkflowID).Order("row_index").Find(&tasks).Error; err != nil { + return nil, err + } + return tasks, nil +} + +// StopBatchWorkflow 停止批量处理 +func (s *BatchWorkflowService) StopBatchWorkflow(id string) error { + if global.GVA_DB == nil { + return fmt.Errorf("数据库连接未初始化") + } + + return global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("id = ?", id).Update("status", "stopped").Error +} + +// RetryFailedTasks 仅重试失败的任务 +func (s *BatchWorkflowService) RetryFailedTasks(id string) error { + if global.GVA_DB == nil { + return fmt.Errorf("数据库连接未初始化") + } + + // 只重置失败的任务为待处理状态,保留已完成的任务 + errorCount := 0 + if err := global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where( + "batch_workflow_id = ? AND status IN ?", id, []string{"failed", "queued", "running"}).Updates( + map[string]interface{}{ + "status": "pending", + "error": "", + "error_count": &errorCount, + "updated_at": time.Now(), + }).Error; err != nil { + return err + } + + // 重新计算已处理行数 + var completedCount int64 + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where( + "batch_workflow_id = ? AND status = ?", id, "completed").Count(&completedCount) + + // 重置批量处理状态 + if err := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("id = ?", id).Updates(map[string]interface{}{ + "status": "pending", + "processed_rows": completedCount, + "error": "", + "updated_at": time.Now(), + }).Error; err != nil { + return err + } + + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 失败任务重试已启动,工作池将自动处理待处理任务", id)) + return nil +} + +// RetryBatchWorkflow 重试批量处理 +func (s *BatchWorkflowService) RetryBatchWorkflow(id string) error { + if global.GVA_DB == nil { + return fmt.Errorf("数据库连接未初始化") + } + + // 重置所有失败的任务为待处理状态 + if err := global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status IN ?", id, []string{"failed", "queued", "running"}).Updates(map[string]interface{}{ + "status": "pending", + "error": "", + "updated_at": time.Now(), + }).Error; err != nil { + return err + } + + // 重新计算已处理行数 + var completedCount int64 + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status = ?", id, "completed").Count(&completedCount) + + // 重置批量处理状态 + if err := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("id = ?", id).Updates(map[string]interface{}{ + "status": "processing", + "processed_rows": completedCount, + "error": "", + "updated_at": time.Now(), + }).Error; err != nil { + return err + } + + // 确保工作池在运行 + if pool := GetWorkerPool(); pool == nil || !pool.IsRunning() { + global.GVA_LOG.Warn("工作池未运行,尝试重新启动") + InitWorkerPool(global.GVA_CONFIG.System.WorkFlowNumber) + } + + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 重试已启动,工作池将自动处理待处理任务", id)) + return nil +} + +// ResumeBatchWorkflow 恢复批量处理 +func (s *BatchWorkflowService) ResumeBatchWorkflow(id string) error { + if global.GVA_DB == nil { + return fmt.Errorf("数据库连接未初始化") + } + + // 检查批量工作流是否存在 + var batchWorkflow gaia.BatchWorkflow + if err := global.GVA_DB.Where("id = ?", id).First(&batchWorkflow).Error; err != nil { + return fmt.Errorf("批量工作流不存在: %v", err) + } + + // 检查批量工作流状态是否为stopped + if batchWorkflow.Status != "stopped" { + return fmt.Errorf("只能恢复已停止的批量处理") + } + + // 检查是否有可恢复的任务(pending 或 cancelled 状态) + var resumableTasks int64 + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status IN (?)", id, []string{"pending", "cancelled"}).Count(&resumableTasks) + + if resumableTasks == 0 { + return fmt.Errorf("没有可恢复的任务") + } + + // 将cancelled状态的任务恢复为pending状态 + if err := global.GVA_DB.Model(&gaia.BatchWorkflowTask{}). + Where("batch_workflow_id = ? AND status = ?", id, "cancelled"). + Updates(map[string]interface{}{ + "status": "pending", + "updated_at": time.Now(), + }).Error; err != nil { + return fmt.Errorf("恢复取消的任务失败: %v", err) + } + + // 更新批量工作流状态为处理中 + if err := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("id = ?", id).Updates(map[string]interface{}{ + "status": "processing", + "updated_at": time.Now(), + }).Error; err != nil { + return err + } + + // 确保工作池在运行 + if pool := GetWorkerPool(); pool == nil || !pool.IsRunning() { + global.GVA_LOG.Warn("工作池未运行,尝试重新启动") + InitWorkerPool(global.GVA_CONFIG.System.WorkFlowNumber) + } + + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 恢复已启动,工作池将自动处理待处理任务", id)) + return nil +} + +// GetBatchWorkflowProgress 获取批量处理进度 +func (s *BatchWorkflowService) GetBatchWorkflowProgress(id string) (map[string]interface{}, error) { + if global.GVA_DB == nil { + return nil, fmt.Errorf("数据库连接未初始化") + } + + var batchWorkflow gaia.BatchWorkflow + if err := global.GVA_DB.Where("id = ?", id).First(&batchWorkflow).Error; err != nil { + return nil, err + } + + // 统计各种状态的任务数量 + var pendingCount, queuedCount, runningCount, completedCount, failedCount int64 + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status = ?", id, "pending").Count(&pendingCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status = ?", id, "queued").Count(&queuedCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status = ?", id, "running").Count(&runningCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status = ?", id, "completed").Count(&completedCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where("batch_workflow_id = ? AND status = ?", id, "failed").Count(&failedCount) + + progress := float64(completedCount) / float64(batchWorkflow.TotalRows) * 100 + + // 获取工作池状态 + var workerPoolStatus map[string]interface{} + if pool := GetWorkerPool(); pool != nil { + workerPoolStatus = pool.GetStatus() + } else { + workerPoolStatus = map[string]interface{}{ + "running": false, + "workers": 0, + "queue_length": 0, + } + } + + // 获取错误信息 - 从批量工作流本身和失败的任务中获取 + var errorInfo string + if batchWorkflow.Error != "" { + errorInfo = batchWorkflow.Error + } else if failedCount > 0 { + // 如果有失败的任务,获取第一个失败任务的错误信息作为代表 + var failedTask gaia.BatchWorkflowTask + if err := global.GVA_DB.Where("batch_workflow_id = ? AND status = ?", id, "failed").First(&failedTask).Error; err == nil && failedTask.Error != "" { + errorInfo = failedTask.Error + } + } + + return map[string]interface{}{ + "id": batchWorkflow.ID, + "status": batchWorkflow.Status, + "total_rows": batchWorkflow.TotalRows, + "processed_rows": completedCount, // 使用实时统计值确保与progress一致 + "progress": progress, + "pending_count": pendingCount, + "queued_count": queuedCount, + "running_count": runningCount, + "completed_count": completedCount, + "failed_count": failedCount, + "error": errorInfo, // 添加错误信息 + "worker_pool_status": workerPoolStatus, + "created_at": batchWorkflow.CreatedAt, + "updated_at": batchWorkflow.UpdatedAt, + }, nil +} + +// GetWorkerPool 获取全局工作池 +func (s *BatchWorkflowService) GetWorkerPool() *WorkerPool { + return GetWorkerPool() +} + +// InitWorkerPool 初始化工作池 +func (s *BatchWorkflowService) InitWorkerPool(workers int) { + InitWorkerPool(workers) +} + +// StopWorkerPool 停止工作池 +func (s *BatchWorkflowService) StopWorkerPool() { + StopWorkerPool() +} diff --git a/admin/server/service/gaia/enter.go b/admin/server/service/gaia/enter.go index 3685b74ab..2de7015c1 100644 --- a/admin/server/service/gaia/enter.go +++ b/admin/server/service/gaia/enter.go @@ -6,4 +6,5 @@ type ServiceGroup struct { QuotaService TenantsService TestService + BatchWorkflowService } diff --git a/admin/server/service/gaia/worker_pool.go b/admin/server/service/gaia/worker_pool.go new file mode 100644 index 000000000..caa4bc21b --- /dev/null +++ b/admin/server/service/gaia/worker_pool.go @@ -0,0 +1,1281 @@ +package gaia + +import ( + "context" + "encoding/json" + "fmt" + "github.com/flipped-aurora/gin-vue-admin/server/model/system" + "github.com/flipped-aurora/gin-vue-admin/server/utils" + "strconv" + "strings" + "sync" + "time" + + "github.com/flipped-aurora/gin-vue-admin/server/global" + "github.com/flipped-aurora/gin-vue-admin/server/model/gaia" +) + +// UserWorkerAllocation 用户工作器分配信息 +type UserWorkerAllocation struct { + UserID uint `json:"user_id"` + Workers int `json:"workers"` + MaxLimit int `json:"max_limit"` +} + +// WorkerPool 工作池管理器 +type WorkerPool struct { + ctx context.Context + cancel context.CancelFunc + totalWorkers int // 总工作器数量 + userWorkers map[uint]*UserWorkerAllocation // 每个用户的工作器分配 + userTaskChan map[uint]chan *gaia.BatchWorkflowTask // 每个用户的任务队列 + runningWorkers map[uint]int // 每个用户当前运行的worker数量 + wg sync.WaitGroup + batchService *BatchWorkflowService + running bool + mutex sync.RWMutex + userMutex sync.RWMutex +} + +// NewWorkerPool 创建新的工作池 +func NewWorkerPool(totalWorkers int) *WorkerPool { + ctx, cancel := context.WithCancel(context.Background()) + return &WorkerPool{ + ctx: ctx, + cancel: cancel, + totalWorkers: totalWorkers, + userWorkers: make(map[uint]*UserWorkerAllocation), + userTaskChan: make(map[uint]chan *gaia.BatchWorkflowTask), + runningWorkers: make(map[uint]int), + batchService: &BatchWorkflowService{}, + running: false, + } +} + +// calculateWorkerCountWithErrorPenalty 根据错误次数计算工作器数量 +// 每50个错误减少1个并发位,最少保留1个并发位 +func (wp *WorkerPool) calculateWorkerCountWithErrorPenalty(baseWorkers int, errorCount int) int { + if baseWorkers <= 0 { + return 1 + } + + // 计算错误惩罚:每50个错误减少1个并发位 + penalty := errorCount / gaia.ErrorPenaltyThreshold + adjustedWorkers := baseWorkers - penalty + + // 确保至少保留1个并发位 + if adjustedWorkers < 1 { + adjustedWorkers = 1 + } + + return adjustedWorkers +} + +// calculateUserWorkerAllocation 计算用户工作器分配 +func (wp *WorkerPool) calculateUserWorkerAllocation() { + wp.userMutex.Lock() + defer wp.userMutex.Unlock() + + // 获取有批量任务的活跃用户(按需分配) + // 排除超过重试次数的任务,只考虑可以继续处理的任务 + // 分两个查询:1. 获取有活跃任务的用户,2. 获取用户的累计错误次数 + + // 第一个查询:获取有活跃批量任务的用户 + var activeUserIDs []uint + err := global.GVA_DB.Raw(` + SELECT DISTINCT bw.user_id + FROM batch_workflows_extend bw + INNER JOIN sys_users su ON bw.user_id = su.id + INNER JOIN batch_workflow_tasks_extend bwt ON bw.id = bwt.batch_workflow_id + WHERE su.enable = ? + AND bw.status IN (?, ?) + AND (bwt.status IN (?, ?) AND bwt.error_count < ?) + `, system.UserActive, gaia.BatchWorkflowStatusPending, gaia.BatchWorkflowStatusProcessing, + gaia.BatchTaskStatusPending, gaia.BatchTaskStatusQueued, gaia.MaxTaskRetryCount).Scan(&activeUserIDs).Error + + if err != nil { + global.GVA_LOG.Error("获取有批量任务的活跃用户失败: " + err.Error()) + return + } + + // 第二个查询:获取这些用户的所有批量工作流的累计错误次数(不限状态) + type UserErrorInfo struct { + UserID uint `json:"user_id"` + ErrorCount int `json:"error_count"` + } + var userErrorInfos []UserErrorInfo + if len(activeUserIDs) > 0 { + err = global.GVA_DB.Raw(` + SELECT bw.user_id, COALESCE(SUM(bw.error_count), 0) as error_count + FROM batch_workflows_extend bw + WHERE bw.user_id IN (?) AND bw.status='pending' + GROUP BY bw.user_id + `, activeUserIDs).Scan(&userErrorInfos).Error + + if err != nil { + global.GVA_LOG.Error("获取用户累计错误次数失败: " + err.Error()) + return + } + } + + // 提取活跃用户ID列表和错误次数映射 + userErrorMap := make(map[uint]int) + for _, info := range userErrorInfos { + userErrorMap[info.UserID] = info.ErrorCount + } + + userCount := len(activeUserIDs) + if userCount == 0 { + // 如果没有用户有批量任务,关闭所有队列 + for _, ch := range wp.userTaskChan { + close(ch) + } + wp.userWorkers = make(map[uint]*UserWorkerAllocation) + wp.userTaskChan = make(map[uint]chan *gaia.BatchWorkflowTask) + wp.runningWorkers = make(map[uint]int) + return + } + + // 创建活跃用户ID集合 + activeUserIDMap := make(map[uint]bool) + for _, userID := range activeUserIDs { + activeUserIDMap[userID] = true + } + + // 关闭不再有批量任务的用户的任务队列 + for userID, ch := range wp.userTaskChan { + if !activeUserIDMap[userID] { + close(ch) + delete(wp.userTaskChan, userID) + delete(wp.userWorkers, userID) + delete(wp.runningWorkers, userID) + } + } + + // 检查用户数量是否超过了最大支持数量(每用户最少1个工作器) + maxSupportedUsers := wp.totalWorkers / 1 + + // 存储新的分配计算结果 + newAllocations := make(map[uint]*UserWorkerAllocation) + + if userCount <= maxSupportedUsers { + // 用户数量在可支持范围内,采用两阶段分配策略 + baseAllocation := wp.totalWorkers / userCount + remainder := wp.totalWorkers % userCount + + // 第一阶段:计算每个用户的基础分配和错误惩罚后的实际分配 + type UserAllocationInfo struct { + UserID uint + BaseWorkers int + ActualWorkers int + ErrorCount int + PenaltyReduced int + } + + var userAllocations []UserAllocationInfo + totalPenaltyReduced := 0 + + for i, userID := range activeUserIDs { + baseWorkers := baseAllocation + // 处理余数,前几个用户多分配一个 + if i < remainder { + baseWorkers++ + } + + // 确保每个用户至少有1个并发位 + if baseWorkers < 1 { + baseWorkers = 1 + } + + // 应用错误惩罚:根据用户的累计错误次数减少并发位 + errorCount := userErrorMap[userID] + actualWorkers := wp.calculateWorkerCountWithErrorPenalty(baseWorkers, errorCount) + penaltyReduced := baseWorkers - actualWorkers + totalPenaltyReduced += penaltyReduced + userAllocations = append(userAllocations, UserAllocationInfo{ + UserID: userID, + BaseWorkers: baseWorkers, + ActualWorkers: actualWorkers, + ErrorCount: errorCount, + PenaltyReduced: penaltyReduced, + }) + } + + // 第二阶段:将空出来的并发位重新分配给错误较少的用户 + if totalPenaltyReduced > 0 { + // 按错误数量排序,错误少的用户优先获得额外分配 + for i := 0; i < len(userAllocations)-1; i++ { + for j := i + 1; j < len(userAllocations); j++ { + if userAllocations[i].ErrorCount > userAllocations[j].ErrorCount { + userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i] + } + } + } + + // 只为没有被惩罚的用户(PenaltyReduced = 0)重新分配空闲的并发位 + // 被惩罚的用户不应该获得额外分配 + remainingToDistribute := totalPenaltyReduced + eligibleUsers := 0 + + // 计算有资格获得额外分配的用户数量(没有被惩罚的用户) + for _, allocation := range userAllocations { + if allocation.PenaltyReduced == 0 { + eligibleUsers++ + } + } + + if eligibleUsers > 0 { + // 只为没有被惩罚的用户分配额外的并发位 + for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ { + if userAllocations[i].PenaltyReduced == 0 { + // 为没有错误惩罚的用户分配额外的并发位 + extraWorkers := remainingToDistribute / eligibleUsers + if extraWorkers < 1 { + extraWorkers = 1 + } + if extraWorkers > remainingToDistribute { + extraWorkers = remainingToDistribute + } + + userAllocations[i].ActualWorkers += extraWorkers + remainingToDistribute -= extraWorkers + eligibleUsers-- + } + } + } + } + + // 创建最终分配结果 + totalFinalWorkers := 0 + for _, allocation := range userAllocations { + newAllocations[allocation.UserID] = &UserWorkerAllocation{ + UserID: allocation.UserID, + Workers: allocation.ActualWorkers, + MaxLimit: wp.totalWorkers, + } + totalFinalWorkers += allocation.ActualWorkers + } + } else { + // 用户数量超过最大支持数量,采用降级分配策略(两阶段分配) + baseAllocation := wp.totalWorkers / userCount + remainder := wp.totalWorkers % userCount + + // 第一阶段:计算每个用户的基础分配和错误惩罚后的实际分配 + type UserAllocationInfo struct { + UserID uint + BaseWorkers int + ActualWorkers int + ErrorCount int + PenaltyReduced int + } + + var userAllocations []UserAllocationInfo + totalPenaltyReduced := 0 + + for i, userID := range activeUserIDs { + baseWorkers := baseAllocation + // 处理余数,前几个用户多分配一个 + if i < remainder { + baseWorkers++ + } + + // 确保至少分配1个工作器 + if baseWorkers < 1 { + baseWorkers = 1 + } + + // 应用错误惩罚:根据用户的累计错误次数减少并发位 + errorCount := userErrorMap[userID] + actualWorkers := wp.calculateWorkerCountWithErrorPenalty(baseWorkers, errorCount) + penaltyReduced := baseWorkers - actualWorkers + totalPenaltyReduced += penaltyReduced + + // 添加详细的错误惩罚计算调试日志 + userAllocations = append(userAllocations, UserAllocationInfo{ + UserID: userID, + BaseWorkers: baseWorkers, + ActualWorkers: actualWorkers, + ErrorCount: errorCount, + PenaltyReduced: penaltyReduced, + }) + } + + // 第二阶段:将空出来的并发位重新分配给错误较少的用户 + if totalPenaltyReduced > 0 { + // 按错误数量排序,错误少的用户优先获得额外分配 + for i := 0; i < len(userAllocations)-1; i++ { + for j := i + 1; j < len(userAllocations); j++ { + if userAllocations[i].ErrorCount > userAllocations[j].ErrorCount { + userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i] + } + } + } + + // 只为没有被惩罚的用户(PenaltyReduced = 0)重新分配空闲的并发位 + // 被惩罚的用户不应该获得额外分配 + remainingToDistribute := totalPenaltyReduced + eligibleUsers := 0 + + // 计算有资格获得额外分配的用户数量(没有被惩罚的用户) + for _, allocation := range userAllocations { + if allocation.PenaltyReduced == 0 { + eligibleUsers++ + } + } + + if eligibleUsers > 0 { + // 只为没有被惩罚的用户分配额外的并发位 + for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ { + if userAllocations[i].PenaltyReduced == 0 { + // 为没有错误惩罚的用户分配额外的并发位 + extraWorkers := remainingToDistribute / eligibleUsers + if extraWorkers < 1 { + extraWorkers = 1 + } + if extraWorkers > remainingToDistribute { + extraWorkers = remainingToDistribute + } + + userAllocations[i].ActualWorkers += extraWorkers + remainingToDistribute -= extraWorkers + eligibleUsers-- + } + } + } + } + + // 创建最终分配结果,确保不超过总工作器数量 + allocatedWorkers := 0 + for _, allocation := range userAllocations { + workers := allocation.ActualWorkers + + // 确保不会超过剩余的工作器数量 + remainingWorkers := wp.totalWorkers - allocatedWorkers + if workers > remainingWorkers { + workers = remainingWorkers + } + + if workers > 0 { + newAllocations[allocation.UserID] = &UserWorkerAllocation{ + UserID: allocation.UserID, + Workers: workers, + MaxLimit: wp.totalWorkers, + } + allocatedWorkers += workers + } + + // 如果工作器已经分配完毕,剩余用户分配0个工作器 + if allocatedWorkers >= wp.totalWorkers { + break + } + } + + global.GVA_LOG.Warn(fmt.Sprintf("降级分配完成 - 总工作器: %d, 用户数: %d, 已分配: %d, 重新分配: %d, 平均每用户: %.1f个", + wp.totalWorkers, userCount, allocatedWorkers, totalPenaltyReduced, float64(allocatedWorkers)/float64(userCount))) + } + + // 应用新的分配,只更新有变化的用户 + for userID, newAllocation := range newAllocations { + oldAllocation, exists := wp.userWorkers[userID] + + if !exists { + // 新用户,创建分配和任务队列 + wp.userWorkers[userID] = newAllocation + wp.userTaskChan[userID] = make(chan *gaia.BatchWorkflowTask, newAllocation.Workers*2) + } else if oldAllocation.Workers != newAllocation.Workers { + // 现有用户的工作器数量发生变化,需要重新创建任务队列 + close(wp.userTaskChan[userID]) + wp.userWorkers[userID] = newAllocation + wp.userTaskChan[userID] = make(chan *gaia.BatchWorkflowTask, newAllocation.Workers*2) + // 重置运行中的worker计数,让adjustWorkers重新启动 + wp.runningWorkers[userID] = 0 + } else { + // 工作器数量没有变化,只更新分配信息 + wp.userWorkers[userID] = newAllocation + } + } +} + +// getUserWorkerCount 获取指定用户的工作器数量 +func (wp *WorkerPool) getUserWorkerCount(userID uint) int { + wp.userMutex.RLock() + defer wp.userMutex.RUnlock() + + if allocation, exists := wp.userWorkers[userID]; exists { + return allocation.Workers + } + return 0 +} + +// Start 启动工作池 +func (wp *WorkerPool) Start() { + wp.mutex.Lock() + defer wp.mutex.Unlock() + + if wp.running { + return + } + + // 计算用户工作器分配 + wp.calculateUserWorkerAllocation() + + wp.running = true + + // 启动初始工作器 + wp.startWorkers() + + // 启动任务调度器 + wp.wg.Add(1) + go wp.taskScheduler() + + // 启动用户工作器分配更新器 + wp.wg.Add(1) + go wp.userAllocationUpdater() + + // 启动动态工作器管理器 + wp.wg.Add(1) + go wp.dynamicWorkerManager() +} + +// Stop 停止工作池 +func (wp *WorkerPool) Stop() { + wp.mutex.Lock() + defer wp.mutex.Unlock() + + if !wp.running { + return + } + + wp.cancel() + wp.running = false + + // 关闭所有用户的任务队列 + wp.userMutex.Lock() + for _, ch := range wp.userTaskChan { + close(ch) + } + wp.userMutex.Unlock() + + // 等待所有goroutine完成 + wp.wg.Wait() +} + +// IsRunning 检查工作池是否运行中 +func (wp *WorkerPool) IsRunning() bool { + wp.mutex.RLock() + defer wp.mutex.RUnlock() + return wp.running +} + +// GetStatus 获取工作池状态 +func (wp *WorkerPool) GetStatus() map[string]interface{} { + wp.mutex.RLock() + defer wp.mutex.RUnlock() + + wp.userMutex.RLock() + defer wp.userMutex.RUnlock() + + userAllocations := make(map[string]interface{}) + for userID, allocation := range wp.userWorkers { + userAllocations[fmt.Sprintf("user_%d", userID)] = allocation + } + + // 计算所有用户队列的总长度 + totalQueueLength := 0 + userQueueLengths := make(map[string]int) + for userID, ch := range wp.userTaskChan { + queueLen := len(ch) + totalQueueLength += queueLen + userQueueLengths[fmt.Sprintf("user_%d", userID)] = queueLen + } + + return map[string]interface{}{ + "running": wp.running, + "total_workers": wp.totalWorkers, + "total_queue_length": totalQueueLength, + "user_queue_lengths": userQueueLengths, + "user_allocations": userAllocations, + } +} + +// userAllocationUpdater 用户工作器分配更新器 +func (wp *WorkerPool) userAllocationUpdater() { + defer wp.wg.Done() + + ticker := time.NewTicker(30 * time.Second) // 每30秒检查一次用户变化 + defer ticker.Stop() + + for { + select { + case <-wp.ctx.Done(): + return + case <-ticker.C: + wp.calculateUserWorkerAllocation() + } + } +} + +// dynamicWorkerManager 动态工作器管理器 +func (wp *WorkerPool) dynamicWorkerManager() { + defer wp.wg.Done() + + defer global.GVA_LOG.Info("动态工作器管理器停止") + + ticker := time.NewTicker(10 * time.Second) // 每10秒检查一次工作器状态 + defer ticker.Stop() + + for { + select { + case <-wp.ctx.Done(): + return + case <-ticker.C: + wp.adjustWorkers() + } + } +} + +// startWorkers 启动工作器 +func (wp *WorkerPool) startWorkers() { + wp.userMutex.Lock() + defer wp.userMutex.Unlock() + + for userID, allocation := range wp.userWorkers { + // 启动所有需要的worker + for i := 0; i < allocation.Workers; i++ { + wp.wg.Add(1) + workerID := fmt.Sprintf("user_%d_worker_%d", userID, i) + go wp.worker(workerID, userID) + } + // 更新运行中的worker数量 + wp.runningWorkers[userID] = allocation.Workers + } +} + +// adjustWorkers 调整工作器数量 +func (wp *WorkerPool) adjustWorkers() { + // 重新计算用户分配 + wp.calculateUserWorkerAllocation() + + // 检查哪些用户的工作器数量发生了变化,为它们启动新worker + wp.userMutex.Lock() + defer wp.userMutex.Unlock() + + for userID, allocation := range wp.userWorkers { + runningCount := wp.runningWorkers[userID] + neededCount := allocation.Workers + + if runningCount < neededCount { + // 需要启动更多worker + for i := runningCount; i < neededCount; i++ { + wp.wg.Add(1) + workerID := fmt.Sprintf("user_%d_worker_%d", userID, i) + go wp.worker(workerID, userID) + } + wp.runningWorkers[userID] = neededCount + } + } +} + +// worker 工作协程 +func (wp *WorkerPool) worker(workerID string, userID uint) { + defer wp.wg.Done() + defer func() { + // Worker退出时减少运行中的worker计数 + wp.userMutex.Lock() + if wp.runningWorkers[userID] > 0 { + wp.runningWorkers[userID]-- + } + wp.userMutex.Unlock() + }() + + // 获取用户专属的任务队列 + wp.userMutex.RLock() + userTaskChan, exists := wp.userTaskChan[userID] + wp.userMutex.RUnlock() + + if !exists { + global.GVA_LOG.Error(fmt.Sprintf("Worker %s: 用户 %d 的任务队列不存在", workerID, userID)) + return + } + + for { + select { + case <-wp.ctx.Done(): + return + case task, ok := <-userTaskChan: + if !ok { + // 任务队列已关闭 + return + } + if task != nil { + wp.processTask(task) + } + } + } +} + +// taskScheduler 任务调度器 +func (wp *WorkerPool) taskScheduler() { + defer wp.wg.Done() + ticker := time.NewTicker(2 * time.Second) // 每2秒检查一次新任务 + defer ticker.Stop() + + for { + select { + case <-wp.ctx.Done(): + return + case <-ticker.C: + global.GVA_LOG.Debug("任务调度器开始检查新任务...") + wp.fetchAndScheduleTasks() + } + } +} + +// fetchAndScheduleTasks 获取并调度任务 +func (wp *WorkerPool) fetchAndScheduleTasks() { + global.GVA_LOG.Debug("fetchAndScheduleTasks 开始执行") + + if global.GVA_DB == nil { + global.GVA_LOG.Error("数据库连接为空,无法获取任务") + return + } + + // 获取所有待处理的任务,但排除已停止的批量工作流的任务和超过重试次数的任务 + var tasks []gaia.BatchWorkflowTask + err := global.GVA_DB.Table("batch_workflow_tasks_extend bwt"). + Select("bwt.*"). + Joins("INNER JOIN batch_workflows_extend bw ON bwt.batch_workflow_id = bw.id"). + Where("bwt.status = ? AND bw.status != ? AND bwt.error_count < ?", gaia.BatchTaskStatusPending, gaia.BatchWorkflowStatusStopped, gaia.MaxTaskRetryCount). + Order("bwt.created_at ASC"). + Find(&tasks).Error + + if err != nil { + global.GVA_LOG.Error("获取待处理任务失败: " + err.Error()) + return + } + + // 按用户分组任务 + userTasks := make(map[uint][]*gaia.BatchWorkflowTask) + for i := range tasks { + task := &tasks[i] + // 获取任务对应的用户ID + var batchWorkflow gaia.BatchWorkflow + if err = global.GVA_DB.Where("id = ?", task.BatchWorkflowID).First(&batchWorkflow).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("找不到任务 %s 对应的批量工作流 %s: %s", task.ID, task.BatchWorkflowID, err.Error())) + continue + } + userTasks[batchWorkflow.UserID] = append(userTasks[batchWorkflow.UserID], task) + } + + // 在分配任务前,再次清理已停止的批量工作流任务 + cleanupStoppedBatchWorkflowTasks() + + // 为每个用户分配任务到队列 + for userID, userTaskList := range userTasks { + userWorkerCount := wp.getUserWorkerCount(userID) + if userWorkerCount == 0 { + continue + } + + // 限制任务数量 + if len(userTaskList) > userWorkerCount { + userTaskList = userTaskList[:userWorkerCount] + } + + // 获取用户专属的任务队列 + wp.userMutex.RLock() + userTaskChan, exists := wp.userTaskChan[userID] + wp.userMutex.RUnlock() + + if !exists { + continue + } + + // 将任务标记为排队状态并加入队列 + for _, task := range userTaskList { + // 更新任务状态为排队中 + if err = global.GVA_DB.Model(task).Update("status", gaia.BatchTaskStatusQueued).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新任务状态失败: %s", err.Error())) + continue + } + + // 非阻塞方式添加到用户专属队列 + select { + case userTaskChan <- task: + + //global.GVA_LOG.Info(fmt.Sprintf("成功将任务 %s 添加到用户 %d 的队列", task.ID, userID)) + case <-wp.ctx.Done(): + return + default: + // 队列满了,将任务状态改回pending + //global.GVA_LOG.Warn(fmt.Sprintf("用户 %d 的队列已满,任务 %s 状态改回pending", userID, task.ID)) + global.GVA_DB.Model(task).Update("status", gaia.BatchTaskStatusPending) + } + } + + if len(userTaskList) > 0 { + global.GVA_LOG.Info(fmt.Sprintf("为用户 %d 调度了 %d 个任务到队列", userID, len(userTaskList))) + } + } +} + +// processTask 处理单个任务 +func (wp *WorkerPool) processTask(task *gaia.BatchWorkflowTask) { + // 更新任务状态为运行中 + if err := global.GVA_DB.Model(task).Updates(map[string]interface{}{ + "status": gaia.BatchTaskStatusRunning, + "updated_at": time.Now(), + }).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新任务状态失败: %s", err.Error())) + return + } + + // 获取批量工作流信息 + var batchWorkflow gaia.BatchWorkflow + if err := global.GVA_DB.Where("id = ?", task.BatchWorkflowID).First(&batchWorkflow).Error; err != nil { + wp.updateTaskError(task, "获取批量工作流信息失败: "+err.Error()) + return + } + + // 检查批量工作流是否被停止 + if batchWorkflow.Status == gaia.BatchWorkflowStatusStopped { + wp.updateTaskError(task, "批量工作流已被停止") + return + } + + // 解析输入参数 + var inputs map[string]string + if err := json.Unmarshal([]byte(task.Inputs), &inputs); err != nil { + wp.updateTaskError(task, "解析输入参数失败: "+err.Error()) + return + } + + // 检查输入参数是否全为空值 + hasNonEmptyValue := false + for _, value := range inputs { + if strings.TrimSpace(value) != "" { + hasNonEmptyValue = true + break + } + } + + // 如果所有输入都为空,跳过处理并标记为完成 + if !hasNonEmptyValue { + global.GVA_LOG.Info(fmt.Sprintf("任务 %s 包含全空值输入,跳过处理并标记为完成", task.ID)) + + // 创建空结果并标记为完成 + emptyResult := map[string]interface{}{ + "status": gaia.BatchTaskStatusCompleted, + "message": "跳过空值输入任务", + "outputs": map[string]interface{}{ + "text": "输入为空,已跳过处理", + }, + } + emptyResultJSON, _ := json.Marshal(emptyResult) + + // 更新任务状态为完成 + if err := global.GVA_DB.Model(task).Updates(map[string]interface{}{ + "status": gaia.BatchTaskStatusCompleted, + "result": string(emptyResultJSON), + "updated_at": time.Now(), + }).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新空值任务状态失败: %s", err.Error())) + return + } + + // 更新批量处理的已处理行数 + global.GVA_DB.Exec("UPDATE batch_workflows_extend SET processed_rows = processed_rows + 1, updated_at = ? WHERE id = ?", + time.Now(), batchWorkflow.ID) + + // 检查批量工作流是否完成 + wp.checkBatchWorkflowCompletion(batchWorkflow.ID) + return + } + + // 快速生成即时token + var err error + var token string + var user system.SysUser + if err = global.GVA_DB.Where( + "id = ? AND enable = ?", batchWorkflow.UserID, system.UserActive).First(&user).Error; err != nil { + wp.updateTaskError(task, "用户不存在: "+err.Error()) + return + } + // 生成这个用户的token + if token, _, err = utils.LoginToken(&user); err != nil { + wp.updateTaskError(task, "用户token生成失败: "+err.Error()) + return + } + + // 调用Dify API + result, err := wp.batchService.callDifyAPI(batchWorkflow.InstalledID, token, inputs) + if err != nil { + // 检查是否是余额不足错误(403状态码) + if strings.Contains(err.Error(), "状态码: 403") && strings.Contains(err.Error(), "Insufficient balance") { + global.GVA_LOG.Warn(fmt.Sprintf("用户 %d 余额不足,将其所有pending和processing状态的批量工作流和任务设置为失败", + batchWorkflow.UserID)) + wp.handleInsufficientBalance(batchWorkflow.UserID, task.BatchWorkflowID) + wp.updateTaskError(task, gaia.ErrorInsufficientBalance) + return + } + wp.updateTaskError(task, gaia.ErrorCallAPIFailed+": "+err.Error()) + return + } + + // 解析返回结果,检查是否有错误 + var apiResult map[string]interface{} + if err = json.Unmarshal([]byte(result), &apiResult); err != nil { + wp.updateTaskError(task, gaia.ErrorParseResultFailed+": "+err.Error()) + return + } + + // 检查API返回的状态 + if status, ok := apiResult["status"].(string); ok && status == gaia.BatchTaskStatusFailed { + // API执行失败,提取错误信息 + var apiError string + errorMsg := gaia.ErrorWorkflowFailed + if apiError, ok = apiResult["error"].(string); ok && apiError != "" { + errorMsg = apiError + } + // 检查是否是余额不足错误 + if strings.Contains(result, "call failed") || strings.Contains(apiError, "Insufficient balance") { + global.GVA_LOG.Warn(fmt.Sprintf("用户 %d 余额不足,将其所有pending和processing状态的批量工作流和任务设置为失败", + batchWorkflow.UserID)) + wp.handleInsufficientBalance(batchWorkflow.UserID, task.BatchWorkflowID) + wp.updateTaskError(task, gaia.ErrorInsufficientBalance) + return + } + // 其他类型的失败,标记为失败状态 + wp.updateTaskError(task, errorMsg) + return + } + + // API执行成功,更新任务结果 + if err = global.GVA_DB.Model(task).Updates(map[string]interface{}{ + "status": gaia.BatchTaskStatusCompleted, + "result": result, + "updated_at": time.Now(), + }).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新任务结果失败: %s", err.Error())) + return + } + + // 更新批量处理的已处理行数 + global.GVA_DB.Exec("UPDATE batch_workflows_extend SET processed_rows = processed_rows + 1, updated_at = ? WHERE id = ?", + time.Now(), batchWorkflow.ID) + + // 检查批量工作流是否完成 + wp.checkBatchWorkflowCompletion(batchWorkflow.ID) +} + +// decodeUnicodeEscapes 解码字符串中的 Unicode 转义序列 +func decodeUnicodeEscapes(input string) string { + // 尝试将字符串作为带引号的字符串进行解码 + if decoded, err := strconv.Unquote(`"` + input + `"`); err == nil { + return decoded + } + + // 如果直接解码失败,尝试逐个替换 Unicode 转义序列 + // 处理类似 \u897f\u73ed\u7259\u7ad9 这样的转义序列 + result := input + for { + // 查找下一个 \u 序列的起始位置 + startIdx := strings.Index(result, "\\u") + if startIdx == -1 { + break + } + + // 检查是否有足够的字符来形成一个完整的 Unicode 转义序列 + if startIdx+6 > len(result) { + break + } + + // 提取 Unicode 转义序列(包括 \u 和 4 位十六进制数字) + unicodeEscape := result[startIdx : startIdx+6] + + // 尝试解码这个单独的 Unicode 转义序列 + if decoded, err := strconv.Unquote(`"` + unicodeEscape + `"`); err == nil { + // 替换原字符串中的转义序列 + result = result[:startIdx] + decoded + result[startIdx+6:] + } else { + // 如果解码失败,跳过这个序列,防止无限循环 + result = result[:startIdx] + "?" + result[startIdx+6:] + } + } + + return result +} + +// updateTaskError 更新任务错误信息 +func (wp *WorkerPool) updateTaskError(task *gaia.BatchWorkflowTask, errorMsg string) { + // 解码错误信息中的 Unicode 转义序列 + decodedErrorMsg := decodeUnicodeEscapes(errorMsg) + global.GVA_LOG.Error(fmt.Sprintf("任务 %s 失败: %s", task.ID, decodedErrorMsg)) + + // 增加错误次数 + newErrorCount := task.ErrorCount + 1 + + // 更新批量工作流的错误次数和错误信息 + if err := global.GVA_DB.Exec("UPDATE batch_workflows_extend SET error_count = error_count + 1, error = ?, updated_at = ? WHERE id = ?", + decodedErrorMsg, time.Now(), task.BatchWorkflowID).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新批量工作流错误次数和错误信息失败: %s", err.Error())) + } else { + global.GVA_LOG.Debug(fmt.Sprintf("批量工作流 %s 错误次数已递增,错误信息已更新", task.BatchWorkflowID)) + } + + // 检查是否超过最大重试次数 + if newErrorCount >= gaia.MaxTaskRetryCount { + // 超过重试次数,标记为最终失败 + if err := global.GVA_DB.Model(task).Updates(map[string]interface{}{ + "status": gaia.BatchTaskStatusFailed, + "error": fmt.Sprintf("%s: %s", gaia.ErrorMaxRetryExceeded, decodedErrorMsg), + "error_count": newErrorCount, + "updated_at": time.Now(), + }).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新任务最终失败状态失败: %s", err.Error())) + } + global.GVA_LOG.Warn(fmt.Sprintf("任务 %s 重试次数已达上限(%d次),标记为最终失败", task.ID, gaia.MaxTaskRetryCount)) + } else { + // 未超过重试次数,重置为pending状态以便重试 + if err := global.GVA_DB.Model(task).Updates(map[string]interface{}{ + "status": gaia.BatchTaskStatusPending, + "error": decodedErrorMsg, + "error_count": newErrorCount, + "updated_at": time.Now(), + }).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新任务重试状态失败: %s", err.Error())) + } + global.GVA_LOG.Info(fmt.Sprintf("任务 %s 第%d次失败,重置为pending状态准备重试", task.ID, newErrorCount)) + } + + // 检查批量工作流状态 + wp.checkBatchWorkflowCompletion(task.BatchWorkflowID) +} + +// checkBatchWorkflowCompletion 检查批量工作流是否完成 +func (wp *WorkerPool) checkBatchWorkflowCompletion(batchWorkflowID string) { + var batchWorkflow gaia.BatchWorkflow + if err := global.GVA_DB.Where("id = ?", batchWorkflowID).First(&batchWorkflow).Error; err != nil { + return + } + + // 统计任务状态 + var pendingCount, queuedCount, runningCount, completedCount, failedCount int64 + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where( + "batch_workflow_id = ? AND status = ?", batchWorkflowID, gaia.BatchTaskStatusPending).Count(&pendingCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where( + "batch_workflow_id = ? AND status = ?", batchWorkflowID, gaia.BatchTaskStatusQueued).Count(&queuedCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where( + "batch_workflow_id = ? AND status = ?", batchWorkflowID, gaia.BatchTaskStatusRunning).Count(&runningCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where( + "batch_workflow_id = ? AND status = ?", batchWorkflowID, gaia.BatchTaskStatusCompleted).Count(&completedCount) + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}).Where( + "batch_workflow_id = ? AND status = ?", batchWorkflowID, gaia.BatchTaskStatusFailed).Count(&failedCount) + + // 如果所有任务都已完成 + if completedCount == int64(batchWorkflow.TotalRows) { + // 重置错误计数并更新状态 + if err := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("id = ?", batchWorkflowID).Updates(map[string]interface{}{ + "status": gaia.BatchWorkflowStatusCompleted, + "error": "", // 清空错误信息 + "error_count": 0, // 重置错误计数,恢复用户并发位 + "updated_at": time.Now(), + }).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新批量工作流完成状态失败: %s", err.Error())) + } else { + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 已完成,错误计数已重置,用户 %d 的并发位将恢复", batchWorkflowID, batchWorkflow.UserID)) + } + } else if pendingCount == 0 && queuedCount == 0 && runningCount == 0 && failedCount > 0 { + // 如果没有待处理、排队或运行中的任务,但有失败的任务 + // 获取第一个失败任务的错误信息作为代表 + var failedTask gaia.BatchWorkflowTask + var errorInfo string + if err := global.GVA_DB.Where("batch_workflow_id = ? AND status = ?", batchWorkflowID, gaia.BatchTaskStatusFailed). + First(&failedTask).Error; err == nil && failedTask.Error != "" { + errorInfo = failedTask.Error + } else { + errorInfo = gaia.ErrorWorkflowFailed + } + + global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("id = ?", batchWorkflowID).Updates(map[string]interface{}{ + "status": gaia.BatchWorkflowStatusFailed, + "error": errorInfo, + "updated_at": time.Now(), + }) + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 处理失败,错误信息: %s", batchWorkflowID, errorInfo)) + } +} + +// resetAbnormalTasks 重置异常状态的任务 +func resetAbnormalTasks() { + if global.GVA_DB == nil { + global.GVA_LOG.Error("数据库连接为空,无法重置异常任务状态") + return + } + + global.GVA_LOG.Info("开始重置异常状态的任务...") + + // 首先清理已停止的批量工作流中的待处理和排队任务 + cleanupStoppedBatchWorkflowTasks() + + // 重置 running 状态的任务为 pending + runningResult := global.GVA_DB.Model(&gaia.BatchWorkflowTask{}). + Where("status = ?", gaia.BatchTaskStatusRunning). + Update("status", gaia.BatchTaskStatusPending) + + if runningResult.Error != nil { + global.GVA_LOG.Error("重置running状态任务失败: " + runningResult.Error.Error()) + } else if runningResult.RowsAffected > 0 { + global.GVA_LOG.Info(fmt.Sprintf("重置了 %d 个running状态的任务为pending", runningResult.RowsAffected)) + } + + // 重置 queued 状态的任务为 pending + queuedResult := global.GVA_DB.Model(&gaia.BatchWorkflowTask{}). + Where("status = ?", gaia.BatchTaskStatusQueued). + Update("status", gaia.BatchTaskStatusPending) + + if queuedResult.Error != nil { + global.GVA_LOG.Error("重置queued状态任务失败: " + queuedResult.Error.Error()) + } else if queuedResult.RowsAffected > 0 { + global.GVA_LOG.Info(fmt.Sprintf("重置了 %d 个queued状态的任务为pending", queuedResult.RowsAffected)) + } + + // 重置相关批量工作流的状态 + // 如果批量工作流状态为 processing 但没有 running 或 queued 的任务,将其重置为 pending + var batchWorkflows []gaia.BatchWorkflow + err := global.GVA_DB.Where("status = ?", gaia.BatchWorkflowStatusProcessing).Find(&batchWorkflows).Error + if err != nil { + global.GVA_LOG.Error("查询processing状态的批量工作流失败: " + err.Error()) + return + } + + for _, bw := range batchWorkflows { + var runningCount int64 + global.GVA_DB.Model(&gaia.BatchWorkflowTask{}). + Where("batch_workflow_id = ? AND status IN (?)", bw.ID, []string{gaia.BatchTaskStatusRunning, gaia.BatchTaskStatusQueued}). + Count(&runningCount) + + // 如果没有正在运行或排队的任务,将批量工作流状态重置为 pending + if runningCount == 0 { + if err = global.GVA_DB.Model(&gaia.BatchWorkflow{}). + Where("id = ?", bw.ID). + Update("status", gaia.BatchWorkflowStatusPending).Error; err != nil { + global.GVA_LOG.Error(fmt.Sprintf("重置批量工作流 %s 状态失败: %s", bw.ID, err.Error())) + } else { + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 状态从processing重置为pending", bw.ID)) + } + } + } + + global.GVA_LOG.Info("异常状态任务重置完成") +} + +// cleanupStoppedBatchWorkflowTasks 清理已停止的批量工作流中的待处理和排队任务 +func cleanupStoppedBatchWorkflowTasks() { + + // 将已停止的批量工作流中的pending和queued任务标记为cancelled + // 使用子查询方式避免JOIN在UPDATE中的别名问题 + result := global.GVA_DB.Table("batch_workflow_tasks_extend"). + Where("batch_workflow_id IN (?) AND status IN (?)", + global.GVA_DB.Table("batch_workflows_extend").Select("id").Where("status = ?", gaia.BatchWorkflowStatusStopped), + []string{gaia.BatchTaskStatusPending, gaia.BatchTaskStatusQueued}). + Update("status", gaia.BatchTaskStatusCancelled) + + if result.Error != nil { + global.GVA_LOG.Error("清理已停止的批量工作流任务失败: " + result.Error.Error()) + return + } + + if result.RowsAffected > 0 { + global.GVA_LOG.Info(fmt.Sprintf("已清理 %d 个已停止批量工作流中的待处理和排队任务", result.RowsAffected)) + } +} + +// 全局工作池实例 +var globalWorkerPool *WorkerPool + +// InitWorkerPool 初始化全局工作池 +func InitWorkerPool(workers int) { + if globalWorkerPool != nil { + globalWorkerPool.Stop() + } + + // 重置所有异常状态的任务 + resetAbnormalTasks() + + globalWorkerPool = NewWorkerPool(workers) + globalWorkerPool.Start() +} + +// GetWorkerPool 获取全局工作池 +func GetWorkerPool() *WorkerPool { + return globalWorkerPool +} + +// StopWorkerPool 停止全局工作池 +func StopWorkerPool() { + if globalWorkerPool != nil { + globalWorkerPool.Stop() + globalWorkerPool = nil + } +} + +// ResetBatchWorkflowErrorCount 重置指定批量工作流的错误计数 +func ResetBatchWorkflowErrorCount(batchWorkflowID string) error { + if global.GVA_DB == nil { + return fmt.Errorf("数据库连接未初始化") + } + + // 获取批量工作流信息 + var batchWorkflow gaia.BatchWorkflow + if err := global.GVA_DB.Where("id = ?", batchWorkflowID).First(&batchWorkflow).Error; err != nil { + return fmt.Errorf("批量工作流不存在: %v", err) + } + + // 重置错误计数 + if err := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("id = ?", batchWorkflowID).Updates(map[string]interface{}{ + "error_count": 0, + "updated_at": time.Now(), + }).Error; err != nil { + return fmt.Errorf("重置错误计数失败: %v", err) + } + + global.GVA_LOG.Info(fmt.Sprintf("批量工作流 %s 的错误计数已手动重置,用户 %d 的并发位将恢复", batchWorkflowID, batchWorkflow.UserID)) + return nil +} + +// ResetUserErrorCount 重置指定用户所有批量工作流的错误计数 +func ResetUserErrorCount(userID uint) error { + if global.GVA_DB == nil { + return fmt.Errorf("数据库连接未初始化") + } + + // 重置该用户所有批量工作流的错误计数 + result := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where("user_id = ?", userID).Updates(map[string]interface{}{ + "error_count": 0, + "updated_at": time.Now(), + }) + + if result.Error != nil { + return fmt.Errorf("重置用户错误计数失败: %v", result.Error) + } + + global.GVA_LOG.Info(fmt.Sprintf("用户 %d 的所有批量工作流错误计数已重置,影响 %d 个工作流,并发位将恢复", userID, result.RowsAffected)) + return nil +} + +// handleInsufficientBalance 处理余额不足的情况,将用户所有pending和processing状态的工作流和任务设置为失败 +// 特别处理同batch_workflow_id的所有任务 +func (wp *WorkerPool) handleInsufficientBalance(userID uint, currentBatchWorkflowID string) { + if global.GVA_DB == nil { + global.GVA_LOG.Error("数据库连接未初始化,无法处理余额不足情况") + return + } + + // 优先处理当前batch_workflow_id的所有任务(包括processing状态) + currentWorkflowResult := global.GVA_DB.Model(&gaia.BatchWorkflow{}). + Where("id = ? AND status IN (?)", currentBatchWorkflowID, []string{gaia.BatchWorkflowStatusPending, gaia.BatchWorkflowStatusProcessing}). + Updates(map[string]interface{}{ + "status": gaia.BatchWorkflowStatusFailed, + "error": gaia.ErrorInsufficientBalance, + "updated_at": time.Now(), + }) + + if currentWorkflowResult.Error != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新批量工作流 %s 状态失败: %s", currentBatchWorkflowID, currentWorkflowResult.Error.Error())) + } else if currentWorkflowResult.RowsAffected > 0 { + global.GVA_LOG.Info(fmt.Sprintf("已将批量工作流 %s 设置为失败状态", currentBatchWorkflowID)) + } + + // 将当前batch_workflow_id的所有未完成任务设置为失败 + currentTaskResult := global.GVA_DB.Table("batch_workflow_tasks_extend"). + Where("batch_workflow_id = ? AND status IN (?)", currentBatchWorkflowID, []string{gaia.BatchTaskStatusPending, gaia.BatchTaskStatusQueued, gaia.BatchTaskStatusRunning}). + Updates(map[string]interface{}{ + "status": gaia.BatchTaskStatusFailed, + "error": gaia.ErrorInsufficientBalance, + "updated_at": time.Now(), + }) + + if currentTaskResult.Error != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新批量工作流 %s 的任务状态失败: %s", currentBatchWorkflowID, currentTaskResult.Error.Error())) + } else if currentTaskResult.RowsAffected > 0 { + global.GVA_LOG.Info(fmt.Sprintf("已将批量工作流 %s 的 %d 个任务设置为失败状态", currentBatchWorkflowID, currentTaskResult.RowsAffected)) + } + + // 将用户其他所有pending状态的批量工作流设置为失败 + otherWorkflowResult := global.GVA_DB.Model(&gaia.BatchWorkflow{}). + Where("user_id = ? AND id != ? AND status = ?", userID, currentBatchWorkflowID, gaia.BatchWorkflowStatusPending). + Updates(map[string]interface{}{ + "status": gaia.BatchWorkflowStatusFailed, + "error": gaia.ErrorInsufficientBalance, + "updated_at": time.Now(), + }) + + if otherWorkflowResult.Error != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新用户 %d 其他pending批量工作流状态失败: %s", userID, otherWorkflowResult.Error.Error())) + } else if otherWorkflowResult.RowsAffected > 0 { + global.GVA_LOG.Info(fmt.Sprintf("已将用户 %d 的 %d 个其他pending批量工作流设置为失败状态", userID, otherWorkflowResult.RowsAffected)) + } + + // 将用户其他所有pending状态的批量工作流任务设置为失败 + otherTaskResult := global.GVA_DB.Table("batch_workflow_tasks_extend"). + Where("batch_workflow_id IN (?) AND batch_workflow_id != ? AND status = ?", + global.GVA_DB.Table("batch_workflows_extend").Select("id").Where("user_id = ?", userID), + currentBatchWorkflowID, + gaia.BatchTaskStatusPending). + Updates(map[string]interface{}{ + "status": gaia.BatchTaskStatusFailed, + "error": gaia.ErrorInsufficientBalance, + "updated_at": time.Now(), + }) + + if otherTaskResult.Error != nil { + global.GVA_LOG.Error(fmt.Sprintf("更新用户 %d 其他pending批量工作流任务状态失败: %s", userID, otherTaskResult.Error.Error())) + } else if otherTaskResult.RowsAffected > 0 { + global.GVA_LOG.Info(fmt.Sprintf("已将用户 %d 的 %d 个其他pending批量工作流任务设置为失败状态", userID, otherTaskResult.RowsAffected)) + } +} + +// GetBatchWorkflowList 获取最近30天的批量工作流列表 +func (s *BatchWorkflowService) GetBatchWorkflowList(userID uint, installedID string, page, limit int) ([]gaia.BatchWorkflow, int64, error) { + if global.GVA_DB == nil { + return nil, 0, fmt.Errorf("数据库连接未初始化") + } + + // 计算30天前的时间 + thirtyDaysAgo := time.Now().AddDate(0, 0, -30) + + // 构建查询条件 + query := global.GVA_DB.Model(&gaia.BatchWorkflow{}). + Where("user_id = ? AND created_at >= ?", userID, thirtyDaysAgo) + + // 如果指定了installedID,则添加该条件 + if installedID != "" { + query = query.Where("installed_id = ?", installedID) + } + + // 获取总数 + var total int64 + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // 分页查询 + var batchWorkflows []gaia.BatchWorkflow + offset := (page - 1) * limit + if err := query.Order("created_at DESC"). + Limit(limit). + Offset(offset). + Find(&batchWorkflows).Error; err != nil { + return nil, 0, err + } + + // 解码错误信息中的 Unicode 转义序列 + for i := range batchWorkflows { + if batchWorkflows[i].Error != "" { + batchWorkflows[i].Error = decodeUnicodeEscapes(batchWorkflows[i].Error) + } + } + + return batchWorkflows, total, nil +} diff --git a/admin/server/service/system/sys_user.go b/admin/server/service/system/sys_user.go index 497aa35d3..c89618fd4 100644 --- a/admin/server/service/system/sys_user.go +++ b/admin/server/service/system/sys_user.go @@ -39,9 +39,22 @@ var UserServiceApp = new(UserService) // @return: err error, userInter *model.SysUser func (userService *UserService) Register(u system.SysUser, token string) (userInter system.SysUser, err error) { var user system.SysUser + // 首先检查email是否已注册 if !errors.Is(global.GVA_DB.Where("email = ?", u.Email).First(&user).Error, gorm.ErrRecordNotFound) { + global.GVA_LOG.Info(fmt.Sprintf("用户email已存在: %s", u.Email)) return userInter, errors.New("用户名已注册") } + + // 如果传入了UUID,检查UUID是否已存在 + if u.UUID != uuid.Nil { + var existingUser system.SysUser + if !errors.Is(global.GVA_DB.Where("uuid = ?", u.UUID).First(&existingUser).Error, gorm.ErrRecordNotFound) { + global.GVA_LOG.Info(fmt.Sprintf("用户UUID已存在: %s, email: %s", u.UUID, u.Email)) + // UUID已存在,返回已存在的用户而不是报错(用于SyncUser场景) + return existingUser, nil + } + } + global.GVA_LOG.Debug("注册用户信息:", zap.Any("1", 1)) // Extend Start: Gaia Register User @@ -50,9 +63,18 @@ func (userService *UserService) Register(u system.SysUser, token string) (userIn } // Extend Stop: Gaia Register User + // 再次检查email是否已注册(防止并发创建) + if !errors.Is(global.GVA_DB.Where("email = ?", u.Email).First(&user).Error, gorm.ErrRecordNotFound) { + global.GVA_LOG.Info(fmt.Sprintf("并发检测:用户email已被创建: %s", u.Email)) + return user, nil + } + // 否则 附加uuid 密码hash加密 注册 u.Password = utils.BcryptHash(u.Password) - u.UUID = uuid.Must(uuid.NewV4()) + // 如果没有设置UUID,才生成新的UUID + if u.UUID == uuid.Nil { + u.UUID = uuid.Must(uuid.NewV4()) + } err = global.GVA_DB.Create(&u).Error return u, err } diff --git a/admin/server/service/system/sys_user_extend.go b/admin/server/service/system/sys_user_extend.go index fd037d454..c38c6921c 100644 --- a/admin/server/service/system/sys_user_extend.go +++ b/admin/server/service/system/sys_user_extend.go @@ -3,6 +3,7 @@ package system import ( "fmt" "strings" + "sync" "github.com/flipped-aurora/gin-vue-admin/server/global" "github.com/flipped-aurora/gin-vue-admin/server/model/gaia" @@ -11,6 +12,9 @@ import ( "github.com/gofrs/uuid/v5" ) +// 全局互斥锁,防止SyncUser并发执行 +var syncUserMutex sync.Mutex + //@author: [piexlmax](https://github.com/piexlmax) //@function: Register //@description: 用户注册 @@ -47,6 +51,10 @@ func (userService *UserExtendService) OaLogin(u *system.SysUser) (userInter *sys // @param: u *model.SysUser // @return: err error, userInter *model.SysUser func (userService *UserExtendService) SyncUser() { + // 使用互斥锁防止并发执行 + syncUserMutex.Lock() + defer syncUserMutex.Unlock() + // init var err error var isInit = true diff --git a/admin/server/source/gaia/forwarding_extend.go b/admin/server/source/gaia/forwarding_extend.go new file mode 100644 index 000000000..107999b27 --- /dev/null +++ b/admin/server/source/gaia/forwarding_extend.go @@ -0,0 +1,83 @@ +package gaia + +import ( + "context" + "github.com/flipped-aurora/gin-vue-admin/server/model/gaia" + "github.com/flipped-aurora/gin-vue-admin/server/service/system" + "github.com/gofrs/uuid/v5" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +const initOrderForwardingExtend = system.InitOrderInternal + 1 + +type initForwardingExtend struct{} + +// auto run +func init() { + system.RegisterInit(initOrderForwardingExtend, &initForwardingExtend{}) +} + +func (i *initForwardingExtend) MigrateTable(ctx context.Context) (context.Context, error) { + db, ok := ctx.Value("db").(*gorm.DB) + if !ok { + return ctx, system.ErrMissingDBContext + } + return ctx, db.AutoMigrate(&gaia.ForwardingExtend{}) +} + +func (i *initForwardingExtend) TableCreated(ctx context.Context) bool { + db, ok := ctx.Value("db").(*gorm.DB) + if !ok { + return false + } + return db.Migrator().HasTable(&gaia.ForwardingExtend{}) +} + +func (i initForwardingExtend) InitializerName() string { + return gaia.ForwardingExtend{}.TableName() +} + +func (i *initForwardingExtend) InitializeData(ctx context.Context) (context.Context, error) { + db, ok := ctx.Value("db").(*gorm.DB) + if !ok { + return ctx, system.ErrMissingDBContext + } + + // 使用指定的 UUID + id, err := uuid.FromString("dbb08cae-2118-469c-a991-0c8f3f2515da") + if err != nil { + return ctx, errors.Wrap(err, "解析 UUID 失败") + } + + entities := []gaia.ForwardingExtend{ + { + ID: id, + Path: "workflow", + Address: "http://admin-server:8888/gaia/workflow/", + Header: "[]", + Description: "", + }, + } + + if err := db.Create(&entities).Error; err != nil { + return ctx, errors.Wrap(err, gaia.ForwardingExtend{}.TableName()+"表数据初始化失败!") + } + + next := context.WithValue(ctx, i.InitializerName(), entities) + return next, nil +} + +func (i *initForwardingExtend) DataInserted(ctx context.Context) bool { + db, ok := ctx.Value("db").(*gorm.DB) + if !ok { + return false + } + + // 检查是否存在指定的记录 + if errors.Is(db.Where("id = ?", "dbb08cae-2118-469c-a991-0c8f3f2515da"). + First(&gaia.ForwardingExtend{}).Error, gorm.ErrRecordNotFound) { + return false + } + return true +} diff --git a/admin/server/source/system/api.go b/admin/server/source/system/api.go index 923979a47..67559f5df 100644 --- a/admin/server/source/system/api.go +++ b/admin/server/source/system/api.go @@ -207,6 +207,19 @@ func (i *initApi) InitializeData(ctx context.Context) (context.Context, error) { {ApiGroup: "应用集成配置", Method: "GET", Path: "/gaia/system/oauth2", Description: "设置OAuth2配置"}, {ApiGroup: "应用集成配置", Method: "POST", Path: "/gaia/system/oauth2", Description: "获取OAuth2集成配置"}, // Extend Stop: oauth2 + + // Extend Start: batch workflow + {ApiGroup: "批量处理工作流", Method: "POST", Path: "/gaia/workflow/batch/processing", Description: "创建批量处理"}, + {ApiGroup: "批量处理工作流", Method: "GET", Path: "/gaia/workflow/batch/list", Description: "获取最近30天的批量工作流列表"}, + {ApiGroup: "批量处理工作流", Method: "GET", Path: "/gaia/workflow/batch/:id", Description: "获取批量处理信息"}, + {ApiGroup: "批量处理工作流", Method: "GET", Path: "/gaia/workflow/batch/:id/tasks", Description: "获取任务列表"}, + {ApiGroup: "批量处理工作流", Method: "GET", Path: "/gaia/workflow/batch/:id/progress", Description: "获取进度信息"}, + {ApiGroup: "批量处理工作流", Method: "POST", Path: "/gaia/workflow/batch/:id/stop", Description: "停止批量处理"}, + {ApiGroup: "批量处理工作流", Method: "POST", Path: "/gaia/workflow/batch/:id/retry", Description: "重试批量处理(重新开始所有任务)"}, + {ApiGroup: "批量处理工作流", Method: "POST", Path: "/gaia/workflow/batch/:id/retry-failed", Description: "仅重试失败的任务"}, + {ApiGroup: "批量处理工作流", Method: "POST", Path: "/gaia/workflow/batch/:id/resume", Description: "恢复批量处理"}, + {ApiGroup: "批量处理工作流", Method: "GET", Path: "/gaia/workflow/batch/:id/download", Description: "下载结果"}, + // Extend Stop: batch workflow } if err := db.Create(&entities).Error; err != nil { return ctx, errors.Wrap(err, sysModel.SysApi{}.TableName()+"表数据初始化失败!") diff --git a/admin/server/source/system/casbin.go b/admin/server/source/system/casbin.go index caab8f718..2490dc595 100644 --- a/admin/server/source/system/casbin.go +++ b/admin/server/source/system/casbin.go @@ -293,6 +293,48 @@ func (i *initCasbin) InitializeData(ctx context.Context) (context.Context, error {Ptype: "p", V0: "888", V1: "/gaia/system/oauth2", V2: "GET"}, {Ptype: "p", V0: "888", V1: "/gaia/system/oauth2", V2: "POST"}, // Extend Stop: oauth2 + + // Extend Start: batch workflow + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/processing", V2: "POST"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id", V2: "GET"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id/tasks", V2: "GET"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id/progress", V2: "GET"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id/stop", V2: "POST"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id/retry", V2: "POST"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id/retry-failed", V2: "POST"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id/resume", V2: "POST"}, + {Ptype: "p", V0: "888", V1: "/gaia/workflow/batch/:id/download", V2: "GET"}, + + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/processing", V2: "POST"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id", V2: "GET"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id/tasks", V2: "GET"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id/progress", V2: "GET"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id/stop", V2: "POST"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id/retry", V2: "POST"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id/retry-failed", V2: "POST"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id/resume", V2: "POST"}, + {Ptype: "p", V0: "8881", V1: "/gaia/workflow/batch/:id/download", V2: "GET"}, + + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/processing", V2: "POST"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id", V2: "GET"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id/tasks", V2: "GET"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id/progress", V2: "GET"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id/stop", V2: "POST"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id/retry", V2: "POST"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id/retry-failed", V2: "POST"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id/resume", V2: "POST"}, + {Ptype: "p", V0: "9528", V1: "/gaia/workflow/batch/:id/download", V2: "GET"}, + + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/processing", V2: "POST"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id", V2: "GET"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id/tasks", V2: "GET"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id/progress", V2: "GET"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id/stop", V2: "POST"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id/retry", V2: "POST"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id/retry-failed", V2: "POST"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id/resume", V2: "POST"}, + {Ptype: "p", V0: "1", V1: "/gaia/workflow/batch/:id/download", V2: "GET"}, + // Extend Stop: batch workflow } if err := db.Create(&entities).Error; err != nil { return ctx, errors.Wrap(err, "Casbin 表 ("+i.InitializerName()+") 数据初始化失败!") diff --git a/admin/server/utils/claims.go b/admin/server/utils/claims.go index e2e5de9f8..ddda2487e 100644 --- a/admin/server/utils/claims.go +++ b/admin/server/utils/claims.go @@ -44,6 +44,12 @@ func SetToken(c *gin.Context, token string, maxAge int) { func GetToken(c *gin.Context) string { // Extend Start: Admin and Gaia JWT token, _ := c.Cookie("x-token") + if len(token) == 0 { + token = c.Request.Header.Get("Authorization") + } + if len(token) > 7 && token[0:7] == "Bearer " { + token = token[7:] + } if token == "" { j := NewJWT() token, _ = c.Cookie("x-token") @@ -65,20 +71,39 @@ func GetClaims(c *gin.Context) (*systemReq.CustomClaims, error) { if err != nil { global.GVA_LOG.Error("从Gin的Context中获取从jwt解析信息失败, 请检查请求头是否存在x-token且claims是否为规定结构") } + // 判断是否dify的token + if claims.Username == "" { + var user system.SysUser + var account gaia.Account + if err = global.GVA_DB.Where("uuid=?", claims.UserId).First(&user).Error; err == nil { + claims.BaseClaims.ID = user.ID + claims.Username = user.Username + claims.AuthorityId = user.AuthorityId + } else if err = global.GVA_DB.Where("id=?", claims.UserId).First(&account).Error; err == nil { + if err = global.GVA_DB.Where("email=?", account.Email).First(&user).Error; err == nil { + claims.AuthorityId = user.AuthorityId + claims.Username = user.Username + claims.BaseClaims.ID = user.ID + user.UUID = account.ID + global.GVA_DB.Save(&user) + } + } + } return claims, err } // GetUserID 从Gin的Context中获取从jwt解析出来的用户ID func GetUserID(c *gin.Context) uint { - if claims, exists := c.Get("claims"); !exists { - if cl, err := GetClaims(c); err != nil { - return 0 - } else { - return cl.BaseClaims.ID - } - } else { + if claims, exists := c.Get("claims"); exists { waitUse := claims.(*systemReq.CustomClaims) - return waitUse.BaseClaims.ID + if waitUse.BaseClaims.ID != 0 { + return waitUse.BaseClaims.ID + } + } + if cl, err := GetClaims(c); err != nil { + return 0 + } else { + return cl.BaseClaims.ID } } diff --git a/api/Dockerfile b/api/Dockerfile index 79a489276..f209e1874 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -12,7 +12,7 @@ RUN pip install --no-cache-dir uv==${UV_VERSION} FROM base AS packages # if you located in China, you can use aliyun mirror to speed up -# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources +RUN sed -i 's@deb.debian.org@mirrors.ustc.edu.cn@g' /etc/apt/sources.list.d/debian.sources RUN apt-get update \ && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev diff --git a/api/README.md b/api/README.md index 5ecf92a4f..5b7ce82d8 100644 --- a/api/README.md +++ b/api/README.md @@ -65,6 +65,7 @@ ```bash uv run flask db upgrade + uv run flask extend_db upgrade ``` 1. Start backend diff --git a/api/commands.py b/api/commands.py index f18d7fe1f..2a5844560 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1383,10 +1383,10 @@ def extend_db_heads(): def _run_alembic_command_extend(command, *args): """运行 alembic 命令""" import os - import sys - from flask import current_app - from alembic.config import Config + from alembic import command as alembic_command + from alembic.config import Config + from flask import current_app # 获取 api 目录的绝对路径 api_dir = os.path.abspath(os.path.dirname(__file__)) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 314b0e369..465c36eea 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -9,11 +9,11 @@ from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig from .enterprise import EnterpriseFeatureConfig +from .extend import ExtendConfig # 二开部分 新增配置 from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig from .observability import ObservabilityConfig -from .extend import ExtendConfig # 二开部分 新增配置 from .packaging import PackagingInfo from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName from .remote_settings_sources.apollo import ApolloSettingsSource diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 6330df7a8..69514ebfe 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -49,9 +49,10 @@ from . import admin, apikey, extension, feature, ping, setup, version from .app import ( advanced_prompt_template, agent, + ai_draw_extnd, # Extend: The backend implements direct proxy forwarding of the API annotation, app, - app_extend, # 二开部分:新增同步应用到模版中心 + app_extend, # 二开部分:新增同步应用到模版中心 audio, completion, conversation, @@ -73,7 +74,16 @@ from .app import ( ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server, register_extend # 二开部分: 新增用户(调用dify注册接口) +from .auth import ( # 二开部分: 新增用户(调用dify注册接口) + activate, + data_source_bearer_auth, + data_source_oauth, + forgot_password, + login, + oauth, + oauth_server, + register_extend, +) # Import billing controllers from .billing import billing, compliance diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 1a2487d8c..3d90bf43e 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -5,8 +5,10 @@ from flask import request # 二开部分 - 密钥额度限制 from flask_login import current_user from flask_restx import Resource, fields, marshal_with from sqlalchemy import select -from sqlalchemy.orm import Session -from sqlalchemy.orm import aliased # 二开部分 - 密钥额度限制 +from sqlalchemy.orm import ( + Session, + aliased, # 二开部分 - 密钥额度限制 +) from werkzeug.exceptions import Forbidden from extensions.ext_database import db diff --git a/api/controllers/console/app/ai_draw_extnd.py b/api/controllers/console/app/ai_draw_extnd.py new file mode 100644 index 000000000..b39f2c935 --- /dev/null +++ b/api/controllers/console/app/ai_draw_extnd.py @@ -0,0 +1,182 @@ +""" +转发相关接口 +Created on 2024-03-21 +""" + +import concurrent.futures +import logging + +from flask import Response, current_app, request +from flask_restful import Resource + +from controllers.console import api +from libs.login_extend import repost_login_required +from services.ai_draw_extend import AiDrawForwarding +from services.billing_extend import AiDrawBilling + +logging.basicConfig(level=logging.DEBUG) +# 创建一个线程池 +executor = concurrent.futures.ThreadPoolExecutor() + + +class AiDrawTransit(Resource): + def __init__(self, *args, **kwargs): + # Destination address + self.target_url = current_app.config.get("HOSTED_FETCH_APP_TEMPLATES_MODE") + + def get(self, path): + pass + + def post(self, path): + pass + + def put(self, path): + pass + + def delete(self, path): + pass + + def patch(self, path): + pass + + def options(self, path): + pass + + @repost_login_required + def dispatch_request(self, *args, **kwargs): + # Replace with the address of the target server + print('1111') + path = kwargs.get("path", "") + path_list = path.split("/") + auth_header = request.headers.get("Authorization") + if auth_header is None: + auth_header = "Bearer " + request.cookies.get("x-token") + if len(path_list) < 1: + return Response("router error", status=500) + # obtains forwarding domain name + logging.warning("obtains forwarding domain name: {}".format(path_list[0])) + forwarding = AiDrawForwarding.get_forwarding(path_list[0]) + print(forwarding) + logging.warning("forwarding: {}".format(forwarding.id)) + if forwarding is None: + return Response("router is none", status=500) + # 使用线程池来运行异步函数 + return AiDrawBilling.billing_forward(forwarding, path_list, kwargs, auth_header, path) + + +# class YouDaoTranslationPictures(Resource): +# """有道翻译图片接口""" +# +# @setup_required +# @login_required +# def post(self): +# """ +# 翻译图片接口 +# --- +# 请求参数: +# - images: list[str] base64编码的图片列表 +# - language: str 目标语言代码 +# 返回: +# - code: int 状态码 +# - message: str 提示信息 +# - data: list[str] 翻译后的base64图片列表 +# """ +# parser = reqparse.RequestParser() +# parser.add_argument("language", type=str, required=True, location="json") +# parser.add_argument("image", type=str, required=True, location="json") +# parser.add_argument("from_code", type=str, required=True, location="json") +# args = parser.parse_args() +# +# if not args.image or not args.language: +# response_data = {"code": 400, "message": '参数错误:images和language不能为空', "data": None} +# response = make_response(response_data) +# self._add_cors_headers(response) +# return response +# +# # 翻译图片 +# forwarding = AiDrawForwarding.get_forwarding("youdao_ocr_translate") +# if forwarding is not None: +# AiDrawBilling.calculate_user_billing_information(current_user.id, forwarding.id, "/translate", args) +# img_url, err = AiDrawBilling.ocr_translate( +# image_base64=args.image, +# from_code=args.from_code, +# to_lang_code=args.language, +# ) +# if err != "": +# response_data = {"code": 500, "message": err, "data": None} +# response = make_response(response_data) +# self._add_cors_headers(response) +# return response +# else: +# # Extend start: 绘图 翻译图片有道的base64改储存到本地 +# try: +# # 解码 base64 图片数据 +# extension = 'png' +# mime_type = 'image/png' +# +# # 确保 base64 字符串格式正确 +# base64_data = img_url +# # 如果 img_url 已经包含 data URL 前缀,提取纯 base64 部分 +# if base64_data.startswith('data:image/'): +# base64_data = base64_data.split(',', 1)[1] +# +# # 添加必要的 padding +# missing_padding = len(base64_data) % 4 +# if missing_padding: +# base64_data += '=' * (4 - missing_padding) +# +# # 解码 base64 数据 +# image_content = base64.b64decode(base64_data) +# +# # 生成文件名 +# filename = f"translated_image_{uuid.uuid4().hex[:8]}.{extension}" +# +# # 使用 FileService 保存文件 +# upload_file = FileService.upload_file( +# filename=filename, +# content=image_content, +# mimetype=mime_type, +# user=current_user +# ) +# +# # 生成可访问的 URL +# base_url = dify_config.FILES_URL +# image_preview_url = f"{base_url}/files/{upload_file.id}/image-preview" +# signed_url = UrlSigner.get_signed_url( +# url=image_preview_url, +# sign_key=upload_file.id, +# prefix="image-preview" +# ) +# +# response_data = { +# 'code': 200, +# 'message': '翻译成功', +# 'data': { +# 'image_url': signed_url, +# 'file_id': upload_file.id +# } +# } +# response = make_response(response_data) +# self._add_cors_headers(response) +# return response +# +# except Exception as e: +# logging.error(f"保存翻译图片失败: {str(e)}") +# response_data = {"code": 500, "message": f'保存翻译图片失败: {str(e)}', "data": None} +# response = make_response(response_data) +# self._add_cors_headers(response) +# return response +# # Extend stop: 绘图 翻译图片有道的base64改储存到本地 +# + def _add_cors_headers(self, response): + """添加CORS头部""" + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS, DELETE" + response.headers["Access-Control-Max-Age"] = "3600" + response.headers["Access-Control-Allow-Headers"] = "x-requested-with,Authorization,token, content-type" + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["X-Accel-Redirect"] = "" + + +api.add_resource(AiDrawTransit, "/extend/") +# api.add_resource(YouDaoTranslationPictures, "/youdao/translation/pictures") diff --git a/api/controllers/console/app/app_extend.py b/api/controllers/console/app/app_extend.py index fb660cff7..28fa0edf9 100644 --- a/api/controllers/console/app/app_extend.py +++ b/api/controllers/console/app/app_extend.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_restx import Resource, marshal_with from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/app/ding_talk_extend.py b/api/controllers/console/app/ding_talk_extend.py index cf9ee40c6..d01c2cc84 100644 --- a/api/controllers/console/app/ding_talk_extend.py +++ b/api/controllers/console/app/ding_talk_extend.py @@ -1,9 +1,9 @@ from flask import redirect, request -from flask_restful import Resource, reqparse +from flask_restx import Resource from controllers.console.app.error_extend import DingTalkNotExist -from services.ding_talk_extend import DingTalkService from controllers.console.wraps import setup_required +from services.ding_talk_extend import DingTalkService from .. import api diff --git a/api/controllers/console/app/passport_extend.py b/api/controllers/console/app/passport_extend.py index da447e01d..552124a05 100644 --- a/api/controllers/console/app/passport_extend.py +++ b/api/controllers/console/app/passport_extend.py @@ -1,7 +1,7 @@ from datetime import UTC, datetime, timedelta from flask import request -from flask_restful import Resource +from flask_restx import Resource from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5ce7cfb80..543bbe798 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -18,7 +18,6 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OaOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService -from services.account_service_extend import TenantExtendService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService diff --git a/api/controllers/console/auth/register_extend.py b/api/controllers/console/auth/register_extend.py index 06b7701ab..d2e26bd26 100644 --- a/api/controllers/console/auth/register_extend.py +++ b/api/controllers/console/auth/register_extend.py @@ -3,7 +3,7 @@ from datetime import UTC, datetime import jwt from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from configs import dify_config from controllers.console import api diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index a523c9ba4..d32f57988 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -15,8 +15,8 @@ from controllers.console.app.error import ( ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.console.money_extend import money_limit +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -31,10 +31,10 @@ from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService -from services.errors.llm import InvokeRateLimitError from services.app_generate_service_extend import ( AppGenerateServiceExtend, # Extend: App Center - Recommended list sorted by usage frequency ) +from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index bf5938bf8..774949bd8 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -11,8 +11,8 @@ from controllers.console.app.error import ( ) from controllers.console.explore.error import NotWorkflowAppError from controllers.console.explore.wraps import InstalledAppResource -from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.console.money_extend import money_limit +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -25,10 +25,10 @@ from libs import helper from libs.login import current_user from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService -from services.errors.llm import InvokeRateLimitError from services.app_generate_service_extend import ( AppGenerateServiceExtend, # Extend: App Center - Recommended list sorted by usage frequency ) +from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) diff --git a/api/controllers/console/workspace/account_extend.py b/api/controllers/console/workspace/account_extend.py index 0a028f33d..1239df1b4 100644 --- a/api/controllers/console/workspace/account_extend.py +++ b/api/controllers/console/workspace/account_extend.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_restx import Resource, marshal_with from configs import dify_config from controllers.console import api diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 03358f6e4..c97bab8b1 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -138,8 +138,7 @@ class WorkflowRunApi(Resource): 500: "Internal server error", } ) - @validate_app_token - @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # 二开部分End - 密钥额度限制,api_token """Execute a workflow. @@ -154,6 +153,11 @@ class WorkflowRunApi(Resource): external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id + + # ------------------- 二开部分Begin - 密钥额度限制 ------------------- + args["api_token"] = api_token + # # ------------------- 二开部分End - 密钥额度限制 ------------------- + streaming = args.get("response_mode") == "streaming" try: @@ -196,7 +200,7 @@ class WorkflowRunByIdApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, workflow_id: str): + def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): """Run specific workflow by ID. Executes a specific workflow version identified by its ID. diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 80bf73ec8..904f37a0c 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,5 +1,5 @@ -import time import logging # ---------------------二开部分 密钥额度限制 --------------------- +import time from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto @@ -23,12 +23,17 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import _get_user -from models.account import Account, Tenant, TenantAccountJoin, TenantStatus, TenantAccountRole # 二开部分 额度限制,API调用计费,新增TenantAccountRole -from models.dataset import Dataset, RateLimitLog +from models.account import ( # 二开部分 额度限制,API调用计费,新增TenantAccountRole + Account, + Tenant, + TenantAccountJoin, + TenantStatus, +) from models.account_money_extend import AccountMoneyExtend from models.api_token_money_extend import ( ApiTokenMoneyExtend, # 二开部分 密钥额度限制 ) +from models.dataset import Dataset, RateLimitLog from models.model import ApiToken, App, EndUser from models.model_extend import ( EndUserAccountJoinsExtend, # 二开部分 额度限制,API调用计费 diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index c3cc539e4..f70b947e1 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,7 @@ import logging from flask import request # ----------------- start You must log in to access your account extend --------------- -from flask_restful import reqparse # type: ignore +from flask_restx import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services @@ -59,7 +59,7 @@ def is_end_login(end_user): if end_user.external_user_id is None: end_user.external_user_id = decoded.get("user_id") except: - logging.error("load_logged_in_account error") + logging.exception("load_logged_in_account error") pass # no login return user_info diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index f278f7bc1..310571e3a 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -29,10 +29,10 @@ from core.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService -from services.errors.llm import InvokeRateLimitError from services.app_generate_service_extend import ( AppGenerateServiceExtend, # Extend: App Center - Recommended list sorted by usage frequency ) +from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 75b068ffd..b40de97f4 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Literal, Optional, Union, overload, cast # 二开部分 - 密钥额度限制,新增cast +from typing import Any, Literal, Optional, Union, cast, overload # 二开部分 - 密钥额度限制,新增cast from flask import Flask, current_app from pydantic import ValidationError @@ -41,7 +41,16 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts -from models import ApiToken, Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom # 二开部分 - 密钥额度限制,新增ApiToken +from models import ( # 二开部分 - 密钥额度限制,新增ApiToken + Account, + ApiToken, + App, + Conversation, + EndUser, + Message, + Workflow, + WorkflowNodeExecutionTriggeredFrom, +) from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 4fdb731ff..88ec7ac57 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -26,8 +26,8 @@ from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser -from services.conversation_service import ConversationService from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制 +from services.conversation_service import ConversationService logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index c62f7304d..0d4b318ee 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Literal, Optional, Union, overload, cast # 二开部分 - 密钥额度限制,新增cast +from typing import Any, Literal, Optional, Union, cast, overload # 二开部分 - 密钥额度限制,新增cast from flask import Flask, current_app from pydantic import ValidationError @@ -34,7 +34,14 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts -from models import Account, ApiToken, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom # 二开部分 - 密钥额度限制,新增ApiToken +from models import ( # 二开部分 - 密钥额度限制,新增ApiToken + Account, + ApiToken, + App, + EndUser, + Workflow, + WorkflowNodeExecutionTriggeredFrom, +) from models.enums import WorkflowRunTriggeredFrom from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService @@ -156,7 +163,6 @@ class WorkflowAppGenerator(BaseAppGenerator): call_depth=call_depth, trace_manager=trace_manager, workflow_execution_id=workflow_run_id, - extras=extras, extras=extras, # 二开部分 - 密钥额度限制 ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index a05113cac..b3112f3e8 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -66,8 +66,8 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models.account import Account -from models.enums import CreatorUserRole from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制 +from models.enums import CreatorUserRole from models.model import AppMode, EndUser # 二开部分End - 密钥额度限制,新增AppMode from models.workflow import ( Workflow, @@ -289,18 +289,17 @@ class WorkflowAppGenerateTaskPipeline: ) -> Generator[StreamResponse, None, None]: """Handle workflow started events.""" # init workflow run - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( - ) + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() - # ------------------- 二开部分Begin - 密钥额度限制 ------------------- - app_token_id = self._application_generate_entity.extras.get("app_token_id") - if app_token_id: - ApiTokenMessageJoinsExtend( - app_token_id=app_token_id, record_id=workflow_run.id, app_mode=AppMode.WORKFLOW.value - ).add_app_token_record_id() - # ------------------- 二开部分End - 密钥额度限制 ------------------- + # ------------------- 二开部分Begin - 密钥额度限制 ------------------- + app_token_id = self._application_generate_entity.extras.get("app_token_id") + if app_token_id: + ApiTokenMessageJoinsExtend( + app_token_id=app_token_id, record_id=workflow_execution.id_, app_mode=AppMode.WORKFLOW.value + ).add_app_token_record_id() + # ------------------- 二开部分End - 密钥额度限制 ------------------- - self._workflow_run_id = workflow_execution.id_ + self._workflow_run_id = workflow_execution.id_ start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_execution=workflow_execution, diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py new file mode 100644 index 000000000..dfb143f4c --- /dev/null +++ b/api/core/helper/url_signer.py @@ -0,0 +1,52 @@ +import base64 +import hashlib +import hmac +import os +import time + +from pydantic import BaseModel, Field + +from configs import dify_config + + +class SignedUrlParams(BaseModel): + sign_key: str = Field(..., description="The sign key") + timestamp: str = Field(..., description="Timestamp") + nonce: str = Field(..., description="Nonce") + sign: str = Field(..., description="Signature") + + +class UrlSigner: + @classmethod + def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: + signed_url_params = cls.get_signed_url_params(sign_key, prefix) + return ( + f"{url}?timestamp={signed_url_params.timestamp}" + f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" + ) + + @classmethod + def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = cls._sign(sign_key, timestamp, nonce, prefix) + + return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) + + @classmethod + def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: + recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) + + return sign == recalculated_sign + + @classmethod + def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: + if not dify_config.SECRET_KEY: + raise Exception("SECRET_KEY is not set") + + data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return encoded_sign diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index c070e2a53..4f09de6a4 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -12,8 +12,8 @@ from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.code.control_extend import ExecutionControl # Extend: Adding execution control logic from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.code.control_extend import ExecutionControl # Extend: Adding execution control logic from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.nodes.enums import ErrorStrategy, NodeType diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index a15c5ee28..aeeb96517 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -14,6 +14,9 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.app.task_pipeline.exc import WorkflowRunNotFoundError +from core.model_runtime.utils.encoders import jsonable_encoder + +# 二开部分Start - 密钥额度限制 from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType @@ -29,12 +32,10 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 - -# 二开部分Start - 密钥额度限制 -from core.model_runtime.utils.encoders import jsonable_encoder from tasks.extend.update_account_money_when_workflow_node_execution_created_extend import ( update_account_money_when_workflow_node_execution_created_extend, ) + # 二开部分End - 密钥额度限制 @dataclass @@ -196,8 +197,9 @@ class WorkflowCycleManager: self._workflow_node_execution_repository.save(domain_execution) # 二开部分Begin - 额度限制 - workflow_node_execution_dict = jsonable_encoder(domain_execution) # 转化为json字典 - update_account_money_when_workflow_node_execution_created_extend.delay(workflow_node_execution_dict) + # 异步任务计算费用并更新账户额度,将对象转换为字典传递 + domain_execution_dict = jsonable_encoder(domain_execution) + update_account_money_when_workflow_node_execution_created_extend.delay(domain_execution_dict) # 二开部分End - 额度限制 return domain_execution diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 1ef2437b0..29c2a365b 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -31,11 +31,18 @@ if [[ "${MODE}" == "worker" ]]; then else CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}" fi - + ## 二开部分,额度计算移动到新的队列中 exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,extend_high,extend_low} - + -Q ${CELERY_QUEUES:-mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} + ## 二开部分,额度计算移动到新的队列中 +## 二开部分,额度计算 +elif [[ "${MODE}" == "worker-gaia" ]]; then + exec celery -A app.celery worker -P gevent -c 1 -Q extend_high,extend_low --loglevel INFO +## 二开部分,单一运行的知识库,多容器执行会导致卡住问题 +elif [[ "${MODE}" == "worker-dataset" ]]; then + exec celery -A app.celery worker -P gevent -c 1 -Q dataset --prefetch-multiplier=1 --loglevel INFO +## 二开部分,end elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} else diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 4b4635a52..34b529a2b 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -7,7 +7,7 @@ from .delete_tool_parameters_cache_when_sync_draft_workflow import ( handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow, ) from .update_account_money_when_messaeg_created_extend import ( - handle as handle_update_account_money_when_messaeg_created_extend + handle as handle_update_account_money_when_messaeg_created_extend, ) # 二开部分:新增限额判断 from .update_app_dataset_join_when_app_model_config_updated import ( handle as handle_update_app_dataset_join_when_app_model_config_updated, @@ -27,8 +27,8 @@ __all__ = [ "handle_create_installed_app_when_app_created", "handle_create_site_record_when_app_created", "handle_delete_tool_parameters_cache_when_sync_draft_workflow", + "handle_update_account_money_when_messaeg_created_extend", "handle_update_app_dataset_join_when_app_model_config_updated", "handle_update_app_dataset_join_when_app_published_workflow_updated", "handle_update_provider_when_message_created", - "handle_update_account_money_when_messaeg_created_extend", ] diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 56c0ac7cb..1a398c480 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -153,6 +153,9 @@ def init_app(app: DifyApp) -> Celery: } # ---------------------------- 二开部分 Begin ---------------------------- + # 导入扩展的 Celery 任务 + imports.append("tasks.extend.update_account_money_when_workflow_node_execution_created_extend") + # 每月1号00:00,重置账号额度 imports.append("schedule.update_account_used_quota_extend") beat_schedule["update_account_used_quota_extend"] = { @@ -163,7 +166,7 @@ def init_app(app: DifyApp) -> Celery: imports.append("schedule.update_api_token_daily_used_quota_task_extend") beat_schedule["update_api_token_daily_used_quota_task_extend"] = { "task": "schedule.update_api_token_daily_used_quota_task_extend.update_api_token_daily_used_quota_task_extend", - "schedule": crontab(days=1), + "schedule": crontab(hour=0, minute=0), } # 每月1号00:00,重置密钥月额度 imports.append("schedule.update_api_token_monthly_used_quota_task_extend") diff --git a/api/fields/app_fields_extend.py b/api/fields/app_fields_extend.py index c0e2549d2..17b854b62 100644 --- a/api/fields/app_fields_extend.py +++ b/api/fields/app_fields_extend.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import AppIconUrlField diff --git a/api/fields/member_fields_extend.py b/api/fields/member_fields_extend.py index 0005b2d35..7f84eef92 100644 --- a/api/fields/member_fields_extend.py +++ b/api/fields/member_fields_extend.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields account_money_fields = { "total_quota": fields.Float, diff --git a/api/libs/login_extend.py b/api/libs/login_extend.py new file mode 100644 index 000000000..e821448a8 --- /dev/null +++ b/api/libs/login_extend.py @@ -0,0 +1,53 @@ +import time +from functools import wraps + +import jwt +from flask import request + +from configs import dify_config + + +def repost_login_required(func): + """ + If you decorate a view with this, it will ensure that the current user is logged in and authenticated via proxy + forwarding before calling the actual view. (If not, it will call the :attr:`LoginManager.unauthorized` callback.) + For example:: + + @app.route('/post') + @repost_login_required + def post(): + pass + """ + + @wraps(func) + def decorated_view(*args, **kwargs): + auth_header = request.headers.get("Authorization") + if auth_header is None: + auth_header = request.cookies.get("x-token") + try: + if auth_header is not None: + auth_header = auth_header[7:] if "Bearer " in auth_header else auth_header + decoded_token = jwt.decode(auth_header, dify_config.SECRET_KEY.encode(), algorithms=["HS256"]) + user_id = decoded_token.get("user_id") + if user_id and time.time() < decoded_token.get("exp", 0): + kwargs["account"] = user_id + return func(*args, **kwargs) + except jwt.ExpiredSignatureError: + return { + "code": 401, + "status": "token has expired", + "message": "account_token_has_expired", + } + except jwt.InvalidTokenError: + return { + "code": 401, + "status": "token is invalid", + "message": "account_token_is_invalid", + } + return { + "code": 403, + "status": "account_not_link_tenant", + "message": "Account not link tenant.", + } + + return decorated_view diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 4c83f6901..884fbb3e3 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -4,9 +4,10 @@ from dataclasses import dataclass from typing import Optional import requests + from configs import dify_config # Extend OAuto third-party login from extensions.ext_database import db # Extend OAuto third-party login -from models.system_extend import SystemIntegrationExtend, SystemIntegrationClassify # Extend OAuto third-party login +from models.system_extend import SystemIntegrationClassify, SystemIntegrationExtend # Extend OAuto third-party login @dataclass diff --git a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table.py similarity index 100% rename from api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py rename to api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table.py diff --git a/api/migrations_extend/env.py b/api/migrations_extend/env.py index ee3287b78..a450710a4 100644 --- a/api/migrations_extend/env.py +++ b/api/migrations_extend/env.py @@ -1,10 +1,10 @@ import logging import os -from logging.config import fileConfig import sys +from logging.config import fileConfig + from alembic import context from flask import current_app -from sqlalchemy import engine_from_config, pool USE_TWOPHASE = False @@ -40,6 +40,7 @@ def get_engine_url(): config.set_main_option('sqlalchemy.url', get_engine_url()) from models import db + target_metadata = db.Model.metadata # other values from the config, defined by the needs of env.py, diff --git a/api/migrations_extend/versions/09633b4cf949_add_account_money_extend.py b/api/migrations_extend/versions/09633b4cf949_add_account_money_extend.py index b0f0bd78f..c8bbb28e6 100644 --- a/api/migrations_extend/versions/09633b4cf949_add_account_money_extend.py +++ b/api/migrations_extend/versions/09633b4cf949_add_account_money_extend.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '005_account_money_extend' diff --git a/api/migrations_extend/versions/2024_08_03_0724-a6f3333821be_add_end_user_account_joins_extend.py b/api/migrations_extend/versions/2024_08_03_0724-a6f3333821be_add_end_user_account_joins_extend.py index 81888f0e4..898082f33 100644 --- a/api/migrations_extend/versions/2024_08_03_0724-a6f3333821be_add_end_user_account_joins_extend.py +++ b/api/migrations_extend/versions/2024_08_03_0724-a6f3333821be_add_end_user_account_joins_extend.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '006_end_user_account_joins' diff --git a/api/migrations_extend/versions/2024_08_05_0626-fb321d6d1ef0_add_account_money_monthly_stat_extend.py b/api/migrations_extend/versions/2024_08_05_0626-fb321d6d1ef0_add_account_money_monthly_stat_extend.py index 8d63106fa..a4400a740 100644 --- a/api/migrations_extend/versions/2024_08_05_0626-fb321d6d1ef0_add_account_money_monthly_stat_extend.py +++ b/api/migrations_extend/versions/2024_08_05_0626-fb321d6d1ef0_add_account_money_monthly_stat_extend.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '007_account_money_monthly_stat' diff --git a/api/migrations_extend/versions/2024_08_27_0308-fbd1f511a08e_add_account_layover_record_extend.py b/api/migrations_extend/versions/2024_08_27_0308-fbd1f511a08e_add_account_layover_record_extend.py index 5a34f27dd..713e54ed3 100644 --- a/api/migrations_extend/versions/2024_08_27_0308-fbd1f511a08e_add_account_layover_record_extend.py +++ b/api/migrations_extend/versions/2024_08_27_0308-fbd1f511a08e_add_account_layover_record_extend.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '008_account_layover_record' diff --git a/api/migrations_extend/versions/2024_08_29_0715-1b804f8bbd28_add_api_token_money_extend.py b/api/migrations_extend/versions/2024_08_29_0715-1b804f8bbd28_add_api_token_money_extend.py index c5b3cb884..ad7f7a4d3 100644 --- a/api/migrations_extend/versions/2024_08_29_0715-1b804f8bbd28_add_api_token_money_extend.py +++ b/api/migrations_extend/versions/2024_08_29_0715-1b804f8bbd28_add_api_token_money_extend.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '009_api_token_money_extend' diff --git a/api/migrations_extend/versions/2025_03_31_2136-588f1696997b_add_system_integration_extend.py b/api/migrations_extend/versions/2025_03_31_2136-588f1696997b_add_system_integration_extend.py index 31941058a..c98d2e7c2 100644 --- a/api/migrations_extend/versions/2025_03_31_2136-588f1696997b_add_system_integration_extend.py +++ b/api/migrations_extend/versions/2025_03_31_2136-588f1696997b_add_system_integration_extend.py @@ -5,11 +5,9 @@ Revises: 009_api_token_money_extend Create Date: 2025-03-31 21:36:03.818117 """ -from alembic import op -from models import db, types import sqlalchemy as sa +from alembic import op from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = '010_system_integration_extend' diff --git a/api/migrations_extend/versions/2025_04_01_0001-add_system_integration_extend_fields.py b/api/migrations_extend/versions/2025_04_01_0001-add_system_integration_extend_fields.py index 2bb58e6c3..9054fa18d 100644 --- a/api/migrations_extend/versions/2025_04_01_0001-add_system_integration_extend_fields.py +++ b/api/migrations_extend/versions/2025_04_01_0001-add_system_integration_extend_fields.py @@ -5,11 +5,9 @@ Revises: 010_system_integration_extend Create Date: 2025-04-01 00:01:00.000000 """ -from alembic import op -from models import db, types import sqlalchemy as sa +from alembic import op from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = '011_system_integration_fields' diff --git a/api/migrations_extend/versions/41e6e402d572_add_recommended_apps_category_join_.py b/api/migrations_extend/versions/41e6e402d572_add_recommended_apps_category_join_.py index df4b0cd91..128473452 100644 --- a/api/migrations_extend/versions/41e6e402d572_add_recommended_apps_category_join_.py +++ b/api/migrations_extend/versions/41e6e402d572_add_recommended_apps_category_join_.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '002_recommended_apps_category' diff --git a/api/migrations_extend/versions/9e52f36c2d6d_ai_billing_and_forwarding_two_extend.py b/api/migrations_extend/versions/9e52f36c2d6d_ai_billing_and_forwarding_two_extend.py index a9a5ef08d..ef02d593a 100644 --- a/api/migrations_extend/versions/9e52f36c2d6d_ai_billing_and_forwarding_two_extend.py +++ b/api/migrations_extend/versions/9e52f36c2d6d_ai_billing_and_forwarding_two_extend.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '004_ai_billing_forwarding' diff --git a/api/migrations_extend/versions/d8929f29057c_add_tenant_model_sync_extend.py b/api/migrations_extend/versions/d8929f29057c_add_tenant_model_sync_extend.py index 01b1bc853..0d73a5b61 100644 --- a/api/migrations_extend/versions/d8929f29057c_add_tenant_model_sync_extend.py +++ b/api/migrations_extend/versions/d8929f29057c_add_tenant_model_sync_extend.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.engine.reflection import Inspector -from models import db, types +from models import types # revision identifiers, used by Alembic. revision = '003_tenant_model_sync_extend' diff --git a/api/models/__init__.py b/api/models/__init__.py index ddab8c534..63ea14658 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -68,8 +68,8 @@ from .provider import ( TenantDefaultModel, TenantPreferredModelProvider, ) -from .system_extend import SystemIntegrationExtend # Extend System Integration from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from .system_extend import SystemIntegrationExtend # Extend System Integration from .task import CeleryTask, CeleryTaskSet from .tools import ( ApiToolProvider, @@ -152,6 +152,7 @@ __all__ = [ "RecommendedApp", "SavedMessage", "Site", + "SystemIntegrationExtend", # Extend System Integration "Tag", "TagBinding", "Tenant", @@ -179,5 +180,4 @@ __all__ = [ "WorkflowToolProvider", "WorkflowType", "db", - "SystemIntegrationExtend", # Extend System Integration ] diff --git a/api/models/ai_draw_extnd.py b/api/models/ai_draw_extnd.py new file mode 100644 index 000000000..7180768b1 --- /dev/null +++ b/api/models/ai_draw_extnd.py @@ -0,0 +1,162 @@ +import json +import logging +from enum import Enum + +from extensions.ext_database import db + +from .types import StringUUID + + +class ForwardingExtend(db.Model): + __tablename__ = "forwarding_extend" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="forwarding_extend_pkey"), + db.Index("idx_forwarding_path", "path"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + path = db.Column(db.String(255), nullable=False) + address = db.Column(db.String(255), nullable=False) + header = db.Column(db.Text, nullable=False, server_default=db.text("'[]'")) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) + + +class RequestContentType(Enum): + TypeNone = 0 + """Content-Type: none""" + + FormData = 1 + """Content-Type: form-data""" + + UrlEncoded = 2 + """Content-Type: x-www-form-urlencoded""" + + RawText = 3 + """Content-Type: text/plain""" + + ApplicationJavaScript = 4 + """Content-Type: application/javascript""" + + ApplicationJson = 5 + """Content-Type: application/json""" + + TextHtml = 6 + """Content-Type: text/html""" + + ApplicationXml = 7 + """Content-Type: application/xml""" + + @staticmethod + def value_of(value): + for member in RequestContentType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class ForwardingAddressBillingExtend: + def __init__(self, remark: str, para: str, operation: int, benchmark: str, price: float, children: list): + self.remark = remark # 计费备注 + self.para = para # 参数路径 + self.operation = operation # 运算符 1: > ,2: < ,3: == , 4: >= , 5: <=, 6: +, 7: -, 8: *, 9: / + self.benchmark = benchmark # 计费基准 + self.price = price # 价格 + self.children = children # 子数据集 + + +def find_in_tree(data, path: str): + # 分离路径 + keys = path.replace("]", "").split(".") + for key in keys: + # 处理数组索引 + if "[" in key: + key, index = key.split("[") + index = int(index) + data = data[key][index] + else: + data = data[key] + return data + + +class ForwardingAddressExtend(db.Model): + __tablename__ = "forwarding_address_extend" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="forwarding_address_pkey"), + db.Index("idx_forwarding_address_id", "forwarding_id"), + db.Index("idx_forwarding_address_status", "status"), + db.Index("idx_forwarding_address_path", "path"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + forwarding_id = db.Column(StringUUID, nullable=False) + path = db.Column(db.String(255), nullable=False) + models = db.Column(db.String(255), nullable=False) + status = db.Column(db.Boolean, nullable=True, server_default=db.text("true")) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) + content_type = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + billing = db.Column(db.Text, nullable=False, server_default=db.text("'[]'")) + + @property + def encode(self): + return json.dumps(self.billing) + + @property + def decode_billing(self) -> list[ForwardingAddressBillingExtend]: + return [ForwardingAddressBillingExtend(**item) for item in json.loads(self.billing)] + + def funds_settlement(self, data, billing_list: list[ForwardingAddressBillingExtend]) -> (dict, int): + money = 0 + funds = {} + # differentiate request types + for i in billing_list: + # 判断路径是否存在 + try: + path_value = find_in_tree(data, i.para) + if path_value is None: + continue + # 判断当前是否符合条件 + # 0: != , 1: > ,2: < ,3: == , 4: >= , 5: <=, 6: +, 7: -, 8: *, 9: / + try: + if i.price and len(path_value) > 0: + funds[i.para] = path_value + if i.operation == 0 and i.benchmark != path_value: + # != 不等于 + money += float(i.price) + if i.operation == 1 and i.benchmark > path_value: + # > 大于 + money += float(i.price) + elif i.operation == 2 and i.benchmark < path_value: + # < 小于 + money += float(i.price) + elif i.operation == 3 and i.benchmark == path_value: + # == 等于 + money += float(i.price) + elif i.operation == 4 and i.benchmark >= path_value: + # >= 大于等于 + money += float(i.price) + elif i.operation == 5 and i.benchmark <= path_value: + # <= 小于等于 + money += float(i.price) + elif i.operation == 6: + # + 加 + money += float(i.price) + elif i.operation == 7: + # - 减 + money -= float(i.price) + elif i.operation == 8: + # * 乘 + money += float(i.price) * float(path_value) + elif i.operation == 9: + # / 除 + money += float(i.price) / float(path_value) + except Exception as e: + logging.debug(e, "billing error", i.price, path_value) + # 判断是否有子集 + if len(i.children) > 0: + # 有子集回调 + cache_funds, cache_money = self.funds_settlement(data, i.children) + funds.update(cache_funds) + money += cache_money + except: + pass + return funds, money diff --git a/api/models/system_extend.py b/api/models/system_extend.py index 1551f32ef..582e14228 100644 --- a/api/models/system_extend.py +++ b/api/models/system_extend.py @@ -1,9 +1,13 @@ -from Crypto.Util.Padding import unpad -from Crypto.Cipher import Blowfish -from configs import dify_config -from .engine import db import base64 +from Crypto.Cipher import Blowfish +from Crypto.Util.Padding import unpad + +from configs import dify_config + +from .engine import db + + class SystemIntegrationClassify: SYSTEM_INTEGRATION_DINGTALK = 1 # 钉钉 SYSTEM_INTEGRATION_WEIXIN = 2 # 微信 diff --git a/api/models/workflow.py b/api/models/workflow.py index 1abfcac9d..4f2dd6438 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -883,7 +883,7 @@ class WorkflowAppLog(Base): else: from models.model import EndUser end_user = db.session.query(EndUser).filter(EndUser.id == self.created_by).first() - if end_user is not None and len(end_user.external_user_id) > 0: + if end_user is not None and end_user.external_user_id is not None and len(end_user.external_user_id) > 0: user: Account = db.session.query(Account).filter(Account.id == end_user.external_user_id).first() if user: return { diff --git a/api/pyproject.toml b/api/pyproject.toml index aa169e741..351723ba0 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -91,11 +91,12 @@ dependencies = [ "sendgrid~=6.12.3", "flask-restx>=1.3.0", ##### start extend ###### - "tokenizers~=0.15.0", - "validators~=0.21.0", + "tokenizers~=0.21.2", + "validators~=0.22.0", "alibabacloud-dingtalk~=2.1.32", "ldap3~=2.9.1", "pypinyin~=0.53.0", + "flask-restful~=0.3.10" ##### stop extend ###### ] # Before adding new dependency, consider place it in @@ -230,4 +231,4 @@ vdb = [ [[tool.poetry.source]] name = "aliyun" -url = "https://mirrors.aliyun.com/pypi/simple" \ No newline at end of file +url = "https://mirrors.aliyun.com/pypi/simple" diff --git a/api/services/ai_draw_extend.py b/api/services/ai_draw_extend.py new file mode 100644 index 000000000..fd86a3202 --- /dev/null +++ b/api/services/ai_draw_extend.py @@ -0,0 +1,85 @@ +import json +import threading +import time + +from extensions.ext_database import db +from models.ai_draw_extnd import ForwardingExtend + +# Create a shared dictionary +FORWARDING = {} +# Create a lock object +dict_lock = threading.Lock() + + +def thread_forwarding_write(key, value: ForwardingExtend): + global dict_lock, FORWARDING + with dict_lock: + FORWARDING[key] = [ + json.dumps( + { + "id": value.id, + "path": value.path, + "header": value.header, + "address": value.address, + "description": value.description, + } + ), + int(time.time()), + ] + + +def thread_forwarding_read(key) -> ForwardingExtend | None: + global FORWARDING + # prevent error: is not bound to a Session; attribute refresh operation cannot proceed + info = FORWARDING.get(key) + if info is not None and info[1] < int(time.time()) + 600: + if info[0] is not None: + try: + forwarding_dict_back = json.loads(info[0]) + return ForwardingExtend( + id=forwarding_dict_back["id"], + path=forwarding_dict_back["path"], + header=forwarding_dict_back["header"], + address=forwarding_dict_back["address"], + description=forwarding_dict_back["description"], + ) + except Exception as e: + pass + else: + return None + forwarding: ForwardingExtend = db.session.query(ForwardingExtend).filter(ForwardingExtend.path == key).first() + # save + if forwarding is not None: + thread_forwarding_write(key, forwarding) + else: + FORWARDING[key] = [None, int(time.time())] + return forwarding + + +class AiDrawForwarding: + @classmethod + def get_forwarding(cls, path: str) -> ForwardingExtend: + """ + AI draws forwarding, obtains forwarding domain name + :param path: str + """ + info = thread_forwarding_read(path) + if info is not None: + return info + info: ForwardingExtend = db.session.query(ForwardingExtend).filter(ForwardingExtend.path == path).first() + # save + thread_forwarding_write(path, info) + return info + + @classmethod + def get_all_forwarding(cls): + address = {} + for i in db.session.query(ForwardingExtend).all(): + # 1. 替换 https:// http:// :8000 + url = i.address.replace('https://', '', 1).replace('http://', '', 1).replace(':8000', '', 1) + # 2. 移除末尾的/(如果有) + url = url.rstrip('/') + address[url] = i.path + return address + + diff --git a/api/services/app_service.py b/api/services/app_service.py index 5aea8ce34..cb5a7afb6 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -19,7 +19,14 @@ from events.app_event import app_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.account import Account -from models.model import App, AppMode, AppModelConfig, Site, AppStatisticsExtend, RecommendedApp # Extend: App Center - Recommended list sorted by usage frequency +from models.model import ( # Extend: App Center - Recommended list sorted by usage frequency + App, + AppMode, + AppModelConfig, + AppStatisticsExtend, + RecommendedApp, + Site, +) from models.tools import ApiToolProvider from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService diff --git a/api/services/batch_workflow_statistics_service.py b/api/services/batch_workflow_statistics_service.py new file mode 100644 index 000000000..df1fdd5de --- /dev/null +++ b/api/services/batch_workflow_statistics_service.py @@ -0,0 +1,854 @@ +""" +批量工作流统计服务 - 生成专业的Excel报表 +""" +from datetime import datetime, timedelta +from typing import Any + +from openpyxl import Workbook +from openpyxl.chart import BarChart, LineChart, PieChart, Reference +from openpyxl.styles import Alignment, Border, Font, PatternFill, Side +from openpyxl.utils import get_column_letter +from sqlalchemy import text +from sqlalchemy.orm import Session + +from app_factory import create_app +from extensions.ext_database import db + + +class BatchWorkflowStatisticsService: + """批量工作流统计服务""" + + @staticmethod + def get_today_app_usage_stats(session: Session | None = None) -> list[dict[str, Any]]: + """ + 获取今天各个APP的使用统计(按使用次数排序) + + Returns: + list[dict]: 包含app_id, app_name, usage_count的列表,按使用次数降序 + """ + if session is None: + session = db.session + + # 获取今天的开始时间(00:00:00) + today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + + # SQL查询:统计今天各个APP的使用次数 + query = text(""" + SELECT + c.id as app_id, + c.name as app_name, + COUNT(a.id) as usage_count, + SUM(a.total_rows) as total_rows, + SUM(a.processed_rows) as processed_rows, + SUM(a.error_count) as error_count + FROM batch_workflows_extend as a + INNER JOIN installed_apps as b ON a.installed_id::uuid = b.id + INNER JOIN apps as c ON b.app_id = c.id + WHERE a.created_at >= :today_start + GROUP BY c.id, c.name + ORDER BY usage_count DESC + """) + + result = session.execute(query, {"today_start": today_start}) + + stats = [] + for row in result: + stats.append({ + "app_id": row.app_id, + "app_name": row.app_name, + "usage_count": row.usage_count, + "total_rows": row.total_rows or 0, + "processed_rows": row.processed_rows or 0, + "error_count": row.error_count or 0, + }) + + return stats + + @staticmethod + def get_hourly_execution_stats(session: Session | None = None, hours: int = 24) -> list[dict[str, Any]]: + """ + 获取按小时统计的执行情况 + + Args: + session: 数据库会话 + hours: 统计最近多少小时,默认24小时 + + Returns: + list[dict]: 包含时间段和执行数量的列表 + """ + if session is None: + session = db.session + + # 获取起始时间 + start_time = datetime.now() - timedelta(hours=hours) + + # SQL查询:按小时统计执行中的任务数 + query = text(""" + SELECT + DATE_TRUNC('hour', a.created_at) as hour_period, + COUNT(DISTINCT a.id) as total_count, + COUNT(DISTINCT CASE WHEN a.status = 'processing' THEN a.id END) as processing_count, + COUNT(DISTINCT CASE WHEN a.status = 'completed' THEN a.id END) as completed_count, + COUNT(DISTINCT CASE WHEN a.status = 'failed' THEN a.id END) as failed_count, + COUNT(DISTINCT CASE WHEN a.status = 'pending' THEN a.id END) as pending_count, + SUM(a.total_rows) as total_rows, + SUM(a.processed_rows) as processed_rows + FROM batch_workflows_extend as a + WHERE a.created_at >= :start_time + GROUP BY DATE_TRUNC('hour', a.created_at) + ORDER BY hour_period DESC + """) + + result = session.execute(query, {"start_time": start_time}) + + stats = [] + for row in result: + stats.append({ + "hour_period": row.hour_period.strftime("%Y-%m-%d %H:00:00"), + "total_count": row.total_count, + "processing_count": row.processing_count or 0, + "completed_count": row.completed_count or 0, + "failed_count": row.failed_count or 0, + "pending_count": row.pending_count or 0, + "total_rows": row.total_rows or 0, + "processed_rows": row.processed_rows or 0, + }) + + return stats + + @staticmethod + def get_user_batch_stats(session: Session | None = None) -> list[dict[str, Any]]: + """ + 获取今天各用户的批量处理统计 + + Returns: + list[dict]: 包含用户信息和统计数据的列表 + """ + if session is None: + session = db.session + + # 获取今天的开始时间 + today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + + # SQL查询:统计各用户今天的批量处理情况 + # batch_workflows_extend.user_id 对应 sys_users.id (uint类型) + query = text(""" + SELECT + su.id as account_id, + COALESCE(su.nick_name, su.username) as account_name, + su.email as account_email, + COUNT(a.id) as batch_count, + SUM(a.total_rows) as total_rows, + SUM(a.processed_rows) as processed_rows, + SUM(a.error_count) as error_count, + COUNT(DISTINCT a.installed_id) as app_count + FROM batch_workflows_extend as a + INNER JOIN sys_users as su ON a.user_id = su.id + WHERE a.created_at >= :today_start + GROUP BY su.id, su.nick_name, su.username, su.email + ORDER BY batch_count DESC + """) + + result = session.execute(query, {"today_start": today_start}) + + stats = [] + for row in result: + stats.append({ + "account_id": row.account_id, + "account_name": row.account_name, + "account_email": row.account_email, + "batch_count": row.batch_count, + "total_rows": row.total_rows or 0, + "processed_rows": row.processed_rows or 0, + "error_count": row.error_count or 0, + "app_count": row.app_count or 0, + }) + + return stats + + @staticmethod + def get_current_executing_stats(session: Session | None = None) -> dict[str, Any]: + """ + 获取当前正在执行的批量工作流统计 + + Returns: + dict: 当前执行状态的统计信息 + """ + if session is None: + session = db.session + + # SQL查询:获取当前执行状态统计 + query = text(""" + SELECT + COUNT(DISTINCT a.id) as processing_workflows, + COUNT(DISTINCT a.user_id) as active_users, + COUNT(DISTINCT a.installed_id) as active_apps, + SUM(a.total_rows - a.processed_rows) as pending_rows, + SUM(a.processed_rows) as completed_rows + FROM batch_workflows_extend as a + WHERE a.status IN ('processing', 'pending') + """) + + result = session.execute(query).fetchone() + + return { + "processing_workflows": result.processing_workflows or 0, + "active_users": result.active_users or 0, + "active_apps": result.active_apps or 0, + "pending_rows": result.pending_rows or 0, + "completed_rows": result.completed_rows or 0, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + } + + @staticmethod + def get_app_hourly_distribution( + app_id: str | None = None, + session: Session | None = None, + hours: int = 24 + ) -> list[dict[str, Any]]: + """ + 获取指定APP(或所有APP)的小时级别分布统计 + + Args: + app_id: APP ID,如果为None则统计所有APP + session: 数据库会话 + hours: 统计最近多少小时 + + Returns: + list[dict]: 小时级别的统计数据 + """ + if session is None: + session = db.session + + start_time = datetime.now() - timedelta(hours=hours) + + if app_id: + query = text(""" + SELECT + DATE_TRUNC('hour', a.created_at) as hour_period, + c.id as app_id, + c.name as app_name, + COUNT(a.id) as execution_count, + SUM(a.total_rows) as total_rows, + SUM(a.processed_rows) as processed_rows + FROM batch_workflows_extend as a + INNER JOIN installed_apps as b ON a.installed_id::uuid = b.id + INNER JOIN apps as c ON b.app_id = c.id + WHERE a.created_at >= :start_time AND c.id = :app_id + GROUP BY DATE_TRUNC('hour', a.created_at), c.id, c.name + ORDER BY hour_period DESC + """) + result = session.execute(query, {"start_time": start_time, "app_id": app_id}) + else: + query = text(""" + SELECT + DATE_TRUNC('hour', a.created_at) as hour_period, + COUNT(a.id) as execution_count, + COUNT(DISTINCT b.app_id) as unique_apps, + SUM(a.total_rows) as total_rows, + SUM(a.processed_rows) as processed_rows + FROM batch_workflows_extend as a + INNER JOIN installed_apps as b ON a.installed_id::uuid = b.id + WHERE a.created_at >= :start_time + GROUP BY DATE_TRUNC('hour', a.created_at) + ORDER BY hour_period DESC + """) + result = session.execute(query, {"start_time": start_time}) + + stats = [] + for row in result: + stat = { + "hour_period": row.hour_period.strftime("%Y-%m-%d %H:00:00"), + "execution_count": row.execution_count, + "total_rows": row.total_rows or 0, + "processed_rows": row.processed_rows or 0, + } + + if app_id: + stat["app_id"] = row.app_id + stat["app_name"] = row.app_name + else: + stat["unique_apps"] = row.unique_apps or 0 + + stats.append(stat) + + return stats + + @staticmethod + def get_error_analysis_stats(session: Session | None = None) -> dict[str, Any]: + """ + 获取错误分析统计 + + Returns: + dict: 包含错误类型统计、APP错误分布、错误示例等 + """ + if session is None: + session = db.session + + # 获取今天的开始时间 + today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + + # 1. 错误类型TOP10统计 + error_type_query = text(""" + SELECT + CASE + WHEN bwt.error LIKE '%rate limit%' THEN 'Rate Limit (频率限制)' + WHEN bwt.error LIKE '%quota%' THEN 'Quota Exceeded (配额超限)' + WHEN bwt.error LIKE '%timeout%' THEN 'Timeout (超时)' + WHEN bwt.error LIKE '%connection%' THEN 'Connection Error (连接错误)' + WHEN bwt.error LIKE '%authentication%' THEN 'Authentication Error (认证错误)' + WHEN bwt.error LIKE '%permission%' THEN 'Permission Error (权限错误)' + WHEN bwt.error LIKE '%model%' THEN 'Model Error (模型错误)' + WHEN bwt.error LIKE '%重试超过%' THEN 'Retry Exceeded (重试超限)' + ELSE 'Other Error (其他错误)' + END as error_type, + COUNT(*) as error_count, + MAX(bwt.error) as error_example + FROM batch_workflow_tasks_extend as bwt + INNER JOIN batch_workflows_extend as bw ON bwt.batch_workflow_id = bw.id + WHERE bwt.status = 'failed' + AND bwt.created_at >= :today_start + GROUP BY + CASE + WHEN bwt.error LIKE '%rate limit%' THEN 'Rate Limit (频率限制)' + WHEN bwt.error LIKE '%quota%' THEN 'Quota Exceeded (配额超限)' + WHEN bwt.error LIKE '%timeout%' THEN 'Timeout (超时)' + WHEN bwt.error LIKE '%connection%' THEN 'Connection Error (连接错误)' + WHEN bwt.error LIKE '%authentication%' THEN 'Authentication Error (认证错误)' + WHEN bwt.error LIKE '%permission%' THEN 'Permission Error (权限错误)' + WHEN bwt.error LIKE '%model%' THEN 'Model Error (模型错误)' + WHEN bwt.error LIKE '%重试超过%' THEN 'Retry Exceeded (重试超限)' + ELSE 'Other Error (其他错误)' + END + ORDER BY error_count DESC + LIMIT 10 + """) + + error_type_result = session.execute(error_type_query, {"today_start": today_start}) + error_types = [] + for row in error_type_result: + error_types.append({ + "error_type": row.error_type, + "error_count": row.error_count, + "error_example": row.error_example[:200] + "..." if len(row.error_example) > 200 else row.error_example + }) + + # 2. 各APP的错误分布 + app_error_query = text(""" + SELECT + c.id as app_id, + c.name as app_name, + COUNT(bwt.id) as total_errors, + COUNT(DISTINCT bwt.batch_workflow_id) as affected_workflows, + COUNT(CASE WHEN bwt.error LIKE '%rate limit%' THEN 1 END) as rate_limit_errors, + COUNT(CASE WHEN bwt.error LIKE '%quota%' THEN 1 END) as quota_errors, + COUNT(CASE WHEN bwt.error LIKE '%重试超过%' THEN 1 END) as retry_errors, + MAX(bwt.error) as error_example + FROM batch_workflow_tasks_extend as bwt + INNER JOIN batch_workflows_extend as bw ON bwt.batch_workflow_id = bw.id + INNER JOIN installed_apps as b ON bw.installed_id::uuid = b.id + INNER JOIN apps as c ON b.app_id = c.id + WHERE bwt.status = 'failed' + AND bwt.created_at >= :today_start + GROUP BY c.id, c.name + ORDER BY total_errors DESC + """) + + app_error_result = session.execute(app_error_query, {"today_start": today_start}) + app_errors = [] + for row in app_error_result: + app_errors.append({ + "app_id": row.app_id, + "app_name": row.app_name, + "total_errors": row.total_errors, + "affected_workflows": row.affected_workflows, + "rate_limit_errors": row.rate_limit_errors or 0, + "quota_errors": row.quota_errors or 0, + "retry_errors": row.retry_errors or 0, + "error_example": row.error_example[:200] + "..." if len(row.error_example) > 200 else row.error_example + }) + + # 3. 具体错误示例(最新的10个) + error_examples_query = text(""" + SELECT + c.name as app_name, + bwt.error, + bwt.created_at, + bwt.error_count, + bwt.row_index + FROM batch_workflow_tasks_extend as bwt + INNER JOIN batch_workflows_extend as bw ON bwt.batch_workflow_id = bw.id + INNER JOIN installed_apps as b ON bw.installed_id::uuid = b.id + INNER JOIN apps as c ON b.app_id = c.id + WHERE bwt.status = 'failed' + AND bwt.created_at >= :today_start + ORDER BY bwt.created_at DESC + LIMIT 10 + """) + + error_examples_result = session.execute(error_examples_query, {"today_start": today_start}) + error_examples = [] + for row in error_examples_result: + error_examples.append({ + "app_name": row.app_name, + "error": row.error, + "created_at": row.created_at.strftime("%Y-%m-%d %H:%M:%S"), + "error_count": row.error_count, + "row_index": row.row_index + }) + + return { + "error_types": error_types, + "app_errors": app_errors, + "error_examples": error_examples, + "total_errors": sum(et["error_count"] for et in error_types), + "affected_apps": len(app_errors) + } + + +class ExcelReportGenerator: + """Excel报表生成器""" + + def __init__(self): + self.service = BatchWorkflowStatisticsService() + self.wb = Workbook() + # 定义样式 + self.header_font = Font(name="微软雅黑", size=11, bold=True, color="FFFFFF") + self.header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") + self.title_font = Font(name="微软雅黑", size=16, bold=True, color="2F5496") + self.border = Border( + left=Side(style="thin"), + right=Side(style="thin"), + top=Side(style="thin"), + bottom=Side(style="thin"), + ) + self.center_alignment = Alignment(horizontal="center", vertical="center") + self.left_alignment = Alignment(horizontal="left", vertical="center") + + def _apply_header_style(self, ws, row: int, max_col: int): + """应用表头样式""" + for col in range(1, max_col + 1): + cell = ws.cell(row=row, column=col) + cell.font = self.header_font + cell.fill = self.header_fill + cell.alignment = self.center_alignment + cell.border = self.border + + def _apply_data_style(self, ws, start_row: int, end_row: int, max_col: int): + """应用数据行样式""" + for row in range(start_row, end_row + 1): + for col in range(1, max_col + 1): + cell = ws.cell(row=row, column=col) + cell.border = self.border + if col == 1: + cell.alignment = self.left_alignment + else: + cell.alignment = self.center_alignment + + def _auto_adjust_column_width(self, ws): + """自动调整列宽""" + for column in ws.columns: + max_length = 0 + column_letter = get_column_letter(column[0].column) + for cell in column: + try: + if cell.value: + max_length = max(max_length, len(str(cell.value))) + except: + pass + adjusted_width = min(max_length + 2, 50) + ws.column_dimensions[column_letter].width = adjusted_width + + def create_summary_sheet(self): + """创建汇总页""" + ws = self.wb.active + ws.title = "概览汇总" + + # 标题 + ws.merge_cells("A1:F1") + ws["A1"] = "批量工作流处理统计报表" + ws["A1"].font = self.title_font + ws["A1"].alignment = self.center_alignment + + ws.merge_cells("A2:F2") + ws["A2"] = f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + ws["A2"].alignment = self.center_alignment + + # 当前执行状态 + current_stats = self.service.get_current_executing_stats() + + ws["A4"] = "当前执行状态" + ws["A4"].font = Font(name="微软雅黑", size=14, bold=True, color="2F5496") + + headers = ["指标", "数值"] + for col, header in enumerate(headers, start=1): + ws.cell(row=5, column=col, value=header) + self._apply_header_style(ws, 5, 2) + + metrics = [ + ("正在执行的工作流数", current_stats["processing_workflows"]), + ("活跃用户数", current_stats["active_users"]), + ("活跃APP数", current_stats["active_apps"]), + ("待处理行数", current_stats["pending_rows"]), + ("已完成行数", current_stats["completed_rows"]), + ] + + for row, (metric, value) in enumerate(metrics, start=6): + ws.cell(row=row, column=1, value=metric) + ws.cell(row=row, column=2, value=value) + self._apply_data_style(ws, row, row, 2) + + self._auto_adjust_column_width(ws) + + def create_app_usage_sheet(self): + """创建APP使用统计页""" + ws = self.wb.create_sheet("APP使用统计") + + # 标题 + ws.merge_cells("A1:F1") + ws["A1"] = "今天各APP使用统计" + ws["A1"].font = self.title_font + ws["A1"].alignment = self.center_alignment + + # 表头 + headers = ["APP名称", "使用次数", "总行数", "已处理行数", "错误数", "完成率(%)"] + for col, header in enumerate(headers, start=1): + ws.cell(row=3, column=col, value=header) + self._apply_header_style(ws, 3, len(headers)) + + # 数据 + app_stats = self.service.get_today_app_usage_stats() + for row, stat in enumerate(app_stats, start=4): + completion_rate = ( + round((stat["processed_rows"] / stat["total_rows"]) * 100, 2) + if stat["total_rows"] > 0 + else 0 + ) + ws.cell(row=row, column=1, value=stat["app_name"]) + ws.cell(row=row, column=2, value=stat["usage_count"]) + ws.cell(row=row, column=3, value=stat["total_rows"]) + ws.cell(row=row, column=4, value=stat["processed_rows"]) + ws.cell(row=row, column=5, value=stat["error_count"]) + ws.cell(row=row, column=6, value=completion_rate) + + if app_stats: + self._apply_data_style(ws, 4, 3 + len(app_stats), len(headers)) + + # 添加柱状图 - 使用次数 + chart1 = BarChart() + chart1.title = "APP使用次数排行" + chart1.style = 10 + chart1.x_axis.title = "APP" + chart1.y_axis.title = "使用次数" + + data = Reference(ws, min_col=2, min_row=3, max_row=3 + len(app_stats)) + cats = Reference(ws, min_col=1, min_row=4, max_row=3 + len(app_stats)) + chart1.add_data(data, titles_from_data=True) + chart1.set_categories(cats) + chart1.height = 10 + chart1.width = 20 + ws.add_chart(chart1, "H3") + + # 添加饼图 - 使用次数占比 + if len(app_stats) <= 10: + chart2 = PieChart() + chart2.title = "APP使用次数占比" + chart2.style = 10 + data = Reference(ws, min_col=2, min_row=4, max_row=3 + len(app_stats)) + cats = Reference(ws, min_col=1, min_row=4, max_row=3 + len(app_stats)) + chart2.add_data(data) + chart2.set_categories(cats) + chart2.height = 10 + chart2.width = 15 + ws.add_chart(chart2, "H20") + + self._auto_adjust_column_width(ws) + + def create_hourly_stats_sheet(self): + """创建小时级别统计页""" + ws = self.wb.create_sheet("小时执行统计") + + # 标题 + ws.merge_cells("A1:H1") + ws["A1"] = "最近24小时执行统计" + ws["A1"].font = self.title_font + ws["A1"].alignment = self.center_alignment + + # 表头 + headers = ["时间段", "总数", "执行中", "已完成", "失败", "待处理", "总行数", "已处理行数"] + for col, header in enumerate(headers, start=1): + ws.cell(row=3, column=col, value=header) + self._apply_header_style(ws, 3, len(headers)) + + # 数据 + hourly_stats = self.service.get_hourly_execution_stats(hours=24) + for row, stat in enumerate(hourly_stats, start=4): + ws.cell(row=row, column=1, value=stat["hour_period"]) + ws.cell(row=row, column=2, value=stat["total_count"]) + ws.cell(row=row, column=3, value=stat["processing_count"]) + ws.cell(row=row, column=4, value=stat["completed_count"]) + ws.cell(row=row, column=5, value=stat["failed_count"]) + ws.cell(row=row, column=6, value=stat["pending_count"]) + ws.cell(row=row, column=7, value=stat["total_rows"]) + ws.cell(row=row, column=8, value=stat["processed_rows"]) + + if hourly_stats: + self._apply_data_style(ws, 4, 3 + len(hourly_stats), len(headers)) + + # 添加折线图 - 执行趋势 + chart = LineChart() + chart.title = "执行数量趋势" + chart.style = 10 + chart.x_axis.title = "时间" + chart.y_axis.title = "数量" + + data = Reference( + ws, min_col=2, min_row=3, max_col=6, max_row=3 + len(hourly_stats) + ) + cats = Reference(ws, min_col=1, min_row=4, max_row=3 + len(hourly_stats)) + chart.add_data(data, titles_from_data=True) + chart.set_categories(cats) + chart.height = 12 + chart.width = 25 + ws.add_chart(chart, "J3") + + self._auto_adjust_column_width(ws) + + def create_user_stats_sheet(self): + """创建用户统计页""" + ws = self.wb.create_sheet("用户统计") + + # 标题 + ws.merge_cells("A1:G1") + ws["A1"] = "今天用户批量处理统计" + ws["A1"].font = self.title_font + ws["A1"].alignment = self.center_alignment + + # 表头 + headers = ["用户名", "邮箱", "批次数", "总行数", "已处理行数", "错误数", "使用APP数"] + for col, header in enumerate(headers, start=1): + ws.cell(row=3, column=col, value=header) + self._apply_header_style(ws, 3, len(headers)) + + # 数据 + user_stats = self.service.get_user_batch_stats() + for row, stat in enumerate(user_stats, start=4): + ws.cell(row=row, column=1, value=stat["account_name"]) + ws.cell(row=row, column=2, value=stat["account_email"]) + ws.cell(row=row, column=3, value=stat["batch_count"]) + ws.cell(row=row, column=4, value=stat["total_rows"]) + ws.cell(row=row, column=5, value=stat["processed_rows"]) + ws.cell(row=row, column=6, value=stat["error_count"]) + ws.cell(row=row, column=7, value=stat["app_count"]) + + if user_stats: + self._apply_data_style(ws, 4, 3 + len(user_stats), len(headers)) + + # 添加柱状图 - 用户批次数排行 + chart = BarChart() + chart.title = "用户批次数排行 TOP 10" + chart.style = 10 + chart.x_axis.title = "用户" + chart.y_axis.title = "批次数" + + max_rows = min(10, len(user_stats)) + data = Reference(ws, min_col=3, min_row=3, max_row=3 + max_rows) + cats = Reference(ws, min_col=1, min_row=4, max_row=3 + max_rows) + chart.add_data(data, titles_from_data=True) + chart.set_categories(cats) + chart.height = 10 + chart.width = 20 + ws.add_chart(chart, "I3") + + self._auto_adjust_column_width(ws) + + def create_error_analysis_sheet(self): + """创建错误分析页""" + ws = self.wb.create_sheet("错误分析") + + # 标题 + ws.merge_cells("A1:H1") + ws["A1"] = "今天错误分析统计" + ws["A1"].font = self.title_font + ws["A1"].alignment = self.center_alignment + + # 获取错误分析数据 + error_stats = self.service.get_error_analysis_stats() + + # 1. 错误类型TOP10统计 + ws["A3"] = "错误类型TOP10统计" + ws["A3"].font = Font(name="微软雅黑", size=14, bold=True, color="2F5496") + + headers = ["错误类型", "错误次数", "错误示例"] + for col, header in enumerate(headers, start=1): + ws.cell(row=4, column=col, value=header) + self._apply_header_style(ws, 4, len(headers)) + + for row, error_type in enumerate(error_stats["error_types"], start=5): + ws.cell(row=row, column=1, value=error_type["error_type"]) + ws.cell(row=row, column=2, value=error_type["error_count"]) + ws.cell(row=row, column=3, value=error_type["error_example"]) + + if error_stats["error_types"]: + self._apply_data_style(ws, 5, 4 + len(error_stats["error_types"]), len(headers)) + + # 添加饼图 - 错误类型分布 + if len(error_stats["error_types"]) <= 10: + chart1 = PieChart() + chart1.title = "错误类型分布" + chart1.style = 10 + data = Reference(ws, min_col=2, min_row=4, max_row=4 + len(error_stats["error_types"])) + cats = Reference(ws, min_col=1, min_row=5, max_row=4 + len(error_stats["error_types"])) + chart1.add_data(data) + chart1.set_categories(cats) + chart1.height = 10 + chart1.width = 15 + ws.add_chart(chart1, "E4") + + # 2. 各APP错误分布 + start_row = 4 + len(error_stats["error_types"]) + 3 + ws.cell(row=start_row, column=1, value="各APP错误分布") + ws.cell(row=start_row, column=1).font = Font(name="微软雅黑", size=14, bold=True, color="2F5496") + + app_headers = ["APP名称", "总错误数", "受影响工作流", "频率限制", "配额超限", "重试超限", "错误示例"] + for col, header in enumerate(app_headers, start=1): + ws.cell(row=start_row + 1, column=col, value=header) + self._apply_header_style(ws, start_row + 1, len(app_headers)) + + for row, app_error in enumerate(error_stats["app_errors"], start=start_row + 2): + ws.cell(row=row, column=1, value=app_error["app_name"]) + ws.cell(row=row, column=2, value=app_error["total_errors"]) + ws.cell(row=row, column=3, value=app_error["affected_workflows"]) + ws.cell(row=row, column=4, value=app_error["rate_limit_errors"]) + ws.cell(row=row, column=5, value=app_error["quota_errors"]) + ws.cell(row=row, column=6, value=app_error["retry_errors"]) + ws.cell(row=row, column=7, value=app_error["error_example"]) + + if error_stats["app_errors"]: + self._apply_data_style(ws, start_row + 2, start_row + 1 + len(error_stats["app_errors"]), len(app_headers)) + + # 添加柱状图 - APP错误排行 + chart2 = BarChart() + chart2.title = "APP错误数量排行" + chart2.style = 10 + chart2.x_axis.title = "APP" + chart2.y_axis.title = "错误数量" + + max_rows = min(10, len(error_stats["app_errors"])) + data = Reference(ws, min_col=2, min_row=start_row + 1, max_row=start_row + 1 + max_rows) + cats = Reference(ws, min_col=1, min_row=start_row + 2, max_row=start_row + 1 + max_rows) + chart2.add_data(data, titles_from_data=True) + chart2.set_categories(cats) + chart2.height = 10 + chart2.width = 20 + ws.add_chart(chart2, "I" + str(start_row + 1)) + + # 3. 具体错误示例 + examples_start_row = start_row + 2 + len(error_stats["app_errors"]) + 3 + ws.cell(row=examples_start_row, column=1, value="最新错误示例") + ws.cell(row=examples_start_row, column=1).font = Font(name="微软雅黑", size=14, bold=True, color="2F5496") + + example_headers = ["APP名称", "错误时间", "行索引", "重试次数", "错误详情"] + for col, header in enumerate(example_headers, start=1): + ws.cell(row=examples_start_row + 1, column=col, value=header) + self._apply_header_style(ws, examples_start_row + 1, len(example_headers)) + + for row, example in enumerate(error_stats["error_examples"], start=examples_start_row + 2): + ws.cell(row=row, column=1, value=example["app_name"]) + ws.cell(row=row, column=2, value=example["created_at"]) + ws.cell(row=row, column=3, value=example["row_index"]) + ws.cell(row=row, column=4, value=example["error_count"]) + ws.cell(row=row, column=5, value=example["error"]) + + if error_stats["error_examples"]: + self._apply_data_style(ws, examples_start_row + 2, examples_start_row + 1 + len(error_stats["error_examples"]), len(example_headers)) + + # 4. 错误统计汇总 + summary_start_row = examples_start_row + 2 + len(error_stats["error_examples"]) + 3 + ws.cell(row=summary_start_row, column=1, value="错误统计汇总") + ws.cell(row=summary_start_row, column=1).font = Font(name="微软雅黑", size=14, bold=True, color="2F5496") + + summary_headers = ["指标", "数值"] + for col, header in enumerate(summary_headers, start=1): + ws.cell(row=summary_start_row + 1, column=col, value=header) + self._apply_header_style(ws, summary_start_row + 1, 2) + + summary_data = [ + ("总错误数", error_stats["total_errors"]), + ("受影响APP数", error_stats["affected_apps"]), + ("错误类型数", len(error_stats["error_types"])), + ] + + for row, (metric, value) in enumerate(summary_data, start=summary_start_row + 2): + ws.cell(row=row, column=1, value=metric) + ws.cell(row=row, column=2, value=value) + self._apply_data_style(ws, row, row, 2) + + self._auto_adjust_column_width(ws) + + def generate_report(self, output_path: str | None = None) -> str: + """ + 生成完整的Excel报表 + + Args: + output_path: 输出文件路径,如果为None则自动生成 + + Returns: + str: 生成的文件路径 + """ + # 创建各个工作表 + self.create_summary_sheet() + self.create_app_usage_sheet() + self.create_hourly_stats_sheet() + self.create_user_stats_sheet() + self.create_error_analysis_sheet() + + # 确定输出路径 + if output_path is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = f"batch_workflow_report_{timestamp}.xlsx" + + # 保存文件 + self.wb.save(output_path) + return output_path + + +def generate_batch_workflow_report(output_path: str | None = None) -> str: + """ + 生成批量工作流统计报表 + + Args: + output_path: 输出文件路径,如果为None则自动生成 + + Returns: + str: 生成的Excel文件路径 + + 示例: + >>> # 生成报表到默认路径 + >>> filepath = generate_batch_workflow_report() + >>> print(f"报表已生成: {filepath}") + + >>> # 生成报表到指定路径 + >>> filepath = generate_batch_workflow_report("/tmp/report.xlsx") + """ + # 创建Flask应用上下文 + app = create_app() + with app.app_context(): + generator = ExcelReportGenerator() + filepath = generator.generate_report(output_path) + print(f"✅ Excel报表已生成: {filepath}") + return filepath + + +if __name__ == "__main__": + # 生成Excel报表 + report_path = generate_batch_workflow_report() + print("\n📊 批量工作流统计报表已生成") + print(f"📁 文件路径: {report_path}") + print("📈 报表包含以下工作表:") + print(" 1. 概览汇总 - 当前执行状态概览") + print(" 2. APP使用统计 - 各APP使用情况及图表") + print(" 3. 小时执行统计 - 24小时执行趋势") + print(" 4. 用户统计 - 用户批量处理统计") + print(" 5. 错误分析 - 错误类型TOP10、APP错误分布、具体错误示例") + diff --git a/api/services/billing_extend.py b/api/services/billing_extend.py new file mode 100644 index 000000000..847c64409 --- /dev/null +++ b/api/services/billing_extend.py @@ -0,0 +1,246 @@ +import hashlib +import json +import logging +import threading +import time +import uuid +from datetime import datetime + +import requests +from flask import Response, request + +from configs import dify_config +from extensions.ext_database import db +from models.account import Account +from models.account_money_extend import AccountLayoverRecordExtend, AccountMoneyExtend +from models.ai_draw_extnd import ForwardingAddressExtend + +# Create a shared dictionary +billing = {} +# Create a lock object +dict_lock = threading.Lock() + + +def thread_billing_write(key: str, billing_info: ForwardingAddressExtend): + global billing + with dict_lock: + billing[key] = [ + json.dumps( + { + "id": billing_info.id, + "path": billing_info.path, + "models": billing_info.models, + "status": billing_info.status, + "billing": billing_info.billing, + "description": billing_info.description, + "content_type": billing_info.content_type, + "forwarding_id": billing_info.forwarding_id, + } + ), + int(time.time()), + ] + + +def thread_billing_read(forwarding_id: str, path: str) -> ForwardingAddressExtend | None: + global billing + url_path = "/".join(path.split("/")[1:]) + key = "{}_{}".format(forwarding_id, url_path) + info = billing.get(key) + if info is not None and info[1] < int(time.time()) + 600: + if info[0] is not None: + address_dict_back = json.loads(info[0]) + return ForwardingAddressExtend( + id=address_dict_back["id"], + path=address_dict_back["path"], + models=address_dict_back["models"], + status=address_dict_back["status"], + billing=address_dict_back["billing"], + description=address_dict_back["description"], + content_type=address_dict_back["content_type"], + forwarding_id=address_dict_back["forwarding_id"], + ) + billing_info: ForwardingAddressExtend = ( + db.session.query(ForwardingAddressExtend) + .filter(ForwardingAddressExtend.forwarding_id == forwarding_id, ForwardingAddressExtend.path == url_path) + .first() + ) + if billing_info is not None: + thread_billing_write(key, billing_info) + else: + billing[key] = [None, int(time.time())] + return billing_info + + +class AiDrawBilling: + @classmethod + def calculate_user_billing_information(cls, account_id: str, forwarding: str, path: str, data: dict) -> (int, str): + """ + Handling fee processing for forward transmission + :param account_id: string + :param forwarding: string + :param path: string + :param data: dict + """ + account: Account = db.session.query(Account).filter(Account.id == account_id).first() + if account is None: + return 0, "user does not exist" + info: ForwardingAddressExtend = thread_billing_read(forwarding, path) + if info is None: + return 0, "count not found" + # differentiate request types + funds, money = info.funds_settlement(data, info.decode_billing) + # 计费 + account_money = db.session.query(AccountMoneyExtend).filter(AccountMoneyExtend.account_id == account.id).first() + if account_money: + if float(account_money.used_quota) + money > float(account_money.total_quota): + return 500, "Insufficient balance" + db.session.query(AccountMoneyExtend).filter(AccountMoneyExtend.account_id == account.id).update( + {"used_quota": float(account_money.used_quota) + money} + ) + else: + account_money_add = AccountMoneyExtend( + account_id=account.id, + used_quota=money, + total_quota=15, # TODO 初始总额度这里到时候默认15要改 + ) + db.session.add(account_money_add) + # 储存记录 + db.session.add( + AccountLayoverRecordExtend( + account_id=account_id, forwarding_id=forwarding, money=money, info=funds, created_at=datetime.now() + ) + ) + db.session.commit() + + return money, "" + + @classmethod + def ocr_translate(cls, image_base64, to_lang_code, from_code): + # 获取凭据 + if not dify_config.YOUDAO_APP_KEY or not dify_config.YOUDAO_APP_SECRET: + return "", "请在配置文件中设置有道翻译的APP_KEY和APP_SECRET" + + # 准备API请求参数 + salt = str(uuid.uuid4()) + curtime = str(int(time.time())) + + # 计算input + if len(image_base64) <= 20: + input_str = image_base64 + else: + input_str = image_base64[:10] + str(len(image_base64)) + image_base64[-10:] + + # 计算签名 + sign_str = dify_config.YOUDAO_APP_KEY + input_str + salt + curtime + dify_config.YOUDAO_APP_SECRET + sign = hashlib.sha256(sign_str.encode('utf-8')).hexdigest() + + # 发送请求 + try: + response = requests.post( + 'https://openapi.youdao.com/ocrtransapi', + data={ + 'type': '1', # Base64类型 + 'q': image_base64, + 'from': from_code, + 'to': to_lang_code, + 'appKey': dify_config.YOUDAO_APP_KEY, + 'salt': salt, + 'sign': sign, + 'signType': 'v3', + 'curtime': curtime, + 'render': '1', + 'docType': 'json' + }, + timeout=30 + ) + result = response.json() + + # 检查错误码 + if result.get('errorCode') == '0': + return result.get('render_image', ''), "" + return "", f"请求失败: {result.get('msg')}" + + except Exception as e: + return "", f"翻译出错: {str(e)}" + + @classmethod + def billing_forward(cls, forwarding, path_list, kwargs, auth_header, path): + # Get request method + method = request.method + target_url = f"{forwarding.address}{'/'.join(path_list[1:])}" + + # Get request data + try: + data = request.get_data() + except: + data = "" + try: + cache_data = request.get_json() + except: + cache_data = {} + # calculate user deduction information + for key, value in request.args.items(): + cache_data[key] = value + for key, value in request.form.items(): + cache_data[key] = value + # Wait for an asynchronous task to complete and get the return value + headers = {key: value for key, value in request.headers if key != "Host"} + # Wait for an asynchronous task to complete and get the return value + money, err = cls.calculate_user_billing_information(kwargs.get("account", ''), forwarding.id, path, cache_data) + if len(err) > 0 and money == 500: + return Response(err, status=500) + for key, value in json.loads(forwarding.header): + headers[key] = value + # Set Cookie - 移除Bearer前缀 + token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else auth_header + headers["cookie"] = f"x-token={token};" + # Disable gzip compression + headers["Accept-Encoding"] = "identity" + # Forward the request according to the request method + logging.warning("target_url: {}. json: {}".format(target_url, json.dumps(request.args))) + logging.warning("headers: {}".format(json.dumps(headers))) + try: + if method == 'GET': + resp = requests.get(target_url, headers=headers, params=request.args, allow_redirects=False) + elif method == "POST": + resp = requests.post(target_url, headers=headers, data=data, params=request.args) + elif method == "PUT": + resp = requests.put(target_url, headers=headers, data=data, params=request.args) + elif method == "DELETE": + resp = requests.delete(target_url, headers=headers, data=data, params=request.args) + else: + return Response("Method not allowed", status=405) + + logging.warning("Response status: {}, content: {}".format(resp.status_code, resp.text[:500])) + except Exception as e: + logging.exception("Request failed: {}".format(str(e))) + return Response("Forward request failed: {}".format(str(e)), status=500) + + # Create response + response = Response(resp.content, status=resp.status_code) + for key, value in resp.headers.items(): + response.headers[key] = value + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS, DELETE" + response.headers["Access-Control-Max-Age"] = "3600" + response.headers["Access-Control-Allow-Headers"] = "x-requested-with,Authorization,token, content-type" + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["X-Accel-Redirect"] = "" + try: + # Compatible processing + body = response.get_json() + if body is not None and isinstance(body, dict): + if "metadata" in body.keys(): + if "usage" in body["metadata"].keys(): + body["metadata"]["usage"]["total_price"] = money + else: + body["metadata"]["usage"] = {"total_price": money} + else: + body["metadata"] = {"usage": {"total_price": money}} + # json encode + body = json.dumps(body) + if body is not None and body != "null" and body != any: + response.data = body + except: + pass + return response diff --git a/api/services/ding_talk_extend.py b/api/services/ding_talk_extend.py index 708be3181..982a2acc6 100644 --- a/api/services/ding_talk_extend.py +++ b/api/services/ding_talk_extend.py @@ -1,21 +1,22 @@ import json import logging -import time import secrets +import time + import requests -from pypinyin import lazy_pinyin from alibabacloud_dingtalk.oauth2_1_0 import models as dingtalkoauth_2__1__0_models from alibabacloud_dingtalk.oauth2_1_0.client import Client as dingtalkoauth2_1_0Client from alibabacloud_tea_openapi import models as open_api_models from alibabacloud_tea_util.client import Client as UtilClient from flask import request +from pypinyin import lazy_pinyin from configs import dify_config from extensions.ext_database import db from libs.helper import extract_remote_ip -from models.account import Account, AccountIntegrate +from models.account import Account +from models.system_extend import SystemIntegrationClassify, SystemIntegrationExtend from services.account_service import AccountService, RegisterService, TenantService -from models.system_extend import SystemIntegrationExtend, SystemIntegrationClassify from services.account_service_extend import TenantExtendService logger = logging.getLogger(__name__) @@ -146,7 +147,7 @@ class DingTalkService: if err != "": return "", f"Failed to obtain token: {err}" response = requests.get( - f"https://api.dingtalk.com/v1.0/contact/users/me", + "https://api.dingtalk.com/v1.0/contact/users/me", headers={ "x-acs-dingtalk-access-token": userToken }, ) # Check the response status code diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 45c2a06b4..9b588b372 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,16 +1,16 @@ -import json # extend: oauth2 -import re # extend: oauth2 +import json # extend: oauth2 +import re # extend: oauth2 from enum import StrEnum +from flask import request # extend: oauth2 from pydantic import BaseModel, ConfigDict, Field from configs import dify_config -from extensions.ext_database import db # extend: oauth2 -from flask import request # extend: oauth2 -from extensions.ext_redis import redis_client # extend: oauth2 +from extensions.ext_database import db # extend: oauth2 +from extensions.ext_redis import redis_client # extend: oauth2 +from models.system_extend import SystemIntegrationClassify, SystemIntegrationExtend # Extend DingTalk third-party login from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService -from models.system_extend import SystemIntegrationExtend, SystemIntegrationClassify # Extend DingTalk third-party login class SubscriptionModel(BaseModel): diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 310333454..d0a2f960e 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,3 @@ -import logging from flask_login import current_user diff --git a/api/tasks/extend/update_account_money_when_workflow_node_execution_created_extend.py b/api/tasks/extend/update_account_money_when_workflow_node_execution_created_extend.py index 2af02af52..5af24616a 100644 --- a/api/tasks/extend/update_account_money_when_workflow_node_execution_created_extend.py +++ b/api/tasks/extend/update_account_money_when_workflow_node_execution_created_extend.py @@ -14,25 +14,36 @@ from models.account_money_extend import AccountMoneyExtend from models.api_token_money_extend import ApiTokenMessageJoinsExtend, ApiTokenMoneyExtend from models.enums import CreatorUserRole from models.model_extend import EndUserAccountJoinsExtend -from models.workflow import WorkflowNodeExecutionModel @shared_task(queue="extend_high", bind=True, max_retries=3) -def update_account_money_when_workflow_node_execution_created_extend(self, workflow_node_execution_dict: dict): - """ """ - workflowNodeExecution = WorkflowNodeExecutionModel(**workflow_node_execution_dict) - # 非大模型则跳过 - if workflowNodeExecution.node_type != NodeType.LLM.value: +def update_account_money_when_workflow_node_execution_created_extend( + self, workflow_node_execution_dict: dict): + """ + 计算工作流节点执行的费用并更新账户额度 + :param workflow_node_execution_dict: 工作流节点执行字典 + """ + + if not workflow_node_execution_dict: + logging.warning(click.style("工作流节点数据为空", fg="yellow")) return - logging.info(click.style("工作流节点ID: {}".format(workflowNodeExecution.id), fg="cyan")) + + # 非大模型则跳过 + if workflow_node_execution_dict.get("node_type") != NodeType.LLM.value: + return + + node_id = workflow_node_execution_dict.get("id") + logging.info(click.style("工作流节点ID: {}".format(node_id), fg="cyan")) # 拿到费用 - outputs = json.loads(workflowNodeExecution.outputs) if workflowNodeExecution.outputs else {} + outputs_str = workflow_node_execution_dict.get("outputs") + outputs = json.loads(outputs_str) if outputs_str else {} total_price = Decimal(outputs.get("usage", {}).get("total_price", 0)) currency = outputs.get("usage", {}).get("currency", "USD") if total_price == 0: return - price = float(total_price) if currency == "USD" else (float(total_price) / float(dify_config.RMB_TO_USD_RATE)) + price = float(total_price) if currency == "USD" else ( + float(total_price) / float(dify_config.RMB_TO_USD_RATE)) logging.info(click.style("扣除费用: {}".format(price), fg="green")) try: @@ -40,20 +51,23 @@ def update_account_money_when_workflow_node_execution_created_extend(self, workf # 分两种情况 # web应用的请求,created_by记录的是登录账号的ID,可以拿这个ID来扣钱 # API调用,created_by记录的是节点登录账号ID,真正需要扣钱的在关联表EndUserAccountJoinsExtend,需要多做一层查询 - payerId = workflowNodeExecution.created_by # 付钱的ID - if workflowNodeExecution.created_by_role == CreatorUserRole.END_USER.value: - account = db.session.query(Account).filter(Account.id == workflowNodeExecution.created_by).first() + created_by = workflow_node_execution_dict.get("created_by") + created_by_role = workflow_node_execution_dict.get("created_by_role") + payerId = created_by # 付钱的ID + if created_by_role == CreatorUserRole.END_USER.value: + account = db.session.query(Account).filter(Account.id == created_by).first() if not account: end_user_account_joins = ( db.session.query(EndUserAccountJoinsExtend) - .filter(EndUserAccountJoinsExtend.end_user_id == workflowNodeExecution.created_by) + .filter(EndUserAccountJoinsExtend.end_user_id == created_by) .order_by(EndUserAccountJoinsExtend.created_at.desc()) .first() ) if end_user_account_joins: payerId = end_user_account_joins.account_id - account_money = db.session.query(AccountMoneyExtend).filter(AccountMoneyExtend.account_id == payerId).first() + account_money = db.session.query(AccountMoneyExtend).filter( + AccountMoneyExtend.account_id == payerId).first() logging.info(click.style("更新账号额度,账号ID: {}".format(payerId), fg="green")) if account_money: db.session.query(AccountMoneyExtend).filter(AccountMoneyExtend.account_id == payerId).update( @@ -69,14 +83,16 @@ def update_account_money_when_workflow_node_execution_created_extend(self, workf db.session.add(account_money_add) # 扣掉密钥的钱 + workflow_run_id = workflow_node_execution_dict.get("workflow_run_id") api_token_message = ( db.session.query(ApiTokenMessageJoinsExtend) - .filter(ApiTokenMessageJoinsExtend.record_id == workflowNodeExecution.workflow_run_id) + .filter(ApiTokenMessageJoinsExtend.record_id == workflow_run_id) .first() ) if api_token_message: - logging.info(click.style("更新密钥额度,密钥ID: {}".format(api_token_message.app_token_id), fg="green")) + logging.info(click.style("更新密钥额度,密钥ID: {}".format( + api_token_message.app_token_id), fg="green")) db.session.query(ApiTokenMoneyExtend).filter( ApiTokenMoneyExtend.app_token_id == api_token_message.app_token_id ).update( @@ -90,12 +106,14 @@ def update_account_money_when_workflow_node_execution_created_extend(self, workf db.session.commit() except SQLAlchemyError as e: logging.exception( - click.style(f"工作流节点ID: {format(workflowNodeExecution.id)},扣除费用:{format(price)} 数据库异常,60秒后进行重试,", fg="red") + click.style(f"工作流节点ID: {format(node_id)},扣除费用:" + f"{format(price)} 数据库异常,60秒后进行重试,", fg="red") ) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds except Exception as e: logging.exception( - click.style(f"工作流节点ID: {format(workflowNodeExecution.id)},扣除费用:{format(price)} 异常报错,60秒后进行重试,", fg="red") + click.style(f"工作流节点ID: {format(node_id)},扣除费用:" + f"{format(price)} 异常报错,60秒后进行重试,", fg="red") ) raise self.retry(exc=e, countdown=60) diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 77db5ef21..4ba0034f7 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -1,7 +1,5 @@ from textwrap import dedent -from sympy import false - from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer diff --git a/api/uv.lock b/api/uv.lock index 7e67a84ce..2666cd5ee 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -142,12 +142,39 @@ version = "1.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/a0/87/1d7019d23891897cb076b2f7e3c81ab3c2ba91de3bb067196f675d60d34c/alibabacloud-credentials-api-1.0.0.tar.gz", hash = "sha256:8c340038d904f0218d7214a8f4088c31912bfcf279af2cbc7d9be4897a97dd2f", size = 2330, upload-time = "2025-01-13T05:53:04.931Z" } +[[package]] +name = "alibabacloud-dingtalk" +version = "2.1.99" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "alibabacloud-endpoint-util" }, + { name = "alibabacloud-gateway-dingtalk" }, + { name = "alibabacloud-gateway-spi" }, + { name = "alibabacloud-openapi-util" }, + { name = "alibabacloud-tea-openapi" }, + { name = "alibabacloud-tea-util" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/41/5909fcad753c10dbe02080345e7d062fceff6ef696295f1c2ef52c217367/alibabacloud_dingtalk-2.1.99.tar.gz", hash = "sha256:1e9cfb9b2d4eefa3250dbf56a3dfac69c3dbf111d02cd6b04f4cae0aa399b41a", size = 1799354, upload-time = "2025-04-29T02:37:49.714Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/e9/99dd01c8daa12edf4cd9f454e55f3906465e71cd6095b0005d09eb0cd143/alibabacloud_dingtalk-2.1.99-py3-none-any.whl", hash = "sha256:a86094d1f99a2ee9bbc32266cf202214d1dab3329786fde593e2aa6697106cde", size = 1910635, upload-time = "2025-04-29T02:37:47.402Z" }, +] + [[package]] name = "alibabacloud-endpoint-util" version = "0.0.4" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/92/7d/8cc92a95c920e344835b005af6ea45a0db98763ad6ad19299d26892e6c8d/alibabacloud_endpoint_util-0.0.4.tar.gz", hash = "sha256:a593eb8ddd8168d5dc2216cd33111b144f9189fcd6e9ca20e48f358a739bbf90", size = 2813, upload-time = "2025-06-12T07:20:52.572Z" } +[[package]] +name = "alibabacloud-gateway-dingtalk" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "alibabacloud-gateway-spi" }, + { name = "alibabacloud-tea-util" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d2/40/751d8bdf133d7fcf053f10c98e8e506810e7bee06458a02eaaa14d30ac26/alibabacloud_gateway_dingtalk-1.0.2.tar.gz", hash = "sha256:acea8b0b1d11e0394913f0b0899ddd19c0bfceab716060449b57fcc250ceb300", size = 2938, upload-time = "2023-04-25T09:48:42.249Z" } + [[package]] name = "alibabacloud-gateway-spi" version = "0.0.3" @@ -911,6 +938,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/7a/10bf5dc92d13cc03230190fcc5016a0b138d99e5b36b8b89ee0fe1680e10/chromadb-0.5.20-py3-none-any.whl", hash = "sha256:9550ba1b6dce911e35cac2568b301badf4b42f457b99a432bdeec2b6b9dd3680", size = 617884, upload-time = "2024-11-19T05:13:56.29Z" }, ] +[[package]] +name = "circuitbreaker" +version = "2.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/ac/de7a92c4ed39cba31fe5ad9203b76a25ca67c530797f6bb420fff5f65ccb/circuitbreaker-2.1.3.tar.gz", hash = "sha256:1a4baee510f7bea3c91b194dcce7c07805fe96c4423ed5594b75af438531d084", size = 10787, upload-time = "2025-03-31T08:12:08.963Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/34/15f08edd4628f65217de1fc3c1a27c82e46fe357d60c217fc9881e12ebcc/circuitbreaker-2.1.3-py3-none-any.whl", hash = "sha256:87ba6a3ed03fdc7032bc175561c2b04d52ade9d5faf94ca2b035fbdc5e6b1dd1", size = 7737, upload-time = "2025-03-31T08:12:07.802Z" }, +] + [[package]] name = "click" version = "8.2.1" @@ -1164,43 +1200,43 @@ sdist = { url = "https://files.pythonhosted.org/packages/6b/b0/e595ce2a2527e169c [[package]] name = "cryptography" -version = "45.0.5" +version = "44.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/95/1e/49527ac611af559665f71cbb8f92b332b5ec9c6fbc4e88b0f8e92f5e85df/cryptography-45.0.5.tar.gz", hash = "sha256:72e76caa004ab63accdf26023fccd1d087f6d90ec6048ff33ad0445abf7f605a", size = 744903, upload-time = "2025-07-02T13:06:25.941Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/d6/1411ab4d6108ab167d06254c5be517681f1e331f90edf1379895bcb87020/cryptography-44.0.3.tar.gz", hash = "sha256:fe19d8bc5536a91a24a8133328880a41831b6c5df54599a8417b62fe015d3053", size = 711096, upload-time = "2025-05-02T19:36:04.667Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/fb/09e28bc0c46d2c547085e60897fea96310574c70fb21cd58a730a45f3403/cryptography-45.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:101ee65078f6dd3e5a028d4f19c07ffa4dd22cce6a20eaa160f8b5219911e7d8", size = 7043092, upload-time = "2025-07-02T13:05:01.514Z" }, - { url = "https://files.pythonhosted.org/packages/b1/05/2194432935e29b91fb649f6149c1a4f9e6d3d9fc880919f4ad1bcc22641e/cryptography-45.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3a264aae5f7fbb089dbc01e0242d3b67dffe3e6292e1f5182122bdf58e65215d", size = 4205926, upload-time = "2025-07-02T13:05:04.741Z" }, - { url = "https://files.pythonhosted.org/packages/07/8b/9ef5da82350175e32de245646b1884fc01124f53eb31164c77f95a08d682/cryptography-45.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e74d30ec9c7cb2f404af331d5b4099a9b322a8a6b25c4632755c8757345baac5", size = 4429235, upload-time = "2025-07-02T13:05:07.084Z" }, - { url = "https://files.pythonhosted.org/packages/7c/e1/c809f398adde1994ee53438912192d92a1d0fc0f2d7582659d9ef4c28b0c/cryptography-45.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:3af26738f2db354aafe492fb3869e955b12b2ef2e16908c8b9cb928128d42c57", size = 4209785, upload-time = "2025-07-02T13:05:09.321Z" }, - { url = "https://files.pythonhosted.org/packages/d0/8b/07eb6bd5acff58406c5e806eff34a124936f41a4fb52909ffa4d00815f8c/cryptography-45.0.5-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e6c00130ed423201c5bc5544c23359141660b07999ad82e34e7bb8f882bb78e0", size = 3893050, upload-time = "2025-07-02T13:05:11.069Z" }, - { url = "https://files.pythonhosted.org/packages/ec/ef/3333295ed58d900a13c92806b67e62f27876845a9a908c939f040887cca9/cryptography-45.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:dd420e577921c8c2d31289536c386aaa30140b473835e97f83bc71ea9d2baf2d", size = 4457379, upload-time = "2025-07-02T13:05:13.32Z" }, - { url = "https://files.pythonhosted.org/packages/d9/9d/44080674dee514dbb82b21d6fa5d1055368f208304e2ab1828d85c9de8f4/cryptography-45.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:d05a38884db2ba215218745f0781775806bde4f32e07b135348355fe8e4991d9", size = 4209355, upload-time = "2025-07-02T13:05:15.017Z" }, - { url = "https://files.pythonhosted.org/packages/c9/d8/0749f7d39f53f8258e5c18a93131919ac465ee1f9dccaf1b3f420235e0b5/cryptography-45.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:ad0caded895a00261a5b4aa9af828baede54638754b51955a0ac75576b831b27", size = 4456087, upload-time = "2025-07-02T13:05:16.945Z" }, - { url = "https://files.pythonhosted.org/packages/09/d7/92acac187387bf08902b0bf0699816f08553927bdd6ba3654da0010289b4/cryptography-45.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9024beb59aca9d31d36fcdc1604dd9bbeed0a55bface9f1908df19178e2f116e", size = 4332873, upload-time = "2025-07-02T13:05:18.743Z" }, - { url = "https://files.pythonhosted.org/packages/03/c2/840e0710da5106a7c3d4153c7215b2736151bba60bf4491bdb421df5056d/cryptography-45.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:91098f02ca81579c85f66df8a588c78f331ca19089763d733e34ad359f474174", size = 4564651, upload-time = "2025-07-02T13:05:21.382Z" }, - { url = "https://files.pythonhosted.org/packages/2e/92/cc723dd6d71e9747a887b94eb3827825c6c24b9e6ce2bb33b847d31d5eaa/cryptography-45.0.5-cp311-abi3-win32.whl", hash = "sha256:926c3ea71a6043921050eaa639137e13dbe7b4ab25800932a8498364fc1abec9", size = 2929050, upload-time = "2025-07-02T13:05:23.39Z" }, - { url = "https://files.pythonhosted.org/packages/1f/10/197da38a5911a48dd5389c043de4aec4b3c94cb836299b01253940788d78/cryptography-45.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:b85980d1e345fe769cfc57c57db2b59cff5464ee0c045d52c0df087e926fbe63", size = 3403224, upload-time = "2025-07-02T13:05:25.202Z" }, - { url = "https://files.pythonhosted.org/packages/fe/2b/160ce8c2765e7a481ce57d55eba1546148583e7b6f85514472b1d151711d/cryptography-45.0.5-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f3562c2f23c612f2e4a6964a61d942f891d29ee320edb62ff48ffb99f3de9ae8", size = 7017143, upload-time = "2025-07-02T13:05:27.229Z" }, - { url = "https://files.pythonhosted.org/packages/c2/e7/2187be2f871c0221a81f55ee3105d3cf3e273c0a0853651d7011eada0d7e/cryptography-45.0.5-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3fcfbefc4a7f332dece7272a88e410f611e79458fab97b5efe14e54fe476f4fd", size = 4197780, upload-time = "2025-07-02T13:05:29.299Z" }, - { url = "https://files.pythonhosted.org/packages/b9/cf/84210c447c06104e6be9122661159ad4ce7a8190011669afceeaea150524/cryptography-45.0.5-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:460f8c39ba66af7db0545a8c6f2eabcbc5a5528fc1cf6c3fa9a1e44cec33385e", size = 4420091, upload-time = "2025-07-02T13:05:31.221Z" }, - { url = "https://files.pythonhosted.org/packages/3e/6a/cb8b5c8bb82fafffa23aeff8d3a39822593cee6e2f16c5ca5c2ecca344f7/cryptography-45.0.5-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:9b4cf6318915dccfe218e69bbec417fdd7c7185aa7aab139a2c0beb7468c89f0", size = 4198711, upload-time = "2025-07-02T13:05:33.062Z" }, - { url = "https://files.pythonhosted.org/packages/04/f7/36d2d69df69c94cbb2473871926daf0f01ad8e00fe3986ac3c1e8c4ca4b3/cryptography-45.0.5-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2089cc8f70a6e454601525e5bf2779e665d7865af002a5dec8d14e561002e135", size = 3883299, upload-time = "2025-07-02T13:05:34.94Z" }, - { url = "https://files.pythonhosted.org/packages/82/c7/f0ea40f016de72f81288e9fe8d1f6748036cb5ba6118774317a3ffc6022d/cryptography-45.0.5-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0027d566d65a38497bc37e0dd7c2f8ceda73597d2ac9ba93810204f56f52ebc7", size = 4450558, upload-time = "2025-07-02T13:05:37.288Z" }, - { url = "https://files.pythonhosted.org/packages/06/ae/94b504dc1a3cdf642d710407c62e86296f7da9e66f27ab12a1ee6fdf005b/cryptography-45.0.5-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:be97d3a19c16a9be00edf79dca949c8fa7eff621763666a145f9f9535a5d7f42", size = 4198020, upload-time = "2025-07-02T13:05:39.102Z" }, - { url = "https://files.pythonhosted.org/packages/05/2b/aaf0adb845d5dabb43480f18f7ca72e94f92c280aa983ddbd0bcd6ecd037/cryptography-45.0.5-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:7760c1c2e1a7084153a0f68fab76e754083b126a47d0117c9ed15e69e2103492", size = 4449759, upload-time = "2025-07-02T13:05:41.398Z" }, - { url = "https://files.pythonhosted.org/packages/91/e4/f17e02066de63e0100a3a01b56f8f1016973a1d67551beaf585157a86b3f/cryptography-45.0.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:6ff8728d8d890b3dda5765276d1bc6fb099252915a2cd3aff960c4c195745dd0", size = 4319991, upload-time = "2025-07-02T13:05:43.64Z" }, - { url = "https://files.pythonhosted.org/packages/f2/2e/e2dbd629481b499b14516eed933f3276eb3239f7cee2dcfa4ee6b44d4711/cryptography-45.0.5-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7259038202a47fdecee7e62e0fd0b0738b6daa335354396c6ddebdbe1206af2a", size = 4554189, upload-time = "2025-07-02T13:05:46.045Z" }, - { url = "https://files.pythonhosted.org/packages/f8/ea/a78a0c38f4c8736287b71c2ea3799d173d5ce778c7d6e3c163a95a05ad2a/cryptography-45.0.5-cp37-abi3-win32.whl", hash = "sha256:1e1da5accc0c750056c556a93c3e9cb828970206c68867712ca5805e46dc806f", size = 2911769, upload-time = "2025-07-02T13:05:48.329Z" }, - { url = "https://files.pythonhosted.org/packages/79/b3/28ac139109d9005ad3f6b6f8976ffede6706a6478e21c889ce36c840918e/cryptography-45.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:90cb0a7bb35959f37e23303b7eed0a32280510030daba3f7fdfbb65defde6a97", size = 3390016, upload-time = "2025-07-02T13:05:50.811Z" }, - { url = "https://files.pythonhosted.org/packages/c0/71/9bdbcfd58d6ff5084687fe722c58ac718ebedbc98b9f8f93781354e6d286/cryptography-45.0.5-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8c4a6ff8a30e9e3d38ac0539e9a9e02540ab3f827a3394f8852432f6b0ea152e", size = 3587878, upload-time = "2025-07-02T13:06:06.339Z" }, - { url = "https://files.pythonhosted.org/packages/f0/63/83516cfb87f4a8756eaa4203f93b283fda23d210fc14e1e594bd5f20edb6/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bd4c45986472694e5121084c6ebbd112aa919a25e783b87eb95953c9573906d6", size = 4152447, upload-time = "2025-07-02T13:06:08.345Z" }, - { url = "https://files.pythonhosted.org/packages/22/11/d2823d2a5a0bd5802b3565437add16f5c8ce1f0778bf3822f89ad2740a38/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:982518cd64c54fcada9d7e5cf28eabd3ee76bd03ab18e08a48cad7e8b6f31b18", size = 4386778, upload-time = "2025-07-02T13:06:10.263Z" }, - { url = "https://files.pythonhosted.org/packages/5f/38/6bf177ca6bce4fe14704ab3e93627c5b0ca05242261a2e43ef3168472540/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:12e55281d993a793b0e883066f590c1ae1e802e3acb67f8b442e721e475e6463", size = 4151627, upload-time = "2025-07-02T13:06:13.097Z" }, - { url = "https://files.pythonhosted.org/packages/38/6a/69fc67e5266bff68a91bcb81dff8fb0aba4d79a78521a08812048913e16f/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:5aa1e32983d4443e310f726ee4b071ab7569f58eedfdd65e9675484a4eb67bd1", size = 4385593, upload-time = "2025-07-02T13:06:15.689Z" }, - { url = "https://files.pythonhosted.org/packages/f6/34/31a1604c9a9ade0fdab61eb48570e09a796f4d9836121266447b0eaf7feb/cryptography-45.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:e357286c1b76403dd384d938f93c46b2b058ed4dfcdce64a770f0537ed3feb6f", size = 3331106, upload-time = "2025-07-02T13:06:18.058Z" }, + { url = "https://files.pythonhosted.org/packages/08/53/c776d80e9d26441bb3868457909b4e74dd9ccabd182e10b2b0ae7a07e265/cryptography-44.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:962bc30480a08d133e631e8dfd4783ab71cc9e33d5d7c1e192f0b7c06397bb88", size = 6670281, upload-time = "2025-05-02T19:34:50.665Z" }, + { url = "https://files.pythonhosted.org/packages/6a/06/af2cf8d56ef87c77319e9086601bef621bedf40f6f59069e1b6d1ec498c5/cryptography-44.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffc61e8f3bf5b60346d89cd3d37231019c17a081208dfbbd6e1605ba03fa137", size = 3959305, upload-time = "2025-05-02T19:34:53.042Z" }, + { url = "https://files.pythonhosted.org/packages/ae/01/80de3bec64627207d030f47bf3536889efee8913cd363e78ca9a09b13c8e/cryptography-44.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58968d331425a6f9eedcee087f77fd3c927c88f55368f43ff7e0a19891f2642c", size = 4171040, upload-time = "2025-05-02T19:34:54.675Z" }, + { url = "https://files.pythonhosted.org/packages/bd/48/bb16b7541d207a19d9ae8b541c70037a05e473ddc72ccb1386524d4f023c/cryptography-44.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:e28d62e59a4dbd1d22e747f57d4f00c459af22181f0b2f787ea83f5a876d7c76", size = 3963411, upload-time = "2025-05-02T19:34:56.61Z" }, + { url = "https://files.pythonhosted.org/packages/42/b2/7d31f2af5591d217d71d37d044ef5412945a8a8e98d5a2a8ae4fd9cd4489/cryptography-44.0.3-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af653022a0c25ef2e3ffb2c673a50e5a0d02fecc41608f4954176f1933b12359", size = 3689263, upload-time = "2025-05-02T19:34:58.591Z" }, + { url = "https://files.pythonhosted.org/packages/25/50/c0dfb9d87ae88ccc01aad8eb93e23cfbcea6a6a106a9b63a7b14c1f93c75/cryptography-44.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:157f1f3b8d941c2bd8f3ffee0af9b049c9665c39d3da9db2dc338feca5e98a43", size = 4196198, upload-time = "2025-05-02T19:35:00.988Z" }, + { url = "https://files.pythonhosted.org/packages/66/c9/55c6b8794a74da652690c898cb43906310a3e4e4f6ee0b5f8b3b3e70c441/cryptography-44.0.3-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c6cd67722619e4d55fdb42ead64ed8843d64638e9c07f4011163e46bc512cf01", size = 3966502, upload-time = "2025-05-02T19:35:03.091Z" }, + { url = "https://files.pythonhosted.org/packages/b6/f7/7cb5488c682ca59a02a32ec5f975074084db4c983f849d47b7b67cc8697a/cryptography-44.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b424563394c369a804ecbee9b06dfb34997f19d00b3518e39f83a5642618397d", size = 4196173, upload-time = "2025-05-02T19:35:05.018Z" }, + { url = "https://files.pythonhosted.org/packages/d2/0b/2f789a8403ae089b0b121f8f54f4a3e5228df756e2146efdf4a09a3d5083/cryptography-44.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c91fc8e8fd78af553f98bc7f2a1d8db977334e4eea302a4bfd75b9461c2d8904", size = 4087713, upload-time = "2025-05-02T19:35:07.187Z" }, + { url = "https://files.pythonhosted.org/packages/1d/aa/330c13655f1af398fc154089295cf259252f0ba5df93b4bc9d9c7d7f843e/cryptography-44.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:25cd194c39fa5a0aa4169125ee27d1172097857b27109a45fadc59653ec06f44", size = 4299064, upload-time = "2025-05-02T19:35:08.879Z" }, + { url = "https://files.pythonhosted.org/packages/10/a8/8c540a421b44fd267a7d58a1fd5f072a552d72204a3f08194f98889de76d/cryptography-44.0.3-cp37-abi3-win32.whl", hash = "sha256:3be3f649d91cb182c3a6bd336de8b61a0a71965bd13d1a04a0e15b39c3d5809d", size = 2773887, upload-time = "2025-05-02T19:35:10.41Z" }, + { url = "https://files.pythonhosted.org/packages/b9/0d/c4b1657c39ead18d76bbd122da86bd95bdc4095413460d09544000a17d56/cryptography-44.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:3883076d5c4cc56dbef0b898a74eb6992fdac29a7b9013870b34efe4ddb39a0d", size = 3209737, upload-time = "2025-05-02T19:35:12.12Z" }, + { url = "https://files.pythonhosted.org/packages/34/a3/ad08e0bcc34ad436013458d7528e83ac29910943cea42ad7dd4141a27bbb/cryptography-44.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:5639c2b16764c6f76eedf722dbad9a0914960d3489c0cc38694ddf9464f1bb2f", size = 6673501, upload-time = "2025-05-02T19:35:13.775Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f0/7491d44bba8d28b464a5bc8cc709f25a51e3eac54c0a4444cf2473a57c37/cryptography-44.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ffef566ac88f75967d7abd852ed5f182da252d23fac11b4766da3957766759", size = 3960307, upload-time = "2025-05-02T19:35:15.917Z" }, + { url = "https://files.pythonhosted.org/packages/f7/c8/e5c5d0e1364d3346a5747cdcd7ecbb23ca87e6dea4f942a44e88be349f06/cryptography-44.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192ed30fac1728f7587c6f4613c29c584abdc565d7417c13904708db10206645", size = 4170876, upload-time = "2025-05-02T19:35:18.138Z" }, + { url = "https://files.pythonhosted.org/packages/73/96/025cb26fc351d8c7d3a1c44e20cf9a01e9f7cf740353c9c7a17072e4b264/cryptography-44.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:7d5fe7195c27c32a64955740b949070f21cba664604291c298518d2e255931d2", size = 3964127, upload-time = "2025-05-02T19:35:19.864Z" }, + { url = "https://files.pythonhosted.org/packages/01/44/eb6522db7d9f84e8833ba3bf63313f8e257729cf3a8917379473fcfd6601/cryptography-44.0.3-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3f07943aa4d7dad689e3bb1638ddc4944cc5e0921e3c227486daae0e31a05e54", size = 3689164, upload-time = "2025-05-02T19:35:21.449Z" }, + { url = "https://files.pythonhosted.org/packages/68/fb/d61a4defd0d6cee20b1b8a1ea8f5e25007e26aeb413ca53835f0cae2bcd1/cryptography-44.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb90f60e03d563ca2445099edf605c16ed1d5b15182d21831f58460c48bffb93", size = 4198081, upload-time = "2025-05-02T19:35:23.187Z" }, + { url = "https://files.pythonhosted.org/packages/1b/50/457f6911d36432a8811c3ab8bd5a6090e8d18ce655c22820994913dd06ea/cryptography-44.0.3-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:ab0b005721cc0039e885ac3503825661bd9810b15d4f374e473f8c89b7d5460c", size = 3967716, upload-time = "2025-05-02T19:35:25.426Z" }, + { url = "https://files.pythonhosted.org/packages/35/6e/dca39d553075980ccb631955c47b93d87d27f3596da8d48b1ae81463d915/cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3bb0847e6363c037df8f6ede57d88eaf3410ca2267fb12275370a76f85786a6f", size = 4197398, upload-time = "2025-05-02T19:35:27.678Z" }, + { url = "https://files.pythonhosted.org/packages/9b/9d/d1f2fe681eabc682067c66a74addd46c887ebacf39038ba01f8860338d3d/cryptography-44.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0cc66c74c797e1db750aaa842ad5b8b78e14805a9b5d1348dc603612d3e3ff5", size = 4087900, upload-time = "2025-05-02T19:35:29.312Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f5/3599e48c5464580b73b236aafb20973b953cd2e7b44c7c2533de1d888446/cryptography-44.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6866df152b581f9429020320e5eb9794c8780e90f7ccb021940d7f50ee00ae0b", size = 4301067, upload-time = "2025-05-02T19:35:31.547Z" }, + { url = "https://files.pythonhosted.org/packages/a7/6c/d2c48c8137eb39d0c193274db5c04a75dab20d2f7c3f81a7dcc3a8897701/cryptography-44.0.3-cp39-abi3-win32.whl", hash = "sha256:c138abae3a12a94c75c10499f1cbae81294a6f983b3af066390adee73f433028", size = 2775467, upload-time = "2025-05-02T19:35:33.805Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375, upload-time = "2025-05-02T19:35:35.369Z" }, + { url = "https://files.pythonhosted.org/packages/8d/4b/c11ad0b6c061902de5223892d680e89c06c7c4d606305eb8de56c5427ae6/cryptography-44.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:896530bc9107b226f265effa7ef3f21270f18a2026bc09fed1ebd7b66ddf6375", size = 3390230, upload-time = "2025-05-02T19:35:49.062Z" }, + { url = "https://files.pythonhosted.org/packages/58/11/0a6bf45d53b9b2290ea3cec30e78b78e6ca29dc101e2e296872a0ffe1335/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9b4d4a5dbee05a2c390bf212e78b99434efec37b17a4bff42f50285c5c8c9647", size = 3895216, upload-time = "2025-05-02T19:35:51.351Z" }, + { url = "https://files.pythonhosted.org/packages/0a/27/b28cdeb7270e957f0077a2c2bfad1b38f72f1f6d699679f97b816ca33642/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02f55fb4f8b79c1221b0961488eaae21015b69b210e18c386b69de182ebb1259", size = 4115044, upload-time = "2025-05-02T19:35:53.044Z" }, + { url = "https://files.pythonhosted.org/packages/35/b0/ec4082d3793f03cb248881fecefc26015813199b88f33e3e990a43f79835/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:dd3db61b8fe5be220eee484a17233287d0be6932d056cf5738225b9c05ef4fff", size = 3898034, upload-time = "2025-05-02T19:35:54.72Z" }, + { url = "https://files.pythonhosted.org/packages/0b/7f/adf62e0b8e8d04d50c9a91282a57628c00c54d4ae75e2b02a223bd1f2613/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:978631ec51a6bbc0b7e58f23b68a8ce9e5f09721940933e9c217068388789fe5", size = 4114449, upload-time = "2025-05-02T19:35:57.139Z" }, + { url = "https://files.pythonhosted.org/packages/87/62/d69eb4a8ee231f4bf733a92caf9da13f1c81a44e874b1d4080c25ecbb723/cryptography-44.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5d20cc348cca3a8aa7312f42ab953a56e15323800ca3ab0706b8cd452a3a056c", size = 3134369, upload-time = "2025-05-02T19:35:58.907Z" }, ] [[package]] @@ -1263,6 +1299,7 @@ name = "dify-api" version = "1.8.1" source = { virtual = "." } dependencies = [ + { name = "alibabacloud-dingtalk" }, { name = "arize-phoenix-otel" }, { name = "authlib" }, { name = "azure-identity" }, @@ -1278,6 +1315,7 @@ dependencies = [ { name = "flask-login" }, { name = "flask-migrate" }, { name = "flask-orjson" }, + { name = "flask-restful" }, { name = "flask-restx" }, { name = "flask-sqlalchemy" }, { name = "gevent" }, @@ -1295,9 +1333,11 @@ dependencies = [ { name = "json-repair" }, { name = "langfuse" }, { name = "langsmith" }, + { name = "ldap3" }, { name = "mailchimp-transactional" }, { name = "markdown" }, { name = "numpy" }, + { name = "oci" }, { name = "openai" }, { name = "openpyxl" }, { name = "opentelemetry-api" }, @@ -1328,6 +1368,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pyjwt" }, { name = "pypdfium2" }, + { name = "pypinyin" }, { name = "python-docx" }, { name = "python-dotenv" }, { name = "pyyaml" }, @@ -1340,8 +1381,10 @@ dependencies = [ { name = "sseclient-py" }, { name = "starlette" }, { name = "tiktoken" }, + { name = "tokenizers" }, { name = "transformers" }, { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, + { name = "validators" }, { name = "weave" }, { name = "webvtt-py" }, { name = "yarl" }, @@ -1453,6 +1496,7 @@ vdb = [ [package.metadata] requires-dist = [ + { name = "alibabacloud-dingtalk", specifier = "~=2.1.32" }, { name = "arize-phoenix-otel", specifier = "~=0.9.2" }, { name = "authlib", specifier = "==1.3.1" }, { name = "azure-identity", specifier = "==1.16.1" }, @@ -1468,6 +1512,7 @@ requires-dist = [ { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.0.7" }, { name = "flask-orjson", specifier = "~=2.0.0" }, + { name = "flask-restful", specifier = "~=0.3.10" }, { name = "flask-restx", specifier = ">=1.3.0" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=24.11.1" }, @@ -1485,9 +1530,11 @@ requires-dist = [ { name = "json-repair", specifier = ">=0.41.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, + { name = "ldap3", specifier = "~=2.9.1" }, { name = "mailchimp-transactional", specifier = "~=1.0.50" }, { name = "markdown", specifier = "~=3.5.1" }, { name = "numpy", specifier = "~=1.26.4" }, + { name = "oci", specifier = "~=2.135.1" }, { name = "openai", specifier = "~=1.61.0" }, { name = "openpyxl", specifier = "~=3.1.5" }, { name = "opentelemetry-api", specifier = "==1.27.0" }, @@ -1518,6 +1565,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = "~=2.9.1" }, { name = "pyjwt", specifier = "~=2.10.1" }, { name = "pypdfium2", specifier = "==4.30.0" }, + { name = "pypinyin", specifier = "~=0.53.0" }, { name = "python-docx", specifier = "~=1.1.0" }, { name = "python-dotenv", specifier = "==1.0.1" }, { name = "pyyaml", specifier = "~=6.0.1" }, @@ -1530,8 +1578,10 @@ requires-dist = [ { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "starlette", specifier = "==0.47.2" }, { name = "tiktoken", specifier = "~=0.9.0" }, + { name = "tokenizers", specifier = "~=0.21.2" }, { name = "transformers", specifier = "~=4.53.0" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, + { name = "validators", specifier = "~=0.22.0" }, { name = "weave", specifier = "~=0.51.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, { name = "yarl", specifier = "~=1.18.3" }, @@ -1902,6 +1952,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f3/ca/53e14be018a2284acf799830e8cd8e0b263c0fd3dff1ad7b35f8417e7067/flask_orjson-2.0.0-py3-none-any.whl", hash = "sha256:5d15f2ba94b8d6c02aee88fc156045016e83db9eda2c30545fabd640aebaec9d", size = 3622, upload-time = "2024-01-15T00:03:17.511Z" }, ] +[[package]] +name = "flask-restful" +version = "0.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aniso8601" }, + { name = "flask" }, + { name = "pytz" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/ce/a0a133db616ea47f78a41e15c4c68b9f08cab3df31eb960f61899200a119/Flask-RESTful-0.3.10.tar.gz", hash = "sha256:fe4af2ef0027df8f9b4f797aba20c5566801b6ade995ac63b588abf1a59cec37", size = 110453, upload-time = "2023-05-21T03:58:55.781Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/7b/f0b45f0df7d2978e5ae51804bb5939b7897b2ace24306009da0cc34d8d1f/Flask_RESTful-0.3.10-py2.py3-none-any.whl", hash = "sha256:1cf93c535172f112e080b0d4503a8d15f93a48c88bdd36dd87269bdaf405051b", size = 26217, upload-time = "2023-05-21T03:58:54.004Z" }, +] + [[package]] name = "flask-restx" version = "1.3.0" @@ -2915,6 +2980,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/f0/63b06b99b730b9954f8709f6f7d9b8d076fa0a973e472efe278089bde42b/langsmith-0.1.147-py3-none-any.whl", hash = "sha256:7166fc23b965ccf839d64945a78e9f1157757add228b086141eb03a60d699a15", size = 311812, upload-time = "2024-11-27T17:32:39.569Z" }, ] +[[package]] +name = "ldap3" +version = "2.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/ac/96bd5464e3edbc61595d0d69989f5d9969ae411866427b2500a8e5b812c0/ldap3-2.9.1.tar.gz", hash = "sha256:f3e7fc4718e3f09dda568b57100095e0ce58633bcabbed8667ce3f8fbaa4229f", size = 398830, upload-time = "2021-07-18T06:34:21.786Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/f6/71d6ec9f18da0b2201287ce9db6afb1a1f637dedb3f0703409558981c723/ldap3-2.9.1-py2.py3-none-any.whl", hash = "sha256:5869596fc4948797020d3f03b7939da938778a0f9e2009f7a072ccf92b8e8d70", size = 432192, upload-time = "2021-07-18T06:34:12.905Z" }, +] + [[package]] name = "litellm" version = "1.63.7" @@ -3424,6 +3501,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" }, ] +[[package]] +name = "oci" +version = "2.135.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "circuitbreaker" }, + { name = "cryptography" }, + { name = "pyopenssl" }, + { name = "python-dateutil" }, + { name = "pytz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/56/b828096e323c140edce4656b2ad073d5b662c9602c89658d4a33a9573d09/oci-2.135.2.tar.gz", hash = "sha256:520f78983c5246eae80dd5ecfd05e3a565c8b98d02ef0c1b11ba1f61bcccb61d", size = 13813532, upload-time = "2024-10-08T06:46:21.406Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/fe/7a106d278f3998ea2aca65d8772736396467efd4922c56c283604dbeec5d/oci-2.135.2-py3-none-any.whl", hash = "sha256:5213319244e1c7f108bcb417322f33f01f043fd9636d4063574039f5fdf4e4f7", size = 28290849, upload-time = "2024-10-08T06:45:51.567Z" }, +] + [[package]] name = "odfpy" version = "1.4.1" @@ -4580,6 +4674,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/1f/a62754ba9b8a02c038d2a96cb641b71d3809f34d2ba4f921fecd7840d7fb/pyobvector-0.2.15-py3-none-any.whl", hash = "sha256:feeefe849ee5400e72a9a4d3844e425a58a99053dd02abe06884206923065ebb", size = 52680, upload-time = "2025-08-18T02:49:25.452Z" }, ] +[[package]] +name = "pyopenssl" +version = "24.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c1/d4/1067b82c4fc674d6f6e9e8d26b3dff978da46d351ca3bac171544693e085/pyopenssl-24.3.0.tar.gz", hash = "sha256:49f7a019577d834746bc55c5fce6ecbcec0f2b4ec5ce1cf43a9a173b8138bb36", size = 178944, upload-time = "2024-11-27T20:43:12.755Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/22/40f9162e943f86f0fc927ebc648078be87def360d9d8db346619fb97df2b/pyOpenSSL-24.3.0-py3-none-any.whl", hash = "sha256:e474f5a473cd7f92221cc04976e48f4d11502804657a08a989fb3be5514c904a", size = 56111, upload-time = "2024-11-27T20:43:21.112Z" }, +] + [[package]] name = "pypandoc" version = "1.15" @@ -4633,6 +4739,15 @@ version = "0.48.9" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/c7/2c/94ed7b91db81d61d7096ac8f2d325ec562fc75e35f3baea8749c85b28784/PyPika-0.48.9.tar.gz", hash = "sha256:838836a61747e7c8380cd1b7ff638694b7a7335345d0f559b04b2cd832ad5378", size = 67259, upload-time = "2022-03-15T11:22:57.066Z" } +[[package]] +name = "pypinyin" +version = "0.53.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b8/2d/58c9e7d0825d834fc5ac62a340640953d39a80e78cba70eb73d3bad5b4be/pypinyin-0.53.0.tar.gz", hash = "sha256:a2d39ddc2bd31b55897bbb10d2e11a0c4d399988a97c00ad489c151afd9b106d", size = 824458, upload-time = "2024-09-15T08:05:49.637Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/af/a1f9ee31b860ea55985a743b53fc06e61fe156bc1a9d64d94a81afa80470/pypinyin-0.53.0-py2.py3-none-any.whl", hash = "sha256:a906768919da3c31771f2c5e0e5a759214dc38d0087e15e6ff67649e03df8097", size = 834720, upload-time = "2024-09-15T08:05:47.379Z" }, +] + [[package]] name = "pyproject-hooks" version = "1.2.0" @@ -6503,11 +6618,11 @@ wheels = [ [[package]] name = "validators" -version = "0.35.0" +version = "0.22.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/53/66/a435d9ae49850b2f071f7ebd8119dd4e84872b01630d6736761e6e7fd847/validators-0.35.0.tar.gz", hash = "sha256:992d6c48a4e77c81f1b4daba10d16c3a9bb0dbb79b3a19ea847ff0928e70497a", size = 73399, upload-time = "2025-05-01T05:42:06.7Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9b/21/40a249498eee5a244a017582c06c0af01851179e2617928063a3d628bc8f/validators-0.22.0.tar.gz", hash = "sha256:77b2689b172eeeb600d9605ab86194641670cdb73b60afd577142a9397873370", size = 41479, upload-time = "2023-09-02T09:17:59.054Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/6e/3e955517e22cbdd565f2f8b2e73d52528b14b8bcfdb04f62466b071de847/validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd", size = 44712, upload-time = "2025-05-01T05:42:04.203Z" }, + { url = "https://files.pythonhosted.org/packages/3a/0c/785d317eea99c3739821718f118c70537639aa43f96bfa1d83a71f68eaf6/validators-0.22.0-py3-none-any.whl", hash = "sha256:61cf7d4a62bbae559f2e54aed3b000cea9ff3e2fdbe463f51179b92c58c9585a", size = 26195, upload-time = "2023-09-02T09:17:56.595Z" }, ] [[package]] diff --git a/docker/docker-compose.dify-plus.yaml b/docker/docker-compose.dify-plus.yaml index a8222c88b..3e98e8844 100644 --- a/docker/docker-compose.dify-plus.yaml +++ b/docker/docker-compose.dify-plus.yaml @@ -11,6 +11,10 @@ x-shared-env: &shared-api-worker-env APP_API_URL: ${APP_API_URL:-} APP_WEB_URL: ${APP_WEB_URL:-} FILES_URL: ${FILES_URL:-} + INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} + LANG: ${LANG:-en_US.UTF-8} + LC_ALL: ${LC_ALL:-en_US.UTF-8} + PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8} LOG_LEVEL: ${LOG_LEVEL:-INFO} LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} @@ -19,6 +23,7 @@ x-shared-env: &shared-api-worker-env LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} + ENABLE_REQUEST_LOGGING: ${ENABLE_REQUEST_LOGGING:-False} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} @@ -43,14 +48,20 @@ x-shared-env: &shared-api-worker-env CELERY_MIN_WORKERS: ${CELERY_MIN_WORKERS:-} API_TOOL_DEFAULT_CONNECT_TIMEOUT: ${API_TOOL_DEFAULT_CONNECT_TIMEOUT:-10} API_TOOL_DEFAULT_READ_TIMEOUT: ${API_TOOL_DEFAULT_READ_TIMEOUT:-60} + ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true} + ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true} + ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true} DB_USERNAME: ${DB_USERNAME:-postgres} DB_PASSWORD: ${DB_PASSWORD:-difyai123456} DB_HOST: ${DB_HOST:-db} DB_PORT: ${DB_PORT:-5432} DB_DATABASE: ${DB_DATABASE:-dify} SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30} + SQLALCHEMY_MAX_OVERFLOW: ${SQLALCHEMY_MAX_OVERFLOW:-10} SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600} SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false} + SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false} + SQLALCHEMY_POOL_USE_LIFO: ${SQLALCHEMY_POOL_USE_LIFO:-false} POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100} POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB} POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} @@ -61,6 +72,10 @@ x-shared-env: &shared-api-worker-env REDIS_USERNAME: ${REDIS_USERNAME:-} REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456} REDIS_USE_SSL: ${REDIS_USE_SSL:-false} + REDIS_SSL_CERT_REQS: ${REDIS_SSL_CERT_REQS:-CERT_NONE} + REDIS_SSL_CA_CERTS: ${REDIS_SSL_CA_CERTS:-} + REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} + REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} @@ -72,15 +87,21 @@ x-shared-env: &shared-api-worker-env REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} + CELERY_BACKEND: ${CELERY_BACKEND:-redis} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} + CELERY_SENTINEL_PASSWORD: ${CELERY_SENTINEL_PASSWORD:-} CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} STORAGE_TYPE: ${STORAGE_TYPE:-opendal} OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} + CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user} + CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-} + CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_} + CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km} S3_ENDPOINT: ${S3_ENDPOINT:-} S3_REGION: ${S3_REGION:-us-east-1} S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} @@ -127,6 +148,7 @@ x-shared-env: &shared-api-worker-env SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} + VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index} WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} @@ -134,11 +156,14 @@ x-shared-env: &shared-api-worker-env QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} + QDRANT_REPLICATION_FACTOR: ${QDRANT_REPLICATION_FACTOR:-1} MILVUS_URI: ${MILVUS_URI:-http://host.docker.internal:19530} + MILVUS_DATABASE: ${MILVUS_DATABASE:-} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-} MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False} + MILVUS_ANALYZER_PARAMS: ${MILVUS_ANALYZER_PARAMS:-} MYSCALE_HOST: ${MYSCALE_HOST:-myscale} MYSCALE_PORT: ${MYSCALE_PORT:-8123} MYSCALE_USER: ${MYSCALE_USER:-default} @@ -159,6 +184,13 @@ x-shared-env: &shared-api-worker-env PGVECTOR_MAX_CONNECTION: ${PGVECTOR_MAX_CONNECTION:-5} PGVECTOR_PG_BIGM: ${PGVECTOR_PG_BIGM:-false} PGVECTOR_PG_BIGM_VERSION: ${PGVECTOR_PG_BIGM_VERSION:-1.2-20240606} + VASTBASE_HOST: ${VASTBASE_HOST:-vastbase} + VASTBASE_PORT: ${VASTBASE_PORT:-5432} + VASTBASE_USER: ${VASTBASE_USER:-dify} + VASTBASE_PASSWORD: ${VASTBASE_PASSWORD:-Difyai123456} + VASTBASE_DATABASE: ${VASTBASE_DATABASE:-dify} + VASTBASE_MIN_CONNECTION: ${VASTBASE_MIN_CONNECTION:-1} + VASTBASE_MAX_CONNECTION: ${VASTBASE_MAX_CONNECTION:-5} PGVECTO_RS_HOST: ${PGVECTO_RS_HOST:-pgvecto-rs} PGVECTO_RS_PORT: ${PGVECTO_RS_PORT:-5432} PGVECTO_RS_USER: ${PGVECTO_RS_USER:-postgres} @@ -181,6 +213,11 @@ x-shared-env: &shared-api-worker-env TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} + MATRIXONE_HOST: ${MATRIXONE_HOST:-matrixone} + MATRIXONE_PORT: ${MATRIXONE_PORT:-6001} + MATRIXONE_USER: ${MATRIXONE_USER:-dump} + MATRIXONE_PASSWORD: ${MATRIXONE_PASSWORD:-111} + MATRIXONE_DATABASE: ${MATRIXONE_DATABASE:-dify} TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} @@ -213,9 +250,13 @@ x-shared-env: &shared-api-worker-env RELYT_DATABASE: ${RELYT_DATABASE:-postgres} OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} + OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} + OPENSEARCH_VERIFY_CERTS: ${OPENSEARCH_VERIFY_CERTS:-true} + OPENSEARCH_AUTH_METHOD: ${OPENSEARCH_AUTH_METHOD:-basic} OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} - OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} + OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1} + OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss} TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} @@ -229,6 +270,14 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} KIBANA_PORT: ${KIBANA_PORT:-5601} + ELASTICSEARCH_USE_CLOUD: ${ELASTICSEARCH_USE_CLOUD:-false} + ELASTICSEARCH_CLOUD_URL: ${ELASTICSEARCH_CLOUD_URL:-YOUR-ELASTICSEARCH_CLOUD_URL} + ELASTICSEARCH_API_KEY: ${ELASTICSEARCH_API_KEY:-YOUR-ELASTICSEARCH_API_KEY} + ELASTICSEARCH_VERIFY_CERTS: ${ELASTICSEARCH_VERIFY_CERTS:-False} + ELASTICSEARCH_CA_CERTS: ${ELASTICSEARCH_CA_CERTS:-} + ELASTICSEARCH_REQUEST_TIMEOUT: ${ELASTICSEARCH_REQUEST_TIMEOUT:-100000} + ELASTICSEARCH_RETRY_ON_TIMEOUT: ${ELASTICSEARCH_RETRY_ON_TIMEOUT:-True} + ELASTICSEARCH_MAX_RETRIES: ${ELASTICSEARCH_MAX_RETRIES:-10} BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} @@ -246,6 +295,7 @@ x-shared-env: &shared-api-worker-env LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm} + LINDORM_QUERY_TIMEOUT: ${LINDORM_QUERY_TIMEOUT:-1} OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase} OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881} OCEANBASE_VECTOR_USER: ${OCEANBASE_VECTOR_USER:-root@test} @@ -262,12 +312,28 @@ x-shared-env: &shared-api-worker-env OPENGAUSS_MIN_CONNECTION: ${OPENGAUSS_MIN_CONNECTION:-1} OPENGAUSS_MAX_CONNECTION: ${OPENGAUSS_MAX_CONNECTION:-5} OPENGAUSS_ENABLE_PQ: ${OPENGAUSS_ENABLE_PQ:-false} + HUAWEI_CLOUD_HOSTS: ${HUAWEI_CLOUD_HOSTS:-https://127.0.0.1:9200} + HUAWEI_CLOUD_USER: ${HUAWEI_CLOUD_USER:-admin} + HUAWEI_CLOUD_PASSWORD: ${HUAWEI_CLOUD_PASSWORD:-admin} UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io} UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} TABLESTORE_ENDPOINT: ${TABLESTORE_ENDPOINT:-https://instance-name.cn-hangzhou.ots.aliyuncs.com} TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name} TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx} TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx} + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false} + CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-} + CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-} + CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-} + CLICKZETTA_SERVICE: ${CLICKZETTA_SERVICE:-api.clickzetta.com} + CLICKZETTA_WORKSPACE: ${CLICKZETTA_WORKSPACE:-quick_start} + CLICKZETTA_VCLUSTER: ${CLICKZETTA_VCLUSTER:-default_ap} + CLICKZETTA_SCHEMA: ${CLICKZETTA_SCHEMA:-dify} + CLICKZETTA_BATCH_SIZE: ${CLICKZETTA_BATCH_SIZE:-100} + CLICKZETTA_ENABLE_INVERTED_INDEX: ${CLICKZETTA_ENABLE_INVERTED_INDEX:-true} + CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese} + CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart} + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} @@ -276,6 +342,7 @@ x-shared-env: &shared-api-worker-env SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true} PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} + PLUGIN_BASED_TOKEN_COUNTING_ENABLED: ${PLUGIN_BASED_TOKEN_COUNTING_ENABLED:-false} MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64} UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} @@ -285,6 +352,8 @@ x-shared-env: &shared-api-worker-env API_SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} API_SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} WEB_SENTRY_DSN: ${WEB_SENTRY_DSN:-} + PLUGIN_SENTRY_ENABLED: ${PLUGIN_SENTRY_ENABLED:-false} + PLUGIN_SENTRY_DSN: ${PLUGIN_SENTRY_DSN:-} NOTION_INTEGRATION_TYPE: ${NOTION_INTEGRATION_TYPE:-public} NOTION_CLIENT_SECRET: ${NOTION_CLIENT_SECRET:-} NOTION_CLIENT_ID: ${NOTION_CLIENT_ID:-} @@ -299,9 +368,12 @@ x-shared-env: &shared-api-worker-env SMTP_PASSWORD: ${SMTP_PASSWORD:-} SMTP_USE_TLS: ${SMTP_USE_TLS:-true} SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} + SENDGRID_API_KEY: ${SENDGRID_API_KEY:-} INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} + CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: ${CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES:-5} + OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5} CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} @@ -322,17 +394,28 @@ x-shared-env: &shared-api-worker-env MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} + WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} + CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository} + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} + API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} + WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false} + WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30} + WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} + RESPECT_XFORWARD_HEADERS_ENABLED: ${RESPECT_XFORWARD_HEADERS_ENABLED:-false} SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} - MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5} + MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} - PGUSER: ${PGUSER:-${DB_USERNAME}} + ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} + MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50} + POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} @@ -432,7 +515,9 @@ x-shared-env: &shared-api-worker-env ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} - FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} + FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-false} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} @@ -443,9 +528,10 @@ x-shared-env: &shared-api-worker-env PLUGIN_PACKAGE_CACHE_PATH: ${PLUGIN_PACKAGE_CACHE_PATH:-plugin_packages} PLUGIN_MEDIA_CACHE_PATH: ${PLUGIN_MEDIA_CACHE_PATH:-assets} PLUGIN_STORAGE_OSS_BUCKET: ${PLUGIN_STORAGE_OSS_BUCKET:-} - PLUGIN_S3_USE_AWS_MANAGED_IAM: ${PLUGIN_S3_USE_AWS_MANAGED_IAM:-} + PLUGIN_S3_USE_AWS: ${PLUGIN_S3_USE_AWS:-false} + PLUGIN_S3_USE_AWS_MANAGED_IAM: ${PLUGIN_S3_USE_AWS_MANAGED_IAM:-false} PLUGIN_S3_ENDPOINT: ${PLUGIN_S3_ENDPOINT:-} - PLUGIN_S3_USE_PATH_STYLE: ${PLUGIN_S3_USE_PATH_STYLE:-} + PLUGIN_S3_USE_PATH_STYLE: ${PLUGIN_S3_USE_PATH_STYLE:-false} PLUGIN_AWS_ACCESS_KEY: ${PLUGIN_AWS_ACCESS_KEY:-} PLUGIN_AWS_SECRET_KEY: ${PLUGIN_AWS_SECRET_KEY:-} PLUGIN_AWS_REGION: ${PLUGIN_AWS_REGION:-} @@ -454,11 +540,49 @@ x-shared-env: &shared-api-worker-env PLUGIN_TENCENT_COS_SECRET_KEY: ${PLUGIN_TENCENT_COS_SECRET_KEY:-} PLUGIN_TENCENT_COS_SECRET_ID: ${PLUGIN_TENCENT_COS_SECRET_ID:-} PLUGIN_TENCENT_COS_REGION: ${PLUGIN_TENCENT_COS_REGION:-} + PLUGIN_ALIYUN_OSS_REGION: ${PLUGIN_ALIYUN_OSS_REGION:-} + PLUGIN_ALIYUN_OSS_ENDPOINT: ${PLUGIN_ALIYUN_OSS_ENDPOINT:-} + PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID:-} + PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} + PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} + PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} + ENABLE_OTEL: ${ENABLE_OTEL:-false} + OTLP_TRACE_ENDPOINT: ${OTLP_TRACE_ENDPOINT:-} + OTLP_METRIC_ENDPOINT: ${OTLP_METRIC_ENDPOINT:-} + OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318} + OTLP_API_KEY: ${OTLP_API_KEY:-} + OTEL_EXPORTER_OTLP_PROTOCOL: ${OTEL_EXPORTER_OTLP_PROTOCOL:-} + OTEL_EXPORTER_TYPE: ${OTEL_EXPORTER_TYPE:-otlp} + OTEL_SAMPLING_RATE: ${OTEL_SAMPLING_RATE:-0.1} + OTEL_BATCH_EXPORT_SCHEDULE_DELAY: ${OTEL_BATCH_EXPORT_SCHEDULE_DELAY:-5000} + OTEL_MAX_QUEUE_SIZE: ${OTEL_MAX_QUEUE_SIZE:-2048} + OTEL_MAX_EXPORT_BATCH_SIZE: ${OTEL_MAX_EXPORT_BATCH_SIZE:-512} + OTEL_METRIC_EXPORT_INTERVAL: ${OTEL_METRIC_EXPORT_INTERVAL:-60000} + OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000} + OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000} + ALLOW_EMBED: ${ALLOW_EMBED:-false} + QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} + QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} + QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} + SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true} + SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html} + ENABLE_CLEAN_EMBEDDING_CACHE_TASK: ${ENABLE_CLEAN_EMBEDDING_CACHE_TASK:-true} + ENABLE_CLEAN_UNUSED_DATASETS_TASK: ${ENABLE_CLEAN_UNUSED_DATASETS_TASK:-true} + ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-true} + ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-true} + ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-true} + ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-true} + ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-true} + ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true} services: # API service api: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.2.1 + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.8.1 restart: always environment: # Use the shared environment variables. @@ -487,7 +611,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.2.1 + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.8.1 restart: always environment: # Use the shared environment variables. @@ -509,10 +633,60 @@ services: - ssrf_proxy_network - default + # worker-gaia service + # The Celery worker-gaia for processing the queue. + worker-gaia: + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.8.1 + restart: always + environment: + # Use the shared environment variables. + <<: *shared-api-worker-env + # Startup mode, 'worker-gaia' starts the Celery worker-gaia for processing the queue. + MODE: worker-gaia + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} + depends_on: + - db + - redis + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/api/storage + networks: + - ssrf_proxy_network + - default + + # worker-dataset service + # The Celery worker-dataset for processing the queue. + worker-dataset: + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.8.1 + restart: always + environment: + # Use the shared environment variables. + <<: *shared-api-worker-env + # Startup mode, 'worker-dataset' starts the Celery worker-dataset for processing the queue. + MODE: worker-dataset + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} + depends_on: + - db + - redis + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/api/storage + networks: + - ssrf_proxy_network + - default + # beat service # The Celery worker for schedule tasks. beat: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.2.1 + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.8.1 restart: always environment: # Use the shared environment variables. @@ -536,7 +710,7 @@ services: # Frontend web application. web: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-web:1.2.0 + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-web:1.8.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -545,6 +719,8 @@ services: NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} CSP_WHITELIST: ${CSP_WHITELIST:-} + ALLOW_EMBED: ${ALLOW_EMBED:-false} + ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} @@ -553,14 +729,18 @@ services: LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} - MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5} + MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} + MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50} + ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true} + ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true} + ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true} # The postgres database. db: image: ccr.ccs.tencentyun.com/yfgaia/postgres:15-alpine restart: always environment: - PGUSER: ${PGUSER:-postgres} + POSTGRES_USER: ${POSTGRES_USER:-postgres} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} POSTGRES_DB: ${POSTGRES_DB:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} @@ -573,10 +753,20 @@ services: volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: - test: ['CMD', 'pg_isready'] + test: + [ + "CMD", + "pg_isready", + "-h", + "db", + "-U", + "${POSTGRES_USER:-postgres}", + "-d", + "${POSTGRES_DB:-dify}", + ] interval: 1s timeout: 3s - retries: 30 + retries: 60 ports: - 5432:5432 @@ -592,13 +782,17 @@ services: # Set the redis password when startup redis server. command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} healthcheck: - test: ['CMD', 'redis-cli', 'ping'] + test: + [ + "CMD-SHELL", + "redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG", + ] ports: - 6379:6379 # The DifySandbox sandbox: - image: ccr.ccs.tencentyun.com/yfgaia/dify-sandbox:0.2.11 + image: ccr.ccs.tencentyun.com/yfgaia/dify-sandbox:0.2.12 restart: always environment: # The DifySandbox configurations @@ -611,6 +805,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf @@ -621,7 +816,7 @@ services: # plugin daemon plugin_daemon: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plugin-daemon:0.0.7-local + image: langgenius/dify-plugin-daemon:0.2.0-local restart: always environment: # Use the shared environment variables. @@ -636,9 +831,11 @@ services: PLUGIN_REMOTE_INSTALLING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_REMOTE_INSTALLING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_WORKING_PATH: ${PLUGIN_WORKING_PATH:-/app/storage/cwd} - FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} + FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-false} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} @@ -647,22 +844,36 @@ services: PLUGIN_MEDIA_CACHE_PATH: ${PLUGIN_MEDIA_CACHE_PATH:-assets} PLUGIN_STORAGE_OSS_BUCKET: ${PLUGIN_STORAGE_OSS_BUCKET:-} S3_USE_AWS_MANAGED_IAM: ${PLUGIN_S3_USE_AWS_MANAGED_IAM:-false} + S3_USE_AWS: ${PLUGIN_S3_USE_AWS:-false} S3_ENDPOINT: ${PLUGIN_S3_ENDPOINT:-} S3_USE_PATH_STYLE: ${PLUGIN_S3_USE_PATH_STYLE:-false} AWS_ACCESS_KEY: ${PLUGIN_AWS_ACCESS_KEY:-} - PAWS_SECRET_KEY: ${PLUGIN_AWS_SECRET_KEY:-} + AWS_SECRET_KEY: ${PLUGIN_AWS_SECRET_KEY:-} AWS_REGION: ${PLUGIN_AWS_REGION:-} AZURE_BLOB_STORAGE_CONNECTION_STRING: ${PLUGIN_AZURE_BLOB_STORAGE_CONNECTION_STRING:-} AZURE_BLOB_STORAGE_CONTAINER_NAME: ${PLUGIN_AZURE_BLOB_STORAGE_CONTAINER_NAME:-} TENCENT_COS_SECRET_KEY: ${PLUGIN_TENCENT_COS_SECRET_KEY:-} TENCENT_COS_SECRET_ID: ${PLUGIN_TENCENT_COS_SECRET_ID:-} TENCENT_COS_REGION: ${PLUGIN_TENCENT_COS_REGION:-} + ALIYUN_OSS_REGION: ${PLUGIN_ALIYUN_OSS_REGION:-} + ALIYUN_OSS_ENDPOINT: ${PLUGIN_ALIYUN_OSS_ENDPOINT:-} + ALIYUN_OSS_ACCESS_KEY_ID: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID:-} + ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} + ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} + ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} + SENTRY_ENABLED: ${PLUGIN_SENTRY_ENABLED:-false} + SENTRY_DSN: ${PLUGIN_SENTRY_DSN:-} ports: - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" volumes: - ./volumes/plugin_daemon:/app/storage depends_on: - - db + db: + condition: service_healthy # ssrf_proxy server # for more information, please refer to @@ -699,8 +910,8 @@ services: - ./certbot/update-cert.template.txt:/update-cert.template.txt - ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh environment: - - CERTBOT_EMAIL=${CERTBOT_EMAIL} - - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} + - CERTBOT_EMAIL=${CERTBOT_EMAIL:-your_email@example.com} + - CERTBOT_DOMAIN=${CERTBOT_DOMAIN:-your_domain.com} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} entrypoint: [ '/docker-entrypoint.sh' ] command: [ 'tail', '-f', '/dev/null' ] @@ -814,6 +1025,30 @@ services: start_period: 30s timeout: 10s + # get image from https://www.vastdata.com.cn/ + vastbase: + image: vastdata/vastbase-vector + profiles: + - vastbase + restart: always + environment: + - VB_DBCOMPATIBILITY=PG + - VB_DB=dify + - VB_USERNAME=dify + - VB_PASSWORD=Difyai123456 + ports: + - "5434:5432" + volumes: + - ./vastbase/lic:/home/vastbase/vastbase/lic + - ./vastbase/data:/home/vastbase/data + - ./vastbase/backup:/home/vastbase/backup + - ./vastbase/backup_log:/home/vastbase/backup_log + healthcheck: + test: ["CMD", "pg_isready"] + interval: 1s + timeout: 3s + retries: 30 + # The pgvector vector database. pgvector: image: pgvector/pgvector:pg16 @@ -1038,6 +1273,18 @@ services: ports: - ${OPENGAUSS_PORT:-6600}:${OPENGAUSS_PORT:-6600} + # Matrixone vector store. + matrixone: + hostname: matrixone + image: matrixorigin/matrixone:2.1.1 + profiles: + - matrixone + restart: always + volumes: + - ./volumes/matrixone/data:/mo-data + ports: + - ${MATRIXONE_PORT:-6001}:${MATRIXONE_PORT:-6001} + # MyScale vector database myscale: container_name: myscale @@ -1128,8 +1375,10 @@ services: # Extend - admin-web admin-web: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-web:1.2.1 + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-web:1.8.1 restart: always + ports: + - '8081:8081' depends_on: - admin-server command: [ 'nginx-debug', '-g', 'daemon off;' ] @@ -1138,8 +1387,12 @@ services: # Extend - admin-server admin-server: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-server:1.2.1 + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-server:1.8.1 restart: always + environment: + # JWT signing key must match API's SECRET_KEY for token compatibility + JWT_SIGNING_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} + SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} ports: - '8888:8888' depends_on: @@ -1158,7 +1411,7 @@ services: # Extend - sandbox-full sandbox-full: - image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-sandbox-full:0.0.7 + image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-sandbox-full:1.2.2 restart: always environment: # The DifySandbox configurations diff --git a/docker/nginx/conf.d/default.conf.template b/docker/nginx/conf.d/default.conf.template index a8688a096..441486176 100644 --- a/docker/nginx/conf.d/default.conf.template +++ b/docker/nginx/conf.d/default.conf.template @@ -3,7 +3,6 @@ server { listen ${NGINX_PORT}; server_name ${NGINX_SERVER_NAME}; - # 管理中心反向代理配置 location = /admin { return 301 /admin/; diff --git a/web/.env.example b/web/.env.example index 0d5d9539e..82ec6fcf2 100644 --- a/web/.env.example +++ b/web/.env.example @@ -70,3 +70,6 @@ NEXT_PUBLIC_DEFAULT_DOMAIN= # Auth2 Logout URL(二开新增配置) NEXT_PUBLIC_AUTH0_LOGOUT_URL= +# 后台地址 +NEXT_PUBLIC_ADMIN_API_URL=http://localhost:5001/ + diff --git a/web/Dockerfile b/web/Dockerfile index 1b3a2c5ed..a60e5594b 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -6,7 +6,7 @@ LABEL maintainer="takatost@gmail.com" # RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories # if you located in China, you can use taobao registry to speed up -# RUN npm config set registry https://registry.npmmirror.com +RUN npm config set registry https://registry.npmmirror.com RUN apk add --no-cache tzdata RUN corepack enable @@ -26,7 +26,7 @@ COPY pnpm-lock.yaml . # Use packageManager from package.json RUN corepack install -#RUN pnpm install --frozen-lockfile +RUN pnpm install --no-frozen-lockfile # build resources FROM base AS builder diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/user_overview_extend/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/user_overview_extend/page.tsx index fefdbcac4..9c33589c5 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/user_overview_extend/page.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/user_overview_extend/page.tsx @@ -3,12 +3,15 @@ import React, { useState } from 'react' import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import { useTranslation } from 'react-i18next' +import type { PeriodParams } from '@/app/components/app/overview/app-chart' import { - PeriodParams, + AvgSessionInteractions, + AvgUserInteractions, + ConversationsChart, + CostChart, WorkflowCostChart, WorkflowMessagesChart, -} from '@/app/components/app/overview/appChart' -import { AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart } from '@/app/components/app/overview/appChart' +} from '@/app/components/app/overview/app-chart' import type { Item } from '@/app/components/base/select' import { SimpleSelect } from '@/app/components/base/select' @@ -65,7 +68,7 @@ const UserOverView = ({ params: { appId } }: UserOverViewProps) => { return (
-
+
{t('appOverview.analysis.title')} ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))} @@ -77,18 +80,18 @@ const UserOverView = ({ params: { appId } }: UserOverViewProps) => { {model === 'workflow' && ( <> {/* Extend: Workflow personal detection error */} -
+
-
+
)} {model !== 'workflow' && ( <> -
+
{model !== 'completion' && (isChatApp ? ( @@ -98,7 +101,7 @@ const UserOverView = ({ params: { appId } }: UserOverViewProps) => { ))}
-
+
diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 70a45a4bb..f40c7b0ff 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -47,7 +47,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [name, setName] = useState('') const [description, setDescription] = useState('') - const [isAppTypeExpanded, setIsAppTypeExpanded] = useState(false) + const [isAppTypeExpanded, setIsAppTypeExpanded] = useState(true) const { plan, enableBilling } = useProviderContext() const isAppsFull = (enableBilling && plan.usage.buildApps >= plan.total.buildApps) diff --git a/web/app/components/base/action-button/index.css b/web/app/components/base/action-button/index.css index 3c1a10b86..8eed53774 100644 --- a/web/app/components/base/action-button/index.css +++ b/web/app/components/base/action-button/index.css @@ -30,6 +30,10 @@ @apply p-0 w-4 h-4 rounded } + .action-btn-sm { + @apply px-2 py-1 h-7 min-w-[60px] rounded-md text-xs + } + .action-btn.action-btn-active { @apply text-text-accent bg-state-accent-active hover:bg-state-accent-active-alt } diff --git a/web/app/components/base/action-button/index.tsx b/web/app/components/base/action-button/index.tsx index c90d1a8de..ac06a8698 100644 --- a/web/app/components/base/action-button/index.tsx +++ b/web/app/components/base/action-button/index.tsx @@ -17,6 +17,7 @@ const actionButtonVariants = cva( variants: { size: { xs: 'action-btn-xs', + sm: 'action-btn-sm', m: 'action-btn-m', l: 'action-btn-l', xl: 'action-btn-xl', @@ -29,7 +30,7 @@ const actionButtonVariants = cva( ) export type ActionButtonProps = { - size?: 'xs' | 's' | 'm' | 'l' | 'xl' + size?: 'xs' | 'sm' | 'm' | 'l' | 'xl' state?: ActionButtonState styleCss?: CSSProperties } & React.ButtonHTMLAttributes & VariantProps diff --git a/web/app/components/header/account-money-extend/index.tsx b/web/app/components/header/account-money-extend/index.tsx index 7a8f0552a..3d8c4da4f 100644 --- a/web/app/components/header/account-money-extend/index.tsx +++ b/web/app/components/header/account-money-extend/index.tsx @@ -38,21 +38,21 @@ const AccountMoneyExtend = () => { // 根据警示级别设置颜色 const alertColorClass = isRedAlert - ? 'text-red-500' + ? 'text-text-destructive' : isYellowAlert - ? 'text-yellow-500' - : 'text-gray-700' + ? 'text-text-warning' + : 'text-text-secondary' return (
-
+
额度
-
- 已用: +
+ 已用: { > ¥{usedRMB} - / - + / + ¥{totalRMB.replace(/\B(?=(\d{3})+(?!\d))/g, ',')}
diff --git a/web/app/components/header/index.tsx b/web/app/components/header/index.tsx index a8d46e805..e947c0406 100644 --- a/web/app/components/header/index.tsx +++ b/web/app/components/header/index.tsx @@ -68,8 +68,8 @@ const Header = () => {
+ {/* // 二开部分 - 额度限制 */}
- {/* // 二开部分 - 额度限制 */}
{!isCurrentWorkspaceDatasetOperator && } {!isCurrentWorkspaceDatasetOperator && } @@ -97,8 +97,8 @@ const Header = () => { {enableBilling ? : } + {/* // 二开部分 - 额度限制 */}
- {/* // 二开部分 - 额度限制 */}
{!isCurrentWorkspaceDatasetOperator && } {!isCurrentWorkspaceDatasetOperator && } diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index bf8a14fb4..d13c8ed18 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -7,11 +7,13 @@ import { RiErrorWarningFill, } from '@remixicon/react' import { useBoolean } from 'ahooks' -import { useSearchParams } from 'next/navigation' +import { useRouter, useSearchParams } from 'next/navigation' // Extend: Batch import import TabHeader from '../../base/tab-header' import MenuDropdown from './menu-dropdown' import RunBatch from './run-batch' import ResDownload from './run-batch/res-download' +import BatchProgress from './run-batch/batch-progress' // Extend: Batch import +import Pagination from '@/app/components/base/pagination' // Extend: Batch import import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import RunOnce from '@/app/components/share/text-generation/run-once' import { fetchSavedMessage as doFetchSavedMessage, removeMessage, saveMessage } from '@/service/share' @@ -37,6 +39,7 @@ import { Resolution, TransferMethod } from '@/types/app' import { useAppFavicon } from '@/hooks/use-app-favicon' import DifyLogo from '@/app/components/base/logo/dify-logo' import cn from '@/utils/classnames' +import { downloadBatchApi, fetchBatchWorkflowListApi, processExcelUploadApi } from '@/service/web-extend' // Extend: Batch import import { AccessMode } from '@/models/access-control' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' @@ -81,6 +84,7 @@ const TextGeneration: FC = ({ const mode = searchParams.get('mode') || 'create' const [currentTab, setCurrentTab] = useState(['create', 'batch'].includes(mode) ? mode : 'create') + const router = useRouter() // extend // Notice this situation isCallBatchAPI but not in batch tab const [isCallBatchAPI, setIsCallBatchAPI] = useState(false) const isInBatchTab = currentTab === 'batch' @@ -148,6 +152,84 @@ const TextGeneration: FC = ({ doSetAllTaskList(taskList) allTaskListRef.current = taskList } + // Extend: Start Batch import + // 每页5个任务 + const batchJobsLimit = 5 + // 分页状态 + const [currentPage, setCurrentPage] = useState(1) + // 批量处理相关状态 + const [batchJobs, setBatchJobs] = useState>([]) + + const [totalBatchJobs, setTotalBatchJobs] = useState(0) + const [isLoadingBatchJobs, setIsLoadingBatchJobs] = useState(false) + + // 从后端获取批量工作流列表 + const loadBatchWorkflows = async () => { + if (!appId || currentTab !== 'batch') return + + setIsLoadingBatchJobs(true) + try { + const result = await fetchBatchWorkflowListApi(installedAppInfo?.id, currentPage, batchJobsLimit) + if (result) { + // 转换数据格式以兼容现有组件 + const convertedJobs = result.items.map(item => ({ + id: item.id, + fileName: item.file_name, + createdAt: item.created_at, + status: item.status, + totalRows: item.total_rows, + processedRows: item.processed_rows, + error: item.error, // 添加错误信息 + })) + setBatchJobs(convertedJobs) + setTotalBatchJobs(result.total) + } + } + catch (error) { + console.error('Failed to load batch workflows:', error) + } + finally { + setIsLoadingBatchJobs(false) + } + } + + // 加载批量工作流列表 + useEffect(() => { + loadBatchWorkflows() + }, [appId, currentTab, currentPage, installedAppInfo?.id]) + + // 自动刷新批量工作流列表(每3秒) + useEffect(() => { + if (currentTab !== 'batch' || batchJobs.length === 0) + return + + // 检查是否有进行中的任务 + const hasActiveJobs = batchJobs.some(job => + job.status === 'pending' || job.status === 'processing', + ) + + if (!hasActiveJobs) + return + + const refreshInterval = setInterval(() => { + loadBatchWorkflows() + }, 3000) // 每3秒刷新一次 + + return () => clearInterval(refreshInterval) + }, [currentTab, batchJobs, appId, installedAppInfo?.id, currentPage]) + + // 计算分页数据 - 现在数据已经是从后端分页获取的,不需要再切片 + const paginatedBatchJobs = batchJobs + // Extend: Stop Batch import + const pendingTaskList = allTaskList.filter(task => task.status === TaskStatus.pending) const noPendingTask = pendingTaskList.length === 0 const showTaskList = allTaskList.filter(task => task.status !== TaskStatus.pending) @@ -319,8 +401,76 @@ const TextGeneration: FC = ({ setControlStopResponding(Date.now()) // eslint-disable-next-line ts/no-use-before-define - showResultPanel() + doShowResultPanel() // Extend: Batch import } + // Extend: Start Batch import + // 处理批量上传 + const handleBatchUpload = async (originalFile: File, data: string[][], originalFileName?: string) => { + if (!checkBatchInputs(data)) + return + + try { + // 创建key-name映射 + const keyNameMapping: Record = {} + promptConfig?.prompt_variables.forEach((variable) => { + keyNameMapping[variable.name] = variable.key + }) + + // 直接使用原始文件 + const result = await processExcelUploadApi(originalFile, installedAppInfo?.id || '', appId, keyNameMapping) + if (result === null) { + // API调用失败,错误信息已经在processExcelUploadApi中显示 + return + } + // 添加到批量任务列表 - 最新的任务显示在顶部 + setBatchJobs(prev => [{ + id: result.id, + fileName: originalFileName || originalFile.name, + createdAt: new Date().toISOString(), + status: 'pending', + totalRows: 0, + processedRows: 0, + error: undefined, + }, ...prev]) + + // 显示结果面板 + // eslint-disable-next-line ts/no-use-before-define + doShowResultPanel() + notify({ type: 'success', message: t('extend.batchWorkflow.batchUploadSuccess') }) + } + catch (error) { + console.error('批量上传失败:', error) + notify({ type: 'error', message: t('extend.batchWorkflow.batchUploadFailed') }) + } + } + + // 下载批量处理结果 + const handleBatchDownload = async (batchId: string) => { + try { + const blob = await downloadBatchApi(batchId) + if (blob) { + const url = window.URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = url + a.download = `batch_results_${batchId}.csv` + document.body.appendChild(a) + a.click() + window.URL.revokeObjectURL(url) + document.body.removeChild(a) + } + } + catch (error) { + console.error('下载失败:', error) + notify({ type: 'error', message: t('extend.batchWorkflow.downloadFailed') }) + } + } + // 处理重试成功回调 + const handleRetrySuccess = () => { + // 重试成功后,重新加载批量工作流列表 + loadBatchWorkflows() + console.log('批量任务重试成功,已刷新列表') + } + // Extend: Stop Batch import const handleCompleted = (completionRes: string, taskId?: number, isSuccess?: boolean) => { const allTaskListLatest = getLatestTaskList() const batchCompletionResLatest = getBatchCompletionRes() @@ -479,15 +629,72 @@ const TextGeneration: FC = ({
0)) && 'pt-0', !isPC && 'p-0 pb-2', )}> - {!isCallBatchAPI ? renderRes() : renderBatchRes()} - {!noPendingTask && ( + {!isCallBatchAPI && !(isInBatchTab && batchJobs.length > 0) ? renderRes() : ( + <> + {isCallBatchAPI && renderBatchRes()} + {isInBatchTab && batchJobs.length > 0 && ( +
+ {/* 数据保留提示 */} +
+
+ {t('extend.batchWorkflow.dataRetentionNotice')}: {t('extend.batchWorkflow.dataRetentionDescription')} +
+
+ + {/* //extend start 批量任务列表 */} +
+ {isLoadingBatchJobs ? ( +
+
+
+ ) : paginatedBatchJobs.length > 0 ? ( + paginatedBatchJobs.map(job => ( + handleBatchDownload(job.id)} + onRetrySuccess={handleRetrySuccess} + /> + )) + ) : ( +
+ 暂无批量处理任务 +
+ )} +
+ {/* // extend stop 批量任务列表 */} + + {/* 分页控件 */} + {totalBatchJobs > batchJobsLimit && ( +
+ { + setCurrentPage(page) + // extend + }} + total={totalBatchJobs} + limit={batchJobsLimit} + className="w-auto" + /> +
+ )} +
+ )} + + )} + {!noPendingTask && isCallBatchAPI && (
)} + { /* // Extend: Stop Batch import */ }
{isCallBatchAPI && allFailedTaskList.length > 0 && (
@@ -556,7 +763,26 @@ const TextGeneration: FC = ({ : []), ]} value={currentTab} - onChange={setCurrentTab} + onChange={(tab) => { + // Extend: Start Batch import + + // 当从批量模式切换回单次运行时,重置批量相关状态 + if (currentTab === 'batch' && tab === 'create') { + setIsCallBatchAPI(false) + setAllTaskList([]) + setCurrGroupNum(0) + // 只清空显示状态,保留localStorage以便再次切换回批量时恢复 + setBatchJobs([]) + } + + // 当从单次运行切换到批量模式时,清理单次运行的结果 + if (currentTab === 'create' && tab === 'batch') { + setResultExisted(false) + setControlStopResponding(Date.now()) // 停止可能正在进行的单次运行 + } + setCurrentTab(tab) + }} + // Extend: Stop Batch import />
{/* form */} @@ -581,7 +807,10 @@ const TextGeneration: FC = ({
{currentTab === 'saved' && ( @@ -633,7 +862,7 @@ const TextGeneration: FC = ({ if (isShowResultPanel) hideResultPanel() else - showResultPanel() + doShowResultPanel()// Extend: Batch import }} >
diff --git a/web/app/components/share/text-generation/run-batch/batch-progress/index.tsx b/web/app/components/share/text-generation/run-batch/batch-progress/index.tsx new file mode 100644 index 000000000..7ecfbe3c3 --- /dev/null +++ b/web/app/components/share/text-generation/run-batch/batch-progress/index.tsx @@ -0,0 +1,313 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiCheckLine, + RiErrorWarningLine, + RiLoader2Line, + RiPauseLine, + RiPlayLargeLine, + RiRefreshLine, + RiStopLine, +} from '@remixicon/react' +import { resumeBatchApi, retryFailedTasksApi, stopBatchApi } from '@/service/web-extend' // extend: 批量运行工单 +import type { BatchStatus } from '@/utils/batch-progress-manager' // extend: 批量运行工单 +import ActionButton from '@/app/components/base/action-button' + +import cn from '@/utils/classnames' + +export type BatchProgressProps = { + batchId: string + fileName: string + workflowId?: string + jobData: { + id: string + fileName: string + createdAt: string + status: string + totalRows: number + processedRows: number + error?: string + } + onDownload: () => void + onRetrySuccess?: () => void +} + +const BatchProgress: FC = ({ + batchId, + fileName, + workflowId, + jobData, + onDownload, + onRetrySuccess, +}) => { + const { t } = useTranslation() + + const [isLoading, setIsLoading] = useState(false) + + // 停止批量处理 + const handleStop = async () => { + setIsLoading(true) + try { + const success = await stopBatchApi(batchId) + if (success) { + // 通知父组件刷新列表 + onRetrySuccess?.() + } + } + catch (error) { + console.error('Failed to stop batch:', error) + } + finally { + setIsLoading(false) + } + } + + // 恢复批量处理 + const handleResume = async () => { + setIsLoading(true) + try { + const success = await resumeBatchApi(batchId) + if (success) { + // 通知父组件刷新列表 + onRetrySuccess?.() + } + } + catch (error) { + console.error('Failed to resume batch:', error) + } + finally { + setIsLoading(false) + } + } + + // 重试失败任务(仅重试失败的任务,保留已完成的任务) + const handleRetry = async () => { + setIsLoading(true) + try { + const success = await retryFailedTasksApi(batchId) + if (success) { + // 通知父组件刷新列表 + onRetrySuccess?.() + } + } + catch (error) { + console.error('Failed to retry failed tasks:', error) + } + finally { + setIsLoading(false) + } + } + + const getStatusText = (status: BatchStatus) => { + switch (status) { + case 'pending': + return t('extend.batchWorkflow.pending') + case 'processing': + return t('extend.batchWorkflow.processing') + case 'completed': + return t('extend.batchWorkflow.completed') + case 'failed': + return t('extend.batchWorkflow.failed') + case 'stopped': + return t('extend.batchWorkflow.stopped') + default: + return t('extend.batchWorkflow.pending') + } + } + + const getStatusColor = (status: BatchStatus) => { + switch (status) { + case 'pending': + return 'text-gray-500' // Extend: 批量运行工单 + case 'processing': + return 'text-blue-700' // Extend: 批量运行工单 + case 'completed': + return 'text-green-500' + case 'failed': + return 'text-red-500' + case 'stopped': + return 'text-yellow-500' + default: + return 'text-gray-500' // Extend: 批量运行工单 + } + } + + const formatDate = (dateString: string) => { + if (!dateString) return '-' + + const date = new Date(dateString) + // 检查日期是否有效 + if (isNaN(date.getTime())) + return '-' + + return date.toLocaleString('zh-CN', { + year: 'numeric', + month: '2-digit', + day: '2-digit', + hour: '2-digit', + minute: '2-digit', + }) + } + + const currentTime = new Date().toLocaleString('zh-CN', { + year: 'numeric', + month: '2-digit', + day: '2-digit', + hour: '2-digit', + minute: '2-digit', + }) + + // 计算进度 + const progress = jobData.totalRows > 0 ? (jobData.processedRows / jobData.totalRows) * 100 : 0 + const status = jobData.status as BatchStatus + const failed_count = 0 // 从列表API没有这个字段,如果需要可以后续添加 + + const getBorderColor = (status: BatchStatus) => { + switch (status) { + case 'pending': + return 'border-gray-300' + case 'processing': + return 'border-blue-500' + case 'completed': + return 'border-green-500' + case 'failed': + return 'border-red-500' + case 'stopped': + return 'border-yellow-500' + default: + return 'border-gray-300' + } + } + + return ( +
+ {/* 统一的批量任务信息框 */} +
+ {/* 文件信息 */} +
+
+
{t('extend.batchWorkflow.uploadedFileName')}
+
{t('extend.batchWorkflow.uploadTime')}
+
+
+
+
{fileName}
+
{formatDate(jobData.createdAt)}
+
+
+
+ + {/* 进度条 */} +
+
+
+ {status === 'processing' && } + {status === 'completed' && } + {status === 'failed' && } + {status === 'pending' && } + {status === 'stopped' && } + + {getStatusText(status)} + +
+ + {isNaN(progress) ? '0' : Math.round(progress)}% + +
+ + {/* 进度条可视化 */} +
+
+
+ + {/* 详细进度信息 */} + {jobData.totalRows > 0 && ( +
+ {t('extend.batchWorkflow.processed', { + processed: jobData.processedRows || 0, + total: jobData.totalRows || 0, + })} +
+ )} + + {/* 错误信息显示 */} + {jobData.error && status === 'failed' && ( +
+
+ +
+
+ {t('extend.batchWorkflow.errorOccurred')} +
+
+ {jobData.error} +
+
+
+
+ )} + +
+ + {/* 操作按钮区域 */} +
+
+ {/* 控制按钮 */} + {(status === 'processing' || status === 'pending') && ( + + {isLoading ? ( + + ) : ( + + )} + {t('extend.batchWorkflow.stop')} + + )} + {status === 'stopped' && ( + + {isLoading ? ( + + ) : ( + + )} + {t('extend.batchWorkflow.resume')} + + )} + {(status === 'failed') && ( + + {isLoading ? ( + + ) : ( + + )} + {t('extend.batchWorkflow.retry')} + + )} +
+ +
+ {/* 下载按钮 */} + {(status === 'failed' || status === 'completed' || (status === 'processing' && progress >= 100)) && ( + + {t('extend.batchWorkflow.download')} + + )} +
+
+
+
+ ) +} + +export default React.memo(BatchProgress) diff --git a/web/app/components/share/text-generation/run-batch/csv-download/index.tsx b/web/app/components/share/text-generation/run-batch/csv-download/index.tsx index 54a5a6d52..fc69a35ef 100644 --- a/web/app/components/share/text-generation/run-batch/csv-download/index.tsx +++ b/web/app/components/share/text-generation/run-batch/csv-download/index.tsx @@ -46,23 +46,31 @@ const CSVDownload: FC = ({
- -
- - {t('share.generation.downloadTemplate')} -
-
+ + {/* Extend: start 聊天批量处理 */} +
+ +
+ + {t('share.generation.downloadTemplate')} +
+
+ + {t('extend.batchWorkflow.willUseBatchProcessing')} + +
+ {/* Extend: stop 聊天批量处理 */}
) diff --git a/web/app/components/share/text-generation/run-batch/csv-reader/index.tsx b/web/app/components/share/text-generation/run-batch/csv-reader/index.tsx index c26f78cca..d38168e4f 100644 --- a/web/app/components/share/text-generation/run-batch/csv-reader/index.tsx +++ b/web/app/components/share/text-generation/run-batch/csv-reader/index.tsx @@ -13,11 +13,11 @@ import cn from '@/utils/classnames' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' export type Props = { - onParsed: (data: string[][]) => void + onParsed: (data: string[][], originalFile?: File) => void } // 二开部分 - Begin 自定义CSVReader type CCProps = { - onUploadAccepted: (results: any) => void + onUploadAccepted: (results: any, file?: File) => void onDragOver: (event: DragEvent) => void onDragLeave: (event: DragEvent) => void children: (props: any) => React.ReactElement @@ -57,7 +57,6 @@ const CustomCSVReader: React.FC = ({ encoding = 'gbk' // Extend stop: 处理可能的误判,将 ISO-8859-2 视为 GBK - console.log('encoding: ', encoding) // 重新用检测到的编码读取文件内容 const correctReader = new FileReader() @@ -66,8 +65,9 @@ const CustomCSVReader: React.FC = ({ // 使用 PapaParse 解析 CSV 文件 Papa.parse(text, { - complete: (results) => { - onUploadAccepted(results) + // eslint-disable-next-line sonarjs/no-nested-functions + complete: (results: any) => { + onUploadAccepted(results, file) }, }) } @@ -144,8 +144,9 @@ const CSVReader: FC = ({ const [zoneHover, setZoneHover] = useState(false) return ( { - onParsed(results.data) + onUploadAccepted={(results: any, file?: File) => { + console.log('CSV Reader - 文件上传:', file ? file.name : 'no file') + onParsed(results.data, file) setZoneHover(false) }} onDragOver={(event: DragEvent) => { diff --git a/web/app/components/share/text-generation/run-batch/index.tsx b/web/app/components/share/text-generation/run-batch/index.tsx index eaaa31f4b..4f4adbb9e 100644 --- a/web/app/components/share/text-generation/run-batch/index.tsx +++ b/web/app/components/share/text-generation/run-batch/index.tsx @@ -14,12 +14,16 @@ import cn from '@/utils/classnames' export type IRunBatchProps = { vars: { name: string }[] onSend: (data: string[][]) => void + onBatchSend?: (originalFile: File, data: string[][], fileName?: string) => void // Extend: Batch import isAllFinished: boolean + isInstalledApp?: boolean // Extend: Batch import + installedAppInfo?: any // Extend: Batch import } const RunBatch: FC = ({ vars, onSend, + onBatchSend, // Extend: Batch import isAllFinished, }) => { const { t } = useTranslation() @@ -28,16 +32,80 @@ const RunBatch: FC = ({ const [csvData, setCsvData] = React.useState([]) const [isParsed, setIsParsed] = React.useState(false) - const handleParsed = (data: string[][]) => { + // Extend: Start Batch import + const [isUploading, setIsUploading] = React.useState(false) + const [fileName, setFileName] = React.useState('') + const [originalFile, setOriginalFile] = React.useState(null) + const [isRecentlyClicked, setIsRecentlyClicked] = React.useState(false) + + const handleParsed = (data: string[][], originalFile?: File) => { + console.log('handleParsed 被调用, originalFile:', originalFile ? originalFile.name : 'undefined') setCsvData(data) - // console.log(data) setIsParsed(true) + if (originalFile) { + setFileName(originalFile.name) + setOriginalFile(originalFile) + console.log('originalFile 已设置:', originalFile.name) + } + else { + console.warn('⚠️ originalFile 未传递!') + } } - const handleSend = () => { - onSend(csvData) + const handleSend = async () => { + console.log('=== 批量运行调试信息 ===') + console.log('csvData:', csvData ? csvData.length : 'null') + console.log('originalFile:', originalFile ? originalFile.name : 'null') + console.log('onBatchSend:', onBatchSend ? '已定义' : '未定义') + console.log('isRecentlyClicked:', isRecentlyClicked) + + if (!csvData || csvData.length === 0 || !originalFile || isRecentlyClicked) { + console.log('提前返回,原因:', { + noCsvData: !csvData || csvData.length === 0, + noOriginalFile: !originalFile, + isRecentlyClicked, + }) + return + } + + // 设置防重复点击状态 + setIsRecentlyClicked(true) + + // 3秒后允许再次点击 + setTimeout(() => { + setIsRecentlyClicked(false) + }, 3000) + + const dataRows = csvData.slice(1).filter(row => !row.every(cell => cell === '')) + const rowCount = dataRows.length + + console.log('有效数据行数:', rowCount) + console.log('判断条件: rowCount > 10 && onBatchSend =', rowCount > 10, '&&', !!onBatchSend, '=', rowCount > 10 && !!onBatchSend) + + // 如果超过10行,使用批量处理 + if (rowCount > 10 && onBatchSend) { + console.log('✅ 使用admin后台批量处理') + setIsUploading(true) + try { + await onBatchSend(originalFile, csvData, fileName) + } + catch (error) { + console.error('批量处理失败:', error) + } + finally { + setIsUploading(false) + } + } + else { + console.log('❌ 使用旧的前端处理逻辑') + onSend(csvData) + } } - const Icon = isAllFinished ? RiPlayLargeLine : RiLoader2Line + + const Icon = isAllFinished && !isUploading ? RiPlayLargeLine : RiLoader2Line + const isDisabled = !isParsed || (!isAllFinished && !isUploading) || isRecentlyClicked + + // Extend: Start Batch import return (
@@ -47,7 +115,7 @@ const RunBatch: FC = ({ variant="primary" className={cn('mt-4 pl-3 pr-4', !isPC && 'grow')} onClick={handleSend} - disabled={!isParsed || !isAllFinished} + disabled={isDisabled} >