Merge pull request #211 from APIParkLab/feature/1.5-local-model

This pull request includes several changes aimed at improving error handling, adding new functionalities, and refactoring existing code. The changes primarily focus on the AI provider and controller modules.

Error Handling Improvements:
Added checks to ensure the client is initialized before performing operations in multiple functions (PullModel, StopPull, CancelPipeline, RemoveModel, ModelsInstalled) in ai-provider/local/executor.go. [1] [2] [3] [4]
New Functionalities:
Introduced OllamaConfig and OllamaConfigUpdate methods to the ILocalModelController interface and implemented them in controller/ai-local/iml.go. These methods allow for getting and updating the Ollama configuration. [1] [2]
Added functionality to automatically subscribe all applications to new services in the Create method of controller/service/iml.go.
Refactoring:
Refactored the initialization of the Ollama client by replacing the static address with a ResetOllamaAddress function in ai-provider/local/local.go.
Removed unused code and imports, such as the newAIUpstream function and upstream_dto import in controller/service/iml.go. [1] [2]
Codebase Simplification:
Simplified the OnInit method in controller/system/iml.go by consolidating the creation of default entities and adding subscription logic. [1] [2] [3]
Additional Changes:
Added new imports and modules to support the new functionalities and refactoring efforts. [1] [2] [3]
This commit is contained in:
Dot.L
2025-02-20 14:31:31 +08:00
committed by GitHub
23 changed files with 489 additions and 143 deletions
+12
View File
@@ -210,6 +210,9 @@ func (e *AsyncExecutor) DistributeToModelPipelines(model string, msg PullMessage
type PullCallback func(msg PullMessage) error
func PullModel(model string, id string, fn PullCallback) (*Pipeline, error) {
if client == nil {
return nil, fmt.Errorf("client not initialized")
}
mp, has := taskExecutor.GetModelPipeline(model)
if !has {
mp = newModelPipeline(taskExecutor.ctx, 100)
@@ -279,6 +282,9 @@ func PullModel(model string, id string, fn PullCallback) (*Pipeline, error) {
}
func StopPull(model string) {
if client == nil {
return
}
taskExecutor.CloseModelPipeline(model)
}
@@ -287,6 +293,9 @@ func CancelPipeline(model string, id string) {
}
func RemoveModel(model string) error {
if client == nil {
return fmt.Errorf("client not initialized")
}
taskExecutor.CloseModelPipeline(model)
err := client.Delete(context.Background(), &api.DeleteRequest{Model: model})
if err != nil {
@@ -298,6 +307,9 @@ func RemoveModel(model string) error {
}
func ModelsInstalled() ([]Model, error) {
if client == nil {
return nil, fmt.Errorf("client not initialized")
}
result, err := client.List(context.Background())
if err != nil {
return nil, err
+4 -13
View File
@@ -4,27 +4,18 @@ import (
"net/http"
"net/url"
"github.com/eolinker/eosc/env"
"github.com/ollama/ollama/api"
)
var (
ollamaAddress = "http://127.0.0.1:11434"
EnvOllamaAddress = "OLLAMA_ADDRESS"
client *api.Client
client *api.Client
)
func init() {
address, has := env.GetEnv(EnvOllamaAddress)
if !has {
address = ollamaAddress
}
func ResetOllamaAddress(address string) error {
u, err := url.Parse(address)
if err != nil {
u, err = url.Parse(ollamaAddress)
if err != nil {
panic(err)
}
return err
}
client = api.NewClient(u, http.DefaultClient)
return nil
}
-1
View File
@@ -1,7 +1,6 @@
package ai_provider_local
var (
OllamaBase = "http://apipark-ollama:11434"
OllamaConfig = "{\n \"mirostat\": 0,\n \"mirostat_eta\": 0.1,\n \"mirostat_tau\": 5.0,\n \"num_ctx\": 4096,\n \"repeat_last_n\":64,\n \"repeat_penalty\": 1.1,\n \"temperature\": 0.7,\n \"seed\": 42,\n \"num_predict\": 42,\n \"top_k\": 40,\n \"top_p\": 0.9,\n \"min_p\": 0.5\n}\n"
OllamaSvg = `<?xml version="1.0" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN"
+2
View File
@@ -18,6 +18,8 @@ type ILocalModelController interface {
Update(ctx *gin.Context, model string, input *ai_local_dto.Update) error
State(ctx *gin.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
SimpleList(ctx *gin.Context) ([]*ai_local_dto.SimpleItem, error)
OllamaConfig(ctx *gin.Context) (*ai_local_dto.OllamaConfig, error)
OllamaConfigUpdate(ctx *gin.Context, input *ai_local_dto.OllamaConfig) error
}
func init() {
+50
View File
@@ -7,7 +7,15 @@ import (
"io"
"math"
"net/http"
"net/url"
"strings"
"time"
ai_balance "github.com/APIParkLab/APIPark/module/ai-balance"
"github.com/APIParkLab/APIPark/module/system"
system_dto "github.com/APIParkLab/APIPark/module/system/dto"
ollama_api "github.com/ollama/ollama/api"
"github.com/APIParkLab/APIPark/module/subscribe"
subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto"
@@ -47,13 +55,54 @@ type imlLocalModelController struct {
serviceModule service.IServiceModule `autowired:""`
catalogueModule catalogue.ICatalogueModule `autowired:""`
aiAPIModule ai_api.IAPIModule `autowired:""`
aiBalanceModule ai_balance.IBalanceModule `autowired:""`
appModule service.IAppModule `autowired:""`
routerModule router.IRouterModule `autowired:""`
subscribeModule subscribe.ISubscribeModule `autowired:""`
docModule service.IServiceDocModule `autowired:""`
settingModule system.ISettingModule `autowired:""`
transaction store.ITransaction `autowired:""`
}
func (i *imlLocalModelController) OllamaConfig(ctx *gin.Context) (*ai_local_dto.OllamaConfig, error) {
cfg := i.settingModule.Get(ctx)
return &ai_local_dto.OllamaConfig{
Address: cfg.OllamaAddress,
}, nil
}
var (
client = &http.Client{
Timeout: 2 * time.Second,
}
)
func (i *imlLocalModelController) OllamaConfigUpdate(ctx *gin.Context, input *ai_local_dto.OllamaConfig) error {
u, err := url.Parse(input.Address)
if err != nil {
return nil
}
ollamaClient := ollama_api.NewClient(u, client)
_, err = ollamaClient.Version(ctx)
if err != nil {
return err
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.module.SyncLocalModels(ctx, input.Address)
if err != nil {
return err
}
err = i.aiBalanceModule.SyncLocalBalances(ctx, input.Address)
if err != nil {
return err
}
return i.settingModule.Set(ctx, &system_dto.InputSetting{
OllamaAddress: &input.Address,
})
})
}
func (i *imlLocalModelController) SimpleList(ctx *gin.Context) ([]*ai_local_dto.SimpleItem, error) {
return i.module.SimpleList(ctx)
}
@@ -209,6 +258,7 @@ func (i *imlLocalModelController) initAILocalService(ctx context.Context, model
ApprovalType: "auto",
Kind: "ai",
Provider: &providerId,
Model: &model,
})
if err != nil {
return err
+44 -40
View File
@@ -32,8 +32,6 @@ import (
api_doc "github.com/APIParkLab/APIPark/module/api-doc"
upstream_dto "github.com/APIParkLab/APIPark/module/upstream/dto"
"github.com/eolinker/eosc/log"
application_authorization "github.com/APIParkLab/APIPark/module/application-authorization"
@@ -94,7 +92,10 @@ func (i *imlServiceController) QuickCreateAIService(ctx *gin.Context, input *ser
if err != nil {
return err
}
p, err := i.providerModule.Provider(ctx, input.Provider)
if err != nil {
return err
}
id := uuid.NewString()
prefix := fmt.Sprintf("/%s", id[:8])
catalogueInfo, err := i.catalogueModule.DefaultCatalogue(ctx)
@@ -111,6 +112,7 @@ func (i *imlServiceController) QuickCreateAIService(ctx *gin.Context, input *ser
Catalogue: catalogueInfo.Id,
ApprovalType: "auto",
Provider: &input.Provider,
Model: &p.DefaultLLM,
Kind: "ai",
})
return err
@@ -294,25 +296,18 @@ func (i *imlServiceController) editAIService(ctx *gin.Context, id string, input
if input.Provider == nil {
return nil, fmt.Errorf("provider is required")
}
p, has := model_runtime.GetProvider(*input.Provider)
if !has {
return nil, fmt.Errorf("provider not found")
}
info, err := i.module.Get(ctx, id)
if err != nil {
}
err = i.transaction.Transaction(ctx, func(txCtx context.Context) error {
info, err = i.module.Edit(ctx, id, input)
if err != nil {
return err
if *input.Provider != "ollama" {
_, has := model_runtime.GetProvider(*input.Provider)
if !has {
return nil, fmt.Errorf("provider not found")
}
_, err = i.upstreamModule.Save(ctx, id, newAIUpstream(id, *input.Provider, p.URI()))
return err
})
}
info, err := i.module.Edit(ctx, id, input)
if err != nil {
return nil, err
}
//_, err = i.upstreamModule.Save(ctx, id, newAIUpstream(id, *input.Provider, p.URI()))
return info, nil
}
@@ -482,13 +477,13 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se
}
var err error
var info *service_dto.Service
err = i.transaction.Transaction(ctx, func(txCtx context.Context) error {
info, err = i.module.Create(txCtx, teamID, input)
err = i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err = i.module.Create(ctx, teamID, input)
if err != nil {
return err
}
path := fmt.Sprintf("/%s/", strings.Trim(input.Prefix, "/"))
_, err = i.routerModule.Create(txCtx, info.Id, &router_dto.Create{
_, err = i.routerModule.Create(ctx, info.Id, &router_dto.Create{
Id: uuid.New().String(),
Name: "",
Path: path + "*",
@@ -504,6 +499,15 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se
},
Disable: false,
})
apps, err := i.appModule.Search(ctx, teamID, "")
if err != nil {
return err
}
for _, app := range apps {
i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{
Application: app.Id,
})
}
return err
})
return info, err
@@ -590,22 +594,22 @@ func (i *imlAppController) DeleteApp(ctx *gin.Context, appId string) error {
return i.module.DeleteApp(ctx, appId)
}
func newAIUpstream(id string, provider string, uri model_runtime.IProviderURI) *upstream_dto.Upstream {
return &upstream_dto.Upstream{
Type: "http",
Balance: "round-robin",
Timeout: 300000,
Retry: 0,
Remark: fmt.Sprintf("auto create by ai service %s,provider is %s", id, provider),
LimitPeerSecond: 0,
ProxyHeaders: nil,
Scheme: uri.Scheme(),
PassHost: "node",
Nodes: []*upstream_dto.NodeConfig{
{
Address: uri.Host(),
Weight: 100,
},
},
}
}
//func newAIUpstream(id string, provider string, uri model_runtime.IProviderURI) *upstream_dto.Upstream {
// return &upstream_dto.Upstream{
// Type: "http",
// Balance: "round-robin",
// Timeout: 300000,
// Retry: 0,
// Remark: fmt.Sprintf("auto create by ai service %s,provider is %s", id, provider),
// LimitPeerSecond: 0,
// ProxyHeaders: nil,
// Scheme: uri.Scheme(),
// PassHost: "node",
// Nodes: []*upstream_dto.NodeConfig{
// {
// Address: uri.Host(),
// Weight: 100,
// },
// },
// }
//}
+33 -10
View File
@@ -10,6 +10,8 @@ import (
"strings"
"time"
subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto"
"github.com/eolinker/eosc/log"
ai_dto "github.com/APIParkLab/APIPark/module/ai/dto"
@@ -222,6 +224,7 @@ type imlInitController struct {
applicationAuthorizationModule application_authorization.IAuthorizationModule `autowired:""`
catalogueModule catalogue.ICatalogueModule `autowired:""`
providerModule ai.IProviderModule `autowired:""`
subscribeModule subscribe.ISubscribeModule `autowired:""`
transaction store.ITransaction `autowired:""`
aiAPIModule ai_api.IAPIModule `autowired:""`
docModule service.IServiceDocModule `autowired:""`
@@ -248,7 +251,7 @@ func (i *imlInitController) OnInit() {
if len(items) == 0 {
err = i.catalogueModule.Create(ctx, &catalogue_dto.CreateCatalogue{
Id: catalogueId,
Name: "Default Catalogue",
Name: "Default Category",
})
if err != nil {
return fmt.Errorf("create default catalogue error: %v", err)
@@ -264,6 +267,13 @@ func (i *imlInitController) OnInit() {
if err != nil {
return fmt.Errorf("create default team error: %v", err)
}
app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{
Name: "Demo Application",
Description: "Auto created By APIPark",
})
if err != nil {
return fmt.Errorf("create default app error: %v", err)
}
// 创建Rest服务
restPath := "/rest-demo"
serviceInfo, err := i.serviceModule.Create(ctx, info.Id, &service_dto.CreateService{
@@ -298,6 +308,13 @@ func (i *imlInitController) OnInit() {
if err != nil {
return fmt.Errorf("create default router error: %v", err)
}
err = i.subscribeModule.AddSubscriber(ctx, serviceInfo.Id, &subscribe_dto.AddSubscriber{
Application: app.Id,
})
if err != nil {
return err
}
// 创建AI服务
err = i.createAIService(ctx, info.Id, &service_dto.CreateService{
Name: "AI Demo Service",
@@ -307,17 +324,11 @@ func (i *imlInitController) OnInit() {
Catalogue: catalogueId,
ApprovalType: "auto",
Kind: "ai",
})
}, app.Id)
if err != nil {
return err
}
app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{
Name: "Demo Application",
Description: "Auto created By APIPark",
})
if err != nil {
return fmt.Errorf("create default app error: %v", err)
}
_, err = i.applicationAuthorizationModule.AddAuthorization(ctx, app.Id, &application_authorization_dto.CreateAuthorization{
Name: "Default API Key",
Driver: "apikey",
@@ -338,7 +349,7 @@ func (i *imlInitController) OnInit() {
}
})
}
func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService) error {
func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService, appId string) error {
providerId := "fakegpt"
err := i.providerModule.UpdateProviderConfig(ctx, providerId, &ai_dto.UpdateConfig{
@@ -351,6 +362,12 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string,
if input.Id == "" {
input.Id = uuid.New().String()
}
providerInfo, err := i.providerModule.Provider(ctx, *input.Provider)
if err != nil {
return err
}
input.Model = &providerInfo.DefaultLLM
if input.Prefix == "" {
if len(input.Id) < 9 {
input.Prefix = input.Id
@@ -463,6 +480,12 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string,
if err != nil {
return err
}
err = i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{
Application: appId,
})
if err != nil {
return err
}
return i.docModule.SaveServiceDoc(ctx, info.Id, &service_dto.SaveServiceDoc{
Doc: "The Translation API allows developers to translate text from one language to another. It supports multiple languages and enables easy integration of high-quality translation features into applications. With simple API requests, you can quickly translate content into different target languages.",
+108 -10
View File
@@ -6,6 +6,8 @@ import (
"fmt"
"sort"
"github.com/APIParkLab/APIPark/service/setting"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
@@ -37,9 +39,18 @@ type imlBalanceModule struct {
aiAPIService ai_api.IAPIService `autowired:""`
aiKeyService ai_key.IKeyService `autowired:""`
balanceService ai_balance.IBalanceService `autowired:""`
settingService setting.ISettingService `autowired:""`
transaction store.ITransaction `autowired:""`
}
func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error {
releases, err := i.getLocalBalances(ctx, address)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
has, err := i.balanceService.Exist(ctx, input.Provider, input.Model)
if err != nil {
@@ -60,6 +71,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
}
providerName := ""
modelName := ""
base := ""
switch input.Type {
case ai_balance_dto.ModelTypeOnline:
p, has := model_runtime.GetProvider(input.Provider)
@@ -68,11 +80,18 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
}
providerName = p.Name()
modelName = input.Model
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
case ai_balance_dto.ModelTypeLocal:
input.Provider = "ollama"
providerName = "Ollama"
modelName = input.Model
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
}
base = v
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.balanceService.Create(ctx, &ai_balance.Create{
Id: input.Id,
@@ -90,23 +109,19 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item)}, true)
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true)
})
}
//var (
// ollamaConfig = "{\n \"mirostat\": 0,\n \"mirostat_eta\": 0.1,\n \"mirostat_tau\": 5.0,\n \"num_ctx\": 4096,\n \"repeat_last_n\":64,\n \"repeat_penalty\": 1.1,\n \"temperature\": 0.7,\n \"seed\": 42,\n \"num_predict\": 42,\n \"top_k\": 40,\n \"top_p\": 0.9,\n \"min_p\": 0.5\n}\n"
// ollamaBase = "http://apipark-ollama:11434"
//)
func newRelease(item *ai_balance.Balance) *gateway.DynamicRelease {
func newRelease(item *ai_balance.Balance, base string) *gateway.DynamicRelease {
cfg := make(map[string]interface{})
cfg["provider"] = item.Id
cfg["provider"] = item.Provider
cfg["model"] = item.Model
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["base"] = ai_provider_local.OllamaBase
cfg["base"] = base
cfg["priority"] = item.Priority
return &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: item.Id,
@@ -133,9 +148,22 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
if err != nil {
return err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, item := range list {
releases = append(releases, newRelease(item))
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
if err != nil {
@@ -229,3 +257,73 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re
return nil
}
func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc")
if err != nil {
return nil, err
}
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", nil, "priority asc")
if err != nil {
return nil, err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getBalances(ctx)
if err != nil {
return err
}
for _, p := range releases {
client, err := clientDriver.Dynamic(p.Resource)
if err != nil {
return err
}
err = client.Online(ctx, p)
if err != nil {
return err
}
}
return nil
}
+6 -1
View File
@@ -4,6 +4,8 @@ import (
"context"
"reflect"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/go-common/autowire"
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
@@ -14,10 +16,13 @@ type IBalanceModule interface {
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
Delete(ctx context.Context, id string) error
SyncLocalBalances(ctx context.Context, address string) error
}
func init() {
balanceModule := new(imlBalanceModule)
autowire.Auto[IBalanceModule](func() reflect.Value {
return reflect.ValueOf(new(imlBalanceModule))
gateway.RegisterInitHandleFunc(balanceModule.initGateway)
return reflect.ValueOf(balanceModule)
})
}
+4
View File
@@ -58,6 +58,10 @@ func FromLocalModelState(state int) LocalModelState {
}
}
type OllamaConfig struct {
Address string `json:"address"`
}
type SimpleItem struct {
Id string `json:"id"`
Name string `json:"name"`
+113 -50
View File
@@ -4,12 +4,13 @@ import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"github.com/APIParkLab/APIPark/service/api"
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
"github.com/eolinker/eosc/env"
"github.com/APIParkLab/APIPark/service/setting"
"github.com/APIParkLab/APIPark/service/api"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/eosc/log"
@@ -46,28 +47,21 @@ type imlLocalModel struct {
localModelPackageService ai_local.ILocalModelPackageService `autowired:""`
localModelStateService ai_local.ILocalModelInstallStateService `autowired:""`
localModelCacheService ai_local.ILocalModelCacheService `autowired:""`
balanceService ai_balance.IBalanceService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
routerService api.IAPIService `autowired:""`
serviceService service.IServiceService `autowired:""`
settingService setting.ISettingService `autowired:""`
transaction store.ITransaction `autowired:""`
}
var (
// ollamaConfig = "{\n \"mirostat\": 0,\n \"mirostat_eta\": 0.1,\n \"mirostat_tau\": 5.0,\n \"num_ctx\": 4096,\n \"repeat_last_n\":64,\n \"repeat_penalty\": 1.1,\n \"temperature\": 0.7,\n \"seed\": 42,\n \"num_predict\": 42,\n \"top_k\": 40,\n \"top_p\": 0.9,\n \"min_p\": 0.5\n}\n"
ollamaBase = "http://apipark-ollama:11434"
)
func init() {
base, has := env.GetEnv("OLLAMA_BASE")
if !has {
return
func (i *imlLocalModel) SyncLocalModels(ctx context.Context, address string) error {
releases, err := i.getLocalModels(ctx, address)
if err != nil {
return err
}
_, err := url.Parse(base)
if err == nil {
ollamaBase = base
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
@@ -122,7 +116,7 @@ func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local
Name: s.Name,
State: ai_local_dto.FromLocalModelState(s.State),
APICount: count,
CanDelete: count < 1,
CanDelete: count < 1 && s.State != ai_local_dto.LocalModelStateDeploying.Int(),
UpdateTime: auto.TimeLabel(s.UpdateAt),
Provider: "ollama",
}
@@ -194,7 +188,14 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.
State: state,
Msg: msg.Msg,
})
if err != nil {
return err
}
info, err = i.localModelStateService.Get(ctx, msg.Model)
if err != nil {
return err
}
} else {
if info.Complete < msg.Completed {
info.Complete = msg.Completed
@@ -210,34 +211,31 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.
if err != nil {
return err
}
serviceState := 0
if msg.Status == "error" {
state = 2
}
serviceState := 0
if msg.Status == "error" {
state = 2
}
list, err := i.localModelCacheService.List(ctx, msg.Model, ai_local.CacheTypeService)
if err != nil {
return err
}
for _, l := range list {
serviceInfo, err := i.serviceService.Get(ctx, l.Target)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
continue
}
return err
}
list, err := i.localModelCacheService.List(ctx, msg.Model, ai_local.CacheTypeService)
if serviceInfo.State == serviceState {
continue
}
err = i.serviceService.Save(ctx, l.Target, &service.Edit{State: &serviceState})
if err != nil {
return err
}
for _, l := range list {
_, err := i.serviceService.Get(ctx, l.Target)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
continue
}
return err
}
if info.State == 0 {
continue
}
err = i.serviceService.Save(ctx, l.Target, &service.Edit{State: &serviceState})
if err != nil {
return err
}
}
}
if err != nil {
return err
}
if state == ai_local_dto.DeployStateFinish.Int() {
for _, f := range fn {
@@ -246,12 +244,14 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.
return err
}
}
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = msg.Model
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["priority"] = 0
cfg["base"] = ollamaBase
cfg["base"] = v
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
@@ -414,6 +414,16 @@ func (i *imlLocalModel) RemoveModel(ctx context.Context, model string) error {
if count > 0 {
return fmt.Errorf("model %s has api, can not remove", model)
}
info, err := i.localModelService.Get(ctx, model)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
return ai_provider_local.RemoveModel(model)
}
if info.State == ai_local_dto.LocalModelStateDeploying.Int() {
return fmt.Errorf("model %s is deploying, can not remove", model)
}
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
err = i.localModelService.Delete(ctx, model)
if err != nil {
@@ -430,8 +440,36 @@ func (i *imlLocalModel) Enable(ctx context.Context, model string) error {
return err
}
if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() {
status := ai_local_dto.LocalModelStateNormal.Int()
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
status := ai_local_dto.LocalModelStateNormal.Int()
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
if err != nil {
return err
}
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = info.Id
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["priority"] = 0
cfg["base"] = v
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Description: info.Id,
Resource: "ai-provider",
Version: info.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
}}, true)
})
}
return fmt.Errorf("model %s is not disabled state,can not enable", model)
}
@@ -443,7 +481,21 @@ func (i *imlLocalModel) Disable(ctx context.Context, model string) error {
}
if info.State == ai_local_dto.LocalModelStateNormal.Int() {
disable := ai_local_dto.LocalModelStateDisable.Int()
return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Resource: "ai-provider",
},
}}, false)
})
}
return fmt.Errorf("model %s is not enabled state,can not disable", model)
}
@@ -523,18 +575,29 @@ func (i *imlLocalModel) OnInit() {
})
}
func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicRelease, error) {
func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
return nil, err
}
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, errors.New("ollama_address not set")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, l := range list {
if l.State != ai_local_dto.LocalModelStateNormal.Int() {
continue
}
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = l.Id
cfg["model_config"] = ai_provider_local.OllamaSvg
cfg["base"] = ollamaBase
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["base"] = v
releases = append(releases, &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: l.Id,
@@ -552,7 +615,7 @@ func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicR
}
func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getLocalModels(ctx)
releases, err := i.getLocalModels(ctx, "")
if err != nil {
return err
}
+2
View File
@@ -24,6 +24,8 @@ type ILocalModelModule interface {
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
SaveCache(ctx context.Context, model string, target string) error
SyncLocalModels(ctx context.Context, address string) error
}
func init() {
+8 -8
View File
@@ -482,7 +482,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
return fmt.Errorf("ai provider not found")
}
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err := i.providerService.Get(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
@@ -533,12 +533,12 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Config: &input.Config,
Status: &status,
}
_, err = i.aiKeyService.DefaultKey(txCtx, id)
_, err = i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
err = i.aiKeyService.Create(txCtx, &ai_key.Create{
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: id,
Name: info.Name,
Config: input.Config,
@@ -549,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Priority: 1,
})
} else {
err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Config: &input.Config,
Status: &status,
})
@@ -557,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
if err != nil {
return err
}
err = i.providerService.Save(txCtx, id, pInfo)
err = i.providerService.Save(ctx, id, pInfo)
if err != nil {
return err
}
if *pInfo.Status == 0 {
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
@@ -573,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
}, false)
}
// 获取当前供应商默认Key信息
defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id)
defaultKey, err := i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
return err
}
@@ -582,7 +582,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
cfg["model"] = info.DefaultLLM
cfg["model_config"] = model.DefaultConfig()
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
+2
View File
@@ -20,6 +20,7 @@ type CreateService struct {
Kind string `json:"service_kind"`
State string `json:"state"`
Provider *string `json:"provider"`
Model *string `json:"model"`
AsApp *bool `json:"as_app"`
AsServer *bool `json:"as_server"`
}
@@ -32,6 +33,7 @@ type EditService struct {
Logo *string `json:"logo"`
Tags *[]string `json:"tags"`
Provider *string `json:"provider"`
Model *string `json:"model"`
ApprovalType *string `json:"approval_type"`
State *string `json:"state"`
}
+6 -1
View File
@@ -97,7 +97,8 @@ type Service struct {
Tags []auto.Label `json:"tags" aolabel:"tag"`
Logo string `json:"logo"`
Provider *auto.Label `json:"provider,omitempty" aolabel:"ai_provider"`
ProviderType string `json:"provider_type"`
ProviderType string `json:"provider_type,omitempty"`
Model string `json:"model,omitempty"`
ApprovalType string `json:"approval_type"`
AsServer bool `json:"as_server"`
AsApp bool `json:"as_app"`
@@ -152,6 +153,10 @@ func ToService(model *service.Service) *Service {
if provider.Id != "ollama" {
s.ProviderType = "online"
}
modelId := model.AdditionalConfig["model"]
if modelId != "" {
s.Model = modelId
}
}
return s
}
+32
View File
@@ -8,6 +8,10 @@ import (
"strings"
"time"
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
"github.com/eolinker/eosc/log"
"github.com/APIParkLab/APIPark/resources/access"
@@ -58,6 +62,7 @@ type imlServiceModule struct {
teamService team.ITeamService `autowired:""`
teamMemberService team_member.ITeamMemberService `autowired:""`
tagService tag.ITagService `autowired:""`
localModelService ai_local.ILocalModelService `autowired:""`
serviceTagService service_tag.ITagService `autowired:""`
apiService api.IAPIService `autowired:""`
@@ -223,6 +228,25 @@ func (i *imlServiceModule) Get(ctx context.Context, id string) (*service_dto.Ser
s.Tags = auto.List(utils.SliceToSlice(tags, func(p *service_tag.Tag) string {
return p.Tid
}))
if s.Model == "" {
switch s.ProviderType {
case "online":
p, has := model_runtime.GetProvider(s.Provider.Id)
if has {
m, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if has {
s.Model = m.ID()
}
}
case "local":
info, err := i.localModelService.DefaultModel(ctx)
if err != nil {
return nil, err
}
s.Model = info.Id
}
}
log.Infof("get service cost %d ms", time.Since(now).Milliseconds())
return s, nil
}
@@ -328,6 +352,11 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser
return nil, fmt.Errorf("ai service: provider can not be empty")
}
mo.AdditionalConfig["provider"] = *input.Provider
if input.Model == nil {
return nil, fmt.Errorf("ai service: model can not be empty")
}
mo.AdditionalConfig["model"] = *input.Model
}
if input.AsApp == nil {
// 默认值为false
@@ -378,6 +407,9 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d
if input.Provider != nil {
info.AdditionalConfig["provider"] = *input.Provider
}
if input.Model != nil {
info.AdditionalConfig["model"] = *input.Model
}
}
err = i.transaction.Transaction(ctx, func(ctx context.Context) error {
+26 -7
View File
@@ -6,14 +6,24 @@ import (
)
type InputSetting struct {
InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"`
SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"`
InvokeAddress *string `json:"invoke_address" key:"system.node.invoke_address"`
SitePrefix *string `json:"site_prefix" key:"system.setting.site_prefix"`
OllamaAddress *string `json:"ollama_address" key:"system.ai_model.ollama_address"`
}
func (i *InputSetting) Validate() error {
_, err := url.Parse(i.InvokeAddress)
if err != nil {
return err
if i.InvokeAddress != nil {
_, err := url.Parse(*i.InvokeAddress)
if err != nil {
return err
}
}
if i.OllamaAddress != nil {
_, err := url.Parse(*i.OllamaAddress)
if err != nil {
return err
}
}
return nil
}
@@ -31,9 +41,18 @@ func ToKeyMap(i interface{}) map[string]string {
{
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.Tag.Get("key") != "" {
result[f.Tag.Get("key")] = val.Field(i).String()
v := val.Field(i)
if f.Type.Kind() == reflect.Ptr {
if v.IsNil() {
continue
}
v = v.Elem()
}
if f.Tag.Get("key") != "" {
result[f.Tag.Get("key")] = v.String()
}
}
}
}
+4 -2
View File
@@ -6,9 +6,11 @@ import (
)
func TestMap(t *testing.T) {
invokeAddress := "http://127.0.0.1:8080"
ollamaAddress := "http://127.0.0.1:8081"
input := &InputSetting{
InvokeAddress: "http://127.0.0.1:8080",
InvokeAddress: &invokeAddress,
OllamaAddress: &ollamaAddress,
}
err := input.Validate()
if err != nil {
+1
View File
@@ -8,6 +8,7 @@ import (
type Setting struct {
InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"`
SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"`
OllamaAddress string `json:"ollama_address" key:"system.ai_model.ollama_address"`
}
func MapStringToStruct[T any](m map[string]string) *T {
+20
View File
@@ -3,6 +3,11 @@ package system
import (
"context"
"github.com/eolinker/go-common/server"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
"github.com/eolinker/go-common/register"
"github.com/eolinker/go-common/store"
"github.com/eolinker/go-common/utils"
@@ -43,6 +48,21 @@ func (i *imlSettingModule) Set(ctx context.Context, input *system_dto.InputSetti
return err
}
}
if input.OllamaAddress != nil {
ai_provider_local.ResetOllamaAddress(*input.OllamaAddress)
}
return nil
})
}
func (i *imlSettingModule) OnInit() {
register.Handle(func(v server.Server) {
ctx := context.Background()
address, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if has {
ai_provider_local.ResetOllamaAddress(address)
}
})
}
+3
View File
@@ -17,5 +17,8 @@ func (p *plugin) aiLocalApis() []pm3.Api {
pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/model/local/info", []string{"context", "query:model", "body"}, nil, p.aiLocalController.Update),
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/model/local/state", []string{"context", "query:model"}, []string{"state", "info"}, p.aiLocalController.State),
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/simple/ai/models/local/configured", []string{"context"}, []string{"models"}, p.aiLocalController.SimpleList),
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/model/local/source/ollama", []string{"context"}, []string{"config"}, p.aiLocalController.OllamaConfig),
pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/model/local/source/ollama", []string{"context", "body"}, nil, p.aiLocalController.OllamaConfigUpdate),
}
}
+8
View File
@@ -20,6 +20,14 @@ type imlLocalModelService struct {
universally.IServiceDelete
}
func (i *imlLocalModelService) DefaultModel(ctx context.Context) (*LocalModel, error) {
info, err := i.store.First(ctx, map[string]interface{}{"state": 1})
if err != nil {
return nil, err
}
return i.fromEntity(info), nil
}
func (i *imlLocalModelService) OnComplete() {
i.IServiceGet = universally.NewGet[LocalModel, ai.LocalModel](i.store, i.fromEntity)
i.IServiceCreate = universally.NewCreator[CreateLocalModel, ai.LocalModel](i.store, "ai_local_model", i.createEntityHandler, i.uniquestHandler, i.labelHandler)
+1
View File
@@ -13,6 +13,7 @@ type ILocalModelService interface {
universally.IServiceCreate[CreateLocalModel]
universally.IServiceEdit[EditLocalModel]
universally.IServiceDelete
DefaultModel(ctx context.Context) (*LocalModel, error)
}
type ILocalModelPackageService interface {