mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-12 18:11:34 +08:00
Merge branch 'main' into feature/dashen/model_mapping
This commit is contained in:
@@ -27,6 +27,7 @@ type AiModel struct {
|
||||
Id string `json:"id"`
|
||||
Config string `json:"config"`
|
||||
Provider string `json:"provider"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type EditAPI struct {
|
||||
|
||||
@@ -4,6 +4,39 @@ import (
|
||||
"github.com/eolinker/go-common/auto"
|
||||
)
|
||||
|
||||
type ModelType string
|
||||
|
||||
const (
|
||||
ModelTypeOnline ModelType = "online"
|
||||
ModelTypeLocal ModelType = "local"
|
||||
)
|
||||
|
||||
func (m ModelType) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
func (m ModelType) Int() int {
|
||||
switch m {
|
||||
case ModelTypeOnline:
|
||||
return 0
|
||||
case ModelTypeLocal:
|
||||
return 1
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func FromModelType(m int) ModelType {
|
||||
switch m {
|
||||
case 0:
|
||||
return ModelTypeOnline
|
||||
case 1:
|
||||
return ModelTypeLocal
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type API struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -19,6 +52,7 @@ type API struct {
|
||||
type APIItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ModelType ModelType `json:"model_type"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Description string `json:"description"`
|
||||
Disable bool `json:"disabled"`
|
||||
|
||||
+12
-3
@@ -54,7 +54,7 @@ func (i *imlAPIModule) getAPIDoc(ctx context.Context, serviceId string) (*openap
|
||||
return openapi3Loader.LoadFromData([]byte(doc.Content))
|
||||
}
|
||||
|
||||
func (i *imlAPIModule) updateAPIDoc(ctx context.Context, serviceId string, serviceName string, path string, summary string, description string, aiPrompt *ai_api_dto.AiPrompt) error {
|
||||
func (i *imlAPIModule) updateAPIDoc(ctx context.Context, serviceId, serviceName, orgPath, path, summary, description string, aiPrompt *ai_api_dto.AiPrompt) error {
|
||||
doc, err := i.getAPIDoc(ctx, serviceId)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -64,6 +64,10 @@ func (i *imlAPIModule) updateAPIDoc(ctx context.Context, serviceId string, servi
|
||||
if aiPrompt != nil {
|
||||
variables = aiPrompt.Variables
|
||||
}
|
||||
if doc.Paths != nil {
|
||||
doc.Paths.Delete(orgPath)
|
||||
}
|
||||
|
||||
doc.AddOperation(path, http.MethodPost, genOperation(summary, description, variables))
|
||||
result, err := doc.MarshalJSON()
|
||||
if err != nil {
|
||||
@@ -103,10 +107,11 @@ func (i *imlAPIModule) Create(ctx context.Context, serviceId string, input *ai_a
|
||||
input.Id = uuid.New().String()
|
||||
}
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
err = i.updateAPIDoc(ctx, serviceId, info.Name, input.Path, input.Name, input.Description, input.AiPrompt)
|
||||
err = i.updateAPIDoc(ctx, serviceId, info.Name, "", input.Path, input.Name, input.Description, input.AiPrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return i.aiAPIService.Create(ctx, &ai_api.Create{
|
||||
ID: input.Id,
|
||||
Name: input.Name,
|
||||
@@ -118,6 +123,7 @@ func (i *imlAPIModule) Create(ctx context.Context, serviceId string, input *ai_a
|
||||
Retry: input.Retry,
|
||||
Model: input.AiModel.Id,
|
||||
Provider: input.AiModel.Provider,
|
||||
Type: ai_api_dto.ModelType(input.AiModel.Type).Int(),
|
||||
AdditionalConfig: map[string]interface{}{
|
||||
"ai_prompt": input.AiPrompt,
|
||||
"ai_model": input.AiModel,
|
||||
@@ -141,13 +147,14 @@ func (i *imlAPIModule) Edit(ctx context.Context, serviceId string, apiId string,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
orgPath := apiInfo.Path
|
||||
if input.Path != nil {
|
||||
apiInfo.Path = *input.Path
|
||||
}
|
||||
if input.Description != nil {
|
||||
apiInfo.Description = *input.Description
|
||||
}
|
||||
err = i.updateAPIDoc(ctx, serviceId, info.Name, apiInfo.Path, apiInfo.Name, apiInfo.Description, input.AiPrompt)
|
||||
err = i.updateAPIDoc(ctx, serviceId, info.Name, orgPath, apiInfo.Path, apiInfo.Name, apiInfo.Description, input.AiPrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -163,6 +170,7 @@ func (i *imlAPIModule) Edit(ctx context.Context, serviceId string, apiId string,
|
||||
if input.AiModel != nil {
|
||||
apiInfo.AdditionalConfig["ai_model"] = input.AiModel
|
||||
}
|
||||
typ := ai_api_dto.ModelType(input.AiModel.Type).Int()
|
||||
return i.aiAPIService.Save(ctx, apiId, &ai_api.Edit{
|
||||
Name: input.Name,
|
||||
Path: input.Path,
|
||||
@@ -171,6 +179,7 @@ func (i *imlAPIModule) Edit(ctx context.Context, serviceId string, apiId string,
|
||||
Retry: input.Retry,
|
||||
Model: modelId,
|
||||
Provider: providerId,
|
||||
Type: &typ,
|
||||
AdditionalConfig: &apiInfo.AdditionalConfig,
|
||||
Disable: input.Disable,
|
||||
})
|
||||
|
||||
+9
-11
@@ -13,12 +13,7 @@ func genOpenAPI3Template(title string, description string) *openapi3.T {
|
||||
Description: description,
|
||||
Version: "beta",
|
||||
}
|
||||
//result.Tags = openapi3.Tags{
|
||||
// {
|
||||
// Name: title,
|
||||
// Description: description,
|
||||
// },
|
||||
//}
|
||||
|
||||
result.Components = components
|
||||
result.Paths = new(openapi3.Paths)
|
||||
return result
|
||||
@@ -26,7 +21,6 @@ func genOpenAPI3Template(title string, description string) *openapi3.T {
|
||||
|
||||
func genOperation(summary string, description string, variables []*ai_api_dto.AiPromptVariable) *openapi3.Operation {
|
||||
operation := openapi3.NewOperation()
|
||||
//operation.Parameters = genRequestParameters(variables)
|
||||
operation.Summary = summary
|
||||
operation.Description = description
|
||||
operation.RequestBody = genRequestBody(variables)
|
||||
@@ -78,9 +72,13 @@ func genResponse() *openapi3.ResponseRef {
|
||||
|
||||
func genRequestBodySchema(variables []*ai_api_dto.AiPromptVariable) *openapi3.Schema {
|
||||
result := openapi3.NewObjectSchema()
|
||||
result.WithProperty("variables", genVariableSchema(variables))
|
||||
if len(variables) > 0 {
|
||||
result.WithProperty("variables", genVariableSchema(variables))
|
||||
result.WithRequired([]string{"variables", "messages"})
|
||||
}
|
||||
|
||||
result.WithPropertyRef("messages", messagesSchemaRef)
|
||||
result.WithRequired([]string{"variables", "messages"})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -136,10 +134,10 @@ func genMessageSchema() *openapi3.Schema {
|
||||
result.Description = "Chat Message"
|
||||
roleSchema := openapi3.NewStringSchema()
|
||||
roleSchema.Description = "Role of the message sender"
|
||||
roleSchema.Example = "assistant"
|
||||
roleSchema.Example = "user"
|
||||
contentSchema := openapi3.NewStringSchema()
|
||||
contentSchema.Description = "The message content"
|
||||
contentSchema.Example = "Hello, how can I help you?"
|
||||
contentSchema.Example = "Hello, who are you?"
|
||||
result.WithProperties(map[string]*openapi3.Schema{
|
||||
"role": roleSchema,
|
||||
"content": contentSchema,
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package ai_balance_dto
|
||||
|
||||
type Create struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type Sort struct {
|
||||
Origin string `json:"origin"`
|
||||
Target string `json:"target"`
|
||||
Sort string `json:"sort"`
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package ai_balance_dto
|
||||
|
||||
const (
|
||||
ModelTypeOnline = "online"
|
||||
ModelTypeLocal = "local"
|
||||
|
||||
StateNormal = "normal"
|
||||
StateAbnormal = "abnormal"
|
||||
)
|
||||
|
||||
type ModelType string
|
||||
|
||||
func (m ModelType) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
func (m ModelType) Int() int {
|
||||
switch m {
|
||||
case ModelTypeOnline:
|
||||
return 0
|
||||
case ModelTypeLocal:
|
||||
return 1
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func ModelTypeFromInt(i int) ModelType {
|
||||
switch i {
|
||||
case 0:
|
||||
return ModelTypeOnline
|
||||
case 1:
|
||||
return ModelTypeLocal
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type ModelState string
|
||||
|
||||
func (m ModelState) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
func (m ModelState) Int() int {
|
||||
switch m {
|
||||
case StateNormal:
|
||||
return 1
|
||||
case StateAbnormal:
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func ModelStateFromInt(i int) ModelState {
|
||||
switch i {
|
||||
case 1:
|
||||
return StateNormal
|
||||
case 0:
|
||||
return StateAbnormal
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Id string `json:"id"`
|
||||
Priority int `json:"priority"`
|
||||
Provider *BasicItem `json:"provider"`
|
||||
Model *BasicItem `json:"model"`
|
||||
Type ModelType `json:"type"`
|
||||
State ModelState `json:"state"`
|
||||
APICount int64 `json:"api_count"`
|
||||
KeyCount int64 `json:"key_count"`
|
||||
}
|
||||
|
||||
type BasicItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
package ai_balance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/setting"
|
||||
|
||||
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
|
||||
|
||||
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
|
||||
|
||||
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
|
||||
|
||||
"github.com/eolinker/go-common/store"
|
||||
|
||||
"github.com/APIParkLab/APIPark/gateway"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/cluster"
|
||||
|
||||
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
|
||||
"github.com/eolinker/eosc/log"
|
||||
)
|
||||
|
||||
var _ IBalanceModule = (*imlBalanceModule)(nil)
|
||||
|
||||
type imlBalanceModule struct {
|
||||
clusterService cluster.IClusterService `autowired:""`
|
||||
aiAPIService ai_api.IAPIService `autowired:""`
|
||||
aiKeyService ai_key.IKeyService `autowired:""`
|
||||
balanceService ai_balance.IBalanceService `autowired:""`
|
||||
settingService setting.ISettingService `autowired:""`
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error {
|
||||
releases, err := i.getLocalBalances(ctx, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
|
||||
has, err := i.balanceService.Exist(ctx, input.Provider, input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if has {
|
||||
return fmt.Errorf("model already exists")
|
||||
}
|
||||
priority, err := i.balanceService.MaxPriority(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
priority = 0
|
||||
}
|
||||
if input.Id == "" {
|
||||
input.Id = uuid.NewString()
|
||||
}
|
||||
providerName := ""
|
||||
modelName := ""
|
||||
base := ""
|
||||
switch input.Type {
|
||||
case ai_balance_dto.ModelTypeOnline:
|
||||
p, has := model_runtime.GetProvider(input.Provider)
|
||||
if !has {
|
||||
return fmt.Errorf("provider not found")
|
||||
}
|
||||
providerName = p.Name()
|
||||
modelName = input.Model
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
case ai_balance_dto.ModelTypeLocal:
|
||||
input.Provider = "ollama"
|
||||
providerName = "Ollama"
|
||||
modelName = input.Model
|
||||
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return fmt.Errorf("ollama address not found")
|
||||
}
|
||||
base = v
|
||||
}
|
||||
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
err = i.balanceService.Create(ctx, &ai_balance.Create{
|
||||
Id: input.Id,
|
||||
Priority: priority + 1,
|
||||
Provider: input.Provider,
|
||||
ProviderName: providerName,
|
||||
Model: input.Model,
|
||||
ModelName: modelName,
|
||||
Type: ai_balance_dto.ModelType(input.Type).Int(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item, err := i.balanceService.Get(ctx, input.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func newRelease(item *ai_balance.Balance, base string) *gateway.DynamicRelease {
|
||||
|
||||
cfg := make(map[string]interface{})
|
||||
cfg["provider"] = item.Provider
|
||||
cfg["model"] = item.Model
|
||||
cfg["model_config"] = ai_provider_local.OllamaConfig
|
||||
cfg["base"] = base
|
||||
cfg["priority"] = item.Priority
|
||||
return &gateway.DynamicRelease{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: item.Id,
|
||||
Description: item.ModelName,
|
||||
Resource: "ai-provider",
|
||||
Version: item.UpdateAt.Format("20060102150405"),
|
||||
MatchLabels: map[string]string{
|
||||
"module": "ai-provider",
|
||||
},
|
||||
},
|
||||
Attr: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort) error {
|
||||
var list []*ai_balance.Balance
|
||||
var err error
|
||||
switch input.Sort {
|
||||
case "after":
|
||||
list, err = i.balanceService.SortAfter(ctx, input.Origin, input.Target)
|
||||
default:
|
||||
list, err = i.balanceService.SortBefore(ctx, input.Origin, input.Target)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return fmt.Errorf("ollama address not found")
|
||||
}
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(list))
|
||||
for _, item := range list {
|
||||
base := v
|
||||
if item.Provider != "ollama" {
|
||||
p, has := model_runtime.GetProvider(item.Provider)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
}
|
||||
|
||||
releases = append(releases, newRelease(item, base))
|
||||
}
|
||||
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) {
|
||||
list, err := i.balanceService.Search(ctx, keyword, nil, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(list, func(i, j int) bool {
|
||||
return list[i].Priority < list[j].Priority
|
||||
})
|
||||
aiAPIMap, err := i.aiAPIService.CountMapByModel(ctx, "", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ai api count error:%v", err)
|
||||
}
|
||||
keyMap, err := i.aiKeyService.CountMapByProvider(ctx, "", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ai key count error:%v", err)
|
||||
}
|
||||
result := make([]*ai_balance_dto.Item, 0, len(list))
|
||||
for i, item := range list {
|
||||
priority := i + 1
|
||||
result = append(result, &ai_balance_dto.Item{
|
||||
Id: item.Id,
|
||||
Provider: &ai_balance_dto.BasicItem{
|
||||
Id: item.Provider,
|
||||
Name: item.ProviderName,
|
||||
},
|
||||
Model: &ai_balance_dto.BasicItem{
|
||||
Id: item.Model,
|
||||
Name: item.ModelName,
|
||||
},
|
||||
Priority: priority,
|
||||
Type: ai_balance_dto.ModelTypeFromInt(item.Type),
|
||||
State: ai_balance_dto.ModelStateFromInt(item.State),
|
||||
APICount: aiAPIMap[item.Model],
|
||||
KeyCount: keyMap[item.Provider],
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) Delete(ctx context.Context, id string) error {
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
err := i.balanceService.Delete(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: id,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
|
||||
client, err := i.clusterService.GatewayClient(ctx, clusterId)
|
||||
if err != nil {
|
||||
log.Errorf("get apinto client error: %v", err)
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
err := client.Close(ctx)
|
||||
if err != nil {
|
||||
log.Warn("close apinto client:", err)
|
||||
}
|
||||
}()
|
||||
for _, releaseInfo := range releases {
|
||||
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if online {
|
||||
err = dynamicClient.Online(ctx, releaseInfo)
|
||||
} else {
|
||||
dynamicClient.Offline(ctx, releaseInfo)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
|
||||
balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v == "" {
|
||||
var has bool
|
||||
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return nil, fmt.Errorf("ollama address not found")
|
||||
}
|
||||
}
|
||||
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(balances))
|
||||
for _, item := range balances {
|
||||
base := v
|
||||
if item.Provider != "ollama" {
|
||||
p, has := model_runtime.GetProvider(item.Provider)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
}
|
||||
releases = append(releases, newRelease(item, base))
|
||||
}
|
||||
return releases, nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) {
|
||||
balances, err := i.balanceService.Search(ctx, "", nil, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return nil, fmt.Errorf("ollama address not found")
|
||||
}
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(balances))
|
||||
for _, item := range balances {
|
||||
base := v
|
||||
if item.Provider != "ollama" {
|
||||
p, has := model_runtime.GetProvider(item.Provider)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
}
|
||||
releases = append(releases, newRelease(item, base))
|
||||
}
|
||||
return releases, nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
|
||||
releases, err := i.getBalances(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, p := range releases {
|
||||
client, err := clientDriver.Dynamic(p.Resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = client.Online(ctx, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package ai_balance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/APIParkLab/APIPark/gateway"
|
||||
|
||||
"github.com/eolinker/go-common/autowire"
|
||||
|
||||
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
|
||||
)
|
||||
|
||||
type IBalanceModule interface {
|
||||
Create(ctx context.Context, input *ai_balance_dto.Create) error
|
||||
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
|
||||
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
SyncLocalBalances(ctx context.Context, address string) error
|
||||
}
|
||||
|
||||
func init() {
|
||||
balanceModule := new(imlBalanceModule)
|
||||
autowire.Auto[IBalanceModule](func() reflect.Value {
|
||||
gateway.RegisterInitHandleFunc(balanceModule.initGateway)
|
||||
return reflect.ValueOf(balanceModule)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package ai_local_dto
|
||||
|
||||
type Update struct {
|
||||
Disable bool `json:"disable"`
|
||||
}
|
||||
|
||||
type CancelDeploy struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type DeployInput struct {
|
||||
Model string `json:"model"`
|
||||
Service string `json:"service"`
|
||||
Team string `json:"team"`
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package ai_local_dto
|
||||
|
||||
import "github.com/eolinker/go-common/auto"
|
||||
|
||||
type LocalModelState string
|
||||
|
||||
const (
|
||||
LocalModelStateNormal LocalModelState = "normal"
|
||||
LocalModelStateDisable LocalModelState = "disabled"
|
||||
LocalModelStateDeployingError LocalModelState = "deploying_error"
|
||||
LocalModelStateError LocalModelState = "error"
|
||||
LocalModelStateDeploying LocalModelState = "deploying"
|
||||
|
||||
DeployStateDownload DeployState = "download"
|
||||
DeployStateDeploy DeployState = "deploy"
|
||||
DeployStateInitializing DeployState = "initializing"
|
||||
DeployStateFinish DeployState = "finish"
|
||||
DeployStateDownloadError DeployState = "download error"
|
||||
DeployStateDeployError DeployState = "deploy error"
|
||||
DeployStateInitializingError DeployState = "initializing error"
|
||||
)
|
||||
|
||||
func (l LocalModelState) String() string {
|
||||
return string(l)
|
||||
}
|
||||
|
||||
func (l LocalModelState) Int() int {
|
||||
switch l {
|
||||
case LocalModelStateDisable:
|
||||
return 0
|
||||
case LocalModelStateNormal:
|
||||
return 1
|
||||
case LocalModelStateError:
|
||||
return 2
|
||||
case LocalModelStateDeploying:
|
||||
return 3
|
||||
case LocalModelStateDeployingError:
|
||||
return 4
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func FromLocalModelState(state int) LocalModelState {
|
||||
switch state {
|
||||
case 0:
|
||||
return LocalModelStateDisable
|
||||
case 1:
|
||||
return LocalModelStateNormal
|
||||
case 2:
|
||||
return LocalModelStateError
|
||||
case 3:
|
||||
return LocalModelStateDeploying
|
||||
case 4:
|
||||
return LocalModelStateDeployingError
|
||||
default:
|
||||
return LocalModelStateDisable
|
||||
}
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type SimpleItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DefaultConfig string `json:"default_config"`
|
||||
Logo string `json:"logo"`
|
||||
}
|
||||
|
||||
type LocalModelItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
State LocalModelState `json:"state"`
|
||||
APICount int64 `json:"api_count"`
|
||||
Provider string `json:"provider"`
|
||||
UpdateTime auto.TimeLabel `json:"update_time"`
|
||||
CanDelete bool `json:"can_delete"`
|
||||
}
|
||||
|
||||
type LocalModelPackageItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Size string `json:"size"`
|
||||
IsPopular bool `json:"is_popular"`
|
||||
}
|
||||
|
||||
type DeployState string
|
||||
|
||||
func (d DeployState) String() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
func (d DeployState) Int() int {
|
||||
switch d {
|
||||
case DeployStateDownload:
|
||||
return 1
|
||||
case DeployStateDeploy:
|
||||
return 2
|
||||
case DeployStateInitializing:
|
||||
return 3
|
||||
case DeployStateFinish:
|
||||
return 4
|
||||
case DeployStateDownloadError:
|
||||
return 5
|
||||
case DeployStateDeployError:
|
||||
return 6
|
||||
case DeployStateInitializingError:
|
||||
return 7
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func FromDeployState(state int) DeployState {
|
||||
switch state {
|
||||
case 1:
|
||||
return DeployStateDownload
|
||||
case 2:
|
||||
return DeployStateDeploy
|
||||
case 3:
|
||||
return DeployStateInitializing
|
||||
case 4:
|
||||
return DeployStateFinish
|
||||
case 5:
|
||||
return DeployStateDownloadError
|
||||
case 6:
|
||||
return DeployStateDeployError
|
||||
case 7:
|
||||
return DeployStateInitializingError
|
||||
default:
|
||||
return DeployStateDownload
|
||||
}
|
||||
}
|
||||
|
||||
type ModelInfo struct {
|
||||
Current int64 `json:"current"`
|
||||
Total int64 `json:"total"`
|
||||
LastMessage string `json:"last_message"`
|
||||
}
|
||||
@@ -0,0 +1,635 @@
|
||||
package ai_local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/setting"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/api"
|
||||
|
||||
"github.com/APIParkLab/APIPark/gateway"
|
||||
"github.com/eolinker/eosc/log"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/cluster"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/service"
|
||||
|
||||
"github.com/eolinker/go-common/auto"
|
||||
|
||||
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
|
||||
|
||||
"github.com/eolinker/go-common/register"
|
||||
"github.com/eolinker/go-common/server"
|
||||
|
||||
"github.com/eolinker/go-common/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
|
||||
|
||||
"github.com/eolinker/go-common/store"
|
||||
|
||||
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
|
||||
ai_local_dto "github.com/APIParkLab/APIPark/module/ai-local/dto"
|
||||
)
|
||||
|
||||
var (
|
||||
_ ILocalModelModule = (*imlLocalModel)(nil)
|
||||
)
|
||||
|
||||
type imlLocalModel struct {
|
||||
localModelService ai_local.ILocalModelService `autowired:""`
|
||||
localModelPackageService ai_local.ILocalModelPackageService `autowired:""`
|
||||
localModelStateService ai_local.ILocalModelInstallStateService `autowired:""`
|
||||
localModelCacheService ai_local.ILocalModelCacheService `autowired:""`
|
||||
balanceService ai_balance.IBalanceService `autowired:""`
|
||||
clusterService cluster.IClusterService `autowired:""`
|
||||
aiAPIService ai_api.IAPIService `autowired:""`
|
||||
routerService api.IAPIService `autowired:""`
|
||||
serviceService service.IServiceService `autowired:""`
|
||||
settingService setting.ISettingService `autowired:""`
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) SyncLocalModels(ctx context.Context, address string) error {
|
||||
releases, err := i.getLocalModels(ctx, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
|
||||
list, err := i.localModelService.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.SimpleItem {
|
||||
return &ai_local_dto.SimpleItem{
|
||||
Id: s.Id,
|
||||
Name: s.Name,
|
||||
DefaultConfig: ai_provider_local.OllamaConfig,
|
||||
Logo: ai_provider_local.OllamaSvg,
|
||||
}
|
||||
}, func(l *ai_local.LocalModel) bool {
|
||||
if l.State != ai_local_dto.LocalModelStateNormal.Int() && l.State != ai_local_dto.LocalModelStateDisable.Int() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) {
|
||||
info, err := i.localModelStateService.Get(ctx, model)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
state := ai_local_dto.FromDeployState(info.State)
|
||||
return &state, &ai_local_dto.ModelInfo{
|
||||
Current: info.Complete,
|
||||
Total: info.Total,
|
||||
LastMessage: info.Msg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelItem, error) {
|
||||
list, err := i.localModelService.Search(ctx, keyword, nil, "update_at desc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apiCountMap, err := i.aiAPIService.CountMapByModel(ctx, "", map[string]interface{}{
|
||||
"type": 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.LocalModelItem {
|
||||
count := apiCountMap[s.Id]
|
||||
return &ai_local_dto.LocalModelItem{
|
||||
Id: s.Id,
|
||||
Name: s.Name,
|
||||
State: ai_local_dto.FromLocalModelState(s.State),
|
||||
APICount: count,
|
||||
CanDelete: count < 1 && s.State != ai_local_dto.LocalModelStateDeploying.Int(),
|
||||
UpdateTime: auto.TimeLabel(s.UpdateAt),
|
||||
Provider: "ollama",
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) ListCanInstall(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelPackageItem, error) {
|
||||
|
||||
if keyword == "" {
|
||||
list, err := i.localModelPackageService.Search(ctx, keyword, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utils.SliceToSlice(list, func(s *ai_local.LocalModelPackage) *ai_local_dto.LocalModelPackageItem {
|
||||
return &ai_local_dto.LocalModelPackageItem{
|
||||
Id: s.Id,
|
||||
Name: s.Name,
|
||||
Size: s.Size,
|
||||
IsPopular: s.IsPopular,
|
||||
}
|
||||
}), nil
|
||||
} else {
|
||||
info, err := i.localModelPackageService.Get(ctx, keyword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]*ai_local_dto.LocalModelPackageItem, 0)
|
||||
|
||||
//for _, v := range list {
|
||||
models := ai_provider_local.ModelsCanInstallById(info.Id)
|
||||
for _, model := range models {
|
||||
result = append(result, &ai_local_dto.LocalModelPackageItem{
|
||||
Id: model.Id,
|
||||
Name: model.Name,
|
||||
Size: model.Size,
|
||||
IsPopular: model.IsPopular,
|
||||
})
|
||||
}
|
||||
//}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.PullMessage) error {
|
||||
return func(msg ai_provider_local.PullMessage) error {
|
||||
return i.transaction.Transaction(context.Background(), func(ctx context.Context) error {
|
||||
|
||||
state := ai_local_dto.DeployStateFinish.Int()
|
||||
modelState := ai_local_dto.LocalModelStateNormal.Int()
|
||||
if msg.Status == "error" {
|
||||
state = ai_local_dto.DeployStateDownloadError.Int()
|
||||
modelState = ai_local_dto.LocalModelStateDeployingError.Int()
|
||||
}
|
||||
err := i.localModelService.Save(ctx, msg.Model, &ai_local.EditLocalModel{State: &modelState})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err := i.localModelStateService.Get(ctx, msg.Model)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.localModelStateService.Create(ctx, &ai_local.CreateLocalModelInstallState{
|
||||
Id: msg.Model,
|
||||
Complete: msg.Completed,
|
||||
Total: msg.Total,
|
||||
State: state,
|
||||
Msg: msg.Msg,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err = i.localModelStateService.Get(ctx, msg.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
if info.Complete < msg.Completed {
|
||||
info.Complete = msg.Completed
|
||||
|
||||
}
|
||||
if info.Total < msg.Total {
|
||||
info.Total = msg.Total
|
||||
}
|
||||
if msg.Msg != "" {
|
||||
info.Msg = msg.Msg
|
||||
}
|
||||
err = i.localModelStateService.Save(ctx, msg.Model, &ai_local.EditLocalModelInstallState{State: &state, Complete: &info.Complete, Total: &info.Total, Msg: &info.Msg})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
serviceState := 0
|
||||
if msg.Status == "error" {
|
||||
state = 2
|
||||
}
|
||||
list, err := i.localModelCacheService.List(ctx, msg.Model, ai_local.CacheTypeService)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, l := range list {
|
||||
serviceInfo, err := i.serviceService.Get(ctx, l.Target)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if serviceInfo.State == serviceState {
|
||||
continue
|
||||
}
|
||||
err = i.serviceService.Save(ctx, l.Target, &service.Edit{State: &serviceState})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if state == ai_local_dto.DeployStateFinish.Int() {
|
||||
for _, f := range fn {
|
||||
err = f()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
|
||||
cfg := make(map[string]interface{})
|
||||
cfg["provider"] = "ollama"
|
||||
cfg["model"] = msg.Model
|
||||
cfg["model_config"] = ai_provider_local.OllamaConfig
|
||||
cfg["priority"] = 0
|
||||
cfg["base"] = v
|
||||
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: msg.Model,
|
||||
Description: msg.Model,
|
||||
Resource: "ai-provider",
|
||||
Version: info.UpdateAt.Format("20060102150405"),
|
||||
MatchLabels: map[string]string{
|
||||
"module": "ai-provider",
|
||||
},
|
||||
},
|
||||
Attr: cfg,
|
||||
}}, true)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
|
||||
client, err := i.clusterService.GatewayClient(ctx, clusterId)
|
||||
if err != nil {
|
||||
log.Errorf("get apinto client error: %v", err)
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
err := client.Close(ctx)
|
||||
if err != nil {
|
||||
log.Warn("close apinto client:", err)
|
||||
}
|
||||
}()
|
||||
for _, releaseInfo := range releases {
|
||||
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if online {
|
||||
err = dynamicClient.Online(ctx, releaseInfo)
|
||||
} else {
|
||||
err = dynamicClient.Offline(ctx, releaseInfo)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) Deploy(ctx context.Context, model string, session string, fn ...func() error) (*ai_provider_local.Pipeline, error) {
|
||||
var p *ai_provider_local.Pipeline
|
||||
err := i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
item, err := i.localModelCacheService.GetByTarget(ctx, ai_local.CacheTypeService, model)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
model = item.Model
|
||||
}
|
||||
info, err := i.localModelService.Get(ctx, model)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
err = i.localModelService.Create(ctx, &ai_local.CreateLocalModel{
|
||||
Id: model,
|
||||
Name: model,
|
||||
Provider: "ollama",
|
||||
State: ai_local_dto.LocalModelStateDeploying.Int(),
|
||||
})
|
||||
|
||||
} else {
|
||||
if info.State == ai_local_dto.LocalModelStateDeployingError.Int() {
|
||||
state := ai_local_dto.LocalModelStateDeploying.Int()
|
||||
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &state})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p, err = ai_provider_local.PullModel(model, session, i.pullHook(fn...))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) SaveCache(ctx context.Context, model string, target string) error {
|
||||
return i.localModelCacheService.Save(ctx, model, ai_local.CacheTypeService, target)
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) CancelDeploy(ctx context.Context, model string) error {
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
item, err := i.localModelCacheService.GetByTarget(ctx, ai_local.CacheTypeService, model)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
model = item.Model
|
||||
}
|
||||
list, err := i.localModelCacheService.List(ctx, model, ai_local.CacheTypeService)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, l := range list {
|
||||
info, err := i.serviceService.Get(ctx, l.Target)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if info.State == 0 {
|
||||
continue
|
||||
}
|
||||
err = i.serviceService.Delete(ctx, info.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = i.aiAPIService.DeleteByService(ctx, info.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = i.routerService.DeleteByService(ctx, info.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = i.localModelCacheService.Delete(ctx, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除模型
|
||||
err = i.localModelService.Delete(ctx, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai_provider_local.StopPull(model)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) RemoveModel(ctx context.Context, model string) error {
|
||||
// 判断是否有api
|
||||
count, err := i.aiAPIService.CountByModel(ctx, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return fmt.Errorf("model %s has api, can not remove", model)
|
||||
}
|
||||
info, err := i.localModelService.Get(ctx, model)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
return ai_provider_local.RemoveModel(model)
|
||||
}
|
||||
if info.State == ai_local_dto.LocalModelStateDeploying.Int() {
|
||||
return fmt.Errorf("model %s is deploying, can not remove", model)
|
||||
}
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
err = i.localModelService.Delete(ctx, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ai_provider_local.RemoveModel(model)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) Enable(ctx context.Context, model string) error {
|
||||
info, err := i.localModelService.Get(ctx, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() {
|
||||
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
status := ai_local_dto.LocalModelStateNormal.Int()
|
||||
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
cfg := make(map[string]interface{})
|
||||
cfg["provider"] = "ollama"
|
||||
cfg["model"] = info.Id
|
||||
cfg["model_config"] = ai_provider_local.OllamaConfig
|
||||
cfg["priority"] = 0
|
||||
cfg["base"] = v
|
||||
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: info.Id,
|
||||
Description: info.Id,
|
||||
Resource: "ai-provider",
|
||||
Version: info.UpdateAt.Format("20060102150405"),
|
||||
MatchLabels: map[string]string{
|
||||
"module": "ai-provider",
|
||||
},
|
||||
},
|
||||
Attr: cfg,
|
||||
}}, true)
|
||||
})
|
||||
|
||||
}
|
||||
return fmt.Errorf("model %s is not disabled state,can not enable", model)
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) Disable(ctx context.Context, model string) error {
|
||||
info, err := i.localModelService.Get(ctx, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.State == ai_local_dto.LocalModelStateNormal.Int() {
|
||||
disable := ai_local_dto.LocalModelStateDisable.Int()
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: info.Id,
|
||||
Resource: "ai-provider",
|
||||
},
|
||||
}}, false)
|
||||
})
|
||||
|
||||
}
|
||||
return fmt.Errorf("model %s is not enabled state,can not disable", model)
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) OnInit() {
|
||||
register.Handle(func(v server.Server) {
|
||||
ctx := context.Background()
|
||||
|
||||
list, err := i.localModelPackageService.List(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
oldModels := utils.SliceToMapO(list, func(s *ai_local.LocalModelPackage) (string, *ai_local.LocalModelPackage) {
|
||||
return s.Id, s
|
||||
})
|
||||
models, version := ai_provider_local.ModelsCanInstall()
|
||||
for _, model := range models {
|
||||
delete(oldModels, model.Id)
|
||||
if v, ok := oldModels[model.Id]; ok {
|
||||
if v.Version == version {
|
||||
continue
|
||||
}
|
||||
err = i.localModelPackageService.Save(ctx, model.Id, &ai_local.EditLocalModelPackage{
|
||||
Size: &model.Size,
|
||||
Hash: &model.Digest,
|
||||
Description: &model.Description,
|
||||
Version: &version,
|
||||
Popular: &model.IsPopular,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
err = i.localModelPackageService.Create(ctx, &ai_local.CreateLocalModelPackage{
|
||||
Id: model.Id,
|
||||
Name: model.Name,
|
||||
Size: model.Size,
|
||||
Hash: model.Digest,
|
||||
Description: model.Description,
|
||||
Version: version,
|
||||
Popular: model.IsPopular,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
for id := range oldModels {
|
||||
err = i.localModelPackageService.Delete(ctx, id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
installModels, err := ai_provider_local.ModelsInstalled()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, model := range installModels {
|
||||
|
||||
id := strings.TrimSuffix(model.Name, ":latest")
|
||||
name := strings.TrimSuffix(model.Name, ":latest")
|
||||
_, err = i.localModelService.Get(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return
|
||||
}
|
||||
err = i.localModelService.Create(ctx, &ai_local.CreateLocalModel{
|
||||
Id: id,
|
||||
Name: name,
|
||||
State: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
|
||||
list, err := i.localModelService.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v == "" {
|
||||
var has bool
|
||||
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return nil, errors.New("ollama_address not set")
|
||||
}
|
||||
}
|
||||
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(list))
|
||||
for _, l := range list {
|
||||
if l.State != ai_local_dto.LocalModelStateNormal.Int() {
|
||||
continue
|
||||
}
|
||||
cfg := make(map[string]interface{})
|
||||
cfg["provider"] = "ollama"
|
||||
cfg["model"] = l.Id
|
||||
cfg["model_config"] = ai_provider_local.OllamaConfig
|
||||
cfg["base"] = v
|
||||
releases = append(releases, &gateway.DynamicRelease{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: l.Id,
|
||||
Description: l.Name,
|
||||
Resource: "ai-provider",
|
||||
Version: l.UpdateAt.Format("20060102150405"),
|
||||
MatchLabels: map[string]string{
|
||||
"module": "ai-provider",
|
||||
},
|
||||
},
|
||||
Attr: cfg,
|
||||
})
|
||||
}
|
||||
return releases, nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
|
||||
releases, err := i.getLocalModels(ctx, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, p := range releases {
|
||||
client, err := clientDriver.Dynamic(p.Resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = client.Online(ctx, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package ai_local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/APIParkLab/APIPark/gateway"
|
||||
|
||||
"github.com/eolinker/go-common/autowire"
|
||||
|
||||
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
|
||||
|
||||
ai_local_dto "github.com/APIParkLab/APIPark/module/ai-local/dto"
|
||||
)
|
||||
|
||||
type ILocalModelModule interface {
|
||||
Search(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelItem, error)
|
||||
ListCanInstall(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelPackageItem, error)
|
||||
Deploy(ctx context.Context, model string, session string, fn ...func() error) (*ai_provider_local.Pipeline, error)
|
||||
CancelDeploy(ctx context.Context, model string) error
|
||||
RemoveModel(ctx context.Context, model string) error
|
||||
Enable(ctx context.Context, model string) error
|
||||
Disable(ctx context.Context, model string) error
|
||||
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
|
||||
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
|
||||
SaveCache(ctx context.Context, model string, target string) error
|
||||
|
||||
SyncLocalModels(ctx context.Context, address string) error
|
||||
}
|
||||
|
||||
func init() {
|
||||
localModel := new(imlLocalModel)
|
||||
autowire.Auto[ILocalModelModule](func() reflect.Value {
|
||||
gateway.RegisterInitHandleFunc(localModel.initGateway)
|
||||
return reflect.ValueOf(localModel)
|
||||
})
|
||||
}
|
||||
+13
-13
@@ -13,15 +13,15 @@ type SimpleProvider struct {
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Config string `json:"config"`
|
||||
GetAPIKeyUrl string `json:"get_apikey_url"`
|
||||
DefaultLLM string `json:"default_llm"`
|
||||
DefaultLLMConfig string `json:"-"`
|
||||
Priority int `json:"priority"`
|
||||
Status ProviderStatus `json:"status"`
|
||||
Configured bool `json:"configured"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Config string `json:"config"`
|
||||
GetAPIKeyUrl string `json:"get_apikey_url"`
|
||||
DefaultLLM string `json:"default_llm"`
|
||||
DefaultLLMConfig string `json:"-"`
|
||||
//Priority int `json:"priority"`
|
||||
Status ProviderStatus `json:"status"`
|
||||
Configured bool `json:"configured"`
|
||||
}
|
||||
|
||||
type ConfiguredProviderItem struct {
|
||||
@@ -31,9 +31,8 @@ type ConfiguredProviderItem struct {
|
||||
DefaultLLM string `json:"default_llm"`
|
||||
Status ProviderStatus `json:"status"`
|
||||
APICount int64 `json:"api_count"`
|
||||
KeyCount int `json:"key_count"`
|
||||
KeyStatus []*KeyStatus `json:"keys"`
|
||||
Priority int `json:"priority"`
|
||||
KeyCount int64 `json:"key_count"`
|
||||
CanDelete bool `json:"can_delete"`
|
||||
}
|
||||
|
||||
type KeyStatus struct {
|
||||
@@ -59,13 +58,14 @@ type SimpleProviderItem struct {
|
||||
DefaultConfig string `json:"default_config"`
|
||||
Status ProviderStatus `json:"status"`
|
||||
Model *BasicInfo `json:"model,omitempty"`
|
||||
Priority int `json:"-"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type BackupProvider struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Model *BasicInfo `json:"model,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type LLMItem struct {
|
||||
|
||||
+241
-229
@@ -8,9 +8,18 @@ import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/service"
|
||||
"github.com/google/uuid"
|
||||
|
||||
ai_key_dto "github.com/APIParkLab/APIPark/module/ai-key/dto"
|
||||
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
|
||||
|
||||
"github.com/eolinker/go-common/register"
|
||||
"github.com/eolinker/go-common/server"
|
||||
|
||||
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
|
||||
|
||||
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/service"
|
||||
|
||||
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
|
||||
|
||||
@@ -54,11 +63,105 @@ func newKey(key *ai_key.Key) *gateway.DynamicRelease {
|
||||
var _ IProviderModule = (*imlProviderModule)(nil)
|
||||
|
||||
type imlProviderModule struct {
|
||||
providerService ai.IProviderService `autowired:""`
|
||||
clusterService cluster.IClusterService `autowired:""`
|
||||
aiAPIService ai_api.IAPIService `autowired:""`
|
||||
aiKeyService ai_key.IKeyService `autowired:""`
|
||||
transaction store.ITransaction `autowired:""`
|
||||
providerService ai.IProviderService `autowired:""`
|
||||
clusterService cluster.IClusterService `autowired:""`
|
||||
aiAPIService ai_api.IAPIService `autowired:""`
|
||||
aiKeyService ai_key.IKeyService `autowired:""`
|
||||
aiBalanceService ai_balance.IBalanceService `autowired:""`
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) OnInit() {
|
||||
register.Handle(func(v server.Server) {
|
||||
ctx := context.Background()
|
||||
|
||||
list, err := i.providerService.List(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
for _, l := range list {
|
||||
if l.Priority < 1 {
|
||||
continue
|
||||
}
|
||||
has, err := i.aiBalanceService.Exist(ctx, l.Id, l.DefaultLLM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if has {
|
||||
continue
|
||||
}
|
||||
|
||||
p, has := model_runtime.GetProvider(l.Id)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
err = i.aiBalanceService.Create(ctx, &ai_balance.Create{
|
||||
Id: uuid.NewString(),
|
||||
Priority: l.Priority,
|
||||
Provider: l.Id,
|
||||
ProviderName: p.Name(),
|
||||
Model: l.DefaultLLM,
|
||||
ModelName: l.DefaultLLM,
|
||||
Type: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
priority := 0
|
||||
err = i.providerService.Save(ctx, l.Id, &ai.SetProvider{
|
||||
Priority: &priority,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) Delete(ctx context.Context, id string) error {
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
// 判断是否有api
|
||||
count, err := i.aiAPIService.CountByProvider(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return fmt.Errorf("provider has api")
|
||||
}
|
||||
keys, err := i.aiKeyService.KeysByProvider(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = i.aiKeyService.DeleteByProvider(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.providerService.Delete(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
releases = append(releases, newKey(key))
|
||||
}
|
||||
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: id,
|
||||
Resource: "ai-provider",
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error) {
|
||||
@@ -75,83 +178,19 @@ func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error {
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
list, err := i.providerService.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
providerMap := utils.SliceToMap(list, func(e *ai.Provider) string {
|
||||
return e.Id
|
||||
})
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(list))
|
||||
offlineReleases := make([]*gateway.DynamicRelease, 0, len(list))
|
||||
for index, id := range input.Providers {
|
||||
p, has := model_runtime.GetProvider(id)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
|
||||
l, has := providerMap[id]
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
model, has := p.GetModel(l.DefaultLLM)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
priority := index + 1
|
||||
err = i.providerService.Save(txCtx, id, &ai.SetProvider{
|
||||
Priority: &priority,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ai_dto.ToProviderStatus(l.Status) == ai_dto.ProviderDisabled {
|
||||
offlineReleases = append(offlineReleases, &gateway.DynamicRelease{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: l.Id,
|
||||
Resource: "ai-provider",
|
||||
}})
|
||||
} else {
|
||||
cfg := make(map[string]interface{})
|
||||
cfg["provider"] = l.Id
|
||||
cfg["model"] = l.DefaultLLM
|
||||
cfg["model_config"] = model.DefaultConfig()
|
||||
cfg["priority"] = l.Priority
|
||||
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
releases = append(releases, &gateway.DynamicRelease{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: l.Id,
|
||||
Description: l.Name,
|
||||
Resource: "ai-provider",
|
||||
Version: l.UpdateAt.Format("20060102150405"),
|
||||
MatchLabels: map[string]string{
|
||||
"module": "ai-provider",
|
||||
},
|
||||
},
|
||||
Attr: cfg,
|
||||
})
|
||||
}
|
||||
}
|
||||
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, offlineReleases, false)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.ConfiguredProviderItem, *ai_dto.BackupProvider, error) {
|
||||
func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword string) ([]*ai_dto.ConfiguredProviderItem, error) {
|
||||
// 获取已配置的AI服务商
|
||||
list, err := i.providerService.List(ctx)
|
||||
list, err := i.providerService.Search(ctx, keyword, nil, "update_at")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get provider list error:%v", err)
|
||||
return nil, fmt.Errorf("get provider list error:%v", err)
|
||||
}
|
||||
aiAPIMap, err := i.aiAPIService.CountMapByProvider(ctx, "", nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get ai api count error:%v", err)
|
||||
return nil, fmt.Errorf("get ai api count error:%v", err)
|
||||
}
|
||||
keyMap, err := i.aiKeyService.CountMapByProvider(ctx, "", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ai key count error:%v", err)
|
||||
}
|
||||
providers := make([]*ai_dto.ConfiguredProviderItem, 0, len(list))
|
||||
for _, l := range list {
|
||||
@@ -159,7 +198,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
|
||||
_, err = i.aiKeyService.DefaultKey(ctx, l.Id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
err = i.aiKeyService.Create(ctx, &ai_key.Create{
|
||||
ID: l.Id,
|
||||
@@ -173,7 +212,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
|
||||
Default: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create default key error:%v", err)
|
||||
return nil, fmt.Errorf("create default key error:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,29 +220,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
keys, err := i.aiKeyService.KeysByProvider(ctx, l.Id)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get provider keys error:%v", err)
|
||||
}
|
||||
|
||||
keysStatus := make([]*ai_dto.KeyStatus, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
status := ai_key_dto.ToKeyStatus(k.Status)
|
||||
switch status {
|
||||
case ai_key_dto.KeyNormal, ai_key_dto.KeyDisable, ai_key_dto.KeyError:
|
||||
default:
|
||||
status = ai_key_dto.KeyError
|
||||
}
|
||||
keysStatus = append(keysStatus, &ai_dto.KeyStatus{
|
||||
Id: k.ID,
|
||||
Name: k.Name,
|
||||
Status: status.String(),
|
||||
Priority: k.Priority,
|
||||
})
|
||||
}
|
||||
sort.Slice(keysStatus, func(i, j int) bool {
|
||||
return keysStatus[i].Priority < keysStatus[j].Priority
|
||||
})
|
||||
apiCount := aiAPIMap[l.Id]
|
||||
|
||||
providers = append(providers, &ai_dto.ConfiguredProviderItem{
|
||||
Id: l.Id,
|
||||
@@ -211,35 +228,13 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
|
||||
Logo: p.Logo(),
|
||||
DefaultLLM: l.DefaultLLM,
|
||||
Status: ai_dto.ToProviderStatus(l.Status),
|
||||
APICount: aiAPIMap[l.Id],
|
||||
KeyCount: len(keysStatus),
|
||||
KeyStatus: keysStatus,
|
||||
Priority: l.Priority,
|
||||
APICount: apiCount,
|
||||
KeyCount: keyMap[l.Id],
|
||||
CanDelete: apiCount < 1,
|
||||
})
|
||||
}
|
||||
sort.Slice(providers, func(i, j int) bool {
|
||||
if providers[i].Priority != providers[j].Priority {
|
||||
if providers[i].Priority == 0 {
|
||||
return false
|
||||
}
|
||||
if providers[j].Priority == 0 {
|
||||
return true
|
||||
}
|
||||
return providers[i].Priority < providers[j].Priority
|
||||
}
|
||||
return providers[i].Name < providers[j].Name
|
||||
})
|
||||
var backup *ai_dto.BackupProvider
|
||||
for _, p := range providers {
|
||||
if p.Status == ai_dto.ProviderEnabled {
|
||||
backup = &ai_dto.BackupProvider{
|
||||
Id: p.Id,
|
||||
Name: p.Name,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return providers, backup, nil
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, error) {
|
||||
@@ -252,6 +247,7 @@ func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.Simp
|
||||
providerMap := utils.SliceToMap(list, func(e *ai.Provider) string {
|
||||
return e.Id
|
||||
})
|
||||
|
||||
items := make([]*ai_dto.SimpleProviderItem, 0, len(providers))
|
||||
for _, v := range providers {
|
||||
item := &ai_dto.SimpleProviderItem{
|
||||
@@ -264,31 +260,35 @@ func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.Simp
|
||||
if info, has := providerMap[v.ID()]; has {
|
||||
item.Configured = true
|
||||
item.Status = ai_dto.ToProviderStatus(info.Status)
|
||||
item.Priority = info.Priority
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
if items[i].Priority != items[j].Priority {
|
||||
if items[i].Priority == 0 {
|
||||
return false
|
||||
}
|
||||
if items[j].Priority == 0 {
|
||||
return true
|
||||
}
|
||||
return items[i].Priority < items[j].Priority
|
||||
}
|
||||
return items[i].Name < items[j].Name
|
||||
})
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error) {
|
||||
func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context, all bool) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error) {
|
||||
list, err := i.providerService.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items := make([]*ai_dto.SimpleProviderItem, 0, len(list))
|
||||
|
||||
healthProvider := make(map[string]struct{})
|
||||
if all {
|
||||
healthProvider["ollama"] = struct{}{}
|
||||
items = append(items, &ai_dto.SimpleProviderItem{
|
||||
Id: "ollama",
|
||||
Name: "Ollama",
|
||||
Logo: ai_provider_local.OllamaSvg,
|
||||
Configured: true,
|
||||
DefaultConfig: "",
|
||||
Status: ai_dto.ProviderEnabled,
|
||||
Type: "local",
|
||||
})
|
||||
}
|
||||
|
||||
var backup *ai_dto.BackupProvider
|
||||
for _, l := range list {
|
||||
p, has := model_runtime.GetProvider(l.Id)
|
||||
@@ -308,34 +308,32 @@ func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context) ([]*a
|
||||
Logo: p.Logo(),
|
||||
DefaultConfig: p.DefaultConfig(),
|
||||
Status: ai_dto.ToProviderStatus(l.Status),
|
||||
Priority: l.Priority,
|
||||
Configured: true,
|
||||
Model: &ai_dto.BasicInfo{
|
||||
Id: model.ID(),
|
||||
Name: model.ID(),
|
||||
},
|
||||
}
|
||||
|
||||
if item.Status == ai_dto.ProviderEnabled {
|
||||
healthProvider[l.Id] = struct{}{}
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
if items[i].Priority != items[j].Priority {
|
||||
if items[i].Priority == 0 {
|
||||
return false
|
||||
}
|
||||
if items[j].Priority == 0 {
|
||||
return true
|
||||
}
|
||||
return items[i].Priority < items[j].Priority
|
||||
}
|
||||
return items[i].Name < items[j].Name
|
||||
})
|
||||
for _, item := range items {
|
||||
if item.Status == ai_dto.ProviderEnabled {
|
||||
|
||||
aiBalanceItems, err := i.aiBalanceService.Search(ctx, "", nil, "priority asc")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, item := range aiBalanceItems {
|
||||
if _, has := healthProvider[item.Provider]; has {
|
||||
backup = &ai_dto.BackupProvider{
|
||||
Id: item.Id,
|
||||
Name: item.Name,
|
||||
Model: item.Model,
|
||||
Id: item.Provider,
|
||||
Name: item.Provider,
|
||||
Model: &ai_dto.BasicInfo{
|
||||
Id: item.Model,
|
||||
Name: item.Model,
|
||||
},
|
||||
Type: "local",
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -388,13 +386,7 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
|
||||
if !has {
|
||||
return nil, fmt.Errorf("ai provider not found")
|
||||
}
|
||||
maxPriority, err := i.providerService.MaxPriority(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
maxPriority = maxPriority + 1
|
||||
|
||||
info, err := i.providerService.Get(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -412,7 +404,7 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
|
||||
DefaultLLM: defaultLLM.ID(),
|
||||
DefaultLLMConfig: defaultLLM.Logo(),
|
||||
Status: ai_dto.ProviderDisabled,
|
||||
Priority: maxPriority,
|
||||
//Priority: maxPriority,
|
||||
}, nil
|
||||
}
|
||||
defaultLLM, has := p.GetModel(info.DefaultLLM)
|
||||
@@ -423,9 +415,6 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
|
||||
}
|
||||
defaultLLM = model
|
||||
}
|
||||
if info.Priority == 0 {
|
||||
info.Priority = maxPriority
|
||||
}
|
||||
|
||||
return &ai_dto.Provider{
|
||||
Id: info.Id,
|
||||
@@ -434,9 +423,9 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
|
||||
GetAPIKeyUrl: p.HelpUrl(),
|
||||
DefaultLLM: defaultLLM.ID(),
|
||||
DefaultLLMConfig: defaultLLM.DefaultConfig(),
|
||||
Priority: info.Priority,
|
||||
Status: ai_dto.ToProviderStatus(info.Status),
|
||||
Configured: true,
|
||||
//Priority: info.Priority,
|
||||
Status: ai_dto.ToProviderStatus(info.Status),
|
||||
Configured: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -492,38 +481,48 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider not found")
|
||||
}
|
||||
info, err := i.providerService.Get(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
info, err := i.providerService.Get(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
if input.DefaultLLM == "" {
|
||||
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider default llm not found")
|
||||
}
|
||||
input.DefaultLLM = defaultLLM.ID()
|
||||
}
|
||||
info = &ai.Provider{
|
||||
Id: id,
|
||||
Name: p.Name(),
|
||||
DefaultLLM: input.DefaultLLM,
|
||||
Config: input.Config,
|
||||
}
|
||||
err = i.providerService.Create(ctx, &ai.CreateProvider{
|
||||
Id: info.Id,
|
||||
Name: info.Name,
|
||||
DefaultLLM: input.DefaultLLM,
|
||||
Config: input.Config,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
model, has := p.GetModel(input.DefaultLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider model not found")
|
||||
}
|
||||
err = p.Check(input.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if input.DefaultLLM == "" {
|
||||
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider default llm not found")
|
||||
}
|
||||
input.DefaultLLM = defaultLLM.ID()
|
||||
input.Config, err = p.GenConfig(input.Config, info.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info = &ai.Provider{
|
||||
Id: id,
|
||||
Name: p.Name(),
|
||||
DefaultLLM: input.DefaultLLM,
|
||||
Config: input.Config,
|
||||
}
|
||||
}
|
||||
model, has := p.GetModel(input.DefaultLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider model not found")
|
||||
}
|
||||
err = p.Check(input.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input.Config, err = p.GenConfig(input.Config, info.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
status := 0
|
||||
if input.Enable != nil && *input.Enable {
|
||||
status = 1
|
||||
@@ -532,15 +531,14 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
Name: &info.Name,
|
||||
DefaultLLM: &input.DefaultLLM,
|
||||
Config: &input.Config,
|
||||
Priority: input.Priority,
|
||||
Status: &status,
|
||||
}
|
||||
_, err = i.aiKeyService.DefaultKey(txCtx, id)
|
||||
_, err = i.aiKeyService.DefaultKey(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
err = i.aiKeyService.Create(txCtx, &ai_key.Create{
|
||||
err = i.aiKeyService.Create(ctx, &ai_key.Create{
|
||||
ID: id,
|
||||
Name: info.Name,
|
||||
Config: input.Config,
|
||||
@@ -551,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
Priority: 1,
|
||||
})
|
||||
} else {
|
||||
err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{
|
||||
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
|
||||
Config: &input.Config,
|
||||
Status: &status,
|
||||
})
|
||||
@@ -559,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = i.providerService.Save(txCtx, id, pInfo)
|
||||
err = i.providerService.Save(ctx, id, pInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *pInfo.Status == 0 {
|
||||
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: id,
|
||||
@@ -575,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
}, false)
|
||||
}
|
||||
// 获取当前供应商默认Key信息
|
||||
defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id)
|
||||
defaultKey, err := i.aiKeyService.DefaultKey(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -583,9 +581,8 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
cfg["provider"] = info.Id
|
||||
cfg["model"] = info.DefaultLLM
|
||||
cfg["model_config"] = model.DefaultConfig()
|
||||
cfg["priority"] = info.Priority
|
||||
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: id,
|
||||
@@ -624,7 +621,6 @@ func (i *imlProviderModule) getAiProviders(ctx context.Context) ([]*gateway.Dyna
|
||||
cfg["provider"] = l.Id
|
||||
cfg["model"] = l.DefaultLLM
|
||||
cfg["model_config"] = model.DefaultConfig()
|
||||
cfg["priority"] = l.Priority
|
||||
providers = append(providers, &gateway.DynamicRelease{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: l.Id,
|
||||
@@ -694,16 +690,38 @@ func (i *imlProviderModule) syncGateway(ctx context.Context, clusterId string, r
|
||||
var _ IAIAPIModule = (*imlAIApiModule)(nil)
|
||||
|
||||
type imlAIApiModule struct {
|
||||
aiAPIService ai_api.IAPIService `autowired:""`
|
||||
aiAPIUseService ai_api.IAPIUseService `autowired:""`
|
||||
serviceService service.IServiceService `autowired:""`
|
||||
aiAPIService ai_api.IAPIService `autowired:""`
|
||||
aiAPIUseService ai_api.IAPIUseService `autowired:""`
|
||||
serviceService service.IServiceService `autowired:""`
|
||||
aiLocalModelService ai_local.ILocalModelService `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId string, start int64, end int64, page int, pageSize int, sortCondition string, asc bool, models []string, serviceIds []string) ([]*ai_dto.APIItem, *ai_dto.Condition, int64, error) {
|
||||
p, has := model_runtime.GetProvider(providerId)
|
||||
if !has {
|
||||
return nil, nil, 0, fmt.Errorf("ai provider not found")
|
||||
modelItems := make([]*ai_dto.BasicInfo, 0)
|
||||
if providerId == "ollama" {
|
||||
items, err := i.aiLocalModelService.Search(ctx, "", nil, "update_at desc")
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
modelItems = utils.SliceToSlice(items, func(e *ai_local.LocalModel) *ai_dto.BasicInfo {
|
||||
return &ai_dto.BasicInfo{
|
||||
Id: e.Id,
|
||||
Name: e.Name,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
p, has := model_runtime.GetProvider(providerId)
|
||||
if !has {
|
||||
return nil, nil, 0, fmt.Errorf("ai provider not found")
|
||||
}
|
||||
modelItems = utils.SliceToSlice(p.Models(), func(e model_runtime.IModel) *ai_dto.BasicInfo {
|
||||
return &ai_dto.BasicInfo{
|
||||
Id: e.ID(),
|
||||
Name: e.ID(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
sortRule := "desc"
|
||||
if asc {
|
||||
sortRule = "asc"
|
||||
@@ -723,12 +741,6 @@ func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId st
|
||||
|
||||
}
|
||||
|
||||
modelItems := utils.SliceToSlice(p.Models(), func(e model_runtime.IModel) *ai_dto.BasicInfo {
|
||||
return &ai_dto.BasicInfo{
|
||||
Id: e.ID(),
|
||||
Name: e.ID(),
|
||||
}
|
||||
})
|
||||
condition := &ai_dto.Condition{Services: serviceItems, Models: modelItems}
|
||||
switch sortCondition {
|
||||
default:
|
||||
|
||||
+6
-5
@@ -10,23 +10,24 @@ import (
|
||||
)
|
||||
|
||||
type IProviderModule interface {
|
||||
ConfiguredProviders(ctx context.Context) ([]*ai_dto.ConfiguredProviderItem, *ai_dto.BackupProvider, error)
|
||||
ConfiguredProviders(ctx context.Context, keyword string) ([]*ai_dto.ConfiguredProviderItem, error)
|
||||
UnConfiguredProviders(ctx context.Context) ([]*ai_dto.ProviderItem, error)
|
||||
SimpleProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, error)
|
||||
SimpleConfiguredProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error)
|
||||
SimpleConfiguredProviders(ctx context.Context, all bool) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error)
|
||||
Provider(ctx context.Context, id string) (*ai_dto.Provider, error)
|
||||
SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error)
|
||||
LLMs(ctx context.Context, driver string) ([]*ai_dto.LLMItem, *ai_dto.ProviderItem, error)
|
||||
//UpdateProviderStatus(ctx context.Context, id string, enable bool) error
|
||||
UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error
|
||||
//UpdateProviderDefaultLLM(ctx context.Context, id string, input *ai_dto.UpdateLLM) error
|
||||
Sort(ctx context.Context, input *ai_dto.Sort) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type IAIAPIModule interface {
|
||||
APIs(ctx context.Context, keyword string, providerId string, start int64, end int64, page int, pageSize int, sortCondition string, asc bool, models []string, services []string) ([]*ai_dto.APIItem, *ai_dto.Condition, int64, error)
|
||||
}
|
||||
|
||||
type ILocalModelModule interface {
|
||||
}
|
||||
|
||||
func init() {
|
||||
autowire.Auto[IProviderModule](func() reflect.Value {
|
||||
module := new(imlProviderModule)
|
||||
|
||||
@@ -2,9 +2,10 @@ package catalogue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/APIParkLab/APIPark/module/system"
|
||||
"reflect"
|
||||
|
||||
"github.com/APIParkLab/APIPark/module/system"
|
||||
|
||||
"github.com/eolinker/go-common/autowire"
|
||||
|
||||
catalogue_dto "github.com/APIParkLab/APIPark/module/catalogue/dto"
|
||||
@@ -28,6 +29,7 @@ type ICatalogueModule interface {
|
||||
// Subscribe 订阅服务
|
||||
Subscribe(ctx context.Context, subscribeInfo *catalogue_dto.SubscribeService) error
|
||||
Sort(ctx context.Context, sorts []*catalogue_dto.SortItem) error
|
||||
DefaultCatalogue(ctx context.Context) (*catalogue_dto.Catalogue, error)
|
||||
//ExportAll(ctx context.Context) ([]*catalogue_dto.ExportCatalogue, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +66,24 @@ type imlCatalogueModule struct {
|
||||
root *Root
|
||||
}
|
||||
|
||||
func (i *imlCatalogueModule) DefaultCatalogue(ctx context.Context) (*catalogue_dto.Catalogue, error) {
|
||||
catalogues, err := i.catalogueService.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range catalogues {
|
||||
if v.Parent == "" {
|
||||
return &catalogue_dto.Catalogue{
|
||||
Id: v.Id,
|
||||
Name: v.Name,
|
||||
Parent: v.Parent,
|
||||
Sort: v.Sort,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("no default catalogue")
|
||||
}
|
||||
|
||||
func (i *imlCatalogueModule) onlineSubscriber(ctx context.Context, clusterId string, sub *gateway.SubscribeRelease) error {
|
||||
client, err := i.clusterService.GatewayClient(ctx, clusterId)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package service_dto
|
||||
|
||||
type QuickCreateAIService struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Config string `json:"config"`
|
||||
Team string `json:"team"`
|
||||
}
|
||||
|
||||
type CreateService struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -11,7 +18,9 @@ type CreateService struct {
|
||||
Catalogue string `json:"catalogue"`
|
||||
ApprovalType string `json:"approval_type"`
|
||||
Kind string `json:"service_kind"`
|
||||
Provider *string `json:"provider" aocheck:"ai_provider"`
|
||||
State string `json:"state"`
|
||||
Provider *string `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
AsApp *bool `json:"as_app"`
|
||||
AsServer *bool `json:"as_server"`
|
||||
ModelMapping string `json:"model_mapping"`
|
||||
@@ -24,8 +33,10 @@ type EditService struct {
|
||||
Catalogue *string `json:"catalogue"`
|
||||
Logo *string `json:"logo"`
|
||||
Tags *[]string `json:"tags"`
|
||||
Provider *string `json:"provider" aocheck:"ai_provider"`
|
||||
Provider *string `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
ApprovalType *string `json:"approval_type"`
|
||||
State *string `json:"state"`
|
||||
ModelMapping string `json:"model_mapping"`
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,44 @@ import (
|
||||
"github.com/eolinker/go-common/auto"
|
||||
)
|
||||
|
||||
type ServiceState string
|
||||
|
||||
const (
|
||||
ServiceStateNormal ServiceState = "normal"
|
||||
ServiceStateDeploying ServiceState = "deploying"
|
||||
ServiceStateDeployError ServiceState = "error"
|
||||
)
|
||||
|
||||
func (s ServiceState) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func (s ServiceState) Int() int {
|
||||
switch s {
|
||||
case ServiceStateNormal:
|
||||
return 0
|
||||
case ServiceStateDeploying:
|
||||
return 1
|
||||
case ServiceStateDeployError:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func FromServiceState(s int) ServiceState {
|
||||
switch s {
|
||||
case 0:
|
||||
return ServiceStateNormal
|
||||
case 1:
|
||||
return ServiceStateDeploying
|
||||
case 2:
|
||||
return ServiceStateDeployError
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type ServiceItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -15,6 +53,7 @@ type ServiceItem struct {
|
||||
CreateTime auto.TimeLabel `json:"create_time"`
|
||||
UpdateTime auto.TimeLabel `json:"update_time"`
|
||||
Provider *auto.Label `json:"provider,omitempty" aolabel:"ai_provider"`
|
||||
State string `json:"state"`
|
||||
CanDelete bool `json:"can_delete"`
|
||||
}
|
||||
|
||||
@@ -58,10 +97,13 @@ type Service struct {
|
||||
Tags []auto.Label `json:"tags" aolabel:"tag"`
|
||||
Logo string `json:"logo"`
|
||||
Provider *auto.Label `json:"provider,omitempty" aolabel:"ai_provider"`
|
||||
ProviderType string `json:"provider_type,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
ApprovalType string `json:"approval_type"`
|
||||
AsServer bool `json:"as_server"`
|
||||
AsApp bool `json:"as_app"`
|
||||
ServiceKind string `json:"service_kind"`
|
||||
State string `json:"state"`
|
||||
ModelMapping string `json:"model_mapping"`
|
||||
}
|
||||
|
||||
@@ -80,6 +122,7 @@ func ToService(model *service.Service) *Service {
|
||||
if model.Prefix != "" {
|
||||
prefix = model.Prefix
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
Id: model.Id,
|
||||
Name: model.Name,
|
||||
@@ -96,10 +139,25 @@ func ToService(model *service.Service) *Service {
|
||||
AsApp: model.AsApp,
|
||||
ServiceKind: model.Kind.String(),
|
||||
}
|
||||
state := FromServiceState(model.State)
|
||||
if state == ServiceStateNormal {
|
||||
s.State = model.ServiceType.String()
|
||||
} else {
|
||||
s.State = state.String()
|
||||
}
|
||||
|
||||
switch model.Kind {
|
||||
case service.AIService:
|
||||
provider := auto.UUID(model.AdditionalConfig["provider"])
|
||||
s.Provider = &provider
|
||||
s.ProviderType = "local"
|
||||
if provider.Id != "ollama" {
|
||||
s.ProviderType = "online"
|
||||
}
|
||||
modelId := model.AdditionalConfig["model"]
|
||||
if modelId != "" {
|
||||
s.Model = modelId
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
+60
-7
@@ -8,6 +8,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
|
||||
|
||||
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
|
||||
"github.com/APIParkLab/APIPark/resources/access"
|
||||
@@ -59,6 +63,7 @@ type imlServiceModule struct {
|
||||
teamService team.ITeamService `autowired:""`
|
||||
teamMemberService team_member.ITeamMemberService `autowired:""`
|
||||
tagService tag.ITagService `autowired:""`
|
||||
localModelService ai_local.ILocalModelService `autowired:""`
|
||||
|
||||
serviceTagService service_tag.ITagService `autowired:""`
|
||||
apiService api.IAPIService `autowired:""`
|
||||
@@ -126,7 +131,7 @@ func (i *imlServiceModule) searchMyServices(ctx context.Context, teamId string,
|
||||
return nil, err
|
||||
}
|
||||
condition["team"] = teamId
|
||||
return i.serviceService.Search(ctx, keyword, condition, "update_at desc")
|
||||
return i.serviceService.Search(ctx, keyword, condition, "create_at desc")
|
||||
} else {
|
||||
membersForUser, err := i.teamMemberService.FilterMembersForUser(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -134,7 +139,7 @@ func (i *imlServiceModule) searchMyServices(ctx context.Context, teamId string,
|
||||
}
|
||||
teamIds := membersForUser[userID]
|
||||
condition["team"] = teamIds
|
||||
return i.serviceService.Search(ctx, keyword, condition, "update_at desc")
|
||||
return i.serviceService.Search(ctx, keyword, condition, "create_at desc")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,6 +225,25 @@ func (i *imlServiceModule) Get(ctx context.Context, id string) (*service_dto.Ser
|
||||
s.Tags = auto.List(utils.SliceToSlice(tags, func(p *service_tag.Tag) string {
|
||||
return p.Tid
|
||||
}))
|
||||
if s.Model == "" {
|
||||
switch s.ProviderType {
|
||||
case "online":
|
||||
p, has := model_runtime.GetProvider(s.Provider.Id)
|
||||
if has {
|
||||
m, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
||||
if has {
|
||||
s.Model = m.ID()
|
||||
}
|
||||
}
|
||||
case "local":
|
||||
info, err := i.localModelService.DefaultModel(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Model = info.Id
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
serviceModelMapping, err := i.serviceModelMappingService.GetByService(ctx, id)
|
||||
if err != nil {
|
||||
@@ -239,9 +263,9 @@ func (i *imlServiceModule) Search(ctx context.Context, teamID string, keyword st
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"team": teamID, "as_server": true}, "update_at desc")
|
||||
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"team": teamID, "as_server": true}, "create_at desc")
|
||||
} else {
|
||||
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"as_server": true}, "update_at desc")
|
||||
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"as_server": true}, "create_at desc")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -277,8 +301,16 @@ func toServiceItem(model *service.Service) *service_dto.ServiceItem {
|
||||
Team: auto.UUID(model.Team),
|
||||
ServiceKind: model.Kind.String(),
|
||||
}
|
||||
state := service_dto.FromServiceState(model.State)
|
||||
if state == service_dto.ServiceStateNormal {
|
||||
item.State = model.ServiceType.String()
|
||||
} else {
|
||||
item.State = state.String()
|
||||
}
|
||||
|
||||
switch model.Kind {
|
||||
case service.RestService:
|
||||
item.State = model.ServiceType.String()
|
||||
return item
|
||||
case service.AIService:
|
||||
provider := auto.UUID(model.AdditionalConfig["provider"])
|
||||
@@ -293,6 +325,13 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser
|
||||
if input.Id == "" {
|
||||
input.Id = uuid.New().String()
|
||||
}
|
||||
if teamID == "" {
|
||||
item, err := i.teamService.DefaultTeam(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
teamID = item.Id
|
||||
}
|
||||
mo := &service.Create{
|
||||
Id: input.Id,
|
||||
Name: input.Name,
|
||||
@@ -302,6 +341,7 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser
|
||||
Catalogue: input.Catalogue,
|
||||
Prefix: input.Prefix,
|
||||
Logo: input.Logo,
|
||||
State: service_dto.ServiceState(input.State).Int(),
|
||||
ApprovalType: service.ApprovalType(input.ApprovalType),
|
||||
AdditionalConfig: make(map[string]string),
|
||||
Kind: service.Kind(input.Kind),
|
||||
@@ -315,6 +355,11 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser
|
||||
return nil, fmt.Errorf("ai service: provider can not be empty")
|
||||
}
|
||||
mo.AdditionalConfig["provider"] = *input.Provider
|
||||
if input.Model == nil {
|
||||
return nil, fmt.Errorf("ai service: model can not be empty")
|
||||
}
|
||||
mo.AdditionalConfig["model"] = *input.Model
|
||||
|
||||
}
|
||||
if input.AsApp == nil {
|
||||
// 默认值为false
|
||||
@@ -374,6 +419,9 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d
|
||||
if input.Provider != nil {
|
||||
info.AdditionalConfig["provider"] = *input.Provider
|
||||
}
|
||||
if input.Model != nil {
|
||||
info.AdditionalConfig["model"] = *input.Model
|
||||
}
|
||||
}
|
||||
err = i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
serviceType := (*service.ServiceType)(input.ServiceType)
|
||||
@@ -386,8 +434,7 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d
|
||||
if input.ApprovalType != nil {
|
||||
approvalType = service.ApprovalType(*input.ApprovalType)
|
||||
}
|
||||
|
||||
err = i.serviceService.Save(ctx, id, &service.Edit{
|
||||
editCfg := &service.Edit{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Logo: input.Logo,
|
||||
@@ -395,7 +442,13 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d
|
||||
Catalogue: input.Catalogue,
|
||||
AdditionalConfig: &info.AdditionalConfig,
|
||||
ApprovalType: &approvalType,
|
||||
})
|
||||
}
|
||||
if input.State != nil {
|
||||
state := service_dto.ServiceState(*input.State).Int()
|
||||
editCfg.State = &state
|
||||
}
|
||||
|
||||
err = i.serviceService.Save(ctx, id, editCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,14 +6,24 @@ import (
|
||||
)
|
||||
|
||||
type InputSetting struct {
|
||||
InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"`
|
||||
SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"`
|
||||
InvokeAddress *string `json:"invoke_address" key:"system.node.invoke_address"`
|
||||
SitePrefix *string `json:"site_prefix" key:"system.setting.site_prefix"`
|
||||
OllamaAddress *string `json:"ollama_address" key:"system.ai_model.ollama_address"`
|
||||
}
|
||||
|
||||
func (i *InputSetting) Validate() error {
|
||||
_, err := url.Parse(i.InvokeAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
if i.InvokeAddress != nil {
|
||||
_, err := url.Parse(*i.InvokeAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if i.OllamaAddress != nil {
|
||||
_, err := url.Parse(*i.OllamaAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -31,9 +41,18 @@ func ToKeyMap(i interface{}) map[string]string {
|
||||
{
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
if f.Tag.Get("key") != "" {
|
||||
result[f.Tag.Get("key")] = val.Field(i).String()
|
||||
v := val.Field(i)
|
||||
if f.Type.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
continue
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
if f.Tag.Get("key") != "" {
|
||||
result[f.Tag.Get("key")] = v.String()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
)
|
||||
|
||||
func TestMap(t *testing.T) {
|
||||
|
||||
invokeAddress := "http://127.0.0.1:8080"
|
||||
ollamaAddress := "http://127.0.0.1:8081"
|
||||
input := &InputSetting{
|
||||
InvokeAddress: "http://127.0.0.1:8080",
|
||||
InvokeAddress: &invokeAddress,
|
||||
OllamaAddress: &ollamaAddress,
|
||||
}
|
||||
err := input.Validate()
|
||||
if err != nil {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
type Setting struct {
|
||||
InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"`
|
||||
SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"`
|
||||
OllamaAddress string `json:"ollama_address" key:"system.ai_model.ollama_address"`
|
||||
}
|
||||
|
||||
func MapStringToStruct[T any](m map[string]string) *T {
|
||||
|
||||
@@ -3,6 +3,11 @@ package system
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/eolinker/go-common/server"
|
||||
|
||||
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
|
||||
"github.com/eolinker/go-common/register"
|
||||
|
||||
"github.com/eolinker/go-common/store"
|
||||
|
||||
"github.com/eolinker/go-common/utils"
|
||||
@@ -43,6 +48,21 @@ func (i *imlSettingModule) Set(ctx context.Context, input *system_dto.InputSetti
|
||||
return err
|
||||
}
|
||||
}
|
||||
if input.OllamaAddress != nil {
|
||||
ai_provider_local.ResetOllamaAddress(*input.OllamaAddress)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlSettingModule) OnInit() {
|
||||
register.Handle(func(v server.Server) {
|
||||
ctx := context.Background()
|
||||
|
||||
address, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if has {
|
||||
ai_provider_local.ResetOllamaAddress(address)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user