Fix: Failure to update local model configuration to gateway

This commit is contained in:
Liujian
2025-02-20 14:24:39 +08:00
parent dcf18705da
commit 6eaa946be6
8 changed files with 259 additions and 38 deletions
+96 -6
View File
@@ -43,6 +43,14 @@ type imlBalanceModule struct {
transaction store.ITransaction `autowired:""`
}
func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error {
releases, err := i.getLocalBalances(ctx, address)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
has, err := i.balanceService.Exist(ctx, input.Provider, input.Model)
if err != nil {
@@ -63,6 +71,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
}
providerName := ""
modelName := ""
base := ""
switch input.Type {
case ai_balance_dto.ModelTypeOnline:
p, has := model_runtime.GetProvider(input.Provider)
@@ -71,14 +80,16 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
}
providerName = p.Name()
modelName = input.Model
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
case ai_balance_dto.ModelTypeLocal:
input.Provider = "ollama"
providerName = "Ollama"
modelName = input.Model
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
}
base = v
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
@@ -98,7 +109,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, v)}, true)
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true)
})
}
@@ -143,7 +154,16 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, item := range list {
releases = append(releases, newRelease(item, v))
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
if err != nil {
@@ -237,3 +257,73 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re
return nil
}
func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc")
if err != nil {
return nil, err
}
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", nil, "priority asc")
if err != nil {
return nil, err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getBalances(ctx)
if err != nil {
return err
}
for _, p := range releases {
client, err := clientDriver.Dynamic(p.Resource)
if err != nil {
return err
}
err = client.Online(ctx, p)
if err != nil {
return err
}
}
return nil
}
+6 -1
View File
@@ -4,6 +4,8 @@ import (
"context"
"reflect"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/go-common/autowire"
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
@@ -14,10 +16,13 @@ type IBalanceModule interface {
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
Delete(ctx context.Context, id string) error
SyncLocalBalances(ctx context.Context, address string) error
}
func init() {
balanceModule := new(imlBalanceModule)
autowire.Auto[IBalanceModule](func() reflect.Value {
return reflect.ValueOf(new(imlBalanceModule))
gateway.RegisterInitHandleFunc(balanceModule.initGateway)
return reflect.ValueOf(balanceModule)
})
}
+75 -8
View File
@@ -6,6 +6,8 @@ import (
"fmt"
"strings"
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
"github.com/APIParkLab/APIPark/service/setting"
"github.com/APIParkLab/APIPark/service/api"
@@ -45,6 +47,7 @@ type imlLocalModel struct {
localModelPackageService ai_local.ILocalModelPackageService `autowired:""`
localModelStateService ai_local.ILocalModelInstallStateService `autowired:""`
localModelCacheService ai_local.ILocalModelCacheService `autowired:""`
balanceService ai_balance.IBalanceService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
routerService api.IAPIService `autowired:""`
@@ -53,6 +56,14 @@ type imlLocalModel struct {
transaction store.ITransaction `autowired:""`
}
func (i *imlLocalModel) SyncLocalModels(ctx context.Context, address string) error {
releases, err := i.getLocalModels(ctx, address)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
@@ -177,6 +188,13 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.
State: state,
Msg: msg.Msg,
})
if err != nil {
return err
}
info, err = i.localModelStateService.Get(ctx, msg.Model)
if err != nil {
return err
}
} else {
if info.Complete < msg.Completed {
@@ -422,8 +440,36 @@ func (i *imlLocalModel) Enable(ctx context.Context, model string) error {
return err
}
if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() {
status := ai_local_dto.LocalModelStateNormal.Int()
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
status := ai_local_dto.LocalModelStateNormal.Int()
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
if err != nil {
return err
}
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = info.Id
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["priority"] = 0
cfg["base"] = v
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Description: info.Id,
Resource: "ai-provider",
Version: info.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
}}, true)
})
}
return fmt.Errorf("model %s is not disabled state,can not enable", model)
}
@@ -435,7 +481,21 @@ func (i *imlLocalModel) Disable(ctx context.Context, model string) error {
}
if info.State == ai_local_dto.LocalModelStateNormal.Int() {
disable := ai_local_dto.LocalModelStateDisable.Int()
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Resource: "ai-provider",
},
}}, false)
})
}
return fmt.Errorf("model %s is not enabled state,can not disable", model)
}
@@ -515,17 +575,24 @@ func (i *imlLocalModel) OnInit() {
})
}
func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicRelease, error) {
func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
return nil, err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, errors.New("ollama_address not set")
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, errors.New("ollama_address not set")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, l := range list {
if l.State != ai_local_dto.LocalModelStateNormal.Int() {
continue
}
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = l.Id
@@ -548,7 +615,7 @@ func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicR
}
func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getLocalModels(ctx)
releases, err := i.getLocalModels(ctx, "")
if err != nil {
return err
}
+2
View File
@@ -24,6 +24,8 @@ type ILocalModelModule interface {
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
SaveCache(ctx context.Context, model string, target string) error
SyncLocalModels(ctx context.Context, address string) error
}
func init() {
+8 -8
View File
@@ -482,7 +482,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
return fmt.Errorf("ai provider not found")
}
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err := i.providerService.Get(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
@@ -533,12 +533,12 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Config: &input.Config,
Status: &status,
}
_, err = i.aiKeyService.DefaultKey(txCtx, id)
_, err = i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
err = i.aiKeyService.Create(txCtx, &ai_key.Create{
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: id,
Name: info.Name,
Config: input.Config,
@@ -549,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Priority: 1,
})
} else {
err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Config: &input.Config,
Status: &status,
})
@@ -557,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
if err != nil {
return err
}
err = i.providerService.Save(txCtx, id, pInfo)
err = i.providerService.Save(ctx, id, pInfo)
if err != nil {
return err
}
if *pInfo.Status == 0 {
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
@@ -573,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
}, false)
}
// 获取当前供应商默认Key信息
defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id)
defaultKey, err := i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
return err
}
@@ -582,7 +582,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
cfg["model"] = info.DefaultLLM
cfg["model_config"] = model.DefaultConfig()
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,