Files
2025-07-21 16:44:11 +08:00

303 lines
8.8 KiB
Go

package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/APIParkLab/APIPark/service/subscribe"
"github.com/getkin/kin-openapi/openapi3"
"gorm.io/gorm"
"github.com/APIParkLab/APIPark/service/release"
mcp_dto "github.com/APIParkLab/APIPark/module/mcp/dto"
"github.com/eolinker/go-common/utils"
api_doc "github.com/APIParkLab/APIPark/service/api-doc"
application_authorization "github.com/APIParkLab/APIPark/service/application-authorization"
"github.com/APIParkLab/APIPark/service/service"
"github.com/mark3labs/mcp-go/mcp"
)
var _ IMcpModule = (*imlMcpModule)(nil)
var (
openapi3Loader = openapi3.NewLoader()
)
type imlMcpModule struct {
serviceService service.IServiceService `autowired:""`
appService service.IServiceService `autowired:""`
appAuthorizationService application_authorization.IAuthorizationService `autowired:""`
apiDocService api_doc.IAPIDocService `autowired:""`
subscriberService subscribe.ISubscribeService `autowired:""`
releaseService release.IReleaseService `autowired:""`
}
func (i *imlMcpModule) subscribeServiceIds(ctx context.Context, appId string) ([]string, error) {
subscribes, err := i.subscriberService.SubscriptionsByApplication(ctx, appId)
if err != nil {
return nil, fmt.Errorf("get subscriber error: %w,app id is %s", err, appId)
}
serviceIds := utils.SliceToSlice(subscribes, func(s *subscribe.Subscribe) string {
return s.Service
}, func(s *subscribe.Subscribe) bool {
return s.ApplyStatus == subscribe.ApplyStatusSubscribe
})
if len(serviceIds) == 0 {
return nil, fmt.Errorf("no subscriber found,app id is %s", appId)
}
return serviceIds, nil
}
func (i *imlMcpModule) Services(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
keyword, _ := req.GetArguments()["keyword"].(string)
appId := utils.Label(ctx, "app")
condition := map[string]interface{}{
"as_server": true,
}
if appId != "" {
serviceIds, err := i.subscribeServiceIds(ctx, appId)
if err != nil {
return nil, fmt.Errorf("get subscriber service ids error: %w,app id is %s", err, appId)
}
condition["uuid"] = serviceIds
}
list, err := i.serviceService.Search(ctx, keyword, condition, "update_at desc")
if err != nil {
return nil, fmt.Errorf("search service error: %w", err)
}
if len(list) == 0 {
list, err = i.serviceService.Search(ctx, "", condition, "update_at desc")
if err != nil {
return nil, fmt.Errorf("search service error: %w", err)
}
}
result := make([]*mcp_dto.Service, 0, len(list))
for _, s := range list {
serviceRelease, err := i.releaseService.GetRunning(ctx, s.Id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("get service release error: %w,service id is %s", err, s.Id)
}
continue
}
_, _, apiDocRelease, _, _, err := i.releaseService.GetReleaseInfos(ctx, serviceRelease.UUID)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("get service release info error: %w,service id is %s", err, s.Id)
}
continue
}
commit, err := i.apiDocService.GetDocCommit(ctx, apiDocRelease.Commit)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("get api doc release error: %w,service id is %s", err, s.Id)
}
continue
}
T, err := openapi3Loader.LoadFromData([]byte(commit.Data.Content))
if err != nil {
return nil, fmt.Errorf("load openapi3 error: %w,service id is %s", err, s.Id)
}
apis := make([]*mcp_dto.API, 0, len(T.Paths.Map()))
for path, v := range T.Paths.Map() {
for method, opt := range v.Operations() {
apis = append(apis, &mcp_dto.API{
Name: opt.Summary,
Method: method,
Path: path,
Description: opt.Description,
})
}
}
result = append(result, &mcp_dto.Service{
Id: s.Id,
Name: s.Name,
Description: s.Name,
ServiceKind: s.Kind.String(),
CreateTime: s.CreateTime,
UpdateTime: s.UpdateTime,
Apis: apis,
})
}
data, _ := json.Marshal(result)
return mcp.NewToolResultText(string(data)), nil
}
func (i *imlMcpModule) APIs(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
serviceId, _ := req.GetArguments()["service"].(string)
if serviceId == "" {
return nil, fmt.Errorf("service id is empty")
}
s, err := i.serviceService.Get(ctx, serviceId)
if err != nil {
return nil, fmt.Errorf("get service error: %w,service id is %s", err, serviceId)
}
appId := utils.Label(ctx, "app")
if appId != "" {
subscribers, err := i.subscriberService.ListByApplication(ctx, serviceId, appId)
if err != nil {
return nil, fmt.Errorf("get subscriber error: %w,app id is %s", err, appId)
}
if len(subscribers) < 1 || subscribers[0].ApplyStatus != subscribe.ApplyStatusSubscribe {
return nil, fmt.Errorf("no subscriber found,app id is %s", appId)
}
}
serviceRelease, err := i.releaseService.GetRunning(ctx, serviceId)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("get service release error: %w,service id is %s", err, s.Id)
}
return nil, fmt.Errorf("no service found,service id is %s", serviceId)
}
_, _, apiDocRelease, _, _, err := i.releaseService.GetReleaseInfos(ctx, serviceRelease.UUID)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("get service release info error: %w,service id is %s", err, s.Id)
}
return nil, fmt.Errorf("no service found,service id is %s", serviceId)
}
commit, err := i.apiDocService.GetDocCommit(ctx, apiDocRelease.Commit)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("get api doc release error: %w,service id is %s", err, s.Id)
}
return nil, fmt.Errorf("no service found,service id is %s", serviceId)
}
T, err := openapi3Loader.LoadFromData([]byte(commit.Data.Content))
if err != nil {
return nil, fmt.Errorf("load openapi3 error: %w,service id is %s", err, s.Id)
}
result := &mcp_dto.ServiceAPI{
ServiceID: serviceId,
ServiceName: s.Name,
APIDoc: T,
}
data, _ := json.Marshal(result)
return mcp.NewToolResultText(string(data)), nil
}
var (
client = &http.Client{}
)
func (i *imlMcpModule) Invoke(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
gatewayInvoke := utils.GatewayInvoke(ctx)
if gatewayInvoke == "" {
return nil, fmt.Errorf("gateway invoke is required")
}
u, err := url.Parse(gatewayInvoke)
if err != nil {
return nil, fmt.Errorf("parse gateway invoke error: %w", err)
}
if u.Scheme == "" {
u.Scheme = "http"
}
path, ok := req.GetArguments()["path"].(string)
if !ok {
return nil, fmt.Errorf("invalid path")
}
u.Path = fmt.Sprintf("%s/%s", strings.TrimSuffix(u.Path, "/"), strings.TrimPrefix(path, "/"))
method, ok := req.GetArguments()["method"].(string)
if !ok {
method = "GET"
}
queryParam := url.Values{}
query, ok := req.GetArguments()["query"].(map[string]interface{})
if ok {
for k, v := range query {
switch v := v.(type) {
case string:
queryParam.Add(k, v)
case []string:
for _, value := range v {
queryParam.Add(k, value)
}
case float64:
queryParam.Add(k, strconv.FormatFloat(v, 'f', -1, 64))
default:
return nil, fmt.Errorf("invalid query param type: %T", v)
}
}
}
u.RawQuery = queryParam.Encode()
headerParam := http.Header{}
header, ok := req.GetArguments()["header"].(map[string]interface{})
if ok {
for k, v := range header {
switch v := v.(type) {
case string:
headerParam.Set(k, v)
case []string:
for _, value := range v {
headerParam.Set(k, value)
}
default:
return nil, fmt.Errorf("invalid header param type: %T", v)
}
}
}
body, ok := req.GetArguments()["body"].(string)
if !ok {
body = ""
}
contentType, ok := req.GetArguments()["content-type"].(string)
if !ok {
contentType = "application/json"
}
request, err := http.NewRequest(method, u.String(), strings.NewReader(body))
if err != nil {
return nil, fmt.Errorf("new request error: %w", err)
}
request.Header = headerParam
request.Header.Set("Content-Type", contentType)
apikey := utils.Label(ctx, "apikey")
if apikey != "" {
appId := utils.Label(ctx, "app")
if appId == "" {
request.Header.Set("Authorization", utils.Md5(apikey))
} else {
request.Header.Set("Authorization", apikey)
}
}
resp, err := client.Do(request)
if err != nil {
return nil, fmt.Errorf("request error: %w", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response error: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("response error: %s", string(data))
}
return mcp.NewToolResultText(string(data)), nil
}