diff --git a/ai-provider/model-runtime/model.go b/ai-provider/model-runtime/model.go index 9ee7caa6..d96008a4 100644 --- a/ai-provider/model-runtime/model.go +++ b/ai-provider/model-runtime/model.go @@ -12,6 +12,7 @@ type IModel interface { Name() string Logo() string Source() string + ModelType() string SetLogo(logo string) AccessConfiguration() string ModelParameters() string @@ -26,6 +27,8 @@ type Model struct { modelParameters string // default: ""/"system", "customize" source string + // @see ModelTypeLLM etc + modelType string //defaultConfig string IConfig //validator IParamValidator @@ -52,6 +55,10 @@ func (m *Model) Source() string { return m.source } +func (m *Model) ModelType() string { + return m.modelType +} + func (m *Model) Logo() string { return m.logo } @@ -73,6 +80,7 @@ func NewCustomizeModel(id string, name string, logo string, accessConfiguration id: id, name: name, logo: logo, + modelType: ModelTypeLLM, source: "customize", accessConfiguration: accessConfiguration, modelParameters: modelParameters, @@ -154,6 +162,7 @@ func NewModel(data string, logo string) (IModel, error) { id: cfg.Model, name: cfg.Model, logo: logo, + modelType: cfg.ModelType, accessConfiguration: "", IConfig: NewConfig(string(dCfg), params), }, nil diff --git a/ai-provider/model-runtime/provider.go b/ai-provider/model-runtime/provider.go index dde582ea..9a36d562 100644 --- a/ai-provider/model-runtime/provider.go +++ b/ai-provider/model-runtime/provider.go @@ -3,8 +3,10 @@ package model_runtime import ( "encoding/json" "fmt" + "hash/fnv" "net/url" "strings" + "sync" yaml "gopkg.in/yaml.v3" @@ -19,8 +21,8 @@ const ( type IProvider interface { IProviderInfo GetModelConfig() ModelConfig - SetModelsByType(modelType string, models []IModel) SetModel(id string, model IModel) + RemoveModel(id string) SetDefaultModel(modelType string, model IModel) GetModel(name string) (IModel, bool) Models() []IModel @@ -66,17 +68,15 @@ func NewCustomizeProvider(id string, name string, models []IModel, defaultModel } provider := &Provider{ - id: id, - name: name, - logo: GetCustomizeLogo(), - helpUrl: "", - models: eosc.BuildUntyped[string, IModel](), - defaultModels: eosc.BuildUntyped[string, IModel](), - modelsByType: eosc.BuildUntyped[string, []IModel](), - maskKeys: make([]string, 0), - recommend: false, - sort: 0, - uri: uri, + id: id, + name: name, + logo: GetCustomizeLogo(), + helpUrl: "", + registry: newModelRegistry(), + maskKeys: make([]string, 0), + recommend: false, + sort: 0, + uri: uri, modelConfig: ModelConfig{ AccessConfigurationStatus: false, AccessConfigurationDemo: "", @@ -93,7 +93,6 @@ func NewCustomizeProvider(id string, name string, models []IModel, defaultModel provider.SetDefaultModel(name, model) } } - provider.SetModelsByType(ModelTypeLLM, models) return provider, nil } @@ -123,17 +122,15 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri return nil, fmt.Errorf("model logo not found:%s", providerCfg.Provider) } provider := &Provider{ - id: providerCfg.Provider, - name: providerCfg.Label[entity.LanguageEnglish], - logo: modelLogo, - helpUrl: providerCfg.Help.URL[entity.LanguageEnglish], - models: eosc.BuildUntyped[string, IModel](), - defaultModels: eosc.BuildUntyped[string, IModel](), - modelsByType: eosc.BuildUntyped[string, []IModel](), - maskKeys: make([]string, 0), - recommend: providerCfg.Recommend, - sort: providerCfg.Sort, - uri: uri, + id: providerCfg.Provider, + name: providerCfg.Label[entity.LanguageEnglish], + logo: modelLogo, + helpUrl: providerCfg.Help.URL[entity.LanguageEnglish], + registry: newModelRegistry(), + maskKeys: make([]string, 0), + recommend: providerCfg.Recommend, + sort: providerCfg.Sort, + uri: uri, modelConfig: ModelConfig{ AccessConfigurationStatus: providerCfg.ModelConfig.AccessConfigurationStatus, AccessConfigurationDemo: providerCfg.ModelConfig.AccessConfigurationDemo, @@ -159,7 +156,6 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri defaultCfgByte, _ := json.MarshalIndent(defaultCfg, "", " ") provider.IConfig = NewConfig(string(defaultCfgByte), params) for name, f := range modelContents { - models := make([]IModel, 0, f.Count()) defaultModel := providerCfg.Default[name] for i, v := range f.List() { model, err := NewModel(v, modelLogo) @@ -173,28 +169,55 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri if model.ID() == defaultModel { provider.SetDefaultModel(name, model) } - models = append(models, model) } - provider.SetModelsByType(name, models) } return provider, nil } type Provider struct { - id string - name string - logo string - helpUrl string + id string + name string + logo string + helpUrl string + registry *ModelRegistry + maskKeys []string + uri IProviderURI + sort int + recommend bool + modelConfig ModelConfig + mu sync.Mutex + IConfig +} + +type ModelRegistry struct { models eosc.Untyped[string, IModel] defaultModels eosc.Untyped[string, IModel] - modelsByType eosc.Untyped[string, []IModel] - maskKeys []string - uri IProviderURI - sort int - recommend bool - modelConfig ModelConfig - IConfig + + typeIndex map[string]*modelNode // type->header node + reverseMap map[string]*modelNode // ID->node + typeShard [8]sync.RWMutex // lock +} + +type modelNode struct { + model IModel + prev, next *modelNode + typeKey string +} + +func newModelRegistry() *ModelRegistry { + return &ModelRegistry{ + models: eosc.BuildUntyped[string, IModel](), + defaultModels: eosc.BuildUntyped[string, IModel](), + typeIndex: make(map[string]*modelNode), + reverseMap: make(map[string]*modelNode), + } +} + +func (r *ModelRegistry) getShard(key string) *sync.RWMutex { + h := fnv.New32a() + h.Write([]byte(key)) + return &r.typeShard[h.Sum32()%8] } type ModelConfig struct { @@ -234,20 +257,89 @@ func (p *Provider) Logo() string { return p.logo } +func (r *ModelRegistry) addModel(m IModel, isDefault bool) { + r.models.Set(m.ID(), m) + + // get lock + shard := r.getShard(m.ID()) + shard.Lock() + defer shard.Unlock() + + // create model node + node := &modelNode{ + model: m, + typeKey: m.ModelType(), + } + + // update index + if head := r.typeIndex[m.ModelType()]; head != nil { + node.next = head + head.prev = node + } + r.typeIndex[m.ModelType()] = node + r.reverseMap[m.ID()] = node + + // default model + if isDefault { + r.defaultModels.Set(m.ModelType(), m) + } +} + +func (r *ModelRegistry) removeModel(id string) { + r.models.Del(id) + + // check node exist + node, exist := r.reverseMap[id] + if !exist { + return + } + + // get lock + shard := r.getShard(node.typeKey) + shard.Lock() + defer shard.Unlock() + + // delete node chain + if node.prev != nil { + node.prev.next = node.next + } else { + r.typeIndex[node.typeKey] = node.next + } + if node.next != nil { + node.next.prev = node.prev + } + + // clean index + delete(r.reverseMap, id) + if r.typeIndex[node.typeKey] == nil { + delete(r.typeIndex, node.typeKey) + } +} + func (p *Provider) DefaultModel(modelType string) (IModel, bool) { - return p.defaultModels.Get(modelType) + return p.registry.defaultModels.Get(modelType) } func (p *Provider) GetModel(name string) (IModel, bool) { - return p.models.Get(name) + return p.registry.models.Get(name) } func (p *Provider) Models() []IModel { - return p.models.List() + return p.registry.models.List() } func (p *Provider) ModelsByType(modelType string) ([]IModel, bool) { - return p.modelsByType.Get(modelType) + shard := p.registry.getShard(modelType) + shard.RLock() + defer shard.RUnlock() + + var result []IModel + if node := p.registry.typeIndex[modelType]; node != nil { + for n := node; n != nil; n = n.next { + result = append(result, n.model) + } + } + return result, true } func (p *Provider) MaskConfig(cfg string) string { @@ -266,19 +358,15 @@ func (p *Provider) MaskConfig(cfg string) string { } func (p *Provider) SetDefaultModel(modelType string, model IModel) { - p.defaultModels.Set(modelType, model) + p.registry.addModel(model, true) } func (p *Provider) SetModel(id string, model IModel) { - p.models.Set(id, model) + p.registry.addModel(model, false) } func (p *Provider) RemoveModel(id string) { - p.models.Del(id) -} - -func (p *Provider) SetModelsByType(modelType string, models []IModel) { - p.modelsByType.Set(modelType, models) + p.registry.removeModel(id) } type providerUri struct { diff --git a/module/ai-model/iml.go b/module/ai-model/iml.go index 4a163f8a..5b3564c1 100644 --- a/module/ai-model/iml.go +++ b/module/ai-model/iml.go @@ -78,19 +78,25 @@ func (i *imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider } func (i *imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error { + p, has := model_runtime.GetProvider(provider) // check provider exist providerInfo, err := i.providerService.Get(ctx, provider) if err != nil { return err } - if providerInfo == nil { + if providerInfo == nil || !has { return fmt.Errorf("provider not found") } modelInfo, _ := i.providerModelService.Get(ctx, id) if modelInfo == nil || modelInfo.Provider != provider { return fmt.Errorf("model not found") } - return i.providerModelService.Delete(ctx, id) + if err := i.providerModelService.Delete(ctx, id); err != nil { + return err + } + p.RemoveModel(id) + + return nil } func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) { diff --git a/module/ai/iml.go b/module/ai/iml.go index 44a9179f..4bc359fb 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -200,12 +200,12 @@ func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewPr } id := uuid.New().String() config, defaultLLM := "{\"api_endpoint_url\": \"http://127.0.0.1\", \"api_key\": \"\"}", "" - typeValue := 1 - if err := i.providerService.Save(ctx, id, &ai.SetProvider{ - Name: &input.Name, - DefaultLLM: &defaultLLM, - Config: &config, - Type: &typeValue, + if err := i.providerService.Create(ctx, &ai.CreateProvider{ + Id: id, + Name: input.Name, + DefaultLLM: defaultLLM, + Config: config, + Type: 1, }); err != nil { return nil, err } @@ -213,8 +213,10 @@ func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewPr iProvider, _ := model_runtime.NewCustomizeProvider(id, input.Name, []model_runtime.IModel{}, "", "") model_runtime.Register(id, iProvider) return &ai_dto.SimpleProvider{ - Id: id, - Name: input.Name, + Id: id, + Name: input.Name, + DefaultConfig: config, + Logo: model_runtime.GetCustomizeLogo(), }, nil } diff --git a/service/ai-model/iml.go b/service/ai-model/iml.go index 7a594556..79d7faf7 100644 --- a/service/ai-model/iml.go +++ b/service/ai-model/iml.go @@ -70,8 +70,8 @@ func (i *imlProviderModelService) Save(ctx context.Context, id string, model *Mo func (i *imlProviderModelService) CheckNameDuplicate(ctx context.Context, provider string, name string, excludeId string) bool { v, _ := i.store.First(ctx, map[string]interface{}{"provider": provider, "name": name}) - if v != nil { - return true + if v == nil { + return false } else if excludeId != "" && v.UUID != excludeId { return true } diff --git a/service/ai/iml.go b/service/ai/iml.go index 7e62c2d3..d94f7ba0 100644 --- a/service/ai/iml.go +++ b/service/ai/iml.go @@ -72,10 +72,8 @@ type imlProviderService struct { func (i *imlProviderService) CheckNameDuplicate(ctx context.Context, name string) bool { v, _ := i.store.First(ctx, map[string]interface{}{"name": name}) - if v != nil { - return true - } - return false + + return v != nil } func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[string]string { diff --git a/service/ai/model.go b/service/ai/model.go index a10fe7a5..1beb7d08 100644 --- a/service/ai/model.go +++ b/service/ai/model.go @@ -27,6 +27,7 @@ type CreateProvider struct { DefaultLLM string Config string Status int + Type int } type SetProvider struct {