diff --git a/module/ai-api/schema.go b/module/ai-api/schema.go index 5429182a..083f2836 100644 --- a/module/ai-api/schema.go +++ b/module/ai-api/schema.go @@ -143,17 +143,58 @@ func genMessagesSchema() *openapi3.Schema { func genResponseSchema() *openapi3.Schema { result := openapi3.NewObjectSchema() result.Description = "Response from the server" - result.WithPropertyRef("message", messageSchemaRef) - openapi3.NewIntegerSchema() - result.WithProperty("code", openapi3.NewIntegerSchema()) - result.WithProperty("error", openapi3.NewStringSchema()) - result.WithProperty("finish_reason", openapi3.NewStringSchema().WithEnum( + + // 创建 choices 数组 + choicesSchema := openapi3.NewArraySchema() + choiceItemSchema := openapi3.NewObjectSchema() + + // choice 中的 message 字段 + choiceItemSchema.WithPropertyRef("message", messageSchemaRef) + + // finish_reason 字段 + finishReasonSchema := openapi3.NewStringSchema().WithEnum( "stop", "length", "function_call", "content_filter", "null", - )) - + ) + choiceItemSchema.WithProperty("finish_reason", finishReasonSchema) + + // index 字段 + choiceItemSchema.WithProperty("index", openapi3.NewIntegerSchema()) + + // logprobs 字段,可以为 null + choiceItemSchema.WithProperty("logprobs", openapi3.NewSchema().WithNullable()) + + choicesSchema.Items = &openapi3.SchemaRef{Value: choiceItemSchema} + result.WithProperty("choices", choicesSchema) + + // object 字段 + result.WithProperty("object", openapi3.NewStringSchema().WithEnum("chat.completion")) + + // usage 字段 + usageSchema := openapi3.NewObjectSchema() + usageSchema.WithProperty("prompt_tokens", openapi3.NewIntegerSchema()) + usageSchema.WithProperty("completion_tokens", openapi3.NewIntegerSchema()) + usageSchema.WithProperty("total_tokens", openapi3.NewIntegerSchema()) + + // prompt_tokens_details 字段 + promptTokensDetailsSchema := openapi3.NewObjectSchema() + promptTokensDetailsSchema.WithProperty("cached_tokens", openapi3.NewIntegerSchema()) + usageSchema.WithProperty("prompt_tokens_details", promptTokensDetailsSchema) + + result.WithProperty("usage", usageSchema) + + // 其他字段 + result.WithProperty("created", openapi3.NewIntegerSchema()) + result.WithProperty("system_fingerprint", openapi3.NewStringSchema().WithNullable()) + result.WithProperty("model", openapi3.NewStringSchema()) + result.WithProperty("id", openapi3.NewStringSchema()) + + // 保留原有的错误字段 + result.WithProperty("code", openapi3.NewIntegerSchema()) + result.WithProperty("error", openapi3.NewStringSchema()) + return result } diff --git a/module/publish/iml.go b/module/publish/iml.go index 4226ef1f..692a5f3a 100644 --- a/module/publish/iml.go +++ b/module/publish/iml.go @@ -7,6 +7,11 @@ import ( "fmt" "time" + mcp_server "github.com/APIParkLab/APIPark/mcp-server" + api_doc "github.com/APIParkLab/APIPark/service/api-doc" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mitchellh/mapstructure" + strategy_driver "github.com/APIParkLab/APIPark/module/strategy/driver" strategy_dto "github.com/APIParkLab/APIPark/module/strategy/dto" @@ -50,6 +55,7 @@ type imlPublishModule struct { releaseModule releaseModule.IReleaseModule `autowired:""` publishService publish.IPublishService `autowired:""` apiService api.IAPIService `autowired:""` + apiDocService api_doc.IAPIDocService `autowired:""` upstreamService upstream.IUpstreamService `autowired:""` strategyService strategy.IStrategyService `autowired:""` releaseService release.IReleaseService `autowired:""` @@ -515,24 +521,39 @@ func (m *imlPublishModule) Publish(ctx context.Context, serviceId string, id str return err } hasError := false - - for _, c := range clusters { - err = m.publish(ctx, flow.Id, c.Uuid, projectReleaseMap[c.Uuid]) - if err != nil { - hasError = true - log.Error(err) - continue + return m.transaction.Transaction(ctx, func(ctx context.Context) error { + for _, c := range clusters { + err = m.publish(ctx, flow.Id, c.Uuid, projectReleaseMap[c.Uuid]) + if err != nil { + hasError = true + log.Error(err) + continue + } } - } - err = m.releaseService.SetRunning(ctx, serviceId, flow.Release) - if err != nil { - return err - } - status := publish.StatusDone - if hasError { - status = publish.StatusPublishError - } - return m.publishService.SetStatus(ctx, serviceId, id, status) + err = m.releaseService.SetRunning(ctx, serviceId, flow.Release) + if err != nil { + return err + } + status := publish.StatusDone + if hasError { + status = publish.StatusPublishError + } + if status == publish.StatusDone { + info, err := m.serviceService.Get(ctx, serviceId) + if err != nil { + return err + } + if info.EnableMCP { + err = m.updateMCPServer(ctx, serviceId, info.Name, flow.Version) + if err != nil { + return err + } + } + } + + return m.publishService.SetStatus(ctx, serviceId, id, status) + }) + } func (m *imlPublishModule) List(ctx context.Context, serviceId string, page, pageSize int) ([]*dto.Publish, int64, error) { @@ -550,6 +571,81 @@ func (m *imlPublishModule) List(ctx context.Context, serviceId string, page, pag }), total, nil } +func (i *imlPublishModule) updateMCPServer(ctx context.Context, sid string, name string, version string) error { + r, err := i.releaseService.GetRunning(ctx, sid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return err + } + _, _, apiDocCommit, _, _, err := i.releaseService.GetReleaseInfos(ctx, r.UUID) + if err != nil { + return fmt.Errorf("get release info error: %w", err) + } + commitDoc, err := i.apiDocService.GetDocCommit(ctx, apiDocCommit.Commit) + if err != nil { + return fmt.Errorf("get api doc commit error: %w", err) + } + mcpInfo, err := mcp_server.ConvertMCPFromOpenAPI3Data([]byte(commitDoc.Data.Content)) + if err != nil { + return fmt.Errorf("convert mcp from openapi3 data error: %w", err) + } + tools := make([]mcp_server.ITool, 0, len(mcpInfo.Apis)) + for _, a := range mcpInfo.Apis { + toolOptions := make([]mcp.ToolOption, 0, len(a.Params)+2) + toolOptions = append(toolOptions, mcp.WithDescription(a.Description)) + headers := make(map[string]interface{}) + queries := make(map[string]interface{}) + path := make(map[string]interface{}) + for _, v := range a.Params { + p := map[string]interface{}{ + "type": "string", + "required": v.Required, + "description": v.Description, + } + switch v.In { + case "header": + headers[v.Name] = p + case "query": + queries[v.Name] = p + case "path": + path[v.Name] = p + } + } + if len(headers) > 0 { + toolOptions = append(toolOptions, mcp.WithObject(mcp_server.MCPHeader, mcp.Properties(headers), mcp.Description("request headers."))) + } + if len(queries) > 0 { + toolOptions = append(toolOptions, mcp.WithObject(mcp_server.MCPQuery, mcp.Properties(queries), mcp.Description("request queries."))) + } + if len(path) > 0 { + toolOptions = append(toolOptions, mcp.WithObject(mcp_server.MCPPath, mcp.Properties(path), mcp.Description("request path params."))) + } + if a.Body != nil { + type Schema struct { + Type string `mapstructure:"type"` + Properties map[string]interface{} `mapstructure:"properties"` + Items interface{} `mapstructure:"items"` + } + var tmp Schema + err = mapstructure.Decode(a.Body, &tmp) + if err != nil { + return err + } + switch tmp.Type { + case "object": + toolOptions = append(toolOptions, mcp.WithObject(mcp_server.MCPBody, mcp.Properties(tmp.Properties), mcp.Description("request body,it is avalible when method is POST、PUT、PATCH."))) + case "array": + toolOptions = append(toolOptions, mcp.WithArray(mcp_server.MCPBody, mcp.Items(tmp.Items), mcp.Description("request body,it is avalible when method is POST、PUT、PATCH."))) + } + } + tools = append(tools, mcp_server.NewTool(a.Summary, a.Path, a.Method, a.ContentType, toolOptions...)) + } + mcp_server.SetSSEServer(sid, name, version, tools...) + return nil +} + func (m *imlPublishModule) Detail(ctx context.Context, serviceId string, id string) (*dto.PublishDetail, error) { _, err := m.serviceService.Check(ctx, serviceId, asServer) if err != nil {