mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-12 18:11:34 +08:00
Fix: Failure to update local model configuration to gateway
This commit is contained in:
@@ -43,6 +43,14 @@ type imlBalanceModule struct {
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error {
|
||||
releases, err := i.getLocalBalances(ctx, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
|
||||
has, err := i.balanceService.Exist(ctx, input.Provider, input.Model)
|
||||
if err != nil {
|
||||
@@ -63,6 +71,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
|
||||
}
|
||||
providerName := ""
|
||||
modelName := ""
|
||||
base := ""
|
||||
switch input.Type {
|
||||
case ai_balance_dto.ModelTypeOnline:
|
||||
p, has := model_runtime.GetProvider(input.Provider)
|
||||
@@ -71,14 +80,16 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
|
||||
}
|
||||
providerName = p.Name()
|
||||
modelName = input.Model
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
case ai_balance_dto.ModelTypeLocal:
|
||||
input.Provider = "ollama"
|
||||
providerName = "Ollama"
|
||||
modelName = input.Model
|
||||
}
|
||||
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return fmt.Errorf("ollama address not found")
|
||||
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return fmt.Errorf("ollama address not found")
|
||||
}
|
||||
base = v
|
||||
}
|
||||
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
@@ -98,7 +109,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, v)}, true)
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true)
|
||||
})
|
||||
|
||||
}
|
||||
@@ -143,7 +154,16 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
|
||||
}
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(list))
|
||||
for _, item := range list {
|
||||
releases = append(releases, newRelease(item, v))
|
||||
base := v
|
||||
if item.Provider != "ollama" {
|
||||
p, has := model_runtime.GetProvider(item.Provider)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
}
|
||||
|
||||
releases = append(releases, newRelease(item, base))
|
||||
}
|
||||
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
|
||||
if err != nil {
|
||||
@@ -237,3 +257,73 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
|
||||
balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v == "" {
|
||||
var has bool
|
||||
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return nil, fmt.Errorf("ollama address not found")
|
||||
}
|
||||
}
|
||||
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(balances))
|
||||
for _, item := range balances {
|
||||
base := v
|
||||
if item.Provider != "ollama" {
|
||||
p, has := model_runtime.GetProvider(item.Provider)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
}
|
||||
releases = append(releases, newRelease(item, base))
|
||||
}
|
||||
return releases, nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) {
|
||||
balances, err := i.balanceService.Search(ctx, "", nil, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return nil, fmt.Errorf("ollama address not found")
|
||||
}
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(balances))
|
||||
for _, item := range balances {
|
||||
base := v
|
||||
if item.Provider != "ollama" {
|
||||
p, has := model_runtime.GetProvider(item.Provider)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
}
|
||||
releases = append(releases, newRelease(item, base))
|
||||
}
|
||||
return releases, nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
|
||||
releases, err := i.getBalances(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, p := range releases {
|
||||
client, err := clientDriver.Dynamic(p.Resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = client.Online(ctx, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/APIParkLab/APIPark/gateway"
|
||||
|
||||
"github.com/eolinker/go-common/autowire"
|
||||
|
||||
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
|
||||
@@ -14,10 +16,13 @@ type IBalanceModule interface {
|
||||
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
|
||||
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
SyncLocalBalances(ctx context.Context, address string) error
|
||||
}
|
||||
|
||||
func init() {
|
||||
balanceModule := new(imlBalanceModule)
|
||||
autowire.Auto[IBalanceModule](func() reflect.Value {
|
||||
return reflect.ValueOf(new(imlBalanceModule))
|
||||
gateway.RegisterInitHandleFunc(balanceModule.initGateway)
|
||||
return reflect.ValueOf(balanceModule)
|
||||
})
|
||||
}
|
||||
|
||||
+75
-8
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/setting"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/api"
|
||||
@@ -45,6 +47,7 @@ type imlLocalModel struct {
|
||||
localModelPackageService ai_local.ILocalModelPackageService `autowired:""`
|
||||
localModelStateService ai_local.ILocalModelInstallStateService `autowired:""`
|
||||
localModelCacheService ai_local.ILocalModelCacheService `autowired:""`
|
||||
balanceService ai_balance.IBalanceService `autowired:""`
|
||||
clusterService cluster.IClusterService `autowired:""`
|
||||
aiAPIService ai_api.IAPIService `autowired:""`
|
||||
routerService api.IAPIService `autowired:""`
|
||||
@@ -53,6 +56,14 @@ type imlLocalModel struct {
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) SyncLocalModels(ctx context.Context, address string) error {
|
||||
releases, err := i.getLocalModels(ctx, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
|
||||
list, err := i.localModelService.List(ctx)
|
||||
if err != nil {
|
||||
@@ -177,6 +188,13 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.
|
||||
State: state,
|
||||
Msg: msg.Msg,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err = i.localModelStateService.Get(ctx, msg.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
if info.Complete < msg.Completed {
|
||||
@@ -422,8 +440,36 @@ func (i *imlLocalModel) Enable(ctx context.Context, model string) error {
|
||||
return err
|
||||
}
|
||||
if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() {
|
||||
status := ai_local_dto.LocalModelStateNormal.Int()
|
||||
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
|
||||
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
status := ai_local_dto.LocalModelStateNormal.Int()
|
||||
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
cfg := make(map[string]interface{})
|
||||
cfg["provider"] = "ollama"
|
||||
cfg["model"] = info.Id
|
||||
cfg["model_config"] = ai_provider_local.OllamaConfig
|
||||
cfg["priority"] = 0
|
||||
cfg["base"] = v
|
||||
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: info.Id,
|
||||
Description: info.Id,
|
||||
Resource: "ai-provider",
|
||||
Version: info.UpdateAt.Format("20060102150405"),
|
||||
MatchLabels: map[string]string{
|
||||
"module": "ai-provider",
|
||||
},
|
||||
},
|
||||
Attr: cfg,
|
||||
}}, true)
|
||||
})
|
||||
|
||||
}
|
||||
return fmt.Errorf("model %s is not disabled state,can not enable", model)
|
||||
}
|
||||
@@ -435,7 +481,21 @@ func (i *imlLocalModel) Disable(ctx context.Context, model string) error {
|
||||
}
|
||||
if info.State == ai_local_dto.LocalModelStateNormal.Int() {
|
||||
disable := ai_local_dto.LocalModelStateDisable.Int()
|
||||
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: info.Id,
|
||||
Resource: "ai-provider",
|
||||
},
|
||||
}}, false)
|
||||
})
|
||||
|
||||
}
|
||||
return fmt.Errorf("model %s is not enabled state,can not disable", model)
|
||||
}
|
||||
@@ -515,17 +575,24 @@ func (i *imlLocalModel) OnInit() {
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicRelease, error) {
|
||||
func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
|
||||
list, err := i.localModelService.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return nil, errors.New("ollama_address not set")
|
||||
if v == "" {
|
||||
var has bool
|
||||
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
|
||||
if !has {
|
||||
return nil, errors.New("ollama_address not set")
|
||||
}
|
||||
}
|
||||
|
||||
releases := make([]*gateway.DynamicRelease, 0, len(list))
|
||||
for _, l := range list {
|
||||
if l.State != ai_local_dto.LocalModelStateNormal.Int() {
|
||||
continue
|
||||
}
|
||||
cfg := make(map[string]interface{})
|
||||
cfg["provider"] = "ollama"
|
||||
cfg["model"] = l.Id
|
||||
@@ -548,7 +615,7 @@ func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicR
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
|
||||
releases, err := i.getLocalModels(ctx)
|
||||
releases, err := i.getLocalModels(ctx, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ type ILocalModelModule interface {
|
||||
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
|
||||
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
|
||||
SaveCache(ctx context.Context, model string, target string) error
|
||||
|
||||
SyncLocalModels(ctx context.Context, address string) error
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
+8
-8
@@ -482,7 +482,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
return fmt.Errorf("ai provider not found")
|
||||
}
|
||||
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
info, err := i.providerService.Get(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -533,12 +533,12 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
Config: &input.Config,
|
||||
Status: &status,
|
||||
}
|
||||
_, err = i.aiKeyService.DefaultKey(txCtx, id)
|
||||
_, err = i.aiKeyService.DefaultKey(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
err = i.aiKeyService.Create(txCtx, &ai_key.Create{
|
||||
err = i.aiKeyService.Create(ctx, &ai_key.Create{
|
||||
ID: id,
|
||||
Name: info.Name,
|
||||
Config: input.Config,
|
||||
@@ -549,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
Priority: 1,
|
||||
})
|
||||
} else {
|
||||
err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{
|
||||
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
|
||||
Config: &input.Config,
|
||||
Status: &status,
|
||||
})
|
||||
@@ -557,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = i.providerService.Save(txCtx, id, pInfo)
|
||||
err = i.providerService.Save(ctx, id, pInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *pInfo.Status == 0 {
|
||||
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: id,
|
||||
@@ -573,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
}, false)
|
||||
}
|
||||
// 获取当前供应商默认Key信息
|
||||
defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id)
|
||||
defaultKey, err := i.aiKeyService.DefaultKey(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -582,7 +582,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
cfg["model"] = info.DefaultLLM
|
||||
cfg["model_config"] = model.DefaultConfig()
|
||||
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
|
||||
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: id,
|
||||
|
||||
Reference in New Issue
Block a user