mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-12 18:11:42 +08:00
141 lines
5.2 KiB
Python
141 lines
5.2 KiB
Python
import logging
|
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
from core.provider_manager import ProviderManager
|
|
from extensions.ext_database import db
|
|
from models.tenant_model_sync_extend import *
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelProviderExtendService:
|
|
"""
|
|
Model Provider Service
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.provider_manager = ProviderManager()
|
|
|
|
def get_model_credentials_obfuscated(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
|
|
"""
|
|
get model credentials.
|
|
|
|
:param tenant_id: workspace id
|
|
:param provider: provider name
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
# Get all provider configurations of the current workspace
|
|
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
|
|
|
# Get provider configuration
|
|
provider_configuration = provider_configurations.get(provider)
|
|
if not provider_configuration:
|
|
raise ValueError(f"Provider {provider} does not exist.")
|
|
|
|
# Get model custom credentials from ProviderModel if exists
|
|
return provider_configuration.get_custom_model_credentials(
|
|
model_type=ModelType.value_of(model_type), model=model, obfuscated=False
|
|
)
|
|
|
|
@staticmethod
|
|
def create_tenant_model_sync_if_not_exist(
|
|
tenant_id: str, model_id, origin_model_id: str, is_all: bool = False
|
|
) -> bool:
|
|
available_ta = (
|
|
TenantModelSyncExtend.query.filter_by(
|
|
tenant_id=tenant_id, model_id=model_id, origin_model_id=origin_model_id
|
|
)
|
|
.order_by(TenantModelSyncExtend.id.asc())
|
|
.first()
|
|
)
|
|
|
|
if available_ta:
|
|
return False
|
|
|
|
ta = TenantModelSyncExtend(
|
|
tenant_id=tenant_id, model_id=model_id, origin_model_id=origin_model_id, is_all=is_all
|
|
)
|
|
db.session.add(ta)
|
|
db.session.commit()
|
|
return True
|
|
|
|
def save_model_credentials_without_validate(
|
|
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
|
) -> str:
|
|
"""
|
|
save model credentials.
|
|
|
|
:param tenant_id: workspace id
|
|
:param provider: provider name
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:param credentials: model credentials
|
|
:return:
|
|
"""
|
|
# Get all provider configurations of the current workspace
|
|
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
|
|
|
# Get provider configuration
|
|
provider_configuration = provider_configurations.get(provider)
|
|
if not provider_configuration:
|
|
raise ValueError(f"Provider {provider} does not exist.")
|
|
# Add or update custom model credentials
|
|
return provider_configuration.add_or_update_custom_model_credentials_without_validate_extend(
|
|
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
|
)
|
|
|
|
def save_provider_credentials_without_validate(self, tenant_id: str, provider: str, credentials: dict) -> str:
|
|
"""
|
|
save custom provider config.
|
|
|
|
:param tenant_id: workspace id
|
|
:param provider: provider name
|
|
:param credentials: provider credentials
|
|
:return:
|
|
"""
|
|
# Get all provider configurations of the current workspace
|
|
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
|
|
|
# Get provider configuration
|
|
provider_configuration = provider_configurations.get(provider)
|
|
if not provider_configuration:
|
|
raise ValueError(f"Provider {provider} does not exist.")
|
|
|
|
# Add or update custom provider credentials.
|
|
return provider_configuration.add_or_update_custom_credentials_without_validate_extend(credentials)
|
|
|
|
def get_provider_credentials_obfuscated(self, tenant_id: str, provider: str) -> dict:
|
|
"""
|
|
get provider credentials.
|
|
|
|
:param tenant_id:
|
|
:param provider:
|
|
:return:
|
|
"""
|
|
# Get all provider configurations of the current workspace
|
|
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
|
|
|
# Get provider configuration
|
|
provider_configuration = provider_configurations.get(provider)
|
|
if not provider_configuration:
|
|
raise ValueError(f"Provider {provider} does not exist.")
|
|
|
|
# Get provider custom credentials from workspace
|
|
return provider_configuration.get_custom_credentials(obfuscated=False)
|
|
|
|
@staticmethod
|
|
def get_current_syned_tenants(origin_model_id: str) -> list[TenantModelSyncExtend]:
|
|
return db.session.query(TenantModelSyncExtend).filter(TenantModelSyncExtend.origin_model_id == origin_model_id).all()
|
|
|
|
@staticmethod
|
|
def delete_syned_tenants(origin_model_id, tenant_id: str
|
|
) -> bool:
|
|
syned_tenant = db.session.query(TenantModelSyncExtend).filter(TenantModelSyncExtend.origin_model_id == origin_model_id, TenantModelSyncExtend.tenant_id == tenant_id).first()
|
|
|
|
db.session.delete(syned_tenant)
|
|
db.session.commit()
|
|
|
|
return True
|