Fix: Failure to update local model configuration to gateway

This commit is contained in:
Liujian
2025-02-20 14:24:39 +08:00
parent dcf18705da
commit 6eaa946be6
8 changed files with 259 additions and 38 deletions
+34 -3
View File
@@ -7,11 +7,15 @@ import (
"io"
"math"
"net/http"
"net/url"
"strings"
"time"
system_dto "github.com/APIParkLab/APIPark/module/system/dto"
ai_balance "github.com/APIParkLab/APIPark/module/ai-balance"
"github.com/APIParkLab/APIPark/module/system"
system_dto "github.com/APIParkLab/APIPark/module/system/dto"
ollama_api "github.com/ollama/ollama/api"
"github.com/APIParkLab/APIPark/module/subscribe"
subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto"
@@ -51,6 +55,7 @@ type imlLocalModelController struct {
serviceModule service.IServiceModule `autowired:""`
catalogueModule catalogue.ICatalogueModule `autowired:""`
aiAPIModule ai_api.IAPIModule `autowired:""`
aiBalanceModule ai_balance.IBalanceModule `autowired:""`
appModule service.IAppModule `autowired:""`
routerModule router.IRouterModule `autowired:""`
subscribeModule subscribe.ISubscribeModule `autowired:""`
@@ -66,9 +71,35 @@ func (i *imlLocalModelController) OllamaConfig(ctx *gin.Context) (*ai_local_dto.
}, nil
}
var (
client = &http.Client{
Timeout: 2 * time.Second,
}
)
func (i *imlLocalModelController) OllamaConfigUpdate(ctx *gin.Context, input *ai_local_dto.OllamaConfig) error {
return i.settingModule.Set(ctx, &system_dto.InputSetting{
OllamaAddress: &input.Address,
u, err := url.Parse(input.Address)
if err != nil {
return nil
}
ollamaClient := ollama_api.NewClient(u, client)
_, err = ollamaClient.Version(ctx)
if err != nil {
return err
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.module.SyncLocalModels(ctx, input.Address)
if err != nil {
return err
}
err = i.aiBalanceModule.SyncLocalBalances(ctx, input.Address)
if err != nil {
return err
}
return i.settingModule.Set(ctx, &system_dto.InputSetting{
OllamaAddress: &input.Address,
})
})
}
+12 -3
View File
@@ -477,13 +477,13 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se
}
var err error
var info *service_dto.Service
err = i.transaction.Transaction(ctx, func(txCtx context.Context) error {
info, err = i.module.Create(txCtx, teamID, input)
err = i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err = i.module.Create(ctx, teamID, input)
if err != nil {
return err
}
path := fmt.Sprintf("/%s/", strings.Trim(input.Prefix, "/"))
_, err = i.routerModule.Create(txCtx, info.Id, &router_dto.Create{
_, err = i.routerModule.Create(ctx, info.Id, &router_dto.Create{
Id: uuid.New().String(),
Name: "",
Path: path + "*",
@@ -499,6 +499,15 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se
},
Disable: false,
})
apps, err := i.appModule.Search(ctx, teamID, "")
if err != nil {
return err
}
for _, app := range apps {
i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{
Application: app.Id,
})
}
return err
})
return info, err
+26 -9
View File
@@ -10,6 +10,8 @@ import (
"strings"
"time"
subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto"
"github.com/eolinker/eosc/log"
ai_dto "github.com/APIParkLab/APIPark/module/ai/dto"
@@ -222,6 +224,7 @@ type imlInitController struct {
applicationAuthorizationModule application_authorization.IAuthorizationModule `autowired:""`
catalogueModule catalogue.ICatalogueModule `autowired:""`
providerModule ai.IProviderModule `autowired:""`
subscribeModule subscribe.ISubscribeModule `autowired:""`
transaction store.ITransaction `autowired:""`
aiAPIModule ai_api.IAPIModule `autowired:""`
docModule service.IServiceDocModule `autowired:""`
@@ -264,6 +267,13 @@ func (i *imlInitController) OnInit() {
if err != nil {
return fmt.Errorf("create default team error: %v", err)
}
app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{
Name: "Demo Application",
Description: "Auto created By APIPark",
})
if err != nil {
return fmt.Errorf("create default app error: %v", err)
}
// 创建Rest服务
restPath := "/rest-demo"
serviceInfo, err := i.serviceModule.Create(ctx, info.Id, &service_dto.CreateService{
@@ -298,6 +308,13 @@ func (i *imlInitController) OnInit() {
if err != nil {
return fmt.Errorf("create default router error: %v", err)
}
err = i.subscribeModule.AddSubscriber(ctx, serviceInfo.Id, &subscribe_dto.AddSubscriber{
Application: app.Id,
})
if err != nil {
return err
}
// 创建AI服务
err = i.createAIService(ctx, info.Id, &service_dto.CreateService{
Name: "AI Demo Service",
@@ -307,17 +324,11 @@ func (i *imlInitController) OnInit() {
Catalogue: catalogueId,
ApprovalType: "auto",
Kind: "ai",
})
}, app.Id)
if err != nil {
return err
}
app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{
Name: "Demo Application",
Description: "Auto created By APIPark",
})
if err != nil {
return fmt.Errorf("create default app error: %v", err)
}
_, err = i.applicationAuthorizationModule.AddAuthorization(ctx, app.Id, &application_authorization_dto.CreateAuthorization{
Name: "Default API Key",
Driver: "apikey",
@@ -338,7 +349,7 @@ func (i *imlInitController) OnInit() {
}
})
}
func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService) error {
func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService, appId string) error {
providerId := "fakegpt"
err := i.providerModule.UpdateProviderConfig(ctx, providerId, &ai_dto.UpdateConfig{
@@ -469,6 +480,12 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string,
if err != nil {
return err
}
err = i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{
Application: appId,
})
if err != nil {
return err
}
return i.docModule.SaveServiceDoc(ctx, info.Id, &service_dto.SaveServiceDoc{
Doc: "The Translation API allows developers to translate text from one language to another. It supports multiple languages and enables easy integration of high-quality translation features into applications. With simple API requests, you can quickly translate content into different target languages.",
+96 -6
View File
@@ -43,6 +43,14 @@ type imlBalanceModule struct {
transaction store.ITransaction `autowired:""`
}
func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error {
releases, err := i.getLocalBalances(ctx, address)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
has, err := i.balanceService.Exist(ctx, input.Provider, input.Model)
if err != nil {
@@ -63,6 +71,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
}
providerName := ""
modelName := ""
base := ""
switch input.Type {
case ai_balance_dto.ModelTypeOnline:
p, has := model_runtime.GetProvider(input.Provider)
@@ -71,14 +80,16 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
}
providerName = p.Name()
modelName = input.Model
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
case ai_balance_dto.ModelTypeLocal:
input.Provider = "ollama"
providerName = "Ollama"
modelName = input.Model
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
}
base = v
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
@@ -98,7 +109,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, v)}, true)
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true)
})
}
@@ -143,7 +154,16 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, item := range list {
releases = append(releases, newRelease(item, v))
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
if err != nil {
@@ -237,3 +257,73 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re
return nil
}
func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc")
if err != nil {
return nil, err
}
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", nil, "priority asc")
if err != nil {
return nil, err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getBalances(ctx)
if err != nil {
return err
}
for _, p := range releases {
client, err := clientDriver.Dynamic(p.Resource)
if err != nil {
return err
}
err = client.Online(ctx, p)
if err != nil {
return err
}
}
return nil
}
+6 -1
View File
@@ -4,6 +4,8 @@ import (
"context"
"reflect"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/go-common/autowire"
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
@@ -14,10 +16,13 @@ type IBalanceModule interface {
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
Delete(ctx context.Context, id string) error
SyncLocalBalances(ctx context.Context, address string) error
}
func init() {
balanceModule := new(imlBalanceModule)
autowire.Auto[IBalanceModule](func() reflect.Value {
return reflect.ValueOf(new(imlBalanceModule))
gateway.RegisterInitHandleFunc(balanceModule.initGateway)
return reflect.ValueOf(balanceModule)
})
}
+75 -8
View File
@@ -6,6 +6,8 @@ import (
"fmt"
"strings"
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
"github.com/APIParkLab/APIPark/service/setting"
"github.com/APIParkLab/APIPark/service/api"
@@ -45,6 +47,7 @@ type imlLocalModel struct {
localModelPackageService ai_local.ILocalModelPackageService `autowired:""`
localModelStateService ai_local.ILocalModelInstallStateService `autowired:""`
localModelCacheService ai_local.ILocalModelCacheService `autowired:""`
balanceService ai_balance.IBalanceService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
routerService api.IAPIService `autowired:""`
@@ -53,6 +56,14 @@ type imlLocalModel struct {
transaction store.ITransaction `autowired:""`
}
func (i *imlLocalModel) SyncLocalModels(ctx context.Context, address string) error {
releases, err := i.getLocalModels(ctx, address)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
@@ -177,6 +188,13 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.
State: state,
Msg: msg.Msg,
})
if err != nil {
return err
}
info, err = i.localModelStateService.Get(ctx, msg.Model)
if err != nil {
return err
}
} else {
if info.Complete < msg.Completed {
@@ -422,8 +440,36 @@ func (i *imlLocalModel) Enable(ctx context.Context, model string) error {
return err
}
if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() {
status := ai_local_dto.LocalModelStateNormal.Int()
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
status := ai_local_dto.LocalModelStateNormal.Int()
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
if err != nil {
return err
}
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = info.Id
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["priority"] = 0
cfg["base"] = v
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Description: info.Id,
Resource: "ai-provider",
Version: info.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
}}, true)
})
}
return fmt.Errorf("model %s is not disabled state,can not enable", model)
}
@@ -435,7 +481,21 @@ func (i *imlLocalModel) Disable(ctx context.Context, model string) error {
}
if info.State == ai_local_dto.LocalModelStateNormal.Int() {
disable := ai_local_dto.LocalModelStateDisable.Int()
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Resource: "ai-provider",
},
}}, false)
})
}
return fmt.Errorf("model %s is not enabled state,can not disable", model)
}
@@ -515,17 +575,24 @@ func (i *imlLocalModel) OnInit() {
})
}
func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicRelease, error) {
func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
return nil, err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, errors.New("ollama_address not set")
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, errors.New("ollama_address not set")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, l := range list {
if l.State != ai_local_dto.LocalModelStateNormal.Int() {
continue
}
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = l.Id
@@ -548,7 +615,7 @@ func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicR
}
func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getLocalModels(ctx)
releases, err := i.getLocalModels(ctx, "")
if err != nil {
return err
}
+2
View File
@@ -24,6 +24,8 @@ type ILocalModelModule interface {
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
SaveCache(ctx context.Context, model string, target string) error
SyncLocalModels(ctx context.Context, address string) error
}
func init() {
+8 -8
View File
@@ -482,7 +482,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
return fmt.Errorf("ai provider not found")
}
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err := i.providerService.Get(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
@@ -533,12 +533,12 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Config: &input.Config,
Status: &status,
}
_, err = i.aiKeyService.DefaultKey(txCtx, id)
_, err = i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
err = i.aiKeyService.Create(txCtx, &ai_key.Create{
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: id,
Name: info.Name,
Config: input.Config,
@@ -549,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Priority: 1,
})
} else {
err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Config: &input.Config,
Status: &status,
})
@@ -557,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
if err != nil {
return err
}
err = i.providerService.Save(txCtx, id, pInfo)
err = i.providerService.Save(ctx, id, pInfo)
if err != nil {
return err
}
if *pInfo.Status == 0 {
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
@@ -573,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
}, false)
}
// 获取当前供应商默认Key信息
defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id)
defaultKey, err := i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
return err
}
@@ -582,7 +582,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
cfg["model"] = info.DefaultLLM
cfg["model_config"] = model.DefaultConfig()
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,