refactor: adjust model_runtime structure of provider model

This commit is contained in:
sunanzhi
2025-03-07 18:32:20 +08:00
parent 69fd1b915b
commit a8c842b8d0
7 changed files with 169 additions and 65 deletions
+9
View File
@@ -12,6 +12,7 @@ type IModel interface {
Name() string
Logo() string
Source() string
ModelType() string
SetLogo(logo string)
AccessConfiguration() string
ModelParameters() string
@@ -26,6 +27,8 @@ type Model struct {
modelParameters string
// default: ""/"system", "customize"
source string
// @see ModelTypeLLM etc
modelType string
//defaultConfig string
IConfig
//validator IParamValidator
@@ -52,6 +55,10 @@ func (m *Model) Source() string {
return m.source
}
func (m *Model) ModelType() string {
return m.modelType
}
func (m *Model) Logo() string {
return m.logo
}
@@ -73,6 +80,7 @@ func NewCustomizeModel(id string, name string, logo string, accessConfiguration
id: id,
name: name,
logo: logo,
modelType: ModelTypeLLM,
source: "customize",
accessConfiguration: accessConfiguration,
modelParameters: modelParameters,
@@ -154,6 +162,7 @@ func NewModel(data string, logo string) (IModel, error) {
id: cfg.Model,
name: cfg.Model,
logo: logo,
modelType: cfg.ModelType,
accessConfiguration: "",
IConfig: NewConfig(string(dCfg), params),
}, nil
+137 -49
View File
@@ -3,8 +3,10 @@ package model_runtime
import (
"encoding/json"
"fmt"
"hash/fnv"
"net/url"
"strings"
"sync"
yaml "gopkg.in/yaml.v3"
@@ -19,8 +21,8 @@ const (
type IProvider interface {
IProviderInfo
GetModelConfig() ModelConfig
SetModelsByType(modelType string, models []IModel)
SetModel(id string, model IModel)
RemoveModel(id string)
SetDefaultModel(modelType string, model IModel)
GetModel(name string) (IModel, bool)
Models() []IModel
@@ -66,17 +68,15 @@ func NewCustomizeProvider(id string, name string, models []IModel, defaultModel
}
provider := &Provider{
id: id,
name: name,
logo: GetCustomizeLogo(),
helpUrl: "",
models: eosc.BuildUntyped[string, IModel](),
defaultModels: eosc.BuildUntyped[string, IModel](),
modelsByType: eosc.BuildUntyped[string, []IModel](),
maskKeys: make([]string, 0),
recommend: false,
sort: 0,
uri: uri,
id: id,
name: name,
logo: GetCustomizeLogo(),
helpUrl: "",
registry: newModelRegistry(),
maskKeys: make([]string, 0),
recommend: false,
sort: 0,
uri: uri,
modelConfig: ModelConfig{
AccessConfigurationStatus: false,
AccessConfigurationDemo: "",
@@ -93,7 +93,6 @@ func NewCustomizeProvider(id string, name string, models []IModel, defaultModel
provider.SetDefaultModel(name, model)
}
}
provider.SetModelsByType(ModelTypeLLM, models)
return provider, nil
}
@@ -123,17 +122,15 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri
return nil, fmt.Errorf("model logo not found:%s", providerCfg.Provider)
}
provider := &Provider{
id: providerCfg.Provider,
name: providerCfg.Label[entity.LanguageEnglish],
logo: modelLogo,
helpUrl: providerCfg.Help.URL[entity.LanguageEnglish],
models: eosc.BuildUntyped[string, IModel](),
defaultModels: eosc.BuildUntyped[string, IModel](),
modelsByType: eosc.BuildUntyped[string, []IModel](),
maskKeys: make([]string, 0),
recommend: providerCfg.Recommend,
sort: providerCfg.Sort,
uri: uri,
id: providerCfg.Provider,
name: providerCfg.Label[entity.LanguageEnglish],
logo: modelLogo,
helpUrl: providerCfg.Help.URL[entity.LanguageEnglish],
registry: newModelRegistry(),
maskKeys: make([]string, 0),
recommend: providerCfg.Recommend,
sort: providerCfg.Sort,
uri: uri,
modelConfig: ModelConfig{
AccessConfigurationStatus: providerCfg.ModelConfig.AccessConfigurationStatus,
AccessConfigurationDemo: providerCfg.ModelConfig.AccessConfigurationDemo,
@@ -159,7 +156,6 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri
defaultCfgByte, _ := json.MarshalIndent(defaultCfg, "", " ")
provider.IConfig = NewConfig(string(defaultCfgByte), params)
for name, f := range modelContents {
models := make([]IModel, 0, f.Count())
defaultModel := providerCfg.Default[name]
for i, v := range f.List() {
model, err := NewModel(v, modelLogo)
@@ -173,28 +169,55 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri
if model.ID() == defaultModel {
provider.SetDefaultModel(name, model)
}
models = append(models, model)
}
provider.SetModelsByType(name, models)
}
return provider, nil
}
type Provider struct {
id string
name string
logo string
helpUrl string
id string
name string
logo string
helpUrl string
registry *ModelRegistry
maskKeys []string
uri IProviderURI
sort int
recommend bool
modelConfig ModelConfig
mu sync.Mutex
IConfig
}
type ModelRegistry struct {
models eosc.Untyped[string, IModel]
defaultModels eosc.Untyped[string, IModel]
modelsByType eosc.Untyped[string, []IModel]
maskKeys []string
uri IProviderURI
sort int
recommend bool
modelConfig ModelConfig
IConfig
typeIndex map[string]*modelNode // type->header node
reverseMap map[string]*modelNode // ID->node
typeShard [8]sync.RWMutex // lock
}
type modelNode struct {
model IModel
prev, next *modelNode
typeKey string
}
func newModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: eosc.BuildUntyped[string, IModel](),
defaultModels: eosc.BuildUntyped[string, IModel](),
typeIndex: make(map[string]*modelNode),
reverseMap: make(map[string]*modelNode),
}
}
func (r *ModelRegistry) getShard(key string) *sync.RWMutex {
h := fnv.New32a()
h.Write([]byte(key))
return &r.typeShard[h.Sum32()%8]
}
type ModelConfig struct {
@@ -234,20 +257,89 @@ func (p *Provider) Logo() string {
return p.logo
}
func (r *ModelRegistry) addModel(m IModel, isDefault bool) {
r.models.Set(m.ID(), m)
// get lock
shard := r.getShard(m.ID())
shard.Lock()
defer shard.Unlock()
// create model node
node := &modelNode{
model: m,
typeKey: m.ModelType(),
}
// update index
if head := r.typeIndex[m.ModelType()]; head != nil {
node.next = head
head.prev = node
}
r.typeIndex[m.ModelType()] = node
r.reverseMap[m.ID()] = node
// default model
if isDefault {
r.defaultModels.Set(m.ModelType(), m)
}
}
func (r *ModelRegistry) removeModel(id string) {
r.models.Del(id)
// check node exist
node, exist := r.reverseMap[id]
if !exist {
return
}
// get lock
shard := r.getShard(node.typeKey)
shard.Lock()
defer shard.Unlock()
// delete node chain
if node.prev != nil {
node.prev.next = node.next
} else {
r.typeIndex[node.typeKey] = node.next
}
if node.next != nil {
node.next.prev = node.prev
}
// clean index
delete(r.reverseMap, id)
if r.typeIndex[node.typeKey] == nil {
delete(r.typeIndex, node.typeKey)
}
}
func (p *Provider) DefaultModel(modelType string) (IModel, bool) {
return p.defaultModels.Get(modelType)
return p.registry.defaultModels.Get(modelType)
}
func (p *Provider) GetModel(name string) (IModel, bool) {
return p.models.Get(name)
return p.registry.models.Get(name)
}
func (p *Provider) Models() []IModel {
return p.models.List()
return p.registry.models.List()
}
func (p *Provider) ModelsByType(modelType string) ([]IModel, bool) {
return p.modelsByType.Get(modelType)
shard := p.registry.getShard(modelType)
shard.RLock()
defer shard.RUnlock()
var result []IModel
if node := p.registry.typeIndex[modelType]; node != nil {
for n := node; n != nil; n = n.next {
result = append(result, n.model)
}
}
return result, true
}
func (p *Provider) MaskConfig(cfg string) string {
@@ -266,19 +358,15 @@ func (p *Provider) MaskConfig(cfg string) string {
}
func (p *Provider) SetDefaultModel(modelType string, model IModel) {
p.defaultModels.Set(modelType, model)
p.registry.addModel(model, true)
}
func (p *Provider) SetModel(id string, model IModel) {
p.models.Set(id, model)
p.registry.addModel(model, false)
}
func (p *Provider) RemoveModel(id string) {
p.models.Del(id)
}
func (p *Provider) SetModelsByType(modelType string, models []IModel) {
p.modelsByType.Set(modelType, models)
p.registry.removeModel(id)
}
type providerUri struct {
+8 -2
View File
@@ -78,19 +78,25 @@ func (i *imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider
}
func (i *imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error {
p, has := model_runtime.GetProvider(provider)
// check provider exist
providerInfo, err := i.providerService.Get(ctx, provider)
if err != nil {
return err
}
if providerInfo == nil {
if providerInfo == nil || !has {
return fmt.Errorf("provider not found")
}
modelInfo, _ := i.providerModelService.Get(ctx, id)
if modelInfo == nil || modelInfo.Provider != provider {
return fmt.Errorf("model not found")
}
return i.providerModelService.Delete(ctx, id)
if err := i.providerModelService.Delete(ctx, id); err != nil {
return err
}
p.RemoveModel(id)
return nil
}
func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) {
+10 -8
View File
@@ -200,12 +200,12 @@ func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewPr
}
id := uuid.New().String()
config, defaultLLM := "{\"api_endpoint_url\": \"http://127.0.0.1\", \"api_key\": \"\"}", ""
typeValue := 1
if err := i.providerService.Save(ctx, id, &ai.SetProvider{
Name: &input.Name,
DefaultLLM: &defaultLLM,
Config: &config,
Type: &typeValue,
if err := i.providerService.Create(ctx, &ai.CreateProvider{
Id: id,
Name: input.Name,
DefaultLLM: defaultLLM,
Config: config,
Type: 1,
}); err != nil {
return nil, err
}
@@ -213,8 +213,10 @@ func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewPr
iProvider, _ := model_runtime.NewCustomizeProvider(id, input.Name, []model_runtime.IModel{}, "", "")
model_runtime.Register(id, iProvider)
return &ai_dto.SimpleProvider{
Id: id,
Name: input.Name,
Id: id,
Name: input.Name,
DefaultConfig: config,
Logo: model_runtime.GetCustomizeLogo(),
}, nil
}
+2 -2
View File
@@ -70,8 +70,8 @@ func (i *imlProviderModelService) Save(ctx context.Context, id string, model *Mo
func (i *imlProviderModelService) CheckNameDuplicate(ctx context.Context, provider string, name string, excludeId string) bool {
v, _ := i.store.First(ctx, map[string]interface{}{"provider": provider, "name": name})
if v != nil {
return true
if v == nil {
return false
} else if excludeId != "" && v.UUID != excludeId {
return true
}
+2 -4
View File
@@ -72,10 +72,8 @@ type imlProviderService struct {
func (i *imlProviderService) CheckNameDuplicate(ctx context.Context, name string) bool {
v, _ := i.store.First(ctx, map[string]interface{}{"name": name})
if v != nil {
return true
}
return false
return v != nil
}
func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[string]string {
+1
View File
@@ -27,6 +27,7 @@ type CreateProvider struct {
DefaultLLM string
Config string
Status int
Type int
}
type SetProvider struct {