Files
dify-plus/api/services/model_provider_service_extend.py
T
2025-03-28 15:18:33 +08:00

141 lines
5.2 KiB
Python

import logging
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from models.tenant_model_sync_extend import *
logger = logging.getLogger(__name__)
class ModelProviderExtendService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = ProviderManager()
def get_model_credentials_obfuscated(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
"""
get model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get model custom credentials from ProviderModel if exists
return provider_configuration.get_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, obfuscated=False
)
@staticmethod
def create_tenant_model_sync_if_not_exist(
tenant_id: str, model_id, origin_model_id: str, is_all: bool = False
) -> bool:
available_ta = (
TenantModelSyncExtend.query.filter_by(
tenant_id=tenant_id, model_id=model_id, origin_model_id=origin_model_id
)
.order_by(TenantModelSyncExtend.id.asc())
.first()
)
if available_ta:
return False
ta = TenantModelSyncExtend(
tenant_id=tenant_id, model_id=model_id, origin_model_id=origin_model_id, is_all=is_all
)
db.session.add(ta)
db.session.commit()
return True
def save_model_credentials_without_validate(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> str:
"""
save model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom model credentials
return provider_configuration.add_or_update_custom_model_credentials_without_validate_extend(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
)
def save_provider_credentials_without_validate(self, tenant_id: str, provider: str, credentials: dict) -> str:
"""
save custom provider config.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom provider credentials.
return provider_configuration.add_or_update_custom_credentials_without_validate_extend(credentials)
def get_provider_credentials_obfuscated(self, tenant_id: str, provider: str) -> dict:
"""
get provider credentials.
:param tenant_id:
:param provider:
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get provider custom credentials from workspace
return provider_configuration.get_custom_credentials(obfuscated=False)
@staticmethod
def get_current_syned_tenants(origin_model_id: str) -> list[TenantModelSyncExtend]:
return db.session.query(TenantModelSyncExtend).filter(TenantModelSyncExtend.origin_model_id == origin_model_id).all()
@staticmethod
def delete_syned_tenants(origin_model_id, tenant_id: str
) -> bool:
syned_tenant = db.session.query(TenantModelSyncExtend).filter(TenantModelSyncExtend.origin_model_id == origin_model_id, TenantModelSyncExtend.tenant_id == tenant_id).first()
db.session.delete(syned_tenant)
db.session.commit()
return True